summaryrefslogtreecommitdiffstats
path: root/stator
diff options
context:
space:
mode:
authorAndrew Godwin2022-11-09 22:29:33 -0700
committerAndrew Godwin2022-11-09 22:29:49 -0700
commit7746abbbb7700fa918450101bbc6d29ed9b4b608 (patch)
tree8768efd8201faa2fee18e5d3b46f33785002f5d6 /stator
parent61c324508e62bb640b4526183d0837fc57d742c2 (diff)
downloadtakahe-7746abbbb7700fa918450101bbc6d29ed9b4b608.tar.gz
takahe-7746abbbb7700fa918450101bbc6d29ed9b4b608.tar.bz2
takahe-7746abbbb7700fa918450101bbc6d29ed9b4b608.zip
Most of the way through the stator refactor
Diffstat (limited to 'stator')
-rw-r--r--stator/admin.py15
-rw-r--r--stator/graph.py47
-rw-r--r--stator/management/__init__.py0
-rw-r--r--stator/management/commands/__init__.py0
-rw-r--r--stator/management/commands/runstator.py28
-rw-r--r--stator/migrations/0001_initial.py11
-rw-r--r--stator/models.py195
-rw-r--r--stator/runner.py47
-rw-r--r--stator/tests/test_graph.py4
9 files changed, 220 insertions, 127 deletions
diff --git a/stator/admin.py b/stator/admin.py
index c04d775..025f225 100644
--- a/stator/admin.py
+++ b/stator/admin.py
@@ -1,8 +1,17 @@
from django.contrib import admin
-from stator.models import StatorTask
+from stator.models import StatorError
-@admin.register(StatorTask)
+@admin.register(StatorError)
class DomainAdmin(admin.ModelAdmin):
- list_display = ["id", "model_label", "instance_pk", "locked_until"]
+ list_display = [
+ "id",
+ "date",
+ "model_label",
+ "instance_pk",
+ "from_state",
+ "to_state",
+ "error",
+ ]
+ ordering = ["-date"]
diff --git a/stator/graph.py b/stator/graph.py
index b06ffb8..7fc23f7 100644
--- a/stator/graph.py
+++ b/stator/graph.py
@@ -1,9 +1,16 @@
-import datetime
-from functools import wraps
-from typing import Callable, ClassVar, Dict, List, Optional, Set, Tuple, Union
-
-from django.db import models
-from django.utils import timezone
+from typing import (
+ Any,
+ Callable,
+ ClassVar,
+ Dict,
+ List,
+ Optional,
+ Set,
+ Tuple,
+ Type,
+ Union,
+ cast,
+)
class StateGraph:
@@ -13,7 +20,7 @@ class StateGraph:
"""
states: ClassVar[Dict[str, "State"]]
- choices: ClassVar[List[Tuple[str, str]]]
+ choices: ClassVar[List[Tuple[object, str]]]
initial_state: ClassVar["State"]
terminal_states: ClassVar[Set["State"]]
@@ -50,7 +57,7 @@ class StateGraph:
cls.initial_state = initial_state
cls.terminal_states = terminal_states
# Generate choices
- cls.choices = [(name, name) for name in cls.states.keys()]
+ cls.choices = [(state, name) for name, state in cls.states.items()]
class State:
@@ -63,7 +70,7 @@ class State:
self.parents: Set["State"] = set()
self.children: Dict["State", "Transition"] = {}
- def _add_to_graph(self, graph: StateGraph, name: str):
+ def _add_to_graph(self, graph: Type[StateGraph], name: str):
self.graph = graph
self.name = name
self.graph.states[name] = self
@@ -71,13 +78,19 @@ class State:
def __repr__(self):
return f"<State {self.name}>"
+ def __str__(self):
+ return self.name
+
+ def __len__(self):
+ return len(self.name)
+
def add_transition(
self,
other: "State",
- handler: Optional[Union[str, Callable]] = None,
+ handler: Optional[Callable] = None,
priority: int = 0,
) -> Callable:
- def decorator(handler: Union[str, Callable]):
+ def decorator(handler: Callable[[Any], bool]):
self.children[other] = Transition(
self,
other,
@@ -85,9 +98,7 @@ class State:
priority=priority,
)
other.parents.add(self)
- # All handlers should be class methods, so do that automatically.
- if callable(handler):
- return classmethod(handler)
+ return handler
# If we're not being called as a decorator, invoke it immediately
if handler is not None:
@@ -113,7 +124,7 @@ class State:
if automatic_only:
transitions = [t for t in self.children.values() if t.automatic]
else:
- transitions = self.children.values()
+ transitions = list(self.children.values())
return sorted(transitions, key=lambda t: t.priority, reverse=True)
@@ -141,7 +152,10 @@ class Transition:
"""
if isinstance(self.handler, str):
self.handler = getattr(self.from_state.graph, self.handler)
- return self.handler
+ return cast(Callable, self.handler)
+
+ def __repr__(self):
+ return f"<Transition {self.from_state} -> {self.to_state}>"
class ManualTransition(Transition):
@@ -157,6 +171,5 @@ class ManualTransition(Transition):
):
self.from_state = from_state
self.to_state = to_state
- self.handler = None
self.priority = 0
self.automatic = False
diff --git a/stator/management/__init__.py b/stator/management/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/stator/management/__init__.py
diff --git a/stator/management/commands/__init__.py b/stator/management/commands/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/stator/management/commands/__init__.py
diff --git a/stator/management/commands/runstator.py b/stator/management/commands/runstator.py
new file mode 100644
index 0000000..1307fef
--- /dev/null
+++ b/stator/management/commands/runstator.py
@@ -0,0 +1,28 @@
+from typing import List, Type, cast
+
+from asgiref.sync import async_to_sync
+from django.apps import apps
+from django.core.management.base import BaseCommand
+
+from stator.models import StatorModel
+from stator.runner import StatorRunner
+
+
+class Command(BaseCommand):
+ help = "Runs a Stator runner for a short period"
+
+ def add_arguments(self, parser):
+ parser.add_argument("model_labels", nargs="*", type=str)
+
+ def handle(self, model_labels: List[str], *args, **options):
+ # Resolve the models list into names
+ models = cast(
+ List[Type[StatorModel]],
+ [apps.get_model(label) for label in model_labels],
+ )
+ if not models:
+ models = StatorModel.subclasses
+ print("Running for models: " + " ".join(m._meta.label_lower for m in models))
+ # Run a runner
+ runner = StatorRunner(models)
+ async_to_sync(runner.run)()
diff --git a/stator/migrations/0001_initial.py b/stator/migrations/0001_initial.py
index f485836..d56ed5c 100644
--- a/stator/migrations/0001_initial.py
+++ b/stator/migrations/0001_initial.py
@@ -1,4 +1,4 @@
-# Generated by Django 4.1.3 on 2022-11-09 05:46
+# Generated by Django 4.1.3 on 2022-11-10 03:24
from django.db import migrations, models
@@ -11,7 +11,7 @@ class Migration(migrations.Migration):
operations = [
migrations.CreateModel(
- name="StatorTask",
+ name="StatorError",
fields=[
(
"id",
@@ -24,8 +24,11 @@ class Migration(migrations.Migration):
),
("model_label", models.CharField(max_length=200)),
("instance_pk", models.CharField(max_length=200)),
- ("locked_until", models.DateTimeField(blank=True, null=True)),
- ("priority", models.IntegerField(default=0)),
+ ("from_state", models.CharField(max_length=200)),
+ ("to_state", models.CharField(max_length=200)),
+ ("date", models.DateTimeField(auto_now_add=True)),
+ ("error", models.TextField()),
+ ("error_details", models.TextField(blank=True, null=True)),
],
),
]
diff --git a/stator/models.py b/stator/models.py
index 3b0da0a..235b18c 100644
--- a/stator/models.py
+++ b/stator/models.py
@@ -1,14 +1,13 @@
import datetime
-from functools import reduce
-from typing import Type, cast
+import traceback
+from typing import ClassVar, List, Optional, Type, cast
from asgiref.sync import sync_to_async
-from django.apps import apps
from django.db import models, transaction
from django.utils import timezone
from django.utils.functional import classproperty
-from stator.graph import State, StateGraph
+from stator.graph import State, StateGraph, Transition
class StateField(models.CharField):
@@ -55,6 +54,9 @@ class StatorModel(models.Model):
concrete model yourself.
"""
+ # If this row is up for transition attempts
+ state_ready = models.BooleanField(default=False)
+
# When the state last actually changed, or the date of instance creation
state_changed = models.DateTimeField(auto_now_add=True)
@@ -62,68 +64,128 @@ class StatorModel(models.Model):
# (and not successful, as this is cleared on transition)
state_attempted = models.DateTimeField(blank=True, null=True)
+ # If a lock is out on this row, when it is locked until
+ # (we don't identify the lock owner, as there's no heartbeats)
+ state_locked_until = models.DateTimeField(null=True, blank=True)
+
+ # Collection of subclasses of us
+ subclasses: ClassVar[List[Type["StatorModel"]]] = []
+
class Meta:
abstract = True
+ def __init_subclass__(cls) -> None:
+ if cls is not StatorModel:
+ cls.subclasses.append(cls)
+
+ @classproperty
+ def state_graph(cls) -> Type[StateGraph]:
+ return cls._meta.get_field("state").graph
+
@classmethod
- def schedule_overdue(cls, now=None) -> models.QuerySet:
+ async def atransition_schedule_due(cls, now=None) -> models.QuerySet:
"""
Finds instances of this model that need to run and schedule them.
"""
q = models.Q()
- for transition in cls.state_graph.transitions(automatic_only=True):
- q = q | transition.get_query(now=now)
- return cls.objects.filter(q)
+ for state in cls.state_graph.states.values():
+ state = cast(State, state)
+ if not state.terminal:
+ q = q | models.Q(
+ (
+ models.Q(
+ state_attempted__lte=timezone.now()
+ - datetime.timedelta(seconds=state.try_interval)
+ )
+ | models.Q(state_attempted__isnull=True)
+ ),
+ state=state.name,
+ )
+ await cls.objects.filter(q).aupdate(state_ready=True)
- @classproperty
- def state_graph(cls) -> Type[StateGraph]:
- return cls._meta.get_field("state").graph
+ @classmethod
+ def transition_get_with_lock(
+ cls, number: int, lock_expiry: datetime.datetime
+ ) -> List["StatorModel"]:
+ """
+ Returns up to `number` tasks for execution, having locked them.
+ """
+ with transaction.atomic():
+ selected = list(
+ cls.objects.filter(state_locked_until__isnull=True, state_ready=True)[
+ :number
+ ].select_for_update()
+ )
+ cls.objects.filter(pk__in=[i.pk for i in selected]).update(
+ state_locked_until=timezone.now()
+ )
+ return selected
+
+ @classmethod
+ async def atransition_get_with_lock(
+ cls, number: int, lock_expiry: datetime.datetime
+ ) -> List["StatorModel"]:
+ return await sync_to_async(cls.transition_get_with_lock)(number, lock_expiry)
- def schedule_transition(self, priority: int = 0):
+ @classmethod
+ async def atransition_clean_locks(cls):
+ await cls.objects.filter(state_locked_until__lte=timezone.now()).aupdate(
+ state_locked_until=None
+ )
+
+ def transition_schedule(self):
"""
Adds this instance to the queue to get its state transition attempted.
The scheduler will call this, but you can also call it directly if you
know it'll be ready and want to lower latency.
"""
- StatorTask.schedule_for_execution(self, priority=priority)
+ self.state_ready = True
+ self.save()
- async def attempt_transition(self):
+ async def atransition_attempt(self) -> bool:
"""
Attempts to transition the current state by running its handler(s).
"""
# Try each transition in priority order
- for transition in self.state_graph.states[self.state].transitions(
- automatic_only=True
- ):
- success = await transition.get_handler()(self)
+ for transition in self.state.transitions(automatic_only=True):
+ try:
+ success = await transition.get_handler()(self)
+ except BaseException as e:
+ await StatorError.acreate_from_instance(self, transition, e)
+ traceback.print_exc()
+ continue
if success:
- await self.perform_transition(transition.to_state.name)
- return
+ await self.atransition_perform(transition.to_state.name)
+ return True
await self.__class__.objects.filter(pk=self.pk).aupdate(
- state_attempted=timezone.now()
+ state_attempted=timezone.now(),
+ state_locked_until=None,
+ state_ready=False,
)
+ return False
- async def perform_transition(self, state_name):
+ def transition_perform(self, state_name):
"""
- Transitions the instance to the given state name
+ Transitions the instance to the given state name, forcibly.
"""
if state_name not in self.state_graph.states:
raise ValueError(f"Invalid state {state_name}")
- await self.__class__.objects.filter(pk=self.pk).aupdate(
+ self.__class__.objects.filter(pk=self.pk).update(
state=state_name,
state_changed=timezone.now(),
state_attempted=None,
+ state_locked_until=None,
+ state_ready=False,
)
+ atransition_perform = sync_to_async(transition_perform)
-class StatorTask(models.Model):
- """
- The model that we use for an internal scheduling queue.
- Entries in this queue are up for checking and execution - it also performs
- locking to ensure we get closer to exactly-once execution (but we err on
- the side of at-least-once)
+class StatorError(models.Model):
+ """
+ Tracks any errors running the transitions.
+ Meant to be cleaned out regularly. Should probably be a log.
"""
# appname.modelname (lowercased) label for the model this represents
@@ -132,60 +194,33 @@ class StatorTask(models.Model):
# The primary key of that model (probably int or str)
instance_pk = models.CharField(max_length=200)
- # Locking columns (no runner ID, as we have no heartbeats - all runners
- # only live for a short amount of time anyway)
- locked_until = models.DateTimeField(null=True, blank=True)
+ # The state we moved from
+ from_state = models.CharField(max_length=200)
- # Basic total ordering priority - higher is more important
- priority = models.IntegerField(default=0)
+ # The state we moved to (or tried to)
+ to_state = models.CharField(max_length=200)
- def __str__(self):
- return f"#{self.pk}: {self.model_label}.{self.instance_pk}"
+ # When it happened
+ date = models.DateTimeField(auto_now_add=True)
- @classmethod
- def schedule_for_execution(cls, model_instance: StatorModel, priority: int = 0):
- # We don't do a transaction here as it's fine to occasionally double up
- model_label = model_instance._meta.label_lower
- pk = model_instance.pk
- # TODO: Increase priority of existing if present
- if not cls.objects.filter(
- model_label=model_label, instance_pk=pk, locked__isnull=True
- ).exists():
- StatorTask.objects.create(
- model_label=model_label,
- instance_pk=pk,
- priority=priority,
- )
-
- @classmethod
- def get_for_execution(cls, number: int, lock_expiry: datetime.datetime):
- """
- Returns up to `number` tasks for execution, having locked them.
- """
- with transaction.atomic():
- selected = list(
- cls.objects.filter(locked_until__isnull=True)[
- :number
- ].select_for_update()
- )
- cls.objects.filter(pk__in=[i.pk for i in selected]).update(
- locked_until=timezone.now()
- )
- return selected
+ # Error name
+ error = models.TextField()
- @classmethod
- async def aget_for_execution(cls, number: int, lock_expiry: datetime.datetime):
- return await sync_to_async(cls.get_for_execution)(number, lock_expiry)
+ # Error details
+ error_details = models.TextField(blank=True, null=True)
@classmethod
- async def aclean_old_locks(cls):
- await cls.objects.filter(locked_until__lte=timezone.now()).aupdate(
- locked_until=None
+ async def acreate_from_instance(
+ cls,
+ instance: StatorModel,
+ transition: Transition,
+ exception: Optional[BaseException] = None,
+ ):
+ return await cls.objects.acreate(
+ model_label=instance._meta.label_lower,
+ instance_pk=str(instance.pk),
+ from_state=transition.from_state,
+ to_state=transition.to_state,
+ error=str(exception),
+ error_details=traceback.format_exc(),
)
-
- async def aget_model_instance(self) -> StatorModel:
- model = apps.get_model(self.model_label)
- return cast(StatorModel, await model.objects.aget(pk=self.pk))
-
- async def adelete(self):
- self.__class__.objects.adelete(pk=self.pk)
diff --git a/stator/runner.py b/stator/runner.py
index 8c6e0f1..f9c726e 100644
--- a/stator/runner.py
+++ b/stator/runner.py
@@ -4,11 +4,9 @@ import time
import uuid
from typing import List, Type
-from asgiref.sync import sync_to_async
-from django.db import transaction
from django.utils import timezone
-from stator.models import StatorModel, StatorTask
+from stator.models import StatorModel
class StatorRunner:
@@ -22,6 +20,7 @@ class StatorRunner:
LOCK_TIMEOUT = 120
MAX_TASKS = 30
+ MAX_TASKS_PER_MODEL = 5
def __init__(self, models: List[Type[StatorModel]]):
self.models = models
@@ -32,38 +31,44 @@ class StatorRunner:
self.handled = 0
self.tasks = []
# Clean up old locks
- await StatorTask.aclean_old_locks()
- # Examine what needs scheduling
-
+ print("Running initial cleaning and scheduling")
+ initial_tasks = []
+ for model in self.models:
+ initial_tasks.append(model.atransition_clean_locks())
+ initial_tasks.append(model.atransition_schedule_due())
+ await asyncio.gather(*initial_tasks)
# For the first time period, launch tasks
+ print("Running main task loop")
while (time.monotonic() - start_time) < self.START_TIMEOUT:
self.remove_completed_tasks()
space_remaining = self.MAX_TASKS - len(self.tasks)
# Fetch new tasks
- if space_remaining > 0:
- for new_task in await StatorTask.aget_for_execution(
- space_remaining,
- timezone.now() + datetime.timedelta(seconds=self.LOCK_TIMEOUT),
- ):
- self.tasks.append(asyncio.create_task(self.run_task(new_task)))
- self.handled += 1
+ for model in self.models:
+ if space_remaining > 0:
+ for instance in await model.atransition_get_with_lock(
+ min(space_remaining, self.MAX_TASKS_PER_MODEL),
+ timezone.now() + datetime.timedelta(seconds=self.LOCK_TIMEOUT),
+ ):
+ print(
+ f"Attempting transition on {instance._meta.label_lower}#{instance.pk}"
+ )
+ self.tasks.append(
+ asyncio.create_task(instance.atransition_attempt())
+ )
+ self.handled += 1
+ space_remaining -= 1
# Prevent busylooping
- await asyncio.sleep(0.01)
+ await asyncio.sleep(0.1)
# Then wait for tasks to finish
+ print("Waiting for tasks to complete")
while (time.monotonic() - start_time) < self.TOTAL_TIMEOUT:
self.remove_completed_tasks()
if not self.tasks:
break
# Prevent busylooping
await asyncio.sleep(1)
+ print("Complete")
return self.handled
- async def run_task(self, task: StatorTask):
- # Resolve the model instance
- model_instance = await task.aget_model_instance()
- await model_instance.attempt_transition()
- # Remove ourselves from the database as complete
- await task.adelete()
-
def remove_completed_tasks(self):
self.tasks = [t for t in self.tasks if not t.done()]
diff --git a/stator/tests/test_graph.py b/stator/tests/test_graph.py
index f6b8404..0a7113d 100644
--- a/stator/tests/test_graph.py
+++ b/stator/tests/test_graph.py
@@ -51,14 +51,14 @@ def test_bad_declarations():
# More than one initial state
with pytest.raises(ValueError):
- class TestGraph(StateGraph):
+ class TestGraph2(StateGraph):
initial = State()
initial2 = State()
# No initial states
with pytest.raises(ValueError):
- class TestGraph(StateGraph):
+ class TestGraph3(StateGraph):
loop = State()
loop2 = State()