summaryrefslogtreecommitdiffstats
path: root/stator/models.py
diff options
context:
space:
mode:
Diffstat (limited to 'stator/models.py')
-rw-r--r--stator/models.py65
1 files changed, 26 insertions, 39 deletions
diff --git a/stator/models.py b/stator/models.py
index 235b18c..50ee622 100644
--- a/stator/models.py
+++ b/stator/models.py
@@ -1,13 +1,13 @@
import datetime
import traceback
-from typing import ClassVar, List, Optional, Type, cast
+from typing import ClassVar, List, Optional, Type, Union, cast
from asgiref.sync import sync_to_async
from django.db import models, transaction
from django.utils import timezone
from django.utils.functional import classproperty
-from stator.graph import State, StateGraph, Transition
+from stator.graph import State, StateGraph
class StateField(models.CharField):
@@ -29,16 +29,6 @@ class StateField(models.CharField):
kwargs["graph"] = self.graph
return name, path, args, kwargs
- def from_db_value(self, value, expression, connection):
- if value is None:
- return value
- return self.graph.states[value]
-
- def to_python(self, value):
- if isinstance(value, State) or value is None:
- return value
- return self.graph.states[value]
-
def get_prep_value(self, value):
if isinstance(value, State):
return value.name
@@ -95,7 +85,9 @@ class StatorModel(models.Model):
(
models.Q(
state_attempted__lte=timezone.now()
- - datetime.timedelta(seconds=state.try_interval)
+ - datetime.timedelta(
+ seconds=cast(float, state.try_interval)
+ )
)
| models.Q(state_attempted__isnull=True)
),
@@ -117,7 +109,7 @@ class StatorModel(models.Model):
].select_for_update()
)
cls.objects.filter(pk__in=[i.pk for i in selected]).update(
- state_locked_until=timezone.now()
+ state_locked_until=lock_expiry
)
return selected
@@ -143,36 +135,36 @@ class StatorModel(models.Model):
self.state_ready = True
self.save()
- async def atransition_attempt(self) -> bool:
+ async def atransition_attempt(self) -> Optional[str]:
"""
Attempts to transition the current state by running its handler(s).
"""
- # Try each transition in priority order
- for transition in self.state.transitions(automatic_only=True):
- try:
- success = await transition.get_handler()(self)
- except BaseException as e:
- await StatorError.acreate_from_instance(self, transition, e)
- traceback.print_exc()
- continue
- if success:
- await self.atransition_perform(transition.to_state.name)
- return True
+ try:
+ next_state = await self.state_graph.states[self.state].handler(self)
+ except BaseException as e:
+ await StatorError.acreate_from_instance(self, e)
+ traceback.print_exc()
+ else:
+ if next_state:
+ await self.atransition_perform(next_state)
+ return next_state
await self.__class__.objects.filter(pk=self.pk).aupdate(
state_attempted=timezone.now(),
state_locked_until=None,
state_ready=False,
)
- return False
+ return None
- def transition_perform(self, state_name):
+ def transition_perform(self, state: Union[State, str]):
"""
Transitions the instance to the given state name, forcibly.
"""
- if state_name not in self.state_graph.states:
- raise ValueError(f"Invalid state {state_name}")
+ if isinstance(state, State):
+ state = state.name
+ if state not in self.state_graph.states:
+ raise ValueError(f"Invalid state {state}")
self.__class__.objects.filter(pk=self.pk).update(
- state=state_name,
+ state=state,
state_changed=timezone.now(),
state_attempted=None,
state_locked_until=None,
@@ -194,11 +186,8 @@ class StatorError(models.Model):
# The primary key of that model (probably int or str)
instance_pk = models.CharField(max_length=200)
- # The state we moved from
- from_state = models.CharField(max_length=200)
-
- # The state we moved to (or tried to)
- to_state = models.CharField(max_length=200)
+ # The state we were on
+ state = models.CharField(max_length=200)
# When it happened
date = models.DateTimeField(auto_now_add=True)
@@ -213,14 +202,12 @@ class StatorError(models.Model):
async def acreate_from_instance(
cls,
instance: StatorModel,
- transition: Transition,
exception: Optional[BaseException] = None,
):
return await cls.objects.acreate(
model_label=instance._meta.label_lower,
instance_pk=str(instance.pk),
- from_state=transition.from_state,
- to_state=transition.to_state,
+ state=instance.state,
error=str(exception),
error_details=traceback.format_exc(),
)