Switch to preserving username case

Except on comparisons, for now. Fixes #224
This commit is contained in:
Andrew Godwin 2022-12-21 21:39:56 +00:00
parent 13ebe14cf9
commit 45de0e09b5
4 changed files with 18 additions and 13 deletions

View file

@ -13,7 +13,7 @@ class SearchService:
""" """
def __init__(self, query: str, identity: Identity | None): def __init__(self, query: str, identity: Identity | None):
self.query = query.strip().lower() self.query = query.strip()
self.identity = identity self.identity = identity
def search_identities_handle(self) -> set[Identity]: def search_identities_handle(self) -> set[Identity]:
@ -26,7 +26,7 @@ class SearchService:
return set() return set()
# Try to fetch the user by handle # Try to fetch the user by handle
handle = self.query.lstrip("@") handle = self.query.lstrip("@").lower()
results: set[Identity] = set() results: set[Identity] = set()
if "@" in handle: if "@" in handle:
username, domain = handle.split("@", 1) username, domain = handle.split("@", 1)
@ -37,7 +37,8 @@ class SearchService:
if domain_instance is None: if domain_instance is None:
raise Identity.DoesNotExist() raise Identity.DoesNotExist()
identity = Identity.objects.get( identity = Identity.objects.get(
domain=domain_instance, username=username domain=domain_instance,
username__iexact=username,
) )
except Identity.DoesNotExist: except Identity.DoesNotExist:
if self.identity is not None: if self.identity is not None:
@ -82,11 +83,13 @@ class SearchService:
type = document.get("type", "unknown").lower() type = document.get("type", "unknown").lower()
# Is it an identity? # Is it an identity?
print(type)
if type in Identity.ACTOR_TYPES: if type in Identity.ACTOR_TYPES:
# Try and retrieve the profile by actor URI # Try and retrieve the profile by actor URI
identity = Identity.by_actor_uri(document["id"], create=True) identity = Identity.by_actor_uri(document["id"], create=True)
print("got identity")
if identity and identity.state == IdentityStates.outdated: if identity and identity.state == IdentityStates.outdated:
async_to_sync(identity.fetch_actor)() print(async_to_sync(identity.fetch_actor)())
return identity return identity
# Is it a post? # Is it a post?
@ -112,7 +115,7 @@ class SearchService:
return set() return set()
results: set[Hashtag] = set() results: set[Hashtag] = set()
name = self.query.lstrip("#") name = self.query.lstrip("#").lower()
for hashtag in Hashtag.objects.public().hashtag_or_alias(name)[:10]: for hashtag in Hashtag.objects.public().hashtag_or_alias(name)[:10]:
results.add(hashtag) results.add(hashtag)
for hashtag in Hashtag.objects.public().filter(hashtag__startswith=name)[:10]: for hashtag in Hashtag.objects.public().filter(hashtag__startswith=name)[:10]:

View file

@ -293,18 +293,17 @@ class Identity(StatorModel):
def by_username_and_domain(cls, username, domain, fetch=False, local=False): def by_username_and_domain(cls, username, domain, fetch=False, local=False):
if username.startswith("@"): if username.startswith("@"):
raise ValueError("Username must not start with @") raise ValueError("Username must not start with @")
username = username.lower()
domain = domain.lower() domain = domain.lower()
try: try:
if local: if local:
return cls.objects.get( return cls.objects.get(
username=username, username__iexact=username,
domain_id=domain, domain_id=domain,
local=True, local=True,
) )
else: else:
return cls.objects.get( return cls.objects.get(
username=username, username__iexact=username,
domain_id=domain, domain_id=domain,
) )
except cls.DoesNotExist: except cls.DoesNotExist:
@ -596,7 +595,7 @@ class Identity(StatorModel):
if self.username and "@value" in self.username: if self.username and "@value" in self.username:
self.username = self.username["@value"] self.username = self.username["@value"]
if self.username: if self.username:
self.username = self.username.lower() self.username = self.username
self.manually_approves_followers = document.get("manuallyApprovesFollowers") self.manually_approves_followers = document.get("manuallyApprovesFollowers")
self.public_key = document.get("publicKey", {}).get("publicKeyPem") self.public_key = document.get("publicKey", {}).get("publicKeyPem")
self.public_key_id = document.get("publicKey", {}).get("id") self.public_key_id = document.get("publicKey", {}).get("id")
@ -626,7 +625,7 @@ class Identity(StatorModel):
) )
if webfinger_handle: if webfinger_handle:
webfinger_username, webfinger_domain = webfinger_handle.split("@") webfinger_username, webfinger_domain = webfinger_handle.split("@")
self.username = webfinger_username.lower() self.username = webfinger_username
self.domain = await get_domain(webfinger_domain) self.domain = await get_domain(webfinger_domain)
else: else:
self.domain = await get_domain(actor_url_parts.hostname) self.domain = await get_domain(actor_url_parts.hostname)

View file

@ -35,7 +35,7 @@ class IdentitiesRoot(ListView):
if "@" in query: if "@" in query:
username, domain = query.split("@", 1) username, domain = query.split("@", 1)
identities = identities.filter( identities = identities.filter(
username=username, username__iexact=username,
domain__domain__istartswith=domain, domain__domain__istartswith=domain,
) )
else: else:

View file

@ -285,7 +285,10 @@ class CreateIdentity(FormView):
if ( if (
username username
and domain and domain
and Identity.objects.filter(username=username, domain=domain).exists() and Identity.objects.filter(
username__iexact=username,
domain=domain.lower(),
).exists()
): ):
raise forms.ValidationError(f"{username}@{domain} is already taken") raise forms.ValidationError(f"{username}@{domain} is already taken")
@ -307,7 +310,7 @@ class CreateIdentity(FormView):
domain_instance = Domain.get_domain(domain) domain_instance = Domain.get_domain(domain)
new_identity = Identity.objects.create( new_identity = Identity.objects.create(
actor_uri=f"https://{domain_instance.uri_domain}/@{username}@{domain}/", actor_uri=f"https://{domain_instance.uri_domain}/@{username}@{domain}/",
username=username.lower(), username=username,
domain_id=domain, domain_id=domain,
name=form.cleaned_data["name"], name=form.cleaned_data["name"],
local=True, local=True,