Rework Stator to use a next field and no async

This commit is contained in:
Andrew Godwin 2023-07-07 15:14:06 -06:00
parent e34e4c0c77
commit 7f483af8d3
28 changed files with 1004 additions and 461 deletions

View file

@ -210,8 +210,8 @@ class TimelineEventAdmin(admin.ModelAdmin):
@admin.register(FanOut) @admin.register(FanOut)
class FanOutAdmin(admin.ModelAdmin): class FanOutAdmin(admin.ModelAdmin):
list_display = ["id", "state", "created", "state_attempted", "type", "identity"] list_display = ["id", "state", "created", "state_next_attempt", "type", "identity"]
list_filter = (IdentityLocalFilter, "type", "state", "state_attempted") list_filter = (IdentityLocalFilter, "type", "state")
raw_id_fields = ["subject_post", "subject_post_interaction"] raw_id_fields = ["subject_post", "subject_post_interaction"]
autocomplete_fields = ["identity"] autocomplete_fields = ["identity"]
readonly_fields = ["created", "updated", "state_changed"] readonly_fields = ["created", "updated", "state_changed"]
@ -229,7 +229,7 @@ class FanOutAdmin(admin.ModelAdmin):
@admin.register(PostInteraction) @admin.register(PostInteraction)
class PostInteractionAdmin(admin.ModelAdmin): class PostInteractionAdmin(admin.ModelAdmin):
list_display = ["id", "state", "state_attempted", "type", "identity", "post"] list_display = ["id", "state", "state_next_attempt", "type", "identity", "post"]
list_filter = (IdentityLocalFilter, "type", "state") list_filter = (IdentityLocalFilter, "type", "state")
raw_id_fields = ["post"] raw_id_fields = ["post"]
autocomplete_fields = ["identity"] autocomplete_fields = ["identity"]

View file

@ -0,0 +1,234 @@
# Generated by Django 4.2.1 on 2023-07-05 22:18
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("activities", "0016_index_together_migration"),
]
operations = [
migrations.RemoveIndex(
model_name="emoji",
name="activities__state_r_aa72ec_idx",
),
migrations.RemoveIndex(
model_name="emoji",
name="ix_emoji_state_attempted",
),
migrations.RemoveIndex(
model_name="emoji",
name="ix_emoji_state_locked",
),
migrations.RemoveIndex(
model_name="fanout",
name="ix_fanout_state_attempted",
),
migrations.RemoveIndex(
model_name="fanout",
name="ix_fanout_state_locked",
),
migrations.RemoveIndex(
model_name="fanout",
name="activities__state_r_aae3b4_idx",
),
migrations.RemoveIndex(
model_name="hashtag",
name="ix_hashtag_state_attempted",
),
migrations.RemoveIndex(
model_name="hashtag",
name="ix_hashtag_state_locked",
),
migrations.RemoveIndex(
model_name="hashtag",
name="activities__state_r_5703be_idx",
),
migrations.RemoveIndex(
model_name="post",
name="ix_post_state_attempted",
),
migrations.RemoveIndex(
model_name="post",
name="ix_post_state_locked",
),
migrations.RemoveIndex(
model_name="post",
name="activities__state_r_b8f1ff_idx",
),
migrations.RemoveIndex(
model_name="postattachment",
name="ix_postattachm_state_attempted",
),
migrations.RemoveIndex(
model_name="postattachment",
name="ix_postattachm_state_locked",
),
migrations.RemoveIndex(
model_name="postattachment",
name="activities__state_r_4e981c_idx",
),
migrations.RemoveIndex(
model_name="postinteraction",
name="activities__state_r_981d8c_idx",
),
migrations.RemoveIndex(
model_name="postinteraction",
name="ix_postinterac_state_attempted",
),
migrations.RemoveIndex(
model_name="postinteraction",
name="ix_postinterac_state_locked",
),
migrations.RemoveField(
model_name="emoji",
name="state_attempted",
),
migrations.RemoveField(
model_name="emoji",
name="state_ready",
),
migrations.RemoveField(
model_name="fanout",
name="state_attempted",
),
migrations.RemoveField(
model_name="fanout",
name="state_ready",
),
migrations.RemoveField(
model_name="hashtag",
name="state_attempted",
),
migrations.RemoveField(
model_name="hashtag",
name="state_ready",
),
migrations.RemoveField(
model_name="post",
name="state_attempted",
),
migrations.RemoveField(
model_name="post",
name="state_ready",
),
migrations.RemoveField(
model_name="postattachment",
name="state_attempted",
),
migrations.RemoveField(
model_name="postattachment",
name="state_ready",
),
migrations.RemoveField(
model_name="postinteraction",
name="state_attempted",
),
migrations.RemoveField(
model_name="postinteraction",
name="state_ready",
),
migrations.AddField(
model_name="emoji",
name="state_next_attempt",
field=models.DateTimeField(blank=True, null=True),
),
migrations.AddField(
model_name="fanout",
name="state_next_attempt",
field=models.DateTimeField(blank=True, null=True),
),
migrations.AddField(
model_name="hashtag",
name="state_next_attempt",
field=models.DateTimeField(blank=True, null=True),
),
migrations.AddField(
model_name="post",
name="state_next_attempt",
field=models.DateTimeField(blank=True, null=True),
),
migrations.AddField(
model_name="postattachment",
name="state_next_attempt",
field=models.DateTimeField(blank=True, null=True),
),
migrations.AddField(
model_name="postinteraction",
name="state_next_attempt",
field=models.DateTimeField(blank=True, null=True),
),
migrations.AlterField(
model_name="emoji",
name="state_locked_until",
field=models.DateTimeField(blank=True, db_index=True, null=True),
),
migrations.AlterField(
model_name="fanout",
name="state_locked_until",
field=models.DateTimeField(blank=True, db_index=True, null=True),
),
migrations.AlterField(
model_name="hashtag",
name="state_locked_until",
field=models.DateTimeField(blank=True, db_index=True, null=True),
),
migrations.AlterField(
model_name="post",
name="state_locked_until",
field=models.DateTimeField(blank=True, db_index=True, null=True),
),
migrations.AlterField(
model_name="postattachment",
name="state_locked_until",
field=models.DateTimeField(blank=True, db_index=True, null=True),
),
migrations.AlterField(
model_name="postinteraction",
name="state_locked_until",
field=models.DateTimeField(blank=True, db_index=True, null=True),
),
migrations.AddIndex(
model_name="emoji",
index=models.Index(
fields=["state", "state_next_attempt", "state_locked_until"],
name="ix_emoji_state_next",
),
),
migrations.AddIndex(
model_name="fanout",
index=models.Index(
fields=["state", "state_next_attempt", "state_locked_until"],
name="ix_fanout_state_next",
),
),
migrations.AddIndex(
model_name="hashtag",
index=models.Index(
fields=["state", "state_next_attempt", "state_locked_until"],
name="ix_hashtag_state_next",
),
),
migrations.AddIndex(
model_name="post",
index=models.Index(
fields=["state", "state_next_attempt", "state_locked_until"],
name="ix_post_state_next",
),
),
migrations.AddIndex(
model_name="postattachment",
index=models.Index(
fields=["state", "state_next_attempt", "state_locked_until"],
name="ix_postattachm_state_next",
),
),
migrations.AddIndex(
model_name="postinteraction",
index=models.Index(
fields=["state", "state_next_attempt", "state_locked_until"],
name="ix_postinterac_state_next",
),
),
]

View file

@ -127,7 +127,7 @@ class Emoji(StatorModel):
class Meta: class Meta:
unique_together = ("domain", "shortcode") unique_together = ("domain", "shortcode")
indexes = StatorModel.Meta.indexes indexes: list = [] # We need this so Stator can add its own
class urls(urlman.Urls): class urls(urlman.Urls):
admin = "/admin/emoji/" admin = "/admin/emoji/"
@ -314,11 +314,11 @@ class Emoji(StatorModel):
emoji.remote_url = icon["url"] emoji.remote_url = icon["url"]
emoji.mimetype = mimetype emoji.mimetype = mimetype
emoji.category = category emoji.category = category
emoji.transition_set_state("outdated")
if emoji.file: if emoji.file:
emoji.file.delete(save=True) emoji.file.delete(save=True)
else: else:
emoji.save() emoji.save()
emoji.transition_perform("outdated")
return emoji return emoji
emoji = cls.objects.create( emoji = cls.objects.create(

View file

@ -1,5 +1,5 @@
import httpx import httpx
from asgiref.sync import sync_to_async from asgiref.sync import async_to_sync
from django.db import models from django.db import models
from activities.models.timeline_event import TimelineEvent from activities.models.timeline_event import TimelineEvent
@ -19,26 +19,24 @@ class FanOutStates(StateGraph):
new.times_out_to(failed, seconds=86400 * 3) new.times_out_to(failed, seconds=86400 * 3)
@classmethod @classmethod
async def handle_new(cls, instance: "FanOut"): def handle_new(cls, instance: "FanOut"):
""" """
Sends the fan-out to the right inbox. Sends the fan-out to the right inbox.
""" """
fan_out = await instance.afetch_full()
# Don't try to fan out to identities that are not fetched yet # Don't try to fan out to identities that are not fetched yet
if not (fan_out.identity.local or fan_out.identity.inbox_uri): if not (instance.identity.local or instance.identity.inbox_uri):
return return
match (fan_out.type, fan_out.identity.local): match (instance.type, instance.identity.local):
# Handle creating/updating local posts # Handle creating/updating local posts
case ((FanOut.Types.post | FanOut.Types.post_edited), True): case ((FanOut.Types.post | FanOut.Types.post_edited), True):
post = await fan_out.subject_post.afetch_full() post = instance.subject_post
# If the author of the post is blocked or muted, skip out # If the author of the post is blocked or muted, skip out
if ( if (
await Block.objects.active() Block.objects.active()
.filter(source=fan_out.identity, target=post.author) .filter(source=instance.identity, target=post.author)
.aexists() .exists()
): ):
return cls.skipped return cls.skipped
# Make a timeline event directly # Make a timeline event directly
@ -48,42 +46,42 @@ class FanOutStates(StateGraph):
add = True add = True
mentioned = {identity.id for identity in post.mentions.all()} mentioned = {identity.id for identity in post.mentions.all()}
if post.in_reply_to: if post.in_reply_to:
followed = await sync_to_async(set)( followed = set(
fan_out.identity.outbound_follows.filter( instance.identity.outbound_follows.filter(
state__in=FollowStates.group_active() state__in=FollowStates.group_active()
).values_list("target_id", flat=True) ).values_list("target_id", flat=True)
) )
interested_in = followed.union( interested_in = followed.union(
{post.author_id, fan_out.identity_id} {post.author_id, instance.identity_id}
) )
add = (post.author_id in followed) and ( add = (post.author_id in followed) and (
bool(mentioned.intersection(interested_in)) bool(mentioned.intersection(interested_in))
) )
if add: if add:
await sync_to_async(TimelineEvent.add_post)( TimelineEvent.add_post(
identity=fan_out.identity, identity=instance.identity,
post=post, post=post,
) )
# We might have been mentioned # We might have been mentioned
if ( if (
fan_out.identity.id in mentioned instance.identity.id in mentioned
and fan_out.identity_id != post.author_id and instance.identity_id != post.author_id
): ):
await sync_to_async(TimelineEvent.add_mentioned)( TimelineEvent.add_mentioned(
identity=fan_out.identity, identity=instance.identity,
post=post, post=post,
) )
# Handle sending remote posts create # Handle sending remote posts create
case (FanOut.Types.post, False): case (FanOut.Types.post, False):
post = await fan_out.subject_post.afetch_full() post = instance.subject_post
# Sign it and send it # Sign it and send it
try: try:
await post.author.signed_request( async_to_sync(post.author.signed_request)(
method="post", method="post",
uri=( uri=(
fan_out.identity.shared_inbox_uri instance.identity.shared_inbox_uri
or fan_out.identity.inbox_uri or instance.identity.inbox_uri
), ),
body=canonicalise(post.to_create_ap()), body=canonicalise(post.to_create_ap()),
) )
@ -92,14 +90,14 @@ class FanOutStates(StateGraph):
# Handle sending remote posts update # Handle sending remote posts update
case (FanOut.Types.post_edited, False): case (FanOut.Types.post_edited, False):
post = await fan_out.subject_post.afetch_full() post = instance.subject_post
# Sign it and send it # Sign it and send it
try: try:
await post.author.signed_request( async_to_sync(post.author.signed_request)(
method="post", method="post",
uri=( uri=(
fan_out.identity.shared_inbox_uri instance.identity.shared_inbox_uri
or fan_out.identity.inbox_uri or instance.identity.inbox_uri
), ),
body=canonicalise(post.to_update_ap()), body=canonicalise(post.to_update_ap()),
) )
@ -108,24 +106,24 @@ class FanOutStates(StateGraph):
# Handle deleting local posts # Handle deleting local posts
case (FanOut.Types.post_deleted, True): case (FanOut.Types.post_deleted, True):
post = await fan_out.subject_post.afetch_full() post = instance.subject_post
if fan_out.identity.local: if instance.identity.local:
# Remove all timeline events mentioning it # Remove all timeline events mentioning it
await TimelineEvent.objects.filter( TimelineEvent.objects.filter(
identity=fan_out.identity, identity=instance.identity,
subject_post=post, subject_post=post,
).adelete() ).delete()
# Handle sending remote post deletes # Handle sending remote post deletes
case (FanOut.Types.post_deleted, False): case (FanOut.Types.post_deleted, False):
post = await fan_out.subject_post.afetch_full() post = instance.subject_post
# Send it to the remote inbox # Send it to the remote inbox
try: try:
await post.author.signed_request( async_to_sync(post.author.signed_request)(
method="post", method="post",
uri=( uri=(
fan_out.identity.shared_inbox_uri instance.identity.shared_inbox_uri
or fan_out.identity.inbox_uri or instance.identity.inbox_uri
), ),
body=canonicalise(post.to_delete_ap()), body=canonicalise(post.to_delete_ap()),
) )
@ -134,38 +132,38 @@ class FanOutStates(StateGraph):
# Handle local boosts/likes # Handle local boosts/likes
case (FanOut.Types.interaction, True): case (FanOut.Types.interaction, True):
interaction = await fan_out.subject_post_interaction.afetch_full() interaction = instance.subject_post_interaction
# If the author of the interaction is blocked or their notifications # If the author of the interaction is blocked or their notifications
# are muted, skip out # are muted, skip out
if ( if (
await Block.objects.active() Block.objects.active()
.filter( .filter(
models.Q(mute=False) | models.Q(include_notifications=True), models.Q(mute=False) | models.Q(include_notifications=True),
source=fan_out.identity, source=instance.identity,
target=interaction.identity, target=interaction.identity,
) )
.aexists() .exists()
): ):
return cls.skipped return cls.skipped
# If blocked/muted the underlying post author, skip out # If blocked/muted the underlying post author, skip out
if ( if (
await Block.objects.active() Block.objects.active()
.filter( .filter(
source=fan_out.identity, source=instance.identity,
target_id=interaction.post.author_id, target_id=interaction.post.author_id,
) )
.aexists() .exists()
): ):
return cls.skipped return cls.skipped
# Make a timeline event directly # Make a timeline event directly
await sync_to_async(TimelineEvent.add_post_interaction)( TimelineEvent.add_post_interaction(
identity=fan_out.identity, identity=instance.identity,
interaction=interaction, interaction=interaction,
) )
# Handle sending remote boosts/likes/votes/pins # Handle sending remote boosts/likes/votes/pins
case (FanOut.Types.interaction, False): case (FanOut.Types.interaction, False):
interaction = await fan_out.subject_post_interaction.afetch_full() interaction = instance.subject_post_interaction
# Send it to the remote inbox # Send it to the remote inbox
try: try:
if interaction.type == interaction.Types.vote: if interaction.type == interaction.Types.vote:
@ -174,11 +172,11 @@ class FanOutStates(StateGraph):
body = interaction.to_add_ap() body = interaction.to_add_ap()
else: else:
body = interaction.to_create_ap() body = interaction.to_create_ap()
await interaction.identity.signed_request( async_to_sync(interaction.identity.signed_request)(
method="post", method="post",
uri=( uri=(
fan_out.identity.shared_inbox_uri instance.identity.shared_inbox_uri
or fan_out.identity.inbox_uri or instance.identity.inbox_uri
), ),
body=canonicalise(body), body=canonicalise(body),
) )
@ -187,28 +185,28 @@ class FanOutStates(StateGraph):
# Handle undoing local boosts/likes # Handle undoing local boosts/likes
case (FanOut.Types.undo_interaction, True): # noqa:F841 case (FanOut.Types.undo_interaction, True): # noqa:F841
interaction = await fan_out.subject_post_interaction.afetch_full() interaction = instance.subject_post_interaction
# Delete any local timeline events # Delete any local timeline events
await sync_to_async(TimelineEvent.delete_post_interaction)( TimelineEvent.delete_post_interaction(
identity=fan_out.identity, identity=instance.identity,
interaction=interaction, interaction=interaction,
) )
# Handle sending remote undoing boosts/likes/pins # Handle sending remote undoing boosts/likes/pins
case (FanOut.Types.undo_interaction, False): # noqa:F841 case (FanOut.Types.undo_interaction, False): # noqa:F841
interaction = await fan_out.subject_post_interaction.afetch_full() interaction = instance.subject_post_interaction
# Send an undo to the remote inbox # Send an undo to the remote inbox
try: try:
if interaction.type == interaction.Types.pin: if interaction.type == interaction.Types.pin:
body = interaction.to_remove_ap() body = interaction.to_remove_ap()
else: else:
body = interaction.to_undo_ap() body = interaction.to_undo_ap()
await interaction.identity.signed_request( async_to_sync(interaction.identity.signed_request)(
method="post", method="post",
uri=( uri=(
fan_out.identity.shared_inbox_uri instance.identity.shared_inbox_uri
or fan_out.identity.inbox_uri or instance.identity.inbox_uri
), ),
body=canonicalise(body), body=canonicalise(body),
) )
@ -217,32 +215,30 @@ class FanOutStates(StateGraph):
# Handle sending identity edited to remote # Handle sending identity edited to remote
case (FanOut.Types.identity_edited, False): case (FanOut.Types.identity_edited, False):
identity = await fan_out.subject_identity.afetch_full() identity = instance.subject_identity
try: try:
await identity.signed_request( async_to_sync(identity.signed_request)(
method="post", method="post",
uri=( uri=(
fan_out.identity.shared_inbox_uri instance.identity.shared_inbox_uri
or fan_out.identity.inbox_uri or instance.identity.inbox_uri
),
body=canonicalise(
await sync_to_async(fan_out.subject_identity.to_update_ap)()
), ),
body=canonicalise(instance.subject_identity.to_update_ap()),
) )
except httpx.RequestError: except httpx.RequestError:
return return
# Handle sending identity deleted to remote # Handle sending identity deleted to remote
case (FanOut.Types.identity_deleted, False): case (FanOut.Types.identity_deleted, False):
identity = await fan_out.subject_identity.afetch_full() identity = instance.subject_identity
try: try:
await identity.signed_request( async_to_sync(identity.signed_request)(
method="post", method="post",
uri=( uri=(
fan_out.identity.shared_inbox_uri instance.identity.shared_inbox_uri
or fan_out.identity.inbox_uri or instance.identity.inbox_uri
), ),
body=canonicalise(fan_out.subject_identity.to_delete_ap()), body=canonicalise(instance.subject_identity.to_delete_ap()),
) )
except httpx.RequestError: except httpx.RequestError:
return return
@ -255,14 +251,14 @@ class FanOutStates(StateGraph):
# Created identities make a timeline event # Created identities make a timeline event
case (FanOut.Types.identity_created, True): case (FanOut.Types.identity_created, True):
await sync_to_async(TimelineEvent.add_identity_created)( TimelineEvent.add_identity_created(
identity=fan_out.identity, identity=instance.identity,
new_identity=fan_out.subject_identity, new_identity=instance.subject_identity,
) )
case _: case _:
raise ValueError( raise ValueError(
f"Cannot fan out with type {fan_out.type} local={fan_out.identity.local}" f"Cannot fan out with type {instance.type} local={instance.identity.local}"
) )
return cls.sent return cls.sent

View file

@ -22,12 +22,8 @@ class HashtagStates(StateGraph):
""" """
Computes the stats and other things for a Hashtag Computes the stats and other things for a Hashtag
""" """
from time import time
from .post import Post from .post import Post
start = time()
posts_query = Post.objects.local_public().tagged_with(instance) posts_query = Post.objects.local_public().tagged_with(instance)
total = await posts_query.acount() total = await posts_query.acount()
@ -57,7 +53,6 @@ class HashtagStates(StateGraph):
instance.stats_updated = timezone.now() instance.stats_updated = timezone.now()
await sync_to_async(instance.save)() await sync_to_async(instance.save)()
print(f"Updated hashtag {instance.hashtag} in {time() - start:.5f} seconds")
return cls.updated return cls.updated
@ -86,7 +81,6 @@ class HashtagManager(models.Manager):
class Hashtag(StatorModel): class Hashtag(StatorModel):
MAXIMUM_LENGTH = 100 MAXIMUM_LENGTH = 100
# Normalized hashtag without the '#' # Normalized hashtag without the '#'

View file

@ -8,7 +8,7 @@ from urllib.parse import urlparse
import httpx import httpx
import urlman import urlman
from asgiref.sync import async_to_sync, sync_to_async from asgiref.sync import async_to_sync
from django.contrib.postgres.indexes import GinIndex from django.contrib.postgres.indexes import GinIndex
from django.contrib.postgres.search import SearchVector from django.contrib.postgres.search import SearchVector
from django.db import models, transaction from django.db import models, transaction
@ -63,45 +63,44 @@ class PostStates(StateGraph):
edited_fanned_out.transitions_to(deleted) edited_fanned_out.transitions_to(deleted)
@classmethod @classmethod
async def targets_fan_out(cls, post: "Post", type_: str) -> None: def targets_fan_out(cls, post: "Post", type_: str) -> None:
# Fan out to each target # Fan out to each target
for follow in await post.aget_targets(): for follow in post.get_targets():
await FanOut.objects.acreate( FanOut.objects.create(
identity=follow, identity=follow,
type=type_, type=type_,
subject_post=post, subject_post=post,
) )
@classmethod @classmethod
async def handle_new(cls, instance: "Post"): def handle_new(cls, instance: "Post"):
""" """
Creates all needed fan-out objects for a new Post. Creates all needed fan-out objects for a new Post.
""" """
post = await instance.afetch_full()
# Only fan out if the post was published in the last day or it's local # Only fan out if the post was published in the last day or it's local
# (we don't want to fan out anything older that that which is remote) # (we don't want to fan out anything older that that which is remote)
if post.local or (timezone.now() - post.published) < datetime.timedelta(days=1): if instance.local or (timezone.now() - instance.published) < datetime.timedelta(
await cls.targets_fan_out(post, FanOut.Types.post) days=1
await post.ensure_hashtags() ):
cls.targets_fan_out(instance, FanOut.Types.post)
instance.ensure_hashtags()
return cls.fanned_out return cls.fanned_out
@classmethod @classmethod
async def handle_deleted(cls, instance: "Post"): def handle_deleted(cls, instance: "Post"):
""" """
Creates all needed fan-out objects needed to delete a Post. Creates all needed fan-out objects needed to delete a Post.
""" """
post = await instance.afetch_full() cls.targets_fan_out(instance, FanOut.Types.post_deleted)
await cls.targets_fan_out(post, FanOut.Types.post_deleted)
return cls.deleted_fanned_out return cls.deleted_fanned_out
@classmethod @classmethod
async def handle_edited(cls, instance: "Post"): def handle_edited(cls, instance: "Post"):
""" """
Creates all needed fan-out objects for an edited Post. Creates all needed fan-out objects for an edited Post.
""" """
post = await instance.afetch_full() cls.targets_fan_out(instance, FanOut.Types.post_edited)
await cls.targets_fan_out(post, FanOut.Types.post_edited) instance.ensure_hashtags()
await post.ensure_hashtags()
return cls.edited_fanned_out return cls.edited_fanned_out
@ -324,7 +323,7 @@ class Post(StatorModel):
fields=["visibility", "local", "created"], fields=["visibility", "local", "created"],
name="ix_post_local_public_created", name="ix_post_local_public_created",
), ),
] + StatorModel.Meta.indexes ]
class urls(urlman.Urls): class urls(urlman.Urls):
view = "{self.author.urls.view}posts/{self.id}/" view = "{self.author.urls.view}posts/{self.id}/"
@ -375,8 +374,6 @@ class Post(StatorModel):
.first() .first()
) )
ain_reply_to_post = sync_to_async(in_reply_to_post)
### Content cleanup and extraction ### ### Content cleanup and extraction ###
def clean_type_data(self, value): def clean_type_data(self, value):
PostTypeData.parse_obj(value) PostTypeData.parse_obj(value)
@ -552,6 +549,8 @@ class Post(StatorModel):
attachment.name = attrs.description attachment.name = attrs.description
attachment.save() attachment.save()
self.transition_perform(PostStates.edited)
@classmethod @classmethod
def mentions_from_content(cls, content, author) -> set[Identity]: def mentions_from_content(cls, content, author) -> set[Identity]:
mention_hits = FediverseHtmlParser(content, find_mentions=True).mentions mention_hits = FediverseHtmlParser(content, find_mentions=True).mentions
@ -572,7 +571,7 @@ class Post(StatorModel):
mentions.add(identity) mentions.add(identity)
return mentions return mentions
async def ensure_hashtags(self) -> None: def ensure_hashtags(self) -> None:
""" """
Ensure any of the already parsed hashtags from this Post Ensure any of the already parsed hashtags from this Post
have a corresponding Hashtag record. have a corresponding Hashtag record.
@ -580,10 +579,10 @@ class Post(StatorModel):
# Ensure hashtags # Ensure hashtags
if self.hashtags: if self.hashtags:
for hashtag in self.hashtags: for hashtag in self.hashtags:
tag, _ = await Hashtag.objects.aget_or_create( tag, _ = Hashtag.objects.get_or_create(
hashtag=hashtag[: Hashtag.MAXIMUM_LENGTH], hashtag=hashtag[: Hashtag.MAXIMUM_LENGTH],
) )
await tag.atransition_perform(HashtagStates.outdated) tag.transition_perform(HashtagStates.outdated)
def calculate_stats(self, save=True): def calculate_stats(self, save=True):
""" """
@ -739,33 +738,33 @@ class Post(StatorModel):
"object": object, "object": object,
} }
async def aget_targets(self) -> Iterable[Identity]: def get_targets(self) -> Iterable[Identity]:
""" """
Returns a list of Identities that need to see posts and their changes Returns a list of Identities that need to see posts and their changes
""" """
targets = set() targets = set()
async for mention in self.mentions.all(): for mention in self.mentions.all():
targets.add(mention) targets.add(mention)
# Then, if it's not mentions only, also deliver to followers and all hashtag followers # Then, if it's not mentions only, also deliver to followers and all hashtag followers
if self.visibility != Post.Visibilities.mentioned: if self.visibility != Post.Visibilities.mentioned:
async for follower in self.author.inbound_follows.filter( for follower in self.author.inbound_follows.filter(
state__in=FollowStates.group_active() state__in=FollowStates.group_active()
).select_related("source"): ).select_related("source"):
targets.add(follower.source) targets.add(follower.source)
if self.hashtags: if self.hashtags:
async for follow in HashtagFollow.objects.by_hashtags( for follow in HashtagFollow.objects.by_hashtags(
self.hashtags self.hashtags
).prefetch_related("identity"): ).prefetch_related("identity"):
targets.add(follow.identity) targets.add(follow.identity)
# If it's a reply, always include the original author if we know them # If it's a reply, always include the original author if we know them
reply_post = await self.ain_reply_to_post() reply_post = self.in_reply_to_post()
if reply_post: if reply_post:
targets.add(reply_post.author) targets.add(reply_post.author)
# And if it's a reply to one of our own, we have to re-fan-out to # And if it's a reply to one of our own, we have to re-fan-out to
# the original author's followers # the original author's followers
if reply_post.author.local: if reply_post.author.local:
async for follower in reply_post.author.inbound_follows.filter( for follower in reply_post.author.inbound_follows.filter(
state__in=FollowStates.group_active() state__in=FollowStates.group_active()
).select_related("source"): ).select_related("source"):
targets.add(follower.source) targets.add(follower.source)
@ -782,7 +781,7 @@ class Post(StatorModel):
.filter(mute=False) .filter(mute=False)
.select_related("target") .select_related("target")
) )
async for block in blocks: for block in blocks:
try: try:
targets.remove(block.target) targets.remove(block.target)
except KeyError: except KeyError:

View file

@ -179,9 +179,7 @@ class PostInteraction(StatorModel):
updated = models.DateTimeField(auto_now=True) updated = models.DateTimeField(auto_now=True)
class Meta: class Meta:
indexes = [ indexes = [models.Index(fields=["type", "identity", "post"])]
models.Index(fields=["type", "identity", "post"])
] + StatorModel.Meta.indexes
### Display helpers ### ### Display helpers ###

View file

@ -1,4 +1,5 @@
from django.conf import settings from django.conf import settings
from django.core.cache import cache
from hatchway import api_view from hatchway import api_view
from activities.models import Post from activities.models import Post
@ -10,6 +11,15 @@ from users.models import Domain, Identity
@api_view.get @api_view.get
def instance_info_v1(request): def instance_info_v1(request):
# The stats are expensive to calculate, so don't do it very often
stats = cache.get("instance_info_stats")
if stats is None:
stats = {
"user_count": Identity.objects.filter(local=True).count(),
"status_count": Post.objects.filter(local=True).not_hidden().count(),
"domain_count": Domain.objects.count(),
}
cache.set("instance_info_stats", stats, timeout=300)
return { return {
"uri": request.headers.get("host", settings.SETUP.MAIN_DOMAIN), "uri": request.headers.get("host", settings.SETUP.MAIN_DOMAIN),
"title": Config.system.site_name, "title": Config.system.site_name,
@ -18,11 +28,7 @@ def instance_info_v1(request):
"email": "", "email": "",
"version": f"takahe/{__version__}", "version": f"takahe/{__version__}",
"urls": {}, "urls": {},
"stats": { "stats": stats,
"user_count": Identity.objects.filter(local=True).count(),
"status_count": Post.objects.filter(local=True).not_hidden().count(),
"domain_count": Domain.objects.count(),
},
"thumbnail": Config.system.site_banner, "thumbnail": Config.system.site_banner,
"languages": ["en"], "languages": ["en"],
"registrations": (Config.system.signup_allowed), "registrations": (Config.system.signup_allowed),

View file

@ -1,6 +1,5 @@
import traceback import traceback
from asgiref.sync import sync_to_async
from django.conf import settings from django.conf import settings
@ -40,6 +39,3 @@ def capture_exception(exception: BaseException, scope=None, **scope_args):
capture_exception(exception, scope, **scope_args) capture_exception(exception, scope, **scope_args)
elif settings.DEBUG: elif settings.DEBUG:
traceback.print_exc() traceback.print_exc()
acapture_exception = sync_to_async(capture_exception, thread_sensitive=False)

40
docs/releases/0.10.rst Normal file
View file

@ -0,0 +1,40 @@
0.9
===
*Released: Not Yet Released*
This release is a polish release that is prepping us for the road to 1.0.
This release's major changes:
* Stator, the background task system, has been significantly reworked to require
smaller indexes, spend less time scheduling, and has had most of its async
nature removed, as this both reduces deadlocks and improves performance in
most situations (the context switching was costing more than the gains from
talking to other servers asynchronously).
* TBC
If you'd like to help with code, design, or other areas, see
:doc:`/contributing` to see how to get in touch.
You can download images from `Docker Hub <https://hub.docker.com/r/jointakahe/takahe>`_,
or use the image name ``jointakahe/takahe:0.10``.
Upgrade Notes
-------------
Migrations
~~~~~~~~~~
There are new database migrations; they are backwards-compatible, but contain
very significant index changes to all of the main tables that may cause the
PostgreSQL deadlock detector to trigger if you attempt to apply them while your
site is live.
We recommend:
* Temporarily stopping all instances of the webserver and Stator
* Applying the migration (should be less than a few minutes on most installs)
* Restarting the instances of webserver and Stator

View file

@ -12,6 +12,10 @@ addopts = --tb=short --ds=takahe.settings --import-mode=importlib
filterwarnings = filterwarnings =
ignore:There is no current event loop ignore:There is no current event loop
ignore:No directory at ignore:No directory at
ignore:DateTimeField Post.created
ignore:'index_together' is deprecated
ignore:Deprecated call to
ignore:pkg_resources is deprecated as an API
[mypy] [mypy]
warn_unused_ignores = True warn_unused_ignores = True

View file

@ -13,6 +13,7 @@ class StateGraph:
initial_state: ClassVar["State"] initial_state: ClassVar["State"]
terminal_states: ClassVar[set["State"]] terminal_states: ClassVar[set["State"]]
automatic_states: ClassVar[set["State"]] automatic_states: ClassVar[set["State"]]
deletion_states: ClassVar[set["State"]]
def __init_subclass__(cls) -> None: def __init_subclass__(cls) -> None:
# Collect state members # Collect state members
@ -33,6 +34,7 @@ class StateGraph:
# Check the graph layout # Check the graph layout
terminal_states = set() terminal_states = set()
automatic_states = set() automatic_states = set()
deletion_states = set()
initial_state = None initial_state = None
for state in cls.states.values(): for state in cls.states.values():
# Check for multiple initial states # Check for multiple initial states
@ -42,6 +44,9 @@ class StateGraph:
f"The graph has more than one initial state: {initial_state} and {state}" f"The graph has more than one initial state: {initial_state} and {state}"
) )
initial_state = state initial_state = state
# Collect states that require deletion handling (they can be terminal or not)
if state.delete_after:
deletion_states.add(state)
# Collect terminal states # Collect terminal states
if state.terminal: if state.terminal:
state.externally_progressed = True state.externally_progressed = True
@ -74,6 +79,7 @@ class StateGraph:
cls.initial_state = initial_state cls.initial_state = initial_state
cls.terminal_states = terminal_states cls.terminal_states = terminal_states
cls.automatic_states = automatic_states cls.automatic_states = automatic_states
cls.deletion_states = deletion_states
# Generate choices # Generate choices
cls.choices = [(name, name) for name in cls.states.keys()] cls.choices = [(name, name) for name in cls.states.keys()]
@ -98,6 +104,9 @@ class State:
self.attempt_immediately = attempt_immediately self.attempt_immediately = attempt_immediately
self.force_initial = force_initial self.force_initial = force_initial
self.delete_after = delete_after self.delete_after = delete_after
# Deletes are also only attempted on try_intervals
if self.delete_after and not self.try_interval:
self.try_interval = self.delete_after
self.parents: set["State"] = set() self.parents: set["State"] = set()
self.children: set["State"] = set() self.children: set["State"] = set()
self.timeout_state: State | None = None self.timeout_state: State | None = None

View file

@ -1,6 +1,5 @@
from typing import cast from typing import cast
from asgiref.sync import async_to_sync
from django.apps import apps from django.apps import apps
from django.core.management.base import BaseCommand from django.core.management.base import BaseCommand
@ -84,6 +83,6 @@ class Command(BaseCommand):
run_for=run_for, run_for=run_for,
) )
try: try:
async_to_sync(runner.run)() runner.run()
except KeyboardInterrupt: except KeyboardInterrupt:
print("Ctrl-C received") print("Ctrl-C received")

View file

@ -1,8 +1,8 @@
import datetime import datetime
import traceback import traceback
from typing import ClassVar, cast from typing import ClassVar
from asgiref.sync import sync_to_async from asgiref.sync import async_to_sync, iscoroutinefunction
from django.db import models, transaction from django.db import models, transaction
from django.db.models.signals import class_prepared from django.db.models.signals import class_prepared
from django.utils import timezone from django.utils import timezone
@ -47,19 +47,15 @@ def add_stator_indexes(sender, **kwargs):
if issubclass(sender, StatorModel): if issubclass(sender, StatorModel):
indexes = [ indexes = [
models.Index( models.Index(
fields=["state", "state_attempted"], fields=["state", "state_next_attempt", "state_locked_until"],
name=f"ix_{sender.__name__.lower()[:11]}_state_attempted", name=f"ix_{sender.__name__.lower()[:11]}_state_next",
),
models.Index(
fields=["state_locked_until", "state"],
condition=models.Q(state_locked_until__isnull=False),
name=f"ix_{sender.__name__.lower()[:11]}_state_locked",
), ),
] ]
if not sender._meta.indexes: if not sender._meta.indexes:
# Meta.indexes needs to not be None to trigger Django behaviors # Meta.indexes needs to not be None to trigger Django behaviors
sender.Meta.indexes = [] sender.Meta.indexes = []
sender._meta.indexes = []
for idx in indexes: for idx in indexes:
sender._meta.indexes.append(idx) sender._meta.indexes.append(idx)
@ -81,30 +77,26 @@ class StatorModel(models.Model):
concrete model yourself. concrete model yourself.
""" """
SCHEDULE_BATCH_SIZE = 1000 CLEAN_BATCH_SIZE = 1000
DELETE_BATCH_SIZE = 500
state: StateField state: StateField
# If this row is up for transition attempts (which it always is on creation!)
state_ready = models.BooleanField(default=True)
# When the state last actually changed, or the date of instance creation # When the state last actually changed, or the date of instance creation
state_changed = models.DateTimeField(auto_now_add=True) state_changed = models.DateTimeField(auto_now_add=True)
# When the last state change for the current state was attempted # When the next state change should be attempted (null means immediately)
# (and not successful, as this is cleared on transition) state_next_attempt = models.DateTimeField(blank=True, null=True)
state_attempted = models.DateTimeField(blank=True, null=True)
# If a lock is out on this row, when it is locked until # If a lock is out on this row, when it is locked until
# (we don't identify the lock owner, as there's no heartbeats) # (we don't identify the lock owner, as there's no heartbeats)
state_locked_until = models.DateTimeField(null=True, blank=True) state_locked_until = models.DateTimeField(null=True, blank=True, db_index=True)
# Collection of subclasses of us # Collection of subclasses of us
subclasses: ClassVar[list[type["StatorModel"]]] = [] subclasses: ClassVar[list[type["StatorModel"]]] = []
class Meta: class Meta:
abstract = True abstract = True
indexes = [models.Index(fields=["state_ready", "state_locked_until", "state"])]
def __init_subclass__(cls) -> None: def __init_subclass__(cls) -> None:
if cls is not StatorModel: if cls is not StatorModel:
@ -118,52 +110,6 @@ class StatorModel(models.Model):
def state_age(self) -> float: def state_age(self) -> float:
return (timezone.now() - self.state_changed).total_seconds() return (timezone.now() - self.state_changed).total_seconds()
@classmethod
async def atransition_schedule_due(cls, now=None):
"""
Finds instances of this model that need to run and schedule them.
"""
if now is None:
now = timezone.now()
q = models.Q()
for state in cls.state_graph.states.values():
state = cast(State, state)
if not state.externally_progressed:
q = q | models.Q(
(
models.Q(
state_attempted__lte=(
now
- datetime.timedelta(
seconds=cast(float, state.try_interval)
)
)
)
| models.Q(state_attempted__isnull=True)
),
state=state.name,
)
select_query = cls.objects.filter(q)[: cls.SCHEDULE_BATCH_SIZE]
await cls.objects.filter(pk__in=select_query).aupdate(state_ready=True)
@classmethod
async def atransition_delete_due(cls, now=None):
"""
Finds instances of this model that need to be deleted and deletes them.
"""
if now is None:
now = timezone.now()
for state in cls.state_graph.states.values():
state = cast(State, state)
if state.delete_after:
select_query = cls.objects.filter(
state=state,
state_changed__lte=(
now - datetime.timedelta(seconds=state.delete_after)
),
)[: cls.SCHEDULE_BATCH_SIZE]
await cls.objects.filter(pk__in=select_query).adelete()
@classmethod @classmethod
def transition_get_with_lock( def transition_get_with_lock(
cls, number: int, lock_expiry: datetime.datetime cls, number: int, lock_expiry: datetime.datetime
@ -172,11 +118,17 @@ class StatorModel(models.Model):
Returns up to `number` tasks for execution, having locked them. Returns up to `number` tasks for execution, having locked them.
""" """
with transaction.atomic(): with transaction.atomic():
# Query for `number` rows that:
# - Have a next_attempt that's either null or in the past
# - Have one of the states we care about
# Then, sort them by next_attempt NULLS FIRST, so that we handle the
# rows in a roughly FIFO order.
selected = list( selected = list(
cls.objects.filter( cls.objects.filter(
state_locked_until__isnull=True, models.Q(state_next_attempt__isnull=True)
state_ready=True, | models.Q(state_next_attempt__lte=timezone.now()),
state__in=cls.state_graph.automatic_states, state__in=cls.state_graph.automatic_states,
state_locked_until__isnull=True,
)[:number].select_for_update() )[:number].select_for_update()
) )
cls.objects.filter(pk__in=[i.pk for i in selected]).update( cls.objects.filter(pk__in=[i.pk for i in selected]).update(
@ -185,44 +137,56 @@ class StatorModel(models.Model):
return selected return selected
@classmethod @classmethod
async def atransition_get_with_lock( def transition_delete_due(cls) -> int | None:
cls, number: int, lock_expiry: datetime.datetime """
) -> list["StatorModel"]: Finds instances of this model that need to be deleted and deletes them
return await sync_to_async(cls.transition_get_with_lock)(number, lock_expiry) in small batches. Returns how many were deleted.
"""
if cls.state_graph.deletion_states:
constraints = models.Q()
for state in cls.state_graph.deletion_states:
constraints |= models.Q(
state=state,
state_changed__lte=(
timezone.now() - datetime.timedelta(seconds=state.delete_after)
),
)
select_query = cls.objects.filter(
models.Q(state_next_attempt__isnull=True)
| models.Q(state_next_attempt__lte=timezone.now()),
constraints,
)[: cls.DELETE_BATCH_SIZE]
return cls.objects.filter(pk__in=select_query).delete()[0]
return None
@classmethod @classmethod
async def atransition_ready_count(cls) -> int: def transition_ready_count(cls) -> int:
""" """
Returns how many instances are "queued" Returns how many instances are "queued"
""" """
return await cls.objects.filter( return cls.objects.filter(
models.Q(state_next_attempt__isnull=True)
| models.Q(state_next_attempt__lte=timezone.now()),
state_locked_until__isnull=True, state_locked_until__isnull=True,
state_ready=True,
state__in=cls.state_graph.automatic_states, state__in=cls.state_graph.automatic_states,
).acount() ).count()
@classmethod @classmethod
async def atransition_clean_locks(cls): def transition_clean_locks(cls):
"""
Deletes stale locks (in batches, to avoid a giant query)
"""
select_query = cls.objects.filter(state_locked_until__lte=timezone.now())[ select_query = cls.objects.filter(state_locked_until__lte=timezone.now())[
: cls.SCHEDULE_BATCH_SIZE : cls.CLEAN_BATCH_SIZE
] ]
await cls.objects.filter(pk__in=select_query).aupdate(state_locked_until=None) cls.objects.filter(pk__in=select_query).update(state_locked_until=None)
def transition_schedule(self): def transition_attempt(self) -> State | None:
"""
Adds this instance to the queue to get its state transition attempted.
The scheduler will call this, but you can also call it directly if you
know it'll be ready and want to lower latency.
"""
self.state_ready = True
self.save()
async def atransition_attempt(self) -> State | None:
""" """
Attempts to transition the current state by running its handler(s). Attempts to transition the current state by running its handler(s).
""" """
current_state: State = self.state_graph.states[self.state] current_state: State = self.state_graph.states[self.state]
# If it's a manual progression state don't even try # If it's a manual progression state don't even try
# We shouldn't really be here in this case, but it could be a race condition # We shouldn't really be here in this case, but it could be a race condition
if current_state.externally_progressed: if current_state.externally_progressed:
@ -230,12 +194,17 @@ class StatorModel(models.Model):
f"Warning: trying to progress externally progressed state {self.state}!" f"Warning: trying to progress externally progressed state {self.state}!"
) )
return None return None
# Try running its handler function
try: try:
next_state = await current_state.handler(self) # type: ignore if iscoroutinefunction(current_state.handler):
next_state = async_to_sync(current_state.handler)(self)
else:
next_state = current_state.handler(self)
except TryAgainLater: except TryAgainLater:
pass pass
except BaseException as e: except BaseException as e:
await exceptions.acapture_exception(e) exceptions.capture_exception(e)
traceback.print_exc() traceback.print_exc()
else: else:
if next_state: if next_state:
@ -247,20 +216,24 @@ class StatorModel(models.Model):
raise ValueError( raise ValueError(
f"Cannot transition from {current_state} to {next_state} - not a declared transition" f"Cannot transition from {current_state} to {next_state} - not a declared transition"
) )
await self.atransition_perform(next_state) self.transition_perform(next_state)
return next_state return next_state
# See if it timed out
# See if it timed out since its last state change
if ( if (
current_state.timeout_value current_state.timeout_value
and current_state.timeout_value and current_state.timeout_value
<= (timezone.now() - self.state_changed).total_seconds() <= (timezone.now() - self.state_changed).total_seconds()
): ):
await self.atransition_perform(current_state.timeout_state) self.transition_perform(current_state.timeout_state) # type: ignore
return current_state.timeout_state return current_state.timeout_state
await self.__class__.objects.filter(pk=self.pk).aupdate(
state_attempted=timezone.now(), # Nothing happened, set next execution and unlock it
self.__class__.objects.filter(pk=self.pk).update(
state_next_attempt=(
timezone.now() + datetime.timedelta(seconds=current_state.try_interval) # type: ignore
),
state_locked_until=None, state_locked_until=None,
state_ready=False,
) )
return None return None
@ -273,27 +246,6 @@ class StatorModel(models.Model):
state, state,
) )
atransition_perform = sync_to_async(transition_perform)
def transition_set_state(self, state: State | str):
"""
Sets the instance to the given state name for when it is saved.
"""
if isinstance(state, State):
state = state.name
if state not in self.state_graph.states:
raise ValueError(f"Invalid state {state}")
self.state = state # type: ignore
self.state_changed = timezone.now()
self.state_locked_until = None
if self.state_graph.states[state].attempt_immediately:
self.state_attempted = None
self.state_ready = True
else:
self.state_attempted = timezone.now()
self.state_ready = False
@classmethod @classmethod
def transition_perform_queryset( def transition_perform_queryset(
cls, cls,
@ -303,26 +255,27 @@ class StatorModel(models.Model):
""" """
Transitions every instance in the queryset to the given state name, forcibly. Transitions every instance in the queryset to the given state name, forcibly.
""" """
# Really ensure we have the right state object
if isinstance(state, State): if isinstance(state, State):
state = state.name state_obj = cls.state_graph.states[state.name]
if state not in cls.state_graph.states: else:
raise ValueError(f"Invalid state {state}") state_obj = cls.state_graph.states[state]
# See if it's ready immediately (if not, delay until first try_interval) # See if it's ready immediately (if not, delay until first try_interval)
if cls.state_graph.states[state].attempt_immediately: if state_obj.attempt_immediately or state_obj.try_interval is None:
queryset.update( queryset.update(
state=state, state=state_obj,
state_changed=timezone.now(), state_changed=timezone.now(),
state_attempted=None, state_next_attempt=None,
state_locked_until=None, state_locked_until=None,
state_ready=True,
) )
else: else:
queryset.update( queryset.update(
state=state, state=state_obj,
state_changed=timezone.now(), state_changed=timezone.now(),
state_attempted=timezone.now(), state_next_attempt=(
timezone.now() + datetime.timedelta(seconds=state_obj.try_interval)
),
state_locked_until=None, state_locked_until=None,
state_ready=False,
) )
@ -355,10 +308,6 @@ class Stats(models.Model):
instance.statistics[key] = {} instance.statistics[key] = {}
return instance return instance
@classmethod
async def aget_for_model(cls, model: type[StatorModel]) -> "Stats":
return await sync_to_async(cls.get_for_model)(model)
def set_queued(self, number: int): def set_queued(self, number: int):
""" """
Sets the current queued amount. Sets the current queued amount.

View file

@ -1,14 +1,13 @@
import asyncio
import datetime import datetime
import os import os
import signal import signal
import time import time
import traceback import traceback
import uuid import uuid
from collections.abc import Callable from concurrent.futures import Future, ThreadPoolExecutor
from asgiref.sync import async_to_sync, sync_to_async
from django.conf import settings from django.conf import settings
from django.db import close_old_connections
from django.utils import timezone from django.utils import timezone
from core import exceptions, sentry from core import exceptions, sentry
@ -16,26 +15,30 @@ from core.models import Config
from stator.models import StatorModel, Stats from stator.models import StatorModel, Stats
class LoopingTask: class LoopingTimer:
""" """
Wrapper for having a coroutine go in the background and only have one Triggers check() to be true once every `interval`.
copy running at a time.
""" """
def __init__(self, callable: Callable): next_run: float | None = None
self.callable = callable
self.task: asyncio.Task | None = None
def run(self) -> bool: def __init__(self, interval: float, trigger_at_start=True):
# If we have a task object, see if we can clear it up self.interval = interval
if self.task is not None: self.trigger_at_start = trigger_at_start
if self.task.done():
self.task = None def check(self) -> bool:
# See if it's our first time being called
if self.next_run is None:
# Set up the next call based on trigger_at_start
if self.trigger_at_start:
self.next_run = time.monotonic()
else: else:
return False self.next_run = time.monotonic() + self.interval
# OK, launch a new task # See if it's time to run the next call
self.task = asyncio.create_task(self.callable()) if time.monotonic() >= self.next_run:
self.next_run = time.monotonic() + self.interval
return True return True
return False
class StatorRunner: class StatorRunner:
@ -47,12 +50,13 @@ class StatorRunner:
def __init__( def __init__(
self, self,
models: list[type[StatorModel]], models: list[type[StatorModel]],
concurrency: int = getattr(settings, "STATOR_CONCURRENCY", 50), concurrency: int = getattr(settings, "STATOR_CONCURRENCY", 30),
concurrency_per_model: int = getattr( concurrency_per_model: int = getattr(
settings, "STATOR_CONCURRENCY_PER_MODEL", 15 settings, "STATOR_CONCURRENCY_PER_MODEL", 15
), ),
liveness_file: str | None = None, liveness_file: str | None = None,
schedule_interval: int = 30, schedule_interval: int = 60,
delete_interval: int = 30,
lock_expiry: int = 300, lock_expiry: int = 300,
run_for: int = 0, run_for: int = 0,
): ):
@ -62,53 +66,52 @@ class StatorRunner:
self.concurrency_per_model = concurrency_per_model self.concurrency_per_model = concurrency_per_model
self.liveness_file = liveness_file self.liveness_file = liveness_file
self.schedule_interval = schedule_interval self.schedule_interval = schedule_interval
self.delete_interval = delete_interval
self.lock_expiry = lock_expiry self.lock_expiry = lock_expiry
self.run_for = run_for self.run_for = run_for
self.minimum_loop_delay = 0.5 self.minimum_loop_delay = 0.5
self.maximum_loop_delay = 5 self.maximum_loop_delay = 5
self.tasks: list[Future] = []
# Set up SIGALRM handler # Set up SIGALRM handler
signal.signal(signal.SIGALRM, self.alarm_handler) signal.signal(signal.SIGALRM, self.alarm_handler)
async def run(self): def run(self):
sentry.set_takahe_app("stator") sentry.set_takahe_app("stator")
self.handled = {} self.handled = {}
self.started = time.monotonic() self.started = time.monotonic()
self.last_clean = time.monotonic() - self.schedule_interval self.executor = ThreadPoolExecutor(max_workers=self.concurrency)
self.tasks = []
self.loop_delay = self.minimum_loop_delay self.loop_delay = self.minimum_loop_delay
self.schedule_task = LoopingTask(self.run_scheduling) self.scheduling_timer = LoopingTimer(self.schedule_interval)
self.fetch_task = LoopingTask(self.fetch_and_process_tasks) self.deletion_timer = LoopingTimer(self.delete_interval)
self.config_task = LoopingTask(self.load_config)
# For the first time period, launch tasks # For the first time period, launch tasks
print("Running main task loop") print("Running main task loop")
try: try:
with sentry.configure_scope() as scope: with sentry.configure_scope() as scope:
while True: while True:
# Do we need to do cleaning? # See if we need to run cleaning
if (time.monotonic() - self.last_clean) >= self.schedule_interval: if self.scheduling_timer.check():
# Set up the watchdog timer (each time we do this the # Set up the watchdog timer (each time we do this the previous one is cancelled)
# previous one is cancelled)
signal.alarm(self.schedule_interval * 2) signal.alarm(self.schedule_interval * 2)
# Refresh the config
self.config_task.run()
if self.schedule_task.run():
print("Running cleaning and scheduling")
else:
print("Previous scheduling still running...!")
# Write liveness file if configured # Write liveness file if configured
if self.liveness_file: if self.liveness_file:
with open(self.liveness_file, "w") as fh: with open(self.liveness_file, "w") as fh:
fh.write(str(int(time.time()))) fh.write(str(int(time.time())))
self.last_clean = time.monotonic() # Refresh the config
self.load_config()
# Do scheduling (stale lock deletion and stats gathering)
self.run_scheduling()
# Clear the cleaning breadcrumbs/extra for the main part of the loop # Clear the cleaning breadcrumbs/extra for the main part of the loop
sentry.scope_clear(scope) sentry.scope_clear(scope)
self.remove_completed_tasks() self.clean_tasks()
# Fetching is kind of blocking, so we need to do this # See if we need to add deletion tasks
# as a separate coroutine if self.deletion_timer.check():
self.fetch_task.run() self.add_deletion_tasks()
# Fetch and run any new handlers we can fit
self.add_transition_tasks()
# Are we in limited run mode? # Are we in limited run mode?
if ( if (
@ -126,22 +129,19 @@ class StatorRunner:
self.loop_delay * 1.5, self.loop_delay * 1.5,
self.maximum_loop_delay, self.maximum_loop_delay,
) )
await asyncio.sleep(self.loop_delay) time.sleep(self.loop_delay)
# Clear the Sentry breadcrumbs and extra for next loop # Clear the Sentry breadcrumbs and extra for next loop
sentry.scope_clear(scope) sentry.scope_clear(scope)
except KeyboardInterrupt: except KeyboardInterrupt:
pass pass
# Wait for tasks to finish # Wait for tasks to finish
print("Waiting for tasks to complete") print("Waiting for tasks to complete")
while True: self.executor.shutdown()
self.remove_completed_tasks()
if not self.tasks: # We're done
break
# Prevent busylooping
await asyncio.sleep(0.5)
print("Complete") print("Complete")
return self.handled
def alarm_handler(self, signum, frame): def alarm_handler(self, signum, frame):
""" """
@ -151,70 +151,107 @@ class StatorRunner:
print("Watchdog timeout exceeded") print("Watchdog timeout exceeded")
os._exit(2) os._exit(2)
async def load_config(self): def load_config(self):
""" """
Refreshes config from the DB Refreshes config from the DB
""" """
Config.system = await Config.aload_system() Config.system = Config.load_system()
async def run_scheduling(self): def run_scheduling(self):
""" """
Do any transition cleanup tasks Deletes stale locks for models, and submits their stats.
""" """
if self.handled:
print("Tasks processed since last flush:")
for label, number in self.handled.items():
print(f" {label}: {number}")
else:
print("No tasks handled since last flush.")
with sentry.start_transaction(op="task", name="stator.run_scheduling"): with sentry.start_transaction(op="task", name="stator.run_scheduling"):
for model in self.models: for model in self.models:
print(f"Scheduling {model._meta.label_lower}") print(
await self.submit_stats(model) f"{model._meta.label_lower}: Scheduling ({self.handled.get(model._meta.label_lower, 0)} handled)"
print(" Cleaning locks") )
await model.atransition_clean_locks() self.submit_stats(model)
print(" Scheduling due items") model.transition_clean_locks()
await model.atransition_schedule_due()
print(" Deleting due items")
await model.atransition_delete_due()
async def submit_stats(self, model): def submit_stats(self, model: type[StatorModel]):
""" """
Pop some statistics into the database Pop some statistics into the database from our local info for the given model
""" """
stats_instance = await Stats.aget_for_model(model) stats_instance = Stats.get_for_model(model)
if stats_instance.model_label in self.handled: if stats_instance.model_label in self.handled:
stats_instance.add_handled(self.handled[stats_instance.model_label]) stats_instance.add_handled(self.handled[stats_instance.model_label])
del self.handled[stats_instance.model_label] del self.handled[stats_instance.model_label]
stats_instance.set_queued(await model.atransition_ready_count()) stats_instance.set_queued(model.transition_ready_count())
stats_instance.trim_data() stats_instance.trim_data()
await sync_to_async(stats_instance.save)() stats_instance.save()
async def fetch_and_process_tasks(self): def add_transition_tasks(self, call_inline=False):
"""
Adds a transition thread for as many instances as we can, given capacity
and batch size limits.
"""
# Calculate space left for tasks # Calculate space left for tasks
space_remaining = self.concurrency - len(self.tasks) space_remaining = self.concurrency - len(self.tasks)
# Fetch new tasks # Fetch new tasks
for model in self.models: for model in self.models:
if space_remaining > 0: if space_remaining > 0:
for instance in await model.atransition_get_with_lock( for instance in model.transition_get_with_lock(
number=min(space_remaining, self.concurrency_per_model), number=min(space_remaining, self.concurrency_per_model),
lock_expiry=( lock_expiry=(
timezone.now() + datetime.timedelta(seconds=self.lock_expiry) timezone.now() + datetime.timedelta(seconds=self.lock_expiry)
), ),
): ):
if call_inline:
task_transition(instance, in_thread=False)
else:
self.tasks.append( self.tasks.append(
asyncio.create_task(self.run_transition(instance)) self.executor.submit(task_transition, instance)
) )
self.handled[model._meta.label_lower] = ( self.handled[model._meta.label_lower] = (
self.handled.get(model._meta.label_lower, 0) + 1 self.handled.get(model._meta.label_lower, 0) + 1
) )
space_remaining -= 1 space_remaining -= 1
async def run_transition(self, instance: StatorModel): def add_deletion_tasks(self, call_inline=False):
""" """
Wrapper for atransition_attempt with fallback error handling Adds a deletion thread for each model
""" """
task_name = f"stator.run_transition:{instance._meta.label_lower}#{{id}} from {instance.state}" # Yes, this potentially goes over the capacity limit - it's fine.
for model in self.models:
if model.state_graph.deletion_states:
if call_inline:
task_deletion(model, in_thread=False)
else:
self.tasks.append(self.executor.submit(task_deletion, model))
def clean_tasks(self):
"""
Removes any tasks that are done and handles exceptions if they
raised them.
"""
new_tasks = []
for task in self.tasks:
if task.done():
try:
task.result()
except BaseException as e:
exceptions.capture_exception(e)
traceback.print_exc()
else:
new_tasks.append(task)
self.tasks = new_tasks
def run_single_cycle(self):
"""
Testing entrypoint to advance things just one cycle, and allow errors
to propagate out.
"""
self.add_deletion_tasks(call_inline=True)
self.add_transition_tasks(call_inline=True)
def task_transition(instance: StatorModel, in_thread: bool = True):
"""
Runs one state transition/action.
"""
task_name = f"stator.task_transition:{instance._meta.label_lower}#{{id}} from {instance.state}"
started = time.monotonic()
with sentry.start_transaction(op="task", name=task_name): with sentry.start_transaction(op="task", name=task_name):
sentry.set_context( sentry.set_context(
"instance", "instance",
@ -225,29 +262,30 @@ class StatorRunner:
"state_age": instance.state_age, "state_age": instance.state_age,
}, },
) )
result = instance.transition_attempt()
try: duration = time.monotonic() - started
if result:
print( print(
f"Attempting transition on {instance._meta.label_lower}#{instance.pk} from state {instance.state}" f"{instance._meta.label_lower}: {instance.pk}: {instance.state} -> {result} ({duration:.2f}s)"
) )
await instance.atransition_attempt() else:
except BaseException as e: print(
await exceptions.acapture_exception(e) f"{instance._meta.label_lower}: {instance.pk}: {instance.state} unchanged ({duration:.2f}s)"
traceback.print_exc() )
if in_thread:
close_old_connections()
def remove_completed_tasks(self):
"""
Removes all completed asyncio.Tasks from our local in-progress list
"""
self.tasks = [t for t in self.tasks if not t.done()]
async def run_single_cycle(self): def task_deletion(model: type[StatorModel], in_thread: bool = True):
""" """
Testing entrypoint to advance things just one cycle, and allow errors Runs one model deletion set.
to propagate out.
""" """
await asyncio.wait_for(self.fetch_and_process_tasks(), timeout=1) # Loop, running deletions every second, until there are no more to do
for task in self.tasks: while True:
await task deleted = model.transition_delete_due()
if not deleted:
run_single_cycle_sync = async_to_sync(run_single_cycle) break
print(f"{model._meta.label_lower}: Deleted {deleted} stale items")
time.sleep(1)
if in_thread:
close_old_connections()

View file

@ -1 +1 @@
__version__ = "0.9.0" __version__ = "0.10.0-dev"

View file

@ -68,7 +68,7 @@ def test_ensure_hashtag(identity: Identity, config_system, stator):
author=identity, author=identity,
content="Hello, #testtag", content="Hello, #testtag",
) )
stator.run_single_cycle_sync() stator.run_single_cycle()
assert post.hashtags == ["testtag"] assert post.hashtags == ["testtag"]
assert Hashtag.objects.filter(hashtag="testtag").exists() assert Hashtag.objects.filter(hashtag="testtag").exists()
# Excessively long hashtag # Excessively long hashtag
@ -76,7 +76,7 @@ def test_ensure_hashtag(identity: Identity, config_system, stator):
author=identity, author=identity,
content="Hello, #thisisahashtagthatiswaytoolongandissignificantlyaboveourmaximumlimitofonehundredcharacterswhytheywouldbethislongidontknow", content="Hello, #thisisahashtagthatiswaytoolongandissignificantlyaboveourmaximumlimitofonehundredcharacterswhytheywouldbethislongidontknow",
) )
stator.run_single_cycle_sync() stator.run_single_cycle()
assert post.hashtags == [ assert post.hashtags == [
"thisisahashtagthatiswaytoolongandissignificantlyaboveourmaximumlimitofonehundredcharacterswhytheywou" "thisisahashtagthatiswaytoolongandissignificantlyaboveourmaximumlimitofonehundredcharacterswhytheywou"
] ]
@ -226,19 +226,19 @@ def test_post_transitions(identity, stator):
) )
# Test: | --> new --> fanned_out # Test: | --> new --> fanned_out
assert post.state == str(PostStates.new) assert post.state == str(PostStates.new)
stator.run_single_cycle_sync() stator.run_single_cycle()
post = Post.objects.get(id=post.id) post = Post.objects.get(id=post.id)
assert post.state == str(PostStates.fanned_out) assert post.state == str(PostStates.fanned_out)
# Test: fanned_out --> (forced) edited --> edited_fanned_out # Test: fanned_out --> (forced) edited --> edited_fanned_out
Post.transition_perform(post, PostStates.edited) Post.transition_perform(post, PostStates.edited)
stator.run_single_cycle_sync() stator.run_single_cycle()
post = Post.objects.get(id=post.id) post = Post.objects.get(id=post.id)
assert post.state == str(PostStates.edited_fanned_out) assert post.state == str(PostStates.edited_fanned_out)
# Test: edited_fanned_out --> (forced) deleted --> deleted_fanned_out # Test: edited_fanned_out --> (forced) deleted --> deleted_fanned_out
Post.transition_perform(post, PostStates.deleted) Post.transition_perform(post, PostStates.deleted)
stator.run_single_cycle_sync() stator.run_single_cycle()
post = Post.objects.get(id=post.id) post = Post.objects.get(id=post.id)
assert post.state == str(PostStates.deleted_fanned_out) assert post.state == str(PostStates.deleted_fanned_out)
@ -392,7 +392,7 @@ def test_inbound_posts(
InboxMessage.objects.create(message=message) InboxMessage.objects.create(message=message)
# Run stator and ensure that made the post # Run stator and ensure that made the post
stator.run_single_cycle_sync() stator.run_single_cycle()
post = Post.objects.get(object_uri="https://remote.test/test-post") post = Post.objects.get(object_uri="https://remote.test/test-post")
assert post.content == "post version one" assert post.content == "post version one"
assert post.published.day == 13 assert post.published.day == 13
@ -416,7 +416,7 @@ def test_inbound_posts(
InboxMessage.objects.create(message=message) InboxMessage.objects.create(message=message)
# Run stator and ensure that edited the post # Run stator and ensure that edited the post
stator.run_single_cycle_sync() stator.run_single_cycle()
post = Post.objects.get(object_uri="https://remote.test/test-post") post = Post.objects.get(object_uri="https://remote.test/test-post")
assert post.content == "post version two" assert post.content == "post version two"
assert post.edited.day == 14 assert post.edited.day == 14
@ -455,7 +455,7 @@ def test_inbound_posts(
InboxMessage.objects.create(message=message) InboxMessage.objects.create(message=message)
# Run stator and ensure that deleted the post # Run stator and ensure that deleted the post
stator.run_single_cycle_sync() stator.run_single_cycle()
assert not Post.objects.filter(object_uri="https://remote.test/test-post").exists() assert not Post.objects.filter(object_uri="https://remote.test/test-post").exists()
# Create an inbound new post message with only contentMap # Create an inbound new post message with only contentMap
@ -474,7 +474,7 @@ def test_inbound_posts(
InboxMessage.objects.create(message=message) InboxMessage.objects.create(message=message)
# Run stator and ensure that made the post # Run stator and ensure that made the post
stator.run_single_cycle_sync() stator.run_single_cycle()
post = Post.objects.get(object_uri="https://remote.test/test-map-only") post = Post.objects.get(object_uri="https://remote.test/test-map-only")
assert post.content == "post with only content map" assert post.content == "post with only content map"
assert post.published.day == 13 assert post.published.day == 13

View file

@ -1,5 +1,4 @@
import pytest import pytest
from asgiref.sync import async_to_sync
from activities.models import Post from activities.models import Post
from users.models import Block, Domain, Follow, Identity from users.models import Block, Domain, Follow, Identity
@ -16,7 +15,7 @@ def test_post_targets_simple(identity, other_identity, remote_identity):
author=identity, author=identity,
local=True, local=True,
) )
targets = async_to_sync(post.aget_targets)() targets = post.get_targets()
assert targets == {identity} assert targets == {identity}
# Test remote reply targets original post author # Test remote reply targets original post author
@ -26,7 +25,7 @@ def test_post_targets_simple(identity, other_identity, remote_identity):
local=False, local=False,
in_reply_to=post.absolute_object_uri(), in_reply_to=post.absolute_object_uri(),
) )
targets = async_to_sync(post.aget_targets)() targets = post.get_targets()
assert targets == {identity} assert targets == {identity}
# Test a post with local and remote mentions # Test a post with local and remote mentions
@ -38,14 +37,14 @@ def test_post_targets_simple(identity, other_identity, remote_identity):
# Mentions are targeted # Mentions are targeted
post.mentions.add(remote_identity) post.mentions.add(remote_identity)
post.mentions.add(other_identity) post.mentions.add(other_identity)
targets = async_to_sync(post.aget_targets)() targets = post.get_targets()
# Targets everyone # Targets everyone
assert targets == {identity, other_identity, remote_identity} assert targets == {identity, other_identity, remote_identity}
# Test remote post with mentions # Test remote post with mentions
post.local = False post.local = False
post.save() post.save()
targets = async_to_sync(post.aget_targets)() targets = post.get_targets()
# Only targets locals who are mentioned # Only targets locals who are mentioned
assert targets == {other_identity} assert targets == {other_identity}
@ -89,7 +88,7 @@ def test_post_targets_shared(identity, other_identity):
post.mentions.add(other_identity) post.mentions.add(other_identity)
post.mentions.add(remote1) post.mentions.add(remote1)
post.mentions.add(remote2) post.mentions.add(remote2)
targets = async_to_sync(post.aget_targets)() targets = post.get_targets()
# We should only have one of remote1 or remote2 in there as they share a # We should only have one of remote1 or remote2 in there as they share a
# shared inbox URI # shared inbox URI
@ -120,13 +119,12 @@ def test_post_local_only(identity, other_identity, remote_identity):
# Remote mention is not targeted # Remote mention is not targeted
post.mentions.add(remote_identity) post.mentions.add(remote_identity)
targets = async_to_sync(post.aget_targets)() targets = post.get_targets()
assert targets == {identity, other_identity} assert targets == {identity, other_identity}
@pytest.mark.django_db @pytest.mark.django_db
def test_post_followers(identity, other_identity, remote_identity): def test_post_followers(identity, other_identity, remote_identity):
Follow.objects.create(source=other_identity, target=identity) Follow.objects.create(source=other_identity, target=identity)
Follow.objects.create(source=remote_identity, target=identity) Follow.objects.create(source=remote_identity, target=identity)
@ -137,26 +135,26 @@ def test_post_followers(identity, other_identity, remote_identity):
local=True, local=True,
visibility=Post.Visibilities.public, visibility=Post.Visibilities.public,
) )
targets = async_to_sync(post.aget_targets)() targets = post.get_targets()
assert targets == {identity, other_identity, remote_identity} assert targets == {identity, other_identity, remote_identity}
# Remote post only targets local followers, not the author # Remote post only targets local followers, not the author
post.local = False post.local = False
post.save() post.save()
targets = async_to_sync(post.aget_targets)() targets = post.get_targets()
assert targets == {other_identity} assert targets == {other_identity}
# Local Only post only targets local followers # Local Only post only targets local followers
post.local = True post.local = True
post.visibility = Post.Visibilities.local_only post.visibility = Post.Visibilities.local_only
post.save() post.save()
targets = async_to_sync(post.aget_targets)() targets = post.get_targets()
assert targets == {identity, other_identity} assert targets == {identity, other_identity}
# Mentioned posts do not target unmentioned followers # Mentioned posts do not target unmentioned followers
post.visibility = Post.Visibilities.mentioned post.visibility = Post.Visibilities.mentioned
post.save() post.save()
targets = async_to_sync(post.aget_targets)() targets = post.get_targets()
assert targets == {identity} assert targets == {identity}
@ -179,5 +177,5 @@ def test_post_blocked(identity, other_identity, remote_identity):
post.mentions.add(other_identity) post.mentions.add(other_identity)
# The muted block should be in targets, the full block should not # The muted block should be in targets, the full block should not
targets = async_to_sync(post.aget_targets)() targets = post.get_targets()
assert targets == {identity, other_identity} assert targets == {identity, other_identity}

View file

@ -53,9 +53,10 @@ def test_mentioned(
elif blocked == "mute": elif blocked == "mute":
Block.create_local_mute(identity, author) Block.create_local_mute(identity, author)
# Run stator twice - to make fanouts and then process them # Run stator thrice - to receive the post, make fanouts and then process them
stator.run_single_cycle_sync() stator.run_single_cycle()
stator.run_single_cycle_sync() stator.run_single_cycle()
stator.run_single_cycle()
if blocked in ["full", "mute"]: if blocked in ["full", "mute"]:
# Verify we were not mentioned # Verify we were not mentioned
@ -121,9 +122,10 @@ def test_interaction_local_post(
elif blocked == "mute_with_notifications": elif blocked == "mute_with_notifications":
Block.create_local_mute(identity, interactor, include_notifications=True) Block.create_local_mute(identity, interactor, include_notifications=True)
# Run stator twice - to make fanouts and then process them # Run stator thrice - to receive the post, make fanouts and then process them
stator.run_single_cycle_sync() stator.run_single_cycle()
stator.run_single_cycle_sync() stator.run_single_cycle()
stator.run_single_cycle()
timeline_event_type = ( timeline_event_type = (
TimelineEvent.Types.boosted if type == "boost" else TimelineEvent.Types.liked TimelineEvent.Types.boosted if type == "boost" else TimelineEvent.Types.liked
@ -177,9 +179,10 @@ def test_old_new_post(
} }
InboxMessage.objects.create(message=message) InboxMessage.objects.create(message=message)
# Run stator twice - to make fanouts and then process them # Run stator thrice - to receive the post, make fanouts and then process them
stator.run_single_cycle_sync() stator.run_single_cycle()
stator.run_single_cycle_sync() stator.run_single_cycle()
stator.run_single_cycle()
if old: if old:
# Verify it did not appear on the timeline # Verify it did not appear on the timeline
@ -229,9 +232,10 @@ def test_clear_timeline(
} }
InboxMessage.objects.create(message=message) InboxMessage.objects.create(message=message)
# Run stator twice - to make fanouts and then process them # Run stator thrice - to receive the post, make fanouts and then process them
stator.run_single_cycle_sync() stator.run_single_cycle()
stator.run_single_cycle_sync() stator.run_single_cycle()
stator.run_single_cycle()
# Make sure it appeared on our timeline as a post and a mentioned # Make sure it appeared on our timeline as a post and a mentioned
assert TimelineEvent.objects.filter( assert TimelineEvent.objects.filter(
@ -248,7 +252,7 @@ def test_clear_timeline(
service.unfollow(remote_identity) service.unfollow(remote_identity)
# Run stator once to process the timeline clear message # Run stator once to process the timeline clear message
stator.run_single_cycle_sync() stator.run_single_cycle()
# Verify that the right things vanished # Verify that the right things vanished
assert not TimelineEvent.objects.filter( assert not TimelineEvent.objects.filter(
@ -308,9 +312,10 @@ def test_hashtag_followed(
elif blocked == "mute": elif blocked == "mute":
Block.create_local_mute(identity, author) Block.create_local_mute(identity, author)
# Run stator twice - to make fanouts and then process them # Run stator thrice - to receive the post, make fanouts and then process them
stator.run_single_cycle_sync() stator.run_single_cycle()
stator.run_single_cycle_sync() stator.run_single_cycle()
stator.run_single_cycle()
if blocked in ["full", "mute"]: if blocked in ["full", "mute"]:
# Verify post is not in timeline # Verify post is not in timeline

View file

@ -27,7 +27,7 @@ def test_follow(
url="https://remote.test/@test/inbox/", url="https://remote.test/@test/inbox/",
status_code=202, status_code=202,
) )
stator.run_single_cycle_sync() stator.run_single_cycle()
outbound_data = json.loads(httpx_mock.get_request().content) outbound_data = json.loads(httpx_mock.get_request().content)
assert outbound_data["type"] == "Follow" assert outbound_data["type"] == "Follow"
assert outbound_data["actor"] == identity.actor_uri assert outbound_data["actor"] == identity.actor_uri
@ -52,5 +52,5 @@ def test_follow(
} }
InboxMessage.objects.create(message=message) InboxMessage.objects.create(message=message)
# Run stator and ensure that accepted our follow # Run stator and ensure that accepted our follow
stator.run_single_cycle_sync() stator.run_single_cycle()
assert Follow.objects.get(pk=follow.pk).state == FollowStates.accepted assert Follow.objects.get(pk=follow.pk).state == FollowStates.accepted

View file

@ -120,5 +120,5 @@ def test_signup_email(client, config_system, stator):
# Run Stator and verify it sends the email # Run Stator and verify it sends the email
assert len(mail.outbox) == 0 assert len(mail.outbox) == 0
stator.run_single_cycle_sync() stator.run_single_cycle()
assert len(mail.outbox) == 1 assert len(mail.outbox) == 1

View file

@ -37,7 +37,7 @@ def test_import_following(
assert InboxMessage.objects.count() == 1 assert InboxMessage.objects.count() == 1
# Run stator to process it # Run stator to process it
stator.run_single_cycle_sync() stator.run_single_cycle()
# See if we're now following that identity # See if we're now following that identity
assert identity.outbound_follows.filter(target=remote_identity).count() == 1 assert identity.outbound_follows.filter(target=remote_identity).count() == 1

View file

@ -0,0 +1,271 @@
# Generated by Django 4.2.1 on 2023-07-05 22:18
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("users", "0018_index_together_migration"),
]
operations = [
migrations.RemoveIndex(
model_name="block",
name="users_block_state_r_e016a2_idx",
),
migrations.RemoveIndex(
model_name="block",
name="ix_block_state_attempted",
),
migrations.RemoveIndex(
model_name="block",
name="ix_block_state_locked",
),
migrations.RemoveIndex(
model_name="domain",
name="ix_domain_state_attempted",
),
migrations.RemoveIndex(
model_name="domain",
name="ix_domain_state_locked",
),
migrations.RemoveIndex(
model_name="domain",
name="users_domai_state_r_42b328_idx",
),
migrations.RemoveIndex(
model_name="follow",
name="users_follo_state_r_d1dbc2_idx",
),
migrations.RemoveIndex(
model_name="follow",
name="ix_follow_state_attempted",
),
migrations.RemoveIndex(
model_name="follow",
name="ix_follow_state_locked",
),
migrations.RemoveIndex(
model_name="identity",
name="users_ident_state_r_6fdeee_idx",
),
migrations.RemoveIndex(
model_name="identity",
name="ix_identity_state_attempted",
),
migrations.RemoveIndex(
model_name="identity",
name="ix_identity_state_locked",
),
migrations.RemoveIndex(
model_name="inboxmessage",
name="ix_inboxmessag_state_attempted",
),
migrations.RemoveIndex(
model_name="inboxmessage",
name="ix_inboxmessag_state_locked",
),
migrations.RemoveIndex(
model_name="inboxmessage",
name="users_inbox_state_r_00fce2_idx",
),
migrations.RemoveIndex(
model_name="passwordreset",
name="ix_passwordres_state_attempted",
),
migrations.RemoveIndex(
model_name="passwordreset",
name="ix_passwordres_state_locked",
),
migrations.RemoveIndex(
model_name="passwordreset",
name="users_passw_state_r_f54f10_idx",
),
migrations.RemoveIndex(
model_name="report",
name="ix_report_state_attempted",
),
migrations.RemoveIndex(
model_name="report",
name="ix_report_state_locked",
),
migrations.RemoveIndex(
model_name="report",
name="users_repor_state_r_345b80_idx",
),
migrations.RemoveField(
model_name="block",
name="state_attempted",
),
migrations.RemoveField(
model_name="block",
name="state_ready",
),
migrations.RemoveField(
model_name="domain",
name="state_attempted",
),
migrations.RemoveField(
model_name="domain",
name="state_ready",
),
migrations.RemoveField(
model_name="follow",
name="state_attempted",
),
migrations.RemoveField(
model_name="follow",
name="state_ready",
),
migrations.RemoveField(
model_name="identity",
name="state_attempted",
),
migrations.RemoveField(
model_name="identity",
name="state_ready",
),
migrations.RemoveField(
model_name="inboxmessage",
name="state_attempted",
),
migrations.RemoveField(
model_name="inboxmessage",
name="state_ready",
),
migrations.RemoveField(
model_name="passwordreset",
name="state_attempted",
),
migrations.RemoveField(
model_name="passwordreset",
name="state_ready",
),
migrations.RemoveField(
model_name="report",
name="state_attempted",
),
migrations.RemoveField(
model_name="report",
name="state_ready",
),
migrations.AddField(
model_name="block",
name="state_next_attempt",
field=models.DateTimeField(blank=True, null=True),
),
migrations.AddField(
model_name="domain",
name="state_next_attempt",
field=models.DateTimeField(blank=True, null=True),
),
migrations.AddField(
model_name="follow",
name="state_next_attempt",
field=models.DateTimeField(blank=True, null=True),
),
migrations.AddField(
model_name="identity",
name="state_next_attempt",
field=models.DateTimeField(blank=True, null=True),
),
migrations.AddField(
model_name="inboxmessage",
name="state_next_attempt",
field=models.DateTimeField(blank=True, null=True),
),
migrations.AddField(
model_name="passwordreset",
name="state_next_attempt",
field=models.DateTimeField(blank=True, null=True),
),
migrations.AddField(
model_name="report",
name="state_next_attempt",
field=models.DateTimeField(blank=True, null=True),
),
migrations.AlterField(
model_name="block",
name="state_locked_until",
field=models.DateTimeField(blank=True, db_index=True, null=True),
),
migrations.AlterField(
model_name="domain",
name="state_locked_until",
field=models.DateTimeField(blank=True, db_index=True, null=True),
),
migrations.AlterField(
model_name="follow",
name="state_locked_until",
field=models.DateTimeField(blank=True, db_index=True, null=True),
),
migrations.AlterField(
model_name="identity",
name="state_locked_until",
field=models.DateTimeField(blank=True, db_index=True, null=True),
),
migrations.AlterField(
model_name="inboxmessage",
name="state_locked_until",
field=models.DateTimeField(blank=True, db_index=True, null=True),
),
migrations.AlterField(
model_name="passwordreset",
name="state_locked_until",
field=models.DateTimeField(blank=True, db_index=True, null=True),
),
migrations.AlterField(
model_name="report",
name="state_locked_until",
field=models.DateTimeField(blank=True, db_index=True, null=True),
),
migrations.AddIndex(
model_name="block",
index=models.Index(
fields=["state", "state_next_attempt", "state_locked_until"],
name="ix_block_state_next",
),
),
migrations.AddIndex(
model_name="follow",
index=models.Index(
fields=["state", "state_next_attempt", "state_locked_until"],
name="ix_follow_state_next",
),
),
migrations.AddIndex(
model_name="identity",
index=models.Index(
fields=["state", "state_next_attempt", "state_locked_until"],
name="ix_identity_state_next",
),
),
migrations.AddIndex(
model_name="inboxmessage",
index=models.Index(
fields=["state", "state_next_attempt", "state_locked_until"],
name="ix_inboxmessag_state_next",
),
),
migrations.AddIndex(
model_name="passwordreset",
index=models.Index(
fields=["state", "state_next_attempt", "state_locked_until"],
name="ix_passwordres_state_next",
),
),
migrations.AddIndex(
model_name="report",
index=models.Index(
fields=["state", "state_next_attempt", "state_locked_until"],
name="ix_report_state_next",
),
),
migrations.AddIndex(
model_name="domain",
index=models.Index(
fields=["state", "state_next_attempt", "state_locked_until"],
name="ix_domain_state_next",
),
),
]

View file

@ -0,0 +1,18 @@
# Generated by Django 4.2.1 on 2023-07-07 20:37
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("users", "0019_stator_next_change"),
]
operations = [
migrations.AlterField(
model_name="identity",
name="local",
field=models.BooleanField(db_index=True),
),
]

View file

@ -136,7 +136,7 @@ class Block(StatorModel):
class Meta: class Meta:
unique_together = [("source", "target", "mute")] unique_together = [("source", "target", "mute")]
indexes = StatorModel.Meta.indexes indexes: list = [] # We need this so Stator can add its own
def __str__(self): def __str__(self):
return f"#{self.id}: {self.source} blocks {self.target}" return f"#{self.id}: {self.source} blocks {self.target}"

View file

@ -146,7 +146,7 @@ class Follow(StatorModel):
class Meta: class Meta:
unique_together = [("source", "target")] unique_together = [("source", "target")]
indexes = StatorModel.Meta.indexes indexes: list = [] # We need this so Stator can add its own
def __str__(self): def __str__(self):
return f"#{self.id}: {self.source}{self.target}" return f"#{self.id}: {self.source}{self.target}"

View file

@ -158,7 +158,7 @@ class Identity(StatorModel):
state = StateField(IdentityStates) state = StateField(IdentityStates)
local = models.BooleanField() local = models.BooleanField(db_index=True)
users = models.ManyToManyField( users = models.ManyToManyField(
"users.User", "users.User",
related_name="identities", related_name="identities",
@ -227,7 +227,7 @@ class Identity(StatorModel):
class Meta: class Meta:
verbose_name_plural = "identities" verbose_name_plural = "identities"
unique_together = [("username", "domain")] unique_together = [("username", "domain")]
indexes = StatorModel.Meta.indexes indexes: list = [] # We need this so Stator can add its own
class urls(urlman.Urls): class urls(urlman.Urls):
view = "/@{self.username}@{self.domain_id}/" view = "/@{self.username}@{self.domain_id}/"

View file

@ -1,4 +1,3 @@
from asgiref.sync import sync_to_async
from django.db import models from django.db import models
from stator.models import State, StateField, StateGraph, StatorModel from stator.models import State, StateField, StateGraph, StatorModel
@ -13,61 +12,59 @@ class InboxMessageStates(StateGraph):
processed.transitions_to(purge) # Delete after release (back compat) processed.transitions_to(purge) # Delete after release (back compat)
@classmethod @classmethod
async def handle_received(cls, instance: "InboxMessage"): def handle_received(cls, instance: "InboxMessage"):
from activities.models import Post, PostInteraction, TimelineEvent from activities.models import Post, PostInteraction, TimelineEvent
from users.models import Block, Follow, Identity, Report from users.models import Block, Follow, Identity, Report
from users.services import IdentityService from users.services import IdentityService
match instance.message_type: match instance.message_type:
case "follow": case "follow":
await sync_to_async(Follow.handle_request_ap)(instance.message) Follow.handle_request_ap(instance.message)
case "block": case "block":
await sync_to_async(Block.handle_ap)(instance.message) Block.handle_ap(instance.message)
case "announce": case "announce":
await sync_to_async(PostInteraction.handle_ap)(instance.message) PostInteraction.handle_ap(instance.message)
case "like": case "like":
await sync_to_async(PostInteraction.handle_ap)(instance.message) PostInteraction.handle_ap(instance.message)
case "create": case "create":
match instance.message_object_type: match instance.message_object_type:
case "note": case "note":
if instance.message_object_has_content: if instance.message_object_has_content:
await sync_to_async(Post.handle_create_ap)(instance.message) Post.handle_create_ap(instance.message)
else: else:
# Notes without content are Interaction candidates # Notes without content are Interaction candidates
await sync_to_async(PostInteraction.handle_ap)( PostInteraction.handle_ap(instance.message)
instance.message
)
case "question": case "question":
await sync_to_async(Post.handle_create_ap)(instance.message) Post.handle_create_ap(instance.message)
case unknown: case unknown:
if unknown in Post.Types.names: if unknown in Post.Types.names:
await sync_to_async(Post.handle_create_ap)(instance.message) Post.handle_create_ap(instance.message)
case "update": case "update":
match instance.message_object_type: match instance.message_object_type:
case "note": case "note":
await sync_to_async(Post.handle_update_ap)(instance.message) Post.handle_update_ap(instance.message)
case "person": case "person":
await sync_to_async(Identity.handle_update_ap)(instance.message) Identity.handle_update_ap(instance.message)
case "service": case "service":
await sync_to_async(Identity.handle_update_ap)(instance.message) Identity.handle_update_ap(instance.message)
case "group": case "group":
await sync_to_async(Identity.handle_update_ap)(instance.message) Identity.handle_update_ap(instance.message)
case "organization": case "organization":
await sync_to_async(Identity.handle_update_ap)(instance.message) Identity.handle_update_ap(instance.message)
case "application": case "application":
await sync_to_async(Identity.handle_update_ap)(instance.message) Identity.handle_update_ap(instance.message)
case "question": case "question":
await sync_to_async(Post.handle_update_ap)(instance.message) Post.handle_update_ap(instance.message)
case unknown: case unknown:
if unknown in Post.Types.names: if unknown in Post.Types.names:
await sync_to_async(Post.handle_update_ap)(instance.message) Post.handle_update_ap(instance.message)
case "accept": case "accept":
match instance.message_object_type: match instance.message_object_type:
case "follow": case "follow":
await sync_to_async(Follow.handle_accept_ap)(instance.message) Follow.handle_accept_ap(instance.message)
case None: case None:
# It's a string object, but these will only be for Follows # It's a string object, but these will only be for Follows
await sync_to_async(Follow.handle_accept_ap)(instance.message) Follow.handle_accept_ap(instance.message)
case unknown: case unknown:
raise ValueError( raise ValueError(
f"Cannot handle activity of type accept.{unknown}" f"Cannot handle activity of type accept.{unknown}"
@ -75,10 +72,10 @@ class InboxMessageStates(StateGraph):
case "reject": case "reject":
match instance.message_object_type: match instance.message_object_type:
case "follow": case "follow":
await sync_to_async(Follow.handle_reject_ap)(instance.message) Follow.handle_reject_ap(instance.message)
case None: case None:
# It's a string object, but these will only be for Follows # It's a string object, but these will only be for Follows
await sync_to_async(Follow.handle_reject_ap)(instance.message) Follow.handle_reject_ap(instance.message)
case unknown: case unknown:
raise ValueError( raise ValueError(
f"Cannot handle activity of type reject.{unknown}" f"Cannot handle activity of type reject.{unknown}"
@ -86,17 +83,13 @@ class InboxMessageStates(StateGraph):
case "undo": case "undo":
match instance.message_object_type: match instance.message_object_type:
case "follow": case "follow":
await sync_to_async(Follow.handle_undo_ap)(instance.message) Follow.handle_undo_ap(instance.message)
case "block": case "block":
await sync_to_async(Block.handle_undo_ap)(instance.message) Block.handle_undo_ap(instance.message)
case "like": case "like":
await sync_to_async(PostInteraction.handle_undo_ap)( PostInteraction.handle_undo_ap(instance.message)
instance.message
)
case "announce": case "announce":
await sync_to_async(PostInteraction.handle_undo_ap)( PostInteraction.handle_undo_ap(instance.message)
instance.message
)
case "http://litepub.social/ns#emojireact": case "http://litepub.social/ns#emojireact":
# We're ignoring emoji reactions for now # We're ignoring emoji reactions for now
pass pass
@ -107,31 +100,31 @@ class InboxMessageStates(StateGraph):
case "delete": case "delete":
# If there is no object type, we need to see if it's a profile or a post # If there is no object type, we need to see if it's a profile or a post
if not isinstance(instance.message["object"], dict): if not isinstance(instance.message["object"], dict):
if await Identity.objects.filter( if Identity.objects.filter(
actor_uri=instance.message["object"] actor_uri=instance.message["object"]
).aexists(): ).exists():
await sync_to_async(Identity.handle_delete_ap)(instance.message) Identity.handle_delete_ap(instance.message)
elif await Post.objects.filter( elif Post.objects.filter(
object_uri=instance.message["object"] object_uri=instance.message["object"]
).aexists(): ).exists():
await sync_to_async(Post.handle_delete_ap)(instance.message) Post.handle_delete_ap(instance.message)
else: else:
# It is presumably already deleted # It is presumably already deleted
pass pass
else: else:
match instance.message_object_type: match instance.message_object_type:
case "tombstone": case "tombstone":
await sync_to_async(Post.handle_delete_ap)(instance.message) Post.handle_delete_ap(instance.message)
case "note": case "note":
await sync_to_async(Post.handle_delete_ap)(instance.message) Post.handle_delete_ap(instance.message)
case unknown: case unknown:
raise ValueError( raise ValueError(
f"Cannot handle activity of type delete.{unknown}" f"Cannot handle activity of type delete.{unknown}"
) )
case "add": case "add":
await sync_to_async(PostInteraction.handle_add_ap)(instance.message) PostInteraction.handle_add_ap(instance.message)
case "remove": case "remove":
await sync_to_async(PostInteraction.handle_remove_ap)(instance.message) PostInteraction.handle_remove_ap(instance.message)
case "move": case "move":
# We're ignoring moves for now # We're ignoring moves for now
pass pass
@ -140,19 +133,15 @@ class InboxMessageStates(StateGraph):
pass pass
case "flag": case "flag":
# Received reports # Received reports
await sync_to_async(Report.handle_ap)(instance.message) Report.handle_ap(instance.message)
case "__internal__": case "__internal__":
match instance.message_object_type: match instance.message_object_type:
case "fetchpost": case "fetchpost":
await sync_to_async(Post.handle_fetch_internal)( Post.handle_fetch_internal(instance.message["object"])
instance.message["object"]
)
case "cleartimeline": case "cleartimeline":
await sync_to_async(TimelineEvent.handle_clear_timeline)( TimelineEvent.handle_clear_timeline(instance.message["object"])
instance.message["object"]
)
case "addfollow": case "addfollow":
await sync_to_async(IdentityService.handle_internal_add_follow)( IdentityService.handle_internal_add_follow(
instance.message["object"] instance.message["object"]
) )
case unknown: case unknown: