summaryrefslogtreecommitdiffstats
path: root/users/models
diff options
context:
space:
mode:
Diffstat (limited to 'users/models')
-rw-r--r--users/models/domain.py4
-rw-r--r--users/models/follow.py50
-rw-r--r--users/models/identity.py176
3 files changed, 139 insertions, 91 deletions
diff --git a/users/models/domain.py b/users/models/domain.py
index f503b89..8467ac3 100644
--- a/users/models/domain.py
+++ b/users/models/domain.py
@@ -48,14 +48,14 @@ class Domain(models.Model):
updated = models.DateTimeField(auto_now=True)
@classmethod
- def get_remote_domain(cls, domain) -> "Domain":
+ def get_remote_domain(cls, domain: str) -> "Domain":
try:
return cls.objects.get(domain=domain, local=False)
except cls.DoesNotExist:
return cls.objects.create(domain=domain, local=False)
@classmethod
- def get_local_domain(cls, domain) -> Optional["Domain"]:
+ def get_local_domain(cls, domain: str) -> Optional["Domain"]:
try:
return cls.objects.get(
models.Q(domain=domain) | models.Q(service_domain=domain)
diff --git a/users/models/follow.py b/users/models/follow.py
index 7287900..29d036e 100644
--- a/users/models/follow.py
+++ b/users/models/follow.py
@@ -1,5 +1,9 @@
+from typing import Optional
+
from django.db import models
+from miniq.models import Task
+
class Follow(models.Model):
"""
@@ -17,7 +21,53 @@ class Follow(models.Model):
related_name="inbound_follows",
)
+ uri = models.CharField(blank=True, null=True, max_length=500)
note = models.TextField(blank=True, null=True)
+ requested = models.BooleanField(default=False)
+ accepted = models.BooleanField(default=False)
+
created = models.DateTimeField(auto_now_add=True)
updated = models.DateTimeField(auto_now=True)
+
+ class Meta:
+ unique_together = [("source", "target")]
+
+ @classmethod
+ def maybe_get(cls, source, target) -> Optional["Follow"]:
+ """
+ Returns a follow if it exists between source and target
+ """
+ try:
+ return Follow.objects.get(source=source, target=target)
+ except Follow.DoesNotExist:
+ return None
+
+ @classmethod
+ def create_local(cls, source, target):
+ """
+ Creates a Follow from a local Identity to the target
+ (which can be local or remote).
+ """
+ if not source.local:
+ raise ValueError("You cannot initiate follows on a remote Identity")
+ try:
+ follow = Follow.objects.get(source=source, target=target)
+ except Follow.DoesNotExist:
+ follow = Follow.objects.create(source=source, target=target, uri="")
+ follow.uri = source.actor_uri + f"follow/{follow.pk}/"
+ if target.local:
+ follow.requested = True
+ follow.accepted = True
+ else:
+ Task.submit("follow_request", str(follow.pk))
+ follow.save()
+ return follow
+
+ def undo(self):
+ """
+ Undoes this follow
+ """
+ if not self.target.local:
+ Task.submit("follow_undo", str(self.pk))
+ self.delete()
diff --git a/users/models/identity.py b/users/models/identity.py
index 4939535..1f44e98 100644
--- a/users/models/identity.py
+++ b/users/models/identity.py
@@ -6,12 +6,11 @@ from urllib.parse import urlparse
import httpx
import urlman
-from asgiref.sync import sync_to_async
-from cryptography.hazmat.primitives import hashes, serialization
-from cryptography.hazmat.primitives.asymmetric import padding, rsa
+from asgiref.sync import async_to_sync, sync_to_async
+from cryptography.hazmat.primitives import serialization
+from cryptography.hazmat.primitives.asymmetric import rsa
from django.db import models
from django.utils import timezone
-from django.utils.http import http_date
from OpenSSL import crypto
from core.ld import canonicalise
@@ -34,7 +33,7 @@ class Identity(models.Model):
# The Actor URI is essentially also a PK - we keep the default numeric
# one around as well for making nice URLs etc.
- actor_uri = models.CharField(max_length=500, blank=True, null=True, unique=True)
+ actor_uri = models.CharField(max_length=500, unique=True)
local = models.BooleanField()
users = models.ManyToManyField("users.User", related_name="identities")
@@ -73,10 +72,35 @@ class Identity(models.Model):
fetched = models.DateTimeField(null=True, blank=True)
deleted = models.DateTimeField(null=True, blank=True)
+ ### Model attributes ###
+
class Meta:
verbose_name_plural = "identities"
unique_together = [("username", "domain")]
+ class urls(urlman.Urls):
+ view = "/@{self.username}@{self.domain_id}/"
+ view_short = "/@{self.username}/"
+ action = "{view}action/"
+ actor = "{view}actor/"
+ activate = "{view}activate/"
+ key = "{actor}#main-key"
+ inbox = "{actor}inbox/"
+ outbox = "{actor}outbox/"
+
+ def get_scheme(self, url):
+ return "https"
+
+ def get_hostname(self, url):
+ return self.instance.domain.uri_domain
+
+ def __str__(self):
+ if self.username and self.domain_id:
+ return self.handle
+ return self.actor_uri
+
+ ### Alternate constructors/fetchers ###
+
@classmethod
def by_handle(cls, handle, fetch=False, local=False):
if handle.startswith("@"):
@@ -91,7 +115,15 @@ class Identity(models.Model):
return cls.objects.get(username=username, domain_id=domain)
except cls.DoesNotExist:
if fetch and not local:
- return cls.objects.create(handle=handle, local=False)
+ actor_uri, handle = async_to_sync(cls.fetch_webfinger)(handle)
+ username, domain = handle.split("@")
+ domain = Domain.get_remote_domain(domain)
+ return cls.objects.create(
+ actor_uri=actor_uri,
+ username=username,
+ domain_id=domain,
+ local=False,
+ )
return None
@classmethod
@@ -108,9 +140,17 @@ class Identity(models.Model):
except cls.DoesNotExist:
return cls.objects.create(actor_uri=uri, local=False)
+ ### Dynamic properties ###
+
+ @property
+ def name_or_handle(self):
+ return self.name or self.handle
+
@property
def handle(self):
- return f"{self.username}@{self.domain_id}"
+ if self.domain_id:
+ return f"{self.username}@{self.domain_id}"
+ return f"{self.username}@UNKNOWN-DOMAIN"
@property
def data_age(self) -> float:
@@ -123,23 +163,12 @@ class Identity(models.Model):
return 10000000000
return (timezone.now() - self.fetched).total_seconds()
- def generate_keypair(self):
- if not self.local:
- raise ValueError("Cannot generate keypair for remote user")
- private_key = rsa.generate_private_key(
- public_exponent=65537,
- key_size=2048,
- )
- self.private_key = private_key.private_bytes(
- encoding=serialization.Encoding.PEM,
- format=serialization.PrivateFormat.PKCS8,
- encryption_algorithm=serialization.NoEncryption(),
- )
- self.public_key = private_key.public_key().public_bytes(
- encoding=serialization.Encoding.PEM,
- format=serialization.PublicFormat.SubjectPublicKeyInfo,
- )
- self.save()
+ @property
+ def outdated(self) -> bool:
+ # TODO: Setting
+ return self.data_age > 60 * 24 * 24
+
+ ### Actor/Webfinger fetching ###
@classmethod
async def fetch_webfinger(cls, handle: str) -> Tuple[Optional[str], Optional[str]]:
@@ -189,6 +218,8 @@ class Identity(models.Model):
self.outbox_uri = document.get("outbox")
self.summary = document.get("summary")
self.username = document.get("preferredUsername")
+ if "@value" in self.username:
+ self.username = self.username["@value"]
self.manually_approves_followers = document.get(
"as:manuallyApprovesFollowers"
)
@@ -214,23 +245,42 @@ class Identity(models.Model):
await sync_to_async(self.save)()
return True
- def sign(self, cleartext: str) -> str:
+ ### Cryptography ###
+
+ def generate_keypair(self):
+ if not self.local:
+ raise ValueError("Cannot generate keypair for remote user")
+ private_key = rsa.generate_private_key(
+ public_exponent=65537,
+ key_size=2048,
+ )
+ self.private_key = private_key.private_bytes(
+ encoding=serialization.Encoding.PEM,
+ format=serialization.PrivateFormat.PKCS8,
+ encryption_algorithm=serialization.NoEncryption(),
+ ).decode("ascii")
+ self.public_key = (
+ private_key.public_key()
+ .public_bytes(
+ encoding=serialization.Encoding.PEM,
+ format=serialization.PublicFormat.SubjectPublicKeyInfo,
+ )
+ .decode("ascii")
+ )
+ self.save()
+
+ def sign(self, cleartext: str) -> bytes:
if not self.private_key:
raise ValueError("Cannot sign - no private key")
- private_key = serialization.load_pem_private_key(
+ pkey = crypto.load_privatekey(
+ crypto.FILETYPE_PEM,
self.private_key.encode("ascii"),
- password=None,
)
- return base64.b64encode(
- private_key.sign(
- cleartext.encode("ascii"),
- padding.PSS(
- mgf=padding.MGF1(hashes.SHA256()),
- salt_length=padding.PSS.MAX_LENGTH,
- ),
- hashes.SHA256(),
- )
- ).decode("ascii")
+ return crypto.sign(
+ pkey,
+ cleartext.encode("ascii"),
+ "sha256",
+ )
def verify_signature(self, signature: bytes, cleartext: str) -> bool:
if not self.public_key:
@@ -247,55 +297,3 @@ class Identity(models.Model):
except crypto.Error:
return False
return True
-
- async def signed_request(self, host, method, path, document):
- """
- Delivers the document to the specified host, method, path and signed
- as this user.
- """
- date_string = http_date(timezone.now().timestamp())
- headers = {
- "(request-target)": f"{method} {path}",
- "Host": host,
- "Date": date_string,
- }
- headers_string = " ".join(headers.keys())
- signed_string = "\n".join(f"{name}: {value}" for name, value in headers.items())
- signature = self.sign(signed_string)
- del headers["(request-target)"]
- headers[
- "Signature"
- ] = f'keyId="{self.urls.key.full()}",headers="{headers_string}",signature="{signature}"'
- async with httpx.AsyncClient() as client:
- return await client.request(
- method,
- "https://{host}{path}",
- headers=headers,
- data=document,
- )
-
- def validate_signature(self, request):
- """
- Attempts to validate the signature on an incoming request.
- Returns False if the signature is invalid, None if it cannot be verified
- as we do not have the key locally, or the name of the actor if it is valid.
- """
- pass
-
- def __str__(self):
- return self.handle or self.actor_uri
-
- class urls(urlman.Urls):
- view = "/@{self.username}@{self.domain_id}/"
- view_short = "/@{self.username}/"
- actor = "{view}actor/"
- key = "{actor}#main-key"
- inbox = "{actor}inbox/"
- outbox = "{actor}outbox/"
- activate = "{view}activate/"
-
- def get_scheme(self, url):
- return "https"
-
- def get_hostname(self, url):
- return self.instance.domain.uri_domain