diff --git a/activities/models/post.py b/activities/models/post.py index 482aea4..62c5983 100644 --- a/activities/models/post.py +++ b/activities/models/post.py @@ -503,6 +503,7 @@ class Post(StatorModel): self, content: str, summary: str | None = None, + sensitive: bool | None = None, visibility: int = Visibilities.public, attachments: list | None = None, ): @@ -512,7 +513,7 @@ class Post(StatorModel): self.content = parser.html self.hashtags = sorted(parser.hashtags) or None self.summary = summary or None - self.sensitive = bool(summary) + self.sensitive = bool(summary) if sensitive is None else sensitive self.visibility = visibility self.edited = timezone.now() self.mentions.set(self.mentions_from_content(content, self.author)) diff --git a/api/schemas.py b/api/schemas.py index 2efe4f2..5f07dd9 100644 --- a/api/schemas.py +++ b/api/schemas.py @@ -1,7 +1,9 @@ from typing import Literal, Optional, Union -from activities import models as activities_models from hatchway import Field, Schema + +from activities import models as activities_models +from core.html import FediverseHtmlParser from users import models as users_models from users.services import IdentityService @@ -174,6 +176,20 @@ class Status(Schema): ] +class StatusSource(Schema): + id: str + text: str + spoiler_text: str + + @classmethod + def from_post(cls, post: activities_models.Post): + return cls( + id=post.id, + text=FediverseHtmlParser(post.content).plain_text, + spoiler_text=post.summary or "", + ) + + class Conversation(Schema): id: str unread: bool diff --git a/api/urls.py b/api/urls.py index 2261648..a2ea0a6 100644 --- a/api/urls.py +++ b/api/urls.py @@ -1,4 +1,5 @@ from django.urls import path +from hatchway import methods from api.views import ( accounts, @@ -14,7 +15,6 @@ from api.views import ( timelines, trends, ) -from hatchway import methods urlpatterns = [ # Accounts @@ -55,9 +55,11 @@ urlpatterns = [ "v1/statuses/", methods( get=statuses.status, + put=statuses.edit_status, delete=statuses.delete_status, ), ), + path("v1/statuses//source", statuses.status_source), # Notifications path("v1/notifications", notifications.notifications), # Search diff --git a/api/views/media.py b/api/views/media.py index a856752..3ec9918 100644 --- a/api/views/media.py +++ b/api/views/media.py @@ -1,10 +1,10 @@ from django.core.files import File from django.shortcuts import get_object_or_404 +from hatchway import ApiError, QueryOrBody, api_view from activities.models import PostAttachment, PostAttachmentStates from api import schemas from core.files import blurhash_image, resize_image -from hatchway import QueryOrBody, api_view from ..decorators import identity_required @@ -54,6 +54,8 @@ def get_media( id: str, ) -> schemas.MediaAttachment: attachment = get_object_or_404(PostAttachment, pk=id) + if attachment.post.author != request.identity: + raise ApiError(401, "Not the author of this attachment") return schemas.MediaAttachment.from_post_attachment(attachment) @@ -66,6 +68,8 @@ def update_media( focus: QueryOrBody[str] = "0,0", ) -> schemas.MediaAttachment: attachment = get_object_or_404(PostAttachment, pk=id) + if attachment.post.author != request.identity: + raise ApiError(401, "Not the author of this attachment") attachment.name = description or None attachment.save() return schemas.MediaAttachment.from_post_attachment(attachment) diff --git a/api/views/statuses.py b/api/views/statuses.py index 6161c09..0492c80 100644 --- a/api/views/statuses.py +++ b/api/views/statuses.py @@ -1,8 +1,8 @@ from typing import Literal -from django.forms import ValidationError from django.http import HttpRequest from django.shortcuts import get_object_or_404 +from hatchway import ApiError, ApiResponse, Schema, api_view from activities.models import ( Post, @@ -16,7 +16,6 @@ from api import schemas from api.decorators import identity_required from api.pagination import MastodonPaginator, PaginationResult from core.models import Config -from hatchway import ApiResponse, Schema, api_view class PostStatusSchema(Schema): @@ -30,14 +29,36 @@ class PostStatusSchema(Schema): media_ids: list[str] = [] +class EditStatusSchema(Schema): + status: str + sensitive: bool = False + spoiler_text: str | None = None + language: str | None = None + media_ids: list[str] = [] + + +def post_for_id(request: HttpRequest, id: str) -> Post: + """ + Common logic to get a Post object for an ID, taking visibility into + account. + """ + if request.identity: + queryset = Post.objects.not_hidden().visible_to( + request.identity, include_replies=True + ) + else: + queryset = Post.objects.not_hidden().unlisted() + return get_object_or_404(queryset, pk=id) + + @identity_required @api_view.post def post_status(request, details: PostStatusSchema) -> schemas.Status: # Check text length if len(details.status) > Config.system.post_length: - raise ValidationError("Status is too long") + raise ApiError(400, "Status is too long") if len(details.status) == 0 and not details.media_ids: - raise ValidationError("Status is empty") + raise ApiError(400, "Status is empty") # Grab attachments attachments = [get_object_or_404(PostAttachment, pk=id) for id in details.media_ids] # Create the Post @@ -70,23 +91,50 @@ def post_status(request, details: PostStatusSchema) -> schemas.Status: @identity_required @api_view.get def status(request, id: str) -> schemas.Status: - post = get_object_or_404(Post, pk=id) + post = post_for_id(request, id) interactions = PostInteraction.get_post_interactions([post], request.identity) return schemas.Status.from_post(post, interactions=interactions) +@identity_required +@api_view.put +def edit_status(request, id: str, details: EditStatusSchema) -> schemas.Status: + post = post_for_id(request, id) + if post.author != request.identity: + raise ApiError(401, "Not the author of this status") + # Grab attachments + attachments = [get_object_or_404(PostAttachment, pk=id) for id in details.media_ids] + # Update all details, as the client must provide them all + post.edit_local( + content=details.status, + summary=details.spoiler_text, + sensitive=details.sensitive, + attachments=attachments, + ) + return schemas.Status.from_post(post) + + @identity_required @api_view.delete def delete_status(request, id: str) -> schemas.Status: - post = get_object_or_404(Post, pk=id) + post = post_for_id(request, id) + if post.author != request.identity: + raise ApiError(401, "Not the author of this status") PostService(post).delete() return schemas.Status.from_post(post) +@identity_required +@api_view.get +def status_source(request, id: str) -> schemas.StatusSource: + post = post_for_id(request, id) + return schemas.StatusSource.from_post(post) + + @identity_required @api_view.get def status_context(request, id: str) -> schemas.Context: - post = get_object_or_404(Post, pk=id) + post = post_for_id(request, id) service = PostService(post) ancestors, descendants = service.context(request.identity) interactions = PostInteraction.get_post_interactions( @@ -106,7 +154,7 @@ def status_context(request, id: str) -> schemas.Context: @identity_required @api_view.post def favourite_status(request, id: str) -> schemas.Status: - post = get_object_or_404(Post, pk=id) + post = post_for_id(request, id) service = PostService(post) service.like_as(request.identity) interactions = PostInteraction.get_post_interactions([post], request.identity) @@ -116,7 +164,7 @@ def favourite_status(request, id: str) -> schemas.Status: @identity_required @api_view.post def unfavourite_status(request, id: str) -> schemas.Status: - post = get_object_or_404(Post, pk=id) + post = post_for_id(request, id) service = PostService(post) service.unlike_as(request.identity) interactions = PostInteraction.get_post_interactions([post], request.identity) @@ -135,9 +183,7 @@ def favourited_by( """ View who favourited a given status. """ - # This method should filter out private statuses, but we don't really have - # a concept of "private status" yet. - post = get_object_or_404(Post, pk=id) + post = post_for_id(request, id) paginator = MastodonPaginator() pager: PaginationResult[PostInteraction] = paginator.paginate( @@ -169,7 +215,7 @@ def favourited_by( @identity_required @api_view.post def reblog_status(request, id: str) -> schemas.Status: - post = get_object_or_404(Post, pk=id) + post = post_for_id(request, id) service = PostService(post) service.boost_as(request.identity) interactions = PostInteraction.get_post_interactions([post], request.identity) @@ -179,7 +225,7 @@ def reblog_status(request, id: str) -> schemas.Status: @identity_required @api_view.post def unreblog_status(request, id: str) -> schemas.Status: - post = get_object_or_404(Post, pk=id) + post = post_for_id(request, id) service = PostService(post) service.unboost_as(request.identity) interactions = PostInteraction.get_post_interactions([post], request.identity) diff --git a/tests/api/test_accounts.py b/tests/api/test_accounts.py index fe4bfb7..2ddd010 100644 --- a/tests/api/test_accounts.py +++ b/tests/api/test_accounts.py @@ -2,22 +2,14 @@ import pytest @pytest.mark.django_db -def test_verify_credentials(api_token, identity, client): - response = client.get( - "/api/v1/accounts/verify_credentials", - HTTP_AUTHORIZATION=f"Bearer {api_token.token}", - HTTP_ACCEPT="application/json", - ).json() +def test_verify_credentials(api_client, identity): + response = api_client.get("/api/v1/accounts/verify_credentials").json() assert response["id"] == str(identity.pk) assert response["username"] == identity.username @pytest.mark.django_db -def test_account_search(api_token, identity, client): - response = client.get( - "/api/v1/accounts/search?q=test", - HTTP_AUTHORIZATION=f"Bearer {api_token.token}", - HTTP_ACCEPT="application/json", - ).json() +def test_account_search(api_client, identity): + response = api_client.get("/api/v1/accounts/search?q=test").json() assert response[0]["id"] == str(identity.pk) assert response[0]["username"] == identity.username diff --git a/tests/api/test_instance.py b/tests/api/test_instance.py index 9fd0af2..c85224f 100644 --- a/tests/api/test_instance.py +++ b/tests/api/test_instance.py @@ -2,10 +2,6 @@ import pytest @pytest.mark.django_db -def test_instance(api_token, client): - response = client.get( - "/api/v1/instance", - HTTP_AUTHORIZATION=f"Bearer {api_token.token}", - HTTP_ACCEPT="application/json", - ).json() +def test_instance(api_client): + response = api_client.get("/api/v1/instance").json() assert response["uri"] == "example.com" diff --git a/tests/api/test_likes.py b/tests/api/test_likes.py index db31d8c..a7b689d 100644 --- a/tests/api/test_likes.py +++ b/tests/api/test_likes.py @@ -2,12 +2,10 @@ import pytest @pytest.mark.django_db -def test_likes_flow(api_token, client): +def test_likes_flow(api_client): # Add a post - response = client.post( + response = api_client.post( "/api/v1/statuses", - HTTP_AUTHORIZATION=f"Bearer {api_token.token}", - HTTP_ACCEPT="application/json", content_type="application/json", data={ "status": "Like test.", @@ -19,17 +17,9 @@ def test_likes_flow(api_token, client): status_id = response["id"] # Like it - response = client.post( - f"/api/v1/statuses/{status_id}/favourite", - HTTP_AUTHORIZATION=f"Bearer {api_token.token}", - HTTP_ACCEPT="application/json", - ).json() + response = api_client.post(f"/api/v1/statuses/{status_id}/favourite").json() assert response["favourited"] is True # Check if it's displaying at likes endpoint - response = client.get( - "/api/v1/favourites", - HTTP_AUTHORIZATION=f"Bearer {api_token.token}", - HTTP_ACCEPT="application/json", - ).json() + response = api_client.get("/api/v1/favourites").json() assert response[0]["id"] == status_id diff --git a/tests/api/test_statuses.py b/tests/api/test_statuses.py index df576e9..05ede30 100644 --- a/tests/api/test_statuses.py +++ b/tests/api/test_statuses.py @@ -4,11 +4,13 @@ from activities.models import Post @pytest.mark.django_db -def test_post_status(api_token, identity, client): - response = client.post( +def test_post_status(api_client): + """ + Tests posting, editing and deleting a status + """ + # Post new one + response = api_client.post( "/api/v1/statuses", - HTTP_AUTHORIZATION=f"Bearer {api_token.token}", - HTTP_ACCEPT="application/json", content_type="application/json", data={ "status": "Hello, world!", @@ -17,19 +19,37 @@ def test_post_status(api_token, identity, client): ).json() assert response["content"] == "

Hello, world!

" assert response["visibility"] == "unlisted" + status_id = response["id"] + # Retrieve "source" version an edit would use + response = api_client.get(f"/api/v1/statuses/{status_id}/source").json() + assert response["text"] == "Hello, world!" + # Post an edit + response = api_client.put( + f"/api/v1/statuses/{status_id}", + content_type="application/json", + data={ + "status": "Hello, world! Again!", + }, + ).json() + # Check it stuck + response = api_client.get(f"/api/v1/statuses/{status_id}").json() + assert response["content"] == "

Hello, world! Again!

" + # Delete it + response = api_client.delete(f"/api/v1/statuses/{status_id}") + assert response.status_code == 200 + # Check it's gone + response = api_client.get(f"/api/v1/statuses/{status_id}") + assert response.status_code == 404 @pytest.mark.django_db -def test_mention_format(api_token, identity, remote_identity, client): +def test_mention_format(api_client, identity, remote_identity): """ Ensures mentions work, and only have one link around them. """ # Make a local post and check it - response = client.post( + response = api_client.post( "/api/v1/statuses", - HTTP_AUTHORIZATION=f"Bearer {api_token.token}", - HTTP_ACCEPT="application/json", - content_type="application/json", data={ "status": "Hello, @test!", "visibility": "unlisted", @@ -49,11 +69,8 @@ def test_mention_format(api_token, identity, remote_identity, client): object_uri="https://remote.test/status/12345", ) post.mentions.add(identity) - response = client.get( + response = api_client.get( f"/api/v1/statuses/{post.id}", - HTTP_AUTHORIZATION=f"Bearer {api_token.token}", - HTTP_ACCEPT="application/json", - content_type="application/json", ).json() assert ( response["text"] == '

Hey @test

' diff --git a/tests/conftest.py b/tests/conftest.py index 51cc9a8..3a03f4a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,6 +2,7 @@ import time import pytest from django.conf import settings +from django.test import Client from api.models import Application, Token from core.models import Config @@ -218,6 +219,14 @@ def api_token(identity) -> Token: ) +@pytest.fixture +def api_client(api_token): + return Client( + HTTP_AUTHORIZATION=f"Bearer {api_token.token}", + HTTP_ACCEPT="application/json", + ) + + @pytest.fixture def stator(config_system) -> StatorRunner: """