Fix tests

This commit is contained in:
Andrew Godwin 2023-04-30 11:54:09 -06:00
parent ef08013541
commit 55de31e3de
16 changed files with 68 additions and 102 deletions

View file

@ -58,6 +58,7 @@ class Individual(TemplateView):
"link_original": True, "link_original": True,
"ancestors": ancestors, "ancestors": ancestors,
"descendants": descendants, "descendants": descendants,
"public_styling": True,
} }
) )

View file

@ -9,7 +9,7 @@ django-hatchway~=0.5.1
django-htmx~=1.13.0 django-htmx~=1.13.0
django-oauth-toolkit~=2.2.0 django-oauth-toolkit~=2.2.0
django-storages[google,boto3]~=1.13.1 django-storages[google,boto3]~=1.13.1
django~=4.1 django~=4.1.0
email-validator~=1.3.0 email-validator~=1.3.0
gunicorn~=20.1.0 gunicorn~=20.1.0
httpx~=0.23 httpx~=0.23

View file

@ -737,7 +737,7 @@ form.inline {
div.follow { div.follow {
float: right; float: right;
margin: 20px 0 0 0; margin: 30px 0 0 0;
font-size: 16px; font-size: 16px;
text-align: center; text-align: center;
} }
@ -1124,7 +1124,7 @@ section.identity .banner {
object-fit: cover; object-fit: cover;
display: block; display: block;
width: calc(100% + 30px); width: calc(100% + 30px);
margin: -5px -15px 20px -15px; margin: -5px -15px 0px -15px;
border-radius: 5px 0 0 0; border-radius: 5px 0 0 0;
} }
@ -1132,7 +1132,7 @@ section.identity .icon {
width: 80px; width: 80px;
height: 80px; height: 80px;
float: left; float: left;
margin: 0 20px 15px 0; margin: 15px 20px 15px 0;
cursor: pointer; cursor: pointer;
} }
@ -1141,7 +1141,7 @@ section.identity .emoji {
} }
section.identity h1 { section.identity h1 {
margin: 30px 0 0 0; margin: 25px 0 0 0;
} }
section.identity small { section.identity small {

View file

@ -3,7 +3,7 @@
{% block subtitle %}{{ domain.domain }}{% endblock %} {% block subtitle %}{{ domain.domain }}{% endblock %}
{% block settings_content %} {% block settings_content %}
<form action="." method="POST"> <form action="." method="POST" enctype="multipart/form-data">
{% csrf_token %} {% csrf_token %}
<fieldset> <fieldset>
<legend>Domain Details</legend> <legend>Domain Details</legend>

View file

@ -18,8 +18,14 @@
--color-text-link: {{ config.highlight_color }}; --color-text-link: {{ config.highlight_color }};
} }
</style> </style>
{% if config_identity.custom_css %} {% if identity and public_styling %}
<style>{{ config_identity.custom_css|safe }}</style> {% if identity.domain.config_domain.custom_css %}
<style>{{ identity.domain.config_domain.custom_css|safe }}</style>
{% endif %}
{% else %}
{% if request.domain.config_domain.custom_css %}
<style>{{ request.domain.config_domain.custom_css|safe }}</style>
{% endif %}
{% endif %} {% endif %}
{% block opengraph %} {% block opengraph %}
{% include "_opengraph.html" with opengraph_local=opengraph_defaults %} {% include "_opengraph.html" with opengraph_local=opengraph_defaults %}
@ -35,8 +41,13 @@
<header> <header>
<menu> <menu>
<a class="logo" href="/"> <a class="logo" href="/">
<img src="{{ config.site_icon }}" width="32"> {% if identity and public_styling %}
<img src="{{ identity.domain.config_domain.site_icon|default:config.site_icon }}" width="32">
{{ identity.domain.config_domain.site_name|default:config.site_name }}
{% else %}
<img src="{{ request.domain.config_domain.site_icon|default:config.site_icon }}" width="32">
{{ request.domain.config_domain.site_name|default:config.site_name }} {{ request.domain.config_domain.site_name|default:config.site_name }}
{% endif %}
</a> </a>
{% if user.is_authenticated %} {% if user.is_authenticated %}
<a href="/" title="My Account"><i class="fa-solid fa-user"></i></a> <a href="/" title="My Account"><i class="fa-solid fa-user"></i></a>
@ -58,6 +69,8 @@
{% include "_announcements.html" %} {% include "_announcements.html" %}
{% endblock %} {% endblock %}
{% include "activities/_image_viewer.html" %}
{% block content %} {% block content %}
{% endblock %} {% endblock %}
{% endblock %} {% endblock %}

View file

@ -1,6 +1,6 @@
{% extends "identity/view.html" %} {% extends "identity/view.html" %}
{% block title %}{% if self.inbound %}Followers{% else %}Following{% endif %} - {{ identity }}{% endblock %} {% block title %}{% if inbound %}Followers{% else %}Following{% endif %} - {{ identity }}{% endblock %}
{% block subcontent %} {% block subcontent %}
@ -8,7 +8,7 @@
{% include "activities/_identity.html" %} {% include "activities/_identity.html" %}
{% empty %} {% empty %}
<span class="empty"> <span class="empty">
This person has no {% if self.inbound %}followers{% else %}follows{% endif %} yet. This person has no {% if inbound %}followers{% else %}follows{% endif %} yet.
</span> </span>
{% endfor %} {% endfor %}

View file

@ -8,51 +8,20 @@ from users.models import Identity
@pytest.mark.django_db @pytest.mark.django_db
def test_content_warning_text( def test_rate_limit(identity: Identity, client_with_user: Client):
client_with_identity: Client,
config_system: Config.SystemOptions,
):
"""
Tests that changing the content warning name works
"""
config_system.content_warning_text = "Content Summary"
response = client_with_identity.get("/compose/")
assertContains(response, 'placeholder="Content Summary"', status_code=200)
assertContains(
response, "<label for='id_content_warning'>Content Summary</label>", html=True
)
@pytest.mark.django_db
def test_post_edit_security(client_with_identity: Client, other_identity: Identity):
"""
Tests that you can't edit other users' posts with URL fiddling
"""
other_post = Post.objects.create(
content="<p>OTHER POST!</p>",
author=other_identity,
local=True,
visibility=Post.Visibilities.public,
)
response = client_with_identity.get(other_post.urls.action_edit)
assert response.status_code == 403
@pytest.mark.django_db
def test_rate_limit(identity: Identity, client_with_identity: Client):
""" """
Tests that the posting rate limit comes into force Tests that the posting rate limit comes into force
""" """
# First post should go through # First post should go through
assert identity.posts.count() == 0 assert identity.posts.count() == 0
response = client_with_identity.post( response = client_with_user.post(
"/compose/", data={"text": "post 1", "visibility": "0"} f"/@{identity.handle}/compose/", data={"text": "post 1", "visibility": "0"}
) )
assert response.status_code == 302 assert response.status_code == 302
assert identity.posts.count() == 1 assert identity.posts.count() == 1
# Second should not # Second should not
response = client_with_identity.post( response = client_with_user.post(
"/compose/", data={"text": "post 2", "visibility": "0"} f"/@{identity.handle}/compose/", data={"text": "post 2", "visibility": "0"}
) )
assertContains(response, "You must wait at least", status_code=200) assertContains(response, "You must wait at least", status_code=200)
assert identity.posts.count() == 1 assert identity.posts.count() == 1

View file

@ -1,20 +0,0 @@
import pytest
from django.test.client import Client
from activities.models import Post
from users.models import Identity
@pytest.mark.django_db
def test_post_delete_security(client_with_identity: Client, other_identity: Identity):
"""
Tests that you can't delete other users' posts with URL fiddling
"""
other_post = Post.objects.create(
content="<p>OTHER POST!</p>",
author=other_identity,
local=True,
visibility=Post.Visibilities.public,
)
response = client_with_identity.get(other_post.urls.action_delete)
assert response.status_code == 403

View file

@ -1,12 +0,0 @@
import pytest
@pytest.mark.django_db
def test_content_warning_text(client_with_identity, config_system):
config_system.content_warning_text = "Content Summary"
response = client_with_identity.get("/")
assert response.status_code == 200
assert 'placeholder="Content Summary"' in str(response.rendered_content)

View file

@ -62,6 +62,8 @@ def _test_settings(settings):
settings.STATICFILES_STORAGE = ( settings.STATICFILES_STORAGE = (
"django.contrib.staticfiles.storage.StaticFilesStorage" "django.contrib.staticfiles.storage.StaticFilesStorage"
) )
settings.SETUP.MAIN_DOMAIN = "example.com"
settings.MAIN_DOMAIN = "example.com"
@pytest.fixture @pytest.fixture
@ -77,15 +79,11 @@ def config_system(keypair):
@pytest.fixture @pytest.fixture
def client_with_identity(client, identity, user): def client_with_user(client, user):
""" """
Provides a logged-in test client with an identity selected Provides a logged-in test client
""" """
client.force_login(user) client.force_login(user)
session = client.session
session["identity_id"] = identity.id
session.save()
client.cookies[settings.SESSION_COOKIE_NAME] = session.session_key
return client return client

View file

@ -13,7 +13,7 @@ def test_stats(client, identity, other_identity):
Follow.objects.create(source=other_identity, target=identity) Follow.objects.create(source=other_identity, target=identity)
Config.set_identity(identity, "visible_follows", True) Config.set_identity(identity, "visible_follows", True)
response = client.get(identity.urls.view) response = client.get(identity.urls.view)
assertContains(response, "<strong>1</strong> follower", status_code=200) assertContains(response, "<strong>1</strong> Follower", status_code=200)
@pytest.mark.django_db @pytest.mark.django_db
@ -23,7 +23,7 @@ def test_visible_follows_disabled(client, identity):
""" """
Config.set_identity(identity, "visible_follows", True) Config.set_identity(identity, "visible_follows", True)
response = client.get(identity.urls.view) response = client.get(identity.urls.view)
assertContains(response, "follower", status_code=200) assertContains(response, "Follower", status_code=200)
Config.set_identity(identity, "visible_follows", False) Config.set_identity(identity, "visible_follows", False)
response = client.get(identity.urls.view) response = client.get(identity.urls.view)
assertNotContains(response, "follower", status_code=200) assertNotContains(response, "Follower", status_code=200)

View file

@ -10,7 +10,7 @@ from users.services import IdentityService
@pytest.mark.django_db @pytest.mark.django_db
def test_import_following( def test_import_following(
client_with_identity: Client, client_with_user: Client,
identity: Identity, identity: Identity,
remote_identity: Identity, remote_identity: Identity,
stator: StatorRunner, stator: StatorRunner,
@ -24,8 +24,8 @@ def test_import_following(
"follows.csv", "follows.csv",
b"Account address,Show boosts,Notify on new posts,Languages\ntest@remote.test,true,false,", b"Account address,Show boosts,Notify on new posts,Languages\ntest@remote.test,true,false,",
) )
response = client_with_identity.post( response = client_with_user.post(
"/settings/import_export/", f"/@{identity.handle}/settings/import_export/",
{ {
"csv": csv_file, "csv": csv_file,
"import_type": "following", "import_type": "following",
@ -45,7 +45,7 @@ def test_import_following(
@pytest.mark.django_db @pytest.mark.django_db
def test_export_following( def test_export_following(
client_with_identity: Client, client_with_user: Client,
identity: Identity, identity: Identity,
remote_identity: Identity, remote_identity: Identity,
stator: StatorRunner, stator: StatorRunner,
@ -58,7 +58,9 @@ def test_export_following(
IdentityService(identity).follow(remote_identity) IdentityService(identity).follow(remote_identity)
# Download the CSV # Download the CSV
response = client_with_identity.get("/settings/import_export/following.csv") response = client_with_user.get(
f"/@{identity.handle}/settings/import_export/following.csv"
)
assert response.status_code == 200 assert response.status_code == 200
assert ( assert (
response.content.strip() response.content.strip()
@ -68,7 +70,7 @@ def test_export_following(
@pytest.mark.django_db @pytest.mark.django_db
def test_export_followers( def test_export_followers(
client_with_identity: Client, client_with_user: Client,
identity: Identity, identity: Identity,
identity2: Identity, identity2: Identity,
stator: StatorRunner, stator: StatorRunner,
@ -81,6 +83,8 @@ def test_export_followers(
IdentityService(identity2).follow(identity) IdentityService(identity2).follow(identity)
# Download the CSV # Download the CSV
response = client_with_identity.get("/settings/import_export/followers.csv") response = client_with_user.get(
f"/@{identity.handle}/settings/import_export/followers.csv"
)
assert response.status_code == 200 assert response.status_code == 200
assert response.content.strip() == b"Account address\r\ntest@example2.com" assert response.content.strip() == b"Account address\r\ntest@example2.com"

View file

@ -10,6 +10,8 @@ class DomainMiddleware:
self.get_response = get_response self.get_response = get_response
def __call__(self, request): def __call__(self, request):
request.domain = None
if "HTTP_HOST" in request.META:
request.domain = Domain.get_domain(request.META["HTTP_HOST"]) request.domain = Domain.get_domain(request.META["HTTP_HOST"])
response = self.get_response(request) response = self.get_response(request)
return response return response

View file

@ -4,6 +4,7 @@ from django.db import models
from django.shortcuts import get_object_or_404, redirect from django.shortcuts import get_object_or_404, redirect
from django.utils.decorators import method_decorator from django.utils.decorators import method_decorator
from django.views.generic import FormView, TemplateView from django.views.generic import FormView, TemplateView
from django.core.files import File
from core.models import Config from core.models import Config
from users.decorators import admin_required from users.decorators import admin_required
@ -214,6 +215,7 @@ class DomainEdit(FormView):
Domain.objects.exclude(pk=self.domain.pk).update(default=False) Domain.objects.exclude(pk=self.domain.pk).update(default=False)
Config.set_domain(self.domain, "hide_login", form.cleaned_data["hide_login"]) Config.set_domain(self.domain, "hide_login", form.cleaned_data["hide_login"])
Config.set_domain(self.domain, "site_name", form.cleaned_data["site_name"]) Config.set_domain(self.domain, "site_name", form.cleaned_data["site_name"])
if isinstance(form.cleaned_data["site_icon"], File):
Config.set_domain(self.domain, "site_icon", form.cleaned_data["site_icon"]) Config.set_domain(self.domain, "site_icon", form.cleaned_data["site_icon"])
Config.set_domain(self.domain, "custom_css", form.cleaned_data["custom_css"]) Config.set_domain(self.domain, "custom_css", form.cleaned_data["custom_css"])
return redirect(Domain.urls.root) return redirect(Domain.urls.root)

View file

@ -71,6 +71,7 @@ class ViewIdentity(ListView):
def get_context_data(self): def get_context_data(self):
context = super().get_context_data() context = super().get_context_data()
context["identity"] = self.identity context["identity"] = self.identity
context["public_styling"] = True
context["post_count"] = self.identity.posts.count() context["post_count"] = self.identity.posts.count()
if self.identity.config_identity.visible_follows: if self.identity.config_identity.visible_follows:
context["followers_count"] = self.identity.inbound_follows.filter( context["followers_count"] = self.identity.inbound_follows.filter(
@ -214,11 +215,18 @@ class IdentityFollows(ListView):
raise Http404("Hidden follows") raise Http404("Hidden follows")
return super().get(request, identity=self.identity) return super().get(request, identity=self.identity)
def get_queryset(self):
if self.inbound:
return IdentityService(self.identity).followers()
else:
return IdentityService(self.identity).following()
def get_context_data(self): def get_context_data(self):
context = super().get_context_data() context = super().get_context_data()
context["identity"] = self.identity context["identity"] = self.identity
context["inbound"] = self.inbound context["inbound"] = self.inbound
context["section"] = "follows" context["section"] = "follows"
context["public_styling"] = True
context["followers_count"] = self.identity.inbound_follows.filter( context["followers_count"] = self.identity.inbound_follows.filter(
state__in=FollowStates.group_active() state__in=FollowStates.group_active()
).count() ).count()
@ -255,6 +263,7 @@ class IdentitySearch(FormView):
context = super().get_context_data(**kwargs) context = super().get_context_data(**kwargs)
context["identity"] = self.identity context["identity"] = self.identity
context["section"] = "search" context["section"] = "search"
context["public_styling"] = True
context["followers_count"] = self.identity.inbound_follows.filter( context["followers_count"] = self.identity.inbound_follows.filter(
state__in=FollowStates.group_active() state__in=FollowStates.group_active()
).count() ).count()

View file

@ -70,7 +70,7 @@ class ImportExportPage(IdentityViewMixin, FormView):
return context return context
class CsvView(View): class CsvView(IdentityViewMixin, View):
""" """
Generic view that exports a queryset as a CSV Generic view that exports a queryset as a CSV
""" """
@ -85,7 +85,7 @@ class CsvView(View):
def get_queryset(self): def get_queryset(self):
raise NotImplementedError() raise NotImplementedError()
def get(self, request): def get(self, request, *args, **kwargs):
response = HttpResponse( response = HttpResponse(
content_type="text/csv", content_type="text/csv",
headers={"Content-Disposition": f'attachment; filename="{self.filename}"'}, headers={"Content-Disposition": f'attachment; filename="{self.filename}"'},
@ -113,7 +113,7 @@ class CsvView(View):
return response return response
class CsvFollowing(IdentityViewMixin, CsvView): class CsvFollowing(CsvView):
columns = { columns = {
"Account address": "get_handle", "Account address": "get_handle",
"Show boosts": "boosts", "Show boosts": "boosts",
@ -136,7 +136,7 @@ class CsvFollowing(IdentityViewMixin, CsvView):
return "" return ""
class CsvFollowers(IdentityViewMixin, CsvView): class CsvFollowers(CsvView):
columns = { columns = {
"Account address": "get_handle", "Account address": "get_handle",
} }