From 7746abbbb7700fa918450101bbc6d29ed9b4b608 Mon Sep 17 00:00:00 2001 From: Andrew Godwin Date: Wed, 9 Nov 2022 22:29:33 -0700 Subject: Most of the way through the stator refactor --- stator/admin.py | 15 +- stator/graph.py | 47 +++-- stator/management/__init__.py | 0 stator/management/commands/__init__.py | 0 stator/management/commands/runstator.py | 28 +++ stator/migrations/0001_initial.py | 11 +- stator/models.py | 195 ++++++++++++--------- stator/runner.py | 47 ++--- stator/tests/test_graph.py | 4 +- ...follow_state_locked_until_follow_state_ready.py | 23 +++ users/models/__init__.py | 4 +- users/models/domain.py | 2 +- users/models/follow.py | 14 +- users/models/identity.py | 16 +- users/shortcuts.py | 9 +- users/views/identity.py | 6 +- 16 files changed, 273 insertions(+), 148 deletions(-) create mode 100644 stator/management/__init__.py create mode 100644 stator/management/commands/__init__.py create mode 100644 stator/management/commands/runstator.py create mode 100644 users/migrations/0005_follow_state_locked_until_follow_state_ready.py diff --git a/stator/admin.py b/stator/admin.py index c04d775..025f225 100644 --- a/stator/admin.py +++ b/stator/admin.py @@ -1,8 +1,17 @@ from django.contrib import admin -from stator.models import StatorTask +from stator.models import StatorError -@admin.register(StatorTask) +@admin.register(StatorError) class DomainAdmin(admin.ModelAdmin): - list_display = ["id", "model_label", "instance_pk", "locked_until"] + list_display = [ + "id", + "date", + "model_label", + "instance_pk", + "from_state", + "to_state", + "error", + ] + ordering = ["-date"] diff --git a/stator/graph.py b/stator/graph.py index b06ffb8..7fc23f7 100644 --- a/stator/graph.py +++ b/stator/graph.py @@ -1,9 +1,16 @@ -import datetime -from functools import wraps -from typing import Callable, ClassVar, Dict, List, Optional, Set, Tuple, Union - -from django.db import models -from django.utils import timezone +from typing import ( + Any, + Callable, + ClassVar, + Dict, + List, + Optional, + Set, + Tuple, + Type, + Union, + cast, +) class StateGraph: @@ -13,7 +20,7 @@ class StateGraph: """ states: ClassVar[Dict[str, "State"]] - choices: ClassVar[List[Tuple[str, str]]] + choices: ClassVar[List[Tuple[object, str]]] initial_state: ClassVar["State"] terminal_states: ClassVar[Set["State"]] @@ -50,7 +57,7 @@ class StateGraph: cls.initial_state = initial_state cls.terminal_states = terminal_states # Generate choices - cls.choices = [(name, name) for name in cls.states.keys()] + cls.choices = [(state, name) for name, state in cls.states.items()] class State: @@ -63,7 +70,7 @@ class State: self.parents: Set["State"] = set() self.children: Dict["State", "Transition"] = {} - def _add_to_graph(self, graph: StateGraph, name: str): + def _add_to_graph(self, graph: Type[StateGraph], name: str): self.graph = graph self.name = name self.graph.states[name] = self @@ -71,13 +78,19 @@ class State: def __repr__(self): return f"" + def __str__(self): + return self.name + + def __len__(self): + return len(self.name) + def add_transition( self, other: "State", - handler: Optional[Union[str, Callable]] = None, + handler: Optional[Callable] = None, priority: int = 0, ) -> Callable: - def decorator(handler: Union[str, Callable]): + def decorator(handler: Callable[[Any], bool]): self.children[other] = Transition( self, other, @@ -85,9 +98,7 @@ class State: priority=priority, ) other.parents.add(self) - # All handlers should be class methods, so do that automatically. - if callable(handler): - return classmethod(handler) + return handler # If we're not being called as a decorator, invoke it immediately if handler is not None: @@ -113,7 +124,7 @@ class State: if automatic_only: transitions = [t for t in self.children.values() if t.automatic] else: - transitions = self.children.values() + transitions = list(self.children.values()) return sorted(transitions, key=lambda t: t.priority, reverse=True) @@ -141,7 +152,10 @@ class Transition: """ if isinstance(self.handler, str): self.handler = getattr(self.from_state.graph, self.handler) - return self.handler + return cast(Callable, self.handler) + + def __repr__(self): + return f" {self.to_state}>" class ManualTransition(Transition): @@ -157,6 +171,5 @@ class ManualTransition(Transition): ): self.from_state = from_state self.to_state = to_state - self.handler = None self.priority = 0 self.automatic = False diff --git a/stator/management/__init__.py b/stator/management/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/stator/management/commands/__init__.py b/stator/management/commands/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/stator/management/commands/runstator.py b/stator/management/commands/runstator.py new file mode 100644 index 0000000..1307fef --- /dev/null +++ b/stator/management/commands/runstator.py @@ -0,0 +1,28 @@ +from typing import List, Type, cast + +from asgiref.sync import async_to_sync +from django.apps import apps +from django.core.management.base import BaseCommand + +from stator.models import StatorModel +from stator.runner import StatorRunner + + +class Command(BaseCommand): + help = "Runs a Stator runner for a short period" + + def add_arguments(self, parser): + parser.add_argument("model_labels", nargs="*", type=str) + + def handle(self, model_labels: List[str], *args, **options): + # Resolve the models list into names + models = cast( + List[Type[StatorModel]], + [apps.get_model(label) for label in model_labels], + ) + if not models: + models = StatorModel.subclasses + print("Running for models: " + " ".join(m._meta.label_lower for m in models)) + # Run a runner + runner = StatorRunner(models) + async_to_sync(runner.run)() diff --git a/stator/migrations/0001_initial.py b/stator/migrations/0001_initial.py index f485836..d56ed5c 100644 --- a/stator/migrations/0001_initial.py +++ b/stator/migrations/0001_initial.py @@ -1,4 +1,4 @@ -# Generated by Django 4.1.3 on 2022-11-09 05:46 +# Generated by Django 4.1.3 on 2022-11-10 03:24 from django.db import migrations, models @@ -11,7 +11,7 @@ class Migration(migrations.Migration): operations = [ migrations.CreateModel( - name="StatorTask", + name="StatorError", fields=[ ( "id", @@ -24,8 +24,11 @@ class Migration(migrations.Migration): ), ("model_label", models.CharField(max_length=200)), ("instance_pk", models.CharField(max_length=200)), - ("locked_until", models.DateTimeField(blank=True, null=True)), - ("priority", models.IntegerField(default=0)), + ("from_state", models.CharField(max_length=200)), + ("to_state", models.CharField(max_length=200)), + ("date", models.DateTimeField(auto_now_add=True)), + ("error", models.TextField()), + ("error_details", models.TextField(blank=True, null=True)), ], ), ] diff --git a/stator/models.py b/stator/models.py index 3b0da0a..235b18c 100644 --- a/stator/models.py +++ b/stator/models.py @@ -1,14 +1,13 @@ import datetime -from functools import reduce -from typing import Type, cast +import traceback +from typing import ClassVar, List, Optional, Type, cast from asgiref.sync import sync_to_async -from django.apps import apps from django.db import models, transaction from django.utils import timezone from django.utils.functional import classproperty -from stator.graph import State, StateGraph +from stator.graph import State, StateGraph, Transition class StateField(models.CharField): @@ -55,6 +54,9 @@ class StatorModel(models.Model): concrete model yourself. """ + # If this row is up for transition attempts + state_ready = models.BooleanField(default=False) + # When the state last actually changed, or the date of instance creation state_changed = models.DateTimeField(auto_now_add=True) @@ -62,68 +64,128 @@ class StatorModel(models.Model): # (and not successful, as this is cleared on transition) state_attempted = models.DateTimeField(blank=True, null=True) + # If a lock is out on this row, when it is locked until + # (we don't identify the lock owner, as there's no heartbeats) + state_locked_until = models.DateTimeField(null=True, blank=True) + + # Collection of subclasses of us + subclasses: ClassVar[List[Type["StatorModel"]]] = [] + class Meta: abstract = True + def __init_subclass__(cls) -> None: + if cls is not StatorModel: + cls.subclasses.append(cls) + + @classproperty + def state_graph(cls) -> Type[StateGraph]: + return cls._meta.get_field("state").graph + @classmethod - def schedule_overdue(cls, now=None) -> models.QuerySet: + async def atransition_schedule_due(cls, now=None) -> models.QuerySet: """ Finds instances of this model that need to run and schedule them. """ q = models.Q() - for transition in cls.state_graph.transitions(automatic_only=True): - q = q | transition.get_query(now=now) - return cls.objects.filter(q) + for state in cls.state_graph.states.values(): + state = cast(State, state) + if not state.terminal: + q = q | models.Q( + ( + models.Q( + state_attempted__lte=timezone.now() + - datetime.timedelta(seconds=state.try_interval) + ) + | models.Q(state_attempted__isnull=True) + ), + state=state.name, + ) + await cls.objects.filter(q).aupdate(state_ready=True) - @classproperty - def state_graph(cls) -> Type[StateGraph]: - return cls._meta.get_field("state").graph + @classmethod + def transition_get_with_lock( + cls, number: int, lock_expiry: datetime.datetime + ) -> List["StatorModel"]: + """ + Returns up to `number` tasks for execution, having locked them. + """ + with transaction.atomic(): + selected = list( + cls.objects.filter(state_locked_until__isnull=True, state_ready=True)[ + :number + ].select_for_update() + ) + cls.objects.filter(pk__in=[i.pk for i in selected]).update( + state_locked_until=timezone.now() + ) + return selected + + @classmethod + async def atransition_get_with_lock( + cls, number: int, lock_expiry: datetime.datetime + ) -> List["StatorModel"]: + return await sync_to_async(cls.transition_get_with_lock)(number, lock_expiry) - def schedule_transition(self, priority: int = 0): + @classmethod + async def atransition_clean_locks(cls): + await cls.objects.filter(state_locked_until__lte=timezone.now()).aupdate( + state_locked_until=None + ) + + def transition_schedule(self): """ Adds this instance to the queue to get its state transition attempted. The scheduler will call this, but you can also call it directly if you know it'll be ready and want to lower latency. """ - StatorTask.schedule_for_execution(self, priority=priority) + self.state_ready = True + self.save() - async def attempt_transition(self): + async def atransition_attempt(self) -> bool: """ Attempts to transition the current state by running its handler(s). """ # Try each transition in priority order - for transition in self.state_graph.states[self.state].transitions( - automatic_only=True - ): - success = await transition.get_handler()(self) + for transition in self.state.transitions(automatic_only=True): + try: + success = await transition.get_handler()(self) + except BaseException as e: + await StatorError.acreate_from_instance(self, transition, e) + traceback.print_exc() + continue if success: - await self.perform_transition(transition.to_state.name) - return + await self.atransition_perform(transition.to_state.name) + return True await self.__class__.objects.filter(pk=self.pk).aupdate( - state_attempted=timezone.now() + state_attempted=timezone.now(), + state_locked_until=None, + state_ready=False, ) + return False - async def perform_transition(self, state_name): + def transition_perform(self, state_name): """ - Transitions the instance to the given state name + Transitions the instance to the given state name, forcibly. """ if state_name not in self.state_graph.states: raise ValueError(f"Invalid state {state_name}") - await self.__class__.objects.filter(pk=self.pk).aupdate( + self.__class__.objects.filter(pk=self.pk).update( state=state_name, state_changed=timezone.now(), state_attempted=None, + state_locked_until=None, + state_ready=False, ) + atransition_perform = sync_to_async(transition_perform) -class StatorTask(models.Model): - """ - The model that we use for an internal scheduling queue. - Entries in this queue are up for checking and execution - it also performs - locking to ensure we get closer to exactly-once execution (but we err on - the side of at-least-once) +class StatorError(models.Model): + """ + Tracks any errors running the transitions. + Meant to be cleaned out regularly. Should probably be a log. """ # appname.modelname (lowercased) label for the model this represents @@ -132,60 +194,33 @@ class StatorTask(models.Model): # The primary key of that model (probably int or str) instance_pk = models.CharField(max_length=200) - # Locking columns (no runner ID, as we have no heartbeats - all runners - # only live for a short amount of time anyway) - locked_until = models.DateTimeField(null=True, blank=True) + # The state we moved from + from_state = models.CharField(max_length=200) - # Basic total ordering priority - higher is more important - priority = models.IntegerField(default=0) + # The state we moved to (or tried to) + to_state = models.CharField(max_length=200) - def __str__(self): - return f"#{self.pk}: {self.model_label}.{self.instance_pk}" + # When it happened + date = models.DateTimeField(auto_now_add=True) - @classmethod - def schedule_for_execution(cls, model_instance: StatorModel, priority: int = 0): - # We don't do a transaction here as it's fine to occasionally double up - model_label = model_instance._meta.label_lower - pk = model_instance.pk - # TODO: Increase priority of existing if present - if not cls.objects.filter( - model_label=model_label, instance_pk=pk, locked__isnull=True - ).exists(): - StatorTask.objects.create( - model_label=model_label, - instance_pk=pk, - priority=priority, - ) - - @classmethod - def get_for_execution(cls, number: int, lock_expiry: datetime.datetime): - """ - Returns up to `number` tasks for execution, having locked them. - """ - with transaction.atomic(): - selected = list( - cls.objects.filter(locked_until__isnull=True)[ - :number - ].select_for_update() - ) - cls.objects.filter(pk__in=[i.pk for i in selected]).update( - locked_until=timezone.now() - ) - return selected + # Error name + error = models.TextField() - @classmethod - async def aget_for_execution(cls, number: int, lock_expiry: datetime.datetime): - return await sync_to_async(cls.get_for_execution)(number, lock_expiry) + # Error details + error_details = models.TextField(blank=True, null=True) @classmethod - async def aclean_old_locks(cls): - await cls.objects.filter(locked_until__lte=timezone.now()).aupdate( - locked_until=None + async def acreate_from_instance( + cls, + instance: StatorModel, + transition: Transition, + exception: Optional[BaseException] = None, + ): + return await cls.objects.acreate( + model_label=instance._meta.label_lower, + instance_pk=str(instance.pk), + from_state=transition.from_state, + to_state=transition.to_state, + error=str(exception), + error_details=traceback.format_exc(), ) - - async def aget_model_instance(self) -> StatorModel: - model = apps.get_model(self.model_label) - return cast(StatorModel, await model.objects.aget(pk=self.pk)) - - async def adelete(self): - self.__class__.objects.adelete(pk=self.pk) diff --git a/stator/runner.py b/stator/runner.py index 8c6e0f1..f9c726e 100644 --- a/stator/runner.py +++ b/stator/runner.py @@ -4,11 +4,9 @@ import time import uuid from typing import List, Type -from asgiref.sync import sync_to_async -from django.db import transaction from django.utils import timezone -from stator.models import StatorModel, StatorTask +from stator.models import StatorModel class StatorRunner: @@ -22,6 +20,7 @@ class StatorRunner: LOCK_TIMEOUT = 120 MAX_TASKS = 30 + MAX_TASKS_PER_MODEL = 5 def __init__(self, models: List[Type[StatorModel]]): self.models = models @@ -32,38 +31,44 @@ class StatorRunner: self.handled = 0 self.tasks = [] # Clean up old locks - await StatorTask.aclean_old_locks() - # Examine what needs scheduling - + print("Running initial cleaning and scheduling") + initial_tasks = [] + for model in self.models: + initial_tasks.append(model.atransition_clean_locks()) + initial_tasks.append(model.atransition_schedule_due()) + await asyncio.gather(*initial_tasks) # For the first time period, launch tasks + print("Running main task loop") while (time.monotonic() - start_time) < self.START_TIMEOUT: self.remove_completed_tasks() space_remaining = self.MAX_TASKS - len(self.tasks) # Fetch new tasks - if space_remaining > 0: - for new_task in await StatorTask.aget_for_execution( - space_remaining, - timezone.now() + datetime.timedelta(seconds=self.LOCK_TIMEOUT), - ): - self.tasks.append(asyncio.create_task(self.run_task(new_task))) - self.handled += 1 + for model in self.models: + if space_remaining > 0: + for instance in await model.atransition_get_with_lock( + min(space_remaining, self.MAX_TASKS_PER_MODEL), + timezone.now() + datetime.timedelta(seconds=self.LOCK_TIMEOUT), + ): + print( + f"Attempting transition on {instance._meta.label_lower}#{instance.pk}" + ) + self.tasks.append( + asyncio.create_task(instance.atransition_attempt()) + ) + self.handled += 1 + space_remaining -= 1 # Prevent busylooping - await asyncio.sleep(0.01) + await asyncio.sleep(0.1) # Then wait for tasks to finish + print("Waiting for tasks to complete") while (time.monotonic() - start_time) < self.TOTAL_TIMEOUT: self.remove_completed_tasks() if not self.tasks: break # Prevent busylooping await asyncio.sleep(1) + print("Complete") return self.handled - async def run_task(self, task: StatorTask): - # Resolve the model instance - model_instance = await task.aget_model_instance() - await model_instance.attempt_transition() - # Remove ourselves from the database as complete - await task.adelete() - def remove_completed_tasks(self): self.tasks = [t for t in self.tasks if not t.done()] diff --git a/stator/tests/test_graph.py b/stator/tests/test_graph.py index f6b8404..0a7113d 100644 --- a/stator/tests/test_graph.py +++ b/stator/tests/test_graph.py @@ -51,14 +51,14 @@ def test_bad_declarations(): # More than one initial state with pytest.raises(ValueError): - class TestGraph(StateGraph): + class TestGraph2(StateGraph): initial = State() initial2 = State() # No initial states with pytest.raises(ValueError): - class TestGraph(StateGraph): + class TestGraph3(StateGraph): loop = State() loop2 = State() diff --git a/users/migrations/0005_follow_state_locked_until_follow_state_ready.py b/users/migrations/0005_follow_state_locked_until_follow_state_ready.py new file mode 100644 index 0000000..3aba08e --- /dev/null +++ b/users/migrations/0005_follow_state_locked_until_follow_state_ready.py @@ -0,0 +1,23 @@ +# Generated by Django 4.1.3 on 2022-11-10 03:24 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ("users", "0004_remove_follow_state_locked_and_more"), + ] + + operations = [ + migrations.AddField( + model_name="follow", + name="state_locked_until", + field=models.DateTimeField(blank=True, null=True), + ), + migrations.AddField( + model_name="follow", + name="state_ready", + field=models.BooleanField(default=False), + ), + ] diff --git a/users/models/__init__.py b/users/models/__init__.py index e1877bc..d46003f 100644 --- a/users/models/__init__.py +++ b/users/models/__init__.py @@ -1,6 +1,6 @@ from .block import Block # noqa from .domain import Domain # noqa -from .follow import Follow # noqa -from .identity import Identity # noqa +from .follow import Follow, FollowStates # noqa +from .identity import Identity, IdentityStates # noqa from .user import User # noqa from .user_event import UserEvent # noqa diff --git a/users/models/domain.py b/users/models/domain.py index 8467ac3..4ac6ee9 100644 --- a/users/models/domain.py +++ b/users/models/domain.py @@ -55,7 +55,7 @@ class Domain(models.Model): return cls.objects.create(domain=domain, local=False) @classmethod - def get_local_domain(cls, domain: str) -> Optional["Domain"]: + def get_domain(cls, domain: str) -> Optional["Domain"]: try: return cls.objects.get( models.Q(domain=domain) | models.Q(service_domain=domain) diff --git a/users/models/follow.py b/users/models/follow.py index 04f90ee..3325a0b 100644 --- a/users/models/follow.py +++ b/users/models/follow.py @@ -6,13 +6,13 @@ from stator.models import State, StateField, StateGraph, StatorModel class FollowStates(StateGraph): - pending = State(try_interval=3600) + pending = State(try_interval=30) requested = State() accepted = State() @pending.add_transition(requested) - async def try_request(cls, instance): - print("Would have tried to follow") + async def try_request(instance: "Follow"): # type:ignore + print("Would have tried to follow on", instance) return False requested.add_manual_transition(accepted) @@ -73,11 +73,3 @@ class Follow(StatorModel): follow.state = FollowStates.accepted follow.save() return follow - - def undo(self): - """ - Undoes this follow - """ - if not self.target.local: - Task.submit("follow_undo", str(self.pk)) - self.delete() diff --git a/users/models/identity.py b/users/models/identity.py index 98262bc..5e2cd06 100644 --- a/users/models/identity.py +++ b/users/models/identity.py @@ -14,9 +14,21 @@ from django.utils import timezone from OpenSSL import crypto from core.ld import canonicalise +from stator.models import State, StateField, StateGraph, StatorModel from users.models.domain import Domain +class IdentityStates(StateGraph): + outdated = State(try_interval=3600) + updated = State() + + @outdated.add_transition(updated) + async def fetch_identity(identity: "Identity"): # type:ignore + if identity.local: + return True + return await identity.fetch_actor() + + def upload_namer(prefix, instance, filename): """ Names uploaded images etc. @@ -26,7 +38,7 @@ def upload_namer(prefix, instance, filename): return f"{prefix}/{now.year}/{now.month}/{now.day}/{filename}" -class Identity(models.Model): +class Identity(StatorModel): """ Represents both local and remote Fediverse identities (actors) """ @@ -35,6 +47,8 @@ class Identity(models.Model): # one around as well for making nice URLs etc. actor_uri = models.CharField(max_length=500, unique=True) + state = StateField(IdentityStates) + local = models.BooleanField() users = models.ManyToManyField("users.User", related_name="identities") diff --git a/users/shortcuts.py b/users/shortcuts.py index 65206a3..3e7618a 100644 --- a/users/shortcuts.py +++ b/users/shortcuts.py @@ -3,7 +3,7 @@ from django.http import Http404 from users.models import Domain, Identity -def by_handle_or_404(request, handle, local=True, fetch=False): +def by_handle_or_404(request, handle, local=True, fetch=False) -> Identity: """ Retrieves an Identity by its long or short handle. Domain-sensitive, so it will understand short handles on alternate domains. @@ -12,14 +12,17 @@ def by_handle_or_404(request, handle, local=True, fetch=False): if "HTTP_HOST" not in request.META: raise Http404("No hostname available") username = handle - domain_instance = Domain.get_local_domain(request.META["HTTP_HOST"]) + domain_instance = Domain.get_domain(request.META["HTTP_HOST"]) if domain_instance is None: raise Http404("No matching domains found") domain = domain_instance.domain else: username, domain = handle.split("@", 1) # Resolve the domain to the display domain - domain = Domain.get_local_domain(request.META["HTTP_HOST"]).domain + domain_instance = Domain.get_domain(domain) + if domain_instance is None: + raise Http404("No matching domains found") + domain = domain_instance.domain identity = Identity.by_username_and_domain( username, domain, diff --git a/users/views/identity.py b/users/views/identity.py index 41c7880..d02505f 100644 --- a/users/views/identity.py +++ b/users/views/identity.py @@ -17,7 +17,7 @@ from core.forms import FormHelper from core.ld import canonicalise from core.signatures import HttpSignature from users.decorators import identity_required -from users.models import Domain, Follow, Identity +from users.models import Domain, Follow, Identity, IdentityStates from users.shortcuts import by_handle_or_404 @@ -34,7 +34,7 @@ class ViewIdentity(TemplateView): ) statuses = identity.statuses.all()[:100] if identity.data_age > settings.IDENTITY_MAX_AGE: - Task.submit("identity_fetch", identity.handle) + identity.transition_perform(IdentityStates.outdated) return { "identity": identity, "statuses": statuses, @@ -129,7 +129,7 @@ class CreateIdentity(FormView): def form_valid(self, form): username = form.cleaned_data["username"] domain = form.cleaned_data["domain"] - domain_instance = Domain.get_local_domain(domain) + domain_instance = Domain.get_domain(domain) new_identity = Identity.objects.create( actor_uri=f"https://{domain_instance.uri_domain}/@{username}@{domain}/actor/", username=username, -- cgit v1.2.3