summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorAndrew Godwin2022-11-08 23:06:29 -0700
committerAndrew Godwin2022-11-09 22:29:49 -0700
commit61c324508e62bb640b4526183d0837fc57d742c2 (patch)
tree618ee8c88ce8a28224a187dc33b7c5fad6831d04
parent8a0a7558894afce8d25b7f0dc16775e899b72a94 (diff)
downloadtakahe-61c324508e62bb640b4526183d0837fc57d742c2.tar.gz
takahe-61c324508e62bb640b4526183d0837fc57d742c2.tar.bz2
takahe-61c324508e62bb640b4526183d0837fc57d742c2.zip
Midway point in task refactor - changing direction
-rw-r--r--miniq/admin.py21
-rw-r--r--miniq/migrations/0001_initial.py48
-rw-r--r--miniq/models.py71
-rw-r--r--miniq/tasks.py34
-rw-r--r--miniq/views.py51
-rw-r--r--stator/__init__.py (renamed from miniq/__init__.py)0
-rw-r--r--stator/admin.py8
-rw-r--r--stator/apps.py (renamed from miniq/apps.py)4
-rw-r--r--stator/graph.py162
-rw-r--r--stator/migrations/0001_initial.py31
-rw-r--r--stator/migrations/__init__.py (renamed from miniq/migrations/__init__.py)0
-rw-r--r--stator/models.py191
-rw-r--r--stator/runner.py69
-rw-r--r--stator/tests/test_graph.py66
-rw-r--r--stator/views.py17
-rw-r--r--takahe/settings.py2
-rw-r--r--takahe/urls.py4
-rw-r--r--users/admin.py2
-rw-r--r--users/migrations/0002_follow_state_follow_state_attempted_and_more.py44
-rw-r--r--users/migrations/0003_remove_follow_accepted_remove_follow_requested_and_more.py31
-rw-r--r--users/migrations/0004_remove_follow_state_locked_and_more.py21
-rw-r--r--users/models/follow.py28
-rw-r--r--users/tasks/follow.py33
-rw-r--r--users/views/identity.py1
24 files changed, 698 insertions, 241 deletions
diff --git a/miniq/admin.py b/miniq/admin.py
deleted file mode 100644
index 1166f89..0000000
--- a/miniq/admin.py
+++ /dev/null
@@ -1,21 +0,0 @@
-from django.contrib import admin
-
-from miniq.models import Task
-
-
-@admin.register(Task)
-class TaskAdmin(admin.ModelAdmin):
-
- list_display = ["id", "created", "type", "subject", "completed", "failed"]
- ordering = ["-created"]
- actions = ["reset"]
-
- @admin.action(description="Reset Task")
- def reset(self, request, queryset):
- queryset.update(
- failed=None,
- completed=None,
- locked=None,
- locked_by=None,
- error=None,
- )
diff --git a/miniq/migrations/0001_initial.py b/miniq/migrations/0001_initial.py
deleted file mode 100644
index dc6d42b..0000000
--- a/miniq/migrations/0001_initial.py
+++ /dev/null
@@ -1,48 +0,0 @@
-# Generated by Django 4.1.3 on 2022-11-07 04:19
-
-from django.db import migrations, models
-
-
-class Migration(migrations.Migration):
-
- initial = True
-
- dependencies = []
-
- operations = [
- migrations.CreateModel(
- name="Task",
- fields=[
- (
- "id",
- models.BigAutoField(
- auto_created=True,
- primary_key=True,
- serialize=False,
- verbose_name="ID",
- ),
- ),
- (
- "type",
- models.CharField(
- choices=[
- ("identity_fetch", "Identity Fetch"),
- ("inbox_item", "Inbox Item"),
- ("follow_request", "Follow Request"),
- ("follow_acknowledge", "Follow Acknowledge"),
- ],
- max_length=500,
- ),
- ),
- ("priority", models.IntegerField(default=0)),
- ("subject", models.TextField()),
- ("payload", models.JSONField(blank=True, null=True)),
- ("error", models.TextField(blank=True, null=True)),
- ("created", models.DateTimeField(auto_now_add=True)),
- ("completed", models.DateTimeField(blank=True, null=True)),
- ("failed", models.DateTimeField(blank=True, null=True)),
- ("locked", models.DateTimeField(blank=True, null=True)),
- ("locked_by", models.CharField(blank=True, max_length=500, null=True)),
- ],
- ),
- ]
diff --git a/miniq/models.py b/miniq/models.py
deleted file mode 100644
index 24d311c..0000000
--- a/miniq/models.py
+++ /dev/null
@@ -1,71 +0,0 @@
-from typing import Optional
-
-from django.db import models, transaction
-from django.utils import timezone
-
-
-class Task(models.Model):
- """
- A task that must be done by a queue processor
- """
-
- class TypeChoices(models.TextChoices):
- identity_fetch = "identity_fetch"
- inbox_item = "inbox_item"
- follow_request = "follow_request"
- follow_acknowledge = "follow_acknowledge"
-
- type = models.CharField(max_length=500, choices=TypeChoices.choices)
- priority = models.IntegerField(default=0)
- subject = models.TextField()
- payload = models.JSONField(blank=True, null=True)
- error = models.TextField(blank=True, null=True)
-
- created = models.DateTimeField(auto_now_add=True)
- completed = models.DateTimeField(blank=True, null=True)
- failed = models.DateTimeField(blank=True, null=True)
- locked = models.DateTimeField(blank=True, null=True)
- locked_by = models.CharField(max_length=500, blank=True, null=True)
-
- def __str__(self):
- return f"{self.id}/{self.type}({self.subject})"
-
- @classmethod
- def get_one_available(cls, processor_id) -> Optional["Task"]:
- """
- Gets one task off the list while reserving it, atomically.
- """
- with transaction.atomic():
- next_task = cls.objects.filter(locked__isnull=True).first()
- if next_task is None:
- return None
- next_task.locked = timezone.now()
- next_task.locked_by = processor_id
- next_task.save()
- return next_task
-
- @classmethod
- def submit(cls, type, subject: str, payload=None, deduplicate=True):
- # Deduplication is done against tasks that have not started yet only,
- # and only on tasks without payloads
- if deduplicate and not payload:
- if cls.objects.filter(
- type=type,
- subject=subject,
- completed__isnull=True,
- failed__isnull=True,
- locked__isnull=True,
- ).exists():
- return
- cls.objects.create(type=type, subject=subject, payload=payload)
-
- async def complete(self):
- await self.__class__.objects.filter(id=self.id).aupdate(
- completed=timezone.now()
- )
-
- async def fail(self, error):
- await self.__class__.objects.filter(id=self.id).aupdate(
- failed=timezone.now(),
- error=error,
- )
diff --git a/miniq/tasks.py b/miniq/tasks.py
deleted file mode 100644
index fedf8fd..0000000
--- a/miniq/tasks.py
+++ /dev/null
@@ -1,34 +0,0 @@
-import traceback
-
-from users.tasks.follow import handle_follow_request
-from users.tasks.identity import handle_identity_fetch
-from users.tasks.inbox import handle_inbox_item
-
-
-class TaskHandler:
-
- handlers = {
- "identity_fetch": handle_identity_fetch,
- "inbox_item": handle_inbox_item,
- "follow_request": handle_follow_request,
- }
-
- def __init__(self, task):
- self.task = task
- self.subject = self.task.subject
- self.payload = self.task.payload
-
- async def handle(self):
- try:
- print(f"Task {self.task}: Starting")
- if self.task.type not in self.handlers:
- raise ValueError(f"Cannot handle type {self.task.type}")
- await self.handlers[self.task.type](
- self,
- )
- await self.task.complete()
- print(f"Task {self.task}: Complete")
- except BaseException as e:
- print(f"Task {self.task}: Error {e}")
- traceback.print_exc()
- await self.task.fail(f"{e}\n\n" + traceback.format_exc())
diff --git a/miniq/views.py b/miniq/views.py
deleted file mode 100644
index 80c9ee2..0000000
--- a/miniq/views.py
+++ /dev/null
@@ -1,51 +0,0 @@
-import asyncio
-import time
-import uuid
-
-from asgiref.sync import sync_to_async
-from django.http import HttpResponse
-from django.views import View
-
-from miniq.models import Task
-from miniq.tasks import TaskHandler
-
-
-class QueueProcessor(View):
- """
- A view that takes some items off the queue and processes them.
- Tries to limit its own runtime so it's within HTTP timeout limits.
- """
-
- START_TIMEOUT = 30
- TOTAL_TIMEOUT = 60
- LOCK_TIMEOUT = 200
- MAX_TASKS = 20
-
- async def get(self, request):
- start_time = time.monotonic()
- processor_id = uuid.uuid4().hex
- handled = 0
- self.tasks = []
- # For the first time period, launch tasks
- while (time.monotonic() - start_time) < self.START_TIMEOUT:
- # Remove completed tasks
- self.tasks = [t for t in self.tasks if not t.done()]
- # See if there's a new task
- if len(self.tasks) < self.MAX_TASKS:
- # Pop a task off the queue and run it
- task = await sync_to_async(Task.get_one_available)(processor_id)
- if task is not None:
- self.tasks.append(asyncio.create_task(TaskHandler(task).handle()))
- handled += 1
- # Prevent busylooping
- await asyncio.sleep(0.01)
- # TODO: Clean up old locks here
- # Then wait for tasks to finish
- while (time.monotonic() - start_time) < self.TOTAL_TIMEOUT:
- # Remove completed tasks
- self.tasks = [t for t in self.tasks if not t.done()]
- if not self.tasks:
- break
- # Prevent busylooping
- await asyncio.sleep(1)
- return HttpResponse(f"{handled} tasks handled")
diff --git a/miniq/__init__.py b/stator/__init__.py
index e69de29..e69de29 100644
--- a/miniq/__init__.py
+++ b/stator/__init__.py
diff --git a/stator/admin.py b/stator/admin.py
new file mode 100644
index 0000000..c04d775
--- /dev/null
+++ b/stator/admin.py
@@ -0,0 +1,8 @@
+from django.contrib import admin
+
+from stator.models import StatorTask
+
+
+@admin.register(StatorTask)
+class DomainAdmin(admin.ModelAdmin):
+ list_display = ["id", "model_label", "instance_pk", "locked_until"]
diff --git a/miniq/apps.py b/stator/apps.py
index 4c7e773..8910ecb 100644
--- a/miniq/apps.py
+++ b/stator/apps.py
@@ -1,6 +1,6 @@
from django.apps import AppConfig
-class MiniqConfig(AppConfig):
+class StatorConfig(AppConfig):
default_auto_field = "django.db.models.BigAutoField"
- name = "miniq"
+ name = "stator"
diff --git a/stator/graph.py b/stator/graph.py
new file mode 100644
index 0000000..b06ffb8
--- /dev/null
+++ b/stator/graph.py
@@ -0,0 +1,162 @@
+import datetime
+from functools import wraps
+from typing import Callable, ClassVar, Dict, List, Optional, Set, Tuple, Union
+
+from django.db import models
+from django.utils import timezone
+
+
+class StateGraph:
+ """
+ Represents a graph of possible states and transitions to attempt on them.
+ Does not support subclasses of existing graphs yet.
+ """
+
+ states: ClassVar[Dict[str, "State"]]
+ choices: ClassVar[List[Tuple[str, str]]]
+ initial_state: ClassVar["State"]
+ terminal_states: ClassVar[Set["State"]]
+
+ def __init_subclass__(cls) -> None:
+ # Collect state memebers
+ cls.states = {}
+ for name, value in cls.__dict__.items():
+ if name in ["__module__", "__doc__", "states"]:
+ pass
+ elif name in ["initial_state", "terminal_states", "choices"]:
+ raise ValueError(f"Cannot name a state {name} - this is reserved")
+ elif isinstance(value, State):
+ value._add_to_graph(cls, name)
+ elif callable(value) or isinstance(value, classmethod):
+ pass
+ else:
+ raise ValueError(
+ f"Graph has item {name} of unallowed type {type(value)}"
+ )
+ # Check the graph layout
+ terminal_states = set()
+ initial_state = None
+ for state in cls.states.values():
+ if state.initial:
+ if initial_state:
+ raise ValueError(
+ f"The graph has more than one initial state: {initial_state} and {state}"
+ )
+ initial_state = state
+ if state.terminal:
+ terminal_states.add(state)
+ if initial_state is None:
+ raise ValueError("The graph has no initial state")
+ cls.initial_state = initial_state
+ cls.terminal_states = terminal_states
+ # Generate choices
+ cls.choices = [(name, name) for name in cls.states.keys()]
+
+
+class State:
+ """
+ Represents an individual state
+ """
+
+ def __init__(self, try_interval: float = 300):
+ self.try_interval = try_interval
+ self.parents: Set["State"] = set()
+ self.children: Dict["State", "Transition"] = {}
+
+ def _add_to_graph(self, graph: StateGraph, name: str):
+ self.graph = graph
+ self.name = name
+ self.graph.states[name] = self
+
+ def __repr__(self):
+ return f"<State {self.name}>"
+
+ def add_transition(
+ self,
+ other: "State",
+ handler: Optional[Union[str, Callable]] = None,
+ priority: int = 0,
+ ) -> Callable:
+ def decorator(handler: Union[str, Callable]):
+ self.children[other] = Transition(
+ self,
+ other,
+ handler,
+ priority=priority,
+ )
+ other.parents.add(self)
+ # All handlers should be class methods, so do that automatically.
+ if callable(handler):
+ return classmethod(handler)
+
+ # If we're not being called as a decorator, invoke it immediately
+ if handler is not None:
+ decorator(handler)
+ return decorator
+
+ def add_manual_transition(self, other: "State"):
+ self.children[other] = ManualTransition(self, other)
+ other.parents.add(self)
+
+ @property
+ def initial(self):
+ return not self.parents
+
+ @property
+ def terminal(self):
+ return not self.children
+
+ def transitions(self, automatic_only=False) -> List["Transition"]:
+ """
+ Returns all transitions from this State in priority order
+ """
+ if automatic_only:
+ transitions = [t for t in self.children.values() if t.automatic]
+ else:
+ transitions = self.children.values()
+ return sorted(transitions, key=lambda t: t.priority, reverse=True)
+
+
+class Transition:
+ """
+ A possible transition from one state to another
+ """
+
+ def __init__(
+ self,
+ from_state: State,
+ to_state: State,
+ handler: Union[str, Callable],
+ priority: int = 0,
+ ):
+ self.from_state = from_state
+ self.to_state = to_state
+ self.handler = handler
+ self.priority = priority
+ self.automatic = True
+
+ def get_handler(self) -> Callable:
+ """
+ Returns the handler (it might need resolving from a string)
+ """
+ if isinstance(self.handler, str):
+ self.handler = getattr(self.from_state.graph, self.handler)
+ return self.handler
+
+
+class ManualTransition(Transition):
+ """
+ A possible transition from one state to another that cannot be done by
+ the stator task runner, and must come from an external source.
+ """
+
+ def __init__(
+ self,
+ from_state: State,
+ to_state: State,
+ ):
+ self.from_state = from_state
+ self.to_state = to_state
+ self.handler = None
+ self.priority = 0
+ self.automatic = False
diff --git a/stator/migrations/0001_initial.py b/stator/migrations/0001_initial.py
new file mode 100644
index 0000000..f485836
--- /dev/null
+++ b/stator/migrations/0001_initial.py
@@ -0,0 +1,31 @@
+# Generated by Django 4.1.3 on 2022-11-09 05:46
+
+from django.db import migrations, models
+
+
+class Migration(migrations.Migration):
+
+ initial = True
+
+ dependencies = []
+
+ operations = [
+ migrations.CreateModel(
+ name="StatorTask",
+ fields=[
+ (
+ "id",
+ models.BigAutoField(
+ auto_created=True,
+ primary_key=True,
+ serialize=False,
+ verbose_name="ID",
+ ),
+ ),
+ ("model_label", models.CharField(max_length=200)),
+ ("instance_pk", models.CharField(max_length=200)),
+ ("locked_until", models.DateTimeField(blank=True, null=True)),
+ ("priority", models.IntegerField(default=0)),
+ ],
+ ),
+ ]
diff --git a/miniq/migrations/__init__.py b/stator/migrations/__init__.py
index e69de29..e69de29 100644
--- a/miniq/migrations/__init__.py
+++ b/stator/migrations/__init__.py
diff --git a/stator/models.py b/stator/models.py
new file mode 100644
index 0000000..3b0da0a
--- /dev/null
+++ b/stator/models.py
@@ -0,0 +1,191 @@
+import datetime
+from functools import reduce
+from typing import Type, cast
+
+from asgiref.sync import sync_to_async
+from django.apps import apps
+from django.db import models, transaction
+from django.utils import timezone
+from django.utils.functional import classproperty
+
+from stator.graph import State, StateGraph
+
+
+class StateField(models.CharField):
+ """
+ A special field that automatically gets choices from a state graph
+ """
+
+ def __init__(self, graph: Type[StateGraph], **kwargs):
+ # Sensible default for state length
+ kwargs.setdefault("max_length", 100)
+ # Add choices and initial
+ self.graph = graph
+ kwargs["choices"] = self.graph.choices
+ kwargs["default"] = self.graph.initial_state.name
+ super().__init__(**kwargs)
+
+ def deconstruct(self):
+ name, path, args, kwargs = super().deconstruct()
+ kwargs["graph"] = self.graph
+ return name, path, args, kwargs
+
+ def from_db_value(self, value, expression, connection):
+ if value is None:
+ return value
+ return self.graph.states[value]
+
+ def to_python(self, value):
+ if isinstance(value, State) or value is None:
+ return value
+ return self.graph.states[value]
+
+ def get_prep_value(self, value):
+ if isinstance(value, State):
+ return value.name
+ return value
+
+
+class StatorModel(models.Model):
+ """
+ A model base class that has a state machine backing it, with tasks to work
+ out when to move the state to the next one.
+
+ You need to provide a "state" field as an instance of StateField on the
+ concrete model yourself.
+ """
+
+ # When the state last actually changed, or the date of instance creation
+ state_changed = models.DateTimeField(auto_now_add=True)
+
+ # When the last state change for the current state was attempted
+ # (and not successful, as this is cleared on transition)
+ state_attempted = models.DateTimeField(blank=True, null=True)
+
+ class Meta:
+ abstract = True
+
+ @classmethod
+ def schedule_overdue(cls, now=None) -> models.QuerySet:
+ """
+ Finds instances of this model that need to run and schedule them.
+ """
+ q = models.Q()
+ for transition in cls.state_graph.transitions(automatic_only=True):
+ q = q | transition.get_query(now=now)
+ return cls.objects.filter(q)
+
+ @classproperty
+ def state_graph(cls) -> Type[StateGraph]:
+ return cls._meta.get_field("state").graph
+
+ def schedule_transition(self, priority: int = 0):
+ """
+ Adds this instance to the queue to get its state transition attempted.
+
+ The scheduler will call this, but you can also call it directly if you
+ know it'll be ready and want to lower latency.
+ """
+ StatorTask.schedule_for_execution(self, priority=priority)
+
+ async def attempt_transition(self):
+ """
+ Attempts to transition the current state by running its handler(s).
+ """
+ # Try each transition in priority order
+ for transition in self.state_graph.states[self.state].transitions(
+ automatic_only=True
+ ):
+ success = await transition.get_handler()(self)
+ if success:
+ await self.perform_transition(transition.to_state.name)
+ return
+ await self.__class__.objects.filter(pk=self.pk).aupdate(
+ state_attempted=timezone.now()
+ )
+
+ async def perform_transition(self, state_name):
+ """
+ Transitions the instance to the given state name
+ """
+ if state_name not in self.state_graph.states:
+ raise ValueError(f"Invalid state {state_name}")
+ await self.__class__.objects.filter(pk=self.pk).aupdate(
+ state=state_name,
+ state_changed=timezone.now(),
+ state_attempted=None,
+ )
+
+
+class StatorTask(models.Model):
+ """
+ The model that we use for an internal scheduling queue.
+
+ Entries in this queue are up for checking and execution - it also performs
+ locking to ensure we get closer to exactly-once execution (but we err on
+ the side of at-least-once)
+ """
+
+ # appname.modelname (lowercased) label for the model this represents
+ model_label = models.CharField(max_length=200)
+
+ # The primary key of that model (probably int or str)
+ instance_pk = models.CharField(max_length=200)
+
+ # Locking columns (no runner ID, as we have no heartbeats - all runners
+ # only live for a short amount of time anyway)
+ locked_until = models.DateTimeField(null=True, blank=True)
+
+ # Basic total ordering priority - higher is more important
+ priority = models.IntegerField(default=0)
+
+ def __str__(self):
+ return f"#{self.pk}: {self.model_label}.{self.instance_pk}"
+
+ @classmethod
+ def schedule_for_execution(cls, model_instance: StatorModel, priority: int = 0):
+ # We don't do a transaction here as it's fine to occasionally double up
+ model_label = model_instance._meta.label_lower
+ pk = model_instance.pk
+ # TODO: Increase priority of existing if present
+ if not cls.objects.filter(
+ model_label=model_label, instance_pk=pk, locked__isnull=True
+ ).exists():
+ StatorTask.objects.create(
+ model_label=model_label,
+ instance_pk=pk,
+ priority=priority,
+ )
+
+ @classmethod
+ def get_for_execution(cls, number: int, lock_expiry: datetime.datetime):
+ """
+ Returns up to `number` tasks for execution, having locked them.
+ """
+ with transaction.atomic():
+ selected = list(
+ cls.objects.filter(locked_until__isnull=True)[
+ :number
+ ].select_for_update()
+ )
+ cls.objects.filter(pk__in=[i.pk for i in selected]).update(
+ locked_until=timezone.now()
+ )
+ return selected
+
+ @classmethod
+ async def aget_for_execution(cls, number: int, lock_expiry: datetime.datetime):
+ return await sync_to_async(cls.get_for_execution)(number, lock_expiry)
+
+ @classmethod
+ async def aclean_old_locks(cls):
+ await cls.objects.filter(locked_until__lte=timezone.now()).aupdate(
+ locked_until=None
+ )
+
+ async def aget_model_instance(self) -> StatorModel:
+ model = apps.get_model(self.model_label)
+ return cast(StatorModel, await model.objects.aget(pk=self.pk))
+
+ async def adelete(self):
+ self.__class__.objects.adelete(pk=self.pk)
diff --git a/stator/runner.py b/stator/runner.py
new file mode 100644
index 0000000..8c6e0f1
--- /dev/null
+++ b/stator/runner.py
@@ -0,0 +1,69 @@
+import asyncio
+import datetime
+import time
+import uuid
+from typing import List, Type
+
+from asgiref.sync import sync_to_async
+from django.db import transaction
+from django.utils import timezone
+
+from stator.models import StatorModel, StatorTask
+
+
+class StatorRunner:
+ """
+ Runs tasks on models that are looking for state changes.
+ Designed to run in a one-shot mode, living inside a request.
+ """
+
+ START_TIMEOUT = 30
+ TOTAL_TIMEOUT = 60
+ LOCK_TIMEOUT = 120
+
+ MAX_TASKS = 30
+
+ def __init__(self, models: List[Type[StatorModel]]):
+ self.models = models
+ self.runner_id = uuid.uuid4().hex
+
+ async def run(self):
+ start_time = time.monotonic()
+ self.handled = 0
+ self.tasks = []
+ # Clean up old locks
+ await StatorTask.aclean_old_locks()
+ # Examine what needs scheduling
+
+ # For the first time period, launch tasks
+ while (time.monotonic() - start_time) < self.START_TIMEOUT:
+ self.remove_completed_tasks()
+ space_remaining = self.MAX_TASKS - len(self.tasks)
+ # Fetch new tasks
+ if space_remaining > 0:
+ for new_task in await StatorTask.aget_for_execution(
+ space_remaining,
+ timezone.now() + datetime.timedelta(seconds=self.LOCK_TIMEOUT),
+ ):
+ self.tasks.append(asyncio.create_task(self.run_task(new_task)))
+ self.handled += 1
+ # Prevent busylooping
+ await asyncio.sleep(0.01)
+ # Then wait for tasks to finish
+ while (time.monotonic() - start_time) < self.TOTAL_TIMEOUT:
+ self.remove_completed_tasks()
+ if not self.tasks:
+ break
+ # Prevent busylooping
+ await asyncio.sleep(1)
+ return self.handled
+
+ async def run_task(self, task: StatorTask):
+ # Resolve the model instance
+ model_instance = await task.aget_model_instance()
+ await model_instance.attempt_transition()
+ # Remove ourselves from the database as complete
+ await task.adelete()
+
+ def remove_completed_tasks(self):
+ self.tasks = [t for t in self.tasks if not t.done()]
diff --git a/stator/tests/test_graph.py b/stator/tests/test_graph.py
new file mode 100644
index 0000000..f6b8404
--- /dev/null
+++ b/stator/tests/test_graph.py
@@ -0,0 +1,66 @@
+import pytest
+
+from stator.graph import State, StateGraph
+
+
+def test_declare():
+ """
+ Tests a basic graph declaration and various kinds of handler
+ lookups.
+ """
+
+ fake_handler = lambda: True
+
+ class TestGraph(StateGraph):
+ initial = State()
+ second = State()
+ third = State()
+ fourth = State()
+ final = State()
+
+ initial.add_transition(second, 60, handler=fake_handler)
+ second.add_transition(third, 60, handler="check_third")
+
+ def check_third(cls):
+ return True
+
+ @third.add_transition(fourth, 60)
+ def check_fourth(cls):
+ return True
+
+ fourth.add_manual_transition(final)
+
+ assert TestGraph.initial_state == TestGraph.initial
+ assert TestGraph.terminal_states == {TestGraph.final}
+
+ assert TestGraph.initial.children[TestGraph.second].get_handler() == fake_handler
+ assert (
+ TestGraph.second.children[TestGraph.third].get_handler()
+ == TestGraph.check_third
+ )
+ assert (
+ TestGraph.third.children[TestGraph.fourth].get_handler().__name__
+ == "check_fourth"
+ )
+
+
+def test_bad_declarations():
+ """
+ Tests that you can't declare an invalid graph.
+ """
+ # More than one initial state
+ with pytest.raises(ValueError):
+
+ class TestGraph(StateGraph):
+ initial = State()
+ initial2 = State()
+
+ # No initial states
+ with pytest.raises(ValueError):
+
+ class TestGraph(StateGraph):
+ loop = State()
+ loop2 = State()
+
+ loop.add_transition(loop2, 1, handler="fake")
+ loop2.add_transition(loop, 1, handler="fake")
diff --git a/stator/views.py b/stator/views.py
new file mode 100644
index 0000000..ef09b8e
--- /dev/null
+++ b/stator/views.py
@@ -0,0 +1,17 @@
+from django.http import HttpResponse
+from django.views import View
+
+from stator.runner import StatorRunner
+from users.models import Follow
+
+
+class RequestRunner(View):
+ """
+ Runs a Stator runner within a HTTP request. For when you're on something
+ serverless.
+ """
+
+ async def get(self, request):
+ runner = StatorRunner([Follow])
+ handled = await runner.run()
+ return HttpResponse(f"Handled {handled}")
diff --git a/takahe/settings.py b/takahe/settings.py
index 62065d2..cefbb35 100644
--- a/takahe/settings.py
+++ b/takahe/settings.py
@@ -26,7 +26,7 @@ INSTALLED_APPS = [
"core",
"statuses",
"users",
- "miniq",
+ "stator",
]
MIDDLEWARE = [
diff --git a/takahe/urls.py b/takahe/urls.py
index 304bc23..764c8e9 100644
--- a/takahe/urls.py
+++ b/takahe/urls.py
@@ -2,7 +2,7 @@ from django.contrib import admin
from django.urls import path
from core import views as core
-from miniq import views as miniq
+from stator import views as stator
from users.views import auth, identity
urlpatterns = [
@@ -22,7 +22,7 @@ urlpatterns = [
# Well-known endpoints
path(".well-known/webfinger", identity.Webfinger.as_view()),
# Task runner
- path(".queue/process/", miniq.QueueProcessor.as_view()),
+ path(".stator/runner/", stator.RequestRunner.as_view()),
# Django admin
path("djadmin/", admin.site.urls),
]
diff --git a/users/admin.py b/users/admin.py
index 0b0cc80..e517b0a 100644
--- a/users/admin.py
+++ b/users/admin.py
@@ -25,4 +25,4 @@ class IdentityAdmin(admin.ModelAdmin):
@admin.register(Follow)
class FollowAdmin(admin.ModelAdmin):
- list_display = ["id", "source", "target", "requested", "accepted"]
+ list_display = ["id", "source", "target", "state"]
diff --git a/users/migrations/0002_follow_state_follow_state_attempted_and_more.py b/users/migrations/0002_follow_state_follow_state_attempted_and_more.py
new file mode 100644
index 0000000..b33236a
--- /dev/null
+++ b/users/migrations/0002_follow_state_follow_state_attempted_and_more.py
@@ -0,0 +1,44 @@
+# Generated by Django 4.1.3 on 2022-11-07 19:22
+
+import django.utils.timezone
+from django.db import migrations, models
+
+import stator.models
+import users.models.follow
+
+
+class Migration(migrations.Migration):
+
+ dependencies = [
+ ("users", "0001_initial"),
+ ]
+
+ operations = [
+ migrations.AddField(
+ model_name="follow",
+ name="state",
+ field=stator.models.StateField(
+ choices=[
+ ("pending", "pending"),
+ ("requested", "requested"),
+ ("accepted", "accepted"),
+ ],
+ default="pending",
+ graph=users.models.follow.FollowStates,
+ max_length=100,
+ ),
+ ),
+ migrations.AddField(
+ model_name="follow",
+ name="state_attempted",
+ field=models.DateTimeField(blank=True, null=True),
+ ),
+ migrations.AddField(
+ model_name="follow",
+ name="state_changed",
+ field=models.DateTimeField(
+ auto_now_add=True, default=django.utils.timezone.now
+ ),
+ preserve_default=False,
+ ),
+ ]
diff --git a/users/migrations/0003_remove_follow_accepted_remove_follow_requested_and_more.py b/users/migrations/0003_remove_follow_accepted_remove_follow_requested_and_more.py
new file mode 100644
index 0000000..180bfdd
--- /dev/null
+++ b/users/migrations/0003_remove_follow_accepted_remove_follow_requested_and_more.py
@@ -0,0 +1,31 @@
+# Generated by Django 4.1.3 on 2022-11-08 03:58
+
+from django.db import migrations, models
+
+
+class Migration(migrations.Migration):
+
+ dependencies = [
+ ("users", "0002_follow_state_follow_state_attempted_and_more"),
+ ]
+
+ operations = [
+ migrations.RemoveField(
+ model_name="follow",
+ name="accepted",
+ ),
+ migrations.RemoveField(
+ model_name="follow",
+ name="requested",
+ ),
+ migrations.AddField(
+ model_name="follow",
+ name="state_locked",
+ field=models.DateTimeField(blank=True, null=True),
+ ),
+ migrations.AddField(
+ model_name="follow",
+ name="state_runner",
+ field=models.CharField(blank=True, max_length=100, null=True),
+ ),
+ ]
diff --git a/users/migrations/0004_remove_follow_state_locked_and_more.py b/users/migrations/0004_remove_follow_state_locked_and_more.py
new file mode 100644
index 0000000..bf98080
--- /dev/null
+++ b/users/migrations/0004_remove_follow_state_locked_and_more.py
@@ -0,0 +1,21 @@
+# Generated by Django 4.1.3 on 2022-11-09 05:15
+
+from django.db import migrations
+
+
+class Migration(migrations.Migration):
+
+ dependencies = [
+ ("users", "0003_remove_follow_accepted_remove_follow_requested_and_more"),
+ ]
+
+ operations = [
+ migrations.RemoveField(
+ model_name="follow",
+ name="state_locked",
+ ),
+ migrations.RemoveField(
+ model_name="follow",
+ name="state_runner",
+ ),
+ ]
diff --git a/users/models/follow.py b/users/models/follow.py
index 29d036e..04f90ee 100644
--- a/users/models/follow.py
+++ b/users/models/follow.py
@@ -2,10 +2,23 @@ from typing import Optional
from django.db import models
-from miniq.models import Task
+from stator.models import State, StateField, StateGraph, StatorModel
-class Follow(models.Model):
+class FollowStates(StateGraph):
+ pending = State(try_interval=3600)
+ requested = State()
+ accepted = State()
+
+ @pending.add_transition(requested)
+ async def try_request(cls, instance):
+ print("Would have tried to follow")
+ return False
+
+ requested.add_manual_transition(accepted)
+
+
+class Follow(StatorModel):
"""
When one user (the source) follows other (the target)
"""
@@ -24,8 +37,7 @@ class Follow(models.Model):
uri = models.CharField(blank=True, null=True, max_length=500)
note = models.TextField(blank=True, null=True)
- requested = models.BooleanField(default=False)
- accepted = models.BooleanField(default=False)
+ state = StateField(FollowStates)
created = models.DateTimeField(auto_now_add=True)
updated = models.DateTimeField(auto_now=True)
@@ -50,17 +62,15 @@ class Follow(models.Model):
(which can be local or remote).
"""
if not source.local:
- raise ValueError("You cannot initiate follows on a remote Identity")
+ raise ValueError("You cannot initiate follows from a remote Identity")
try:
follow = Follow.objects.get(source=source, target=target)
except Follow.DoesNotExist:
follow = Follow.objects.create(source=source, target=target, uri="")
follow.uri = source.actor_uri + f"follow/{follow.pk}/"
+ # TODO: Local follow approvals
if target.local:
- follow.requested = True
- follow.accepted = True
- else:
- Task.submit("follow_request", str(follow.pk))
+ follow.state = FollowStates.accepted
follow.save()
return follow
diff --git a/users/tasks/follow.py b/users/tasks/follow.py
index 872b35f..0f802cf 100644
--- a/users/tasks/follow.py
+++ b/users/tasks/follow.py
@@ -27,3 +27,36 @@ async def handle_follow_request(task_handler):
if response.status_code >= 400:
raise ValueError(f"Request error: {response.status_code} {response.content}")
await Follow.objects.filter(pk=follow.pk).aupdate(requested=True)
+
+
+def send_follow_undo(id):
+ """
+ Request a follow from a remote server
+ """
+ follow = Follow.objects.select_related("source", "source__domain", "target").get(
+ pk=id
+ )
+ # Construct the request
+ request = canonicalise(
+ {
+ "@context": "https://www.w3.org/ns/activitystreams",
+ "id": follow.uri + "#undo",
+ "type": "Undo",
+ "actor": follow.source.actor_uri,
+ "object": {
+ "id": follow.uri,
+ "type": "Follow",
+ "actor": follow.source.actor_uri,
+ "object": follow.target.actor_uri,
+ },
+ }
+ )
+ # Sign it and send it
+ from asgiref.sync import async_to_sync
+
+ response = async_to_sync(HttpSignature.signed_request)(
+ follow.target.inbox_uri, request, follow.source
+ )
+ if response.status_code >= 400:
+ raise ValueError(f"Request error: {response.status_code} {response.content}")
+ print(response)
diff --git a/users/views/identity.py b/users/views/identity.py
index 98fcdd6..41c7880 100644
--- a/users/views/identity.py
+++ b/users/views/identity.py
@@ -16,7 +16,6 @@ from django.views.generic import FormView, TemplateView, View
from core.forms import FormHelper
from core.ld import canonicalise
from core.signatures import HttpSignature
-from miniq.models import Task
from users.decorators import identity_required
from users.models import Domain, Follow, Identity
from users.shortcuts import by_handle_or_404