summaryrefslogtreecommitdiffstats
path: root/stator/graph.py
diff options
context:
space:
mode:
Diffstat (limited to 'stator/graph.py')
-rw-r--r--stator/graph.py47
1 files changed, 30 insertions, 17 deletions
diff --git a/stator/graph.py b/stator/graph.py
index b06ffb8..7fc23f7 100644
--- a/stator/graph.py
+++ b/stator/graph.py
@@ -1,9 +1,16 @@
-import datetime
-from functools import wraps
-from typing import Callable, ClassVar, Dict, List, Optional, Set, Tuple, Union
-
-from django.db import models
-from django.utils import timezone
+from typing import (
+ Any,
+ Callable,
+ ClassVar,
+ Dict,
+ List,
+ Optional,
+ Set,
+ Tuple,
+ Type,
+ Union,
+ cast,
+)
class StateGraph:
@@ -13,7 +20,7 @@ class StateGraph:
"""
states: ClassVar[Dict[str, "State"]]
- choices: ClassVar[List[Tuple[str, str]]]
+ choices: ClassVar[List[Tuple[object, str]]]
initial_state: ClassVar["State"]
terminal_states: ClassVar[Set["State"]]
@@ -50,7 +57,7 @@ class StateGraph:
cls.initial_state = initial_state
cls.terminal_states = terminal_states
# Generate choices
- cls.choices = [(name, name) for name in cls.states.keys()]
+ cls.choices = [(state, name) for name, state in cls.states.items()]
class State:
@@ -63,7 +70,7 @@ class State:
self.parents: Set["State"] = set()
self.children: Dict["State", "Transition"] = {}
- def _add_to_graph(self, graph: StateGraph, name: str):
+ def _add_to_graph(self, graph: Type[StateGraph], name: str):
self.graph = graph
self.name = name
self.graph.states[name] = self
@@ -71,13 +78,19 @@ class State:
def __repr__(self):
return f"<State {self.name}>"
+ def __str__(self):
+ return self.name
+
+ def __len__(self):
+ return len(self.name)
+
def add_transition(
self,
other: "State",
- handler: Optional[Union[str, Callable]] = None,
+ handler: Optional[Callable] = None,
priority: int = 0,
) -> Callable:
- def decorator(handler: Union[str, Callable]):
+ def decorator(handler: Callable[[Any], bool]):
self.children[other] = Transition(
self,
other,
@@ -85,9 +98,7 @@ class State:
priority=priority,
)
other.parents.add(self)
- # All handlers should be class methods, so do that automatically.
- if callable(handler):
- return classmethod(handler)
+ return handler
# If we're not being called as a decorator, invoke it immediately
if handler is not None:
@@ -113,7 +124,7 @@ class State:
if automatic_only:
transitions = [t for t in self.children.values() if t.automatic]
else:
- transitions = self.children.values()
+ transitions = list(self.children.values())
return sorted(transitions, key=lambda t: t.priority, reverse=True)
@@ -141,7 +152,10 @@ class Transition:
"""
if isinstance(self.handler, str):
self.handler = getattr(self.from_state.graph, self.handler)
- return self.handler
+ return cast(Callable, self.handler)
+
+ def __repr__(self):
+ return f"<Transition {self.from_state} -> {self.to_state}>"
class ManualTransition(Transition):
@@ -157,6 +171,5 @@ class ManualTransition(Transition):
):
self.from_state = from_state
self.to_state = to_state
- self.handler = None
self.priority = 0
self.automatic = False