mirror of
https://github.com/jointakahe/takahe.git
synced 2024-11-25 08:41:00 +00:00
Midway point in task refactor - changing direction
This commit is contained in:
parent
8a0a755889
commit
61c324508e
24 changed files with 698 additions and 241 deletions
|
@ -1,21 +0,0 @@
|
||||||
from django.contrib import admin
|
|
||||||
|
|
||||||
from miniq.models import Task
|
|
||||||
|
|
||||||
|
|
||||||
@admin.register(Task)
|
|
||||||
class TaskAdmin(admin.ModelAdmin):
|
|
||||||
|
|
||||||
list_display = ["id", "created", "type", "subject", "completed", "failed"]
|
|
||||||
ordering = ["-created"]
|
|
||||||
actions = ["reset"]
|
|
||||||
|
|
||||||
@admin.action(description="Reset Task")
|
|
||||||
def reset(self, request, queryset):
|
|
||||||
queryset.update(
|
|
||||||
failed=None,
|
|
||||||
completed=None,
|
|
||||||
locked=None,
|
|
||||||
locked_by=None,
|
|
||||||
error=None,
|
|
||||||
)
|
|
|
@ -1,48 +0,0 @@
|
||||||
# Generated by Django 4.1.3 on 2022-11-07 04:19
|
|
||||||
|
|
||||||
from django.db import migrations, models
|
|
||||||
|
|
||||||
|
|
||||||
class Migration(migrations.Migration):
|
|
||||||
|
|
||||||
initial = True
|
|
||||||
|
|
||||||
dependencies = []
|
|
||||||
|
|
||||||
operations = [
|
|
||||||
migrations.CreateModel(
|
|
||||||
name="Task",
|
|
||||||
fields=[
|
|
||||||
(
|
|
||||||
"id",
|
|
||||||
models.BigAutoField(
|
|
||||||
auto_created=True,
|
|
||||||
primary_key=True,
|
|
||||||
serialize=False,
|
|
||||||
verbose_name="ID",
|
|
||||||
),
|
|
||||||
),
|
|
||||||
(
|
|
||||||
"type",
|
|
||||||
models.CharField(
|
|
||||||
choices=[
|
|
||||||
("identity_fetch", "Identity Fetch"),
|
|
||||||
("inbox_item", "Inbox Item"),
|
|
||||||
("follow_request", "Follow Request"),
|
|
||||||
("follow_acknowledge", "Follow Acknowledge"),
|
|
||||||
],
|
|
||||||
max_length=500,
|
|
||||||
),
|
|
||||||
),
|
|
||||||
("priority", models.IntegerField(default=0)),
|
|
||||||
("subject", models.TextField()),
|
|
||||||
("payload", models.JSONField(blank=True, null=True)),
|
|
||||||
("error", models.TextField(blank=True, null=True)),
|
|
||||||
("created", models.DateTimeField(auto_now_add=True)),
|
|
||||||
("completed", models.DateTimeField(blank=True, null=True)),
|
|
||||||
("failed", models.DateTimeField(blank=True, null=True)),
|
|
||||||
("locked", models.DateTimeField(blank=True, null=True)),
|
|
||||||
("locked_by", models.CharField(blank=True, max_length=500, null=True)),
|
|
||||||
],
|
|
||||||
),
|
|
||||||
]
|
|
|
@ -1,71 +0,0 @@
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from django.db import models, transaction
|
|
||||||
from django.utils import timezone
|
|
||||||
|
|
||||||
|
|
||||||
class Task(models.Model):
|
|
||||||
"""
|
|
||||||
A task that must be done by a queue processor
|
|
||||||
"""
|
|
||||||
|
|
||||||
class TypeChoices(models.TextChoices):
|
|
||||||
identity_fetch = "identity_fetch"
|
|
||||||
inbox_item = "inbox_item"
|
|
||||||
follow_request = "follow_request"
|
|
||||||
follow_acknowledge = "follow_acknowledge"
|
|
||||||
|
|
||||||
type = models.CharField(max_length=500, choices=TypeChoices.choices)
|
|
||||||
priority = models.IntegerField(default=0)
|
|
||||||
subject = models.TextField()
|
|
||||||
payload = models.JSONField(blank=True, null=True)
|
|
||||||
error = models.TextField(blank=True, null=True)
|
|
||||||
|
|
||||||
created = models.DateTimeField(auto_now_add=True)
|
|
||||||
completed = models.DateTimeField(blank=True, null=True)
|
|
||||||
failed = models.DateTimeField(blank=True, null=True)
|
|
||||||
locked = models.DateTimeField(blank=True, null=True)
|
|
||||||
locked_by = models.CharField(max_length=500, blank=True, null=True)
|
|
||||||
|
|
||||||
def __str__(self):
|
|
||||||
return f"{self.id}/{self.type}({self.subject})"
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_one_available(cls, processor_id) -> Optional["Task"]:
|
|
||||||
"""
|
|
||||||
Gets one task off the list while reserving it, atomically.
|
|
||||||
"""
|
|
||||||
with transaction.atomic():
|
|
||||||
next_task = cls.objects.filter(locked__isnull=True).first()
|
|
||||||
if next_task is None:
|
|
||||||
return None
|
|
||||||
next_task.locked = timezone.now()
|
|
||||||
next_task.locked_by = processor_id
|
|
||||||
next_task.save()
|
|
||||||
return next_task
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def submit(cls, type, subject: str, payload=None, deduplicate=True):
|
|
||||||
# Deduplication is done against tasks that have not started yet only,
|
|
||||||
# and only on tasks without payloads
|
|
||||||
if deduplicate and not payload:
|
|
||||||
if cls.objects.filter(
|
|
||||||
type=type,
|
|
||||||
subject=subject,
|
|
||||||
completed__isnull=True,
|
|
||||||
failed__isnull=True,
|
|
||||||
locked__isnull=True,
|
|
||||||
).exists():
|
|
||||||
return
|
|
||||||
cls.objects.create(type=type, subject=subject, payload=payload)
|
|
||||||
|
|
||||||
async def complete(self):
|
|
||||||
await self.__class__.objects.filter(id=self.id).aupdate(
|
|
||||||
completed=timezone.now()
|
|
||||||
)
|
|
||||||
|
|
||||||
async def fail(self, error):
|
|
||||||
await self.__class__.objects.filter(id=self.id).aupdate(
|
|
||||||
failed=timezone.now(),
|
|
||||||
error=error,
|
|
||||||
)
|
|
|
@ -1,34 +0,0 @@
|
||||||
import traceback
|
|
||||||
|
|
||||||
from users.tasks.follow import handle_follow_request
|
|
||||||
from users.tasks.identity import handle_identity_fetch
|
|
||||||
from users.tasks.inbox import handle_inbox_item
|
|
||||||
|
|
||||||
|
|
||||||
class TaskHandler:
|
|
||||||
|
|
||||||
handlers = {
|
|
||||||
"identity_fetch": handle_identity_fetch,
|
|
||||||
"inbox_item": handle_inbox_item,
|
|
||||||
"follow_request": handle_follow_request,
|
|
||||||
}
|
|
||||||
|
|
||||||
def __init__(self, task):
|
|
||||||
self.task = task
|
|
||||||
self.subject = self.task.subject
|
|
||||||
self.payload = self.task.payload
|
|
||||||
|
|
||||||
async def handle(self):
|
|
||||||
try:
|
|
||||||
print(f"Task {self.task}: Starting")
|
|
||||||
if self.task.type not in self.handlers:
|
|
||||||
raise ValueError(f"Cannot handle type {self.task.type}")
|
|
||||||
await self.handlers[self.task.type](
|
|
||||||
self,
|
|
||||||
)
|
|
||||||
await self.task.complete()
|
|
||||||
print(f"Task {self.task}: Complete")
|
|
||||||
except BaseException as e:
|
|
||||||
print(f"Task {self.task}: Error {e}")
|
|
||||||
traceback.print_exc()
|
|
||||||
await self.task.fail(f"{e}\n\n" + traceback.format_exc())
|
|
|
@ -1,51 +0,0 @@
|
||||||
import asyncio
|
|
||||||
import time
|
|
||||||
import uuid
|
|
||||||
|
|
||||||
from asgiref.sync import sync_to_async
|
|
||||||
from django.http import HttpResponse
|
|
||||||
from django.views import View
|
|
||||||
|
|
||||||
from miniq.models import Task
|
|
||||||
from miniq.tasks import TaskHandler
|
|
||||||
|
|
||||||
|
|
||||||
class QueueProcessor(View):
|
|
||||||
"""
|
|
||||||
A view that takes some items off the queue and processes them.
|
|
||||||
Tries to limit its own runtime so it's within HTTP timeout limits.
|
|
||||||
"""
|
|
||||||
|
|
||||||
START_TIMEOUT = 30
|
|
||||||
TOTAL_TIMEOUT = 60
|
|
||||||
LOCK_TIMEOUT = 200
|
|
||||||
MAX_TASKS = 20
|
|
||||||
|
|
||||||
async def get(self, request):
|
|
||||||
start_time = time.monotonic()
|
|
||||||
processor_id = uuid.uuid4().hex
|
|
||||||
handled = 0
|
|
||||||
self.tasks = []
|
|
||||||
# For the first time period, launch tasks
|
|
||||||
while (time.monotonic() - start_time) < self.START_TIMEOUT:
|
|
||||||
# Remove completed tasks
|
|
||||||
self.tasks = [t for t in self.tasks if not t.done()]
|
|
||||||
# See if there's a new task
|
|
||||||
if len(self.tasks) < self.MAX_TASKS:
|
|
||||||
# Pop a task off the queue and run it
|
|
||||||
task = await sync_to_async(Task.get_one_available)(processor_id)
|
|
||||||
if task is not None:
|
|
||||||
self.tasks.append(asyncio.create_task(TaskHandler(task).handle()))
|
|
||||||
handled += 1
|
|
||||||
# Prevent busylooping
|
|
||||||
await asyncio.sleep(0.01)
|
|
||||||
# TODO: Clean up old locks here
|
|
||||||
# Then wait for tasks to finish
|
|
||||||
while (time.monotonic() - start_time) < self.TOTAL_TIMEOUT:
|
|
||||||
# Remove completed tasks
|
|
||||||
self.tasks = [t for t in self.tasks if not t.done()]
|
|
||||||
if not self.tasks:
|
|
||||||
break
|
|
||||||
# Prevent busylooping
|
|
||||||
await asyncio.sleep(1)
|
|
||||||
return HttpResponse(f"{handled} tasks handled")
|
|
8
stator/admin.py
Normal file
8
stator/admin.py
Normal file
|
@ -0,0 +1,8 @@
|
||||||
|
from django.contrib import admin
|
||||||
|
|
||||||
|
from stator.models import StatorTask
|
||||||
|
|
||||||
|
|
||||||
|
@admin.register(StatorTask)
|
||||||
|
class DomainAdmin(admin.ModelAdmin):
|
||||||
|
list_display = ["id", "model_label", "instance_pk", "locked_until"]
|
|
@ -1,6 +1,6 @@
|
||||||
from django.apps import AppConfig
|
from django.apps import AppConfig
|
||||||
|
|
||||||
|
|
||||||
class MiniqConfig(AppConfig):
|
class StatorConfig(AppConfig):
|
||||||
default_auto_field = "django.db.models.BigAutoField"
|
default_auto_field = "django.db.models.BigAutoField"
|
||||||
name = "miniq"
|
name = "stator"
|
162
stator/graph.py
Normal file
162
stator/graph.py
Normal file
|
@ -0,0 +1,162 @@
|
||||||
|
import datetime
|
||||||
|
from functools import wraps
|
||||||
|
from typing import Callable, ClassVar, Dict, List, Optional, Set, Tuple, Union
|
||||||
|
|
||||||
|
from django.db import models
|
||||||
|
from django.utils import timezone
|
||||||
|
|
||||||
|
|
||||||
|
class StateGraph:
|
||||||
|
"""
|
||||||
|
Represents a graph of possible states and transitions to attempt on them.
|
||||||
|
Does not support subclasses of existing graphs yet.
|
||||||
|
"""
|
||||||
|
|
||||||
|
states: ClassVar[Dict[str, "State"]]
|
||||||
|
choices: ClassVar[List[Tuple[str, str]]]
|
||||||
|
initial_state: ClassVar["State"]
|
||||||
|
terminal_states: ClassVar[Set["State"]]
|
||||||
|
|
||||||
|
def __init_subclass__(cls) -> None:
|
||||||
|
# Collect state memebers
|
||||||
|
cls.states = {}
|
||||||
|
for name, value in cls.__dict__.items():
|
||||||
|
if name in ["__module__", "__doc__", "states"]:
|
||||||
|
pass
|
||||||
|
elif name in ["initial_state", "terminal_states", "choices"]:
|
||||||
|
raise ValueError(f"Cannot name a state {name} - this is reserved")
|
||||||
|
elif isinstance(value, State):
|
||||||
|
value._add_to_graph(cls, name)
|
||||||
|
elif callable(value) or isinstance(value, classmethod):
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Graph has item {name} of unallowed type {type(value)}"
|
||||||
|
)
|
||||||
|
# Check the graph layout
|
||||||
|
terminal_states = set()
|
||||||
|
initial_state = None
|
||||||
|
for state in cls.states.values():
|
||||||
|
if state.initial:
|
||||||
|
if initial_state:
|
||||||
|
raise ValueError(
|
||||||
|
f"The graph has more than one initial state: {initial_state} and {state}"
|
||||||
|
)
|
||||||
|
initial_state = state
|
||||||
|
if state.terminal:
|
||||||
|
terminal_states.add(state)
|
||||||
|
if initial_state is None:
|
||||||
|
raise ValueError("The graph has no initial state")
|
||||||
|
cls.initial_state = initial_state
|
||||||
|
cls.terminal_states = terminal_states
|
||||||
|
# Generate choices
|
||||||
|
cls.choices = [(name, name) for name in cls.states.keys()]
|
||||||
|
|
||||||
|
|
||||||
|
class State:
|
||||||
|
"""
|
||||||
|
Represents an individual state
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, try_interval: float = 300):
|
||||||
|
self.try_interval = try_interval
|
||||||
|
self.parents: Set["State"] = set()
|
||||||
|
self.children: Dict["State", "Transition"] = {}
|
||||||
|
|
||||||
|
def _add_to_graph(self, graph: StateGraph, name: str):
|
||||||
|
self.graph = graph
|
||||||
|
self.name = name
|
||||||
|
self.graph.states[name] = self
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"<State {self.name}>"
|
||||||
|
|
||||||
|
def add_transition(
|
||||||
|
self,
|
||||||
|
other: "State",
|
||||||
|
handler: Optional[Union[str, Callable]] = None,
|
||||||
|
priority: int = 0,
|
||||||
|
) -> Callable:
|
||||||
|
def decorator(handler: Union[str, Callable]):
|
||||||
|
self.children[other] = Transition(
|
||||||
|
self,
|
||||||
|
other,
|
||||||
|
handler,
|
||||||
|
priority=priority,
|
||||||
|
)
|
||||||
|
other.parents.add(self)
|
||||||
|
# All handlers should be class methods, so do that automatically.
|
||||||
|
if callable(handler):
|
||||||
|
return classmethod(handler)
|
||||||
|
|
||||||
|
# If we're not being called as a decorator, invoke it immediately
|
||||||
|
if handler is not None:
|
||||||
|
decorator(handler)
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
def add_manual_transition(self, other: "State"):
|
||||||
|
self.children[other] = ManualTransition(self, other)
|
||||||
|
other.parents.add(self)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def initial(self):
|
||||||
|
return not self.parents
|
||||||
|
|
||||||
|
@property
|
||||||
|
def terminal(self):
|
||||||
|
return not self.children
|
||||||
|
|
||||||
|
def transitions(self, automatic_only=False) -> List["Transition"]:
|
||||||
|
"""
|
||||||
|
Returns all transitions from this State in priority order
|
||||||
|
"""
|
||||||
|
if automatic_only:
|
||||||
|
transitions = [t for t in self.children.values() if t.automatic]
|
||||||
|
else:
|
||||||
|
transitions = self.children.values()
|
||||||
|
return sorted(transitions, key=lambda t: t.priority, reverse=True)
|
||||||
|
|
||||||
|
|
||||||
|
class Transition:
|
||||||
|
"""
|
||||||
|
A possible transition from one state to another
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
from_state: State,
|
||||||
|
to_state: State,
|
||||||
|
handler: Union[str, Callable],
|
||||||
|
priority: int = 0,
|
||||||
|
):
|
||||||
|
self.from_state = from_state
|
||||||
|
self.to_state = to_state
|
||||||
|
self.handler = handler
|
||||||
|
self.priority = priority
|
||||||
|
self.automatic = True
|
||||||
|
|
||||||
|
def get_handler(self) -> Callable:
|
||||||
|
"""
|
||||||
|
Returns the handler (it might need resolving from a string)
|
||||||
|
"""
|
||||||
|
if isinstance(self.handler, str):
|
||||||
|
self.handler = getattr(self.from_state.graph, self.handler)
|
||||||
|
return self.handler
|
||||||
|
|
||||||
|
|
||||||
|
class ManualTransition(Transition):
|
||||||
|
"""
|
||||||
|
A possible transition from one state to another that cannot be done by
|
||||||
|
the stator task runner, and must come from an external source.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
from_state: State,
|
||||||
|
to_state: State,
|
||||||
|
):
|
||||||
|
self.from_state = from_state
|
||||||
|
self.to_state = to_state
|
||||||
|
self.handler = None
|
||||||
|
self.priority = 0
|
||||||
|
self.automatic = False
|
31
stator/migrations/0001_initial.py
Normal file
31
stator/migrations/0001_initial.py
Normal file
|
@ -0,0 +1,31 @@
|
||||||
|
# Generated by Django 4.1.3 on 2022-11-09 05:46
|
||||||
|
|
||||||
|
from django.db import migrations, models
|
||||||
|
|
||||||
|
|
||||||
|
class Migration(migrations.Migration):
|
||||||
|
|
||||||
|
initial = True
|
||||||
|
|
||||||
|
dependencies = []
|
||||||
|
|
||||||
|
operations = [
|
||||||
|
migrations.CreateModel(
|
||||||
|
name="StatorTask",
|
||||||
|
fields=[
|
||||||
|
(
|
||||||
|
"id",
|
||||||
|
models.BigAutoField(
|
||||||
|
auto_created=True,
|
||||||
|
primary_key=True,
|
||||||
|
serialize=False,
|
||||||
|
verbose_name="ID",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
("model_label", models.CharField(max_length=200)),
|
||||||
|
("instance_pk", models.CharField(max_length=200)),
|
||||||
|
("locked_until", models.DateTimeField(blank=True, null=True)),
|
||||||
|
("priority", models.IntegerField(default=0)),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
]
|
191
stator/models.py
Normal file
191
stator/models.py
Normal file
|
@ -0,0 +1,191 @@
|
||||||
|
import datetime
|
||||||
|
from functools import reduce
|
||||||
|
from typing import Type, cast
|
||||||
|
|
||||||
|
from asgiref.sync import sync_to_async
|
||||||
|
from django.apps import apps
|
||||||
|
from django.db import models, transaction
|
||||||
|
from django.utils import timezone
|
||||||
|
from django.utils.functional import classproperty
|
||||||
|
|
||||||
|
from stator.graph import State, StateGraph
|
||||||
|
|
||||||
|
|
||||||
|
class StateField(models.CharField):
|
||||||
|
"""
|
||||||
|
A special field that automatically gets choices from a state graph
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, graph: Type[StateGraph], **kwargs):
|
||||||
|
# Sensible default for state length
|
||||||
|
kwargs.setdefault("max_length", 100)
|
||||||
|
# Add choices and initial
|
||||||
|
self.graph = graph
|
||||||
|
kwargs["choices"] = self.graph.choices
|
||||||
|
kwargs["default"] = self.graph.initial_state.name
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
def deconstruct(self):
|
||||||
|
name, path, args, kwargs = super().deconstruct()
|
||||||
|
kwargs["graph"] = self.graph
|
||||||
|
return name, path, args, kwargs
|
||||||
|
|
||||||
|
def from_db_value(self, value, expression, connection):
|
||||||
|
if value is None:
|
||||||
|
return value
|
||||||
|
return self.graph.states[value]
|
||||||
|
|
||||||
|
def to_python(self, value):
|
||||||
|
if isinstance(value, State) or value is None:
|
||||||
|
return value
|
||||||
|
return self.graph.states[value]
|
||||||
|
|
||||||
|
def get_prep_value(self, value):
|
||||||
|
if isinstance(value, State):
|
||||||
|
return value.name
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
class StatorModel(models.Model):
|
||||||
|
"""
|
||||||
|
A model base class that has a state machine backing it, with tasks to work
|
||||||
|
out when to move the state to the next one.
|
||||||
|
|
||||||
|
You need to provide a "state" field as an instance of StateField on the
|
||||||
|
concrete model yourself.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# When the state last actually changed, or the date of instance creation
|
||||||
|
state_changed = models.DateTimeField(auto_now_add=True)
|
||||||
|
|
||||||
|
# When the last state change for the current state was attempted
|
||||||
|
# (and not successful, as this is cleared on transition)
|
||||||
|
state_attempted = models.DateTimeField(blank=True, null=True)
|
||||||
|
|
||||||
|
class Meta:
|
||||||
|
abstract = True
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def schedule_overdue(cls, now=None) -> models.QuerySet:
|
||||||
|
"""
|
||||||
|
Finds instances of this model that need to run and schedule them.
|
||||||
|
"""
|
||||||
|
q = models.Q()
|
||||||
|
for transition in cls.state_graph.transitions(automatic_only=True):
|
||||||
|
q = q | transition.get_query(now=now)
|
||||||
|
return cls.objects.filter(q)
|
||||||
|
|
||||||
|
@classproperty
|
||||||
|
def state_graph(cls) -> Type[StateGraph]:
|
||||||
|
return cls._meta.get_field("state").graph
|
||||||
|
|
||||||
|
def schedule_transition(self, priority: int = 0):
|
||||||
|
"""
|
||||||
|
Adds this instance to the queue to get its state transition attempted.
|
||||||
|
|
||||||
|
The scheduler will call this, but you can also call it directly if you
|
||||||
|
know it'll be ready and want to lower latency.
|
||||||
|
"""
|
||||||
|
StatorTask.schedule_for_execution(self, priority=priority)
|
||||||
|
|
||||||
|
async def attempt_transition(self):
|
||||||
|
"""
|
||||||
|
Attempts to transition the current state by running its handler(s).
|
||||||
|
"""
|
||||||
|
# Try each transition in priority order
|
||||||
|
for transition in self.state_graph.states[self.state].transitions(
|
||||||
|
automatic_only=True
|
||||||
|
):
|
||||||
|
success = await transition.get_handler()(self)
|
||||||
|
if success:
|
||||||
|
await self.perform_transition(transition.to_state.name)
|
||||||
|
return
|
||||||
|
await self.__class__.objects.filter(pk=self.pk).aupdate(
|
||||||
|
state_attempted=timezone.now()
|
||||||
|
)
|
||||||
|
|
||||||
|
async def perform_transition(self, state_name):
|
||||||
|
"""
|
||||||
|
Transitions the instance to the given state name
|
||||||
|
"""
|
||||||
|
if state_name not in self.state_graph.states:
|
||||||
|
raise ValueError(f"Invalid state {state_name}")
|
||||||
|
await self.__class__.objects.filter(pk=self.pk).aupdate(
|
||||||
|
state=state_name,
|
||||||
|
state_changed=timezone.now(),
|
||||||
|
state_attempted=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class StatorTask(models.Model):
|
||||||
|
"""
|
||||||
|
The model that we use for an internal scheduling queue.
|
||||||
|
|
||||||
|
Entries in this queue are up for checking and execution - it also performs
|
||||||
|
locking to ensure we get closer to exactly-once execution (but we err on
|
||||||
|
the side of at-least-once)
|
||||||
|
"""
|
||||||
|
|
||||||
|
# appname.modelname (lowercased) label for the model this represents
|
||||||
|
model_label = models.CharField(max_length=200)
|
||||||
|
|
||||||
|
# The primary key of that model (probably int or str)
|
||||||
|
instance_pk = models.CharField(max_length=200)
|
||||||
|
|
||||||
|
# Locking columns (no runner ID, as we have no heartbeats - all runners
|
||||||
|
# only live for a short amount of time anyway)
|
||||||
|
locked_until = models.DateTimeField(null=True, blank=True)
|
||||||
|
|
||||||
|
# Basic total ordering priority - higher is more important
|
||||||
|
priority = models.IntegerField(default=0)
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return f"#{self.pk}: {self.model_label}.{self.instance_pk}"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def schedule_for_execution(cls, model_instance: StatorModel, priority: int = 0):
|
||||||
|
# We don't do a transaction here as it's fine to occasionally double up
|
||||||
|
model_label = model_instance._meta.label_lower
|
||||||
|
pk = model_instance.pk
|
||||||
|
# TODO: Increase priority of existing if present
|
||||||
|
if not cls.objects.filter(
|
||||||
|
model_label=model_label, instance_pk=pk, locked__isnull=True
|
||||||
|
).exists():
|
||||||
|
StatorTask.objects.create(
|
||||||
|
model_label=model_label,
|
||||||
|
instance_pk=pk,
|
||||||
|
priority=priority,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_for_execution(cls, number: int, lock_expiry: datetime.datetime):
|
||||||
|
"""
|
||||||
|
Returns up to `number` tasks for execution, having locked them.
|
||||||
|
"""
|
||||||
|
with transaction.atomic():
|
||||||
|
selected = list(
|
||||||
|
cls.objects.filter(locked_until__isnull=True)[
|
||||||
|
:number
|
||||||
|
].select_for_update()
|
||||||
|
)
|
||||||
|
cls.objects.filter(pk__in=[i.pk for i in selected]).update(
|
||||||
|
locked_until=timezone.now()
|
||||||
|
)
|
||||||
|
return selected
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def aget_for_execution(cls, number: int, lock_expiry: datetime.datetime):
|
||||||
|
return await sync_to_async(cls.get_for_execution)(number, lock_expiry)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def aclean_old_locks(cls):
|
||||||
|
await cls.objects.filter(locked_until__lte=timezone.now()).aupdate(
|
||||||
|
locked_until=None
|
||||||
|
)
|
||||||
|
|
||||||
|
async def aget_model_instance(self) -> StatorModel:
|
||||||
|
model = apps.get_model(self.model_label)
|
||||||
|
return cast(StatorModel, await model.objects.aget(pk=self.pk))
|
||||||
|
|
||||||
|
async def adelete(self):
|
||||||
|
self.__class__.objects.adelete(pk=self.pk)
|
69
stator/runner.py
Normal file
69
stator/runner.py
Normal file
|
@ -0,0 +1,69 @@
|
||||||
|
import asyncio
|
||||||
|
import datetime
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
from typing import List, Type
|
||||||
|
|
||||||
|
from asgiref.sync import sync_to_async
|
||||||
|
from django.db import transaction
|
||||||
|
from django.utils import timezone
|
||||||
|
|
||||||
|
from stator.models import StatorModel, StatorTask
|
||||||
|
|
||||||
|
|
||||||
|
class StatorRunner:
|
||||||
|
"""
|
||||||
|
Runs tasks on models that are looking for state changes.
|
||||||
|
Designed to run in a one-shot mode, living inside a request.
|
||||||
|
"""
|
||||||
|
|
||||||
|
START_TIMEOUT = 30
|
||||||
|
TOTAL_TIMEOUT = 60
|
||||||
|
LOCK_TIMEOUT = 120
|
||||||
|
|
||||||
|
MAX_TASKS = 30
|
||||||
|
|
||||||
|
def __init__(self, models: List[Type[StatorModel]]):
|
||||||
|
self.models = models
|
||||||
|
self.runner_id = uuid.uuid4().hex
|
||||||
|
|
||||||
|
async def run(self):
|
||||||
|
start_time = time.monotonic()
|
||||||
|
self.handled = 0
|
||||||
|
self.tasks = []
|
||||||
|
# Clean up old locks
|
||||||
|
await StatorTask.aclean_old_locks()
|
||||||
|
# Examine what needs scheduling
|
||||||
|
|
||||||
|
# For the first time period, launch tasks
|
||||||
|
while (time.monotonic() - start_time) < self.START_TIMEOUT:
|
||||||
|
self.remove_completed_tasks()
|
||||||
|
space_remaining = self.MAX_TASKS - len(self.tasks)
|
||||||
|
# Fetch new tasks
|
||||||
|
if space_remaining > 0:
|
||||||
|
for new_task in await StatorTask.aget_for_execution(
|
||||||
|
space_remaining,
|
||||||
|
timezone.now() + datetime.timedelta(seconds=self.LOCK_TIMEOUT),
|
||||||
|
):
|
||||||
|
self.tasks.append(asyncio.create_task(self.run_task(new_task)))
|
||||||
|
self.handled += 1
|
||||||
|
# Prevent busylooping
|
||||||
|
await asyncio.sleep(0.01)
|
||||||
|
# Then wait for tasks to finish
|
||||||
|
while (time.monotonic() - start_time) < self.TOTAL_TIMEOUT:
|
||||||
|
self.remove_completed_tasks()
|
||||||
|
if not self.tasks:
|
||||||
|
break
|
||||||
|
# Prevent busylooping
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
return self.handled
|
||||||
|
|
||||||
|
async def run_task(self, task: StatorTask):
|
||||||
|
# Resolve the model instance
|
||||||
|
model_instance = await task.aget_model_instance()
|
||||||
|
await model_instance.attempt_transition()
|
||||||
|
# Remove ourselves from the database as complete
|
||||||
|
await task.adelete()
|
||||||
|
|
||||||
|
def remove_completed_tasks(self):
|
||||||
|
self.tasks = [t for t in self.tasks if not t.done()]
|
66
stator/tests/test_graph.py
Normal file
66
stator/tests/test_graph.py
Normal file
|
@ -0,0 +1,66 @@
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from stator.graph import State, StateGraph
|
||||||
|
|
||||||
|
|
||||||
|
def test_declare():
|
||||||
|
"""
|
||||||
|
Tests a basic graph declaration and various kinds of handler
|
||||||
|
lookups.
|
||||||
|
"""
|
||||||
|
|
||||||
|
fake_handler = lambda: True
|
||||||
|
|
||||||
|
class TestGraph(StateGraph):
|
||||||
|
initial = State()
|
||||||
|
second = State()
|
||||||
|
third = State()
|
||||||
|
fourth = State()
|
||||||
|
final = State()
|
||||||
|
|
||||||
|
initial.add_transition(second, 60, handler=fake_handler)
|
||||||
|
second.add_transition(third, 60, handler="check_third")
|
||||||
|
|
||||||
|
def check_third(cls):
|
||||||
|
return True
|
||||||
|
|
||||||
|
@third.add_transition(fourth, 60)
|
||||||
|
def check_fourth(cls):
|
||||||
|
return True
|
||||||
|
|
||||||
|
fourth.add_manual_transition(final)
|
||||||
|
|
||||||
|
assert TestGraph.initial_state == TestGraph.initial
|
||||||
|
assert TestGraph.terminal_states == {TestGraph.final}
|
||||||
|
|
||||||
|
assert TestGraph.initial.children[TestGraph.second].get_handler() == fake_handler
|
||||||
|
assert (
|
||||||
|
TestGraph.second.children[TestGraph.third].get_handler()
|
||||||
|
== TestGraph.check_third
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
TestGraph.third.children[TestGraph.fourth].get_handler().__name__
|
||||||
|
== "check_fourth"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_bad_declarations():
|
||||||
|
"""
|
||||||
|
Tests that you can't declare an invalid graph.
|
||||||
|
"""
|
||||||
|
# More than one initial state
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
|
||||||
|
class TestGraph(StateGraph):
|
||||||
|
initial = State()
|
||||||
|
initial2 = State()
|
||||||
|
|
||||||
|
# No initial states
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
|
||||||
|
class TestGraph(StateGraph):
|
||||||
|
loop = State()
|
||||||
|
loop2 = State()
|
||||||
|
|
||||||
|
loop.add_transition(loop2, 1, handler="fake")
|
||||||
|
loop2.add_transition(loop, 1, handler="fake")
|
17
stator/views.py
Normal file
17
stator/views.py
Normal file
|
@ -0,0 +1,17 @@
|
||||||
|
from django.http import HttpResponse
|
||||||
|
from django.views import View
|
||||||
|
|
||||||
|
from stator.runner import StatorRunner
|
||||||
|
from users.models import Follow
|
||||||
|
|
||||||
|
|
||||||
|
class RequestRunner(View):
|
||||||
|
"""
|
||||||
|
Runs a Stator runner within a HTTP request. For when you're on something
|
||||||
|
serverless.
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def get(self, request):
|
||||||
|
runner = StatorRunner([Follow])
|
||||||
|
handled = await runner.run()
|
||||||
|
return HttpResponse(f"Handled {handled}")
|
|
@ -26,7 +26,7 @@ INSTALLED_APPS = [
|
||||||
"core",
|
"core",
|
||||||
"statuses",
|
"statuses",
|
||||||
"users",
|
"users",
|
||||||
"miniq",
|
"stator",
|
||||||
]
|
]
|
||||||
|
|
||||||
MIDDLEWARE = [
|
MIDDLEWARE = [
|
||||||
|
|
|
@ -2,7 +2,7 @@ from django.contrib import admin
|
||||||
from django.urls import path
|
from django.urls import path
|
||||||
|
|
||||||
from core import views as core
|
from core import views as core
|
||||||
from miniq import views as miniq
|
from stator import views as stator
|
||||||
from users.views import auth, identity
|
from users.views import auth, identity
|
||||||
|
|
||||||
urlpatterns = [
|
urlpatterns = [
|
||||||
|
@ -22,7 +22,7 @@ urlpatterns = [
|
||||||
# Well-known endpoints
|
# Well-known endpoints
|
||||||
path(".well-known/webfinger", identity.Webfinger.as_view()),
|
path(".well-known/webfinger", identity.Webfinger.as_view()),
|
||||||
# Task runner
|
# Task runner
|
||||||
path(".queue/process/", miniq.QueueProcessor.as_view()),
|
path(".stator/runner/", stator.RequestRunner.as_view()),
|
||||||
# Django admin
|
# Django admin
|
||||||
path("djadmin/", admin.site.urls),
|
path("djadmin/", admin.site.urls),
|
||||||
]
|
]
|
||||||
|
|
|
@ -25,4 +25,4 @@ class IdentityAdmin(admin.ModelAdmin):
|
||||||
|
|
||||||
@admin.register(Follow)
|
@admin.register(Follow)
|
||||||
class FollowAdmin(admin.ModelAdmin):
|
class FollowAdmin(admin.ModelAdmin):
|
||||||
list_display = ["id", "source", "target", "requested", "accepted"]
|
list_display = ["id", "source", "target", "state"]
|
||||||
|
|
|
@ -0,0 +1,44 @@
|
||||||
|
# Generated by Django 4.1.3 on 2022-11-07 19:22
|
||||||
|
|
||||||
|
import django.utils.timezone
|
||||||
|
from django.db import migrations, models
|
||||||
|
|
||||||
|
import stator.models
|
||||||
|
import users.models.follow
|
||||||
|
|
||||||
|
|
||||||
|
class Migration(migrations.Migration):
|
||||||
|
|
||||||
|
dependencies = [
|
||||||
|
("users", "0001_initial"),
|
||||||
|
]
|
||||||
|
|
||||||
|
operations = [
|
||||||
|
migrations.AddField(
|
||||||
|
model_name="follow",
|
||||||
|
name="state",
|
||||||
|
field=stator.models.StateField(
|
||||||
|
choices=[
|
||||||
|
("pending", "pending"),
|
||||||
|
("requested", "requested"),
|
||||||
|
("accepted", "accepted"),
|
||||||
|
],
|
||||||
|
default="pending",
|
||||||
|
graph=users.models.follow.FollowStates,
|
||||||
|
max_length=100,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
migrations.AddField(
|
||||||
|
model_name="follow",
|
||||||
|
name="state_attempted",
|
||||||
|
field=models.DateTimeField(blank=True, null=True),
|
||||||
|
),
|
||||||
|
migrations.AddField(
|
||||||
|
model_name="follow",
|
||||||
|
name="state_changed",
|
||||||
|
field=models.DateTimeField(
|
||||||
|
auto_now_add=True, default=django.utils.timezone.now
|
||||||
|
),
|
||||||
|
preserve_default=False,
|
||||||
|
),
|
||||||
|
]
|
|
@ -0,0 +1,31 @@
|
||||||
|
# Generated by Django 4.1.3 on 2022-11-08 03:58
|
||||||
|
|
||||||
|
from django.db import migrations, models
|
||||||
|
|
||||||
|
|
||||||
|
class Migration(migrations.Migration):
|
||||||
|
|
||||||
|
dependencies = [
|
||||||
|
("users", "0002_follow_state_follow_state_attempted_and_more"),
|
||||||
|
]
|
||||||
|
|
||||||
|
operations = [
|
||||||
|
migrations.RemoveField(
|
||||||
|
model_name="follow",
|
||||||
|
name="accepted",
|
||||||
|
),
|
||||||
|
migrations.RemoveField(
|
||||||
|
model_name="follow",
|
||||||
|
name="requested",
|
||||||
|
),
|
||||||
|
migrations.AddField(
|
||||||
|
model_name="follow",
|
||||||
|
name="state_locked",
|
||||||
|
field=models.DateTimeField(blank=True, null=True),
|
||||||
|
),
|
||||||
|
migrations.AddField(
|
||||||
|
model_name="follow",
|
||||||
|
name="state_runner",
|
||||||
|
field=models.CharField(blank=True, max_length=100, null=True),
|
||||||
|
),
|
||||||
|
]
|
21
users/migrations/0004_remove_follow_state_locked_and_more.py
Normal file
21
users/migrations/0004_remove_follow_state_locked_and_more.py
Normal file
|
@ -0,0 +1,21 @@
|
||||||
|
# Generated by Django 4.1.3 on 2022-11-09 05:15
|
||||||
|
|
||||||
|
from django.db import migrations
|
||||||
|
|
||||||
|
|
||||||
|
class Migration(migrations.Migration):
|
||||||
|
|
||||||
|
dependencies = [
|
||||||
|
("users", "0003_remove_follow_accepted_remove_follow_requested_and_more"),
|
||||||
|
]
|
||||||
|
|
||||||
|
operations = [
|
||||||
|
migrations.RemoveField(
|
||||||
|
model_name="follow",
|
||||||
|
name="state_locked",
|
||||||
|
),
|
||||||
|
migrations.RemoveField(
|
||||||
|
model_name="follow",
|
||||||
|
name="state_runner",
|
||||||
|
),
|
||||||
|
]
|
|
@ -2,10 +2,23 @@ from typing import Optional
|
||||||
|
|
||||||
from django.db import models
|
from django.db import models
|
||||||
|
|
||||||
from miniq.models import Task
|
from stator.models import State, StateField, StateGraph, StatorModel
|
||||||
|
|
||||||
|
|
||||||
class Follow(models.Model):
|
class FollowStates(StateGraph):
|
||||||
|
pending = State(try_interval=3600)
|
||||||
|
requested = State()
|
||||||
|
accepted = State()
|
||||||
|
|
||||||
|
@pending.add_transition(requested)
|
||||||
|
async def try_request(cls, instance):
|
||||||
|
print("Would have tried to follow")
|
||||||
|
return False
|
||||||
|
|
||||||
|
requested.add_manual_transition(accepted)
|
||||||
|
|
||||||
|
|
||||||
|
class Follow(StatorModel):
|
||||||
"""
|
"""
|
||||||
When one user (the source) follows other (the target)
|
When one user (the source) follows other (the target)
|
||||||
"""
|
"""
|
||||||
|
@ -24,8 +37,7 @@ class Follow(models.Model):
|
||||||
uri = models.CharField(blank=True, null=True, max_length=500)
|
uri = models.CharField(blank=True, null=True, max_length=500)
|
||||||
note = models.TextField(blank=True, null=True)
|
note = models.TextField(blank=True, null=True)
|
||||||
|
|
||||||
requested = models.BooleanField(default=False)
|
state = StateField(FollowStates)
|
||||||
accepted = models.BooleanField(default=False)
|
|
||||||
|
|
||||||
created = models.DateTimeField(auto_now_add=True)
|
created = models.DateTimeField(auto_now_add=True)
|
||||||
updated = models.DateTimeField(auto_now=True)
|
updated = models.DateTimeField(auto_now=True)
|
||||||
|
@ -50,17 +62,15 @@ class Follow(models.Model):
|
||||||
(which can be local or remote).
|
(which can be local or remote).
|
||||||
"""
|
"""
|
||||||
if not source.local:
|
if not source.local:
|
||||||
raise ValueError("You cannot initiate follows on a remote Identity")
|
raise ValueError("You cannot initiate follows from a remote Identity")
|
||||||
try:
|
try:
|
||||||
follow = Follow.objects.get(source=source, target=target)
|
follow = Follow.objects.get(source=source, target=target)
|
||||||
except Follow.DoesNotExist:
|
except Follow.DoesNotExist:
|
||||||
follow = Follow.objects.create(source=source, target=target, uri="")
|
follow = Follow.objects.create(source=source, target=target, uri="")
|
||||||
follow.uri = source.actor_uri + f"follow/{follow.pk}/"
|
follow.uri = source.actor_uri + f"follow/{follow.pk}/"
|
||||||
|
# TODO: Local follow approvals
|
||||||
if target.local:
|
if target.local:
|
||||||
follow.requested = True
|
follow.state = FollowStates.accepted
|
||||||
follow.accepted = True
|
|
||||||
else:
|
|
||||||
Task.submit("follow_request", str(follow.pk))
|
|
||||||
follow.save()
|
follow.save()
|
||||||
return follow
|
return follow
|
||||||
|
|
||||||
|
|
|
@ -27,3 +27,36 @@ async def handle_follow_request(task_handler):
|
||||||
if response.status_code >= 400:
|
if response.status_code >= 400:
|
||||||
raise ValueError(f"Request error: {response.status_code} {response.content}")
|
raise ValueError(f"Request error: {response.status_code} {response.content}")
|
||||||
await Follow.objects.filter(pk=follow.pk).aupdate(requested=True)
|
await Follow.objects.filter(pk=follow.pk).aupdate(requested=True)
|
||||||
|
|
||||||
|
|
||||||
|
def send_follow_undo(id):
|
||||||
|
"""
|
||||||
|
Request a follow from a remote server
|
||||||
|
"""
|
||||||
|
follow = Follow.objects.select_related("source", "source__domain", "target").get(
|
||||||
|
pk=id
|
||||||
|
)
|
||||||
|
# Construct the request
|
||||||
|
request = canonicalise(
|
||||||
|
{
|
||||||
|
"@context": "https://www.w3.org/ns/activitystreams",
|
||||||
|
"id": follow.uri + "#undo",
|
||||||
|
"type": "Undo",
|
||||||
|
"actor": follow.source.actor_uri,
|
||||||
|
"object": {
|
||||||
|
"id": follow.uri,
|
||||||
|
"type": "Follow",
|
||||||
|
"actor": follow.source.actor_uri,
|
||||||
|
"object": follow.target.actor_uri,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
# Sign it and send it
|
||||||
|
from asgiref.sync import async_to_sync
|
||||||
|
|
||||||
|
response = async_to_sync(HttpSignature.signed_request)(
|
||||||
|
follow.target.inbox_uri, request, follow.source
|
||||||
|
)
|
||||||
|
if response.status_code >= 400:
|
||||||
|
raise ValueError(f"Request error: {response.status_code} {response.content}")
|
||||||
|
print(response)
|
||||||
|
|
|
@ -16,7 +16,6 @@ from django.views.generic import FormView, TemplateView, View
|
||||||
from core.forms import FormHelper
|
from core.forms import FormHelper
|
||||||
from core.ld import canonicalise
|
from core.ld import canonicalise
|
||||||
from core.signatures import HttpSignature
|
from core.signatures import HttpSignature
|
||||||
from miniq.models import Task
|
|
||||||
from users.decorators import identity_required
|
from users.decorators import identity_required
|
||||||
from users.models import Domain, Follow, Identity
|
from users.models import Domain, Follow, Identity
|
||||||
from users.shortcuts import by_handle_or_404
|
from users.shortcuts import by_handle_or_404
|
||||||
|
|
Loading…
Reference in a new issue