diff options
Diffstat (limited to 'users/models')
-rw-r--r-- | users/models/domain.py | 4 | ||||
-rw-r--r-- | users/models/follow.py | 50 | ||||
-rw-r--r-- | users/models/identity.py | 176 |
3 files changed, 139 insertions, 91 deletions
diff --git a/users/models/domain.py b/users/models/domain.py index f503b89..8467ac3 100644 --- a/users/models/domain.py +++ b/users/models/domain.py @@ -48,14 +48,14 @@ class Domain(models.Model): updated = models.DateTimeField(auto_now=True) @classmethod - def get_remote_domain(cls, domain) -> "Domain": + def get_remote_domain(cls, domain: str) -> "Domain": try: return cls.objects.get(domain=domain, local=False) except cls.DoesNotExist: return cls.objects.create(domain=domain, local=False) @classmethod - def get_local_domain(cls, domain) -> Optional["Domain"]: + def get_local_domain(cls, domain: str) -> Optional["Domain"]: try: return cls.objects.get( models.Q(domain=domain) | models.Q(service_domain=domain) diff --git a/users/models/follow.py b/users/models/follow.py index 7287900..29d036e 100644 --- a/users/models/follow.py +++ b/users/models/follow.py @@ -1,5 +1,9 @@ +from typing import Optional + from django.db import models +from miniq.models import Task + class Follow(models.Model): """ @@ -17,7 +21,53 @@ class Follow(models.Model): related_name="inbound_follows", ) + uri = models.CharField(blank=True, null=True, max_length=500) note = models.TextField(blank=True, null=True) + requested = models.BooleanField(default=False) + accepted = models.BooleanField(default=False) + created = models.DateTimeField(auto_now_add=True) updated = models.DateTimeField(auto_now=True) + + class Meta: + unique_together = [("source", "target")] + + @classmethod + def maybe_get(cls, source, target) -> Optional["Follow"]: + """ + Returns a follow if it exists between source and target + """ + try: + return Follow.objects.get(source=source, target=target) + except Follow.DoesNotExist: + return None + + @classmethod + def create_local(cls, source, target): + """ + Creates a Follow from a local Identity to the target + (which can be local or remote). + """ + if not source.local: + raise ValueError("You cannot initiate follows on a remote Identity") + try: + follow = Follow.objects.get(source=source, target=target) + except Follow.DoesNotExist: + follow = Follow.objects.create(source=source, target=target, uri="") + follow.uri = source.actor_uri + f"follow/{follow.pk}/" + if target.local: + follow.requested = True + follow.accepted = True + else: + Task.submit("follow_request", str(follow.pk)) + follow.save() + return follow + + def undo(self): + """ + Undoes this follow + """ + if not self.target.local: + Task.submit("follow_undo", str(self.pk)) + self.delete() diff --git a/users/models/identity.py b/users/models/identity.py index 4939535..1f44e98 100644 --- a/users/models/identity.py +++ b/users/models/identity.py @@ -6,12 +6,11 @@ from urllib.parse import urlparse import httpx import urlman -from asgiref.sync import sync_to_async -from cryptography.hazmat.primitives import hashes, serialization -from cryptography.hazmat.primitives.asymmetric import padding, rsa +from asgiref.sync import async_to_sync, sync_to_async +from cryptography.hazmat.primitives import serialization +from cryptography.hazmat.primitives.asymmetric import rsa from django.db import models from django.utils import timezone -from django.utils.http import http_date from OpenSSL import crypto from core.ld import canonicalise @@ -34,7 +33,7 @@ class Identity(models.Model): # The Actor URI is essentially also a PK - we keep the default numeric # one around as well for making nice URLs etc. - actor_uri = models.CharField(max_length=500, blank=True, null=True, unique=True) + actor_uri = models.CharField(max_length=500, unique=True) local = models.BooleanField() users = models.ManyToManyField("users.User", related_name="identities") @@ -73,10 +72,35 @@ class Identity(models.Model): fetched = models.DateTimeField(null=True, blank=True) deleted = models.DateTimeField(null=True, blank=True) + ### Model attributes ### + class Meta: verbose_name_plural = "identities" unique_together = [("username", "domain")] + class urls(urlman.Urls): + view = "/@{self.username}@{self.domain_id}/" + view_short = "/@{self.username}/" + action = "{view}action/" + actor = "{view}actor/" + activate = "{view}activate/" + key = "{actor}#main-key" + inbox = "{actor}inbox/" + outbox = "{actor}outbox/" + + def get_scheme(self, url): + return "https" + + def get_hostname(self, url): + return self.instance.domain.uri_domain + + def __str__(self): + if self.username and self.domain_id: + return self.handle + return self.actor_uri + + ### Alternate constructors/fetchers ### + @classmethod def by_handle(cls, handle, fetch=False, local=False): if handle.startswith("@"): @@ -91,7 +115,15 @@ class Identity(models.Model): return cls.objects.get(username=username, domain_id=domain) except cls.DoesNotExist: if fetch and not local: - return cls.objects.create(handle=handle, local=False) + actor_uri, handle = async_to_sync(cls.fetch_webfinger)(handle) + username, domain = handle.split("@") + domain = Domain.get_remote_domain(domain) + return cls.objects.create( + actor_uri=actor_uri, + username=username, + domain_id=domain, + local=False, + ) return None @classmethod @@ -108,9 +140,17 @@ class Identity(models.Model): except cls.DoesNotExist: return cls.objects.create(actor_uri=uri, local=False) + ### Dynamic properties ### + + @property + def name_or_handle(self): + return self.name or self.handle + @property def handle(self): - return f"{self.username}@{self.domain_id}" + if self.domain_id: + return f"{self.username}@{self.domain_id}" + return f"{self.username}@UNKNOWN-DOMAIN" @property def data_age(self) -> float: @@ -123,23 +163,12 @@ class Identity(models.Model): return 10000000000 return (timezone.now() - self.fetched).total_seconds() - def generate_keypair(self): - if not self.local: - raise ValueError("Cannot generate keypair for remote user") - private_key = rsa.generate_private_key( - public_exponent=65537, - key_size=2048, - ) - self.private_key = private_key.private_bytes( - encoding=serialization.Encoding.PEM, - format=serialization.PrivateFormat.PKCS8, - encryption_algorithm=serialization.NoEncryption(), - ) - self.public_key = private_key.public_key().public_bytes( - encoding=serialization.Encoding.PEM, - format=serialization.PublicFormat.SubjectPublicKeyInfo, - ) - self.save() + @property + def outdated(self) -> bool: + # TODO: Setting + return self.data_age > 60 * 24 * 24 + + ### Actor/Webfinger fetching ### @classmethod async def fetch_webfinger(cls, handle: str) -> Tuple[Optional[str], Optional[str]]: @@ -189,6 +218,8 @@ class Identity(models.Model): self.outbox_uri = document.get("outbox") self.summary = document.get("summary") self.username = document.get("preferredUsername") + if "@value" in self.username: + self.username = self.username["@value"] self.manually_approves_followers = document.get( "as:manuallyApprovesFollowers" ) @@ -214,23 +245,42 @@ class Identity(models.Model): await sync_to_async(self.save)() return True - def sign(self, cleartext: str) -> str: + ### Cryptography ### + + def generate_keypair(self): + if not self.local: + raise ValueError("Cannot generate keypair for remote user") + private_key = rsa.generate_private_key( + public_exponent=65537, + key_size=2048, + ) + self.private_key = private_key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption(), + ).decode("ascii") + self.public_key = ( + private_key.public_key() + .public_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PublicFormat.SubjectPublicKeyInfo, + ) + .decode("ascii") + ) + self.save() + + def sign(self, cleartext: str) -> bytes: if not self.private_key: raise ValueError("Cannot sign - no private key") - private_key = serialization.load_pem_private_key( + pkey = crypto.load_privatekey( + crypto.FILETYPE_PEM, self.private_key.encode("ascii"), - password=None, ) - return base64.b64encode( - private_key.sign( - cleartext.encode("ascii"), - padding.PSS( - mgf=padding.MGF1(hashes.SHA256()), - salt_length=padding.PSS.MAX_LENGTH, - ), - hashes.SHA256(), - ) - ).decode("ascii") + return crypto.sign( + pkey, + cleartext.encode("ascii"), + "sha256", + ) def verify_signature(self, signature: bytes, cleartext: str) -> bool: if not self.public_key: @@ -247,55 +297,3 @@ class Identity(models.Model): except crypto.Error: return False return True - - async def signed_request(self, host, method, path, document): - """ - Delivers the document to the specified host, method, path and signed - as this user. - """ - date_string = http_date(timezone.now().timestamp()) - headers = { - "(request-target)": f"{method} {path}", - "Host": host, - "Date": date_string, - } - headers_string = " ".join(headers.keys()) - signed_string = "\n".join(f"{name}: {value}" for name, value in headers.items()) - signature = self.sign(signed_string) - del headers["(request-target)"] - headers[ - "Signature" - ] = f'keyId="{self.urls.key.full()}",headers="{headers_string}",signature="{signature}"' - async with httpx.AsyncClient() as client: - return await client.request( - method, - "https://{host}{path}", - headers=headers, - data=document, - ) - - def validate_signature(self, request): - """ - Attempts to validate the signature on an incoming request. - Returns False if the signature is invalid, None if it cannot be verified - as we do not have the key locally, or the name of the actor if it is valid. - """ - pass - - def __str__(self): - return self.handle or self.actor_uri - - class urls(urlman.Urls): - view = "/@{self.username}@{self.domain_id}/" - view_short = "/@{self.username}/" - actor = "{view}actor/" - key = "{actor}#main-key" - inbox = "{actor}inbox/" - outbox = "{actor}outbox/" - activate = "{view}activate/" - - def get_scheme(self, url): - return "https" - - def get_hostname(self, url): - return self.instance.domain.uri_domain |