summaryrefslogtreecommitdiffstats
path: root/stator/graph.py
blob: 0ec5ee73fd8e43ec3a57614b8d5e871c01328723 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
from collections.abc import Callable
from typing import Any, ClassVar


class StateGraph:
    """
    Represents a graph of possible states and transitions to attempt on them.
    Does not support subclasses of existing graphs yet.
    """

    states: ClassVar[dict[str, "State"]]
    choices: ClassVar[list[tuple[object, str]]]
    initial_state: ClassVar["State"]
    terminal_states: ClassVar[set["State"]]
    automatic_states: ClassVar[set["State"]]

    def __init_subclass__(cls) -> None:
        # Collect state members
        cls.states = {}
        for name, value in cls.__dict__.items():
            if name in ["__module__", "__doc__", "states"]:
                pass
            elif name in ["initial_state", "terminal_states", "choices"]:
                raise ValueError(f"Cannot name a state {name} - this is reserved")
            elif isinstance(value, State):
                value._add_to_graph(cls, name)
            elif callable(value) or isinstance(value, classmethod):
                pass
            else:
                raise ValueError(
                    f"Graph has item {name} of unallowed type {type(value)}"
                )
        # Check the graph layout
        terminal_states = set()
        automatic_states = set()
        initial_state = None
        for state in cls.states.values():
            # Check for multiple initial states
            if state.initial:
                if initial_state:
                    raise ValueError(
                        f"The graph has more than one initial state: {initial_state} and {state}"
                    )
                initial_state = state
            # Collect terminal states
            if state.terminal:
                state.externally_progressed = True
                terminal_states.add(state)
                # Ensure they do NOT have a handler
                try:
                    state.handler
                except AttributeError:
                    pass
                else:
                    raise ValueError(
                        f"Terminal state '{state}' should not have a handler method ({state.handler_name})"
                    )
            else:
                # Ensure non-terminal/manual states have a try interval and a handler
                if not state.externally_progressed:
                    if not state.try_interval:
                        raise ValueError(
                            f"State '{state}' has no try_interval and is not terminal or manual"
                        )
                    try:
                        state.handler
                    except AttributeError:
                        raise ValueError(
                            f"State '{state}' does not have a handler method ({state.handler_name})"
                        )
                    automatic_states.add(state)
        if initial_state is None:
            raise ValueError("The graph has no initial state")
        cls.initial_state = initial_state
        cls.terminal_states = terminal_states
        cls.automatic_states = automatic_states
        # Generate choices
        cls.choices = [(name, name) for name in cls.states.keys()]


class State:
    """
    Represents an individual state
    """

    def __init__(
        self,
        try_interval: float | None = None,
        handler_name: str | None = None,
        externally_progressed: bool = False,
        attempt_immediately: bool = True,
        force_initial: bool = False,
    ):
        self.try_interval = try_interval
        self.handler_name = handler_name
        self.externally_progressed = externally_progressed
        self.attempt_immediately = attempt_immediately
        self.force_initial = force_initial
        self.parents: set["State"] = set()
        self.children: set["State"] = set()

    def _add_to_graph(self, graph: type[StateGraph], name: str):
        self.graph = graph
        self.name = name
        self.graph.states[name] = self
        if self.handler_name is None:
            self.handler_name = f"handle_{self.name}"

    def __repr__(self):
        return f"<State {self.name}>"

    def __str__(self):
        return self.name

    def __eq__(self, other):
        if isinstance(other, State):
            return self is other
        return self.name == other

    def __hash__(self):
        return hash(id(self))

    def transitions_to(self, other: "State"):
        self.children.add(other)
        other.parents.add(other)

    @property
    def initial(self):
        return self.force_initial or (not self.parents)

    @property
    def terminal(self):
        return not self.children

    @property
    def handler(self) -> Callable[[Any], str | None]:
        # Retrieve it by name off the graph
        if self.handler_name is None:
            raise AttributeError("No handler defined")
        return getattr(self.graph, self.handler_name)