summaryrefslogtreecommitdiffstats
path: root/stator/models.py
diff options
context:
space:
mode:
authorAndrew Godwin2022-11-09 22:29:33 -0700
committerAndrew Godwin2022-11-09 22:29:49 -0700
commit7746abbbb7700fa918450101bbc6d29ed9b4b608 (patch)
tree8768efd8201faa2fee18e5d3b46f33785002f5d6 /stator/models.py
parent61c324508e62bb640b4526183d0837fc57d742c2 (diff)
downloadtakahe-7746abbbb7700fa918450101bbc6d29ed9b4b608.tar.gz
takahe-7746abbbb7700fa918450101bbc6d29ed9b4b608.tar.bz2
takahe-7746abbbb7700fa918450101bbc6d29ed9b4b608.zip
Most of the way through the stator refactor
Diffstat (limited to 'stator/models.py')
-rw-r--r--stator/models.py195
1 files changed, 115 insertions, 80 deletions
diff --git a/stator/models.py b/stator/models.py
index 3b0da0a..235b18c 100644
--- a/stator/models.py
+++ b/stator/models.py
@@ -1,14 +1,13 @@
import datetime
-from functools import reduce
-from typing import Type, cast
+import traceback
+from typing import ClassVar, List, Optional, Type, cast
from asgiref.sync import sync_to_async
-from django.apps import apps
from django.db import models, transaction
from django.utils import timezone
from django.utils.functional import classproperty
-from stator.graph import State, StateGraph
+from stator.graph import State, StateGraph, Transition
class StateField(models.CharField):
@@ -55,6 +54,9 @@ class StatorModel(models.Model):
concrete model yourself.
"""
+ # If this row is up for transition attempts
+ state_ready = models.BooleanField(default=False)
+
# When the state last actually changed, or the date of instance creation
state_changed = models.DateTimeField(auto_now_add=True)
@@ -62,68 +64,128 @@ class StatorModel(models.Model):
# (and not successful, as this is cleared on transition)
state_attempted = models.DateTimeField(blank=True, null=True)
+ # If a lock is out on this row, when it is locked until
+ # (we don't identify the lock owner, as there's no heartbeats)
+ state_locked_until = models.DateTimeField(null=True, blank=True)
+
+ # Collection of subclasses of us
+ subclasses: ClassVar[List[Type["StatorModel"]]] = []
+
class Meta:
abstract = True
+ def __init_subclass__(cls) -> None:
+ if cls is not StatorModel:
+ cls.subclasses.append(cls)
+
+ @classproperty
+ def state_graph(cls) -> Type[StateGraph]:
+ return cls._meta.get_field("state").graph
+
@classmethod
- def schedule_overdue(cls, now=None) -> models.QuerySet:
+ async def atransition_schedule_due(cls, now=None) -> models.QuerySet:
"""
Finds instances of this model that need to run and schedule them.
"""
q = models.Q()
- for transition in cls.state_graph.transitions(automatic_only=True):
- q = q | transition.get_query(now=now)
- return cls.objects.filter(q)
+ for state in cls.state_graph.states.values():
+ state = cast(State, state)
+ if not state.terminal:
+ q = q | models.Q(
+ (
+ models.Q(
+ state_attempted__lte=timezone.now()
+ - datetime.timedelta(seconds=state.try_interval)
+ )
+ | models.Q(state_attempted__isnull=True)
+ ),
+ state=state.name,
+ )
+ await cls.objects.filter(q).aupdate(state_ready=True)
- @classproperty
- def state_graph(cls) -> Type[StateGraph]:
- return cls._meta.get_field("state").graph
+ @classmethod
+ def transition_get_with_lock(
+ cls, number: int, lock_expiry: datetime.datetime
+ ) -> List["StatorModel"]:
+ """
+ Returns up to `number` tasks for execution, having locked them.
+ """
+ with transaction.atomic():
+ selected = list(
+ cls.objects.filter(state_locked_until__isnull=True, state_ready=True)[
+ :number
+ ].select_for_update()
+ )
+ cls.objects.filter(pk__in=[i.pk for i in selected]).update(
+ state_locked_until=timezone.now()
+ )
+ return selected
+
+ @classmethod
+ async def atransition_get_with_lock(
+ cls, number: int, lock_expiry: datetime.datetime
+ ) -> List["StatorModel"]:
+ return await sync_to_async(cls.transition_get_with_lock)(number, lock_expiry)
- def schedule_transition(self, priority: int = 0):
+ @classmethod
+ async def atransition_clean_locks(cls):
+ await cls.objects.filter(state_locked_until__lte=timezone.now()).aupdate(
+ state_locked_until=None
+ )
+
+ def transition_schedule(self):
"""
Adds this instance to the queue to get its state transition attempted.
The scheduler will call this, but you can also call it directly if you
know it'll be ready and want to lower latency.
"""
- StatorTask.schedule_for_execution(self, priority=priority)
+ self.state_ready = True
+ self.save()
- async def attempt_transition(self):
+ async def atransition_attempt(self) -> bool:
"""
Attempts to transition the current state by running its handler(s).
"""
# Try each transition in priority order
- for transition in self.state_graph.states[self.state].transitions(
- automatic_only=True
- ):
- success = await transition.get_handler()(self)
+ for transition in self.state.transitions(automatic_only=True):
+ try:
+ success = await transition.get_handler()(self)
+ except BaseException as e:
+ await StatorError.acreate_from_instance(self, transition, e)
+ traceback.print_exc()
+ continue
if success:
- await self.perform_transition(transition.to_state.name)
- return
+ await self.atransition_perform(transition.to_state.name)
+ return True
await self.__class__.objects.filter(pk=self.pk).aupdate(
- state_attempted=timezone.now()
+ state_attempted=timezone.now(),
+ state_locked_until=None,
+ state_ready=False,
)
+ return False
- async def perform_transition(self, state_name):
+ def transition_perform(self, state_name):
"""
- Transitions the instance to the given state name
+ Transitions the instance to the given state name, forcibly.
"""
if state_name not in self.state_graph.states:
raise ValueError(f"Invalid state {state_name}")
- await self.__class__.objects.filter(pk=self.pk).aupdate(
+ self.__class__.objects.filter(pk=self.pk).update(
state=state_name,
state_changed=timezone.now(),
state_attempted=None,
+ state_locked_until=None,
+ state_ready=False,
)
+ atransition_perform = sync_to_async(transition_perform)
-class StatorTask(models.Model):
- """
- The model that we use for an internal scheduling queue.
- Entries in this queue are up for checking and execution - it also performs
- locking to ensure we get closer to exactly-once execution (but we err on
- the side of at-least-once)
+class StatorError(models.Model):
+ """
+ Tracks any errors running the transitions.
+ Meant to be cleaned out regularly. Should probably be a log.
"""
# appname.modelname (lowercased) label for the model this represents
@@ -132,60 +194,33 @@ class StatorTask(models.Model):
# The primary key of that model (probably int or str)
instance_pk = models.CharField(max_length=200)
- # Locking columns (no runner ID, as we have no heartbeats - all runners
- # only live for a short amount of time anyway)
- locked_until = models.DateTimeField(null=True, blank=True)
+ # The state we moved from
+ from_state = models.CharField(max_length=200)
- # Basic total ordering priority - higher is more important
- priority = models.IntegerField(default=0)
+ # The state we moved to (or tried to)
+ to_state = models.CharField(max_length=200)
- def __str__(self):
- return f"#{self.pk}: {self.model_label}.{self.instance_pk}"
+ # When it happened
+ date = models.DateTimeField(auto_now_add=True)
- @classmethod
- def schedule_for_execution(cls, model_instance: StatorModel, priority: int = 0):
- # We don't do a transaction here as it's fine to occasionally double up
- model_label = model_instance._meta.label_lower
- pk = model_instance.pk
- # TODO: Increase priority of existing if present
- if not cls.objects.filter(
- model_label=model_label, instance_pk=pk, locked__isnull=True
- ).exists():
- StatorTask.objects.create(
- model_label=model_label,
- instance_pk=pk,
- priority=priority,
- )
-
- @classmethod
- def get_for_execution(cls, number: int, lock_expiry: datetime.datetime):
- """
- Returns up to `number` tasks for execution, having locked them.
- """
- with transaction.atomic():
- selected = list(
- cls.objects.filter(locked_until__isnull=True)[
- :number
- ].select_for_update()
- )
- cls.objects.filter(pk__in=[i.pk for i in selected]).update(
- locked_until=timezone.now()
- )
- return selected
+ # Error name
+ error = models.TextField()
- @classmethod
- async def aget_for_execution(cls, number: int, lock_expiry: datetime.datetime):
- return await sync_to_async(cls.get_for_execution)(number, lock_expiry)
+ # Error details
+ error_details = models.TextField(blank=True, null=True)
@classmethod
- async def aclean_old_locks(cls):
- await cls.objects.filter(locked_until__lte=timezone.now()).aupdate(
- locked_until=None
+ async def acreate_from_instance(
+ cls,
+ instance: StatorModel,
+ transition: Transition,
+ exception: Optional[BaseException] = None,
+ ):
+ return await cls.objects.acreate(
+ model_label=instance._meta.label_lower,
+ instance_pk=str(instance.pk),
+ from_state=transition.from_state,
+ to_state=transition.to_state,
+ error=str(exception),
+ error_details=traceback.format_exc(),
)
-
- async def aget_model_instance(self) -> StatorModel:
- model = apps.get_model(self.model_label)
- return cast(StatorModel, await model.objects.aget(pk=self.pk))
-
- async def adelete(self):
- self.__class__.objects.adelete(pk=self.pk)