summaryrefslogtreecommitdiffstats
path: root/stator
diff options
context:
space:
mode:
Diffstat (limited to 'stator')
-rw-r--r--stator/admin.py3
-rw-r--r--stator/graph.py148
-rw-r--r--stator/migrations/0001_initial.py5
-rw-r--r--stator/models.py65
-rw-r--r--stator/runner.py15
-rw-r--r--stator/tests/test_graph.py57
6 files changed, 121 insertions, 172 deletions
diff --git a/stator/admin.py b/stator/admin.py
index 025f225..790fc38 100644
--- a/stator/admin.py
+++ b/stator/admin.py
@@ -10,8 +10,7 @@ class DomainAdmin(admin.ModelAdmin):
"date",
"model_label",
"instance_pk",
- "from_state",
- "to_state",
+ "state",
"error",
]
ordering = ["-date"]
diff --git a/stator/graph.py b/stator/graph.py
index 7fc23f7..7a8455c 100644
--- a/stator/graph.py
+++ b/stator/graph.py
@@ -1,16 +1,4 @@
-from typing import (
- Any,
- Callable,
- ClassVar,
- Dict,
- List,
- Optional,
- Set,
- Tuple,
- Type,
- Union,
- cast,
-)
+from typing import Any, Callable, ClassVar, Dict, List, Optional, Set, Tuple, Type
class StateGraph:
@@ -44,20 +32,43 @@ class StateGraph:
terminal_states = set()
initial_state = None
for state in cls.states.values():
+ # Check for multiple initial states
if state.initial:
if initial_state:
raise ValueError(
f"The graph has more than one initial state: {initial_state} and {state}"
)
initial_state = state
+ # Collect terminal states
if state.terminal:
terminal_states.add(state)
+ # Ensure they do NOT have a handler
+ try:
+ state.handler
+ except AttributeError:
+ pass
+ else:
+ raise ValueError(
+ f"Terminal state '{state}' should not have a handler method ({state.handler_name})"
+ )
+ else:
+ # Ensure non-terminal states have a try interval and a handler
+ if not state.try_interval:
+ raise ValueError(
+ f"State '{state}' has no try_interval and is not terminal"
+ )
+ try:
+ state.handler
+ except AttributeError:
+ raise ValueError(
+ f"State '{state}' does not have a handler method ({state.handler_name})"
+ )
if initial_state is None:
raise ValueError("The graph has no initial state")
cls.initial_state = initial_state
cls.terminal_states = terminal_states
# Generate choices
- cls.choices = [(state, name) for name, state in cls.states.items()]
+ cls.choices = [(name, name) for name in cls.states.keys()]
class State:
@@ -65,49 +76,37 @@ class State:
Represents an individual state
"""
- def __init__(self, try_interval: float = 300):
+ def __init__(
+ self,
+ try_interval: Optional[float] = None,
+ handler_name: Optional[str] = None,
+ ):
self.try_interval = try_interval
+ self.handler_name = handler_name
self.parents: Set["State"] = set()
- self.children: Dict["State", "Transition"] = {}
+ self.children: Set["State"] = set()
def _add_to_graph(self, graph: Type[StateGraph], name: str):
self.graph = graph
self.name = name
self.graph.states[name] = self
+ if self.handler_name is None:
+ self.handler_name = f"handle_{self.name}"
def __repr__(self):
return f"<State {self.name}>"
- def __str__(self):
- return self.name
-
- def __len__(self):
- return len(self.name)
-
- def add_transition(
- self,
- other: "State",
- handler: Optional[Callable] = None,
- priority: int = 0,
- ) -> Callable:
- def decorator(handler: Callable[[Any], bool]):
- self.children[other] = Transition(
- self,
- other,
- handler,
- priority=priority,
- )
- other.parents.add(self)
- return handler
+ def __eq__(self, other):
+ if isinstance(other, State):
+ return self is other
+ return self.name == other
- # If we're not being called as a decorator, invoke it immediately
- if handler is not None:
- decorator(handler)
- return decorator
+ def __hash__(self):
+ return hash(id(self))
- def add_manual_transition(self, other: "State"):
- self.children[other] = ManualTransition(self, other)
- other.parents.add(self)
+ def transitions_to(self, other: "State"):
+ self.children.add(other)
+ other.parents.add(other)
@property
def initial(self):
@@ -117,59 +116,8 @@ class State:
def terminal(self):
return not self.children
- def transitions(self, automatic_only=False) -> List["Transition"]:
- """
- Returns all transitions from this State in priority order
- """
- if automatic_only:
- transitions = [t for t in self.children.values() if t.automatic]
- else:
- transitions = list(self.children.values())
- return sorted(transitions, key=lambda t: t.priority, reverse=True)
-
-
-class Transition:
- """
- A possible transition from one state to another
- """
-
- def __init__(
- self,
- from_state: State,
- to_state: State,
- handler: Union[str, Callable],
- priority: int = 0,
- ):
- self.from_state = from_state
- self.to_state = to_state
- self.handler = handler
- self.priority = priority
- self.automatic = True
-
- def get_handler(self) -> Callable:
- """
- Returns the handler (it might need resolving from a string)
- """
- if isinstance(self.handler, str):
- self.handler = getattr(self.from_state.graph, self.handler)
- return cast(Callable, self.handler)
-
- def __repr__(self):
- return f"<Transition {self.from_state} -> {self.to_state}>"
-
-
-class ManualTransition(Transition):
- """
- A possible transition from one state to another that cannot be done by
- the stator task runner, and must come from an external source.
- """
-
- def __init__(
- self,
- from_state: State,
- to_state: State,
- ):
- self.from_state = from_state
- self.to_state = to_state
- self.priority = 0
- self.automatic = False
+ @property
+ def handler(self) -> Callable[[Any], Optional[str]]:
+ if self.handler_name is None:
+ raise AttributeError("No handler defined")
+ return getattr(self.graph, self.handler_name)
diff --git a/stator/migrations/0001_initial.py b/stator/migrations/0001_initial.py
index d56ed5c..f7d652e 100644
--- a/stator/migrations/0001_initial.py
+++ b/stator/migrations/0001_initial.py
@@ -1,4 +1,4 @@
-# Generated by Django 4.1.3 on 2022-11-10 03:24
+# Generated by Django 4.1.3 on 2022-11-10 05:56
from django.db import migrations, models
@@ -24,8 +24,7 @@ class Migration(migrations.Migration):
),
("model_label", models.CharField(max_length=200)),
("instance_pk", models.CharField(max_length=200)),
- ("from_state", models.CharField(max_length=200)),
- ("to_state", models.CharField(max_length=200)),
+ ("state", models.CharField(max_length=200)),
("date", models.DateTimeField(auto_now_add=True)),
("error", models.TextField()),
("error_details", models.TextField(blank=True, null=True)),
diff --git a/stator/models.py b/stator/models.py
index 235b18c..50ee622 100644
--- a/stator/models.py
+++ b/stator/models.py
@@ -1,13 +1,13 @@
import datetime
import traceback
-from typing import ClassVar, List, Optional, Type, cast
+from typing import ClassVar, List, Optional, Type, Union, cast
from asgiref.sync import sync_to_async
from django.db import models, transaction
from django.utils import timezone
from django.utils.functional import classproperty
-from stator.graph import State, StateGraph, Transition
+from stator.graph import State, StateGraph
class StateField(models.CharField):
@@ -29,16 +29,6 @@ class StateField(models.CharField):
kwargs["graph"] = self.graph
return name, path, args, kwargs
- def from_db_value(self, value, expression, connection):
- if value is None:
- return value
- return self.graph.states[value]
-
- def to_python(self, value):
- if isinstance(value, State) or value is None:
- return value
- return self.graph.states[value]
-
def get_prep_value(self, value):
if isinstance(value, State):
return value.name
@@ -95,7 +85,9 @@ class StatorModel(models.Model):
(
models.Q(
state_attempted__lte=timezone.now()
- - datetime.timedelta(seconds=state.try_interval)
+ - datetime.timedelta(
+ seconds=cast(float, state.try_interval)
+ )
)
| models.Q(state_attempted__isnull=True)
),
@@ -117,7 +109,7 @@ class StatorModel(models.Model):
].select_for_update()
)
cls.objects.filter(pk__in=[i.pk for i in selected]).update(
- state_locked_until=timezone.now()
+ state_locked_until=lock_expiry
)
return selected
@@ -143,36 +135,36 @@ class StatorModel(models.Model):
self.state_ready = True
self.save()
- async def atransition_attempt(self) -> bool:
+ async def atransition_attempt(self) -> Optional[str]:
"""
Attempts to transition the current state by running its handler(s).
"""
- # Try each transition in priority order
- for transition in self.state.transitions(automatic_only=True):
- try:
- success = await transition.get_handler()(self)
- except BaseException as e:
- await StatorError.acreate_from_instance(self, transition, e)
- traceback.print_exc()
- continue
- if success:
- await self.atransition_perform(transition.to_state.name)
- return True
+ try:
+ next_state = await self.state_graph.states[self.state].handler(self)
+ except BaseException as e:
+ await StatorError.acreate_from_instance(self, e)
+ traceback.print_exc()
+ else:
+ if next_state:
+ await self.atransition_perform(next_state)
+ return next_state
await self.__class__.objects.filter(pk=self.pk).aupdate(
state_attempted=timezone.now(),
state_locked_until=None,
state_ready=False,
)
- return False
+ return None
- def transition_perform(self, state_name):
+ def transition_perform(self, state: Union[State, str]):
"""
Transitions the instance to the given state name, forcibly.
"""
- if state_name not in self.state_graph.states:
- raise ValueError(f"Invalid state {state_name}")
+ if isinstance(state, State):
+ state = state.name
+ if state not in self.state_graph.states:
+ raise ValueError(f"Invalid state {state}")
self.__class__.objects.filter(pk=self.pk).update(
- state=state_name,
+ state=state,
state_changed=timezone.now(),
state_attempted=None,
state_locked_until=None,
@@ -194,11 +186,8 @@ class StatorError(models.Model):
# The primary key of that model (probably int or str)
instance_pk = models.CharField(max_length=200)
- # The state we moved from
- from_state = models.CharField(max_length=200)
-
- # The state we moved to (or tried to)
- to_state = models.CharField(max_length=200)
+ # The state we were on
+ state = models.CharField(max_length=200)
# When it happened
date = models.DateTimeField(auto_now_add=True)
@@ -213,14 +202,12 @@ class StatorError(models.Model):
async def acreate_from_instance(
cls,
instance: StatorModel,
- transition: Transition,
exception: Optional[BaseException] = None,
):
return await cls.objects.acreate(
model_label=instance._meta.label_lower,
instance_pk=str(instance.pk),
- from_state=transition.from_state,
- to_state=transition.to_state,
+ state=instance.state,
error=str(exception),
error_details=traceback.format_exc(),
)
diff --git a/stator/runner.py b/stator/runner.py
index f9c726e..1392e4d 100644
--- a/stator/runner.py
+++ b/stator/runner.py
@@ -1,6 +1,7 @@
import asyncio
import datetime
import time
+import traceback
import uuid
from typing import List, Type
@@ -53,7 +54,7 @@ class StatorRunner:
f"Attempting transition on {instance._meta.label_lower}#{instance.pk}"
)
self.tasks.append(
- asyncio.create_task(instance.atransition_attempt())
+ asyncio.create_task(self.run_transition(instance))
)
self.handled += 1
space_remaining -= 1
@@ -70,5 +71,17 @@ class StatorRunner:
print("Complete")
return self.handled
+ async def run_transition(self, instance: StatorModel):
+ """
+ Wrapper for atransition_attempt with fallback error handling
+ """
+ try:
+ await instance.atransition_attempt()
+ except BaseException:
+ traceback.print_exc()
+
def remove_completed_tasks(self):
+ """
+ Removes all completed asyncio.Tasks from our local in-progress list
+ """
self.tasks = [t for t in self.tasks if not t.done()]
diff --git a/stator/tests/test_graph.py b/stator/tests/test_graph.py
index 0a7113d..c66f441 100644
--- a/stator/tests/test_graph.py
+++ b/stator/tests/test_graph.py
@@ -9,39 +9,29 @@ def test_declare():
lookups.
"""
- fake_handler = lambda: True
-
class TestGraph(StateGraph):
- initial = State()
- second = State()
+ initial = State(try_interval=3600)
+ second = State(try_interval=1)
third = State()
- fourth = State()
- final = State()
-
- initial.add_transition(second, 60, handler=fake_handler)
- second.add_transition(third, 60, handler="check_third")
- def check_third(cls):
- return True
+ initial.transitions_to(second)
+ second.transitions_to(third)
- @third.add_transition(fourth, 60)
- def check_fourth(cls):
- return True
+ @classmethod
+ def handle_initial(cls):
+ pass
- fourth.add_manual_transition(final)
+ @classmethod
+ def handle_second(cls):
+ pass
assert TestGraph.initial_state == TestGraph.initial
- assert TestGraph.terminal_states == {TestGraph.final}
+ assert TestGraph.terminal_states == {TestGraph.third}
- assert TestGraph.initial.children[TestGraph.second].get_handler() == fake_handler
- assert (
- TestGraph.second.children[TestGraph.third].get_handler()
- == TestGraph.check_third
- )
- assert (
- TestGraph.third.children[TestGraph.fourth].get_handler().__name__
- == "check_fourth"
- )
+ assert TestGraph.initial.handler == TestGraph.handle_initial
+ assert TestGraph.initial.try_interval == 3600
+ assert TestGraph.second.handler == TestGraph.handle_second
+ assert TestGraph.second.try_interval == 1
def test_bad_declarations():
@@ -62,5 +52,18 @@ def test_bad_declarations():
loop = State()
loop2 = State()
- loop.add_transition(loop2, 1, handler="fake")
- loop2.add_transition(loop, 1, handler="fake")
+ loop.transitions_to(loop2)
+ loop2.transitions_to(loop)
+
+
+def test_state():
+ """
+ Tests basic values of the State class
+ """
+
+ class TestGraph(StateGraph):
+ initial = State()
+
+ assert "initial" == TestGraph.initial
+ assert TestGraph.initial == "initial"
+ assert TestGraph.initial == TestGraph.initial