From 5d2ed9edfe5aa29f84d7618eae6033bb85711589 Mon Sep 17 00:00:00 2001 From: Andrew Godwin Date: Tue, 7 Feb 2023 12:07:15 -0700 Subject: [PATCH] Hatchway API Rewrite (#499) Removes django-ninja and replaces it with a new API framework, "hatchway". I plan to move hatchway into its own project very soon. --- activities/models/post_interaction.py | 2 +- api/decorators.py | 3 + api/pagination.py | 86 +++++++- api/parser.py | 22 -- api/schemas.py | 99 ++++++++- api/urls.py | 72 +++++++ api/views/__init__.py | 12 -- api/views/accounts.py | 296 ++++++++++++------------- api/views/announcements.py | 10 +- api/views/apps.py | 9 +- api/views/base.py | 5 - api/views/emoji.py | 10 +- api/views/filters.py | 6 +- api/views/instance.py | 20 +- api/views/media.py | 41 ++-- api/views/notifications.py | 24 +-- api/views/oauth.py | 21 +- api/views/search.py | 23 +- api/views/statuses.py | 95 ++++---- api/views/timelines.py | 92 ++++---- hatchway/__init__.py | 5 + hatchway/constants.py | 10 + hatchway/http.py | 47 ++++ hatchway/schema.py | 52 +++++ hatchway/tests/test_types.py | 63 ++++++ hatchway/tests/test_view.py | 244 +++++++++++++++++++++ hatchway/types.py | 145 +++++++++++++ hatchway/urls.py | 32 +++ hatchway/view.py | 297 ++++++++++++++++++++++++++ requirements.txt | 1 - takahe/urls.py | 4 +- tests/activities/models/test_post.py | 2 + 32 files changed, 1471 insertions(+), 379 deletions(-) delete mode 100644 api/parser.py create mode 100644 api/urls.py delete mode 100644 api/views/base.py create mode 100644 hatchway/__init__.py create mode 100644 hatchway/constants.py create mode 100644 hatchway/http.py create mode 100644 hatchway/schema.py create mode 100644 hatchway/tests/test_types.py create mode 100644 hatchway/tests/test_view.py create mode 100644 hatchway/types.py create mode 100644 hatchway/urls.py create mode 100644 hatchway/view.py diff --git a/activities/models/post_interaction.py b/activities/models/post_interaction.py index eeeff55..5bc92cf 100644 --- a/activities/models/post_interaction.py +++ b/activities/models/post_interaction.py @@ -188,7 +188,7 @@ class PostInteraction(StatorModel): return result @classmethod - def get_event_interactions(cls, events, identity): + def get_event_interactions(cls, events, identity) -> dict[str, set[str]]: """ Returns a dict of {interaction_type: set(post_ids)} for all the posts within the events and the given identity, for use in templates. diff --git a/api/decorators.py b/api/decorators.py index b60cc05..14215e6 100644 --- a/api/decorators.py +++ b/api/decorators.py @@ -16,4 +16,7 @@ def identity_required(function): return JsonResponse({"error": "identity_token_required"}, status=400) return function(request, *args, **kwargs) + # This is for the API only + inner.csrf_exempt = True + return inner diff --git a/api/pagination.py b/api/pagination.py index 275653a..8645e77 100644 --- a/api/pagination.py +++ b/api/pagination.py @@ -1,22 +1,98 @@ import dataclasses import urllib.parse from collections.abc import Callable -from typing import Any +from typing import Any, Generic, Protocol, TypeVar from django.db import models from django.http import HttpRequest from activities.models import PostInteraction +from hatchway.http import ApiResponse + +T = TypeVar("T") + + +class SchemaWithId(Protocol): + """ + Little protocol type to represent schemas that have an ID attribute + """ + + id: str + + +TI = TypeVar("TI", bound=SchemaWithId) +TM = TypeVar("TM", bound=models.Model) + + +class PaginatingApiResponse(ApiResponse[list[TI]]): + """ + An ApiResponse subclass that also handles pagination link headers + """ + + def __init__( + self, + data: list[TI], + request: HttpRequest, + include_params: list[str], + **kwargs, + ): + # Call superclass + super().__init__(data, **kwargs) + # Figure out if we need link headers + self._request = request + self.extra_params = self.filter_params(self._request, include_params) + link_header = self.build_link_header() + if link_header: + self.headers["link"] = link_header + + @staticmethod + def filter_params(request: HttpRequest, allowed_params: list[str]): + params = {} + for key in allowed_params: + value = request.GET.get(key, None) + if value: + params[key] = value + return params + + def get_part(self, data_index: int, param_name: str, rel: str) -> str | None: + """ + Used to get next/prev URLs + """ + if not self.data: + return None + # Use the ID of the last object for the next page start + params = dict(self.extra_params) + params[param_name] = self.data[data_index].id + return ( + "<" + + self._request.build_absolute_uri(self._request.path) + + "?" + + urllib.parse.urlencode(params) + + f'>; rel="{rel}"' + ) + + def build_link_header(self): + parts = [ + entry + for entry in [ + self.get_part(0, "min_id", "prev"), + self.get_part(-1, "max_id", "next"), + ] + if entry + ] + if not parts: + return None + return ", ".join(parts) @dataclasses.dataclass -class PaginationResult: +class PaginationResult(Generic[T]): """ Represents a pagination result for Mastodon (it does Link header stuff) """ #: A list of objects that matched the pagination query. - results: list[models.Model] + results: list[T] #: The actual applied limit, which may be different from what was requested. limit: int @@ -130,12 +206,12 @@ class MastodonPaginator: def paginate( self, - queryset, + queryset: models.QuerySet[TM], min_id: str | None, max_id: str | None, since_id: str | None, limit: int | None, - ) -> PaginationResult: + ) -> PaginationResult[TM]: # These "does not start with interaction" checks can be removed after a # couple months, when clients have flushed them out. if max_id and not max_id.startswith("interaction"): diff --git a/api/parser.py b/api/parser.py deleted file mode 100644 index 57b420d..0000000 --- a/api/parser.py +++ /dev/null @@ -1,22 +0,0 @@ -import json - -from ninja.parser import Parser - - -class FormOrJsonParser(Parser): - """ - If there's form data in a request, makes it into a JSON dict. - This is needed as the Mastodon API allows form data OR json body as input. - """ - - def parse_body(self, request): - # Did they submit JSON? - if request.content_type == "application/json" and request.body.strip(): - return json.loads(request.body) - # Fall back to form data - value = {} - for key, item in request.POST.items(): - value[key] = item - for key, item in request.GET.items(): - value[key] = item - return value diff --git a/api/schemas.py b/api/schemas.py index 60035fe..2efe4f2 100644 --- a/api/schemas.py +++ b/api/schemas.py @@ -1,6 +1,9 @@ from typing import Literal, Optional, Union -from ninja import Field, Schema +from activities import models as activities_models +from hatchway import Field, Schema +from users import models as users_models +from users.services import IdentityService class Application(Schema): @@ -19,6 +22,10 @@ class CustomEmoji(Schema): visible_in_picker: bool category: str + @classmethod + def from_emoji(cls, emoji: activities_models.Emoji) -> "CustomEmoji": + return cls(**emoji.to_mastodon_json()) + class AccountField(Schema): name: str @@ -53,6 +60,17 @@ class Account(Schema): following_count: int source: dict | None + @classmethod + def from_identity( + cls, + identity: users_models.Identity, + include_counts: bool = True, + source=False, + ) -> "Account": + return cls( + **identity.to_mastodon_json(include_counts=include_counts, source=source) + ) + class MediaAttachment(Schema): id: str @@ -64,6 +82,12 @@ class MediaAttachment(Schema): description: str | None blurhash: str | None + @classmethod + def from_post_attachment( + cls, attachment: activities_models.PostAttachment + ) -> "MediaAttachment": + return cls(**attachment.to_mastodon_json()) + class StatusMention(Schema): id: str @@ -108,6 +132,47 @@ class Status(Schema): bookmarked: bool | None pinned: bool | None + @classmethod + def from_post( + cls, + post: activities_models.Post, + interactions: dict[str, set[str]] | None = None, + ) -> "Status": + return cls(**post.to_mastodon_json(interactions=interactions)) + + @classmethod + def map_from_post( + cls, + posts: list[activities_models.Post], + identity: users_models.Identity, + ) -> list["Status"]: + interactions = activities_models.PostInteraction.get_post_interactions( + posts, identity + ) + return [cls.from_post(post, interactions=interactions) for post in posts] + + @classmethod + def from_timeline_event( + cls, + timeline_event: activities_models.TimelineEvent, + interactions: dict[str, set[str]] | None = None, + ) -> "Status": + return cls(**timeline_event.to_mastodon_status_json(interactions=interactions)) + + @classmethod + def map_from_timeline_event( + cls, + events: list[activities_models.TimelineEvent], + identity: users_models.Identity, + ) -> list["Status"]: + interactions = activities_models.PostInteraction.get_event_interactions( + events, identity + ) + return [ + cls.from_timeline_event(event, interactions=interactions) + for event in events + ] + class Conversation(Schema): id: str @@ -134,12 +199,26 @@ class Notification(Schema): account: Account status: Status | None + @classmethod + def from_timeline_event( + cls, + event: activities_models.TimelineEvent, + ) -> "Notification": + return cls(**event.to_mastodon_notification_json()) + class Tag(Schema): name: str url: str history: dict + @classmethod + def from_hashtag( + cls, + hashtag: activities_models.Hashtag, + ) -> "Tag": + return cls(**hashtag.to_mastodon_json()) + class Search(Schema): accounts: list[Account] @@ -162,6 +241,16 @@ class Relationship(Schema): endorsed: bool note: str + @classmethod + def from_identity_pair( + cls, + identity: users_models.Identity, + from_identity: users_models.Identity, + ) -> "Relationship": + return cls( + **IdentityService(identity).mastodon_json_relationship(from_identity) + ) + class Context(Schema): ancestors: list[Status] @@ -187,3 +276,11 @@ class Announcement(Schema): tags: list[Tag] emojis: list[CustomEmoji] reactions: list + + @classmethod + def from_announcement( + cls, + announcement: users_models.Announcement, + user: users_models.User, + ) -> "Announcement": + return cls(**announcement.to_mastodon_json(user=user)) diff --git a/api/urls.py b/api/urls.py new file mode 100644 index 0000000..8c53222 --- /dev/null +++ b/api/urls.py @@ -0,0 +1,72 @@ +from django.urls import path + +from api.views import ( + accounts, + announcements, + apps, + emoji, + filters, + instance, + media, + notifications, + search, + statuses, + timelines, +) +from hatchway import methods + +urlpatterns = [ + # Accounts + path("v1/accounts/verify_credentials", accounts.verify_credentials), + path("v1/accounts/update_credentials", accounts.update_credentials), + path("v1/accounts/relationships", accounts.account_relationships), + path("v1/accounts/familiar_followers", accounts.familiar_followers), + path("v1/accounts/search", accounts.accounts_search), + path("v1/accounts/lookup", accounts.lookup), + path("v1/accounts/", accounts.account), + path("v1/accounts//statuses", accounts.account_statuses), + path("v1/accounts//follow", accounts.account_follow), + path("v1/accounts//unfollow", accounts.account_unfollow), + path("v1/accounts//block", accounts.account_block), + path("v1/accounts//unblock", accounts.account_unblock), + path("v1/accounts//mute", accounts.account_mute), + path("v1/accounts//unmute", accounts.account_unmute), + path("v1/accounts//following", accounts.account_following), + path("v1/accounts//followers", accounts.account_followers), + # Announcements + path("v1/announcements", announcements.announcement_list), + path("v1/announcements//dismiss", announcements.announcement_dismiss), + # Apps + path("v1/apps", apps.add_app), + # Filters + path("v2/filters", filters.list_filters), + path("v1/filters", filters.list_filters), + # Instance + path("v1/instance", instance.instance_info_v1), + path("v2/instance", instance.instance_info_v2), + path("v1/timelines/home", timelines.home), + path("v1/timelines/public", timelines.public), + path("v1/timelines/tag/", timelines.hashtag), + path("v1/timelines/conversations", timelines.conversations), + path("v1/favourites", timelines.favourites), + path("v1/notifications", notifications.notifications), + path("v1/statuses", statuses.post_status), + path("v1/media", media.upload_media), + path("v2/media", media.upload_media), + path("v1/media/", methods(get=media.get_media, put=media.update_media)), + path( + "v1/statuses/", + methods( + get=statuses.status, + delete=statuses.delete_status, + ), + ), + path("v1/statuses//context", statuses.status_context), + path("v1/statuses//favourite", statuses.favourite_status), + path("v1/statuses//unfavourite", statuses.unfavourite_status), + path("v1/statuses//favourited_by", statuses.favourited_by), + path("v1/statuses//reblog", statuses.reblog_status), + path("v1/statuses//unreblog", statuses.unreblog_status), + path("v1/custom_emojis", emoji.emojis), + path("v2/search", search.search), +] diff --git a/api/views/__init__.py b/api/views/__init__.py index 26fcbaf..e69de29 100644 --- a/api/views/__init__.py +++ b/api/views/__init__.py @@ -1,12 +0,0 @@ -from .accounts import * # noqa -from .announcements import * # noqa -from .apps import * # noqa -from .emoji import * # noqa -from .filters import * # noqa -from .instance import * # noqa -from .media import * # noqa -from .notifications import * # noqa -from .oauth import * # noqa -from .search import * # noqa -from .statuses import * # noqa -from .timelines import * # noqa diff --git a/api/views/accounts.py b/api/views/accounts.py index b401ddd..dc68488 100644 --- a/api/views/accounts.py +++ b/api/views/accounts.py @@ -1,131 +1,129 @@ -from django.http import HttpRequest, HttpResponse, QueryDict -from django.http.multipartparser import MultiPartParser +from typing import Any + +from django.core.files import File +from django.http import HttpRequest from django.shortcuts import get_object_or_404 -from ninja import Field, Schema from activities.models import Post from activities.services import SearchService from api import schemas from api.decorators import identity_required -from api.pagination import MastodonPaginator -from api.views.base import api_router +from api.pagination import MastodonPaginator, PaginatingApiResponse, PaginationResult from core.models import Config +from hatchway import ApiResponse, QueryOrBody, api_view from users.models import Identity from users.services import IdentityService from users.shortcuts import by_handle_or_404 -@api_router.get("/v1/accounts/verify_credentials", response=schemas.Account) @identity_required -def verify_credentials(request): - return request.identity.to_mastodon_json(source=True) +@api_view.get +def verify_credentials(request) -> schemas.Account: + return schemas.Account.from_identity(request.identity, source=True) -@api_router.patch("/v1/accounts/update_credentials", response=schemas.Account) @identity_required +@api_view.patch def update_credentials( request, -): - # Django won't load POST and FILES for patch methods, so we do it. - if request.content_type == "multipart/form-data": - POST, FILES = MultiPartParser( - request.META, request, request.upload_handlers, request.encoding - ).parse() - elif request.content_type == "application/x-www-form-urlencoded": - POST = QueryDict(request.body, encoding=request._encoding) - FILES = {} - else: - return HttpResponse(status=400) + display_name: QueryOrBody[str | None] = None, + note: QueryOrBody[str | None] = None, + discoverable: QueryOrBody[bool | None] = None, + source: QueryOrBody[dict[str, Any] | None] = None, + fields_attributes: QueryOrBody[dict[str, dict[str, str]] | None] = None, + avatar: File | None = None, + header: File | None = None, +) -> schemas.Account: identity = request.identity service = IdentityService(identity) - if "display_name" in POST: - identity.name = POST["display_name"] - if "note" in POST: - service.set_summary(POST["note"]) - if "discoverable" in POST: - identity.discoverable = POST["discoverable"] == "checked" - if "source[privacy]" in POST: - privacy_map = { - "public": Post.Visibilities.public, - "unlisted": Post.Visibilities.unlisted, - "private": Post.Visibilities.followers, - "direct": Post.Visibilities.mentioned, - } - Config.set_identity( - identity, - "default_post_visibility", - privacy_map[POST["source[privacy]"]], - ) - if "fields_attributes[0][name]" in POST: + if display_name is not None: + identity.name = display_name + if note is not None: + service.set_summary(note) + if discoverable is not None: + identity.discoverable = discoverable + if source: + if "privacy" in source: + privacy_map = { + "public": Post.Visibilities.public, + "unlisted": Post.Visibilities.unlisted, + "private": Post.Visibilities.followers, + "direct": Post.Visibilities.mentioned, + } + Config.set_identity( + identity, + "default_post_visibility", + privacy_map[source["privacy"]], + ) + if fields_attributes: identity.metadata = [] - for i in range(4): - name_name = f"fields_attributes[{i}][name]" - value_name = f"fields_attributes[{i}][value]" - if name_name and value_name in POST: + for attribute in fields_attributes.values(): + attr_name = attribute.get("name", None) + attr_value = attribute.get("value", None) + if attr_name: # Empty value means delete this item - if not POST[value_name]: + if not attr_value: break - identity.metadata.append( - {"name": POST[name_name], "value": POST[value_name]} - ) - if "avatar" in FILES: - service.set_icon(FILES["avatar"]) - if "header" in FILES: - service.set_image(FILES["header"]) + identity.metadata.append({"name": attr_name, "value": attr_value}) + if avatar: + service.set_icon(avatar) + if header: + service.set_image(header) identity.save() - return identity.to_mastodon_json(source=True) + return schemas.Account.from_identity(identity, source=True) -@api_router.get("/v1/accounts/relationships", response=list[schemas.Relationship]) @identity_required -def account_relationships(request): - ids = request.GET.getlist("id[]") +@api_view.get +def account_relationships(request, id: list[str] | None) -> list[schemas.Relationship]: result = [] - for id in ids: - identity = get_object_or_404(Identity, pk=id) + # ID is actually a list. Thanks Mastodon! + ids = id or [] + for actual_id in ids: + identity = get_object_or_404(Identity, pk=actual_id) result.append( IdentityService(identity).mastodon_json_relationship(request.identity) ) return result -@api_router.get( - "/v1/accounts/familiar_followers", response=list[schemas.FamiliarFollowers] -) @identity_required -def familiar_followers(request): +@api_view.get +def familiar_followers( + request, id: list[str] | None +) -> list[schemas.FamiliarFollowers]: """ Returns people you follow that also follow given account IDs """ - ids = request.GET.getlist("id[]") + ids = id or [] result = [] - for id in ids: - target_identity = get_object_or_404(Identity, pk=id) + for actual_id in ids: + target_identity = get_object_or_404(Identity, pk=actual_id) result.append( - { - "id": id, - "accounts": [ - identity.to_mastodon_json() + schemas.FamiliarFollowers( + id=actual_id, + accounts=[ + schemas.Account.from_identity(identity) for identity in Identity.objects.filter( inbound_follows__source=request.identity, outbound_follows__target=target_identity, )[:20] ], - } + ) ) return result -@api_router.get("/v1/accounts/search", response=list[schemas.Account]) @identity_required -def search( +@api_view.get +def accounts_search( request, q: str, - fetch_identities: bool = Field(False, alias="resolve"), + resolve: bool = False, following: bool = False, limit: int = 20, offset: int = 0, -): +) -> list[schemas.Account]: """ Handles searching for accounts by username or handle """ @@ -135,33 +133,33 @@ def search( return [] searcher = SearchService(q, request.identity) search_result = searcher.search_identities_handle() - return [i.to_mastodon_json() for i in search_result] + return [schemas.Account.from_identity(i) for i in search_result] -@api_router.get("/v1/accounts/lookup", response=schemas.Account) -def lookup(request: HttpRequest, acct: str): +@api_view.get +def lookup(request: HttpRequest, acct: str) -> schemas.Account: """ Quickly lookup a username to see if it is available, skipping WebFinger resolution. """ identity = by_handle_or_404(request, handle=acct, local=False) - return identity.to_mastodon_json() + return schemas.Account.from_identity(identity) -@api_router.get("/v1/accounts/{id}", response=schemas.Account) +@api_view.get @identity_required -def account(request, id: str): +def account(request, id: str) -> schemas.Account: identity = get_object_or_404( - Identity.objects.exclude(restriction=Identity.Restriction.blocked), pk=id + Identity.objects.exclude(restriction=Identity.Restriction.blocked), + pk=id, ) - return identity.to_mastodon_json() + return schemas.Account.from_identity(identity) -@api_router.get("/v1/accounts/{id}/statuses", response=list[schemas.Status]) +@api_view.get @identity_required def account_statuses( request: HttpRequest, - response: HttpResponse, id: str, exclude_reblogs: bool = False, exclude_replies: bool = False, @@ -172,7 +170,7 @@ def account_statuses( since_id: str | None = None, min_id: str | None = None, limit: int = 20, -): +) -> ApiResponse[list[schemas.Status]]: identity = get_object_or_404( Identity.objects.exclude(restriction=Identity.Restriction.blocked), pk=id ) @@ -191,177 +189,163 @@ def account_statuses( .order_by("-created") ) if pinned: - return [] + return ApiResponse([]) if only_media: queryset = queryset.filter(attachments__pk__isnull=False) if tagged: queryset = queryset.tagged_with(tagged) # Get user posts with pagination paginator = MastodonPaginator() - pager = paginator.paginate( + pager: PaginationResult[Post] = paginator.paginate( queryset, min_id=min_id, max_id=max_id, since_id=since_id, limit=limit, ) - # Convert those to the JSON form - pager.jsonify_posts(identity=request.identity) - # Add a link header if we need to - if pager.results: - response.headers["Link"] = pager.link_header( - request, - [ - "limit", - "id", - "exclude_reblogs", - "exclude_replies", - "only_media", - "pinned", - "tagged", - ], - ) - return pager.json_results + return PaginatingApiResponse( + schemas.Status.map_from_post(pager.results, request.identity), + request=request, + include_params=[ + "limit", + "id", + "exclude_reblogs", + "exclude_replies", + "only_media", + "pinned", + "tagged", + ], + ) -@api_router.post("/v1/accounts/{id}/follow", response=schemas.Relationship) +@api_view.post @identity_required -def account_follow(request, id: str, reblogs: bool = True): +def account_follow(request, id: str, reblogs: bool = True) -> schemas.Relationship: identity = get_object_or_404( Identity.objects.exclude(restriction=Identity.Restriction.blocked), pk=id ) service = IdentityService(identity) service.follow_from(request.identity, boosts=reblogs) - return service.mastodon_json_relationship(request.identity) + return schemas.Relationship.from_identity_pair(identity, request.identity) -@api_router.post("/v1/accounts/{id}/unfollow", response=schemas.Relationship) +@api_view.post @identity_required -def account_unfollow(request, id: str): +def account_unfollow(request, id: str) -> schemas.Relationship: identity = get_object_or_404( Identity.objects.exclude(restriction=Identity.Restriction.blocked), pk=id ) service = IdentityService(identity) service.unfollow_from(request.identity) - return service.mastodon_json_relationship(request.identity) + return schemas.Relationship.from_identity_pair(identity, request.identity) -@api_router.post("/v1/accounts/{id}/block", response=schemas.Relationship) +@api_view.post @identity_required -def account_block(request, id: str): +def account_block(request, id: str) -> schemas.Relationship: identity = get_object_or_404(Identity, pk=id) service = IdentityService(identity) service.block_from(request.identity) - return service.mastodon_json_relationship(request.identity) + return schemas.Relationship.from_identity_pair(identity, request.identity) -@api_router.post("/v1/accounts/{id}/unblock", response=schemas.Relationship) +@api_view.post @identity_required -def account_unblock(request, id: str): +def account_unblock(request, id: str) -> schemas.Relationship: identity = get_object_or_404(Identity, pk=id) service = IdentityService(identity) service.unblock_from(request.identity) - return service.mastodon_json_relationship(request.identity) + return schemas.Relationship.from_identity_pair(identity, request.identity) -class MuteDetailsSchema(Schema): - notifications: bool = True - duration: int = 0 - - -@api_router.post("/v1/accounts/{id}/mute", response=schemas.Relationship) @identity_required -def account_mute(request, id: str, details: MuteDetailsSchema): +@api_view.post +def account_mute( + request, + id: str, + notifications: QueryOrBody[bool] = True, + duration: QueryOrBody[int] = 0, +) -> schemas.Relationship: identity = get_object_or_404(Identity, pk=id) service = IdentityService(identity) service.mute_from( request.identity, - duration=details.duration, - include_notifications=details.notifications, + duration=duration, + include_notifications=notifications, ) - return service.mastodon_json_relationship(request.identity) + return schemas.Relationship.from_identity_pair(identity, request.identity) -@api_router.post("/v1/accounts/{id}/unmute", response=schemas.Relationship) @identity_required -def account_unmute(request, id: str): +@api_view.post +def account_unmute(request, id: str) -> schemas.Relationship: identity = get_object_or_404(Identity, pk=id) service = IdentityService(identity) service.unmute_from(request.identity) - return service.mastodon_json_relationship(request.identity) + return schemas.Relationship.from_identity_pair(identity, request.identity) -@api_router.get("/v1/accounts/{id}/following", response=list[schemas.Account]) +@api_view.get def account_following( request: HttpRequest, - response: HttpResponse, id: str, max_id: str | None = None, since_id: str | None = None, min_id: str | None = None, limit: int = 40, -): +) -> ApiResponse[list[schemas.Account]]: identity = get_object_or_404( Identity.objects.exclude(restriction=Identity.Restriction.blocked), pk=id ) if not identity.config_identity.visible_follows and request.identity != identity: - return [] + return ApiResponse([]) service = IdentityService(identity) paginator = MastodonPaginator(max_limit=80) - pager = paginator.paginate( + pager: PaginationResult[Identity] = paginator.paginate( service.following(), min_id=min_id, max_id=max_id, since_id=since_id, limit=limit, ) - pager.jsonify_identities() - - if pager.results: - response.headers["Link"] = pager.link_header( - request, - ["limit"], - ) - - return pager.json_results + return PaginatingApiResponse( + [schemas.Account.from_identity(i) for i in pager.results], + request=request, + include_params=["limit"], + ) -@api_router.get("/v1/accounts/{id}/followers", response=list[schemas.Account]) +@api_view.get def account_followers( request: HttpRequest, - response: HttpResponse, id: str, max_id: str | None = None, since_id: str | None = None, min_id: str | None = None, limit: int = 40, -): +) -> ApiResponse[list[schemas.Account]]: identity = get_object_or_404( Identity.objects.exclude(restriction=Identity.Restriction.blocked), pk=id ) if not identity.config_identity.visible_follows and request.identity != identity: - return [] + return ApiResponse([]) service = IdentityService(identity) paginator = MastodonPaginator(max_limit=80) - pager = paginator.paginate( + pager: PaginationResult[Identity] = paginator.paginate( service.followers(), min_id=min_id, max_id=max_id, since_id=since_id, limit=limit, ) - pager.jsonify_identities() - - if pager.results: - response.headers["Link"] = pager.link_header( - request, - ["limit"], - ) - - return pager.json_results + return PaginatingApiResponse( + [schemas.Account.from_identity(i) for i in pager.results], + request=request, + include_params=["limit"], + ) diff --git a/api/views/announcements.py b/api/views/announcements.py index 0fd44af..50fe473 100644 --- a/api/views/announcements.py +++ b/api/views/announcements.py @@ -2,22 +2,22 @@ from django.shortcuts import get_object_or_404 from api import schemas from api.decorators import identity_required -from api.views.base import api_router +from hatchway import api_view from users.models import Announcement from users.services import AnnouncementService -@api_router.get("/v1/announcements", response=list[schemas.Announcement]) @identity_required -def announcement_list(request): +@api_view.get +def announcement_list(request) -> list[schemas.Announcement]: return [ - a.to_mastodon_json(request.user) + schemas.Announcement.from_announcement(a, request.user) for a in AnnouncementService(request.user).visible() ] -@api_router.post("/v1/announcements/{pk}/dismiss") @identity_required +@api_view.post def announcement_dismiss(request, pk: str): announcement = get_object_or_404(Announcement, pk=pk) AnnouncementService(request.user).mark_seen(announcement) diff --git a/api/views/apps.py b/api/views/apps.py index 74d8ac8..ebb883e 100644 --- a/api/views/apps.py +++ b/api/views/apps.py @@ -1,10 +1,9 @@ import secrets -from ninja import Schema +from hatchway import Schema, api_view from .. import schemas from ..models import Application -from .base import api_router class CreateApplicationSchema(Schema): @@ -14,8 +13,8 @@ class CreateApplicationSchema(Schema): website: None | str = None -@api_router.post("/v1/apps", response=schemas.Application) -def add_app(request, details: CreateApplicationSchema): +@api_view.post +def add_app(request, details: CreateApplicationSchema) -> schemas.Application: client_id = "tk-" + secrets.token_urlsafe(16) client_secret = secrets.token_urlsafe(40) application = Application.objects.create( @@ -26,4 +25,4 @@ def add_app(request, details: CreateApplicationSchema): redirect_uris=details.redirect_uris, scopes=details.scopes or "read", ) - return application + return schemas.Application.from_orm(application) diff --git a/api/views/base.py b/api/views/base.py deleted file mode 100644 index 33efc47..0000000 --- a/api/views/base.py +++ /dev/null @@ -1,5 +0,0 @@ -from ninja import NinjaAPI - -from api.parser import FormOrJsonParser - -api_router = NinjaAPI(parser=FormOrJsonParser()) diff --git a/api/views/emoji.py b/api/views/emoji.py index a0ffbc3..f909067 100644 --- a/api/views/emoji.py +++ b/api/views/emoji.py @@ -1,8 +1,10 @@ from activities.models import Emoji from api.schemas import CustomEmoji -from api.views.base import api_router +from hatchway import api_view -@api_router.get("/v1/custom_emojis", response=list[CustomEmoji]) -def emojis(request): - return [e.to_mastodon_json() for e in Emoji.objects.usable().filter(local=True)] +@api_view.get +def emojis(request) -> list[CustomEmoji]: + return [ + CustomEmoji.from_emoji(e) for e in Emoji.objects.usable().filter(local=True) + ] diff --git a/api/views/filters.py b/api/views/filters.py index 261247a..5b94f41 100644 --- a/api/views/filters.py +++ b/api/views/filters.py @@ -1,8 +1,8 @@ from api.decorators import identity_required -from api.views.base import api_router +from hatchway import api_view -@api_router.get("/v1/filters") @identity_required -def status(request): +@api_view.get +def list_filters(request): return [] diff --git a/api/views/instance.py b/api/views/instance.py index 0046502..1419038 100644 --- a/api/views/instance.py +++ b/api/views/instance.py @@ -1,15 +1,15 @@ from django.conf import settings from activities.models import Post +from api import schemas from core.models import Config +from hatchway import api_view from takahe import __version__ from users.models import Domain, Identity -from .base import api_router - -@api_router.get("/v1/instance") -def instance_info(request): +@api_view.get +def instance_info_v1(request): return { "uri": request.headers.get("host", settings.SETUP.MAIN_DOMAIN), "title": Config.system.site_name, @@ -53,11 +53,17 @@ def instance_info(request): } -@api_router.get("/v2/instance") -def instance_info_v2(request): +@api_view.get +def instance_info_v2(request) -> dict: current_domain = Domain.get_domain( request.headers.get("host", settings.SETUP.MAIN_DOMAIN) ) + if current_domain is None or not current_domain.local: + current_domain = Domain.get_domain( + request.headers.get(settings.SETUP.MAIN_DOMAIN) + ) + if current_domain is None: + raise ValueError("No domain set up for MAIN_DOMAIN") admin_identity = ( Identity.objects.filter(users__admin=True).order_by("created").first() ) @@ -116,7 +122,7 @@ def instance_info_v2(request): }, "contact": { "email": "", - "account": admin_identity.to_mastodon_json(), + "account": schemas.Account.from_identity(admin_identity), }, "rules": [], } diff --git a/api/views/media.py b/api/views/media.py index 35c0650..a856752 100644 --- a/api/views/media.py +++ b/api/views/media.py @@ -1,28 +1,22 @@ +from django.core.files import File from django.shortcuts import get_object_or_404 -from ninja import File, Schema -from ninja.files import UploadedFile from activities.models import PostAttachment, PostAttachmentStates from api import schemas -from api.views.base import api_router from core.files import blurhash_image, resize_image +from hatchway import QueryOrBody, api_view from ..decorators import identity_required -class UploadMediaSchema(Schema): - description: str = "" - focus: str = "0,0" - - -@api_router.post("/v1/media", response=schemas.MediaAttachment) -@api_router.post("/v2/media", response=schemas.MediaAttachment) @identity_required +@api_view.post def upload_media( request, - file: UploadedFile = File(...), - details: UploadMediaSchema | None = None, -): + file: File, + description: QueryOrBody[str] = "", + focus: QueryOrBody[str] = "0,0", +) -> schemas.MediaAttachment: main_file = resize_image( file, size=(2000, 2000), @@ -38,7 +32,7 @@ def upload_media( mimetype="image/webp", width=main_file.image.width, height=main_file.image.height, - name=details.description if details else None, + name=description or None, state=PostAttachmentStates.fetched, ) attachment.file.save( @@ -50,27 +44,28 @@ def upload_media( thumbnail_file, ) attachment.save() - return attachment.to_mastodon_json() + return schemas.MediaAttachment.from_post_attachment(attachment) -@api_router.get("/v1/media/{id}", response=schemas.MediaAttachment) @identity_required +@api_view.get def get_media( request, id: str, -): +) -> schemas.MediaAttachment: attachment = get_object_or_404(PostAttachment, pk=id) - return attachment.to_mastodon_json() + return schemas.MediaAttachment.from_post_attachment(attachment) -@api_router.put("/v1/media/{id}", response=schemas.MediaAttachment) @identity_required +@api_view.put def update_media( request, id: str, - details: UploadMediaSchema | None = None, -): + description: QueryOrBody[str] = "", + focus: QueryOrBody[str] = "0,0", +) -> schemas.MediaAttachment: attachment = get_object_or_404(PostAttachment, pk=id) - attachment.name = details.description if details else None + attachment.name = description or None attachment.save() - return attachment.to_mastodon_json() + return schemas.MediaAttachment.from_post_attachment(attachment) diff --git a/api/views/notifications.py b/api/views/notifications.py index 98bee04..5553d6a 100644 --- a/api/views/notifications.py +++ b/api/views/notifications.py @@ -1,24 +1,23 @@ -from django.http import HttpRequest, HttpResponse +from django.http import HttpRequest from activities.models import TimelineEvent from activities.services import TimelineService from api import schemas from api.decorators import identity_required -from api.pagination import MastodonPaginator -from api.views.base import api_router +from api.pagination import MastodonPaginator, PaginatingApiResponse, PaginationResult +from hatchway import ApiResponse, api_view -@api_router.get("/v1/notifications", response=list[schemas.Notification]) @identity_required +@api_view.get def notifications( request: HttpRequest, - response: HttpResponse, max_id: str | None = None, since_id: str | None = None, min_id: str | None = None, limit: int = 20, account_id: str | None = None, -): +) -> ApiResponse[list[schemas.Notification]]: # Types/exclude_types use weird syntax so we have to handle them manually base_types = { "favourite": TimelineEvent.Types.liked, @@ -37,16 +36,15 @@ def notifications( [base_types[r] for r in requested_types if r in base_types] ) paginator = MastodonPaginator() - pager = paginator.paginate( + pager: PaginationResult[TimelineEvent] = paginator.paginate( queryset, min_id=min_id, max_id=max_id, since_id=since_id, limit=limit, ) - pager.jsonify_notification_events(identity=request.identity) - - if pager.results: - response.headers["Link"] = pager.link_header(request, ["limit", "account_id"]) - - return pager.json_results + return PaginatingApiResponse( + [schemas.Notification.from_timeline_event(event) for event in pager.results], + request=request, + include_params=["limit", "account_id"], + ) diff --git a/api/views/oauth.py b/api/views/oauth.py index e37323b..161e18a 100644 --- a/api/views/oauth.py +++ b/api/views/oauth.py @@ -1,4 +1,5 @@ import base64 +import json import secrets from urllib.parse import urlparse, urlunparse @@ -16,7 +17,6 @@ from django.views.decorators.csrf import csrf_exempt from django.views.generic import View from api.models import Application, Authorization, Token -from api.parser import FormOrJsonParser class OauthRedirect(HttpResponseRedirect): @@ -40,6 +40,19 @@ class OauthRedirect(HttpResponseRedirect): super().__init__(urlunparse(url_parts)) +def get_json_and_formdata(request): + # Did they submit JSON? + if request.content_type == "application/json" and request.body.strip(): + return json.loads(request.body) + # Fall back to form data + value = {} + for key, item in request.POST.items(): + value[key] = item + for key, item in request.GET.items(): + value[key] = item + return value + + class AuthorizationView(LoginRequiredMixin, View): """ Asks the user to authorize access. @@ -87,7 +100,7 @@ class AuthorizationView(LoginRequiredMixin, View): return render(request, "api/oauth_authorize.html", context) def post(self, request): - post_data = FormOrJsonParser().parse_body(request) + post_data = get_json_and_formdata(request) # Grab the application and other details again redirect_uri = post_data["redirect_uri"] scope = post_data["scope"] @@ -141,7 +154,7 @@ class TokenView(View): ) def post(self, request): - post_data = FormOrJsonParser().parse_body(request) + post_data = get_json_and_formdata(request) auth_client_id, auth_client_secret = extract_client_info_from_basic_auth( request ) @@ -221,7 +234,7 @@ class TokenView(View): @method_decorator(csrf_exempt, name="dispatch") class RevokeTokenView(View): def post(self, request): - post_data = FormOrJsonParser().parse_body(request) + post_data = get_json_and_formdata(request) auth_client_id, auth_client_secret = extract_client_info_from_basic_auth( request ) diff --git a/api/views/search.py b/api/views/search.py index 6d91cbe..bca127d 100644 --- a/api/views/search.py +++ b/api/views/search.py @@ -1,16 +1,14 @@ from typing import Literal -from ninja import Field - from activities.models import PostInteraction from activities.services.search import SearchService from api import schemas from api.decorators import identity_required -from api.views.base import api_router +from hatchway import Field, api_view -@api_router.get("/v2/search", response=schemas.Search) @identity_required +@api_view.get def search( request, q: str, @@ -24,26 +22,31 @@ def search( min_id: str | None = None, limit: int = 20, offset: int = 0, -): +) -> schemas.Search: if limit > 40: limit = 40 result: dict[str, list] = {"accounts": [], "statuses": [], "hashtags": []} # We don't support pagination for searches yet if max_id or since_id or min_id or offset: - return result + return schemas.Search(**result) # Run search searcher = SearchService(q, request.identity) search_result = searcher.search_all() if type is None or type == "accounts": - result["accounts"] = [i.to_mastodon_json() for i in search_result["identities"]] + result["accounts"] = [ + schemas.Account.from_identity(i, include_counts=False) + for i in search_result["identities"] + ] if type is None or type == "hashtag": - result["hashtag"] = [h.to_mastodon_json() for h in search_result["hashtags"]] + result["hashtag"] = [ + schemas.Tag.from_hashtag(h) for h in search_result["hashtags"] + ] if type is None or type == "statuses": interactions = PostInteraction.get_post_interactions( search_result["posts"], request.identity ) result["statuses"] = [ - p.to_mastodon_json(interactions=interactions) + schemas.Status.from_post(p, interactions=interactions) for p in search_result["posts"] ] - return result + return schemas.Search(**result) diff --git a/api/views/statuses.py b/api/views/statuses.py index 566cbea..6161c09 100644 --- a/api/views/statuses.py +++ b/api/views/statuses.py @@ -1,9 +1,8 @@ from typing import Literal from django.forms import ValidationError -from django.http import HttpRequest, HttpResponse +from django.http import HttpRequest from django.shortcuts import get_object_or_404 -from ninja import Schema from activities.models import ( Post, @@ -14,11 +13,10 @@ from activities.models import ( ) from activities.services import PostService from api import schemas -from api.views.base import api_router +from api.decorators import identity_required +from api.pagination import MastodonPaginator, PaginationResult from core.models import Config - -from ..decorators import identity_required -from ..pagination import MastodonPaginator +from hatchway import ApiResponse, Schema, api_view class PostStatusSchema(Schema): @@ -32,9 +30,9 @@ class PostStatusSchema(Schema): media_ids: list[str] = [] -@api_router.post("/v1/statuses", response=schemas.Status) @identity_required -def post_status(request, details: PostStatusSchema): +@api_view.post +def post_status(request, details: PostStatusSchema) -> schemas.Status: # Check text length if len(details.status) > Config.system.post_length: raise ValidationError("Status is too long") @@ -66,74 +64,74 @@ def post_status(request, details: PostStatusSchema): ) # Add their own timeline event for immediate visibility TimelineEvent.add_post(request.identity, post) - return post.to_mastodon_json() + return schemas.Status.from_post(post) -@api_router.get("/v1/statuses/{id}", response=schemas.Status) @identity_required -def status(request, id: str): +@api_view.get +def status(request, id: str) -> schemas.Status: post = get_object_or_404(Post, pk=id) interactions = PostInteraction.get_post_interactions([post], request.identity) - return post.to_mastodon_json(interactions=interactions) + return schemas.Status.from_post(post, interactions=interactions) -@api_router.delete("/v1/statuses/{id}", response=schemas.Status) @identity_required -def delete_status(request, id: str): +@api_view.delete +def delete_status(request, id: str) -> schemas.Status: post = get_object_or_404(Post, pk=id) PostService(post).delete() - return post.to_mastodon_json() + return schemas.Status.from_post(post) -@api_router.get("/v1/statuses/{id}/context", response=schemas.Context) @identity_required -def status_context(request, id: str): +@api_view.get +def status_context(request, id: str) -> schemas.Context: post = get_object_or_404(Post, pk=id) service = PostService(post) ancestors, descendants = service.context(request.identity) interactions = PostInteraction.get_post_interactions( ancestors + descendants, request.identity ) - return { - "ancestors": [ - p.to_mastodon_json(interactions=interactions) for p in reversed(ancestors) + return schemas.Context( + ancestors=[ + schemas.Status.from_post(p, interactions=interactions) + for p in reversed(ancestors) ], - "descendants": [ - p.to_mastodon_json(interactions=interactions) for p in descendants + descendants=[ + schemas.Status.from_post(p, interactions=interactions) for p in descendants ], - } + ) -@api_router.post("/v1/statuses/{id}/favourite", response=schemas.Status) @identity_required -def favourite_status(request, id: str): +@api_view.post +def favourite_status(request, id: str) -> schemas.Status: post = get_object_or_404(Post, pk=id) service = PostService(post) service.like_as(request.identity) interactions = PostInteraction.get_post_interactions([post], request.identity) - return post.to_mastodon_json(interactions=interactions) + return schemas.Status.from_post(post, interactions=interactions) -@api_router.post("/v1/statuses/{id}/unfavourite", response=schemas.Status) @identity_required -def unfavourite_status(request, id: str): +@api_view.post +def unfavourite_status(request, id: str) -> schemas.Status: post = get_object_or_404(Post, pk=id) service = PostService(post) service.unlike_as(request.identity) interactions = PostInteraction.get_post_interactions([post], request.identity) - return post.to_mastodon_json(interactions=interactions) + return schemas.Status.from_post(post, interactions=interactions) -@api_router.get("/v1/statuses/{id}/favourited_by", response=list[schemas.Account]) +@api_view.get def favourited_by( request: HttpRequest, - response: HttpResponse, id: str, max_id: str | None = None, since_id: str | None = None, min_id: str | None = None, limit: int = 20, -): +) -> ApiResponse[list[schemas.Account]]: """ View who favourited a given status. """ @@ -142,7 +140,7 @@ def favourited_by( post = get_object_or_404(Post, pk=id) paginator = MastodonPaginator() - pager = paginator.paginate( + pager: PaginationResult[PostInteraction] = paginator.paginate( post.interactions.filter( type=PostInteraction.Types.like, state__in=PostInteractionStates.group_active(), @@ -152,32 +150,37 @@ def favourited_by( since_id=since_id, limit=limit, ) - pager.jsonify_results(lambda r: r.identity.to_mastodon_json(include_counts=False)) + headers = {} if pager.results: - response.headers["Link"] = pager.link_header( - request, - ["limit"], - ) - - return pager.json_results + headers = {"link": pager.link_header(request, ["limit"])} + return ApiResponse( + [ + schemas.Account.from_identity( + interaction.identity, + include_counts=False, + ) + for interaction in pager.results + ], + headers=headers, + ) -@api_router.post("/v1/statuses/{id}/reblog", response=schemas.Status) @identity_required -def reblog_status(request, id: str): +@api_view.post +def reblog_status(request, id: str) -> schemas.Status: post = get_object_or_404(Post, pk=id) service = PostService(post) service.boost_as(request.identity) interactions = PostInteraction.get_post_interactions([post], request.identity) - return post.to_mastodon_json(interactions=interactions) + return schemas.Status.from_post(post, interactions=interactions) -@api_router.post("/v1/statuses/{id}/unreblog", response=schemas.Status) @identity_required -def unreblog_status(request, id: str): +@api_view.post +def unreblog_status(request, id: str) -> schemas.Status: post = get_object_or_404(Post, pk=id) service = PostService(post) service.unboost_as(request.identity) interactions = PostInteraction.get_post_interactions([post], request.identity) - return post.to_mastodon_json(interactions=interactions) + return schemas.Status.from_post(post, interactions=interactions) diff --git a/api/views/timelines.py b/api/views/timelines.py index 70d7d3b..fcaf6bc 100644 --- a/api/views/timelines.py +++ b/api/views/timelines.py @@ -1,23 +1,23 @@ -from django.http import HttpRequest, HttpResponse, JsonResponse +from django.http import HttpRequest +from activities.models import Post from activities.services import TimelineService from api import schemas from api.decorators import identity_required -from api.pagination import MastodonPaginator -from api.views.base import api_router +from api.pagination import MastodonPaginator, PaginatingApiResponse, PaginationResult from core.models import Config +from hatchway import ApiError, ApiResponse, api_view -@api_router.get("/v1/timelines/home", response=list[schemas.Status]) @identity_required +@api_view.get def home( request: HttpRequest, - response: HttpResponse, max_id: str | None = None, since_id: str | None = None, min_id: str | None = None, limit: int = 20, -): +) -> ApiResponse[list[schemas.Status]]: # Grab a paginated result set of instances paginator = MastodonPaginator() queryset = TimelineService(request.identity).home() @@ -41,18 +41,16 @@ def home( since_id=since_id, limit=limit, ) - # Convert those to the JSON form - pager.jsonify_status_events(identity=request.identity) - # Add the link header if needed - if pager.results: - response.headers["Link"] = pager.link_header(request, ["limit"]) - return pager.json_results + return PaginatingApiResponse( + schemas.Status.map_from_timeline_event(pager.results, request.identity), + request=request, + include_params=["limit"], + ) -@api_router.get("/v1/timelines/public", response=list[schemas.Status]) +@api_view.get def public( request: HttpRequest, - response: HttpResponse, local: bool = False, remote: bool = False, only_media: bool = False, @@ -60,9 +58,9 @@ def public( since_id: str | None = None, min_id: str | None = None, limit: int = 20, -): +) -> ApiResponse[list[schemas.Status]]: if not request.identity and not Config.system.public_timeline: - return JsonResponse({"error": "public timeline is disabled"}, status=422) + raise ApiError(error="public timeline is disabled", status=422) if local: queryset = TimelineService(request.identity).local() @@ -74,29 +72,24 @@ def public( queryset = queryset.filter(attachments__id__isnull=True) # Grab a paginated result set of instances paginator = MastodonPaginator() - pager = paginator.paginate( + pager: PaginationResult[Post] = paginator.paginate( queryset, min_id=min_id, max_id=max_id, since_id=since_id, limit=limit, ) - # Convert those to the JSON form - pager.jsonify_posts(identity=request.identity) - # Add the link header if needed - if pager.results: - response.headers["Link"] = pager.link_header( - request, - ["limit", "local", "remote", "only_media"], - ) - return pager.json_results + return PaginatingApiResponse( + schemas.Status.map_from_post(pager.results, request.identity), + request=request, + include_params=["limit", "local", "remote", "only_media"], + ) -@api_router.get("/v1/timelines/tag/{hashtag}", response=list[schemas.Status]) @identity_required +@api_view.get def hashtag( request: HttpRequest, - response: HttpResponse, hashtag: str, local: bool = False, only_media: bool = False, @@ -104,7 +97,7 @@ def hashtag( since_id: str | None = None, min_id: str | None = None, limit: int = 20, -): +) -> ApiResponse[list[schemas.Status]]: if limit > 40: limit = 40 queryset = TimelineService(request.identity).hashtag(hashtag) @@ -114,63 +107,54 @@ def hashtag( queryset = queryset.filter(attachments__id__isnull=True) # Grab a paginated result set of instances paginator = MastodonPaginator() - pager = paginator.paginate( + pager: PaginationResult[Post] = paginator.paginate( queryset, min_id=min_id, max_id=max_id, since_id=since_id, limit=limit, ) - # Convert those to the JSON form - pager.jsonify_posts(identity=request.identity) - # Add a link header if we need to - if pager.results: - response.headers["Link"] = pager.link_header( - request, - ["limit", "local", "remote", "only_media"], - ) - return pager.json_results + return PaginatingApiResponse( + schemas.Status.map_from_post(pager.results, request.identity), + request=request, + include_params=["limit", "local", "remote", "only_media"], + ) -@api_router.get("/v1/conversations", response=list[schemas.Status]) @identity_required +@api_view.get def conversations( request: HttpRequest, - response: HttpResponse, max_id: str | None = None, since_id: str | None = None, min_id: str | None = None, limit: int = 20, -): +) -> list[schemas.Status]: # We don't implement this yet return [] -@api_router.get("/v1/favourites", response=list[schemas.Status]) @identity_required +@api_view.get def favourites( request: HttpRequest, - response: HttpResponse, max_id: str | None = None, since_id: str | None = None, min_id: str | None = None, limit: int = 20, -): +) -> ApiResponse[list[schemas.Status]]: queryset = TimelineService(request.identity).likes() paginator = MastodonPaginator() - pager = paginator.paginate( + pager: PaginationResult[Post] = paginator.paginate( queryset, min_id=min_id, max_id=max_id, since_id=since_id, limit=limit, ) - # Convert those to the JSON form - pager.jsonify_posts(identity=request.identity) - - # Add the link header if needed - if pager.results: - response.headers["Link"] = pager.link_header(request, ["limit"]) - - return pager.json_results + return PaginatingApiResponse( + schemas.Status.map_from_post(pager.results, request.identity), + request=request, + include_params=["limit"], + ) diff --git a/hatchway/__init__.py b/hatchway/__init__.py new file mode 100644 index 0000000..318ba73 --- /dev/null +++ b/hatchway/__init__.py @@ -0,0 +1,5 @@ +from .http import ApiError, ApiResponse # noqa +from .schema import Field, Schema # noqa +from .types import Body, BodyDirect, Path, Query, QueryOrBody # noqa +from .urls import methods # noqa +from .view import api_view # noqa diff --git a/hatchway/constants.py b/hatchway/constants.py new file mode 100644 index 0000000..720e31f --- /dev/null +++ b/hatchway/constants.py @@ -0,0 +1,10 @@ +import enum + + +class InputSource(str, enum.Enum): + path = "path" + query = "query" + body = "body" + body_direct = "body_direct" + query_and_body_direct = "query_and_body_direct" + file = "file" diff --git a/hatchway/http.py b/hatchway/http.py new file mode 100644 index 0000000..661cd3f --- /dev/null +++ b/hatchway/http.py @@ -0,0 +1,47 @@ +import json +from typing import Generic, TypeVar + +from django.core.serializers.json import DjangoJSONEncoder +from django.http import HttpResponse + +T = TypeVar("T") + + +class ApiResponse(Generic[T], HttpResponse): + """ + A way to return extra information with a response if you want + headers, etc. + """ + + def __init__( + self, + data: T, + encoder=DjangoJSONEncoder, + json_dumps_params: dict[str, object] | None = None, + finalize: bool = False, + **kwargs + ): + self.data = data + self.encoder = encoder + self.json_dumps_params = json_dumps_params or {} + kwargs.setdefault("content_type", "application/json") + super().__init__(content=b"(unfinalised)", **kwargs) + if finalize: + self.finalize() + + def finalize(self): + """ + Converts whatever our current data is into HttpResponse content + """ + # TODO: Automatically call this when we're asked to write output? + self.content = json.dumps(self.data, cls=self.encoder, **self.json_dumps_params) + + +class ApiError(BaseException): + """ + A handy way to raise an error with JSONable contents + """ + + def __init__(self, status: int, error: str): + self.status = status + self.error = error diff --git a/hatchway/schema.py b/hatchway/schema.py new file mode 100644 index 0000000..fcdb491 --- /dev/null +++ b/hatchway/schema.py @@ -0,0 +1,52 @@ +from typing import Any + +from django.db.models import Manager, QuerySet +from django.db.models.fields.files import FieldFile +from django.template import Variable, VariableDoesNotExist +from pydantic.fields import Field # noqa +from pydantic.main import BaseModel +from pydantic.utils import GetterDict + + +class DjangoGetterDict(GetterDict): + def __init__(self, obj: Any): + self._obj = obj + + def __getitem__(self, key: str) -> Any: + try: + item = getattr(self._obj, key) + except AttributeError: + try: + item = Variable(key).resolve(self._obj) + except VariableDoesNotExist as e: + raise KeyError(key) from e + return self._convert_result(item) + + def get(self, key: Any, default: Any = None) -> Any: + try: + return self[key] + except KeyError: + return default + + def _convert_result(self, result: Any) -> Any: + if isinstance(result, Manager): + return list(result.all()) + + elif isinstance(result, getattr(QuerySet, "__origin__", QuerySet)): + return list(result) + + if callable(result): + return result() + + elif isinstance(result, FieldFile): + if not result: + return None + return result.url + + return result + + +class Schema(BaseModel): + class Config: + orm_mode = True + getter_dict = DjangoGetterDict diff --git a/hatchway/tests/test_types.py b/hatchway/tests/test_types.py new file mode 100644 index 0000000..ddc011f --- /dev/null +++ b/hatchway/tests/test_types.py @@ -0,0 +1,63 @@ +from typing import Literal, Optional, Union + +from django.core.files import File + +from hatchway.http import ApiResponse +from hatchway.types import ( + Query, + QueryType, + acceptable_input, + extract_output_type, + extract_signifier, + is_optional, +) + + +def test_is_optional(): + + assert is_optional(Optional[int]) == (True, int) + assert is_optional(Union[int, None]) == (True, int) + assert is_optional(Union[None, int]) == (True, int) + assert is_optional(int | None) == (True, int) + assert is_optional(None | int) == (True, int) + assert is_optional(int) == (False, int) + assert is_optional(Query[int]) == (False, Query[int]) + + +def test_extract_signifier(): + + assert extract_signifier(int) == (None, int) + assert extract_signifier(Query[int]) == (QueryType, int) + assert extract_signifier(Query[Optional[int]]) == ( # type:ignore + QueryType, + Optional[int], + ) + assert extract_signifier(Query[int | None]) == ( # type:ignore + QueryType, + Optional[int], + ) + assert extract_signifier(Optional[Query[int]]) == (QueryType, Optional[int]) + + +def test_extract_output_type(): + + assert extract_output_type(int) == int + assert extract_output_type(ApiResponse[int]) == int + assert extract_output_type(ApiResponse[int | str]) == int | str + + +def test_acceptable_input(): + + assert acceptable_input(str) is True + assert acceptable_input(int) is True + assert acceptable_input(Query[int]) is True + assert acceptable_input(Optional[int]) is True + assert acceptable_input(int | None) is True + assert acceptable_input(int | str | None) is True + assert acceptable_input(Query[int | None]) is True # type: ignore + assert acceptable_input(File) is True + assert acceptable_input(list[str]) is True + assert acceptable_input(dict[str, int]) is True + assert acceptable_input(Literal["a", "b"]) is True + assert acceptable_input(frozenset) is False + assert acceptable_input(dict[str, frozenset]) is False diff --git a/hatchway/tests/test_view.py b/hatchway/tests/test_view.py new file mode 100644 index 0000000..c4b43e9 --- /dev/null +++ b/hatchway/tests/test_view.py @@ -0,0 +1,244 @@ +import json + +import pytest +from django.core import files +from django.core.files.uploadedfile import SimpleUploadedFile +from django.http import QueryDict +from django.test import RequestFactory +from django.test.client import MULTIPART_CONTENT +from pydantic import BaseModel + +from hatchway import ApiError, Body, QueryOrBody, api_view +from hatchway.view import ApiView + + +def test_basic_view(): + """ + Tests that a view with simple types works correctly + """ + + @api_view + def test_view( + request, + a: int, + b: QueryOrBody[int | None] = None, + c: str = "x", + ) -> str: + if b is None: + return c * a + else: + return c * (a - b) + + # Call it with a few different patterns to verify it's type coercing right + factory = RequestFactory() + + # Implicit query param + response = test_view(factory.get("/test/?a=4")) + assert json.loads(response.content) == "xxxx" + + # QueryOrBody pulling from query + response = test_view(factory.get("/test/?a=4&b=2")) + assert json.loads(response.content) == "xx" + + # QueryOrBody pulling from formdata body + response = test_view(factory.post("/test/?a=4", {"b": "3"})) + assert json.loads(response.content) == "x" + + # QueryOrBody pulling from JSON body + response = test_view( + factory.post( + "/test/?a=4", json.dumps({"b": 3}), content_type="application/json" + ) + ) + assert json.loads(response.content) == "x" + + # Implicit Query not pulling from body + with pytest.raises(TypeError): + test_view(factory.post("/test/", {"a": 4, "b": 3})) + + +def test_body_direct(): + """ + Tests that a Pydantic model with BodyDirect gets its fields from the top level + """ + + class TestModel(BaseModel): + number: int + name: str + + @api_view + def test_view(request, data: TestModel) -> int: + return data.number + + factory = RequestFactory() + + # formdata version + response = test_view(factory.post("/test/", {"number": "123", "name": "Andrew"})) + assert json.loads(response.content) == 123 + + # JSON body version + response = test_view( + factory.post( + "/test/", + json.dumps({"number": "123", "name": "Andrew"}), + content_type="application/json", + ) + ) + assert json.loads(response.content) == 123 + + +def test_list_response(): + """ + Tests that a view with a list response type works correctly with both + dicts and pydantic model instances. + """ + + class TestModel(BaseModel): + number: int + name: str + + @api_view + def test_view_dict(request) -> list[TestModel]: + return [ + {"name": "Andrew", "number": 1}, # type:ignore + {"name": "Alice", "number": 0}, # type:ignore + ] + + @api_view + def test_view_model(request) -> list[TestModel]: + return [TestModel(name="Andrew", number=1), TestModel(name="Alice", number=0)] + + response = test_view_dict(RequestFactory().get("/test/")) + assert json.loads(response.content) == [ + {"name": "Andrew", "number": 1}, + {"name": "Alice", "number": 0}, + ] + + response = test_view_model(RequestFactory().get("/test/")) + assert json.loads(response.content) == [ + {"name": "Andrew", "number": 1}, + {"name": "Alice", "number": 0}, + ] + + +def test_patch_body(): + """ + Tests that PATCH also gets its body parsed + """ + + @api_view.patch + def test_view(request, a: Body[int]): + return a + + factory = RequestFactory() + response = test_view( + factory.patch( + "/test/", + content_type=MULTIPART_CONTENT, + data=factory._encode_data({"a": "42"}, MULTIPART_CONTENT), + ) + ) + assert json.loads(response.content) == 42 + + +def test_file_body(): + """ + Tests that file uploads work right + """ + + @api_view.post + def test_view(request, a: Body[int], b: files.File) -> str: + return str(a) + b.read().decode("ascii") + + factory = RequestFactory() + uploaded_file = SimpleUploadedFile( + "file.txt", + b"MY FILE IS AMAZING", + content_type="text/plain", + ) + response = test_view( + factory.post( + "/test/", + data={"a": 42, "b": uploaded_file}, + ) + ) + assert json.loads(response.content) == "42MY FILE IS AMAZING" + + +def test_no_response(): + """ + Tests that a view with no response type returns the contents verbatim + """ + + @api_view + def test_view(request): + return [1, "woooooo"] + + response = test_view(RequestFactory().get("/test/")) + assert json.loads(response.content) == [1, "woooooo"] + + +def test_wrong_method(): + """ + Tests that a view with a method limiter works + """ + + @api_view.get + def test_view(request): + return "yay" + + response = test_view(RequestFactory().get("/test/")) + assert json.loads(response.content) == "yay" + + response = test_view(RequestFactory().post("/test/")) + assert response.status_code == 405 + + +def test_api_error(): + """ + Tests that ApiError propagates right + """ + + @api_view.get + def test_view(request): + raise ApiError(401, "you did a bad thing") + + response = test_view(RequestFactory().get("/test/")) + assert json.loads(response.content) == {"error": "you did a bad thing"} + assert response.status_code == 401 + + +def test_unusable_type(): + """ + Tests that you get a nice error when you use a type on an input that + Pydantic doesn't understand. + """ + + with pytest.raises(ValueError): + + @api_view.get + def test_view(request, a: RequestFactory): + pass + + +def test_get_values(): + """ + Tests that ApiView.get_values correctly handles lists + """ + + assert ApiView.get_values({"a": 2, "b": [3, 4]}) == {"a": 2, "b": [3, 4]} + assert ApiView.get_values({"a": 2, "b[]": [3, 4]}) == {"a": 2, "b": [3, 4]} + assert ApiView.get_values(QueryDict("a=2&b=3&b=4")) == {"a": "2", "b": ["3", "4"]} + assert ApiView.get_values(QueryDict("a=2&b[]=3&b[]=4")) == { + "a": "2", + "b": ["3", "4"], + } + assert ApiView.get_values(QueryDict("a=2&b=3")) == {"a": "2", "b": "3"} + assert ApiView.get_values(QueryDict("a=2&b[]=3")) == {"a": "2", "b": ["3"]} + assert ApiView.get_values(QueryDict("a[b]=1")) == {"a": {"b": "1"}} + assert ApiView.get_values(QueryDict("a[b]=1&a[c]=2")) == {"a": {"b": "1", "c": "2"}} + assert ApiView.get_values(QueryDict("a[b][c]=1")) == {"a": {"b": {"c": "1"}}} + assert ApiView.get_values(QueryDict("a[b][c][]=1")) == {"a": {"b": {"c": ["1"]}}} + assert ApiView.get_values(QueryDict("a[b][]=1&a[b][]=2")) == { + "a": {"b": ["1", "2"]} + } diff --git a/hatchway/types.py b/hatchway/types.py new file mode 100644 index 0000000..bbd4f2b --- /dev/null +++ b/hatchway/types.py @@ -0,0 +1,145 @@ +from types import NoneType, UnionType +from typing import ( # type: ignore[attr-defined] + Annotated, + Any, + Literal, + Optional, + TypeVar, + Union, + _AnnotatedAlias, + _GenericAlias, + get_args, + get_origin, +) + +from django.core import files +from pydantic import BaseModel + +from .http import ApiResponse + +T = TypeVar("T") + + +class PathType: + """ + An input pulled from the path (url resolver kwargs) + """ + + +class QueryType: + """ + An input pulled from the query parameters (request.GET) + """ + + +class BodyType: + """ + An input pulled from the POST body (request.POST or a JSON body) + """ + + +class FileType: + """ + An input pulled from the POST body (request.POST or a JSON body) + """ + + +class BodyDirectType: + """ + A Pydantic model whose keys are all looked for in the top-level + POST data, rather than in a dict under a key named after the input. + """ + + +class QueryOrBodyType: + """ + An input pulled from either query parameters or post data. + """ + + +Path = Annotated[T, PathType] +Query = Annotated[T, QueryType] +Body = Annotated[T, BodyType] +File = Annotated[T, FileType] +BodyDirect = Annotated[T, BodyDirectType] +QueryOrBody = Annotated[T, QueryOrBodyType] + + +def is_optional(annotation) -> tuple[bool, Any]: + """ + If an annotation is Optional or | None, returns (True, internal type). + Returns (False, annotation) otherwise. + """ + if (isinstance(annotation, _GenericAlias) and annotation.__origin__ is Union) or ( + isinstance(annotation, UnionType) + ): + args = get_args(annotation) + if len(args) > 2: + return False, annotation + if args[0] is NoneType: + return True, args[1] + if args[1] is NoneType: + return True, args[0] + return False, annotation + return False, annotation + + +def extract_signifier(annotation) -> tuple[Any, Any]: + """ + Given a type annotation, looks to see if it can find a input source + signifier (Path, Query, etc.) + + If it can, returns (signifier, annotation_without_signifier) + If not, returns (None, annotation) + """ + our_generics = { + PathType, + QueryType, + BodyType, + FileType, + BodyDirectType, + QueryOrBodyType, + } + # Remove any optional-style wrapper + optional, internal_annotation = is_optional(annotation) + # Is it an annotation? + if isinstance(internal_annotation, _AnnotatedAlias): + args = get_args(internal_annotation) + for arg in args[1:]: + if arg in our_generics: + if optional: + return (arg, Optional[args[0]]) + else: + return (arg, args[0]) + return None, annotation + + +def extract_output_type(annotation): + """ + Returns the right response type for a function + """ + # If the type is ApiResponse, we want to pull out its inside + if isinstance(annotation, _GenericAlias): + if get_origin(annotation) == ApiResponse: + return get_args(annotation)[0] + return annotation + + +def acceptable_input(annotation) -> bool: + """ + Returns if this annotation is something we think we can accept as input + """ + _, inner_type = extract_signifier(annotation) + try: + if issubclass(inner_type, BaseModel): + return True + except TypeError: + pass + if inner_type in [str, int, list, tuple, bool, Any, files.File, type(None)]: + return True + origin = get_origin(inner_type) + if origin == Literal: + return True + if origin in [Union, UnionType, dict, list, tuple]: + return all(acceptable_input(a) for a in get_args(inner_type)) + return False diff --git a/hatchway/urls.py b/hatchway/urls.py new file mode 100644 index 0000000..c531cba --- /dev/null +++ b/hatchway/urls.py @@ -0,0 +1,32 @@ +from collections.abc import Callable +from typing import Any + +from django.http import HttpResponseNotAllowed + + +class Methods: + """ + Allows easy multi-method dispatch to different functions + """ + + csrf_exempt = True + + def __init__(self, **callables: Callable): + self.callables = { + method.lower(): callable for method, callable in callables.items() + } + unknown_methods = set(self.callables.keys()).difference( + {"get", "post", "patch", "put", "delete"} + ) + if unknown_methods: + raise ValueError(f"Cannot route methods: {unknown_methods}") + + def __call__(self, request, *args, **kwargs) -> Any: + method = request.method.lower() + if method in self.callables: + return self.callables[method](request, *args, **kwargs) + else: + return HttpResponseNotAllowed(self.callables.keys()) + + +methods = Methods diff --git a/hatchway/view.py b/hatchway/view.py new file mode 100644 index 0000000..c8a6b1f --- /dev/null +++ b/hatchway/view.py @@ -0,0 +1,297 @@ +import json +from collections.abc import Callable +from typing import Any, Optional, get_type_hints + +from django.core import files +from django.http import HttpRequest, HttpResponseNotAllowed, QueryDict +from django.http.multipartparser import MultiPartParser +from pydantic import BaseModel, create_model + +from .constants import InputSource +from .http import ApiError, ApiResponse +from .types import ( + BodyDirectType, + BodyType, + FileType, + PathType, + QueryOrBodyType, + QueryType, + acceptable_input, + extract_output_type, + extract_signifier, + is_optional, +) + + +class ApiView: + """ + A view 'wrapper' object that replaces the API view for anything further + up the stack. + + Unlike Django's class-based views, we don't need an as_view pattern + as we are careful never to write anything per-request to self. + """ + + csrf_exempt = True + + def __init__( + self, + view: Callable, + input_types: dict[str, Any] | None = None, + output_type: Any = None, + implicit_lists: bool = True, + method: str | None = None, + ): + self.view = view + self.implicit_lists = implicit_lists + self.view_name = getattr(view, "__name__", "unknown_view") + self.method = method + # Extract input/output types from view annotations if we need to + self.input_types = input_types + if self.input_types is None: + self.input_types = get_type_hints(view, include_extras=True) + if "return" in self.input_types: + del self.input_types["return"] + self.output_type = output_type + if self.output_type is None: + try: + self.output_type = extract_output_type( + get_type_hints(view, include_extras=True)["return"] + ) + except KeyError: + self.output_type = None + self.compile() + + @classmethod + def get(cls, view: Callable): + return cls(view=view, method="get") + + @classmethod + def post(cls, view: Callable): + return cls(view=view, method="post") + + @classmethod + def put(cls, view: Callable): + return cls(view=view, method="put") + + @classmethod + def patch(cls, view: Callable): + return cls(view=view, method="patch") + + @classmethod + def delete(cls, view: Callable): + return cls(view=view, method="delete") + + @classmethod + def sources_for_input(cls, input_type) -> tuple[list[InputSource], Any]: + """ + Given a type that can appear as a request parameter type, returns + what sources it can come from, and what its type is as understood + by Pydantic. + """ + signifier, input_type = extract_signifier(input_type) + if signifier is QueryType: + return ([InputSource.query], input_type) + elif signifier is BodyType: + return ([InputSource.body], input_type) + elif signifier is BodyDirectType: + if not issubclass(input_type, BaseModel): + raise ValueError( + "You cannot use BodyDirect on something that is not a Pydantic model" + ) + return ([InputSource.body_direct], input_type) + elif signifier is PathType: + return ([InputSource.path], input_type) + elif ( + signifier is FileType + or input_type is files.File + or is_optional(input_type)[1] is files.File + ): + return ([InputSource.file], input_type) + elif signifier is QueryOrBodyType: + return ([InputSource.query, InputSource.body], input_type) + # Is it a Pydantic model, which means it's implicitly body? + elif isinstance(input_type, type) and issubclass(input_type, BaseModel): + return ([InputSource.body], input_type) + # Otherwise, we look in the path first and then the query + else: + return ([InputSource.path, InputSource.query], input_type) + + @classmethod + def get_values(cls, data, use_square_brackets=True) -> dict[str, Any]: + """ + Given a QueryDict or normal dict, returns data taking into account + lists made by repeated values or by suffixing names with []. + """ + result: dict[str, Any] = {} + for key, value in data.items(): + # If it's a query dict with multiple values, make it a list + if isinstance(data, QueryDict): + values = data.getlist(key) + if len(values) > 1: + value = values + # If it is in dict-ish/list-ish syntax, adhere to that + # TODO: Make this better handle badly formed keys + if "[" in key and use_square_brackets: + parts = key.split("[") + target = result + last_key = parts[0] + for part in parts[1:]: + part = part.rstrip("]") + if not part: + target = target.setdefault(last_key, []) + else: + target = target.setdefault(last_key, {}) + last_key = part + if isinstance(target, list): + if isinstance(value, list): + target.extend(value) + else: + target.append(value) + else: + target[last_key] = value + else: + result[key] = value + return result + + def compile(self): + self.sources: dict[str, list[InputSource]] = {} + amount_from_body = 0 + pydantic_model_dict = {} + self.input_files = set() + last_body_type = None + # For each input item, work out where to pull it from + for name, input_type in self.input_types.items(): + # Do some basic typechecking to stop things that aren't allowed + if isinstance(input_type, type) and issubclass(input_type, HttpRequest): + continue + if not acceptable_input(input_type): + # Strip away any singifiers for the error + _, inner_type = extract_signifier(input_type) + raise ValueError( + f"Input argument {name} has an unsupported type {inner_type}" + ) + sources, pydantic_type = self.sources_for_input(input_type) + self.sources[name] = sources + # Keep count of how many are pulling from the body + if InputSource.body in sources: + amount_from_body += 1 + last_body_type = pydantic_type + if InputSource.file in sources: + self.input_files.add(name) + else: + pydantic_model_dict[name] = (Optional[pydantic_type], ...) + # If there is just one thing pulling from the body and it's a BaseModel, + # signify that it's actually pulling from the body keys directly and + # not a sub-dict + if amount_from_body == 1: + for name, sources in self.sources.items(): + if ( + InputSource.body in sources + and isinstance(last_body_type, type) + and issubclass(last_body_type, BaseModel) + ): + self.sources[name] = [ + x for x in sources if x != InputSource.body + ] + [InputSource.body_direct] + # Turn all the main arguments into Pydantic parsing models + try: + self.input_model = create_model( + f"{self.view_name}_input", **pydantic_model_dict + ) + except RuntimeError: + raise ValueError( + f"One or more inputs on view {self.view_name} have a bad configuration" + ) + if self.output_type is not None: + self.output_model = create_model( + f"{self.view_name}_output", value=(self.output_type, ...) + ) + + def __call__(self, request: HttpRequest, *args, **kwargs): + """ + Entrypoint when this is called as a view. + """ + # Do a method check if we have one set + if self.method and self.method.upper() != request.method: + return HttpResponseNotAllowed([self.method]) + # For each item we can source, go find it if we can + query_values = self.get_values(request.GET) + body_values = self.get_values(request.POST) + files_values = self.get_values(request.FILES) + # If it's a PUT or PATCH method, work around Django not handling FILES + # or POST on those requests + if request.method in ["PATCH", "PUT"]: + if request.content_type == "multipart/form-data": + POST, FILES = MultiPartParser( + request.META, request, request.upload_handlers, request.encoding + ).parse() + body_values = self.get_values(POST) + files_values = self.get_values(FILES) + elif request.content_type == "application/x-www-form-urlencoded": + POST = QueryDict(request.body, encoding=request._encoding) + body_values = self.get_values(POST) + # If there was a JSON body, go load that + if request.content_type == "application/json" and request.body.strip(): + body_values.update(self.get_values(json.loads(request.body))) + values = {} + for name, sources in self.sources.items(): + for source in sources: + if source == InputSource.path: + if name in kwargs: + values[name] = kwargs[name] + break + elif source == InputSource.query: + if name in query_values: + values[name] = query_values[name] + break + elif source == InputSource.body: + if name in body_values: + values[name] = body_values[name] + break + elif source == InputSource.file: + if name in files_values: + values[name] = files_values[name] + break + elif source == InputSource.body_direct: + values[name] = body_values + break + elif source == InputSource.query_and_body_direct: + values[name] = dict(query_values) + values[name].update(body_values) + break + else: + raise ValueError(f"Unknown source {source}") + else: + values[name] = None + # Give that to the Pydantic model to make it handle stuff + model_instance = self.input_model(**values) + kwargs = { + name: getattr(model_instance, name) + for name in model_instance.__fields__ + if values[name] is not None # Trim out missing fields + } + # Add in any files + # TODO: HTTP error if file is not optional + for name in self.input_files: + kwargs[name] = files_values.get(name, None) + # Call the view with those as kwargs + try: + response = self.view(request, **kwargs) + except ApiError as error: + return ApiResponse( + {"error": error.error}, status=error.status, finalize=True + ) + # If it's not an ApiResponse, make it one + if not isinstance(response, ApiResponse): + response = ApiResponse(response) + # Get pydantic to coerce the output response + if self.output_type is not None: + response.data = self.output_model(value=response.data).dict()["value"] + elif isinstance(response.data, BaseModel): + response.data = response.data.dict() + response.finalize() + return response + + +api_view = ApiView diff --git a/requirements.txt b/requirements.txt index 911830f..acc6fe6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,7 +6,6 @@ django-cache-url~=3.4.2 django-cors-headers~=3.13.0 django-debug-toolbar~=3.8.1 django-htmx~=1.13.0 -django-ninja~=0.19.1 django-oauth-toolkit~=2.2.0 django-storages[google,boto3]~=1.13.1 django~=4.1 diff --git a/takahe/urls.py b/takahe/urls.py index d437308..456636a 100644 --- a/takahe/urls.py +++ b/takahe/urls.py @@ -3,7 +3,7 @@ from django.contrib import admin as djadmin from django.urls import include, path, re_path from activities.views import compose, debug, explore, follows, posts, search, timelines -from api.views import api_router, oauth +from api.views import oauth from core import views as core from mediaproxy import views as mediaproxy from stator import views as stator @@ -296,7 +296,7 @@ urlpatterns = [ path("actor/outbox/", activitypub.EmptyOutbox.as_view()), path("inbox/", activitypub.Inbox.as_view(), name="shared_inbox"), # API/Oauth - path("api/", api_router.urls), + path("api/", include("api.urls")), path("oauth/authorize", oauth.AuthorizationView.as_view()), path("oauth/token", oauth.TokenView.as_view()), path("oauth/revoke", oauth.RevokeTokenView.as_view()), diff --git a/tests/activities/models/test_post.py b/tests/activities/models/test_post.py index 0b9578b..0eb1fc5 100644 --- a/tests/activities/models/test_post.py +++ b/tests/activities/models/test_post.py @@ -281,7 +281,9 @@ def test_inbound_posts( InboxMessage.objects.create(message=message) # Run stator and ensure that made the post + print("prestat") stator.run_single_cycle_sync() + print("poststat") post = Post.objects.get(object_uri="https://remote.test/test-post") assert post.content == "post version one" assert post.published.day == 13