From 61c324508e62bb640b4526183d0837fc57d742c2 Mon Sep 17 00:00:00 2001 From: Andrew Godwin Date: Tue, 8 Nov 2022 23:06:29 -0700 Subject: Midway point in task refactor - changing direction --- stator/__init__.py | 0 stator/admin.py | 8 ++ stator/apps.py | 6 ++ stator/graph.py | 162 ++++++++++++++++++++++++++++++++ stator/migrations/0001_initial.py | 31 +++++++ stator/migrations/__init__.py | 0 stator/models.py | 191 ++++++++++++++++++++++++++++++++++++++ stator/runner.py | 69 ++++++++++++++ stator/tests/test_graph.py | 66 +++++++++++++ stator/views.py | 17 ++++ 10 files changed, 550 insertions(+) create mode 100644 stator/__init__.py create mode 100644 stator/admin.py create mode 100644 stator/apps.py create mode 100644 stator/graph.py create mode 100644 stator/migrations/0001_initial.py create mode 100644 stator/migrations/__init__.py create mode 100644 stator/models.py create mode 100644 stator/runner.py create mode 100644 stator/tests/test_graph.py create mode 100644 stator/views.py (limited to 'stator') diff --git a/stator/__init__.py b/stator/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/stator/admin.py b/stator/admin.py new file mode 100644 index 0000000..c04d775 --- /dev/null +++ b/stator/admin.py @@ -0,0 +1,8 @@ +from django.contrib import admin + +from stator.models import StatorTask + + +@admin.register(StatorTask) +class DomainAdmin(admin.ModelAdmin): + list_display = ["id", "model_label", "instance_pk", "locked_until"] diff --git a/stator/apps.py b/stator/apps.py new file mode 100644 index 0000000..8910ecb --- /dev/null +++ b/stator/apps.py @@ -0,0 +1,6 @@ +from django.apps import AppConfig + + +class StatorConfig(AppConfig): + default_auto_field = "django.db.models.BigAutoField" + name = "stator" diff --git a/stator/graph.py b/stator/graph.py new file mode 100644 index 0000000..b06ffb8 --- /dev/null +++ b/stator/graph.py @@ -0,0 +1,162 @@ +import datetime +from functools import wraps +from typing import Callable, ClassVar, Dict, List, Optional, Set, Tuple, Union + +from django.db import models +from django.utils import timezone + + +class StateGraph: + """ + Represents a graph of possible states and transitions to attempt on them. + Does not support subclasses of existing graphs yet. + """ + + states: ClassVar[Dict[str, "State"]] + choices: ClassVar[List[Tuple[str, str]]] + initial_state: ClassVar["State"] + terminal_states: ClassVar[Set["State"]] + + def __init_subclass__(cls) -> None: + # Collect state memebers + cls.states = {} + for name, value in cls.__dict__.items(): + if name in ["__module__", "__doc__", "states"]: + pass + elif name in ["initial_state", "terminal_states", "choices"]: + raise ValueError(f"Cannot name a state {name} - this is reserved") + elif isinstance(value, State): + value._add_to_graph(cls, name) + elif callable(value) or isinstance(value, classmethod): + pass + else: + raise ValueError( + f"Graph has item {name} of unallowed type {type(value)}" + ) + # Check the graph layout + terminal_states = set() + initial_state = None + for state in cls.states.values(): + if state.initial: + if initial_state: + raise ValueError( + f"The graph has more than one initial state: {initial_state} and {state}" + ) + initial_state = state + if state.terminal: + terminal_states.add(state) + if initial_state is None: + raise ValueError("The graph has no initial state") + cls.initial_state = initial_state + cls.terminal_states = terminal_states + # Generate choices + cls.choices = [(name, name) for name in cls.states.keys()] + + +class State: + """ + Represents an individual state + """ + + def __init__(self, try_interval: float = 300): + self.try_interval = try_interval + self.parents: Set["State"] = set() + self.children: Dict["State", "Transition"] = {} + + def _add_to_graph(self, graph: StateGraph, name: str): + self.graph = graph + self.name = name + self.graph.states[name] = self + + def __repr__(self): + return f"" + + def add_transition( + self, + other: "State", + handler: Optional[Union[str, Callable]] = None, + priority: int = 0, + ) -> Callable: + def decorator(handler: Union[str, Callable]): + self.children[other] = Transition( + self, + other, + handler, + priority=priority, + ) + other.parents.add(self) + # All handlers should be class methods, so do that automatically. + if callable(handler): + return classmethod(handler) + + # If we're not being called as a decorator, invoke it immediately + if handler is not None: + decorator(handler) + return decorator + + def add_manual_transition(self, other: "State"): + self.children[other] = ManualTransition(self, other) + other.parents.add(self) + + @property + def initial(self): + return not self.parents + + @property + def terminal(self): + return not self.children + + def transitions(self, automatic_only=False) -> List["Transition"]: + """ + Returns all transitions from this State in priority order + """ + if automatic_only: + transitions = [t for t in self.children.values() if t.automatic] + else: + transitions = self.children.values() + return sorted(transitions, key=lambda t: t.priority, reverse=True) + + +class Transition: + """ + A possible transition from one state to another + """ + + def __init__( + self, + from_state: State, + to_state: State, + handler: Union[str, Callable], + priority: int = 0, + ): + self.from_state = from_state + self.to_state = to_state + self.handler = handler + self.priority = priority + self.automatic = True + + def get_handler(self) -> Callable: + """ + Returns the handler (it might need resolving from a string) + """ + if isinstance(self.handler, str): + self.handler = getattr(self.from_state.graph, self.handler) + return self.handler + + +class ManualTransition(Transition): + """ + A possible transition from one state to another that cannot be done by + the stator task runner, and must come from an external source. + """ + + def __init__( + self, + from_state: State, + to_state: State, + ): + self.from_state = from_state + self.to_state = to_state + self.handler = None + self.priority = 0 + self.automatic = False diff --git a/stator/migrations/0001_initial.py b/stator/migrations/0001_initial.py new file mode 100644 index 0000000..f485836 --- /dev/null +++ b/stator/migrations/0001_initial.py @@ -0,0 +1,31 @@ +# Generated by Django 4.1.3 on 2022-11-09 05:46 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + initial = True + + dependencies = [] + + operations = [ + migrations.CreateModel( + name="StatorTask", + fields=[ + ( + "id", + models.BigAutoField( + auto_created=True, + primary_key=True, + serialize=False, + verbose_name="ID", + ), + ), + ("model_label", models.CharField(max_length=200)), + ("instance_pk", models.CharField(max_length=200)), + ("locked_until", models.DateTimeField(blank=True, null=True)), + ("priority", models.IntegerField(default=0)), + ], + ), + ] diff --git a/stator/migrations/__init__.py b/stator/migrations/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/stator/models.py b/stator/models.py new file mode 100644 index 0000000..3b0da0a --- /dev/null +++ b/stator/models.py @@ -0,0 +1,191 @@ +import datetime +from functools import reduce +from typing import Type, cast + +from asgiref.sync import sync_to_async +from django.apps import apps +from django.db import models, transaction +from django.utils import timezone +from django.utils.functional import classproperty + +from stator.graph import State, StateGraph + + +class StateField(models.CharField): + """ + A special field that automatically gets choices from a state graph + """ + + def __init__(self, graph: Type[StateGraph], **kwargs): + # Sensible default for state length + kwargs.setdefault("max_length", 100) + # Add choices and initial + self.graph = graph + kwargs["choices"] = self.graph.choices + kwargs["default"] = self.graph.initial_state.name + super().__init__(**kwargs) + + def deconstruct(self): + name, path, args, kwargs = super().deconstruct() + kwargs["graph"] = self.graph + return name, path, args, kwargs + + def from_db_value(self, value, expression, connection): + if value is None: + return value + return self.graph.states[value] + + def to_python(self, value): + if isinstance(value, State) or value is None: + return value + return self.graph.states[value] + + def get_prep_value(self, value): + if isinstance(value, State): + return value.name + return value + + +class StatorModel(models.Model): + """ + A model base class that has a state machine backing it, with tasks to work + out when to move the state to the next one. + + You need to provide a "state" field as an instance of StateField on the + concrete model yourself. + """ + + # When the state last actually changed, or the date of instance creation + state_changed = models.DateTimeField(auto_now_add=True) + + # When the last state change for the current state was attempted + # (and not successful, as this is cleared on transition) + state_attempted = models.DateTimeField(blank=True, null=True) + + class Meta: + abstract = True + + @classmethod + def schedule_overdue(cls, now=None) -> models.QuerySet: + """ + Finds instances of this model that need to run and schedule them. + """ + q = models.Q() + for transition in cls.state_graph.transitions(automatic_only=True): + q = q | transition.get_query(now=now) + return cls.objects.filter(q) + + @classproperty + def state_graph(cls) -> Type[StateGraph]: + return cls._meta.get_field("state").graph + + def schedule_transition(self, priority: int = 0): + """ + Adds this instance to the queue to get its state transition attempted. + + The scheduler will call this, but you can also call it directly if you + know it'll be ready and want to lower latency. + """ + StatorTask.schedule_for_execution(self, priority=priority) + + async def attempt_transition(self): + """ + Attempts to transition the current state by running its handler(s). + """ + # Try each transition in priority order + for transition in self.state_graph.states[self.state].transitions( + automatic_only=True + ): + success = await transition.get_handler()(self) + if success: + await self.perform_transition(transition.to_state.name) + return + await self.__class__.objects.filter(pk=self.pk).aupdate( + state_attempted=timezone.now() + ) + + async def perform_transition(self, state_name): + """ + Transitions the instance to the given state name + """ + if state_name not in self.state_graph.states: + raise ValueError(f"Invalid state {state_name}") + await self.__class__.objects.filter(pk=self.pk).aupdate( + state=state_name, + state_changed=timezone.now(), + state_attempted=None, + ) + + +class StatorTask(models.Model): + """ + The model that we use for an internal scheduling queue. + + Entries in this queue are up for checking and execution - it also performs + locking to ensure we get closer to exactly-once execution (but we err on + the side of at-least-once) + """ + + # appname.modelname (lowercased) label for the model this represents + model_label = models.CharField(max_length=200) + + # The primary key of that model (probably int or str) + instance_pk = models.CharField(max_length=200) + + # Locking columns (no runner ID, as we have no heartbeats - all runners + # only live for a short amount of time anyway) + locked_until = models.DateTimeField(null=True, blank=True) + + # Basic total ordering priority - higher is more important + priority = models.IntegerField(default=0) + + def __str__(self): + return f"#{self.pk}: {self.model_label}.{self.instance_pk}" + + @classmethod + def schedule_for_execution(cls, model_instance: StatorModel, priority: int = 0): + # We don't do a transaction here as it's fine to occasionally double up + model_label = model_instance._meta.label_lower + pk = model_instance.pk + # TODO: Increase priority of existing if present + if not cls.objects.filter( + model_label=model_label, instance_pk=pk, locked__isnull=True + ).exists(): + StatorTask.objects.create( + model_label=model_label, + instance_pk=pk, + priority=priority, + ) + + @classmethod + def get_for_execution(cls, number: int, lock_expiry: datetime.datetime): + """ + Returns up to `number` tasks for execution, having locked them. + """ + with transaction.atomic(): + selected = list( + cls.objects.filter(locked_until__isnull=True)[ + :number + ].select_for_update() + ) + cls.objects.filter(pk__in=[i.pk for i in selected]).update( + locked_until=timezone.now() + ) + return selected + + @classmethod + async def aget_for_execution(cls, number: int, lock_expiry: datetime.datetime): + return await sync_to_async(cls.get_for_execution)(number, lock_expiry) + + @classmethod + async def aclean_old_locks(cls): + await cls.objects.filter(locked_until__lte=timezone.now()).aupdate( + locked_until=None + ) + + async def aget_model_instance(self) -> StatorModel: + model = apps.get_model(self.model_label) + return cast(StatorModel, await model.objects.aget(pk=self.pk)) + + async def adelete(self): + self.__class__.objects.adelete(pk=self.pk) diff --git a/stator/runner.py b/stator/runner.py new file mode 100644 index 0000000..8c6e0f1 --- /dev/null +++ b/stator/runner.py @@ -0,0 +1,69 @@ +import asyncio +import datetime +import time +import uuid +from typing import List, Type + +from asgiref.sync import sync_to_async +from django.db import transaction +from django.utils import timezone + +from stator.models import StatorModel, StatorTask + + +class StatorRunner: + """ + Runs tasks on models that are looking for state changes. + Designed to run in a one-shot mode, living inside a request. + """ + + START_TIMEOUT = 30 + TOTAL_TIMEOUT = 60 + LOCK_TIMEOUT = 120 + + MAX_TASKS = 30 + + def __init__(self, models: List[Type[StatorModel]]): + self.models = models + self.runner_id = uuid.uuid4().hex + + async def run(self): + start_time = time.monotonic() + self.handled = 0 + self.tasks = [] + # Clean up old locks + await StatorTask.aclean_old_locks() + # Examine what needs scheduling + + # For the first time period, launch tasks + while (time.monotonic() - start_time) < self.START_TIMEOUT: + self.remove_completed_tasks() + space_remaining = self.MAX_TASKS - len(self.tasks) + # Fetch new tasks + if space_remaining > 0: + for new_task in await StatorTask.aget_for_execution( + space_remaining, + timezone.now() + datetime.timedelta(seconds=self.LOCK_TIMEOUT), + ): + self.tasks.append(asyncio.create_task(self.run_task(new_task))) + self.handled += 1 + # Prevent busylooping + await asyncio.sleep(0.01) + # Then wait for tasks to finish + while (time.monotonic() - start_time) < self.TOTAL_TIMEOUT: + self.remove_completed_tasks() + if not self.tasks: + break + # Prevent busylooping + await asyncio.sleep(1) + return self.handled + + async def run_task(self, task: StatorTask): + # Resolve the model instance + model_instance = await task.aget_model_instance() + await model_instance.attempt_transition() + # Remove ourselves from the database as complete + await task.adelete() + + def remove_completed_tasks(self): + self.tasks = [t for t in self.tasks if not t.done()] diff --git a/stator/tests/test_graph.py b/stator/tests/test_graph.py new file mode 100644 index 0000000..f6b8404 --- /dev/null +++ b/stator/tests/test_graph.py @@ -0,0 +1,66 @@ +import pytest + +from stator.graph import State, StateGraph + + +def test_declare(): + """ + Tests a basic graph declaration and various kinds of handler + lookups. + """ + + fake_handler = lambda: True + + class TestGraph(StateGraph): + initial = State() + second = State() + third = State() + fourth = State() + final = State() + + initial.add_transition(second, 60, handler=fake_handler) + second.add_transition(third, 60, handler="check_third") + + def check_third(cls): + return True + + @third.add_transition(fourth, 60) + def check_fourth(cls): + return True + + fourth.add_manual_transition(final) + + assert TestGraph.initial_state == TestGraph.initial + assert TestGraph.terminal_states == {TestGraph.final} + + assert TestGraph.initial.children[TestGraph.second].get_handler() == fake_handler + assert ( + TestGraph.second.children[TestGraph.third].get_handler() + == TestGraph.check_third + ) + assert ( + TestGraph.third.children[TestGraph.fourth].get_handler().__name__ + == "check_fourth" + ) + + +def test_bad_declarations(): + """ + Tests that you can't declare an invalid graph. + """ + # More than one initial state + with pytest.raises(ValueError): + + class TestGraph(StateGraph): + initial = State() + initial2 = State() + + # No initial states + with pytest.raises(ValueError): + + class TestGraph(StateGraph): + loop = State() + loop2 = State() + + loop.add_transition(loop2, 1, handler="fake") + loop2.add_transition(loop, 1, handler="fake") diff --git a/stator/views.py b/stator/views.py new file mode 100644 index 0000000..ef09b8e --- /dev/null +++ b/stator/views.py @@ -0,0 +1,17 @@ +from django.http import HttpResponse +from django.views import View + +from stator.runner import StatorRunner +from users.models import Follow + + +class RequestRunner(View): + """ + Runs a Stator runner within a HTTP request. For when you're on something + serverless. + """ + + async def get(self, request): + runner = StatorRunner([Follow]) + handled = await runner.run() + return HttpResponse(f"Handled {handled}") -- cgit v1.2.3