mirror of
https://github.com/jointakahe/takahe.git
synced 2024-11-25 08:41:00 +00:00
Profile fetching now working on state machine
This commit is contained in:
parent
7746abbbb7
commit
2c3a129970
21 changed files with 304 additions and 383 deletions
|
@ -10,8 +10,7 @@ class DomainAdmin(admin.ModelAdmin):
|
||||||
"date",
|
"date",
|
||||||
"model_label",
|
"model_label",
|
||||||
"instance_pk",
|
"instance_pk",
|
||||||
"from_state",
|
"state",
|
||||||
"to_state",
|
|
||||||
"error",
|
"error",
|
||||||
]
|
]
|
||||||
ordering = ["-date"]
|
ordering = ["-date"]
|
||||||
|
|
148
stator/graph.py
148
stator/graph.py
|
@ -1,16 +1,4 @@
|
||||||
from typing import (
|
from typing import Any, Callable, ClassVar, Dict, List, Optional, Set, Tuple, Type
|
||||||
Any,
|
|
||||||
Callable,
|
|
||||||
ClassVar,
|
|
||||||
Dict,
|
|
||||||
List,
|
|
||||||
Optional,
|
|
||||||
Set,
|
|
||||||
Tuple,
|
|
||||||
Type,
|
|
||||||
Union,
|
|
||||||
cast,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class StateGraph:
|
class StateGraph:
|
||||||
|
@ -44,20 +32,43 @@ class StateGraph:
|
||||||
terminal_states = set()
|
terminal_states = set()
|
||||||
initial_state = None
|
initial_state = None
|
||||||
for state in cls.states.values():
|
for state in cls.states.values():
|
||||||
|
# Check for multiple initial states
|
||||||
if state.initial:
|
if state.initial:
|
||||||
if initial_state:
|
if initial_state:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"The graph has more than one initial state: {initial_state} and {state}"
|
f"The graph has more than one initial state: {initial_state} and {state}"
|
||||||
)
|
)
|
||||||
initial_state = state
|
initial_state = state
|
||||||
|
# Collect terminal states
|
||||||
if state.terminal:
|
if state.terminal:
|
||||||
terminal_states.add(state)
|
terminal_states.add(state)
|
||||||
|
# Ensure they do NOT have a handler
|
||||||
|
try:
|
||||||
|
state.handler
|
||||||
|
except AttributeError:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Terminal state '{state}' should not have a handler method ({state.handler_name})"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Ensure non-terminal states have a try interval and a handler
|
||||||
|
if not state.try_interval:
|
||||||
|
raise ValueError(
|
||||||
|
f"State '{state}' has no try_interval and is not terminal"
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
state.handler
|
||||||
|
except AttributeError:
|
||||||
|
raise ValueError(
|
||||||
|
f"State '{state}' does not have a handler method ({state.handler_name})"
|
||||||
|
)
|
||||||
if initial_state is None:
|
if initial_state is None:
|
||||||
raise ValueError("The graph has no initial state")
|
raise ValueError("The graph has no initial state")
|
||||||
cls.initial_state = initial_state
|
cls.initial_state = initial_state
|
||||||
cls.terminal_states = terminal_states
|
cls.terminal_states = terminal_states
|
||||||
# Generate choices
|
# Generate choices
|
||||||
cls.choices = [(state, name) for name, state in cls.states.items()]
|
cls.choices = [(name, name) for name in cls.states.keys()]
|
||||||
|
|
||||||
|
|
||||||
class State:
|
class State:
|
||||||
|
@ -65,49 +76,37 @@ class State:
|
||||||
Represents an individual state
|
Represents an individual state
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, try_interval: float = 300):
|
def __init__(
|
||||||
|
self,
|
||||||
|
try_interval: Optional[float] = None,
|
||||||
|
handler_name: Optional[str] = None,
|
||||||
|
):
|
||||||
self.try_interval = try_interval
|
self.try_interval = try_interval
|
||||||
|
self.handler_name = handler_name
|
||||||
self.parents: Set["State"] = set()
|
self.parents: Set["State"] = set()
|
||||||
self.children: Dict["State", "Transition"] = {}
|
self.children: Set["State"] = set()
|
||||||
|
|
||||||
def _add_to_graph(self, graph: Type[StateGraph], name: str):
|
def _add_to_graph(self, graph: Type[StateGraph], name: str):
|
||||||
self.graph = graph
|
self.graph = graph
|
||||||
self.name = name
|
self.name = name
|
||||||
self.graph.states[name] = self
|
self.graph.states[name] = self
|
||||||
|
if self.handler_name is None:
|
||||||
|
self.handler_name = f"handle_{self.name}"
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f"<State {self.name}>"
|
return f"<State {self.name}>"
|
||||||
|
|
||||||
def __str__(self):
|
def __eq__(self, other):
|
||||||
return self.name
|
if isinstance(other, State):
|
||||||
|
return self is other
|
||||||
|
return self.name == other
|
||||||
|
|
||||||
def __len__(self):
|
def __hash__(self):
|
||||||
return len(self.name)
|
return hash(id(self))
|
||||||
|
|
||||||
def add_transition(
|
def transitions_to(self, other: "State"):
|
||||||
self,
|
self.children.add(other)
|
||||||
other: "State",
|
other.parents.add(other)
|
||||||
handler: Optional[Callable] = None,
|
|
||||||
priority: int = 0,
|
|
||||||
) -> Callable:
|
|
||||||
def decorator(handler: Callable[[Any], bool]):
|
|
||||||
self.children[other] = Transition(
|
|
||||||
self,
|
|
||||||
other,
|
|
||||||
handler,
|
|
||||||
priority=priority,
|
|
||||||
)
|
|
||||||
other.parents.add(self)
|
|
||||||
return handler
|
|
||||||
|
|
||||||
# If we're not being called as a decorator, invoke it immediately
|
|
||||||
if handler is not None:
|
|
||||||
decorator(handler)
|
|
||||||
return decorator
|
|
||||||
|
|
||||||
def add_manual_transition(self, other: "State"):
|
|
||||||
self.children[other] = ManualTransition(self, other)
|
|
||||||
other.parents.add(self)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def initial(self):
|
def initial(self):
|
||||||
|
@ -117,59 +116,8 @@ class State:
|
||||||
def terminal(self):
|
def terminal(self):
|
||||||
return not self.children
|
return not self.children
|
||||||
|
|
||||||
def transitions(self, automatic_only=False) -> List["Transition"]:
|
@property
|
||||||
"""
|
def handler(self) -> Callable[[Any], Optional[str]]:
|
||||||
Returns all transitions from this State in priority order
|
if self.handler_name is None:
|
||||||
"""
|
raise AttributeError("No handler defined")
|
||||||
if automatic_only:
|
return getattr(self.graph, self.handler_name)
|
||||||
transitions = [t for t in self.children.values() if t.automatic]
|
|
||||||
else:
|
|
||||||
transitions = list(self.children.values())
|
|
||||||
return sorted(transitions, key=lambda t: t.priority, reverse=True)
|
|
||||||
|
|
||||||
|
|
||||||
class Transition:
|
|
||||||
"""
|
|
||||||
A possible transition from one state to another
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
from_state: State,
|
|
||||||
to_state: State,
|
|
||||||
handler: Union[str, Callable],
|
|
||||||
priority: int = 0,
|
|
||||||
):
|
|
||||||
self.from_state = from_state
|
|
||||||
self.to_state = to_state
|
|
||||||
self.handler = handler
|
|
||||||
self.priority = priority
|
|
||||||
self.automatic = True
|
|
||||||
|
|
||||||
def get_handler(self) -> Callable:
|
|
||||||
"""
|
|
||||||
Returns the handler (it might need resolving from a string)
|
|
||||||
"""
|
|
||||||
if isinstance(self.handler, str):
|
|
||||||
self.handler = getattr(self.from_state.graph, self.handler)
|
|
||||||
return cast(Callable, self.handler)
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return f"<Transition {self.from_state} -> {self.to_state}>"
|
|
||||||
|
|
||||||
|
|
||||||
class ManualTransition(Transition):
|
|
||||||
"""
|
|
||||||
A possible transition from one state to another that cannot be done by
|
|
||||||
the stator task runner, and must come from an external source.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
from_state: State,
|
|
||||||
to_state: State,
|
|
||||||
):
|
|
||||||
self.from_state = from_state
|
|
||||||
self.to_state = to_state
|
|
||||||
self.priority = 0
|
|
||||||
self.automatic = False
|
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
# Generated by Django 4.1.3 on 2022-11-10 03:24
|
# Generated by Django 4.1.3 on 2022-11-10 05:56
|
||||||
|
|
||||||
from django.db import migrations, models
|
from django.db import migrations, models
|
||||||
|
|
||||||
|
@ -24,8 +24,7 @@ class Migration(migrations.Migration):
|
||||||
),
|
),
|
||||||
("model_label", models.CharField(max_length=200)),
|
("model_label", models.CharField(max_length=200)),
|
||||||
("instance_pk", models.CharField(max_length=200)),
|
("instance_pk", models.CharField(max_length=200)),
|
||||||
("from_state", models.CharField(max_length=200)),
|
("state", models.CharField(max_length=200)),
|
||||||
("to_state", models.CharField(max_length=200)),
|
|
||||||
("date", models.DateTimeField(auto_now_add=True)),
|
("date", models.DateTimeField(auto_now_add=True)),
|
||||||
("error", models.TextField()),
|
("error", models.TextField()),
|
||||||
("error_details", models.TextField(blank=True, null=True)),
|
("error_details", models.TextField(blank=True, null=True)),
|
||||||
|
|
|
@ -1,13 +1,13 @@
|
||||||
import datetime
|
import datetime
|
||||||
import traceback
|
import traceback
|
||||||
from typing import ClassVar, List, Optional, Type, cast
|
from typing import ClassVar, List, Optional, Type, Union, cast
|
||||||
|
|
||||||
from asgiref.sync import sync_to_async
|
from asgiref.sync import sync_to_async
|
||||||
from django.db import models, transaction
|
from django.db import models, transaction
|
||||||
from django.utils import timezone
|
from django.utils import timezone
|
||||||
from django.utils.functional import classproperty
|
from django.utils.functional import classproperty
|
||||||
|
|
||||||
from stator.graph import State, StateGraph, Transition
|
from stator.graph import State, StateGraph
|
||||||
|
|
||||||
|
|
||||||
class StateField(models.CharField):
|
class StateField(models.CharField):
|
||||||
|
@ -29,16 +29,6 @@ class StateField(models.CharField):
|
||||||
kwargs["graph"] = self.graph
|
kwargs["graph"] = self.graph
|
||||||
return name, path, args, kwargs
|
return name, path, args, kwargs
|
||||||
|
|
||||||
def from_db_value(self, value, expression, connection):
|
|
||||||
if value is None:
|
|
||||||
return value
|
|
||||||
return self.graph.states[value]
|
|
||||||
|
|
||||||
def to_python(self, value):
|
|
||||||
if isinstance(value, State) or value is None:
|
|
||||||
return value
|
|
||||||
return self.graph.states[value]
|
|
||||||
|
|
||||||
def get_prep_value(self, value):
|
def get_prep_value(self, value):
|
||||||
if isinstance(value, State):
|
if isinstance(value, State):
|
||||||
return value.name
|
return value.name
|
||||||
|
@ -95,7 +85,9 @@ class StatorModel(models.Model):
|
||||||
(
|
(
|
||||||
models.Q(
|
models.Q(
|
||||||
state_attempted__lte=timezone.now()
|
state_attempted__lte=timezone.now()
|
||||||
- datetime.timedelta(seconds=state.try_interval)
|
- datetime.timedelta(
|
||||||
|
seconds=cast(float, state.try_interval)
|
||||||
|
)
|
||||||
)
|
)
|
||||||
| models.Q(state_attempted__isnull=True)
|
| models.Q(state_attempted__isnull=True)
|
||||||
),
|
),
|
||||||
|
@ -117,7 +109,7 @@ class StatorModel(models.Model):
|
||||||
].select_for_update()
|
].select_for_update()
|
||||||
)
|
)
|
||||||
cls.objects.filter(pk__in=[i.pk for i in selected]).update(
|
cls.objects.filter(pk__in=[i.pk for i in selected]).update(
|
||||||
state_locked_until=timezone.now()
|
state_locked_until=lock_expiry
|
||||||
)
|
)
|
||||||
return selected
|
return selected
|
||||||
|
|
||||||
|
@ -143,36 +135,36 @@ class StatorModel(models.Model):
|
||||||
self.state_ready = True
|
self.state_ready = True
|
||||||
self.save()
|
self.save()
|
||||||
|
|
||||||
async def atransition_attempt(self) -> bool:
|
async def atransition_attempt(self) -> Optional[str]:
|
||||||
"""
|
"""
|
||||||
Attempts to transition the current state by running its handler(s).
|
Attempts to transition the current state by running its handler(s).
|
||||||
"""
|
"""
|
||||||
# Try each transition in priority order
|
try:
|
||||||
for transition in self.state.transitions(automatic_only=True):
|
next_state = await self.state_graph.states[self.state].handler(self)
|
||||||
try:
|
except BaseException as e:
|
||||||
success = await transition.get_handler()(self)
|
await StatorError.acreate_from_instance(self, e)
|
||||||
except BaseException as e:
|
traceback.print_exc()
|
||||||
await StatorError.acreate_from_instance(self, transition, e)
|
else:
|
||||||
traceback.print_exc()
|
if next_state:
|
||||||
continue
|
await self.atransition_perform(next_state)
|
||||||
if success:
|
return next_state
|
||||||
await self.atransition_perform(transition.to_state.name)
|
|
||||||
return True
|
|
||||||
await self.__class__.objects.filter(pk=self.pk).aupdate(
|
await self.__class__.objects.filter(pk=self.pk).aupdate(
|
||||||
state_attempted=timezone.now(),
|
state_attempted=timezone.now(),
|
||||||
state_locked_until=None,
|
state_locked_until=None,
|
||||||
state_ready=False,
|
state_ready=False,
|
||||||
)
|
)
|
||||||
return False
|
return None
|
||||||
|
|
||||||
def transition_perform(self, state_name):
|
def transition_perform(self, state: Union[State, str]):
|
||||||
"""
|
"""
|
||||||
Transitions the instance to the given state name, forcibly.
|
Transitions the instance to the given state name, forcibly.
|
||||||
"""
|
"""
|
||||||
if state_name not in self.state_graph.states:
|
if isinstance(state, State):
|
||||||
raise ValueError(f"Invalid state {state_name}")
|
state = state.name
|
||||||
|
if state not in self.state_graph.states:
|
||||||
|
raise ValueError(f"Invalid state {state}")
|
||||||
self.__class__.objects.filter(pk=self.pk).update(
|
self.__class__.objects.filter(pk=self.pk).update(
|
||||||
state=state_name,
|
state=state,
|
||||||
state_changed=timezone.now(),
|
state_changed=timezone.now(),
|
||||||
state_attempted=None,
|
state_attempted=None,
|
||||||
state_locked_until=None,
|
state_locked_until=None,
|
||||||
|
@ -194,11 +186,8 @@ class StatorError(models.Model):
|
||||||
# The primary key of that model (probably int or str)
|
# The primary key of that model (probably int or str)
|
||||||
instance_pk = models.CharField(max_length=200)
|
instance_pk = models.CharField(max_length=200)
|
||||||
|
|
||||||
# The state we moved from
|
# The state we were on
|
||||||
from_state = models.CharField(max_length=200)
|
state = models.CharField(max_length=200)
|
||||||
|
|
||||||
# The state we moved to (or tried to)
|
|
||||||
to_state = models.CharField(max_length=200)
|
|
||||||
|
|
||||||
# When it happened
|
# When it happened
|
||||||
date = models.DateTimeField(auto_now_add=True)
|
date = models.DateTimeField(auto_now_add=True)
|
||||||
|
@ -213,14 +202,12 @@ class StatorError(models.Model):
|
||||||
async def acreate_from_instance(
|
async def acreate_from_instance(
|
||||||
cls,
|
cls,
|
||||||
instance: StatorModel,
|
instance: StatorModel,
|
||||||
transition: Transition,
|
|
||||||
exception: Optional[BaseException] = None,
|
exception: Optional[BaseException] = None,
|
||||||
):
|
):
|
||||||
return await cls.objects.acreate(
|
return await cls.objects.acreate(
|
||||||
model_label=instance._meta.label_lower,
|
model_label=instance._meta.label_lower,
|
||||||
instance_pk=str(instance.pk),
|
instance_pk=str(instance.pk),
|
||||||
from_state=transition.from_state,
|
state=instance.state,
|
||||||
to_state=transition.to_state,
|
|
||||||
error=str(exception),
|
error=str(exception),
|
||||||
error_details=traceback.format_exc(),
|
error_details=traceback.format_exc(),
|
||||||
)
|
)
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
import datetime
|
import datetime
|
||||||
import time
|
import time
|
||||||
|
import traceback
|
||||||
import uuid
|
import uuid
|
||||||
from typing import List, Type
|
from typing import List, Type
|
||||||
|
|
||||||
|
@ -53,7 +54,7 @@ class StatorRunner:
|
||||||
f"Attempting transition on {instance._meta.label_lower}#{instance.pk}"
|
f"Attempting transition on {instance._meta.label_lower}#{instance.pk}"
|
||||||
)
|
)
|
||||||
self.tasks.append(
|
self.tasks.append(
|
||||||
asyncio.create_task(instance.atransition_attempt())
|
asyncio.create_task(self.run_transition(instance))
|
||||||
)
|
)
|
||||||
self.handled += 1
|
self.handled += 1
|
||||||
space_remaining -= 1
|
space_remaining -= 1
|
||||||
|
@ -70,5 +71,17 @@ class StatorRunner:
|
||||||
print("Complete")
|
print("Complete")
|
||||||
return self.handled
|
return self.handled
|
||||||
|
|
||||||
|
async def run_transition(self, instance: StatorModel):
|
||||||
|
"""
|
||||||
|
Wrapper for atransition_attempt with fallback error handling
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
await instance.atransition_attempt()
|
||||||
|
except BaseException:
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
def remove_completed_tasks(self):
|
def remove_completed_tasks(self):
|
||||||
|
"""
|
||||||
|
Removes all completed asyncio.Tasks from our local in-progress list
|
||||||
|
"""
|
||||||
self.tasks = [t for t in self.tasks if not t.done()]
|
self.tasks = [t for t in self.tasks if not t.done()]
|
||||||
|
|
|
@ -9,39 +9,29 @@ def test_declare():
|
||||||
lookups.
|
lookups.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
fake_handler = lambda: True
|
|
||||||
|
|
||||||
class TestGraph(StateGraph):
|
class TestGraph(StateGraph):
|
||||||
initial = State()
|
initial = State(try_interval=3600)
|
||||||
second = State()
|
second = State(try_interval=1)
|
||||||
third = State()
|
third = State()
|
||||||
fourth = State()
|
|
||||||
final = State()
|
|
||||||
|
|
||||||
initial.add_transition(second, 60, handler=fake_handler)
|
initial.transitions_to(second)
|
||||||
second.add_transition(third, 60, handler="check_third")
|
second.transitions_to(third)
|
||||||
|
|
||||||
def check_third(cls):
|
@classmethod
|
||||||
return True
|
def handle_initial(cls):
|
||||||
|
pass
|
||||||
|
|
||||||
@third.add_transition(fourth, 60)
|
@classmethod
|
||||||
def check_fourth(cls):
|
def handle_second(cls):
|
||||||
return True
|
pass
|
||||||
|
|
||||||
fourth.add_manual_transition(final)
|
|
||||||
|
|
||||||
assert TestGraph.initial_state == TestGraph.initial
|
assert TestGraph.initial_state == TestGraph.initial
|
||||||
assert TestGraph.terminal_states == {TestGraph.final}
|
assert TestGraph.terminal_states == {TestGraph.third}
|
||||||
|
|
||||||
assert TestGraph.initial.children[TestGraph.second].get_handler() == fake_handler
|
assert TestGraph.initial.handler == TestGraph.handle_initial
|
||||||
assert (
|
assert TestGraph.initial.try_interval == 3600
|
||||||
TestGraph.second.children[TestGraph.third].get_handler()
|
assert TestGraph.second.handler == TestGraph.handle_second
|
||||||
== TestGraph.check_third
|
assert TestGraph.second.try_interval == 1
|
||||||
)
|
|
||||||
assert (
|
|
||||||
TestGraph.third.children[TestGraph.fourth].get_handler().__name__
|
|
||||||
== "check_fourth"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_bad_declarations():
|
def test_bad_declarations():
|
||||||
|
@ -62,5 +52,18 @@ def test_bad_declarations():
|
||||||
loop = State()
|
loop = State()
|
||||||
loop2 = State()
|
loop2 = State()
|
||||||
|
|
||||||
loop.add_transition(loop2, 1, handler="fake")
|
loop.transitions_to(loop2)
|
||||||
loop2.add_transition(loop, 1, handler="fake")
|
loop2.transitions_to(loop)
|
||||||
|
|
||||||
|
|
||||||
|
def test_state():
|
||||||
|
"""
|
||||||
|
Tests basic values of the State class
|
||||||
|
"""
|
||||||
|
|
||||||
|
class TestGraph(StateGraph):
|
||||||
|
initial = State()
|
||||||
|
|
||||||
|
assert "initial" == TestGraph.initial
|
||||||
|
assert TestGraph.initial == "initial"
|
||||||
|
assert TestGraph.initial == TestGraph.initial
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
# Generated by Django 4.1.3 on 2022-11-07 04:19
|
# Generated by Django 4.1.3 on 2022-11-10 05:58
|
||||||
|
|
||||||
import django.db.models.deletion
|
import django.db.models.deletion
|
||||||
from django.db import migrations, models
|
from django.db import migrations, models
|
||||||
|
|
|
@ -20,7 +20,7 @@ class UserEventAdmin(admin.ModelAdmin):
|
||||||
|
|
||||||
@admin.register(Identity)
|
@admin.register(Identity)
|
||||||
class IdentityAdmin(admin.ModelAdmin):
|
class IdentityAdmin(admin.ModelAdmin):
|
||||||
list_display = ["id", "handle", "actor_uri", "name", "local"]
|
list_display = ["id", "handle", "actor_uri", "state", "local"]
|
||||||
|
|
||||||
|
|
||||||
@admin.register(Follow)
|
@admin.register(Follow)
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
# Generated by Django 4.1.3 on 2022-11-07 04:19
|
# Generated by Django 4.1.3 on 2022-11-10 05:58
|
||||||
|
|
||||||
import functools
|
import functools
|
||||||
|
|
||||||
|
@ -6,7 +6,10 @@ import django.db.models.deletion
|
||||||
from django.conf import settings
|
from django.conf import settings
|
||||||
from django.db import migrations, models
|
from django.db import migrations, models
|
||||||
|
|
||||||
|
import stator.models
|
||||||
|
import users.models.follow
|
||||||
import users.models.identity
|
import users.models.identity
|
||||||
|
import users.models.inbox_message
|
||||||
|
|
||||||
|
|
||||||
class Migration(migrations.Migration):
|
class Migration(migrations.Migration):
|
||||||
|
@ -77,6 +80,37 @@ class Migration(migrations.Migration):
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
),
|
),
|
||||||
|
migrations.CreateModel(
|
||||||
|
name="InboxMessage",
|
||||||
|
fields=[
|
||||||
|
(
|
||||||
|
"id",
|
||||||
|
models.BigAutoField(
|
||||||
|
auto_created=True,
|
||||||
|
primary_key=True,
|
||||||
|
serialize=False,
|
||||||
|
verbose_name="ID",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
("state_ready", models.BooleanField(default=False)),
|
||||||
|
("state_changed", models.DateTimeField(auto_now_add=True)),
|
||||||
|
("state_attempted", models.DateTimeField(blank=True, null=True)),
|
||||||
|
("state_locked_until", models.DateTimeField(blank=True, null=True)),
|
||||||
|
("message", models.JSONField()),
|
||||||
|
(
|
||||||
|
"state",
|
||||||
|
stator.models.StateField(
|
||||||
|
choices=[("received", "received"), ("processed", "processed")],
|
||||||
|
default="received",
|
||||||
|
graph=users.models.inbox_message.InboxMessageStates,
|
||||||
|
max_length=100,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
options={
|
||||||
|
"abstract": False,
|
||||||
|
},
|
||||||
|
),
|
||||||
migrations.CreateModel(
|
migrations.CreateModel(
|
||||||
name="UserEvent",
|
name="UserEvent",
|
||||||
fields=[
|
fields=[
|
||||||
|
@ -124,7 +158,20 @@ class Migration(migrations.Migration):
|
||||||
verbose_name="ID",
|
verbose_name="ID",
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
|
("state_ready", models.BooleanField(default=False)),
|
||||||
|
("state_changed", models.DateTimeField(auto_now_add=True)),
|
||||||
|
("state_attempted", models.DateTimeField(blank=True, null=True)),
|
||||||
|
("state_locked_until", models.DateTimeField(blank=True, null=True)),
|
||||||
("actor_uri", models.CharField(max_length=500, unique=True)),
|
("actor_uri", models.CharField(max_length=500, unique=True)),
|
||||||
|
(
|
||||||
|
"state",
|
||||||
|
stator.models.StateField(
|
||||||
|
choices=[("outdated", "outdated"), ("updated", "updated")],
|
||||||
|
default="outdated",
|
||||||
|
graph=users.models.identity.IdentityStates,
|
||||||
|
max_length=100,
|
||||||
|
),
|
||||||
|
),
|
||||||
("local", models.BooleanField()),
|
("local", models.BooleanField()),
|
||||||
("username", models.CharField(blank=True, max_length=500, null=True)),
|
("username", models.CharField(blank=True, max_length=500, null=True)),
|
||||||
("name", models.CharField(blank=True, max_length=500, null=True)),
|
("name", models.CharField(blank=True, max_length=500, null=True)),
|
||||||
|
@ -239,10 +286,25 @@ class Migration(migrations.Migration):
|
||||||
verbose_name="ID",
|
verbose_name="ID",
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
|
("state_ready", models.BooleanField(default=False)),
|
||||||
|
("state_changed", models.DateTimeField(auto_now_add=True)),
|
||||||
|
("state_attempted", models.DateTimeField(blank=True, null=True)),
|
||||||
|
("state_locked_until", models.DateTimeField(blank=True, null=True)),
|
||||||
("uri", models.CharField(blank=True, max_length=500, null=True)),
|
("uri", models.CharField(blank=True, max_length=500, null=True)),
|
||||||
("note", models.TextField(blank=True, null=True)),
|
("note", models.TextField(blank=True, null=True)),
|
||||||
("requested", models.BooleanField(default=False)),
|
(
|
||||||
("accepted", models.BooleanField(default=False)),
|
"state",
|
||||||
|
stator.models.StateField(
|
||||||
|
choices=[
|
||||||
|
("pending", "pending"),
|
||||||
|
("requested", "requested"),
|
||||||
|
("accepted", "accepted"),
|
||||||
|
],
|
||||||
|
default="pending",
|
||||||
|
graph=users.models.follow.FollowStates,
|
||||||
|
max_length=100,
|
||||||
|
),
|
||||||
|
),
|
||||||
("created", models.DateTimeField(auto_now_add=True)),
|
("created", models.DateTimeField(auto_now_add=True)),
|
||||||
("updated", models.DateTimeField(auto_now=True)),
|
("updated", models.DateTimeField(auto_now=True)),
|
||||||
(
|
(
|
||||||
|
|
|
@ -1,44 +0,0 @@
|
||||||
# Generated by Django 4.1.3 on 2022-11-07 19:22
|
|
||||||
|
|
||||||
import django.utils.timezone
|
|
||||||
from django.db import migrations, models
|
|
||||||
|
|
||||||
import stator.models
|
|
||||||
import users.models.follow
|
|
||||||
|
|
||||||
|
|
||||||
class Migration(migrations.Migration):
|
|
||||||
|
|
||||||
dependencies = [
|
|
||||||
("users", "0001_initial"),
|
|
||||||
]
|
|
||||||
|
|
||||||
operations = [
|
|
||||||
migrations.AddField(
|
|
||||||
model_name="follow",
|
|
||||||
name="state",
|
|
||||||
field=stator.models.StateField(
|
|
||||||
choices=[
|
|
||||||
("pending", "pending"),
|
|
||||||
("requested", "requested"),
|
|
||||||
("accepted", "accepted"),
|
|
||||||
],
|
|
||||||
default="pending",
|
|
||||||
graph=users.models.follow.FollowStates,
|
|
||||||
max_length=100,
|
|
||||||
),
|
|
||||||
),
|
|
||||||
migrations.AddField(
|
|
||||||
model_name="follow",
|
|
||||||
name="state_attempted",
|
|
||||||
field=models.DateTimeField(blank=True, null=True),
|
|
||||||
),
|
|
||||||
migrations.AddField(
|
|
||||||
model_name="follow",
|
|
||||||
name="state_changed",
|
|
||||||
field=models.DateTimeField(
|
|
||||||
auto_now_add=True, default=django.utils.timezone.now
|
|
||||||
),
|
|
||||||
preserve_default=False,
|
|
||||||
),
|
|
||||||
]
|
|
|
@ -1,31 +0,0 @@
|
||||||
# Generated by Django 4.1.3 on 2022-11-08 03:58
|
|
||||||
|
|
||||||
from django.db import migrations, models
|
|
||||||
|
|
||||||
|
|
||||||
class Migration(migrations.Migration):
|
|
||||||
|
|
||||||
dependencies = [
|
|
||||||
("users", "0002_follow_state_follow_state_attempted_and_more"),
|
|
||||||
]
|
|
||||||
|
|
||||||
operations = [
|
|
||||||
migrations.RemoveField(
|
|
||||||
model_name="follow",
|
|
||||||
name="accepted",
|
|
||||||
),
|
|
||||||
migrations.RemoveField(
|
|
||||||
model_name="follow",
|
|
||||||
name="requested",
|
|
||||||
),
|
|
||||||
migrations.AddField(
|
|
||||||
model_name="follow",
|
|
||||||
name="state_locked",
|
|
||||||
field=models.DateTimeField(blank=True, null=True),
|
|
||||||
),
|
|
||||||
migrations.AddField(
|
|
||||||
model_name="follow",
|
|
||||||
name="state_runner",
|
|
||||||
field=models.CharField(blank=True, max_length=100, null=True),
|
|
||||||
),
|
|
||||||
]
|
|
|
@ -1,21 +0,0 @@
|
||||||
# Generated by Django 4.1.3 on 2022-11-09 05:15
|
|
||||||
|
|
||||||
from django.db import migrations
|
|
||||||
|
|
||||||
|
|
||||||
class Migration(migrations.Migration):
|
|
||||||
|
|
||||||
dependencies = [
|
|
||||||
("users", "0003_remove_follow_accepted_remove_follow_requested_and_more"),
|
|
||||||
]
|
|
||||||
|
|
||||||
operations = [
|
|
||||||
migrations.RemoveField(
|
|
||||||
model_name="follow",
|
|
||||||
name="state_locked",
|
|
||||||
),
|
|
||||||
migrations.RemoveField(
|
|
||||||
model_name="follow",
|
|
||||||
name="state_runner",
|
|
||||||
),
|
|
||||||
]
|
|
|
@ -1,23 +0,0 @@
|
||||||
# Generated by Django 4.1.3 on 2022-11-10 03:24
|
|
||||||
|
|
||||||
from django.db import migrations, models
|
|
||||||
|
|
||||||
|
|
||||||
class Migration(migrations.Migration):
|
|
||||||
|
|
||||||
dependencies = [
|
|
||||||
("users", "0004_remove_follow_state_locked_and_more"),
|
|
||||||
]
|
|
||||||
|
|
||||||
operations = [
|
|
||||||
migrations.AddField(
|
|
||||||
model_name="follow",
|
|
||||||
name="state_locked_until",
|
|
||||||
field=models.DateTimeField(blank=True, null=True),
|
|
||||||
),
|
|
||||||
migrations.AddField(
|
|
||||||
model_name="follow",
|
|
||||||
name="state_ready",
|
|
||||||
field=models.BooleanField(default=False),
|
|
||||||
),
|
|
||||||
]
|
|
|
@ -2,5 +2,6 @@ from .block import Block # noqa
|
||||||
from .domain import Domain # noqa
|
from .domain import Domain # noqa
|
||||||
from .follow import Follow, FollowStates # noqa
|
from .follow import Follow, FollowStates # noqa
|
||||||
from .identity import Identity, IdentityStates # noqa
|
from .identity import Identity, IdentityStates # noqa
|
||||||
|
from .inbox_message import InboxMessage, InboxMessageStates # noqa
|
||||||
from .user import User # noqa
|
from .user import User # noqa
|
||||||
from .user_event import UserEvent # noqa
|
from .user_event import UserEvent # noqa
|
||||||
|
|
|
@ -6,16 +6,20 @@ from stator.models import State, StateField, StateGraph, StatorModel
|
||||||
|
|
||||||
|
|
||||||
class FollowStates(StateGraph):
|
class FollowStates(StateGraph):
|
||||||
pending = State(try_interval=30)
|
unrequested = State(try_interval=30)
|
||||||
requested = State()
|
requested = State(try_interval=24 * 60 * 60)
|
||||||
accepted = State()
|
accepted = State()
|
||||||
|
|
||||||
@pending.add_transition(requested)
|
unrequested.transitions_to(requested)
|
||||||
async def try_request(instance: "Follow"): # type:ignore
|
requested.transitions_to(accepted)
|
||||||
print("Would have tried to follow on", instance)
|
|
||||||
return False
|
|
||||||
|
|
||||||
requested.add_manual_transition(accepted)
|
@classmethod
|
||||||
|
async def handle_unrequested(cls, instance: "Follow"):
|
||||||
|
print("Would have tried to follow on", instance)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def handle_requested(cls, instance: "Follow"):
|
||||||
|
print("Would have tried to requested on", instance)
|
||||||
|
|
||||||
|
|
||||||
class Follow(StatorModel):
|
class Follow(StatorModel):
|
||||||
|
@ -73,3 +77,17 @@ class Follow(StatorModel):
|
||||||
follow.state = FollowStates.accepted
|
follow.state = FollowStates.accepted
|
||||||
follow.save()
|
follow.save()
|
||||||
return follow
|
return follow
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def remote_created(cls, source, target, uri):
|
||||||
|
follow = cls.maybe_get(source=source, target=target)
|
||||||
|
if follow is None:
|
||||||
|
follow = Follow.objects.create(source=source, target=target, uri=uri)
|
||||||
|
if follow.state == FollowStates.fresh:
|
||||||
|
follow.transition_perform(FollowStates.requested)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def remote_accepted(cls, source, target):
|
||||||
|
follow = cls.maybe_get(source=source, target=target)
|
||||||
|
if follow and follow.state == FollowStates.requested:
|
||||||
|
follow.transition_perform(FollowStates.accepted)
|
||||||
|
|
|
@ -22,11 +22,16 @@ class IdentityStates(StateGraph):
|
||||||
outdated = State(try_interval=3600)
|
outdated = State(try_interval=3600)
|
||||||
updated = State()
|
updated = State()
|
||||||
|
|
||||||
@outdated.add_transition(updated)
|
outdated.transitions_to(updated)
|
||||||
async def fetch_identity(identity: "Identity"): # type:ignore
|
|
||||||
|
@classmethod
|
||||||
|
async def handle_outdated(cls, identity: "Identity"):
|
||||||
|
# Local identities never need fetching
|
||||||
if identity.local:
|
if identity.local:
|
||||||
return True
|
return "updated"
|
||||||
return await identity.fetch_actor()
|
# Run the actor fetch and progress to updated if it succeeds
|
||||||
|
if await identity.fetch_actor():
|
||||||
|
return "updated"
|
||||||
|
|
||||||
|
|
||||||
def upload_namer(prefix, instance, filename):
|
def upload_namer(prefix, instance, filename):
|
||||||
|
|
71
users/models/inbox_message.py
Normal file
71
users/models/inbox_message.py
Normal file
|
@ -0,0 +1,71 @@
|
||||||
|
from asgiref.sync import sync_to_async
|
||||||
|
from django.db import models
|
||||||
|
|
||||||
|
from stator.models import State, StateField, StateGraph, StatorModel
|
||||||
|
from users.models import Follow, Identity
|
||||||
|
|
||||||
|
|
||||||
|
class InboxMessageStates(StateGraph):
|
||||||
|
received = State(try_interval=300)
|
||||||
|
processed = State()
|
||||||
|
|
||||||
|
received.transitions_to(processed)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def handle_received(cls, instance: "InboxMessage"):
|
||||||
|
type = instance.message["type"].lower()
|
||||||
|
if type == "follow":
|
||||||
|
await instance.follow_request()
|
||||||
|
elif type == "accept":
|
||||||
|
inner_type = instance.message["object"]["type"].lower()
|
||||||
|
if inner_type == "follow":
|
||||||
|
await instance.follow_accepted()
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Cannot handle activity of type accept.{inner_type}")
|
||||||
|
elif type == "undo":
|
||||||
|
inner_type = instance.message["object"]["type"].lower()
|
||||||
|
if inner_type == "follow":
|
||||||
|
await instance.follow_undo()
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Cannot handle activity of type undo.{inner_type}")
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Cannot handle activity of type {type}")
|
||||||
|
|
||||||
|
|
||||||
|
class InboxMessage(StatorModel):
|
||||||
|
"""
|
||||||
|
an incoming inbox message that needs processing.
|
||||||
|
|
||||||
|
Yes, this is kind of its own message queue built on the state graph system.
|
||||||
|
It's fine. It'll scale up to a decent point.
|
||||||
|
"""
|
||||||
|
|
||||||
|
message = models.JSONField()
|
||||||
|
|
||||||
|
state = StateField(InboxMessageStates)
|
||||||
|
|
||||||
|
@sync_to_async
|
||||||
|
def follow_request(self):
|
||||||
|
"""
|
||||||
|
Handles an incoming follow request
|
||||||
|
"""
|
||||||
|
Follow.remote_created(
|
||||||
|
source=Identity.by_actor_uri_with_create(self.message["actor"]),
|
||||||
|
target=Identity.by_actor_uri(self.message["object"]),
|
||||||
|
uri=self.message["id"],
|
||||||
|
)
|
||||||
|
|
||||||
|
@sync_to_async
|
||||||
|
def follow_accepted(self):
|
||||||
|
"""
|
||||||
|
Handles an incoming acceptance of one of our follow requests
|
||||||
|
"""
|
||||||
|
Follow.remote_accepted(
|
||||||
|
source=Identity.by_actor_uri_with_create(self.message["actor"]),
|
||||||
|
target=Identity.by_actor_uri(self.message["object"]),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def follow_undo(self):
|
||||||
|
"""
|
||||||
|
Handles an incoming follow undo
|
||||||
|
"""
|
|
@ -19,10 +19,7 @@ def by_handle_or_404(request, handle, local=True, fetch=False) -> Identity:
|
||||||
else:
|
else:
|
||||||
username, domain = handle.split("@", 1)
|
username, domain = handle.split("@", 1)
|
||||||
# Resolve the domain to the display domain
|
# Resolve the domain to the display domain
|
||||||
domain_instance = Domain.get_domain(domain)
|
domain = Domain.get_remote_domain(domain).domain
|
||||||
if domain_instance is None:
|
|
||||||
raise Http404("No matching domains found")
|
|
||||||
domain = domain_instance.domain
|
|
||||||
identity = Identity.by_username_and_domain(
|
identity = Identity.by_username_and_domain(
|
||||||
username,
|
username,
|
||||||
domain,
|
domain,
|
||||||
|
|
|
@ -1,11 +0,0 @@
|
||||||
from asgiref.sync import sync_to_async
|
|
||||||
|
|
||||||
from users.models import Identity
|
|
||||||
|
|
||||||
|
|
||||||
async def handle_identity_fetch(task_handler):
|
|
||||||
# Get the actor URI via webfinger
|
|
||||||
actor_uri, handle = await Identity.fetch_webfinger(task_handler.subject)
|
|
||||||
# Get or create the identity, then fetch
|
|
||||||
identity = await sync_to_async(Identity.by_actor_uri_with_create)(actor_uri)
|
|
||||||
await identity.fetch_actor()
|
|
|
@ -1,56 +0,0 @@
|
||||||
from asgiref.sync import sync_to_async
|
|
||||||
|
|
||||||
from users.models import Follow, Identity
|
|
||||||
|
|
||||||
|
|
||||||
async def handle_inbox_item(task_handler):
|
|
||||||
type = task_handler.payload["type"].lower()
|
|
||||||
if type == "follow":
|
|
||||||
await inbox_follow(task_handler.payload)
|
|
||||||
elif type == "accept":
|
|
||||||
inner_type = task_handler.payload["object"]["type"].lower()
|
|
||||||
if inner_type == "follow":
|
|
||||||
await sync_to_async(accept_follow)(task_handler.payload["object"])
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Cannot handle activity of type accept.{inner_type}")
|
|
||||||
elif type == "undo":
|
|
||||||
inner_type = task_handler.payload["object"]["type"].lower()
|
|
||||||
if inner_type == "follow":
|
|
||||||
await inbox_unfollow(task_handler.payload["object"])
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Cannot handle activity of type undo.{inner_type}")
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Cannot handle activity of type {inner_type}")
|
|
||||||
|
|
||||||
|
|
||||||
async def inbox_follow(payload):
|
|
||||||
"""
|
|
||||||
Handles an incoming follow request
|
|
||||||
"""
|
|
||||||
# TODO: Manually approved follows
|
|
||||||
source = Identity.by_actor_uri_with_create(payload["actor"])
|
|
||||||
target = Identity.by_actor_uri(payload["object"])
|
|
||||||
# See if this follow already exists
|
|
||||||
try:
|
|
||||||
follow = Follow.objects.get(source=source, target=target)
|
|
||||||
except Follow.DoesNotExist:
|
|
||||||
follow = Follow.objects.create(source=source, target=target, uri=payload["id"])
|
|
||||||
# See if we need to acknowledge it
|
|
||||||
if not follow.acknowledged:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
async def inbox_unfollow(payload):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
def accept_follow(payload):
|
|
||||||
"""
|
|
||||||
Another server has acknowledged our follow request
|
|
||||||
"""
|
|
||||||
source = Identity.by_actor_uri_with_create(payload["actor"])
|
|
||||||
target = Identity.by_actor_uri(payload["object"])
|
|
||||||
follow = Follow.maybe_get(source, target)
|
|
||||||
if follow:
|
|
||||||
follow.accepted = True
|
|
||||||
follow.save()
|
|
|
@ -17,7 +17,7 @@ from core.forms import FormHelper
|
||||||
from core.ld import canonicalise
|
from core.ld import canonicalise
|
||||||
from core.signatures import HttpSignature
|
from core.signatures import HttpSignature
|
||||||
from users.decorators import identity_required
|
from users.decorators import identity_required
|
||||||
from users.models import Domain, Follow, Identity, IdentityStates
|
from users.models import Domain, Follow, Identity, IdentityStates, InboxMessage
|
||||||
from users.shortcuts import by_handle_or_404
|
from users.shortcuts import by_handle_or_404
|
||||||
|
|
||||||
|
|
||||||
|
@ -117,9 +117,13 @@ class CreateIdentity(FormView):
|
||||||
|
|
||||||
def clean(self):
|
def clean(self):
|
||||||
# Check for existing users
|
# Check for existing users
|
||||||
username = self.cleaned_data["username"]
|
username = self.cleaned_data.get("username")
|
||||||
domain = self.cleaned_data["domain"]
|
domain = self.cleaned_data.get("domain")
|
||||||
if Identity.objects.filter(username=username, domain=domain).exists():
|
if (
|
||||||
|
username
|
||||||
|
and domain
|
||||||
|
and Identity.objects.filter(username=username, domain=domain).exists()
|
||||||
|
):
|
||||||
raise forms.ValidationError(f"{username}@{domain} is already taken")
|
raise forms.ValidationError(f"{username}@{domain} is already taken")
|
||||||
|
|
||||||
def get_form(self):
|
def get_form(self):
|
||||||
|
@ -219,7 +223,7 @@ class Inbox(View):
|
||||||
):
|
):
|
||||||
return HttpResponseBadRequest("Bad signature")
|
return HttpResponseBadRequest("Bad signature")
|
||||||
# Hand off the item to be processed by the queue
|
# Hand off the item to be processed by the queue
|
||||||
Task.submit("inbox_item", subject=identity.actor_uri, payload=document)
|
InboxMessage.objects.create(message=document)
|
||||||
return HttpResponse(status=202)
|
return HttpResponse(status=202)
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue