Prune some unnecessary async usage

This commit is contained in:
Andrew Godwin 2023-07-17 00:18:00 -06:00
parent 31c743319e
commit 0915b17c4b
12 changed files with 63 additions and 171 deletions

View file

@ -319,23 +319,3 @@ class FanOut(StatorModel):
created = models.DateTimeField(auto_now_add=True) created = models.DateTimeField(auto_now_add=True)
updated = models.DateTimeField(auto_now=True) updated = models.DateTimeField(auto_now=True)
### Async helpers ###
async def afetch_full(self):
"""
Returns a version of the object with all relations pre-loaded
"""
return (
await FanOut.objects.select_related(
"identity",
"subject_post",
"subject_post_interaction",
"subject_identity",
"subject_identity__domain",
)
.prefetch_related(
"subject_post__emojis",
)
.aget(pk=self.pk)
)

View file

@ -447,18 +447,6 @@ class Post(StatorModel):
"replies": self.stats.get("replies", 0) if self.stats else 0, "replies": self.stats.get("replies", 0) if self.stats else 0,
} }
### Async helpers ###
async def afetch_full(self) -> "Post":
"""
Returns a version of the object with all relations pre-loaded
"""
return (
await Post.objects.select_related("author", "author__domain")
.prefetch_related("mentions", "mentions__domain", "attachments", "emojis")
.aget(pk=self.pk)
)
### Local creation/editing ### ### Local creation/editing ###
@classmethod @classmethod

View file

@ -27,103 +27,89 @@ class PostInteractionStates(StateGraph):
return [cls.new, cls.fanned_out] return [cls.new, cls.fanned_out]
@classmethod @classmethod
async def handle_new(cls, instance: "PostInteraction"): def handle_new(cls, instance: "PostInteraction"):
""" """
Creates all needed fan-out objects for a new PostInteraction. Creates all needed fan-out objects for a new PostInteraction.
""" """
interaction = await instance.afetch_full()
# Boost: send a copy to all people who follow this user (limiting # Boost: send a copy to all people who follow this user (limiting
# to just local follows if it's a remote boost) # to just local follows if it's a remote boost)
# Pin: send Add activity to all people who follow this user # Pin: send Add activity to all people who follow this user
if ( if instance.type == instance.Types.boost or instance.type == instance.Types.pin:
interaction.type == interaction.Types.boost for target in instance.get_targets():
or interaction.type == interaction.Types.pin FanOut.objects.create(
):
for target in await interaction.aget_targets():
await FanOut.objects.acreate(
type=FanOut.Types.interaction, type=FanOut.Types.interaction,
identity=target, identity=target,
subject_post=interaction.post, subject_post=instance.post,
subject_post_interaction=interaction, subject_post_interaction=instance,
) )
# Like: send a copy to the original post author only, # Like: send a copy to the original post author only,
# if the liker is local or they are # if the liker is local or they are
elif interaction.type == interaction.Types.like: elif instance.type == instance.Types.like:
if interaction.identity.local or interaction.post.local: if instance.identity.local or instance.post.local:
await FanOut.objects.acreate( FanOut.objects.create(
type=FanOut.Types.interaction, type=FanOut.Types.interaction,
identity_id=interaction.post.author_id, identity_id=instance.post.author_id,
subject_post=interaction.post, subject_post=instance.post,
subject_post_interaction=interaction, subject_post_interaction=instance,
) )
# Vote: send a copy of the vote to the original # Vote: send a copy of the vote to the original
# post author only if it's a local interaction # post author only if it's a local interaction
# to a non local post # to a non local post
elif interaction.type == interaction.Types.vote: elif instance.type == instance.Types.vote:
if interaction.identity.local and not interaction.post.local: if instance.identity.local and not instance.post.local:
await FanOut.objects.acreate( FanOut.objects.create(
type=FanOut.Types.interaction, type=FanOut.Types.interaction,
identity_id=interaction.post.author_id, identity_id=instance.post.author_id,
subject_post=interaction.post, subject_post=instance.post,
subject_post_interaction=interaction, subject_post_interaction=instance,
) )
else: else:
raise ValueError("Cannot fan out unknown type") raise ValueError("Cannot fan out unknown type")
# And one for themselves if they're local and it's a boost # And one for themselves if they're local and it's a boost
if ( if instance.type == PostInteraction.Types.boost and instance.identity.local:
interaction.type == PostInteraction.Types.boost FanOut.objects.create(
and interaction.identity.local identity_id=instance.identity_id,
):
await FanOut.objects.acreate(
identity_id=interaction.identity_id,
type=FanOut.Types.interaction, type=FanOut.Types.interaction,
subject_post=interaction.post, subject_post=instance.post,
subject_post_interaction=interaction, subject_post_interaction=instance,
) )
return cls.fanned_out return cls.fanned_out
@classmethod @classmethod
async def handle_undone(cls, instance: "PostInteraction"): def handle_undone(cls, instance: "PostInteraction"):
""" """
Creates all needed fan-out objects to undo a PostInteraction. Creates all needed fan-out objects to undo a PostInteraction.
""" """
interaction = await instance.afetch_full()
# Undo Boost: send a copy to all people who follow this user # Undo Boost: send a copy to all people who follow this user
# Undo Pin: send a Remove activity to all people who follow this user # Undo Pin: send a Remove activity to all people who follow this user
if ( if instance.type == instance.Types.boost or instance.type == instance.Types.pin:
interaction.type == interaction.Types.boost for follow in instance.identity.inbound_follows.select_related(
or interaction.type == interaction.Types.pin
):
async for follow in interaction.identity.inbound_follows.select_related(
"source", "target" "source", "target"
): ):
if follow.source.local or follow.target.local: if follow.source.local or follow.target.local:
await FanOut.objects.acreate( FanOut.objects.create(
type=FanOut.Types.undo_interaction, type=FanOut.Types.undo_interaction,
identity_id=follow.source_id, identity_id=follow.source_id,
subject_post=interaction.post, subject_post=instance.post,
subject_post_interaction=interaction, subject_post_interaction=instance,
) )
# Undo Like: send a copy to the original post author only # Undo Like: send a copy to the original post author only
elif interaction.type == interaction.Types.like: elif instance.type == instance.Types.like:
await FanOut.objects.acreate( FanOut.objects.create(
type=FanOut.Types.undo_interaction, type=FanOut.Types.undo_interaction,
identity_id=interaction.post.author_id, identity_id=instance.post.author_id,
subject_post=interaction.post, subject_post=instance.post,
subject_post_interaction=interaction, subject_post_interaction=instance,
) )
else: else:
raise ValueError("Cannot fan out unknown type") raise ValueError("Cannot fan out unknown type")
# And one for themselves if they're local and it's a boost # And one for themselves if they're local and it's a boost
if ( if instance.type == PostInteraction.Types.boost and instance.identity.local:
interaction.type == PostInteraction.Types.boost FanOut.objects.create(
and interaction.identity.local identity_id=instance.identity_id,
):
await FanOut.objects.acreate(
identity_id=interaction.identity_id,
type=FanOut.Types.undo_interaction, type=FanOut.Types.undo_interaction,
subject_post=interaction.post, subject_post=instance.post,
subject_post_interaction=interaction, subject_post_interaction=instance,
) )
return cls.undone_fanned_out return cls.undone_fanned_out
@ -212,17 +198,7 @@ class PostInteraction(StatorModel):
[e.subject_post for e in events if e.subject_post], identity [e.subject_post for e in events if e.subject_post], identity
) )
### Async helpers ### def get_targets(self) -> Iterable[Identity]:
async def afetch_full(self):
"""
Returns a version of the object with all relations pre-loaded
"""
return await PostInteraction.objects.select_related(
"identity", "post", "post__author"
).aget(pk=self.pk)
async def aget_targets(self) -> Iterable[Identity]:
""" """
Returns an iterable with Identities of followers that have unique Returns an iterable with Identities of followers that have unique
shared_inbox among each other to be used as target. shared_inbox among each other to be used as target.
@ -237,13 +213,15 @@ class PostInteraction(StatorModel):
# Include all followers that are following the boosts # Include all followers that are following the boosts
if self.type == self.Types.boost: if self.type == self.Types.boost:
query = query.filter(boosts=True) query = query.filter(boosts=True)
async for follow in query.select_related("source"): for follow in query.select_related("source"):
targets.add(follow.source) targets.add(follow.source)
# Fetch the full blocks and remove them as targets # Fetch the full blocks and remove them as targets
async for block in self.identity.outbound_blocks.active().filter( for block in (
mute=False self.identity.outbound_blocks.active()
).select_related("target"): .filter(mute=False)
.select_related("target")
):
try: try:
targets.remove(block.target) targets.remove(block.target)
except KeyError: except KeyError:

View file

@ -16,7 +16,6 @@ class PushSubscriptionSchema(BaseModel):
keys: Keys keys: Keys
alerts: dict[str, bool] alerts: dict[str, bool]
policy: str policy: str
server_key: str
class Token(models.Model): class Token(models.Model):

View file

@ -498,7 +498,7 @@ class PushSubscription(Schema):
value = token.push_subscription value = token.push_subscription
if value: if value:
value["id"] = "1" value["id"] = "1"
value["server_key"] = settings.VAPID_PUBLIC_KEY value["server_key"] = settings.SETUP.VAPID_PUBLIC_KEY
del value["keys"] del value["keys"]
return value return value
else: else:

View file

@ -2,7 +2,6 @@ from functools import partial
from typing import ClassVar from typing import ClassVar
import pydantic import pydantic
from asgiref.sync import sync_to_async
from django.core.files import File from django.core.files import File
from django.db import models from django.db import models
from django.utils.functional import lazy from django.utils.functional import lazy
@ -97,16 +96,6 @@ class Config(models.Model):
{"identity__isnull": True, "user__isnull": True, "domain__isnull": True}, {"identity__isnull": True, "user__isnull": True, "domain__isnull": True},
) )
@classmethod
async def aload_system(cls):
"""
Async loads the system config options object
"""
return await sync_to_async(cls.load_values)(
cls.SystemOptions,
{"identity__isnull": True, "user__isnull": True, "domain__isnull": True},
)
@classmethod @classmethod
def load_user(cls, user): def load_user(cls, user):
""" """
@ -117,16 +106,6 @@ class Config(models.Model):
{"identity__isnull": True, "user": user, "domain__isnull": True}, {"identity__isnull": True, "user": user, "domain__isnull": True},
) )
@classmethod
async def aload_user(cls, user):
"""
Async loads the user config options object
"""
return await sync_to_async(cls.load_values)(
cls.UserOptions,
{"identity__isnull": True, "user": user, "domain__isnull": True},
)
@classmethod @classmethod
def load_identity(cls, identity): def load_identity(cls, identity):
""" """
@ -137,16 +116,6 @@ class Config(models.Model):
{"identity": identity, "user__isnull": True, "domain__isnull": True}, {"identity": identity, "user__isnull": True, "domain__isnull": True},
) )
@classmethod
async def aload_identity(cls, identity):
"""
Async loads an identity config options object
"""
return await sync_to_async(cls.load_values)(
cls.IdentityOptions,
{"identity": identity, "user__isnull": True, "domain__isnull": True},
)
@classmethod @classmethod
def load_domain(cls, domain): def load_domain(cls, domain):
""" """
@ -157,16 +126,6 @@ class Config(models.Model):
{"domain": domain, "user__isnull": True, "identity__isnull": True}, {"domain": domain, "user__isnull": True, "identity__isnull": True},
) )
@classmethod
async def aload_domain(cls, domain):
"""
Async loads an domain config options object
"""
return await sync_to_async(cls.load_values)(
cls.DomainOptions,
{"domain": domain, "user__isnull": True, "identity__isnull": True},
)
@classmethod @classmethod
def set_value(cls, key, value, options_class, filters): def set_value(cls, key, value, options_class, filters):
config_field = options_class.__fields__[key] config_field = options_class.__fields__[key]

View file

@ -16,7 +16,7 @@ class Command(BaseCommand):
"--concurrency", "--concurrency",
"-c", "-c",
type=int, type=int,
default=30, default=15,
help="How many tasks to run at once", help="How many tasks to run at once",
) )
parser.add_argument( parser.add_argument(

View file

@ -12,7 +12,7 @@ class RequestRunner(View):
For when you're on something serverless. For when you're on something serverless.
""" """
async def get(self, request): def get(self, request):
# Check the token, if supplied # Check the token, if supplied
if not settings.STATOR_TOKEN: if not settings.STATOR_TOKEN:
return HttpResponseForbidden("No token set") return HttpResponseForbidden("No token set")
@ -20,5 +20,5 @@ class RequestRunner(View):
return HttpResponseForbidden("Invalid token") return HttpResponseForbidden("Invalid token")
# Run on all models # Run on all models
runner = StatorRunner(StatorModel.subclasses, run_for=2) runner = StatorRunner(StatorModel.subclasses, run_for=2)
handled = await runner.run() handled = runner.run()
return HttpResponse(f"Handled {handled}") return HttpResponse(f"Handled {handled}")

View file

@ -58,7 +58,7 @@ class BlockStates(StateGraph):
return cls.sent return cls.sent
@classmethod @classmethod
async def handle_awaiting_expiry(cls, instance: "Block"): def handle_awaiting_expiry(cls, instance: "Block"):
""" """
Checks to see if there is an expiry we should undo Checks to see if there is an expiry we should undo
""" """

View file

@ -60,7 +60,7 @@ class FollowStates(StateGraph):
return cls.local_requested return cls.local_requested
@classmethod @classmethod
async def handle_local_requested(cls, instance: "Follow"): def handle_local_requested(cls, instance: "Follow"):
# TODO: Resend follow requests occasionally # TODO: Resend follow requests occasionally
pass pass

View file

@ -119,7 +119,7 @@ class IdentityStates(StateGraph):
return cls.updated return cls.updated
@classmethod @classmethod
async def handle_updated(cls, instance: "Identity"): def handle_updated(cls, instance: "Identity"):
if instance.state_age > Config.system.identity_max_age: if instance.state_age > Config.system.identity_max_age:
return cls.outdated return cls.outdated

View file

@ -1,7 +1,6 @@
import random import random
import string import string
from asgiref.sync import sync_to_async
from django.conf import settings from django.conf import settings
from django.core.mail import send_mail from django.core.mail import send_mail
from django.db import models from django.db import models
@ -18,18 +17,17 @@ class PasswordResetStates(StateGraph):
new.transitions_to(sent) new.transitions_to(sent)
@classmethod @classmethod
async def handle_new(cls, instance: "PasswordReset"): def handle_new(cls, instance: "PasswordReset"):
""" """
Sends the password reset email. Sends the password reset email.
""" """
reset = await instance.afetch_full() if instance.new_account:
if reset.new_account: send_mail(
await sync_to_async(send_mail)(
subject=f"{Config.system.site_name}: Confirm new account", subject=f"{Config.system.site_name}: Confirm new account",
message=render_to_string( message=render_to_string(
"emails/account_new.txt", "emails/account_new.txt",
{ {
"reset": reset, "reset": instance,
"config": Config.system, "config": Config.system,
"settings": settings, "settings": settings,
}, },
@ -37,21 +35,21 @@ class PasswordResetStates(StateGraph):
html_message=render_to_string( html_message=render_to_string(
"emails/account_new.html", "emails/account_new.html",
{ {
"reset": reset, "reset": instance,
"config": Config.system, "config": Config.system,
"settings": settings, "settings": settings,
}, },
), ),
from_email=settings.SERVER_EMAIL, from_email=settings.SERVER_EMAIL,
recipient_list=[reset.user.email], recipient_list=[instance.user.email],
) )
else: else:
await sync_to_async(send_mail)( send_mail(
subject=f"{Config.system.site_name}: Reset password", subject=f"{Config.system.site_name}: Reset password",
message=render_to_string( message=render_to_string(
"emails/password_reset.txt", "emails/password_reset.txt",
{ {
"reset": reset, "reset": instance,
"config": Config.system, "config": Config.system,
"settings": settings, "settings": settings,
}, },
@ -59,13 +57,13 @@ class PasswordResetStates(StateGraph):
html_message=render_to_string( html_message=render_to_string(
"emails/password_reset.html", "emails/password_reset.html",
{ {
"reset": reset, "reset": instance,
"config": Config.system, "config": Config.system,
"settings": settings, "settings": settings,
}, },
), ),
from_email=settings.SERVER_EMAIL, from_email=settings.SERVER_EMAIL,
recipient_list=[reset.user.email], recipient_list=[instance.user.email],
) )
return cls.sent return cls.sent
@ -96,13 +94,3 @@ class PasswordReset(StatorModel):
token="".join(random.choice(string.ascii_lowercase) for i in range(42)), token="".join(random.choice(string.ascii_lowercase) for i in range(42)),
new_account=not user.password, new_account=not user.password,
) )
### Async helpers ###
async def afetch_full(self):
"""
Returns a version of the object with all relations pre-loaded
"""
return await PasswordReset.objects.select_related(
"user",
).aget(pk=self.pk)