diff options
| -rw-r--r-- | core/signatures.py | 16 | ||||
| -rw-r--r-- | stator/graph.py | 27 | ||||
| -rw-r--r-- | stator/models.py | 20 | ||||
| -rw-r--r-- | stator/runner.py | 6 | ||||
| -rw-r--r-- | statuses/models/status.py | 3 | ||||
| -rw-r--r-- | users/admin.py | 13 | ||||
| -rw-r--r-- | users/models/domain.py | 12 | ||||
| -rw-r--r-- | users/models/follow.py | 114 | ||||
| -rw-r--r-- | users/models/inbox_message.py | 18 | ||||
| -rw-r--r-- | users/shortcuts.py | 5 | ||||
| -rw-r--r-- | users/tasks/__init__.py | 0 | ||||
| -rw-r--r-- | users/tasks/follow.py | 62 | ||||
| -rw-r--r-- | users/views/identity.py | 15 | 
13 files changed, 207 insertions, 104 deletions
diff --git a/core/signatures.py b/core/signatures.py index 805ae91..27e7f7d 100644 --- a/core/signatures.py +++ b/core/signatures.py @@ -1,6 +1,6 @@  import base64  import json -from typing import Dict, List, Literal, TypedDict +from typing import TYPE_CHECKING, Dict, List, Literal, TypedDict  from urllib.parse import urlparse  import httpx @@ -8,7 +8,9 @@ from cryptography.hazmat.primitives import hashes  from django.http import HttpRequest  from django.utils.http import http_date -from users.models import Identity +# Prevent a circular import +if TYPE_CHECKING: +    from users.models import Identity  class HttpSignature: @@ -73,7 +75,7 @@ class HttpSignature:          self,          uri: str,          body: Dict, -        identity: Identity, +        identity: "Identity",          content_type: str = "application/json",          method: Literal["post"] = "post",      ): @@ -105,13 +107,17 @@ class HttpSignature:          del headers["(request-target)"]          async with httpx.AsyncClient() as client:              print(f"Calling {method} {uri}") -            print(body) -            return await client.request( +            response = await client.request(                  method,                  uri,                  headers=headers,                  content=body_bytes,              ) +            if response.status_code >= 400: +                raise ValueError( +                    f"Request error: {response.status_code} {response.content}" +                ) +            return response  class SignatureDetails(TypedDict): diff --git a/stator/graph.py b/stator/graph.py index 7a8455c..00ef1c4 100644 --- a/stator/graph.py +++ b/stator/graph.py @@ -41,6 +41,7 @@ class StateGraph:                  initial_state = state              # Collect terminal states              if state.terminal: +                state.externally_progressed = True                  terminal_states.add(state)                  # Ensure they do NOT have a handler                  try: @@ -52,17 +53,18 @@ class StateGraph:                          f"Terminal state '{state}' should not have a handler method ({state.handler_name})"                      )              else: -                # Ensure non-terminal states have a try interval and a handler -                if not state.try_interval: -                    raise ValueError( -                        f"State '{state}' has no try_interval and is not terminal" -                    ) -                try: -                    state.handler -                except AttributeError: -                    raise ValueError( -                        f"State '{state}' does not have a handler method ({state.handler_name})" -                    ) +                # Ensure non-terminal/manual states have a try interval and a handler +                if not state.externally_progressed: +                    if not state.try_interval: +                        raise ValueError( +                            f"State '{state}' has no try_interval and is not terminal or manual" +                        ) +                    try: +                        state.handler +                    except AttributeError: +                        raise ValueError( +                            f"State '{state}' does not have a handler method ({state.handler_name})" +                        )          if initial_state is None:              raise ValueError("The graph has no initial state")          cls.initial_state = initial_state @@ -80,9 +82,11 @@ class State:          self,          try_interval: Optional[float] = None,          handler_name: Optional[str] = None, +        externally_progressed: bool = False,      ):          self.try_interval = try_interval          self.handler_name = handler_name +        self.externally_progressed = externally_progressed          self.parents: Set["State"] = set()          self.children: Set["State"] = set() @@ -118,6 +122,7 @@ class State:      @property      def handler(self) -> Callable[[Any], Optional[str]]: +        # Retrieve it by name off the graph          if self.handler_name is None:              raise AttributeError("No handler defined")          return getattr(self.graph, self.handler_name) diff --git a/stator/models.py b/stator/models.py index 50ee622..072a3ed 100644 --- a/stator/models.py +++ b/stator/models.py @@ -80,7 +80,7 @@ class StatorModel(models.Model):          q = models.Q()          for state in cls.state_graph.states.values():              state = cast(State, state) -            if not state.terminal: +            if not state.externally_progressed:                  q = q | models.Q(                      (                          models.Q( @@ -135,17 +135,31 @@ class StatorModel(models.Model):          self.state_ready = True          self.save() -    async def atransition_attempt(self) -> Optional[str]: +    async def atransition_attempt(self) -> Optional[State]:          """          Attempts to transition the current state by running its handler(s).          """ +        current_state = self.state_graph.states[self.state] +        # If it's a manual progression state don't even try +        # We shouldn't really be here in this case, but it could be a race condition +        if current_state.externally_progressed: +            print("Externally progressed state!") +            return None          try: -            next_state = await self.state_graph.states[self.state].handler(self) +            next_state = await current_state.handler(self)          except BaseException as e:              await StatorError.acreate_from_instance(self, e)              traceback.print_exc()          else:              if next_state: +                # Ensure it's a State object +                if isinstance(next_state, str): +                    next_state = self.state_graph.states[next_state] +                # Ensure it's a child +                if next_state not in current_state.children: +                    raise ValueError( +                        f"Cannot transition from {current_state} to {next_state} - not a declared transition" +                    )                  await self.atransition_perform(next_state)                  return next_state          await self.__class__.objects.filter(pk=self.pk).aupdate( diff --git a/stator/runner.py b/stator/runner.py index 1392e4d..0b42b27 100644 --- a/stator/runner.py +++ b/stator/runner.py @@ -50,9 +50,6 @@ class StatorRunner:                          min(space_remaining, self.MAX_TASKS_PER_MODEL),                          timezone.now() + datetime.timedelta(seconds=self.LOCK_TIMEOUT),                      ): -                        print( -                            f"Attempting transition on {instance._meta.label_lower}#{instance.pk}" -                        )                          self.tasks.append(                              asyncio.create_task(self.run_transition(instance))                          ) @@ -76,6 +73,9 @@ class StatorRunner:          Wrapper for atransition_attempt with fallback error handling          """          try: +            print( +                f"Attempting transition on {instance._meta.label_lower}#{instance.pk} from state {instance.state}" +            )              await instance.atransition_attempt()          except BaseException:              traceback.print_exc() diff --git a/statuses/models/status.py b/statuses/models/status.py index bfc8eb9..b12a595 100644 --- a/statuses/models/status.py +++ b/statuses/models/status.py @@ -27,6 +27,9 @@ class Status(models.Model):      updated = models.DateTimeField(auto_now=True)      deleted = models.DateTimeField(null=True, blank=True) +    class Meta: +        verbose_name_plural = "statuses" +      @classmethod      def create_local(cls, identity, text: str):          return cls.objects.create( diff --git a/users/admin.py b/users/admin.py index f2b807c..d8f2931 100644 --- a/users/admin.py +++ b/users/admin.py @@ -1,6 +1,6 @@  from django.contrib import admin -from users.models import Domain, Follow, Identity, User, UserEvent +from users.models import Domain, Follow, Identity, InboxMessage, User, UserEvent  @admin.register(Domain) @@ -26,3 +26,14 @@ class IdentityAdmin(admin.ModelAdmin):  @admin.register(Follow)  class FollowAdmin(admin.ModelAdmin):      list_display = ["id", "source", "target", "state"] + + +@admin.register(InboxMessage) +class InboxMessageAdmin(admin.ModelAdmin): +    list_display = ["id", "state", "message_type"] +    actions = ["reset_state"] + +    @admin.action(description="Reset State") +    def reset_state(self, request, queryset): +        for instance in queryset: +            instance.transition_perform("received") diff --git a/users/models/domain.py b/users/models/domain.py index 4ac6ee9..a3815ee 100644 --- a/users/models/domain.py +++ b/users/models/domain.py @@ -81,3 +81,15 @@ class Domain(models.Model):      def __str__(self):          return self.domain + +    def save(self, *args, **kwargs): +        # Ensure that we are not conflicting with other domains +        if Domain.objects.filter(service_domain=self.domain).exists(): +            raise ValueError( +                f"Domain {self.domain} is already a service domain elsewhere!" +            ) +        if self.service_domain: +            if Domain.objects.filter(domain=self.service_domain).exists(): +                raise ValueError( +                    f"Service domain {self.service_domain} is already a domain elsewhere!" +                ) diff --git a/users/models/follow.py b/users/models/follow.py index 6f62481..94ad40f 100644 --- a/users/models/follow.py +++ b/users/models/follow.py @@ -2,24 +2,110 @@ from typing import Optional  from django.db import models +from core.ld import canonicalise +from core.signatures import HttpSignature  from stator.models import State, StateField, StateGraph, StatorModel  class FollowStates(StateGraph):      unrequested = State(try_interval=30) -    requested = State(try_interval=24 * 60 * 60) -    accepted = State() - -    unrequested.transitions_to(requested) -    requested.transitions_to(accepted) +    local_requested = State(try_interval=24 * 60 * 60) +    remote_requested = State(try_interval=24 * 60 * 60) +    accepted = State(externally_progressed=True) +    undone_locally = State(try_interval=60 * 60) +    undone_remotely = State() + +    unrequested.transitions_to(local_requested) +    unrequested.transitions_to(remote_requested) +    local_requested.transitions_to(accepted) +    remote_requested.transitions_to(accepted) +    accepted.transitions_to(undone_locally) +    undone_locally.transitions_to(undone_remotely)      @classmethod      async def handle_unrequested(cls, instance: "Follow"): -        print("Would have tried to follow on", instance) +        # Re-retrieve the follow with more things linked +        follow = await Follow.objects.select_related( +            "source", "source__domain", "target" +        ).aget(pk=instance.pk) +        # Remote follows should not be here +        if not follow.source.local: +            return cls.remote_requested +        # Construct the request +        request = canonicalise( +            { +                "@context": "https://www.w3.org/ns/activitystreams", +                "id": follow.uri, +                "type": "Follow", +                "actor": follow.source.actor_uri, +                "object": follow.target.actor_uri, +            } +        ) +        # Sign it and send it +        await HttpSignature.signed_request( +            follow.target.inbox_uri, request, follow.source +        ) +        return cls.local_requested + +    @classmethod +    async def handle_local_requested(cls, instance: "Follow"): +        # TODO: Resend follow requests occasionally +        pass + +    @classmethod +    async def handle_remote_requested(cls, instance: "Follow"): +        # Re-retrieve the follow with more things linked +        follow = await Follow.objects.select_related( +            "source", "source__domain", "target" +        ).aget(pk=instance.pk) +        # Send an accept +        request = canonicalise( +            { +                "@context": "https://www.w3.org/ns/activitystreams", +                "id": follow.target.actor_uri + f"follow/{follow.pk}/#accept", +                "type": "Follow", +                "actor": follow.source.actor_uri, +                "object": { +                    "id": follow.uri, +                    "type": "Follow", +                    "actor": follow.source.actor_uri, +                    "object": follow.target.actor_uri, +                }, +            } +        ) +        # Sign it and send it +        await HttpSignature.signed_request( +            follow.source.inbox_uri, +            request, +            identity=follow.target, +        ) +        return cls.accepted      @classmethod -    async def handle_requested(cls, instance: "Follow"): -        print("Would have tried to requested on", instance) +    async def handle_undone_locally(cls, instance: "Follow"): +        follow = Follow.objects.select_related( +            "source", "source__domain", "target" +        ).get(pk=instance.pk) +        # Construct the request +        request = canonicalise( +            { +                "@context": "https://www.w3.org/ns/activitystreams", +                "id": follow.uri + "#undo", +                "type": "Undo", +                "actor": follow.source.actor_uri, +                "object": { +                    "id": follow.uri, +                    "type": "Follow", +                    "actor": follow.source.actor_uri, +                    "object": follow.target.actor_uri, +                }, +            } +        ) +        # Sign it and send it +        await HttpSignature.signed_request( +            follow.target.inbox_uri, request, follow.source +        ) +        return cls.undone_remotely  class Follow(StatorModel): @@ -83,11 +169,17 @@ class Follow(StatorModel):          follow = cls.maybe_get(source=source, target=target)          if follow is None:              follow = Follow.objects.create(source=source, target=target, uri=uri) -        if follow.state == FollowStates.fresh: -            follow.transition_perform(FollowStates.requested) +        if follow.state == FollowStates.unrequested: +            follow.transition_perform(FollowStates.remote_requested)      @classmethod      def remote_accepted(cls, source, target): +        print(f"accepted follow source {source} target {target}")          follow = cls.maybe_get(source=source, target=target) -        if follow and follow.state == FollowStates.requested: +        print(f"accepting follow {follow}") +        if follow and follow.state in [ +            FollowStates.unrequested, +            FollowStates.local_requested, +        ]:              follow.transition_perform(FollowStates.accepted) +            print("accepted") diff --git a/users/models/inbox_message.py b/users/models/inbox_message.py index 0dbdc3a..54b05e9 100644 --- a/users/models/inbox_message.py +++ b/users/models/inbox_message.py @@ -13,7 +13,7 @@ class InboxMessageStates(StateGraph):      @classmethod      async def handle_received(cls, instance: "InboxMessage"): -        type = instance.message["type"].lower() +        type = instance.message_type          if type == "follow":              await instance.follow_request()          elif type == "accept": @@ -30,6 +30,7 @@ class InboxMessageStates(StateGraph):                  raise ValueError(f"Cannot handle activity of type undo.{inner_type}")          else:              raise ValueError(f"Cannot handle activity of type {type}") +        return cls.processed  class InboxMessage(StatorModel): @@ -60,10 +61,17 @@ class InboxMessage(StatorModel):          """          Handles an incoming acceptance of one of our follow requests          """ -        Follow.remote_accepted( -            source=Identity.by_actor_uri_with_create(self.message["actor"]), -            target=Identity.by_actor_uri(self.message["object"]), -        ) +        target = Identity.by_actor_uri_with_create(self.message["actor"]) +        source = Identity.by_actor_uri(self.message["object"]["actor"]) +        if source is None: +            raise ValueError( +                f"Follow-Accept has invalid source {self.message['object']['actor']}" +            ) +        Follow.remote_accepted(source=source, target=target) + +    @property +    def message_type(self): +        return self.message["type"].lower()      async def follow_undo(self):          """ diff --git a/users/shortcuts.py b/users/shortcuts.py index 8e20a09..0726218 100644 --- a/users/shortcuts.py +++ b/users/shortcuts.py @@ -19,7 +19,10 @@ def by_handle_or_404(request, handle, local=True, fetch=False) -> Identity:      else:          username, domain = handle.split("@", 1)          # Resolve the domain to the display domain -        domain = Domain.get_remote_domain(domain).domain +        domain_instance = Domain.get_domain(domain) +        if domain_instance is None: +            domain_instance = Domain.get_remote_domain(domain) +        domain = domain_instance.domain      identity = Identity.by_username_and_domain(          username,          domain, diff --git a/users/tasks/__init__.py b/users/tasks/__init__.py deleted file mode 100644 index e69de29..0000000 --- a/users/tasks/__init__.py +++ /dev/null diff --git a/users/tasks/follow.py b/users/tasks/follow.py deleted file mode 100644 index 0f802cf..0000000 --- a/users/tasks/follow.py +++ /dev/null @@ -1,62 +0,0 @@ -from core.ld import canonicalise -from core.signatures import HttpSignature -from users.models import Follow - - -async def handle_follow_request(task_handler): -    """ -    Request a follow from a remote server -    """ -    follow = await Follow.objects.select_related( -        "source", "source__domain", "target" -    ).aget(pk=task_handler.subject) -    # Construct the request -    request = canonicalise( -        { -            "@context": "https://www.w3.org/ns/activitystreams", -            "id": follow.uri, -            "type": "Follow", -            "actor": follow.source.actor_uri, -            "object": follow.target.actor_uri, -        } -    ) -    # Sign it and send it -    response = await HttpSignature.signed_request( -        follow.target.inbox_uri, request, follow.source -    ) -    if response.status_code >= 400: -        raise ValueError(f"Request error: {response.status_code} {response.content}") -    await Follow.objects.filter(pk=follow.pk).aupdate(requested=True) - - -def send_follow_undo(id): -    """ -    Request a follow from a remote server -    """ -    follow = Follow.objects.select_related("source", "source__domain", "target").get( -        pk=id -    ) -    # Construct the request -    request = canonicalise( -        { -            "@context": "https://www.w3.org/ns/activitystreams", -            "id": follow.uri + "#undo", -            "type": "Undo", -            "actor": follow.source.actor_uri, -            "object": { -                "id": follow.uri, -                "type": "Follow", -                "actor": follow.source.actor_uri, -                "object": follow.target.actor_uri, -            }, -        } -    ) -    # Sign it and send it -    from asgiref.sync import async_to_sync - -    response = async_to_sync(HttpSignature.signed_request)( -        follow.target.inbox_uri, request, follow.source -    ) -    if response.status_code >= 400: -        raise ValueError(f"Request error: {response.status_code} {response.content}") -    print(response) diff --git a/users/views/identity.py b/users/views/identity.py index 3e69dae..0aed7fa 100644 --- a/users/views/identity.py +++ b/users/views/identity.py @@ -21,6 +21,10 @@ from users.models import Domain, Follow, Identity, IdentityStates, InboxMessage  from users.shortcuts import by_handle_or_404 +class HttpResponseUnauthorized(HttpResponse): +    status_code = 401 + +  class ViewIdentity(TemplateView):      template_name = "identity/view.html" @@ -188,20 +192,26 @@ class Inbox(View):          if "HTTP_DIGEST" in request.META:              expected_digest = HttpSignature.calculate_digest(request.body)              if request.META["HTTP_DIGEST"] != expected_digest: +                print("Wrong digest")                  return HttpResponseBadRequest("Digest is incorrect")          # Verify date header          if "HTTP_DATE" in request.META:              header_date = parse_http_date(request.META["HTTP_DATE"])              if abs(timezone.now().timestamp() - header_date) > 60: +                print( +                    f"Date mismatch - they sent {header_date}, now is {timezone.now().timestamp()}" +                )                  return HttpResponseBadRequest("Date is too far away")          # Get the signature details          if "HTTP_SIGNATURE" not in request.META: +            print("No signature")              return HttpResponseBadRequest("No signature present")          signature_details = HttpSignature.parse_signature(              request.META["HTTP_SIGNATURE"]          )          # Reject unknown algorithms          if signature_details["algorithm"] != "rsa-sha256": +            print("Unknown sig algo")              return HttpResponseBadRequest("Unknown signature algorithm")          # Create the signature payload          headers_string = HttpSignature.headers_from_request( @@ -217,13 +227,14 @@ class Inbox(View):              # See if we can fetch it right now              async_to_sync(identity.fetch_actor)()          if not identity.public_key: +            print("Cannot get actor")              return HttpResponseBadRequest("Cannot retrieve actor")          if not identity.verify_signature(              signature_details["signature"], headers_string          ): -            return HttpResponseBadRequest("Bad signature") +            return HttpResponseUnauthorized("Bad signature")          # Hand off the item to be processed by the queue -        InboxMessage.objects.create(message=document) +        InboxMessage.objects.create(message=document, state_ready=True)          return HttpResponse(status=202)  | 
