diff options
-rw-r--r-- | .gitignore | 1 | ||||
-rw-r--r-- | core/ld.py | 59 | ||||
-rw-r--r-- | core/signatures.py | 67 | ||||
-rw-r--r-- | miniq/migrations/0001_initial.py | 10 | ||||
-rw-r--r-- | miniq/models.py | 5 | ||||
-rw-r--r-- | miniq/tasks.py | 34 | ||||
-rw-r--r-- | miniq/views.py | 30 | ||||
-rw-r--r-- | statuses/migrations/0001_initial.py | 2 | ||||
-rw-r--r-- | takahe/settings.py | 5 | ||||
-rw-r--r-- | takahe/urls.py | 1 | ||||
-rw-r--r-- | templates/identity/view.html | 17 | ||||
-rw-r--r-- | users/admin.py | 2 | ||||
-rw-r--r-- | users/migrations/0001_initial.py | 41 | ||||
-rw-r--r-- | users/models/domain.py | 4 | ||||
-rw-r--r-- | users/models/follow.py | 50 | ||||
-rw-r--r-- | users/models/identity.py | 176 | ||||
-rw-r--r-- | users/shortcuts.py | 19 | ||||
-rw-r--r-- | users/tasks/__init__.py | 0 | ||||
-rw-r--r-- | users/tasks/follow.py | 28 | ||||
-rw-r--r-- | users/tasks/identity.py | 11 | ||||
-rw-r--r-- | users/tasks/inbox.py | 36 | ||||
-rw-r--r-- | users/views/identity.py | 59 |
22 files changed, 476 insertions, 181 deletions
@@ -1,2 +1,3 @@ *.psql *.sqlite3 +notes.md @@ -227,6 +227,49 @@ schemas = { } }, }, + "*/schemas/litepub-0.1.jsonld": { + "contentType": "application/ld+json", + "documentUrl": "http://w3id.org/security/v1", + "contextUrl": None, + "document": { + "@context": [ + "https://www.w3.org/ns/activitystreams", + "https://w3id.org/security/v1", + { + "Emoji": "toot:Emoji", + "Hashtag": "as:Hashtag", + "PropertyValue": "schema:PropertyValue", + "atomUri": "ostatus:atomUri", + "conversation": {"@id": "ostatus:conversation", "@type": "@id"}, + "discoverable": "toot:discoverable", + "manuallyApprovesFollowers": "as:manuallyApprovesFollowers", + "capabilities": "litepub:capabilities", + "ostatus": "http://ostatus.org#", + "schema": "http://schema.org#", + "toot": "http://joinmastodon.org/ns#", + "misskey": "https://misskey-hub.net/ns#", + "fedibird": "http://fedibird.com/ns#", + "value": "schema:value", + "sensitive": "as:sensitive", + "litepub": "http://litepub.social/ns#", + "invisible": "litepub:invisible", + "directMessage": "litepub:directMessage", + "listMessage": {"@id": "litepub:listMessage", "@type": "@id"}, + "quoteUrl": "as:quoteUrl", + "quoteUri": "fedibird:quoteUri", + "oauthRegistrationEndpoint": { + "@id": "litepub:oauthRegistrationEndpoint", + "@type": "@id", + }, + "EmojiReact": "litepub:EmojiReact", + "ChatMessage": "litepub:ChatMessage", + "alsoKnownAs": {"@id": "as:alsoKnownAs", "@type": "@id"}, + "vcard": "http://www.w3.org/2006/vcard/ns#", + "formerRepresentations": "litepub:formerRepresentations", + }, + ] + }, + }, } @@ -244,12 +287,16 @@ def builtin_document_loader(url: str, options={}): try: return schemas[key] except KeyError: - raise JsonLdError( - f"No schema built-in for {key!r}", - "jsonld.LoadDocumentError", - code="loading document failed", - cause="KeyError", - ) + try: + key = "*" + pieces.path.rstrip("/") + return schemas[key] + except KeyError: + raise JsonLdError( + f"No schema built-in for {key!r}", + "jsonld.LoadDocumentError", + code="loading document failed", + cause="KeyError", + ) def canonicalise(json_data, include_security=False): diff --git a/core/signatures.py b/core/signatures.py index a5e4fed..6f4d9ef 100644 --- a/core/signatures.py +++ b/core/signatures.py @@ -1,8 +1,14 @@ import base64 -from typing import List, TypedDict +import json +from typing import Dict, List, Literal, TypedDict +from urllib.parse import urlparse +import httpx from cryptography.hazmat.primitives import hashes from django.http import HttpRequest +from django.utils.http import http_date + +from users.models import Identity class HttpSignature: @@ -25,7 +31,8 @@ class HttpSignature: @classmethod def headers_from_request(cls, request: HttpRequest, header_names: List[str]) -> str: """ - Creates the to-be-signed header payload from a Django request""" + Creates the to-be-signed header payload from a Django request + """ headers = {} for header_name in header_names: if header_name == "(request-target)": @@ -38,7 +45,7 @@ class HttpSignature: return "\n".join(f"{name.lower()}: {value}" for name, value in headers.items()) @classmethod - def parse_signature(cls, signature) -> "SignatureDetails": + def parse_signature(cls, signature: str) -> "SignatureDetails": bits = {} for item in signature.split(","): name, value = item.split("=", 1) @@ -52,6 +59,60 @@ class HttpSignature: } return signature_details + @classmethod + def compile_signature(cls, details: "SignatureDetails") -> str: + value = f'keyId="{details["keyid"]}",headers="' + value += " ".join(h.lower() for h in details["headers"]) + value += '",signature="' + value += base64.b64encode(details["signature"]).decode("ascii") + value += f'",algorithm="{details["algorithm"]}"' + return value + + @classmethod + async def signed_request( + self, + uri: str, + body: Dict, + identity: Identity, + content_type: str = "application/json", + method: Literal["post"] = "post", + ): + """ + Performs an async request to the given path, with a document, signed + as an identity. + """ + uri_parts = urlparse(uri) + date_string = http_date() + body_bytes = json.dumps(body).encode("utf8") + headers = { + "(request-target)": f"{method} {uri_parts.path}", + "Host": uri_parts.hostname, + "Date": date_string, + "Digest": self.calculate_digest(body_bytes), + "Content-Type": content_type, + } + signed_string = "\n".join( + f"{name.lower()}: {value}" for name, value in headers.items() + ) + headers["Signature"] = self.compile_signature( + { + "keyid": identity.urls.key.full(), # type:ignore + "headers": list(headers.keys()), + "signature": identity.sign(signed_string), + "algorithm": "rsa-sha256", + } + ) + del headers["(request-target)"] + async with httpx.AsyncClient() as client: + print(f"Calling {method} {uri}") + print(body) + return await client.request( + method, + uri, + headers=headers, + content=body_bytes, + ) + class SignatureDetails(TypedDict): algorithm: str diff --git a/miniq/migrations/0001_initial.py b/miniq/migrations/0001_initial.py index 32c5d53..dc6d42b 100644 --- a/miniq/migrations/0001_initial.py +++ b/miniq/migrations/0001_initial.py @@ -1,4 +1,4 @@ -# Generated by Django 4.1.3 on 2022-11-06 19:58 +# Generated by Django 4.1.3 on 2022-11-07 04:19 from django.db import migrations, models @@ -25,7 +25,13 @@ class Migration(migrations.Migration): ( "type", models.CharField( - choices=[("identity_fetch", "Identity Fetch")], max_length=500 + choices=[ + ("identity_fetch", "Identity Fetch"), + ("inbox_item", "Inbox Item"), + ("follow_request", "Follow Request"), + ("follow_acknowledge", "Follow Acknowledge"), + ], + max_length=500, ), ), ("priority", models.IntegerField(default=0)), diff --git a/miniq/models.py b/miniq/models.py index 996b482..24d311c 100644 --- a/miniq/models.py +++ b/miniq/models.py @@ -11,6 +11,9 @@ class Task(models.Model): class TypeChoices(models.TextChoices): identity_fetch = "identity_fetch" + inbox_item = "inbox_item" + follow_request = "follow_request" + follow_acknowledge = "follow_acknowledge" type = models.CharField(max_length=500, choices=TypeChoices.choices) priority = models.IntegerField(default=0) @@ -42,7 +45,7 @@ class Task(models.Model): return next_task @classmethod - def submit(cls, type, subject, payload=None, deduplicate=True): + def submit(cls, type, subject: str, payload=None, deduplicate=True): # Deduplication is done against tasks that have not started yet only, # and only on tasks without payloads if deduplicate and not payload: diff --git a/miniq/tasks.py b/miniq/tasks.py new file mode 100644 index 0000000..fedf8fd --- /dev/null +++ b/miniq/tasks.py @@ -0,0 +1,34 @@ +import traceback + +from users.tasks.follow import handle_follow_request +from users.tasks.identity import handle_identity_fetch +from users.tasks.inbox import handle_inbox_item + + +class TaskHandler: + + handlers = { + "identity_fetch": handle_identity_fetch, + "inbox_item": handle_inbox_item, + "follow_request": handle_follow_request, + } + + def __init__(self, task): + self.task = task + self.subject = self.task.subject + self.payload = self.task.payload + + async def handle(self): + try: + print(f"Task {self.task}: Starting") + if self.task.type not in self.handlers: + raise ValueError(f"Cannot handle type {self.task.type}") + await self.handlers[self.task.type]( + self, + ) + await self.task.complete() + print(f"Task {self.task}: Complete") + except BaseException as e: + print(f"Task {self.task}: Error {e}") + traceback.print_exc() + await self.task.fail(f"{e}\n\n" + traceback.format_exc()) diff --git a/miniq/views.py b/miniq/views.py index 21275f8..80c9ee2 100644 --- a/miniq/views.py +++ b/miniq/views.py @@ -1,6 +1,5 @@ import asyncio import time -import traceback import uuid from asgiref.sync import sync_to_async @@ -8,7 +7,7 @@ from django.http import HttpResponse from django.views import View from miniq.models import Task -from users.models import Identity +from miniq.tasks import TaskHandler class QueueProcessor(View): @@ -19,7 +18,8 @@ class QueueProcessor(View): START_TIMEOUT = 30 TOTAL_TIMEOUT = 60 - MAX_TASKS = 10 + LOCK_TIMEOUT = 200 + MAX_TASKS = 20 async def get(self, request): start_time = time.monotonic() @@ -35,10 +35,11 @@ class QueueProcessor(View): # Pop a task off the queue and run it task = await sync_to_async(Task.get_one_available)(processor_id) if task is not None: - self.tasks.append(asyncio.create_task(self.run_task(task))) + self.tasks.append(asyncio.create_task(TaskHandler(task).handle())) handled += 1 # Prevent busylooping await asyncio.sleep(0.01) + # TODO: Clean up old locks here # Then wait for tasks to finish while (time.monotonic() - start_time) < self.TOTAL_TIMEOUT: # Remove completed tasks @@ -48,24 +49,3 @@ class QueueProcessor(View): # Prevent busylooping await asyncio.sleep(1) return HttpResponse(f"{handled} tasks handled") - - async def run_task(self, task): - try: - print(f"Task {task}: Starting") - handler = getattr(self, f"handle_{task.type}", None) - if handler is None: - raise ValueError(f"Cannot handle type {task.type}") - await handler(task.subject, task.payload) - await task.complete() - print(f"Task {task}: Complete") - except BaseException as e: - print(f"Task {task}: Error {e}") - traceback.print_exc() - await task.fail(f"{e}\n\n" + traceback.format_exc()) - - async def handle_identity_fetch(self, subject, payload): - # Get the actor URI via webfinger - actor_uri, handle = await Identity.fetch_webfinger(subject) - # Get or create the identity, then fetch - identity = await sync_to_async(Identity.by_actor_uri)(actor_uri, create=True) - await identity.fetch_actor() diff --git a/statuses/migrations/0001_initial.py b/statuses/migrations/0001_initial.py index 933c526..55c6c6c 100644 --- a/statuses/migrations/0001_initial.py +++ b/statuses/migrations/0001_initial.py @@ -1,4 +1,4 @@ -# Generated by Django 4.1.3 on 2022-11-06 19:58 +# Generated by Django 4.1.3 on 2022-11-07 04:19 import django.db.models.deletion from django.db import migrations, models diff --git a/takahe/settings.py b/takahe/settings.py index c3c8d38..78a8403 100644 --- a/takahe/settings.py +++ b/takahe/settings.py @@ -62,8 +62,9 @@ WSGI_APPLICATION = "takahe.wsgi.application" DATABASES = { "default": { - "ENGINE": "django.db.backends.sqlite3", - "NAME": BASE_DIR / "db.sqlite3", + "ENGINE": "django.db.backends.postgresql_psycopg2", + "NAME": "takahe", + "USER": "postgres", } } diff --git a/takahe/urls.py b/takahe/urls.py index f8bff07..304bc23 100644 --- a/takahe/urls.py +++ b/takahe/urls.py @@ -14,6 +14,7 @@ urlpatterns = [ path("@<handle>/", identity.ViewIdentity.as_view()), path("@<handle>/actor/", identity.Actor.as_view()), path("@<handle>/actor/inbox/", identity.Inbox.as_view()), + path("@<handle>/action/", identity.ActionIdentity.as_view()), # Identity selection path("@<handle>/activate/", identity.ActivateIdentity.as_view()), path("identity/select/", identity.SelectIdentity.as_view()), diff --git a/templates/identity/view.html b/templates/identity/view.html index ffb76db..d82543e 100644 --- a/templates/identity/view.html +++ b/templates/identity/view.html @@ -10,11 +10,11 @@ {% else %} <img src="{% static "img/unknown-icon-128.png" %}" class="icon"> {% endif %} - {{ identity }} <small>@{{ identity.handle }}</small> + {{ identity.name_or_handle }} <small>@{{ identity.handle }}</small> </h1> {% if not identity.local %} - {% if not identity.actor_uri %} + {% if identity.outdated and not identity.name %} <p class="system-note"> The system is still fetching this profile. Refresh to see updates. </p> @@ -26,6 +26,19 @@ {% endif %} {% endif %} + {% if request.identity %} + <form action="{{ identity.urls.action }}" method="POST"> + {% csrf_token %} + {% if follow %} + <input type="hidden" name="action" value="unfollow"> + <button>Unfollow</button> + {% else %} + <input type="hidden" name="action" value="follow"> + <button>Follow</button> + {% endif %} + </form> + {% endif %} + {% for status in statuses %} {% include "statuses/_status.html" %} {% empty %} diff --git a/users/admin.py b/users/admin.py index 5672876..bb07aa1 100644 --- a/users/admin.py +++ b/users/admin.py @@ -21,4 +21,4 @@ class UserEventAdmin(admin.ModelAdmin): @admin.register(Identity) class IdentityAdmin(admin.ModelAdmin): - list_display = ["id", "handle", "name", "local"] + list_display = ["id", "handle", "actor_uri", "name", "local"] diff --git a/users/migrations/0001_initial.py b/users/migrations/0001_initial.py index 364daaa..f5ebf55 100644 --- a/users/migrations/0001_initial.py +++ b/users/migrations/0001_initial.py @@ -1,4 +1,4 @@ -# Generated by Django 4.1.3 on 2022-11-06 19:58 +# Generated by Django 4.1.3 on 2022-11-07 04:19 import functools @@ -56,11 +56,17 @@ class Migration(migrations.Migration): ), ( "service_domain", - models.CharField(blank=True, max_length=250, null=True), + models.CharField( + blank=True, + db_index=True, + max_length=250, + null=True, + unique=True, + ), ), ("local", models.BooleanField()), ("blocked", models.BooleanField(default=False)), - ("public", models.BooleanField()), + ("public", models.BooleanField(default=False)), ("created", models.DateTimeField(auto_now_add=True)), ("updated", models.DateTimeField(auto_now=True)), ( @@ -118,12 +124,7 @@ class Migration(migrations.Migration): verbose_name="ID", ), ), - ( - "actor_uri", - models.CharField( - blank=True, max_length=500, null=True, unique=True - ), - ), + ("actor_uri", models.CharField(max_length=500, unique=True)), ("local", models.BooleanField()), ("username", models.CharField(blank=True, max_length=500, null=True)), ("name", models.CharField(blank=True, max_length=500, null=True)), @@ -192,7 +193,7 @@ class Migration(migrations.Migration): }, ), migrations.CreateModel( - name="Follow", + name="Block", fields=[ ( "id", @@ -203,6 +204,8 @@ class Migration(migrations.Migration): verbose_name="ID", ), ), + ("mute", models.BooleanField()), + ("expires", models.DateTimeField(blank=True, null=True)), ("note", models.TextField(blank=True, null=True)), ("created", models.DateTimeField(auto_now_add=True)), ("updated", models.DateTimeField(auto_now=True)), @@ -210,7 +213,7 @@ class Migration(migrations.Migration): "source", models.ForeignKey( on_delete=django.db.models.deletion.CASCADE, - related_name="outbound_follows", + related_name="outbound_blocks", to="users.identity", ), ), @@ -218,14 +221,14 @@ class Migration(migrations.Migration): "target", models.ForeignKey( on_delete=django.db.models.deletion.CASCADE, - related_name="inbound_follows", + related_name="inbound_blocks", to="users.identity", ), ), ], ), migrations.CreateModel( - name="Block", + name="Follow", fields=[ ( "id", @@ -236,16 +239,17 @@ class Migration(migrations.Migration): verbose_name="ID", ), ), - ("mute", models.BooleanField()), - ("expires", models.DateTimeField(blank=True, null=True)), + ("uri", models.CharField(blank=True, max_length=500, null=True)), ("note", models.TextField(blank=True, null=True)), + ("requested", models.BooleanField(default=False)), + ("accepted", models.BooleanField(default=False)), ("created", models.DateTimeField(auto_now_add=True)), ("updated", models.DateTimeField(auto_now=True)), ( "source", models.ForeignKey( on_delete=django.db.models.deletion.CASCADE, - related_name="outbound_blocks", + related_name="outbound_follows", to="users.identity", ), ), @@ -253,10 +257,13 @@ class Migration(migrations.Migration): "target", models.ForeignKey( on_delete=django.db.models.deletion.CASCADE, - related_name="inbound_blocks", + related_name="inbound_follows", to="users.identity", ), ), ], + options={ + "unique_together": {("source", "target")}, + }, ), ] diff --git a/users/models/domain.py b/users/models/domain.py index f503b89..8467ac3 100644 --- a/users/models/domain.py +++ b/users/models/domain.py @@ -48,14 +48,14 @@ class Domain(models.Model): updated = models.DateTimeField(auto_now=True) @classmethod - def get_remote_domain(cls, domain) -> "Domain": + def get_remote_domain(cls, domain: str) -> "Domain": try: return cls.objects.get(domain=domain, local=False) except cls.DoesNotExist: return cls.objects.create(domain=domain, local=False) @classmethod - def get_local_domain(cls, domain) -> Optional["Domain"]: + def get_local_domain(cls, domain: str) -> Optional["Domain"]: try: return cls.objects.get( models.Q(domain=domain) | models.Q(service_domain=domain) diff --git a/users/models/follow.py b/users/models/follow.py index 7287900..29d036e 100644 --- a/users/models/follow.py +++ b/users/models/follow.py @@ -1,5 +1,9 @@ +from typing import Optional + from django.db import models +from miniq.models import Task + class Follow(models.Model): """ @@ -17,7 +21,53 @@ class Follow(models.Model): related_name="inbound_follows", ) + uri = models.CharField(blank=True, null=True, max_length=500) note = models.TextField(blank=True, null=True) + requested = models.BooleanField(default=False) + accepted = models.BooleanField(default=False) + created = models.DateTimeField(auto_now_add=True) updated = models.DateTimeField(auto_now=True) + + class Meta: + unique_together = [("source", "target")] + + @classmethod + def maybe_get(cls, source, target) -> Optional["Follow"]: + """ + Returns a follow if it exists between source and target + """ + try: + return Follow.objects.get(source=source, target=target) + except Follow.DoesNotExist: + return None + + @classmethod + def create_local(cls, source, target): + """ + Creates a Follow from a local Identity to the target + (which can be local or remote). + """ + if not source.local: + raise ValueError("You cannot initiate follows on a remote Identity") + try: + follow = Follow.objects.get(source=source, target=target) + except Follow.DoesNotExist: + follow = Follow.objects.create(source=source, target=target, uri="") + follow.uri = source.actor_uri + f"follow/{follow.pk}/" + if target.local: + follow.requested = True + follow.accepted = True + else: + Task.submit("follow_request", str(follow.pk)) + follow.save() + return follow + + def undo(self): + """ + Undoes this follow + """ + if not self.target.local: + Task.submit("follow_undo", str(self.pk)) + self.delete() diff --git a/users/models/identity.py b/users/models/identity.py index 4939535..1f44e98 100644 --- a/users/models/identity.py +++ b/users/models/identity.py @@ -6,12 +6,11 @@ from urllib.parse import urlparse import httpx import urlman -from asgiref.sync import sync_to_async -from cryptography.hazmat.primitives import hashes, serialization -from cryptography.hazmat.primitives.asymmetric import padding, rsa +from asgiref.sync import async_to_sync, sync_to_async +from cryptography.hazmat.primitives import serialization +from cryptography.hazmat.primitives.asymmetric import rsa from django.db import models from django.utils import timezone -from django.utils.http import http_date from OpenSSL import crypto from core.ld import canonicalise @@ -34,7 +33,7 @@ class Identity(models.Model): # The Actor URI is essentially also a PK - we keep the default numeric # one around as well for making nice URLs etc. - actor_uri = models.CharField(max_length=500, blank=True, null=True, unique=True) + actor_uri = models.CharField(max_length=500, unique=True) local = models.BooleanField() users = models.ManyToManyField("users.User", related_name="identities") @@ -73,10 +72,35 @@ class Identity(models.Model): fetched = models.DateTimeField(null=True, blank=True) deleted = models.DateTimeField(null=True, blank=True) + ### Model attributes ### + class Meta: verbose_name_plural = "identities" unique_together = [("username", "domain")] + class urls(urlman.Urls): + view = "/@{self.username}@{self.domain_id}/" + view_short = "/@{self.username}/" + action = "{view}action/" + actor = "{view}actor/" + activate = "{view}activate/" + key = "{actor}#main-key" + inbox = "{actor}inbox/" + outbox = "{actor}outbox/" + + def get_scheme(self, url): + return "https" + + def get_hostname(self, url): + return self.instance.domain.uri_domain + + def __str__(self): + if self.username and self.domain_id: + return self.handle + return self.actor_uri + + ### Alternate constructors/fetchers ### + @classmethod def by_handle(cls, handle, fetch=False, local=False): if handle.startswith("@"): @@ -91,7 +115,15 @@ class Identity(models.Model): return cls.objects.get(username=username, domain_id=domain) except cls.DoesNotExist: if fetch and not local: - return cls.objects.create(handle=handle, local=False) + actor_uri, handle = async_to_sync(cls.fetch_webfinger)(handle) + username, domain = handle.split("@") + domain = Domain.get_remote_domain(domain) + return cls.objects.create( + actor_uri=actor_uri, + username=username, + domain_id=domain, + local=False, + ) return None @classmethod @@ -108,9 +140,17 @@ class Identity(models.Model): except cls.DoesNotExist: return cls.objects.create(actor_uri=uri, local=False) + ### Dynamic properties ### + + @property + def name_or_handle(self): + return self.name or self.handle + @property def handle(self): - return f"{self.username}@{self.domain_id}" + if self.domain_id: + return f"{self.username}@{self.domain_id}" + return f"{self.username}@UNKNOWN-DOMAIN" @property def data_age(self) -> float: @@ -123,23 +163,12 @@ class Identity(models.Model): return 10000000000 return (timezone.now() - self.fetched).total_seconds() - def generate_keypair(self): - if not self.local: - raise ValueError("Cannot generate keypair for remote user") - private_key = rsa.generate_private_key( - public_exponent=65537, - key_size=2048, - ) - self.private_key = private_key.private_bytes( - encoding=serialization.Encoding.PEM, - format=serialization.PrivateFormat.PKCS8, - encryption_algorithm=serialization.NoEncryption(), - ) - self.public_key = private_key.public_key().public_bytes( - encoding=serialization.Encoding.PEM, - format=serialization.PublicFormat.SubjectPublicKeyInfo, - ) - self.save() + @property + def outdated(self) -> bool: + # TODO: Setting + return self.data_age > 60 * 24 * 24 + + ### Actor/Webfinger fetching ### @classmethod async def fetch_webfinger(cls, handle: str) -> Tuple[Optional[str], Optional[str]]: @@ -189,6 +218,8 @@ class Identity(models.Model): self.outbox_uri = document.get("outbox") self.summary = document.get("summary") self.username = document.get("preferredUsername") + if "@value" in self.username: + self.username = self.username["@value"] self.manually_approves_followers = document.get( "as:manuallyApprovesFollowers" ) @@ -214,23 +245,42 @@ class Identity(models.Model): await sync_to_async(self.save)() return True - def sign(self, cleartext: str) -> str: + ### Cryptography ### + + def generate_keypair(self): + if not self.local: + raise ValueError("Cannot generate keypair for remote user") + private_key = rsa.generate_private_key( + public_exponent=65537, + key_size=2048, + ) + self.private_key = private_key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption(), + ).decode("ascii") + self.public_key = ( + private_key.public_key() + .public_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PublicFormat.SubjectPublicKeyInfo, + ) + .decode("ascii") + ) + self.save() + + def sign(self, cleartext: str) -> bytes: if not self.private_key: raise ValueError("Cannot sign - no private key") - private_key = serialization.load_pem_private_key( + pkey = crypto.load_privatekey( + crypto.FILETYPE_PEM, self.private_key.encode("ascii"), - password=None, ) - return base64.b64encode( - private_key.sign( - cleartext.encode("ascii"), - padding.PSS( - mgf=padding.MGF1(hashes.SHA256()), - salt_length=padding.PSS.MAX_LENGTH, - ), - hashes.SHA256(), - ) - ).decode("ascii") + return crypto.sign( + pkey, + cleartext.encode("ascii"), + "sha256", + ) def verify_signature(self, signature: bytes, cleartext: str) -> bool: if not self.public_key: @@ -247,55 +297,3 @@ class Identity(models.Mod |