diff --git a/activities/views/compose.py b/activities/views/compose.py index 8e3c96d..6889864 100644 --- a/activities/views/compose.py +++ b/activities/views/compose.py @@ -2,6 +2,7 @@ from django import forms from django.conf import settings from django.core.exceptions import PermissionDenied from django.shortcuts import get_object_or_404, redirect, render +from django.utils import timezone from django.utils.decorators import method_decorator from django.views.generic import FormView @@ -54,8 +55,9 @@ class Compose(FormView): ) reply_to = forms.CharField(widget=forms.HiddenInput(), required=False) - def __init__(self, *args, **kwargs): + def __init__(self, request, *args, **kwargs): super().__init__(*args, **kwargs) + self.request = request self.fields["text"].widget.attrs[ "_" ] = f""" @@ -74,8 +76,20 @@ class Compose(FormView): def clean_text(self): text = self.cleaned_data.get("text") + # Check minimum interval + last_post = self.request.identity.posts.order_by("-created").first() + if ( + last_post + and (timezone.now() - last_post.created).total_seconds() + < Config.system.post_minimum_interval + ): + raise forms.ValidationError( + f"You must wait at least {Config.system.post_minimum_interval} seconds between posts" + ) + print(last_post) if not text: return text + # Check post length length = len(text) if length > Config.system.post_length: raise forms.ValidationError( @@ -83,6 +97,9 @@ class Compose(FormView): ) return text + def get_form(self, form_class=None): + return self.form_class(request=self.request, **self.get_form_kwargs()) + def get_initial(self): initial = super().get_initial() if self.post_obj: diff --git a/activities/views/timelines.py b/activities/views/timelines.py index f55e331..d797cfd 100644 --- a/activities/views/timelines.py +++ b/activities/views/timelines.py @@ -17,6 +17,9 @@ class Home(FormView): form_class = Compose.form_class + def get_form(self, form_class=None): + return self.form_class(request=self.request, **self.get_form_kwargs()) + def get_context_data(self): context = super().get_context_data() context["events"] = list( diff --git a/core/models/config.py b/core/models/config.py index 9670c11..48da7de 100644 --- a/core/models/config.py +++ b/core/models/config.py @@ -216,6 +216,7 @@ class Config(models.Model): content_warning_text: str = "Content Warning" post_length: int = 500 + post_minimum_interval: int = 3 # seconds identity_min_length: int = 2 identity_max_per_user: int = 5 identity_max_age: int = 24 * 60 * 60 diff --git a/takahe/settings.py b/takahe/settings.py index 98bc9dd..3017293 100644 --- a/takahe/settings.py +++ b/takahe/settings.py @@ -282,6 +282,8 @@ STATICFILES_DIRS = [BASE_DIR / "static"] STATICFILES_STORAGE = "django.contrib.staticfiles.storage.ManifestStaticFilesStorage" +SESSION_ENGINE = "django.contrib.sessions.backends.signed_cookies" + WHITENOISE_MAX_AGE = 3600 STATIC_ROOT = BASE_DIR / "static-collected" diff --git a/tests/activities/templatetags/test_activity_tags.py b/tests/activities/templatetags/test_activity_tags.py index 85d8cdf..9426337 100644 --- a/tests/activities/templatetags/test_activity_tags.py +++ b/tests/activities/templatetags/test_activity_tags.py @@ -5,7 +5,10 @@ from django.utils import timezone from activities.templatetags.activity_tags import linkify_hashtags, timedeltashort -def test_timedeltashort_regress(): +def test_timedeltashort(): + """ + Tests that timedeltashort works correctly + """ assert timedeltashort(None) == "" assert timedeltashort("") == "" @@ -21,7 +24,11 @@ def test_timedeltashort_regress(): assert timedeltashort(value - timedelta(days=366)) == "1y" -def test_linkify_hashtags_regres(): +def test_linkify_hashtags(): + """ + Tests that linkify_hashtags works correctly + """ + assert linkify_hashtags(None) == "" assert linkify_hashtags("") == "" diff --git a/tests/activities/views/test_compose.py b/tests/activities/views/test_compose.py index 2b8c4ea..2fdec3a 100644 --- a/tests/activities/views/test_compose.py +++ b/tests/activities/views/test_compose.py @@ -1,46 +1,58 @@ -import re -from unittest import mock - import pytest -from django.core.exceptions import PermissionDenied +from django.test.client import Client +from pytest_django.asserts import assertContains from activities.models import Post -from activities.views.compose import Compose +from core.models import Config +from users.models import Identity @pytest.mark.django_db -def test_content_warning_text(identity, user, rf, config_system): - request = rf.get("/compose/") - request.user = user - request.identity = identity - +def test_content_warning_text( + client_with_identity: Client, + config_system: Config.SystemOptions, +): + """ + Tests that changing the content warning name works + """ config_system.content_warning_text = "Content Summary" - with mock.patch("core.models.Config.load_system", return_value=config_system): - view = Compose.as_view() - resp = view(request) - assert resp.status_code == 200 - content = str(resp.rendered_content) - assert 'placeholder="Content Summary"' in content - assert re.search( - r"\s*Content Summary\s*", content, flags=re.MULTILINE - ) + response = client_with_identity.get("/compose/") + assertContains(response, 'placeholder="Content Summary"', status_code=200) + assertContains( + response, "", html=True + ) @pytest.mark.django_db -def test_post_edit_security(identity, user, rf, other_identity): - # Create post +def test_post_edit_security(client_with_identity: Client, other_identity: Identity): + """ + Tests that you can't edit other users' posts with URL fiddling + """ other_post = Post.objects.create( content="

OTHER POST!

", author=other_identity, local=True, visibility=Post.Visibilities.public, ) + response = client_with_identity.get(other_post.urls.action_edit) + assert response.status_code == 403 - request = rf.get(other_post.get_absolute_url() + "edit/") - request.user = user - request.identity = identity - view = Compose.as_view() - with pytest.raises(PermissionDenied) as ex: - view(request, handle=other_identity.handle.lstrip("@"), post_id=other_post.id) - assert str(ex.value) == "Post author is not requestor" +@pytest.mark.django_db +def test_rate_limit(identity: Identity, client_with_identity: Client): + """ + Tests that the posting rate limit comes into force + """ + # First post should go through + assert identity.posts.count() == 0 + response = client_with_identity.post( + "/compose/", data={"text": "post 1", "visibility": "0"} + ) + assert response.status_code == 302 + assert identity.posts.count() == 1 + # Second should not + response = client_with_identity.post( + "/compose/", data={"text": "post 2", "visibility": "0"} + ) + assertContains(response, "You must wait at least", status_code=200) + assert identity.posts.count() == 1 diff --git a/tests/activities/views/test_posts.py b/tests/activities/views/test_posts.py index ac148d0..a3dbc64 100644 --- a/tests/activities/views/test_posts.py +++ b/tests/activities/views/test_posts.py @@ -1,25 +1,20 @@ import pytest -from django.core.exceptions import PermissionDenied +from django.test.client import Client from activities.models import Post -from activities.views.posts import Delete +from users.models import Identity @pytest.mark.django_db -def test_post_delete_security(identity, user, rf, other_identity): - # Create post +def test_post_delete_security(client_with_identity: Client, other_identity: Identity): + """ + Tests that you can't delete other users' posts with URL fiddling + """ other_post = Post.objects.create( content="

OTHER POST!

", author=other_identity, local=True, visibility=Post.Visibilities.public, ) - - request = rf.post(other_post.get_absolute_url() + "delete/") - request.user = user - request.identity = identity - - view = Delete.as_view() - with pytest.raises(PermissionDenied) as ex: - view(request, handle=other_identity.handle.lstrip("@"), post_id=other_post.id) - assert str(ex.value) == "Post author is not requestor" + response = client_with_identity.get(other_post.urls.action_delete) + assert response.status_code == 403 diff --git a/tests/activities/views/test_timelines.py b/tests/activities/views/test_timelines.py index 74bf43d..a2cbb32 100644 --- a/tests/activities/views/test_timelines.py +++ b/tests/activities/views/test_timelines.py @@ -1,19 +1,12 @@ -from unittest import mock - import pytest -from activities.views.timelines import Home - @pytest.mark.django_db -def test_content_warning_text(identity, user, rf, config_system): - request = rf.get("/") - request.user = user - request.identity = identity +def test_content_warning_text(client_with_identity, config_system): config_system.content_warning_text = "Content Summary" - with mock.patch("core.models.Config.load_system", return_value=config_system): - view = Home.as_view() - resp = view(request) - assert resp.status_code == 200 - assert 'placeholder="Content Summary"' in str(resp.rendered_content) + + response = client_with_identity.get("/") + + assert response.status_code == 200 + assert 'placeholder="Content Summary"' in str(response.rendered_content) diff --git a/tests/conftest.py b/tests/conftest.py index 9127466..a86f2f3 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,6 +1,7 @@ import time import pytest +from django.conf import settings from activities.models import Emoji from api.models import Application, Token @@ -73,6 +74,19 @@ def config_system(keypair): del Config.system +@pytest.fixture +def client_with_identity(client, identity, user): + """ + Provides a logged-in test client with an identity selected + """ + client.force_login(user) + session = client.session + session["identity_id"] = identity.id + session.save() + client.cookies[settings.SESSION_COOKIE_NAME] = session.session_key + return client + + @pytest.fixture @pytest.mark.django_db def emoji_locals(): diff --git a/tests/users/views/test_domains.py b/tests/users/views/test_domains.py index 13d4b11..5752125 100644 --- a/tests/users/views/test_domains.py +++ b/tests/users/views/test_domains.py @@ -31,10 +31,16 @@ INVALID_DOMAINS = [ @pytest.mark.parametrize("domain", VALID_DOMAINS) def test_domain_validation_accepts_valid_domains(domain): + """ + Tests that the domain validator works in positive cases + """ DomainValidator()(domain) @pytest.mark.parametrize("domain", INVALID_DOMAINS) def test_domain_validation_raises_exception_for_invalid_domains(domain): + """ + Tests that the domain validator works in negative cases + """ with pytest.raises(ValidationError): DomainValidator()(domain) diff --git a/users/views/admin/settings.py b/users/views/admin/settings.py index b9e2543..4786d83 100644 --- a/users/views/admin/settings.py +++ b/users/views/admin/settings.py @@ -38,6 +38,10 @@ class BasicSettings(AdminSettingsPage): "title": "Maximum Post Length", "help_text": "The maximum number of characters allowed per post", }, + "post_minimum_interval": { + "title": "Minimum Posting Interval", + "help_text": "The minimum number of seconds a user must wait between posts", + }, "content_warning_text": { "title": "Content Warning Feature Name", "help_text": "What the feature that lets users provide post summaries is called", @@ -102,6 +106,7 @@ class BasicSettings(AdminSettingsPage): "Signups": ["signup_allowed", "signup_invite_only", "signup_text"], "Posts": [ "post_length", + "post_minimum_interval", "content_warning_text", "hashtag_unreviewed_are_public", "emoji_unreviewed_are_public",