From 7746abbbb7700fa918450101bbc6d29ed9b4b608 Mon Sep 17 00:00:00 2001 From: Andrew Godwin Date: Wed, 9 Nov 2022 22:29:33 -0700 Subject: Most of the way through the stator refactor --- stator/graph.py | 47 ++++++++++++++++++++++++++++++----------------- 1 file changed, 30 insertions(+), 17 deletions(-) (limited to 'stator/graph.py') diff --git a/stator/graph.py b/stator/graph.py index b06ffb8..7fc23f7 100644 --- a/stator/graph.py +++ b/stator/graph.py @@ -1,9 +1,16 @@ -import datetime -from functools import wraps -from typing import Callable, ClassVar, Dict, List, Optional, Set, Tuple, Union - -from django.db import models -from django.utils import timezone +from typing import ( + Any, + Callable, + ClassVar, + Dict, + List, + Optional, + Set, + Tuple, + Type, + Union, + cast, +) class StateGraph: @@ -13,7 +20,7 @@ class StateGraph: """ states: ClassVar[Dict[str, "State"]] - choices: ClassVar[List[Tuple[str, str]]] + choices: ClassVar[List[Tuple[object, str]]] initial_state: ClassVar["State"] terminal_states: ClassVar[Set["State"]] @@ -50,7 +57,7 @@ class StateGraph: cls.initial_state = initial_state cls.terminal_states = terminal_states # Generate choices - cls.choices = [(name, name) for name in cls.states.keys()] + cls.choices = [(state, name) for name, state in cls.states.items()] class State: @@ -63,7 +70,7 @@ class State: self.parents: Set["State"] = set() self.children: Dict["State", "Transition"] = {} - def _add_to_graph(self, graph: StateGraph, name: str): + def _add_to_graph(self, graph: Type[StateGraph], name: str): self.graph = graph self.name = name self.graph.states[name] = self @@ -71,13 +78,19 @@ class State: def __repr__(self): return f"" + def __str__(self): + return self.name + + def __len__(self): + return len(self.name) + def add_transition( self, other: "State", - handler: Optional[Union[str, Callable]] = None, + handler: Optional[Callable] = None, priority: int = 0, ) -> Callable: - def decorator(handler: Union[str, Callable]): + def decorator(handler: Callable[[Any], bool]): self.children[other] = Transition( self, other, @@ -85,9 +98,7 @@ class State: priority=priority, ) other.parents.add(self) - # All handlers should be class methods, so do that automatically. - if callable(handler): - return classmethod(handler) + return handler # If we're not being called as a decorator, invoke it immediately if handler is not None: @@ -113,7 +124,7 @@ class State: if automatic_only: transitions = [t for t in self.children.values() if t.automatic] else: - transitions = self.children.values() + transitions = list(self.children.values()) return sorted(transitions, key=lambda t: t.priority, reverse=True) @@ -141,7 +152,10 @@ class Transition: """ if isinstance(self.handler, str): self.handler = getattr(self.from_state.graph, self.handler) - return self.handler + return cast(Callable, self.handler) + + def __repr__(self): + return f" {self.to_state}>" class ManualTransition(Transition): @@ -157,6 +171,5 @@ class ManualTransition(Transition): ): self.from_state = from_state self.to_state = to_state - self.handler = None self.priority = 0 self.automatic = False -- cgit v1.2.3