diff options
Diffstat (limited to 'stator/models.py')
-rw-r--r-- | stator/models.py | 65 |
1 files changed, 26 insertions, 39 deletions
diff --git a/stator/models.py b/stator/models.py index 235b18c..50ee622 100644 --- a/stator/models.py +++ b/stator/models.py @@ -1,13 +1,13 @@ import datetime import traceback -from typing import ClassVar, List, Optional, Type, cast +from typing import ClassVar, List, Optional, Type, Union, cast from asgiref.sync import sync_to_async from django.db import models, transaction from django.utils import timezone from django.utils.functional import classproperty -from stator.graph import State, StateGraph, Transition +from stator.graph import State, StateGraph class StateField(models.CharField): @@ -29,16 +29,6 @@ class StateField(models.CharField): kwargs["graph"] = self.graph return name, path, args, kwargs - def from_db_value(self, value, expression, connection): - if value is None: - return value - return self.graph.states[value] - - def to_python(self, value): - if isinstance(value, State) or value is None: - return value - return self.graph.states[value] - def get_prep_value(self, value): if isinstance(value, State): return value.name @@ -95,7 +85,9 @@ class StatorModel(models.Model): ( models.Q( state_attempted__lte=timezone.now() - - datetime.timedelta(seconds=state.try_interval) + - datetime.timedelta( + seconds=cast(float, state.try_interval) + ) ) | models.Q(state_attempted__isnull=True) ), @@ -117,7 +109,7 @@ class StatorModel(models.Model): ].select_for_update() ) cls.objects.filter(pk__in=[i.pk for i in selected]).update( - state_locked_until=timezone.now() + state_locked_until=lock_expiry ) return selected @@ -143,36 +135,36 @@ class StatorModel(models.Model): self.state_ready = True self.save() - async def atransition_attempt(self) -> bool: + async def atransition_attempt(self) -> Optional[str]: """ Attempts to transition the current state by running its handler(s). """ - # Try each transition in priority order - for transition in self.state.transitions(automatic_only=True): - try: - success = await transition.get_handler()(self) - except BaseException as e: - await StatorError.acreate_from_instance(self, transition, e) - traceback.print_exc() - continue - if success: - await self.atransition_perform(transition.to_state.name) - return True + try: + next_state = await self.state_graph.states[self.state].handler(self) + except BaseException as e: + await StatorError.acreate_from_instance(self, e) + traceback.print_exc() + else: + if next_state: + await self.atransition_perform(next_state) + return next_state await self.__class__.objects.filter(pk=self.pk).aupdate( state_attempted=timezone.now(), state_locked_until=None, state_ready=False, ) - return False + return None - def transition_perform(self, state_name): + def transition_perform(self, state: Union[State, str]): """ Transitions the instance to the given state name, forcibly. """ - if state_name not in self.state_graph.states: - raise ValueError(f"Invalid state {state_name}") + if isinstance(state, State): + state = state.name + if state not in self.state_graph.states: + raise ValueError(f"Invalid state {state}") self.__class__.objects.filter(pk=self.pk).update( - state=state_name, + state=state, state_changed=timezone.now(), state_attempted=None, state_locked_until=None, @@ -194,11 +186,8 @@ class StatorError(models.Model): # The primary key of that model (probably int or str) instance_pk = models.CharField(max_length=200) - # The state we moved from - from_state = models.CharField(max_length=200) - - # The state we moved to (or tried to) - to_state = models.CharField(max_length=200) + # The state we were on + state = models.CharField(max_length=200) # When it happened date = models.DateTimeField(auto_now_add=True) @@ -213,14 +202,12 @@ class StatorError(models.Model): async def acreate_from_instance( cls, instance: StatorModel, - transition: Transition, exception: Optional[BaseException] = None, ): return await cls.objects.acreate( model_label=instance._meta.label_lower, instance_pk=str(instance.pk), - from_state=transition.from_state, - to_state=transition.to_state, + state=instance.state, error=str(exception), error_details=traceback.format_exc(), ) |