Refactors how mentions are collected

This should be quicker, because it queries the users in one go instead
of jumping straight to iterating over them, and it checks if a user
blocks the poster before allowing them to be tagged.
This commit is contained in:
Mouse Reeve 2022-08-05 14:04:16 -07:00
parent bdf563748c
commit 5c3bb2da13
4 changed files with 96 additions and 53 deletions

View file

@ -6,6 +6,7 @@ from django.test import TestCase
from django.test.client import RequestFactory from django.test.client import RequestFactory
from bookwyrm import forms, models, views from bookwyrm import forms, models, views
from bookwyrm.views.status import find_mentions
from bookwyrm.settings import DOMAIN from bookwyrm.settings import DOMAIN
from bookwyrm.tests.validate_html import validate_html from bookwyrm.tests.validate_html import validate_html
@ -34,6 +35,13 @@ class StatusViews(TestCase):
localname="mouse", localname="mouse",
remote_id="https://example.com/users/mouse", remote_id="https://example.com/users/mouse",
) )
self.another_user = models.User.objects.create_user(
f"nutria@{DOMAIN}",
"nutria@nutria.com",
"password",
local=True,
localname="nutria",
)
with patch("bookwyrm.models.user.set_remote_server"): with patch("bookwyrm.models.user.set_remote_server"):
self.remote_user = models.User.objects.create_user( self.remote_user = models.User.objects.create_user(
"rat", "rat",
@ -211,51 +219,66 @@ class StatusViews(TestCase):
self.assertFalse(self.remote_user in reply.mention_users.all()) self.assertFalse(self.remote_user in reply.mention_users.all())
self.assertTrue(self.local_user in reply.mention_users.all()) self.assertTrue(self.local_user in reply.mention_users.all())
def test_find_mentions(self, *_): def test_find_mentions_local(self, *_):
"""detect and look up @ mentions of users""" """detect and look up @ mentions of users"""
user = models.User.objects.create_user( result = find_mentions(self.local_user, "@nutria")
f"nutria@{DOMAIN}", self.assertEqual(result["@nutria"], self.another_user)
"nutria@nutria.com", self.assertEqual(result[f"@nutria@{DOMAIN}"], self.another_user)
"password",
local=True,
localname="nutria",
)
self.assertEqual(user.username, f"nutria@{DOMAIN}")
result = find_mentions(self.local_user, f"@nutria@{DOMAIN}")
self.assertEqual(result["@nutria"], self.another_user)
self.assertEqual(result[f"@nutria@{DOMAIN}"], self.another_user)
result = find_mentions(self.local_user, "leading text @nutria")
self.assertEqual(result["@nutria"], self.another_user)
self.assertEqual(result[f"@nutria@{DOMAIN}"], self.another_user)
result = find_mentions(self.local_user, "leading @nutria trailing")
self.assertEqual(result["@nutria"], self.another_user)
self.assertEqual(result[f"@nutria@{DOMAIN}"], self.another_user)
self.assertEqual(find_mentions(self.local_user, "leading@nutria"), {})
def test_find_mentions_remote(self, *_):
"""detect and look up @ mentions of users"""
self.assertEqual( self.assertEqual(
list(views.status.find_mentions("@nutria"))[0], ("@nutria", user) find_mentions(self.local_user, "@rat@example.com"),
) {"@rat@example.com": self.remote_user},
self.assertEqual(
list(views.status.find_mentions("leading text @nutria"))[0],
("@nutria", user),
)
self.assertEqual(
list(views.status.find_mentions("leading @nutria trailing text"))[0],
("@nutria", user),
)
self.assertEqual(
list(views.status.find_mentions("@rat@example.com"))[0],
("@rat@example.com", self.remote_user),
) )
multiple = list(views.status.find_mentions("@nutria and @rat@example.com")) def test_find_mentions_multiple(self, *_):
self.assertEqual(multiple[0], ("@nutria", user)) """detect and look up @ mentions of users"""
self.assertEqual(multiple[1], ("@rat@example.com", self.remote_user)) multiple = find_mentions(self.local_user, "@nutria and @rat@example.com")
self.assertEqual(multiple["@nutria"], self.another_user)
self.assertEqual(multiple[f"@nutria@{DOMAIN}"], self.another_user)
self.assertEqual(multiple["@rat@example.com"], self.remote_user)
self.assertIsNone(multiple.get("@rat"))
def test_find_mentions_unknown(self, *_):
"""detect and look up @ mentions of users"""
multiple = find_mentions(self.local_user, "@nutria and @rdkjfgh")
self.assertEqual(multiple["@nutria"], self.another_user)
self.assertEqual(multiple[f"@nutria@{DOMAIN}"], self.another_user)
def test_find_mentions_blocked(self, *_):
"""detect and look up @ mentions of users"""
self.another_user.blocks.add(self.local_user)
result = find_mentions(self.local_user, "@nutria hello")
self.assertEqual(result, {})
def test_find_mentions_unknown_remote(self, *_):
"""mention a user that isn't in the database"""
with patch("bookwyrm.views.status.handle_remote_webfinger") as rw: with patch("bookwyrm.views.status.handle_remote_webfinger") as rw:
rw.return_value = self.local_user rw.return_value = self.another_user
self.assertEqual( result = find_mentions(self.local_user, "@beep@beep.com")
list(views.status.find_mentions("@beep@beep.com"))[0], self.assertEqual(result["@nutria"], self.another_user)
("@beep@beep.com", self.local_user), self.assertEqual(result[f"@nutria@{DOMAIN}"], self.another_user)
)
with patch("bookwyrm.views.status.handle_remote_webfinger") as rw: with patch("bookwyrm.views.status.handle_remote_webfinger") as rw:
rw.return_value = None rw.return_value = None
self.assertEqual(list(views.status.find_mentions("@beep@beep.com")), []) result = find_mentions(self.local_user, "@beep@beep.com")
self.assertEqual(result, {})
self.assertEqual(
list(views.status.find_mentions(f"@nutria@{DOMAIN}"))[0],
(f"@nutria@{DOMAIN}", user),
)
def test_format_links_simple_url(self, *_): def test_format_links_simple_url(self, *_):
"""find and format urls into a tags""" """find and format urls into a tags"""

View file

@ -4,7 +4,7 @@ DOMAIN = r"[\w_\-\.]+\.[a-z\-]{2,}"
LOCALNAME = r"@?[a-zA-Z_\-\.0-9]+" LOCALNAME = r"@?[a-zA-Z_\-\.0-9]+"
STRICT_LOCALNAME = r"@[a-zA-Z_\-\.0-9]+" STRICT_LOCALNAME = r"@[a-zA-Z_\-\.0-9]+"
USERNAME = rf"{LOCALNAME}(@{DOMAIN})?" USERNAME = rf"{LOCALNAME}(@{DOMAIN})?"
STRICT_USERNAME = rf"\B{STRICT_LOCALNAME}(@{DOMAIN})?\b" STRICT_USERNAME = rf"(\B{STRICT_LOCALNAME}(@{DOMAIN})?\b)"
FULL_USERNAME = rf"{LOCALNAME}@{DOMAIN}\b" FULL_USERNAME = rf"{LOCALNAME}@{DOMAIN}\b"
SLUG = r"/s/(?P<slug>[-_a-z0-9]*)" SLUG = r"/s/(?P<slug>[-_a-z0-9]*)"
# should match (BookWyrm/1.0.0; or (BookWyrm/99.1.2; # should match (BookWyrm/1.0.0; or (BookWyrm/99.1.2;

View file

@ -59,7 +59,7 @@ def is_bookwyrm_request(request):
return True return True
def handle_remote_webfinger(query): def handle_remote_webfinger(query, unknown_only=False):
"""webfingerin' other servers""" """webfingerin' other servers"""
user = None user = None
@ -75,6 +75,11 @@ def handle_remote_webfinger(query):
try: try:
user = models.User.objects.get(username__iexact=query) user = models.User.objects.get(username__iexact=query)
if unknown_only:
# In this case, we only want to know about previously undiscovered users
# So the fact that we found a match in the database means no results
return None
except models.User.DoesNotExist: except models.User.DoesNotExist:
url = f"https://{domain}/.well-known/webfinger?resource=acct:{query}" url = f"https://{domain}/.well-known/webfinger?resource=acct:{query}"
try: try:

View file

@ -6,6 +6,7 @@ from urllib.parse import urlparse
from django.contrib.auth.decorators import login_required from django.contrib.auth.decorators import login_required
from django.core.validators import URLValidator from django.core.validators import URLValidator
from django.core.exceptions import ValidationError from django.core.exceptions import ValidationError
from django.db.models import Q
from django.http import HttpResponse, HttpResponseBadRequest, Http404 from django.http import HttpResponse, HttpResponseBadRequest, Http404
from django.shortcuts import get_object_or_404, redirect from django.shortcuts import get_object_or_404, redirect
from django.template.response import TemplateResponse from django.template.response import TemplateResponse
@ -16,7 +17,6 @@ from django.views.decorators.http import require_POST
from markdown import markdown from markdown import markdown
from bookwyrm import forms, models from bookwyrm import forms, models
from bookwyrm.settings import DOMAIN
from bookwyrm.utils import regex, sanitizer from bookwyrm.utils import regex, sanitizer
from .helpers import handle_remote_webfinger, is_api_request from .helpers import handle_remote_webfinger, is_api_request
from .helpers import load_date_in_user_tz_as_utc from .helpers import load_date_in_user_tz_as_utc
@ -96,14 +96,16 @@ class CreateStatus(View):
# inspect the text for user tags # inspect the text for user tags
content = status.content content = status.content
for (mention_text, mention_user) in find_mentions(content): for (mention_text, mention_user) in find_mentions(
request.user, content
).items():
# add them to status mentions fk # add them to status mentions fk
status.mention_users.add(mention_user) status.mention_users.add(mention_user)
# turn the mention into a link # turn the mention into a link
content = re.sub( content = re.sub(
rf"{mention_text}([^@]|$)", rf"{mention_text}\b(?!@)",
rf'<a href="{mention_user.remote_id}">{mention_text}</a>\g<1>', rf'<a href="{mention_user.remote_id}">{mention_text}</a>',
content, content,
) )
# add reply parent to mentions # add reply parent to mentions
@ -199,22 +201,35 @@ def edit_readthrough(request):
return redirect("/") return redirect("/")
def find_mentions(content): def find_mentions(user, content):
"""detect @mentions in raw status content""" """detect @mentions in raw status content"""
if not content: if not content:
return return {}
for match in re.finditer(regex.STRICT_USERNAME, content): # The regex has nested match groups, so the 0th entry has the full (outer) match
username = match.group().strip().split("@")[1:] # And beacuse the strict username starts with @, the username is 1st char onward
if len(username) == 1: usernames = [m[0][1:] for m in re.findall(regex.STRICT_USERNAME, content)]
# this looks like a local user (@user), fill in the domain
username.append(DOMAIN)
username = "@".join(username)
mention_user = handle_remote_webfinger(username) known_users = (
models.User.viewer_aware_objects(user)
.filter(Q(username__in=usernames) | Q(localname__in=usernames))
.distinct()
)
# Prepare a lookup based on both username and localname
username_dict = {
**{f"@{u.username}": u for u in known_users},
**{f"@{u.localname}": u for u in known_users.filter(local=True)},
}
# Users not captured here could be blocked or not yet loaded on the server
not_found = set(usernames) - set(username_dict.keys())
for username in not_found:
mention_user = handle_remote_webfinger(username, unknown_only=True)
if not mention_user: if not mention_user:
# we can ignore users we don't know about # this user is blocked or can't be found
continue continue
yield (match.group(), mention_user) username_dict[f"@{mention_user.username}"] = mention_user
username_dict[f"@{mention_user.localname}"] = mention_user
return username_dict
def format_links(content): def format_links(content):