Merge pull request #879 from mouse-reeve/redis-stores

Created generic redis class for activitystreams
This commit is contained in:
Mouse Reeve 2021-04-06 08:27:21 -07:00 committed by GitHub
commit b96a092e71
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 178 additions and 111 deletions

View file

@ -1,18 +1,13 @@
""" access the activity streams stored in redis """ """ access the activity streams stored in redis """
from abc import ABC
from django.dispatch import receiver from django.dispatch import receiver
from django.db.models import signals, Q from django.db.models import signals, Q
import redis
from bookwyrm import models, settings from bookwyrm import models
from bookwyrm.redis_store import RedisStore, r
from bookwyrm.views.helpers import privacy_filter from bookwyrm.views.helpers import privacy_filter
r = redis.Redis(
host=settings.REDIS_ACTIVITY_HOST, port=settings.REDIS_ACTIVITY_PORT, db=0
)
class ActivityStream(RedisStore):
class ActivityStream(ABC):
""" a category of activity stream (like home, local, federated) """ """ a category of activity stream (like home, local, federated) """
def stream_id(self, user): def stream_id(self, user):
@ -23,58 +18,40 @@ class ActivityStream(ABC):
""" the redis key for this user's unread count for this stream """ """ the redis key for this user's unread count for this stream """
return "{}-unread".format(self.stream_id(user)) return "{}-unread".format(self.stream_id(user))
def get_value(self, status): # pylint: disable=no-self-use def get_rank(self, obj): # pylint: disable=no-self-use
""" the status id and the rank (ie, published date) """ """ statuses are sorted by date published """
return {status.id: status.published_date.timestamp()} return obj.published_date.timestamp()
def add_status(self, status): def add_status(self, status):
""" add a status to users' feeds """ """ add a status to users' feeds """
value = self.get_value(status) # the pipeline contains all the add-to-stream activities
# we want to do this as a bulk operation, hence "pipeline" pipeline = self.add_object_to_related_stores(status, execute=False)
pipeline = r.pipeline()
for user in self.stream_users(status): for user in self.get_audience(status):
# add the status to the feed
pipeline.zadd(self.stream_id(user), value)
pipeline.zremrangebyrank(
self.stream_id(user), 0, -1 * settings.MAX_STREAM_LENGTH
)
# add to the unread status count # add to the unread status count
pipeline.incr(self.unread_id(user)) pipeline.incr(self.unread_id(user))
# and go!
pipeline.execute()
def remove_status(self, status): # and go!
""" remove a status from all feeds """
pipeline = r.pipeline()
for user in self.stream_users(status):
pipeline.zrem(self.stream_id(user), -1, status.id)
pipeline.execute() pipeline.execute()
def add_user_statuses(self, viewer, user): def add_user_statuses(self, viewer, user):
""" add a user's statuses to another user's feed """ """ add a user's statuses to another user's feed """
pipeline = r.pipeline() # only add the statuses that the viewer should be able to see (ie, not dms)
statuses = user.status_set.all()[: settings.MAX_STREAM_LENGTH] statuses = privacy_filter(viewer, user.status_set.all())
for status in statuses: self.bulk_add_objects_to_store(statuses, self.stream_id(viewer))
pipeline.zadd(self.stream_id(viewer), self.get_value(status))
if statuses:
pipeline.zremrangebyrank(
self.stream_id(user), 0, -1 * settings.MAX_STREAM_LENGTH
)
pipeline.execute()
def remove_user_statuses(self, viewer, user): def remove_user_statuses(self, viewer, user):
""" remove a user's status from another user's feed """ """ remove a user's status from another user's feed """
pipeline = r.pipeline() # remove all so that followers only statuses are removed
for status in user.status_set.all()[: settings.MAX_STREAM_LENGTH]: statuses = user.status_set.all()
pipeline.lrem(self.stream_id(viewer), -1, status.id) self.bulk_remove_objects_from_store(statuses, self.stream_id(viewer))
pipeline.execute()
def get_activity_stream(self, user): def get_activity_stream(self, user):
""" load the ids for statuses to be displayed """ """ load the statuses to be displayed """
# clear unreads for this feed # clear unreads for this feed
r.set(self.unread_id(user), 0) r.set(self.unread_id(user), 0)
statuses = r.zrevrange(self.stream_id(user), 0, -1) statuses = self.get_store(self.stream_id(user))
return ( return (
models.Status.objects.select_subclasses() models.Status.objects.select_subclasses()
.filter(id__in=statuses) .filter(id__in=statuses)
@ -85,23 +62,11 @@ class ActivityStream(ABC):
""" get the unread status count for this user's feed """ """ get the unread status count for this user's feed """
return int(r.get(self.unread_id(user)) or 0) return int(r.get(self.unread_id(user)) or 0)
def populate_stream(self, user): def populate_streams(self, user):
""" go from zero to a timeline """ """ go from zero to a timeline """
pipeline = r.pipeline() self.populate_store(self.stream_id(user))
statuses = self.stream_statuses(user)
stream_id = self.stream_id(user) def get_audience(self, status): # pylint: disable=no-self-use
for status in statuses.all()[: settings.MAX_STREAM_LENGTH]:
pipeline.zadd(stream_id, self.get_value(status))
# only trim the stream if statuses were added
if statuses.exists():
pipeline.zremrangebyrank(
self.stream_id(user), 0, -1 * settings.MAX_STREAM_LENGTH
)
pipeline.execute()
def stream_users(self, status): # pylint: disable=no-self-use
""" given a status, what users should see it """ """ given a status, what users should see it """
# direct messages don't appeard in feeds, direct comments/reviews/etc do # direct messages don't appeard in feeds, direct comments/reviews/etc do
if status.privacy == "direct" and status.status_type == "Note": if status.privacy == "direct" and status.status_type == "Note":
@ -129,7 +94,10 @@ class ActivityStream(ABC):
) )
return audience.distinct() return audience.distinct()
def stream_statuses(self, user): # pylint: disable=no-self-use def get_stores_for_object(self, obj):
return [self.stream_id(u) for u in self.get_audience(obj)]
def get_statuses_for_user(self, user): # pylint: disable=no-self-use
""" given a user, what statuses should they see on this stream """ """ given a user, what statuses should they see on this stream """
return privacy_filter( return privacy_filter(
user, user,
@ -137,14 +105,18 @@ class ActivityStream(ABC):
privacy_levels=["public", "unlisted", "followers"], privacy_levels=["public", "unlisted", "followers"],
) )
def get_objects_for_store(self, store):
user = models.User.objects.get(id=store.split("-")[0])
return self.get_statuses_for_user(user)
class HomeStream(ActivityStream): class HomeStream(ActivityStream):
""" users you follow """ """ users you follow """
key = "home" key = "home"
def stream_users(self, status): def get_audience(self, status):
audience = super().stream_users(status) audience = super().get_audience(status)
if not audience: if not audience:
return [] return []
return audience.filter( return audience.filter(
@ -152,7 +124,7 @@ class HomeStream(ActivityStream):
| Q(following=status.user) # if the user is following the author | Q(following=status.user) # if the user is following the author
).distinct() ).distinct()
def stream_statuses(self, user): def get_statuses_for_user(self, user):
return privacy_filter( return privacy_filter(
user, user,
models.Status.objects.select_subclasses(), models.Status.objects.select_subclasses(),
@ -166,13 +138,13 @@ class LocalStream(ActivityStream):
key = "local" key = "local"
def stream_users(self, status): def get_audience(self, status):
# this stream wants no part in non-public statuses # this stream wants no part in non-public statuses
if status.privacy != "public" or not status.user.local: if status.privacy != "public" or not status.user.local:
return [] return []
return super().stream_users(status) return super().get_audience(status)
def stream_statuses(self, user): def get_statuses_for_user(self, user):
# all public statuses by a local user # all public statuses by a local user
return privacy_filter( return privacy_filter(
user, user,
@ -186,13 +158,13 @@ class FederatedStream(ActivityStream):
key = "federated" key = "federated"
def stream_users(self, status): def get_audience(self, status):
# this stream wants no part in non-public statuses # this stream wants no part in non-public statuses
if status.privacy != "public": if status.privacy != "public":
return [] return []
return super().stream_users(status) return super().get_audience(status)
def stream_statuses(self, user): def get_statuses_for_user(self, user):
return privacy_filter( return privacy_filter(
user, user,
models.Status.objects.select_subclasses(), models.Status.objects.select_subclasses(),
@ -217,7 +189,7 @@ def add_status_on_create(sender, instance, created, *args, **kwargs):
if instance.deleted: if instance.deleted:
for stream in streams.values(): for stream in streams.values():
stream.remove_status(instance) stream.remove_object_from_related_stores(instance)
return return
if not created: if not created:
@ -234,7 +206,7 @@ def remove_boost_on_delete(sender, instance, *args, **kwargs):
""" boosts are deleted """ """ boosts are deleted """
# we're only interested in new statuses # we're only interested in new statuses
for stream in streams.values(): for stream in streams.values():
stream.remove_status(instance) stream.remove_object_from_related_stores(instance)
@receiver(signals.post_save, sender=models.UserFollows) @receiver(signals.post_save, sender=models.UserFollows)
@ -294,4 +266,4 @@ def populate_streams_on_account_create(sender, instance, created, *args, **kwarg
return return
for stream in streams.values(): for stream in streams.values():
stream.populate_stream(instance) stream.populate_streams(instance)

86
bookwyrm/redis_store.py Normal file
View file

@ -0,0 +1,86 @@
""" access the activity stores stored in redis """
from abc import ABC, abstractmethod
import redis
from bookwyrm import settings
r = redis.Redis(
host=settings.REDIS_ACTIVITY_HOST, port=settings.REDIS_ACTIVITY_PORT, db=0
)
class RedisStore(ABC):
""" sets of ranked, related objects, like statuses for a user's feed """
max_length = settings.MAX_STREAM_LENGTH
def get_value(self, obj):
""" the object and rank """
return {obj.id: self.get_rank(obj)}
def add_object_to_related_stores(self, obj, execute=True):
""" add an object to all suitable stores """
value = self.get_value(obj)
# we want to do this as a bulk operation, hence "pipeline"
pipeline = r.pipeline()
for store in self.get_stores_for_object(obj):
# add the status to the feed
pipeline.zadd(store, value)
# trim the store
pipeline.zremrangebyrank(store, 0, -1 * self.max_length)
if not execute:
return pipeline
# and go!
return pipeline.execute()
def remove_object_from_related_stores(self, obj):
""" remove an object from all stores """
pipeline = r.pipeline()
for store in self.get_stores_for_object(obj):
pipeline.zrem(store, -1, obj.id)
pipeline.execute()
def bulk_add_objects_to_store(self, objs, store):
""" add a list of objects to a given store """
pipeline = r.pipeline()
for obj in objs[: self.max_length]:
pipeline.zadd(store, self.get_value(obj))
if objs:
pipeline.zremrangebyrank(store, 0, -1 * self.max_length)
pipeline.execute()
def bulk_remove_objects_from_store(self, objs, store):
""" remoev a list of objects from a given store """
pipeline = r.pipeline()
for obj in objs[: self.max_length]:
pipeline.zrem(store, -1, obj.id)
pipeline.execute()
def get_store(self, store): # pylint: disable=no-self-use
""" load the values in a store """
return r.zrevrange(store, 0, -1)
def populate_store(self, store):
""" go from zero to a store """
pipeline = r.pipeline()
queryset = self.get_objects_for_store(store)
for obj in queryset[: self.max_length]:
pipeline.zadd(store, self.get_value(obj))
# only trim the store if objects were added
if queryset.exists():
pipeline.zremrangebyrank(store, 0, -1 * self.max_length)
pipeline.execute()
@abstractmethod
def get_objects_for_store(self, store):
""" a queryset of what should go in a store, used for populating it """
@abstractmethod
def get_stores_for_object(self, obj):
""" the stores that an object belongs in """
@abstractmethod
def get_rank(self, obj):
""" how to rank an object """

View file

@ -116,7 +116,9 @@ class Status(TestCase):
def test_status_to_activity_tombstone(self, *_): def test_status_to_activity_tombstone(self, *_):
""" subclass of the base model version with a "pure" serializer """ """ subclass of the base model version with a "pure" serializer """
with patch("bookwyrm.activitystreams.ActivityStream.remove_status"): with patch(
"bookwyrm.activitystreams.ActivityStream.remove_object_from_related_stores"
):
status = models.Status.objects.create( status = models.Status.objects.create(
content="test content", content="test content",
user=self.local_user, user=self.local_user,

View file

@ -47,18 +47,18 @@ class Activitystreams(TestCase):
"{}-test-unread".format(self.local_user.id), "{}-test-unread".format(self.local_user.id),
) )
def test_abstractstream_stream_users(self, *_): def test_abstractstream_get_audience(self, *_):
""" get a list of users that should see a status """ """ get a list of users that should see a status """
status = models.Status.objects.create( status = models.Status.objects.create(
user=self.remote_user, content="hi", privacy="public" user=self.remote_user, content="hi", privacy="public"
) )
users = self.test_stream.stream_users(status) users = self.test_stream.get_audience(status)
# remote users don't have feeds # remote users don't have feeds
self.assertFalse(self.remote_user in users) self.assertFalse(self.remote_user in users)
self.assertTrue(self.local_user in users) self.assertTrue(self.local_user in users)
self.assertTrue(self.another_user in users) self.assertTrue(self.another_user in users)
def test_abstractstream_stream_users_direct(self, *_): def test_abstractstream_get_audience_direct(self, *_):
""" get a list of users that should see a status """ """ get a list of users that should see a status """
status = models.Status.objects.create( status = models.Status.objects.create(
user=self.remote_user, user=self.remote_user,
@ -66,7 +66,7 @@ class Activitystreams(TestCase):
privacy="direct", privacy="direct",
) )
status.mention_users.add(self.local_user) status.mention_users.add(self.local_user)
users = self.test_stream.stream_users(status) users = self.test_stream.get_audience(status)
self.assertEqual(users, []) self.assertEqual(users, [])
status = models.Comment.objects.create( status = models.Comment.objects.create(
@ -76,22 +76,22 @@ class Activitystreams(TestCase):
book=self.book, book=self.book,
) )
status.mention_users.add(self.local_user) status.mention_users.add(self.local_user)
users = self.test_stream.stream_users(status) users = self.test_stream.get_audience(status)
self.assertTrue(self.local_user in users) self.assertTrue(self.local_user in users)
self.assertFalse(self.another_user in users) self.assertFalse(self.another_user in users)
self.assertFalse(self.remote_user in users) self.assertFalse(self.remote_user in users)
def test_abstractstream_stream_users_followers_remote_user(self, *_): def test_abstractstream_get_audience_followers_remote_user(self, *_):
""" get a list of users that should see a status """ """ get a list of users that should see a status """
status = models.Status.objects.create( status = models.Status.objects.create(
user=self.remote_user, user=self.remote_user,
content="hi", content="hi",
privacy="followers", privacy="followers",
) )
users = self.test_stream.stream_users(status) users = self.test_stream.get_audience(status)
self.assertFalse(users.exists()) self.assertFalse(users.exists())
def test_abstractstream_stream_users_followers_self(self, *_): def test_abstractstream_get_audience_followers_self(self, *_):
""" get a list of users that should see a status """ """ get a list of users that should see a status """
status = models.Comment.objects.create( status = models.Comment.objects.create(
user=self.local_user, user=self.local_user,
@ -99,12 +99,12 @@ class Activitystreams(TestCase):
privacy="direct", privacy="direct",
book=self.book, book=self.book,
) )
users = self.test_stream.stream_users(status) users = self.test_stream.get_audience(status)
self.assertTrue(self.local_user in users) self.assertTrue(self.local_user in users)
self.assertFalse(self.another_user in users) self.assertFalse(self.another_user in users)
self.assertFalse(self.remote_user in users) self.assertFalse(self.remote_user in users)
def test_abstractstream_stream_users_followers_with_mention(self, *_): def test_abstractstream_get_audience_followers_with_mention(self, *_):
""" get a list of users that should see a status """ """ get a list of users that should see a status """
status = models.Comment.objects.create( status = models.Comment.objects.create(
user=self.remote_user, user=self.remote_user,
@ -114,12 +114,12 @@ class Activitystreams(TestCase):
) )
status.mention_users.add(self.local_user) status.mention_users.add(self.local_user)
users = self.test_stream.stream_users(status) users = self.test_stream.get_audience(status)
self.assertTrue(self.local_user in users) self.assertTrue(self.local_user in users)
self.assertFalse(self.another_user in users) self.assertFalse(self.another_user in users)
self.assertFalse(self.remote_user in users) self.assertFalse(self.remote_user in users)
def test_abstractstream_stream_users_followers_with_relationship(self, *_): def test_abstractstream_get_audience_followers_with_relationship(self, *_):
""" get a list of users that should see a status """ """ get a list of users that should see a status """
self.remote_user.followers.add(self.local_user) self.remote_user.followers.add(self.local_user)
status = models.Comment.objects.create( status = models.Comment.objects.create(
@ -128,77 +128,77 @@ class Activitystreams(TestCase):
privacy="direct", privacy="direct",
book=self.book, book=self.book,
) )
users = self.test_stream.stream_users(status) users = self.test_stream.get_audience(status)
self.assertFalse(self.local_user in users) self.assertFalse(self.local_user in users)
self.assertFalse(self.another_user in users) self.assertFalse(self.another_user in users)
self.assertFalse(self.remote_user in users) self.assertFalse(self.remote_user in users)
def test_homestream_stream_users(self, *_): def test_homestream_get_audience(self, *_):
""" get a list of users that should see a status """ """ get a list of users that should see a status """
status = models.Status.objects.create( status = models.Status.objects.create(
user=self.remote_user, content="hi", privacy="public" user=self.remote_user, content="hi", privacy="public"
) )
users = activitystreams.HomeStream().stream_users(status) users = activitystreams.HomeStream().get_audience(status)
self.assertFalse(users.exists()) self.assertFalse(users.exists())
def test_homestream_stream_users_with_mentions(self, *_): def test_homestream_get_audience_with_mentions(self, *_):
""" get a list of users that should see a status """ """ get a list of users that should see a status """
status = models.Status.objects.create( status = models.Status.objects.create(
user=self.remote_user, content="hi", privacy="public" user=self.remote_user, content="hi", privacy="public"
) )
status.mention_users.add(self.local_user) status.mention_users.add(self.local_user)
users = activitystreams.HomeStream().stream_users(status) users = activitystreams.HomeStream().get_audience(status)
self.assertFalse(self.local_user in users) self.assertFalse(self.local_user in users)
self.assertFalse(self.another_user in users) self.assertFalse(self.another_user in users)
def test_homestream_stream_users_with_relationship(self, *_): def test_homestream_get_audience_with_relationship(self, *_):
""" get a list of users that should see a status """ """ get a list of users that should see a status """
self.remote_user.followers.add(self.local_user) self.remote_user.followers.add(self.local_user)
status = models.Status.objects.create( status = models.Status.objects.create(
user=self.remote_user, content="hi", privacy="public" user=self.remote_user, content="hi", privacy="public"
) )
users = activitystreams.HomeStream().stream_users(status) users = activitystreams.HomeStream().get_audience(status)
self.assertTrue(self.local_user in users) self.assertTrue(self.local_user in users)
self.assertFalse(self.another_user in users) self.assertFalse(self.another_user in users)
def test_localstream_stream_users_remote_status(self, *_): def test_localstream_get_audience_remote_status(self, *_):
""" get a list of users that should see a status """ """ get a list of users that should see a status """
status = models.Status.objects.create( status = models.Status.objects.create(
user=self.remote_user, content="hi", privacy="public" user=self.remote_user, content="hi", privacy="public"
) )
users = activitystreams.LocalStream().stream_users(status) users = activitystreams.LocalStream().get_audience(status)
self.assertEqual(users, []) self.assertEqual(users, [])
def test_localstream_stream_users_local_status(self, *_): def test_localstream_get_audience_local_status(self, *_):
""" get a list of users that should see a status """ """ get a list of users that should see a status """
status = models.Status.objects.create( status = models.Status.objects.create(
user=self.local_user, content="hi", privacy="public" user=self.local_user, content="hi", privacy="public"
) )
users = activitystreams.LocalStream().stream_users(status) users = activitystreams.LocalStream().get_audience(status)
self.assertTrue(self.local_user in users) self.assertTrue(self.local_user in users)
self.assertTrue(self.another_user in users) self.assertTrue(self.another_user in users)
def test_localstream_stream_users_unlisted(self, *_): def test_localstream_get_audience_unlisted(self, *_):
""" get a list of users that should see a status """ """ get a list of users that should see a status """
status = models.Status.objects.create( status = models.Status.objects.create(
user=self.local_user, content="hi", privacy="unlisted" user=self.local_user, content="hi", privacy="unlisted"
) )
users = activitystreams.LocalStream().stream_users(status) users = activitystreams.LocalStream().get_audience(status)
self.assertEqual(users, []) self.assertEqual(users, [])
def test_federatedstream_stream_users(self, *_): def test_federatedstream_get_audience(self, *_):
""" get a list of users that should see a status """ """ get a list of users that should see a status """
status = models.Status.objects.create( status = models.Status.objects.create(
user=self.remote_user, content="hi", privacy="public" user=self.remote_user, content="hi", privacy="public"
) )
users = activitystreams.FederatedStream().stream_users(status) users = activitystreams.FederatedStream().get_audience(status)
self.assertTrue(self.local_user in users) self.assertTrue(self.local_user in users)
self.assertTrue(self.another_user in users) self.assertTrue(self.another_user in users)
def test_federatedstream_stream_users_unlisted(self, *_): def test_federatedstream_get_audience_unlisted(self, *_):
""" get a list of users that should see a status """ """ get a list of users that should see a status """
status = models.Status.objects.create( status = models.Status.objects.create(
user=self.remote_user, content="hi", privacy="unlisted" user=self.remote_user, content="hi", privacy="unlisted"
) )
users = activitystreams.FederatedStream().stream_users(status) users = activitystreams.FederatedStream().get_audience(status)
self.assertEqual(users, []) self.assertEqual(users, [])

View file

@ -85,7 +85,9 @@ class TemplateTags(TestCase):
second_child = models.Status.objects.create( second_child = models.Status.objects.create(
reply_parent=parent, user=self.user, content="hi" reply_parent=parent, user=self.user, content="hi"
) )
with patch("bookwyrm.activitystreams.ActivityStream.remove_status"): with patch(
"bookwyrm.activitystreams.ActivityStream.remove_object_from_related_stores"
):
third_child = models.Status.objects.create( third_child = models.Status.objects.create(
reply_parent=parent, reply_parent=parent,
user=self.user, user=self.user,

View file

@ -444,7 +444,7 @@ class Inbox(TestCase):
"object": {"id": self.status.remote_id, "type": "Tombstone"}, "object": {"id": self.status.remote_id, "type": "Tombstone"},
} }
with patch( with patch(
"bookwyrm.activitystreams.ActivityStream.remove_status" "bookwyrm.activitystreams.ActivityStream.remove_object_from_related_stores"
) as redis_mock: ) as redis_mock:
views.inbox.activity_task(activity) views.inbox.activity_task(activity)
self.assertTrue(redis_mock.called) self.assertTrue(redis_mock.called)
@ -477,7 +477,7 @@ class Inbox(TestCase):
"object": {"id": self.status.remote_id, "type": "Tombstone"}, "object": {"id": self.status.remote_id, "type": "Tombstone"},
} }
with patch( with patch(
"bookwyrm.activitystreams.ActivityStream.remove_status" "bookwyrm.activitystreams.ActivityStream.remove_object_from_related_stores"
) as redis_mock: ) as redis_mock:
views.inbox.activity_task(activity) views.inbox.activity_task(activity)
self.assertTrue(redis_mock.called) self.assertTrue(redis_mock.called)
@ -666,7 +666,7 @@ class Inbox(TestCase):
}, },
} }
with patch( with patch(
"bookwyrm.activitystreams.ActivityStream.remove_status" "bookwyrm.activitystreams.ActivityStream.remove_object_from_related_stores"
) as redis_mock: ) as redis_mock:
views.inbox.activity_task(activity) views.inbox.activity_task(activity)
self.assertTrue(redis_mock.called) self.assertTrue(redis_mock.called)

View file

@ -164,7 +164,7 @@ class InteractionViews(TestCase):
self.assertEqual(models.Boost.objects.count(), 1) self.assertEqual(models.Boost.objects.count(), 1)
self.assertEqual(models.Notification.objects.count(), 1) self.assertEqual(models.Notification.objects.count(), 1)
with patch( with patch(
"bookwyrm.activitystreams.ActivityStream.remove_status" "bookwyrm.activitystreams.ActivityStream.remove_object_from_related_stores"
) as redis_mock: ) as redis_mock:
view(request, status.id) view(request, status.id)
self.assertTrue(redis_mock.called) self.assertTrue(redis_mock.called)

View file

@ -177,7 +177,9 @@ class StatusViews(TestCase):
content="hi", book=self.book, user=self.local_user content="hi", book=self.book, user=self.local_user
) )
with patch("bookwyrm.activitystreams.ActivityStream.remove_status") as mock: with patch(
"bookwyrm.activitystreams.ActivityStream.remove_object_from_related_stores"
) as mock:
result = view(request, status.id) result = view(request, status.id)
self.assertTrue(mock.called) self.assertTrue(mock.called)
result.render() result.render()
@ -196,7 +198,9 @@ class StatusViews(TestCase):
book=self.book, rating=2.0, user=self.local_user book=self.book, rating=2.0, user=self.local_user
) )
with patch("bookwyrm.activitystreams.ActivityStream.remove_status") as mock: with patch(
"bookwyrm.activitystreams.ActivityStream.remove_object_from_related_stores"
) as mock:
result = view(request, status.id) result = view(request, status.id)
self.assertFalse(mock.called) self.assertFalse(mock.called)
self.assertEqual(result.status_code, 400) self.assertEqual(result.status_code, 400)
@ -214,7 +218,9 @@ class StatusViews(TestCase):
content="hi", user=self.local_user content="hi", user=self.local_user
) )
with patch("bookwyrm.activitystreams.ActivityStream.remove_status") as mock: with patch(
"bookwyrm.activitystreams.ActivityStream.remove_object_from_related_stores"
) as mock:
result = view(request, status.id) result = view(request, status.id)
self.assertFalse(mock.called) self.assertFalse(mock.called)
self.assertEqual(result.status_code, 400) self.assertEqual(result.status_code, 400)
@ -316,7 +322,7 @@ class StatusViews(TestCase):
request.user = self.local_user request.user = self.local_user
with patch( with patch(
"bookwyrm.activitystreams.ActivityStream.remove_status" "bookwyrm.activitystreams.ActivityStream.remove_object_from_related_stores"
) as redis_mock: ) as redis_mock:
view(request, status.id) view(request, status.id)
self.assertTrue(redis_mock.called) self.assertTrue(redis_mock.called)
@ -351,7 +357,7 @@ class StatusViews(TestCase):
request.user.is_superuser = True request.user.is_superuser = True
with patch( with patch(
"bookwyrm.activitystreams.ActivityStream.remove_status" "bookwyrm.activitystreams.ActivityStream.remove_object_from_related_stores"
) as redis_mock: ) as redis_mock:
view(request, status.id) view(request, status.id)
self.assertTrue(redis_mock.called) self.assertTrue(redis_mock.called)

View file

@ -31,7 +31,6 @@ class Feed(View):
tab = "home" tab = "home"
activities = activitystreams.streams[tab].get_activity_stream(request.user) activities = activitystreams.streams[tab].get_activity_stream(request.user)
paginated = Paginator(activities, PAGE_LENGTH) paginated = Paginator(activities, PAGE_LENGTH)
suggested_users = get_suggested_users(request.user) suggested_users = get_suggested_users(request.user)