diff options
author | Andrew Godwin | 2022-11-08 23:06:29 -0700 |
---|---|---|
committer | Andrew Godwin | 2022-11-09 22:29:49 -0700 |
commit | 61c324508e62bb640b4526183d0837fc57d742c2 (patch) | |
tree | 618ee8c88ce8a28224a187dc33b7c5fad6831d04 | |
parent | 8a0a7558894afce8d25b7f0dc16775e899b72a94 (diff) | |
download | takahe-61c324508e62bb640b4526183d0837fc57d742c2.tar.gz takahe-61c324508e62bb640b4526183d0837fc57d742c2.tar.bz2 takahe-61c324508e62bb640b4526183d0837fc57d742c2.zip |
Midway point in task refactor - changing direction
24 files changed, 698 insertions, 241 deletions
diff --git a/miniq/admin.py b/miniq/admin.py deleted file mode 100644 index 1166f89..0000000 --- a/miniq/admin.py +++ /dev/null @@ -1,21 +0,0 @@ -from django.contrib import admin - -from miniq.models import Task - - -@admin.register(Task) -class TaskAdmin(admin.ModelAdmin): - - list_display = ["id", "created", "type", "subject", "completed", "failed"] - ordering = ["-created"] - actions = ["reset"] - - @admin.action(description="Reset Task") - def reset(self, request, queryset): - queryset.update( - failed=None, - completed=None, - locked=None, - locked_by=None, - error=None, - ) diff --git a/miniq/migrations/0001_initial.py b/miniq/migrations/0001_initial.py deleted file mode 100644 index dc6d42b..0000000 --- a/miniq/migrations/0001_initial.py +++ /dev/null @@ -1,48 +0,0 @@ -# Generated by Django 4.1.3 on 2022-11-07 04:19 - -from django.db import migrations, models - - -class Migration(migrations.Migration): - - initial = True - - dependencies = [] - - operations = [ - migrations.CreateModel( - name="Task", - fields=[ - ( - "id", - models.BigAutoField( - auto_created=True, - primary_key=True, - serialize=False, - verbose_name="ID", - ), - ), - ( - "type", - models.CharField( - choices=[ - ("identity_fetch", "Identity Fetch"), - ("inbox_item", "Inbox Item"), - ("follow_request", "Follow Request"), - ("follow_acknowledge", "Follow Acknowledge"), - ], - max_length=500, - ), - ), - ("priority", models.IntegerField(default=0)), - ("subject", models.TextField()), - ("payload", models.JSONField(blank=True, null=True)), - ("error", models.TextField(blank=True, null=True)), - ("created", models.DateTimeField(auto_now_add=True)), - ("completed", models.DateTimeField(blank=True, null=True)), - ("failed", models.DateTimeField(blank=True, null=True)), - ("locked", models.DateTimeField(blank=True, null=True)), - ("locked_by", models.CharField(blank=True, max_length=500, null=True)), - ], - ), - ] diff --git a/miniq/models.py b/miniq/models.py deleted file mode 100644 index 24d311c..0000000 --- a/miniq/models.py +++ /dev/null @@ -1,71 +0,0 @@ -from typing import Optional - -from django.db import models, transaction -from django.utils import timezone - - -class Task(models.Model): - """ - A task that must be done by a queue processor - """ - - class TypeChoices(models.TextChoices): - identity_fetch = "identity_fetch" - inbox_item = "inbox_item" - follow_request = "follow_request" - follow_acknowledge = "follow_acknowledge" - - type = models.CharField(max_length=500, choices=TypeChoices.choices) - priority = models.IntegerField(default=0) - subject = models.TextField() - payload = models.JSONField(blank=True, null=True) - error = models.TextField(blank=True, null=True) - - created = models.DateTimeField(auto_now_add=True) - completed = models.DateTimeField(blank=True, null=True) - failed = models.DateTimeField(blank=True, null=True) - locked = models.DateTimeField(blank=True, null=True) - locked_by = models.CharField(max_length=500, blank=True, null=True) - - def __str__(self): - return f"{self.id}/{self.type}({self.subject})" - - @classmethod - def get_one_available(cls, processor_id) -> Optional["Task"]: - """ - Gets one task off the list while reserving it, atomically. - """ - with transaction.atomic(): - next_task = cls.objects.filter(locked__isnull=True).first() - if next_task is None: - return None - next_task.locked = timezone.now() - next_task.locked_by = processor_id - next_task.save() - return next_task - - @classmethod - def submit(cls, type, subject: str, payload=None, deduplicate=True): - # Deduplication is done against tasks that have not started yet only, - # and only on tasks without payloads - if deduplicate and not payload: - if cls.objects.filter( - type=type, - subject=subject, - completed__isnull=True, - failed__isnull=True, - locked__isnull=True, - ).exists(): - return - cls.objects.create(type=type, subject=subject, payload=payload) - - async def complete(self): - await self.__class__.objects.filter(id=self.id).aupdate( - completed=timezone.now() - ) - - async def fail(self, error): - await self.__class__.objects.filter(id=self.id).aupdate( - failed=timezone.now(), - error=error, - ) diff --git a/miniq/tasks.py b/miniq/tasks.py deleted file mode 100644 index fedf8fd..0000000 --- a/miniq/tasks.py +++ /dev/null @@ -1,34 +0,0 @@ -import traceback - -from users.tasks.follow import handle_follow_request -from users.tasks.identity import handle_identity_fetch -from users.tasks.inbox import handle_inbox_item - - -class TaskHandler: - - handlers = { - "identity_fetch": handle_identity_fetch, - "inbox_item": handle_inbox_item, - "follow_request": handle_follow_request, - } - - def __init__(self, task): - self.task = task - self.subject = self.task.subject - self.payload = self.task.payload - - async def handle(self): - try: - print(f"Task {self.task}: Starting") - if self.task.type not in self.handlers: - raise ValueError(f"Cannot handle type {self.task.type}") - await self.handlers[self.task.type]( - self, - ) - await self.task.complete() - print(f"Task {self.task}: Complete") - except BaseException as e: - print(f"Task {self.task}: Error {e}") - traceback.print_exc() - await self.task.fail(f"{e}\n\n" + traceback.format_exc()) diff --git a/miniq/views.py b/miniq/views.py deleted file mode 100644 index 80c9ee2..0000000 --- a/miniq/views.py +++ /dev/null @@ -1,51 +0,0 @@ -import asyncio -import time -import uuid - -from asgiref.sync import sync_to_async -from django.http import HttpResponse -from django.views import View - -from miniq.models import Task -from miniq.tasks import TaskHandler - - -class QueueProcessor(View): - """ - A view that takes some items off the queue and processes them. - Tries to limit its own runtime so it's within HTTP timeout limits. - """ - - START_TIMEOUT = 30 - TOTAL_TIMEOUT = 60 - LOCK_TIMEOUT = 200 - MAX_TASKS = 20 - - async def get(self, request): - start_time = time.monotonic() - processor_id = uuid.uuid4().hex - handled = 0 - self.tasks = [] - # For the first time period, launch tasks - while (time.monotonic() - start_time) < self.START_TIMEOUT: - # Remove completed tasks - self.tasks = [t for t in self.tasks if not t.done()] - # See if there's a new task - if len(self.tasks) < self.MAX_TASKS: - # Pop a task off the queue and run it - task = await sync_to_async(Task.get_one_available)(processor_id) - if task is not None: - self.tasks.append(asyncio.create_task(TaskHandler(task).handle())) - handled += 1 - # Prevent busylooping - await asyncio.sleep(0.01) - # TODO: Clean up old locks here - # Then wait for tasks to finish - while (time.monotonic() - start_time) < self.TOTAL_TIMEOUT: - # Remove completed tasks - self.tasks = [t for t in self.tasks if not t.done()] - if not self.tasks: - break - # Prevent busylooping - await asyncio.sleep(1) - return HttpResponse(f"{handled} tasks handled") diff --git a/miniq/__init__.py b/stator/__init__.py index e69de29..e69de29 100644 --- a/miniq/__init__.py +++ b/stator/__init__.py diff --git a/stator/admin.py b/stator/admin.py new file mode 100644 index 0000000..c04d775 --- /dev/null +++ b/stator/admin.py @@ -0,0 +1,8 @@ +from django.contrib import admin + +from stator.models import StatorTask + + +@admin.register(StatorTask) +class DomainAdmin(admin.ModelAdmin): + list_display = ["id", "model_label", "instance_pk", "locked_until"] diff --git a/miniq/apps.py b/stator/apps.py index 4c7e773..8910ecb 100644 --- a/miniq/apps.py +++ b/stator/apps.py @@ -1,6 +1,6 @@ from django.apps import AppConfig -class MiniqConfig(AppConfig): +class StatorConfig(AppConfig): default_auto_field = "django.db.models.BigAutoField" - name = "miniq" + name = "stator" diff --git a/stator/graph.py b/stator/graph.py new file mode 100644 index 0000000..b06ffb8 --- /dev/null +++ b/stator/graph.py @@ -0,0 +1,162 @@ +import datetime +from functools import wraps +from typing import Callable, ClassVar, Dict, List, Optional, Set, Tuple, Union + +from django.db import models +from django.utils import timezone + + +class StateGraph: + """ + Represents a graph of possible states and transitions to attempt on them. + Does not support subclasses of existing graphs yet. + """ + + states: ClassVar[Dict[str, "State"]] + choices: ClassVar[List[Tuple[str, str]]] + initial_state: ClassVar["State"] + terminal_states: ClassVar[Set["State"]] + + def __init_subclass__(cls) -> None: + # Collect state memebers + cls.states = {} + for name, value in cls.__dict__.items(): + if name in ["__module__", "__doc__", "states"]: + pass + elif name in ["initial_state", "terminal_states", "choices"]: + raise ValueError(f"Cannot name a state {name} - this is reserved") + elif isinstance(value, State): + value._add_to_graph(cls, name) + elif callable(value) or isinstance(value, classmethod): + pass + else: + raise ValueError( + f"Graph has item {name} of unallowed type {type(value)}" + ) + # Check the graph layout + terminal_states = set() + initial_state = None + for state in cls.states.values(): + if state.initial: + if initial_state: + raise ValueError( + f"The graph has more than one initial state: {initial_state} and {state}" + ) + initial_state = state + if state.terminal: + terminal_states.add(state) + if initial_state is None: + raise ValueError("The graph has no initial state") + cls.initial_state = initial_state + cls.terminal_states = terminal_states + # Generate choices + cls.choices = [(name, name) for name in cls.states.keys()] + + +class State: + """ + Represents an individual state + """ + + def __init__(self, try_interval: float = 300): + self.try_interval = try_interval + self.parents: Set["State"] = set() + self.children: Dict["State", "Transition"] = {} + + def _add_to_graph(self, graph: StateGraph, name: str): + self.graph = graph + self.name = name + self.graph.states[name] = self + + def __repr__(self): + return f"<State {self.name}>" + + def add_transition( + self, + other: "State", + handler: Optional[Union[str, Callable]] = None, + priority: int = 0, + ) -> Callable: + def decorator(handler: Union[str, Callable]): + self.children[other] = Transition( + self, + other, + handler, + priority=priority, + ) + other.parents.add(self) + # All handlers should be class methods, so do that automatically. + if callable(handler): + return classmethod(handler) + + # If we're not being called as a decorator, invoke it immediately + if handler is not None: + decorator(handler) + return decorator + + def add_manual_transition(self, other: "State"): + self.children[other] = ManualTransition(self, other) + other.parents.add(self) + + @property + def initial(self): + return not self.parents + + @property + def terminal(self): + return not self.children + + def transitions(self, automatic_only=False) -> List["Transition"]: + """ + Returns all transitions from this State in priority order + """ + if automatic_only: + transitions = [t for t in self.children.values() if t.automatic] + else: + transitions = self.children.values() + return sorted(transitions, key=lambda t: t.priority, reverse=True) + + +class Transition: + """ + A possible transition from one state to another + """ + + def __init__( + self, + from_state: State, + to_state: State, + handler: Union[str, Callable], + priority: int = 0, + ): + self.from_state = from_state + self.to_state = to_state + self.handler = handler + self.priority = priority + self.automatic = True + + def get_handler(self) -> Callable: + """ + Returns the handler (it might need resolving from a string) + """ + if isinstance(self.handler, str): + self.handler = getattr(self.from_state.graph, self.handler) + return self.handler + + +class ManualTransition(Transition): + """ + A possible transition from one state to another that cannot be done by + the stator task runner, and must come from an external source. + """ + + def __init__( + self, + from_state: State, + to_state: State, + ): + self.from_state = from_state + self.to_state = to_state + self.handler = None + self.priority = 0 + self.automatic = False diff --git a/stator/migrations/0001_initial.py b/stator/migrations/0001_initial.py new file mode 100644 index 0000000..f485836 --- /dev/null +++ b/stator/migrations/0001_initial.py @@ -0,0 +1,31 @@ +# Generated by Django 4.1.3 on 2022-11-09 05:46 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + initial = True + + dependencies = [] + + operations = [ + migrations.CreateModel( + name="StatorTask", + fields=[ + ( + "id", + models.BigAutoField( + auto_created=True, + primary_key=True, + serialize=False, + verbose_name="ID", + ), + ), + ("model_label", models.CharField(max_length=200)), + ("instance_pk", models.CharField(max_length=200)), + ("locked_until", models.DateTimeField(blank=True, null=True)), + ("priority", models.IntegerField(default=0)), + ], + ), + ] diff --git a/miniq/migrations/__init__.py b/stator/migrations/__init__.py index e69de29..e69de29 100644 --- a/miniq/migrations/__init__.py +++ b/stator/migrations/__init__.py diff --git a/stator/models.py b/stator/models.py new file mode 100644 index 0000000..3b0da0a --- /dev/null +++ b/stator/models.py @@ -0,0 +1,191 @@ +import datetime +from functools import reduce +from typing import Type, cast + +from asgiref.sync import sync_to_async +from django.apps import apps +from django.db import models, transaction +from django.utils import timezone +from django.utils.functional import classproperty + +from stator.graph import State, StateGraph + + +class StateField(models.CharField): + """ + A special field that automatically gets choices from a state graph + """ + + def __init__(self, graph: Type[StateGraph], **kwargs): + # Sensible default for state length + kwargs.setdefault("max_length", 100) + # Add choices and initial + self.graph = graph + kwargs["choices"] = self.graph.choices + kwargs["default"] = self.graph.initial_state.name + super().__init__(**kwargs) + + def deconstruct(self): + name, path, args, kwargs = super().deconstruct() + kwargs["graph"] = self.graph + return name, path, args, kwargs + + def from_db_value(self, value, expression, connection): + if value is None: + return value + return self.graph.states[value] + + def to_python(self, value): + if isinstance(value, State) or value is None: + return value + return self.graph.states[value] + + def get_prep_value(self, value): + if isinstance(value, State): + return value.name + return value + + +class StatorModel(models.Model): + """ + A model base class that has a state machine backing it, with tasks to work + out when to move the state to the next one. + + You need to provide a "state" field as an instance of StateField on the + concrete model yourself. + """ + + # When the state last actually changed, or the date of instance creation + state_changed = models.DateTimeField(auto_now_add=True) + + # When the last state change for the current state was attempted + # (and not successful, as this is cleared on transition) + state_attempted = models.DateTimeField(blank=True, null=True) + + class Meta: + abstract = True + + @classmethod + def schedule_overdue(cls, now=None) -> models.QuerySet: + """ + Finds instances of this model that need to run and schedule them. + """ + q = models.Q() + for transition in cls.state_graph.transitions(automatic_only=True): + q = q | transition.get_query(now=now) + return cls.objects.filter(q) + + @classproperty + def state_graph(cls) -> Type[StateGraph]: + return cls._meta.get_field("state").graph + + def schedule_transition(self, priority: int = 0): + """ + Adds this instance to the queue to get its state transition attempted. + + The scheduler will call this, but you can also call it directly if you + know it'll be ready and want to lower latency. + """ + StatorTask.schedule_for_execution(self, priority=priority) + + async def attempt_transition(self): + """ + Attempts to transition the current state by running its handler(s). + """ + # Try each transition in priority order + for transition in self.state_graph.states[self.state].transitions( + automatic_only=True + ): + success = await transition.get_handler()(self) + if success: + await self.perform_transition(transition.to_state.name) + return + await self.__class__.objects.filter(pk=self.pk).aupdate( + state_attempted=timezone.now() + ) + + async def perform_transition(self, state_name): + """ + Transitions the instance to the given state name + """ + if state_name not in self.state_graph.states: + raise ValueError(f"Invalid state {state_name}") + await self.__class__.objects.filter(pk=self.pk).aupdate( + state=state_name, + state_changed=timezone.now(), + state_attempted=None, + ) + + +class StatorTask(models.Model): + """ + The model that we use for an internal scheduling queue. + + Entries in this queue are up for checking and execution - it also performs + locking to ensure we get closer to exactly-once execution (but we err on + the side of at-least-once) + """ + + # appname.modelname (lowercased) label for the model this represents + model_label = models.CharField(max_length=200) + + # The primary key of that model (probably int or str) + instance_pk = models.CharField(max_length=200) + + # Locking columns (no runner ID, as we have no heartbeats - all runners + # only live for a short amount of time anyway) + locked_until = models.DateTimeField(null=True, blank=True) + + # Basic total ordering priority - higher is more important + priority = models.IntegerField(default=0) + + def __str__(self): + return f"#{self.pk}: {self.model_label}.{self.instance_pk}" + + @classmethod + def schedule_for_execution(cls, model_instance: StatorModel, priority: int = 0): + # We don't do a transaction here as it's fine to occasionally double up + model_label = model_instance._meta.label_lower + pk = model_instance.pk + # TODO: Increase priority of existing if present + if not cls.objects.filter( + model_label=model_label, instance_pk=pk, locked__isnull=True + ).exists(): + StatorTask.objects.create( + model_label=model_label, + instance_pk=pk, + priority=priority, + ) + + @classmethod + def get_for_execution(cls, number: int, lock_expiry: datetime.datetime): + """ + Returns up to `number` tasks for execution, having locked them. + """ + with transaction.atomic(): + selected = list( + cls.objects.filter(locked_until__isnull=True)[ + :number + ].select_for_update() + ) + cls.objects.filter(pk__in=[i.pk for i in selected]).update( + locked_until=timezone.now() + ) + return selected + + @classmethod + async def aget_for_execution(cls, number: int, lock_expiry: datetime.datetime): + return await sync_to_async(cls.get_for_execution)(number, lock_expiry) + + @classmethod + async def aclean_old_locks(cls): + await cls.objects.filter(locked_until__lte=timezone.now()).aupdate( + locked_until=None + ) + + async def aget_model_instance(self) -> StatorModel: + model = apps.get_model(self.model_label) + return cast(StatorModel, await model.objects.aget(pk=self.pk)) + + async def adelete(self): + self.__class__.objects.adelete(pk=self.pk) diff --git a/stator/runner.py b/stator/runner.py new file mode 100644 index 0000000..8c6e0f1 --- /dev/null +++ b/stator/runner.py @@ -0,0 +1,69 @@ +import asyncio +import datetime +import time +import uuid +from typing import List, Type + +from asgiref.sync import sync_to_async +from django.db import transaction +from django.utils import timezone + +from stator.models import StatorModel, StatorTask + + +class StatorRunner: + """ + Runs tasks on models that are looking for state changes. + Designed to run in a one-shot mode, living inside a request. + """ + + START_TIMEOUT = 30 + TOTAL_TIMEOUT = 60 + LOCK_TIMEOUT = 120 + + MAX_TASKS = 30 + + def __init__(self, models: List[Type[StatorModel]]): + self.models = models + self.runner_id = uuid.uuid4().hex + + async def run(self): + start_time = time.monotonic() + self.handled = 0 + self.tasks = [] + # Clean up old locks + await StatorTask.aclean_old_locks() + # Examine what needs scheduling + + # For the first time period, launch tasks + while (time.monotonic() - start_time) < self.START_TIMEOUT: + self.remove_completed_tasks() + space_remaining = self.MAX_TASKS - len(self.tasks) + # Fetch new tasks + if space_remaining > 0: + for new_task in await StatorTask.aget_for_execution( + space_remaining, + timezone.now() + datetime.timedelta(seconds=self.LOCK_TIMEOUT), + ): + self.tasks.append(asyncio.create_task(self.run_task(new_task))) + self.handled += 1 + # Prevent busylooping + await asyncio.sleep(0.01) + # Then wait for tasks to finish + while (time.monotonic() - start_time) < self.TOTAL_TIMEOUT: + self.remove_completed_tasks() + if not self.tasks: + break + # Prevent busylooping + await asyncio.sleep(1) + return self.handled + + async def run_task(self, task: StatorTask): + # Resolve the model instance + model_instance = await task.aget_model_instance() + await model_instance.attempt_transition() + # Remove ourselves from the database as complete + await task.adelete() + + def remove_completed_tasks(self): + self.tasks = [t for t in self.tasks if not t.done()] diff --git a/stator/tests/test_graph.py b/stator/tests/test_graph.py new file mode 100644 index 0000000..f6b8404 --- /dev/null +++ b/stator/tests/test_graph.py @@ -0,0 +1,66 @@ +import pytest + +from stator.graph import State, StateGraph + + +def test_declare(): + """ + Tests a basic graph declaration and various kinds of handler + lookups. + """ + + fake_handler = lambda: True + + class TestGraph(StateGraph): + initial = State() + second = State() + third = State() + fourth = State() + final = State() + + initial.add_transition(second, 60, handler=fake_handler) + second.add_transition(third, 60, handler="check_third") + + def check_third(cls): + return True + + @third.add_transition(fourth, 60) + def check_fourth(cls): + return True + + fourth.add_manual_transition(final) + + assert TestGraph.initial_state == TestGraph.initial + assert TestGraph.terminal_states == {TestGraph.final} + + assert TestGraph.initial.children[TestGraph.second].get_handler() == fake_handler + assert ( + TestGraph.second.children[TestGraph.third].get_handler() + == TestGraph.check_third + ) + assert ( + TestGraph.third.children[TestGraph.fourth].get_handler().__name__ + == "check_fourth" + ) + + +def test_bad_declarations(): + """ + Tests that you can't declare an invalid graph. + """ + # More than one initial state + with pytest.raises(ValueError): + + class TestGraph(StateGraph): + initial = State() + initial2 = State() + + # No initial states + with pytest.raises(ValueError): + + class TestGraph(StateGraph): + loop = State() + loop2 = State() + + loop.add_transition(loop2, 1, handler="fake") + loop2.add_transition(loop, 1, handler="fake") diff --git a/stator/views.py b/stator/views.py new file mode 100644 index 0000000..ef09b8e --- /dev/null +++ b/stator/views.py @@ -0,0 +1,17 @@ +from django.http import HttpResponse +from django.views import View + +from stator.runner import StatorRunner +from users.models import Follow + + +class RequestRunner(View): + """ + Runs a Stator runner within a HTTP request. For when you're on something + serverless. + """ + + async def get(self, request): + runner = StatorRunner([Follow]) + handled = await runner.run() + return HttpResponse(f"Handled {handled}") diff --git a/takahe/settings.py b/takahe/settings.py index 62065d2..cefbb35 100644 --- a/takahe/settings.py +++ b/takahe/settings.py @@ -26,7 +26,7 @@ INSTALLED_APPS = [ "core", "statuses", "users", - "miniq", + "stator", ] MIDDLEWARE = [ diff --git a/takahe/urls.py b/takahe/urls.py index 304bc23..764c8e9 100644 --- a/takahe/urls.py +++ b/takahe/urls.py @@ -2,7 +2,7 @@ from django.contrib import admin from django.urls import path from core import views as core -from miniq import views as miniq +from stator import views as stator from users.views import auth, identity urlpatterns = [ @@ -22,7 +22,7 @@ urlpatterns = [ # Well-known endpoints path(".well-known/webfinger", identity.Webfinger.as_view()), # Task runner - path(".queue/process/", miniq.QueueProcessor.as_view()), + path(".stator/runner/", stator.RequestRunner.as_view()), # Django admin path("djadmin/", admin.site.urls), ] diff --git a/users/admin.py b/users/admin.py index 0b0cc80..e517b0a 100644 --- a/users/admin.py +++ b/users/admin.py @@ -25,4 +25,4 @@ class IdentityAdmin(admin.ModelAdmin): @admin.register(Follow) class FollowAdmin(admin.ModelAdmin): - list_display = ["id", "source", "target", "requested", "accepted"] + list_display = ["id", "source", "target", "state"] diff --git a/users/migrations/0002_follow_state_follow_state_attempted_and_more.py b/users/migrations/0002_follow_state_follow_state_attempted_and_more.py new file mode 100644 index 0000000..b33236a --- /dev/null +++ b/users/migrations/0002_follow_state_follow_state_attempted_and_more.py @@ -0,0 +1,44 @@ +# Generated by Django 4.1.3 on 2022-11-07 19:22 + +import django.utils.timezone +from django.db import migrations, models + +import stator.models +import users.models.follow + + +class Migration(migrations.Migration): + + dependencies = [ + ("users", "0001_initial"), + ] + + operations = [ + migrations.AddField( + model_name="follow", + name="state", + field=stator.models.StateField( + choices=[ + ("pending", "pending"), + ("requested", "requested"), + ("accepted", "accepted"), + ], + default="pending", + graph=users.models.follow.FollowStates, + max_length=100, + ), + ), + migrations.AddField( + model_name="follow", + name="state_attempted", + field=models.DateTimeField(blank=True, null=True), + ), + migrations.AddField( + model_name="follow", + name="state_changed", + field=models.DateTimeField( + auto_now_add=True, default=django.utils.timezone.now + ), + preserve_default=False, + ), + ] diff --git a/users/migrations/0003_remove_follow_accepted_remove_follow_requested_and_more.py b/users/migrations/0003_remove_follow_accepted_remove_follow_requested_and_more.py new file mode 100644 index 0000000..180bfdd --- /dev/null +++ b/users/migrations/0003_remove_follow_accepted_remove_follow_requested_and_more.py @@ -0,0 +1,31 @@ +# Generated by Django 4.1.3 on 2022-11-08 03:58 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ("users", "0002_follow_state_follow_state_attempted_and_more"), + ] + + operations = [ + migrations.RemoveField( + model_name="follow", + name="accepted", + ), + migrations.RemoveField( + model_name="follow", + name="requested", + ), + migrations.AddField( + model_name="follow", + name="state_locked", + field=models.DateTimeField(blank=True, null=True), + ), + migrations.AddField( + model_name="follow", + name="state_runner", + field=models.CharField(blank=True, max_length=100, null=True), + ), + ] diff --git a/users/migrations/0004_remove_follow_state_locked_and_more.py b/users/migrations/0004_remove_follow_state_locked_and_more.py new file mode 100644 index 0000000..bf98080 --- /dev/null +++ b/users/migrations/0004_remove_follow_state_locked_and_more.py @@ -0,0 +1,21 @@ +# Generated by Django 4.1.3 on 2022-11-09 05:15 + +from django.db import migrations + + +class Migration(migrations.Migration): + + dependencies = [ + ("users", "0003_remove_follow_accepted_remove_follow_requested_and_more"), + ] + + operations = [ + migrations.RemoveField( + model_name="follow", + name="state_locked", + ), + migrations.RemoveField( + model_name="follow", + name="state_runner", + ), + ] diff --git a/users/models/follow.py b/users/models/follow.py index 29d036e..04f90ee 100644 --- a/users/models/follow.py +++ b/users/models/follow.py @@ -2,10 +2,23 @@ from typing import Optional from django.db import models -from miniq.models import Task +from stator.models import State, StateField, StateGraph, StatorModel -class Follow(models.Model): +class FollowStates(StateGraph): + pending = State(try_interval=3600) + requested = State() + accepted = State() + + @pending.add_transition(requested) + async def try_request(cls, instance): + print("Would have tried to follow") + return False + + requested.add_manual_transition(accepted) + + +class Follow(StatorModel): """ When one user (the source) follows other (the target) """ @@ -24,8 +37,7 @@ class Follow(models.Model): uri = models.CharField(blank=True, null=True, max_length=500) note = models.TextField(blank=True, null=True) - requested = models.BooleanField(default=False) - accepted = models.BooleanField(default=False) + state = StateField(FollowStates) created = models.DateTimeField(auto_now_add=True) updated = models.DateTimeField(auto_now=True) @@ -50,17 +62,15 @@ class Follow(models.Model): (which can be local or remote). """ if not source.local: - raise ValueError("You cannot initiate follows on a remote Identity") + raise ValueError("You cannot initiate follows from a remote Identity") try: follow = Follow.objects.get(source=source, target=target) except Follow.DoesNotExist: follow = Follow.objects.create(source=source, target=target, uri="") follow.uri = source.actor_uri + f"follow/{follow.pk}/" + # TODO: Local follow approvals if target.local: - follow.requested = True - follow.accepted = True - else: - Task.submit("follow_request", str(follow.pk)) + follow.state = FollowStates.accepted follow.save() return follow diff --git a/users/tasks/follow.py b/users/tasks/follow.py index 872b35f..0f802cf 100644 --- a/users/tasks/follow.py +++ b/users/tasks/follow.py @@ -27,3 +27,36 @@ async def handle_follow_request(task_handler): if response.status_code >= 400: raise ValueError(f"Request error: {response.status_code} {response.content}") await Follow.objects.filter(pk=follow.pk).aupdate(requested=True) + + +def send_follow_undo(id): + """ + Request a follow from a remote server + """ + follow = Follow.objects.select_related("source", "source__domain", "target").get( + pk=id + ) + # Construct the request + request = canonicalise( + { + "@context": "https://www.w3.org/ns/activitystreams", + "id": follow.uri + "#undo", + "type": "Undo", + "actor": follow.source.actor_uri, + "object": { + "id": follow.uri, + "type": "Follow", + "actor": follow.source.actor_uri, + "object": follow.target.actor_uri, + }, + } + ) + # Sign it and send it + from asgiref.sync import async_to_sync + + response = async_to_sync(HttpSignature.signed_request)( + follow.target.inbox_uri, request, follow.source + ) + if response.status_code >= 400: + raise ValueError(f"Request error: {response.status_code} {response.content}") + print(response) diff --git a/users/views/identity.py b/users/views/identity.py index 98fcdd6..41c7880 100644 --- a/users/views/identity.py +++ b/users/views/identity.py @@ -16,7 +16,6 @@ from django.views.generic import FormView, TemplateView, View from core.forms import FormHelper from core.ld import canonicalise from core.signatures import HttpSignature -from miniq.models import Task from users.decorators import identity_required from users.models import Domain, Follow, Identity from users.shortcuts import by_handle_or_404 |