summaryrefslogtreecommitdiffstats
path: root/stator
diff options
context:
space:
mode:
authorAndrew Godwin2022-11-08 23:06:29 -0700
committerAndrew Godwin2022-11-09 22:29:49 -0700
commit61c324508e62bb640b4526183d0837fc57d742c2 (patch)
tree618ee8c88ce8a28224a187dc33b7c5fad6831d04 /stator
parent8a0a7558894afce8d25b7f0dc16775e899b72a94 (diff)
downloadtakahe-61c324508e62bb640b4526183d0837fc57d742c2.tar.gz
takahe-61c324508e62bb640b4526183d0837fc57d742c2.tar.bz2
takahe-61c324508e62bb640b4526183d0837fc57d742c2.zip
Midway point in task refactor - changing direction
Diffstat (limited to 'stator')
-rw-r--r--stator/__init__.py0
-rw-r--r--stator/admin.py8
-rw-r--r--stator/apps.py6
-rw-r--r--stator/graph.py162
-rw-r--r--stator/migrations/0001_initial.py31
-rw-r--r--stator/migrations/__init__.py0
-rw-r--r--stator/models.py191
-rw-r--r--stator/runner.py69
-rw-r--r--stator/tests/test_graph.py66
-rw-r--r--stator/views.py17
10 files changed, 550 insertions, 0 deletions
diff --git a/stator/__init__.py b/stator/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/stator/__init__.py
diff --git a/stator/admin.py b/stator/admin.py
new file mode 100644
index 0000000..c04d775
--- /dev/null
+++ b/stator/admin.py
@@ -0,0 +1,8 @@
+from django.contrib import admin
+
+from stator.models import StatorTask
+
+
+@admin.register(StatorTask)
+class DomainAdmin(admin.ModelAdmin):
+ list_display = ["id", "model_label", "instance_pk", "locked_until"]
diff --git a/stator/apps.py b/stator/apps.py
new file mode 100644
index 0000000..8910ecb
--- /dev/null
+++ b/stator/apps.py
@@ -0,0 +1,6 @@
+from django.apps import AppConfig
+
+
+class StatorConfig(AppConfig):
+ default_auto_field = "django.db.models.BigAutoField"
+ name = "stator"
diff --git a/stator/graph.py b/stator/graph.py
new file mode 100644
index 0000000..b06ffb8
--- /dev/null
+++ b/stator/graph.py
@@ -0,0 +1,162 @@
+import datetime
+from functools import wraps
+from typing import Callable, ClassVar, Dict, List, Optional, Set, Tuple, Union
+
+from django.db import models
+from django.utils import timezone
+
+
+class StateGraph:
+ """
+ Represents a graph of possible states and transitions to attempt on them.
+ Does not support subclasses of existing graphs yet.
+ """
+
+ states: ClassVar[Dict[str, "State"]]
+ choices: ClassVar[List[Tuple[str, str]]]
+ initial_state: ClassVar["State"]
+ terminal_states: ClassVar[Set["State"]]
+
+ def __init_subclass__(cls) -> None:
+ # Collect state memebers
+ cls.states = {}
+ for name, value in cls.__dict__.items():
+ if name in ["__module__", "__doc__", "states"]:
+ pass
+ elif name in ["initial_state", "terminal_states", "choices"]:
+ raise ValueError(f"Cannot name a state {name} - this is reserved")
+ elif isinstance(value, State):
+ value._add_to_graph(cls, name)
+ elif callable(value) or isinstance(value, classmethod):
+ pass
+ else:
+ raise ValueError(
+ f"Graph has item {name} of unallowed type {type(value)}"
+ )
+ # Check the graph layout
+ terminal_states = set()
+ initial_state = None
+ for state in cls.states.values():
+ if state.initial:
+ if initial_state:
+ raise ValueError(
+ f"The graph has more than one initial state: {initial_state} and {state}"
+ )
+ initial_state = state
+ if state.terminal:
+ terminal_states.add(state)
+ if initial_state is None:
+ raise ValueError("The graph has no initial state")
+ cls.initial_state = initial_state
+ cls.terminal_states = terminal_states
+ # Generate choices
+ cls.choices = [(name, name) for name in cls.states.keys()]
+
+
+class State:
+ """
+ Represents an individual state
+ """
+
+ def __init__(self, try_interval: float = 300):
+ self.try_interval = try_interval
+ self.parents: Set["State"] = set()
+ self.children: Dict["State", "Transition"] = {}
+
+ def _add_to_graph(self, graph: StateGraph, name: str):
+ self.graph = graph
+ self.name = name
+ self.graph.states[name] = self
+
+ def __repr__(self):
+ return f"<State {self.name}>"
+
+ def add_transition(
+ self,
+ other: "State",
+ handler: Optional[Union[str, Callable]] = None,
+ priority: int = 0,
+ ) -> Callable:
+ def decorator(handler: Union[str, Callable]):
+ self.children[other] = Transition(
+ self,
+ other,
+ handler,
+ priority=priority,
+ )
+ other.parents.add(self)
+ # All handlers should be class methods, so do that automatically.
+ if callable(handler):
+ return classmethod(handler)
+
+ # If we're not being called as a decorator, invoke it immediately
+ if handler is not None:
+ decorator(handler)
+ return decorator
+
+ def add_manual_transition(self, other: "State"):
+ self.children[other] = ManualTransition(self, other)
+ other.parents.add(self)
+
+ @property
+ def initial(self):
+ return not self.parents
+
+ @property
+ def terminal(self):
+ return not self.children
+
+ def transitions(self, automatic_only=False) -> List["Transition"]:
+ """
+ Returns all transitions from this State in priority order
+ """
+ if automatic_only:
+ transitions = [t for t in self.children.values() if t.automatic]
+ else:
+ transitions = self.children.values()
+ return sorted(transitions, key=lambda t: t.priority, reverse=True)
+
+
+class Transition:
+ """
+ A possible transition from one state to another
+ """
+
+ def __init__(
+ self,
+ from_state: State,
+ to_state: State,
+ handler: Union[str, Callable],
+ priority: int = 0,
+ ):
+ self.from_state = from_state
+ self.to_state = to_state
+ self.handler = handler
+ self.priority = priority
+ self.automatic = True
+
+ def get_handler(self) -> Callable:
+ """
+ Returns the handler (it might need resolving from a string)
+ """
+ if isinstance(self.handler, str):
+ self.handler = getattr(self.from_state.graph, self.handler)
+ return self.handler
+
+
+class ManualTransition(Transition):
+ """
+ A possible transition from one state to another that cannot be done by
+ the stator task runner, and must come from an external source.
+ """
+
+ def __init__(
+ self,
+ from_state: State,
+ to_state: State,
+ ):
+ self.from_state = from_state
+ self.to_state = to_state
+ self.handler = None
+ self.priority = 0
+ self.automatic = False
diff --git a/stator/migrations/0001_initial.py b/stator/migrations/0001_initial.py
new file mode 100644
index 0000000..f485836
--- /dev/null
+++ b/stator/migrations/0001_initial.py
@@ -0,0 +1,31 @@
+# Generated by Django 4.1.3 on 2022-11-09 05:46
+
+from django.db import migrations, models
+
+
+class Migration(migrations.Migration):
+
+ initial = True
+
+ dependencies = []
+
+ operations = [
+ migrations.CreateModel(
+ name="StatorTask",
+ fields=[
+ (
+ "id",
+ models.BigAutoField(
+ auto_created=True,
+ primary_key=True,
+ serialize=False,
+ verbose_name="ID",
+ ),
+ ),
+ ("model_label", models.CharField(max_length=200)),
+ ("instance_pk", models.CharField(max_length=200)),
+ ("locked_until", models.DateTimeField(blank=True, null=True)),
+ ("priority", models.IntegerField(default=0)),
+ ],
+ ),
+ ]
diff --git a/stator/migrations/__init__.py b/stator/migrations/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/stator/migrations/__init__.py
diff --git a/stator/models.py b/stator/models.py
new file mode 100644
index 0000000..3b0da0a
--- /dev/null
+++ b/stator/models.py
@@ -0,0 +1,191 @@
+import datetime
+from functools import reduce
+from typing import Type, cast
+
+from asgiref.sync import sync_to_async
+from django.apps import apps
+from django.db import models, transaction
+from django.utils import timezone
+from django.utils.functional import classproperty
+
+from stator.graph import State, StateGraph
+
+
+class StateField(models.CharField):
+ """
+ A special field that automatically gets choices from a state graph
+ """
+
+ def __init__(self, graph: Type[StateGraph], **kwargs):
+ # Sensible default for state length
+ kwargs.setdefault("max_length", 100)
+ # Add choices and initial
+ self.graph = graph
+ kwargs["choices"] = self.graph.choices
+ kwargs["default"] = self.graph.initial_state.name
+ super().__init__(**kwargs)
+
+ def deconstruct(self):
+ name, path, args, kwargs = super().deconstruct()
+ kwargs["graph"] = self.graph
+ return name, path, args, kwargs
+
+ def from_db_value(self, value, expression, connection):
+ if value is None:
+ return value
+ return self.graph.states[value]
+
+ def to_python(self, value):
+ if isinstance(value, State) or value is None:
+ return value
+ return self.graph.states[value]
+
+ def get_prep_value(self, value):
+ if isinstance(value, State):
+ return value.name
+ return value
+
+
+class StatorModel(models.Model):
+ """
+ A model base class that has a state machine backing it, with tasks to work
+ out when to move the state to the next one.
+
+ You need to provide a "state" field as an instance of StateField on the
+ concrete model yourself.
+ """
+
+ # When the state last actually changed, or the date of instance creation
+ state_changed = models.DateTimeField(auto_now_add=True)
+
+ # When the last state change for the current state was attempted
+ # (and not successful, as this is cleared on transition)
+ state_attempted = models.DateTimeField(blank=True, null=True)
+
+ class Meta:
+ abstract = True
+
+ @classmethod
+ def schedule_overdue(cls, now=None) -> models.QuerySet:
+ """
+ Finds instances of this model that need to run and schedule them.
+ """
+ q = models.Q()
+ for transition in cls.state_graph.transitions(automatic_only=True):
+ q = q | transition.get_query(now=now)
+ return cls.objects.filter(q)
+
+ @classproperty
+ def state_graph(cls) -> Type[StateGraph]:
+ return cls._meta.get_field("state").graph
+
+ def schedule_transition(self, priority: int = 0):
+ """
+ Adds this instance to the queue to get its state transition attempted.
+
+ The scheduler will call this, but you can also call it directly if you
+ know it'll be ready and want to lower latency.
+ """
+ StatorTask.schedule_for_execution(self, priority=priority)
+
+ async def attempt_transition(self):
+ """
+ Attempts to transition the current state by running its handler(s).
+ """
+ # Try each transition in priority order
+ for transition in self.state_graph.states[self.state].transitions(
+ automatic_only=True
+ ):
+ success = await transition.get_handler()(self)
+ if success:
+ await self.perform_transition(transition.to_state.name)
+ return
+ await self.__class__.objects.filter(pk=self.pk).aupdate(
+ state_attempted=timezone.now()
+ )
+
+ async def perform_transition(self, state_name):
+ """
+ Transitions the instance to the given state name
+ """
+ if state_name not in self.state_graph.states:
+ raise ValueError(f"Invalid state {state_name}")
+ await self.__class__.objects.filter(pk=self.pk).aupdate(
+ state=state_name,
+ state_changed=timezone.now(),
+ state_attempted=None,
+ )
+
+
+class StatorTask(models.Model):
+ """
+ The model that we use for an internal scheduling queue.
+
+ Entries in this queue are up for checking and execution - it also performs
+ locking to ensure we get closer to exactly-once execution (but we err on
+ the side of at-least-once)
+ """
+
+ # appname.modelname (lowercased) label for the model this represents
+ model_label = models.CharField(max_length=200)
+
+ # The primary key of that model (probably int or str)
+ instance_pk = models.CharField(max_length=200)
+
+ # Locking columns (no runner ID, as we have no heartbeats - all runners
+ # only live for a short amount of time anyway)
+ locked_until = models.DateTimeField(null=True, blank=True)
+
+ # Basic total ordering priority - higher is more important
+ priority = models.IntegerField(default=0)
+
+ def __str__(self):
+ return f"#{self.pk}: {self.model_label}.{self.instance_pk}"
+
+ @classmethod
+ def schedule_for_execution(cls, model_instance: StatorModel, priority: int = 0):
+ # We don't do a transaction here as it's fine to occasionally double up
+ model_label = model_instance._meta.label_lower
+ pk = model_instance.pk
+ # TODO: Increase priority of existing if present
+ if not cls.objects.filter(
+ model_label=model_label, instance_pk=pk, locked__isnull=True
+ ).exists():
+ StatorTask.objects.create(
+ model_label=model_label,
+ instance_pk=pk,
+ priority=priority,
+ )
+
+ @classmethod
+ def get_for_execution(cls, number: int, lock_expiry: datetime.datetime):
+ """
+ Returns up to `number` tasks for execution, having locked them.
+ """
+ with transaction.atomic():
+ selected = list(
+ cls.objects.filter(locked_until__isnull=True)[
+ :number
+ ].select_for_update()
+ )
+ cls.objects.filter(pk__in=[i.pk for i in selected]).update(
+ locked_until=timezone.now()
+ )
+ return selected
+
+ @classmethod
+ async def aget_for_execution(cls, number: int, lock_expiry: datetime.datetime):
+ return await sync_to_async(cls.get_for_execution)(number, lock_expiry)
+
+ @classmethod
+ async def aclean_old_locks(cls):
+ await cls.objects.filter(locked_until__lte=timezone.now()).aupdate(
+ locked_until=None
+ )
+
+ async def aget_model_instance(self) -> StatorModel:
+ model = apps.get_model(self.model_label)
+ return cast(StatorModel, await model.objects.aget(pk=self.pk))
+
+ async def adelete(self):
+ self.__class__.objects.adelete(pk=self.pk)
diff --git a/stator/runner.py b/stator/runner.py
new file mode 100644
index 0000000..8c6e0f1
--- /dev/null
+++ b/stator/runner.py
@@ -0,0 +1,69 @@
+import asyncio
+import datetime
+import time
+import uuid
+from typing import List, Type
+
+from asgiref.sync import sync_to_async
+from django.db import transaction
+from django.utils import timezone
+
+from stator.models import StatorModel, StatorTask
+
+
+class StatorRunner:
+ """
+ Runs tasks on models that are looking for state changes.
+ Designed to run in a one-shot mode, living inside a request.
+ """
+
+ START_TIMEOUT = 30
+ TOTAL_TIMEOUT = 60
+ LOCK_TIMEOUT = 120
+
+ MAX_TASKS = 30
+
+ def __init__(self, models: List[Type[StatorModel]]):
+ self.models = models
+ self.runner_id = uuid.uuid4().hex
+
+ async def run(self):
+ start_time = time.monotonic()
+ self.handled = 0
+ self.tasks = []
+ # Clean up old locks
+ await StatorTask.aclean_old_locks()
+ # Examine what needs scheduling
+
+ # For the first time period, launch tasks
+ while (time.monotonic() - start_time) < self.START_TIMEOUT:
+ self.remove_completed_tasks()
+ space_remaining = self.MAX_TASKS - len(self.tasks)
+ # Fetch new tasks
+ if space_remaining > 0:
+ for new_task in await StatorTask.aget_for_execution(
+ space_remaining,
+ timezone.now() + datetime.timedelta(seconds=self.LOCK_TIMEOUT),
+ ):
+ self.tasks.append(asyncio.create_task(self.run_task(new_task)))
+ self.handled += 1
+ # Prevent busylooping
+ await asyncio.sleep(0.01)
+ # Then wait for tasks to finish
+ while (time.monotonic() - start_time) < self.TOTAL_TIMEOUT:
+ self.remove_completed_tasks()
+ if not self.tasks:
+ break
+ # Prevent busylooping
+ await asyncio.sleep(1)
+ return self.handled
+
+ async def run_task(self, task: StatorTask):
+ # Resolve the model instance
+ model_instance = await task.aget_model_instance()
+ await model_instance.attempt_transition()
+ # Remove ourselves from the database as complete
+ await task.adelete()
+
+ def remove_completed_tasks(self):
+ self.tasks = [t for t in self.tasks if not t.done()]
diff --git a/stator/tests/test_graph.py b/stator/tests/test_graph.py
new file mode 100644
index 0000000..f6b8404
--- /dev/null
+++ b/stator/tests/test_graph.py
@@ -0,0 +1,66 @@
+import pytest
+
+from stator.graph import State, StateGraph
+
+
+def test_declare():
+ """
+ Tests a basic graph declaration and various kinds of handler
+ lookups.
+ """
+
+ fake_handler = lambda: True
+
+ class TestGraph(StateGraph):
+ initial = State()
+ second = State()
+ third = State()
+ fourth = State()
+ final = State()
+
+ initial.add_transition(second, 60, handler=fake_handler)
+ second.add_transition(third, 60, handler="check_third")
+
+ def check_third(cls):
+ return True
+
+ @third.add_transition(fourth, 60)
+ def check_fourth(cls):
+ return True
+
+ fourth.add_manual_transition(final)
+
+ assert TestGraph.initial_state == TestGraph.initial
+ assert TestGraph.terminal_states == {TestGraph.final}
+
+ assert TestGraph.initial.children[TestGraph.second].get_handler() == fake_handler
+ assert (
+ TestGraph.second.children[TestGraph.third].get_handler()
+ == TestGraph.check_third
+ )
+ assert (
+ TestGraph.third.children[TestGraph.fourth].get_handler().__name__
+ == "check_fourth"
+ )
+
+
+def test_bad_declarations():
+ """
+ Tests that you can't declare an invalid graph.
+ """
+ # More than one initial state
+ with pytest.raises(ValueError):
+
+ class TestGraph(StateGraph):
+ initial = State()
+ initial2 = State()
+
+ # No initial states
+ with pytest.raises(ValueError):
+
+ class TestGraph(StateGraph):
+ loop = State()
+ loop2 = State()
+
+ loop.add_transition(loop2, 1, handler="fake")
+ loop2.add_transition(loop, 1, handler="fake")
diff --git a/stator/views.py b/stator/views.py
new file mode 100644
index 0000000..ef09b8e
--- /dev/null
+++ b/stator/views.py
@@ -0,0 +1,17 @@
+from django.http import HttpResponse
+from django.views import View
+
+from stator.runner import StatorRunner
+from users.models import Follow
+
+
+class RequestRunner(View):
+ """
+ Runs a Stator runner within a HTTP request. For when you're on something
+ serverless.
+ """
+
+ async def get(self, request):
+ runner = StatorRunner([Follow])
+ handled = await runner.run()
+ return HttpResponse(f"Handled {handled}")