summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--core/signatures.py16
-rw-r--r--stator/graph.py27
-rw-r--r--stator/models.py20
-rw-r--r--stator/runner.py6
-rw-r--r--statuses/models/status.py3
-rw-r--r--users/admin.py13
-rw-r--r--users/models/domain.py12
-rw-r--r--users/models/follow.py114
-rw-r--r--users/models/inbox_message.py18
-rw-r--r--users/shortcuts.py5
-rw-r--r--users/tasks/__init__.py0
-rw-r--r--users/tasks/follow.py62
-rw-r--r--users/views/identity.py15
13 files changed, 207 insertions, 104 deletions
diff --git a/core/signatures.py b/core/signatures.py
index 805ae91..27e7f7d 100644
--- a/core/signatures.py
+++ b/core/signatures.py
@@ -1,6 +1,6 @@
import base64
import json
-from typing import Dict, List, Literal, TypedDict
+from typing import TYPE_CHECKING, Dict, List, Literal, TypedDict
from urllib.parse import urlparse
import httpx
@@ -8,7 +8,9 @@ from cryptography.hazmat.primitives import hashes
from django.http import HttpRequest
from django.utils.http import http_date
-from users.models import Identity
+# Prevent a circular import
+if TYPE_CHECKING:
+ from users.models import Identity
class HttpSignature:
@@ -73,7 +75,7 @@ class HttpSignature:
self,
uri: str,
body: Dict,
- identity: Identity,
+ identity: "Identity",
content_type: str = "application/json",
method: Literal["post"] = "post",
):
@@ -105,13 +107,17 @@ class HttpSignature:
del headers["(request-target)"]
async with httpx.AsyncClient() as client:
print(f"Calling {method} {uri}")
- print(body)
- return await client.request(
+ response = await client.request(
method,
uri,
headers=headers,
content=body_bytes,
)
+ if response.status_code >= 400:
+ raise ValueError(
+ f"Request error: {response.status_code} {response.content}"
+ )
+ return response
class SignatureDetails(TypedDict):
diff --git a/stator/graph.py b/stator/graph.py
index 7a8455c..00ef1c4 100644
--- a/stator/graph.py
+++ b/stator/graph.py
@@ -41,6 +41,7 @@ class StateGraph:
initial_state = state
# Collect terminal states
if state.terminal:
+ state.externally_progressed = True
terminal_states.add(state)
# Ensure they do NOT have a handler
try:
@@ -52,17 +53,18 @@ class StateGraph:
f"Terminal state '{state}' should not have a handler method ({state.handler_name})"
)
else:
- # Ensure non-terminal states have a try interval and a handler
- if not state.try_interval:
- raise ValueError(
- f"State '{state}' has no try_interval and is not terminal"
- )
- try:
- state.handler
- except AttributeError:
- raise ValueError(
- f"State '{state}' does not have a handler method ({state.handler_name})"
- )
+ # Ensure non-terminal/manual states have a try interval and a handler
+ if not state.externally_progressed:
+ if not state.try_interval:
+ raise ValueError(
+ f"State '{state}' has no try_interval and is not terminal or manual"
+ )
+ try:
+ state.handler
+ except AttributeError:
+ raise ValueError(
+ f"State '{state}' does not have a handler method ({state.handler_name})"
+ )
if initial_state is None:
raise ValueError("The graph has no initial state")
cls.initial_state = initial_state
@@ -80,9 +82,11 @@ class State:
self,
try_interval: Optional[float] = None,
handler_name: Optional[str] = None,
+ externally_progressed: bool = False,
):
self.try_interval = try_interval
self.handler_name = handler_name
+ self.externally_progressed = externally_progressed
self.parents: Set["State"] = set()
self.children: Set["State"] = set()
@@ -118,6 +122,7 @@ class State:
@property
def handler(self) -> Callable[[Any], Optional[str]]:
+ # Retrieve it by name off the graph
if self.handler_name is None:
raise AttributeError("No handler defined")
return getattr(self.graph, self.handler_name)
diff --git a/stator/models.py b/stator/models.py
index 50ee622..072a3ed 100644
--- a/stator/models.py
+++ b/stator/models.py
@@ -80,7 +80,7 @@ class StatorModel(models.Model):
q = models.Q()
for state in cls.state_graph.states.values():
state = cast(State, state)
- if not state.terminal:
+ if not state.externally_progressed:
q = q | models.Q(
(
models.Q(
@@ -135,17 +135,31 @@ class StatorModel(models.Model):
self.state_ready = True
self.save()
- async def atransition_attempt(self) -> Optional[str]:
+ async def atransition_attempt(self) -> Optional[State]:
"""
Attempts to transition the current state by running its handler(s).
"""
+ current_state = self.state_graph.states[self.state]
+ # If it's a manual progression state don't even try
+ # We shouldn't really be here in this case, but it could be a race condition
+ if current_state.externally_progressed:
+ print("Externally progressed state!")
+ return None
try:
- next_state = await self.state_graph.states[self.state].handler(self)
+ next_state = await current_state.handler(self)
except BaseException as e:
await StatorError.acreate_from_instance(self, e)
traceback.print_exc()
else:
if next_state:
+ # Ensure it's a State object
+ if isinstance(next_state, str):
+ next_state = self.state_graph.states[next_state]
+ # Ensure it's a child
+ if next_state not in current_state.children:
+ raise ValueError(
+ f"Cannot transition from {current_state} to {next_state} - not a declared transition"
+ )
await self.atransition_perform(next_state)
return next_state
await self.__class__.objects.filter(pk=self.pk).aupdate(
diff --git a/stator/runner.py b/stator/runner.py
index 1392e4d..0b42b27 100644
--- a/stator/runner.py
+++ b/stator/runner.py
@@ -50,9 +50,6 @@ class StatorRunner:
min(space_remaining, self.MAX_TASKS_PER_MODEL),
timezone.now() + datetime.timedelta(seconds=self.LOCK_TIMEOUT),
):
- print(
- f"Attempting transition on {instance._meta.label_lower}#{instance.pk}"
- )
self.tasks.append(
asyncio.create_task(self.run_transition(instance))
)
@@ -76,6 +73,9 @@ class StatorRunner:
Wrapper for atransition_attempt with fallback error handling
"""
try:
+ print(
+ f"Attempting transition on {instance._meta.label_lower}#{instance.pk} from state {instance.state}"
+ )
await instance.atransition_attempt()
except BaseException:
traceback.print_exc()
diff --git a/statuses/models/status.py b/statuses/models/status.py
index bfc8eb9..b12a595 100644
--- a/statuses/models/status.py
+++ b/statuses/models/status.py
@@ -27,6 +27,9 @@ class Status(models.Model):
updated = models.DateTimeField(auto_now=True)
deleted = models.DateTimeField(null=True, blank=True)
+ class Meta:
+ verbose_name_plural = "statuses"
+
@classmethod
def create_local(cls, identity, text: str):
return cls.objects.create(
diff --git a/users/admin.py b/users/admin.py
index f2b807c..d8f2931 100644
--- a/users/admin.py
+++ b/users/admin.py
@@ -1,6 +1,6 @@
from django.contrib import admin
-from users.models import Domain, Follow, Identity, User, UserEvent
+from users.models import Domain, Follow, Identity, InboxMessage, User, UserEvent
@admin.register(Domain)
@@ -26,3 +26,14 @@ class IdentityAdmin(admin.ModelAdmin):
@admin.register(Follow)
class FollowAdmin(admin.ModelAdmin):
list_display = ["id", "source", "target", "state"]
+
+
+@admin.register(InboxMessage)
+class InboxMessageAdmin(admin.ModelAdmin):
+ list_display = ["id", "state", "message_type"]
+ actions = ["reset_state"]
+
+ @admin.action(description="Reset State")
+ def reset_state(self, request, queryset):
+ for instance in queryset:
+ instance.transition_perform("received")
diff --git a/users/models/domain.py b/users/models/domain.py
index 4ac6ee9..a3815ee 100644
--- a/users/models/domain.py
+++ b/users/models/domain.py
@@ -81,3 +81,15 @@ class Domain(models.Model):
def __str__(self):
return self.domain
+
+ def save(self, *args, **kwargs):
+ # Ensure that we are not conflicting with other domains
+ if Domain.objects.filter(service_domain=self.domain).exists():
+ raise ValueError(
+ f"Domain {self.domain} is already a service domain elsewhere!"
+ )
+ if self.service_domain:
+ if Domain.objects.filter(domain=self.service_domain).exists():
+ raise ValueError(
+ f"Service domain {self.service_domain} is already a domain elsewhere!"
+ )
diff --git a/users/models/follow.py b/users/models/follow.py
index 6f62481..94ad40f 100644
--- a/users/models/follow.py
+++ b/users/models/follow.py
@@ -2,24 +2,110 @@ from typing import Optional
from django.db import models
+from core.ld import canonicalise
+from core.signatures import HttpSignature
from stator.models import State, StateField, StateGraph, StatorModel
class FollowStates(StateGraph):
unrequested = State(try_interval=30)
- requested = State(try_interval=24 * 60 * 60)
- accepted = State()
-
- unrequested.transitions_to(requested)
- requested.transitions_to(accepted)
+ local_requested = State(try_interval=24 * 60 * 60)
+ remote_requested = State(try_interval=24 * 60 * 60)
+ accepted = State(externally_progressed=True)
+ undone_locally = State(try_interval=60 * 60)
+ undone_remotely = State()
+
+ unrequested.transitions_to(local_requested)
+ unrequested.transitions_to(remote_requested)
+ local_requested.transitions_to(accepted)
+ remote_requested.transitions_to(accepted)
+ accepted.transitions_to(undone_locally)
+ undone_locally.transitions_to(undone_remotely)
@classmethod
async def handle_unrequested(cls, instance: "Follow"):
- print("Would have tried to follow on", instance)
+ # Re-retrieve the follow with more things linked
+ follow = await Follow.objects.select_related(
+ "source", "source__domain", "target"
+ ).aget(pk=instance.pk)
+ # Remote follows should not be here
+ if not follow.source.local:
+ return cls.remote_requested
+ # Construct the request
+ request = canonicalise(
+ {
+ "@context": "https://www.w3.org/ns/activitystreams",
+ "id": follow.uri,
+ "type": "Follow",
+ "actor": follow.source.actor_uri,
+ "object": follow.target.actor_uri,
+ }
+ )
+ # Sign it and send it
+ await HttpSignature.signed_request(
+ follow.target.inbox_uri, request, follow.source
+ )
+ return cls.local_requested
+
+ @classmethod
+ async def handle_local_requested(cls, instance: "Follow"):
+ # TODO: Resend follow requests occasionally
+ pass
+
+ @classmethod
+ async def handle_remote_requested(cls, instance: "Follow"):
+ # Re-retrieve the follow with more things linked
+ follow = await Follow.objects.select_related(
+ "source", "source__domain", "target"
+ ).aget(pk=instance.pk)
+ # Send an accept
+ request = canonicalise(
+ {
+ "@context": "https://www.w3.org/ns/activitystreams",
+ "id": follow.target.actor_uri + f"follow/{follow.pk}/#accept",
+ "type": "Follow",
+ "actor": follow.source.actor_uri,
+ "object": {
+ "id": follow.uri,
+ "type": "Follow",
+ "actor": follow.source.actor_uri,
+ "object": follow.target.actor_uri,
+ },
+ }
+ )
+ # Sign it and send it
+ await HttpSignature.signed_request(
+ follow.source.inbox_uri,
+ request,
+ identity=follow.target,
+ )
+ return cls.accepted
@classmethod
- async def handle_requested(cls, instance: "Follow"):
- print("Would have tried to requested on", instance)
+ async def handle_undone_locally(cls, instance: "Follow"):
+ follow = Follow.objects.select_related(
+ "source", "source__domain", "target"
+ ).get(pk=instance.pk)
+ # Construct the request
+ request = canonicalise(
+ {
+ "@context": "https://www.w3.org/ns/activitystreams",
+ "id": follow.uri + "#undo",
+ "type": "Undo",
+ "actor": follow.source.actor_uri,
+ "object": {
+ "id": follow.uri,
+ "type": "Follow",
+ "actor": follow.source.actor_uri,
+ "object": follow.target.actor_uri,
+ },
+ }
+ )
+ # Sign it and send it
+ await HttpSignature.signed_request(
+ follow.target.inbox_uri, request, follow.source
+ )
+ return cls.undone_remotely
class Follow(StatorModel):
@@ -83,11 +169,17 @@ class Follow(StatorModel):
follow = cls.maybe_get(source=source, target=target)
if follow is None:
follow = Follow.objects.create(source=source, target=target, uri=uri)
- if follow.state == FollowStates.fresh:
- follow.transition_perform(FollowStates.requested)
+ if follow.state == FollowStates.unrequested:
+ follow.transition_perform(FollowStates.remote_requested)
@classmethod
def remote_accepted(cls, source, target):
+ print(f"accepted follow source {source} target {target}")
follow = cls.maybe_get(source=source, target=target)
- if follow and follow.state == FollowStates.requested:
+ print(f"accepting follow {follow}")
+ if follow and follow.state in [
+ FollowStates.unrequested,
+ FollowStates.local_requested,
+ ]:
follow.transition_perform(FollowStates.accepted)
+ print("accepted")
diff --git a/users/models/inbox_message.py b/users/models/inbox_message.py
index 0dbdc3a..54b05e9 100644
--- a/users/models/inbox_message.py
+++ b/users/models/inbox_message.py
@@ -13,7 +13,7 @@ class InboxMessageStates(StateGraph):
@classmethod
async def handle_received(cls, instance: "InboxMessage"):
- type = instance.message["type"].lower()
+ type = instance.message_type
if type == "follow":
await instance.follow_request()
elif type == "accept":
@@ -30,6 +30,7 @@ class InboxMessageStates(StateGraph):
raise ValueError(f"Cannot handle activity of type undo.{inner_type}")
else:
raise ValueError(f"Cannot handle activity of type {type}")
+ return cls.processed
class InboxMessage(StatorModel):
@@ -60,10 +61,17 @@ class InboxMessage(StatorModel):
"""
Handles an incoming acceptance of one of our follow requests
"""
- Follow.remote_accepted(
- source=Identity.by_actor_uri_with_create(self.message["actor"]),
- target=Identity.by_actor_uri(self.message["object"]),
- )
+ target = Identity.by_actor_uri_with_create(self.message["actor"])
+ source = Identity.by_actor_uri(self.message["object"]["actor"])
+ if source is None:
+ raise ValueError(
+ f"Follow-Accept has invalid source {self.message['object']['actor']}"
+ )
+ Follow.remote_accepted(source=source, target=target)
+
+ @property
+ def message_type(self):
+ return self.message["type"].lower()
async def follow_undo(self):
"""
diff --git a/users/shortcuts.py b/users/shortcuts.py
index 8e20a09..0726218 100644
--- a/users/shortcuts.py
+++ b/users/shortcuts.py
@@ -19,7 +19,10 @@ def by_handle_or_404(request, handle, local=True, fetch=False) -> Identity:
else:
username, domain = handle.split("@", 1)
# Resolve the domain to the display domain
- domain = Domain.get_remote_domain(domain).domain
+ domain_instance = Domain.get_domain(domain)
+ if domain_instance is None:
+ domain_instance = Domain.get_remote_domain(domain)
+ domain = domain_instance.domain
identity = Identity.by_username_and_domain(
username,
domain,
diff --git a/users/tasks/__init__.py b/users/tasks/__init__.py
deleted file mode 100644
index e69de29..0000000
--- a/users/tasks/__init__.py
+++ /dev/null
diff --git a/users/tasks/follow.py b/users/tasks/follow.py
deleted file mode 100644
index 0f802cf..0000000
--- a/users/tasks/follow.py
+++ /dev/null
@@ -1,62 +0,0 @@
-from core.ld import canonicalise
-from core.signatures import HttpSignature
-from users.models import Follow
-
-
-async def handle_follow_request(task_handler):
- """
- Request a follow from a remote server
- """
- follow = await Follow.objects.select_related(
- "source", "source__domain", "target"
- ).aget(pk=task_handler.subject)
- # Construct the request
- request = canonicalise(
- {
- "@context": "https://www.w3.org/ns/activitystreams",
- "id": follow.uri,
- "type": "Follow",
- "actor": follow.source.actor_uri,
- "object": follow.target.actor_uri,
- }
- )
- # Sign it and send it
- response = await HttpSignature.signed_request(
- follow.target.inbox_uri, request, follow.source
- )
- if response.status_code >= 400:
- raise ValueError(f"Request error: {response.status_code} {response.content}")
- await Follow.objects.filter(pk=follow.pk).aupdate(requested=True)
-
-
-def send_follow_undo(id):
- """
- Request a follow from a remote server
- """
- follow = Follow.objects.select_related("source", "source__domain", "target").get(
- pk=id
- )
- # Construct the request
- request = canonicalise(
- {
- "@context": "https://www.w3.org/ns/activitystreams",
- "id": follow.uri + "#undo",
- "type": "Undo",
- "actor": follow.source.actor_uri,
- "object": {
- "id": follow.uri,
- "type": "Follow",
- "actor": follow.source.actor_uri,
- "object": follow.target.actor_uri,
- },
- }
- )
- # Sign it and send it
- from asgiref.sync import async_to_sync
-
- response = async_to_sync(HttpSignature.signed_request)(
- follow.target.inbox_uri, request, follow.source
- )
- if response.status_code >= 400:
- raise ValueError(f"Request error: {response.status_code} {response.content}")
- print(response)
diff --git a/users/views/identity.py b/users/views/identity.py
index 3e69dae..0aed7fa 100644
--- a/users/views/identity.py
+++ b/users/views/identity.py
@@ -21,6 +21,10 @@ from users.models import Domain, Follow, Identity, IdentityStates, InboxMessage
from users.shortcuts import by_handle_or_404
+class HttpResponseUnauthorized(HttpResponse):
+ status_code = 401
+
+
class ViewIdentity(TemplateView):
template_name = "identity/view.html"
@@ -188,20 +192,26 @@ class Inbox(View):
if "HTTP_DIGEST" in request.META:
expected_digest = HttpSignature.calculate_digest(request.body)
if request.META["HTTP_DIGEST"] != expected_digest:
+ print("Wrong digest")
return HttpResponseBadRequest("Digest is incorrect")
# Verify date header
if "HTTP_DATE" in request.META:
header_date = parse_http_date(request.META["HTTP_DATE"])
if abs(timezone.now().timestamp() - header_date) > 60:
+ print(
+ f"Date mismatch - they sent {header_date}, now is {timezone.now().timestamp()}"
+ )
return HttpResponseBadRequest("Date is too far away")
# Get the signature details
if "HTTP_SIGNATURE" not in request.META:
+ print("No signature")
return HttpResponseBadRequest("No signature present")
signature_details = HttpSignature.parse_signature(
request.META["HTTP_SIGNATURE"]
)
# Reject unknown algorithms
if signature_details["algorithm"] != "rsa-sha256":
+ print("Unknown sig algo")
return HttpResponseBadRequest("Unknown signature algorithm")
# Create the signature payload
headers_string = HttpSignature.headers_from_request(
@@ -217,13 +227,14 @@ class Inbox(View):
# See if we can fetch it right now
async_to_sync(identity.fetch_actor)()
if not identity.public_key:
+ print("Cannot get actor")
return HttpResponseBadRequest("Cannot retrieve actor")
if not identity.verify_signature(
signature_details["signature"], headers_string
):
- return HttpResponseBadRequest("Bad signature")
+ return HttpResponseUnauthorized("Bad signature")
# Hand off the item to be processed by the queue
- InboxMessage.objects.create(message=document)
+ InboxMessage.objects.create(message=document, state_ready=True)
return HttpResponse(status=202)