Implement poll AP and API backend
This commit is contained in:
Humberto Rocha 2023-02-15 12:42:56 -05:00 committed by GitHub
parent cc0bbfb93b
commit e382994e73
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
18 changed files with 868 additions and 43 deletions

View file

@ -0,0 +1,26 @@
# Generated by Django 4.1.7 on 2023-02-14 22:30
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("activities", "0010_stator_indexes"),
]
operations = [
migrations.AddField(
model_name="postinteraction",
name="value",
field=models.CharField(blank=True, max_length=50, null=True),
),
migrations.AlterField(
model_name="postinteraction",
name="type",
field=models.CharField(
choices=[("like", "Like"), ("boost", "Boost"), ("vote", "Vote")],
max_length=100,
),
),
]

View file

@ -163,7 +163,7 @@ class FanOutStates(StateGraph):
interaction=interaction,
)
# Handle sending remote boosts/likes
# Handle sending remote boosts/likes/votes
case (FanOut.Types.interaction, False):
interaction = await fan_out.subject_post_interaction.afetch_full()
# Send it to the remote inbox
@ -174,7 +174,11 @@ class FanOutStates(StateGraph):
fan_out.identity.shared_inbox_uri
or fan_out.identity.inbox_uri
),
body=canonicalise(interaction.to_ap()),
body=canonicalise(
interaction.to_create_ap()
if interaction.type == interaction.Types.vote
else interaction.to_ap()
),
)
except httpx.RequestError:
return

View file

@ -23,6 +23,7 @@ from activities.models.post_types import (
PostTypeData,
PostTypeDataDecoder,
PostTypeDataEncoder,
QuestionData,
)
from core.exceptions import capture_message
from core.html import ContentRenderer, FediverseHtmlParser
@ -377,12 +378,16 @@ class Post(StatorModel):
def _safe_content_note(self, *, local: bool = True):
return ContentRenderer(local=local).render_post(self.content, self)
# def _safe_content_question(self, *, local: bool = True):
# context = {
# "post": self,
# "typed_data": PostTypeData(self.type_data),
# }
# return loader.render_to_string("activities/_type_question.html", context)
def _safe_content_question(self, *, local: bool = True):
if local:
context = {
"post": self,
"sanitized_content": self._safe_content_note(local=local),
"local_display": local,
}
return loader.render_to_string("activities/_type_question.html", context)
else:
return ContentRenderer(local=local).render_post(self.content, self)
def _safe_content_typed(self, *, local: bool = True):
context = {
@ -461,6 +466,7 @@ class Post(StatorModel):
visibility: int = Visibilities.public,
reply_to: Optional["Post"] = None,
attachments: list | None = None,
question: dict | None = None,
) -> "Post":
with transaction.atomic():
# Find mentions in this post
@ -493,6 +499,9 @@ class Post(StatorModel):
post.emojis.set(emojis)
if attachments:
post.attachments.set(attachments)
if question:
post.type = question["type"]
post.type_data = PostTypeData(__root__=question).__root__
post.save()
# Recalculate parent stats for replies
if reply_to:
@ -574,6 +583,30 @@ class Post(StatorModel):
if save:
self.save()
def calculate_type_data(self, save=True):
"""
Recalculate type_data (used mostly for poll votes)
"""
from activities.models import PostInteraction
if self.local and isinstance(self.type_data, QuestionData):
self.type_data.voter_count = (
self.interactions.filter(
type=PostInteraction.Types.vote,
)
.values("identity")
.distinct()
.count()
)
for option in self.type_data.options:
option.votes = self.interactions.filter(
type=PostInteraction.Types.vote,
value=option.name,
).count()
if save:
self.save()
### ActivityPub (outbound) ###
def to_ap(self) -> dict:
@ -802,8 +835,7 @@ class Post(StatorModel):
if update or created:
post.type = data["type"]
if post.type in (cls.Types.article, cls.Types.question):
type_data = PostTypeData(__root__=data).__root__
post.type_data = type_data.dict()
post.type_data = PostTypeData(__root__=data).__root__
post.content = get_value_or_map(data, "content", "contentMap")
post.summary = data.get("summary")
post.sensitive = data.get("sensitive", False)
@ -1027,7 +1059,7 @@ class Post(StatorModel):
### Mastodon API ###
def to_mastodon_json(self, interactions=None):
def to_mastodon_json(self, interactions=None, identity=None):
reply_parent = None
if self.in_reply_to:
# Load the PK and author.id explicitly to prevent a SELECT on the entire author Identity
@ -1085,7 +1117,9 @@ class Post(StatorModel):
reply_parent.author_id if reply_parent else None
),
"reblog": None,
"poll": None,
"poll": self.type_data.to_mastodon_json(self, identity)
if isinstance(self.type_data, QuestionData)
else None,
"card": None,
"language": None,
"text": self.safe_content_remote(),

View file

@ -3,6 +3,7 @@ from django.utils import timezone
from activities.models.fan_out import FanOut
from activities.models.post import Post
from activities.models.post_types import QuestionData
from core.ld import format_ld_date, get_str_or_id, parse_ld_date
from core.snowflake import Snowflake
from stator.models import State, StateField, StateGraph, StatorModel
@ -60,6 +61,17 @@ class PostInteractionStates(StateGraph):
subject_post=interaction.post,
subject_post_interaction=interaction,
)
# Vote: send a copy of the vote to the original
# post author only if it's a local interaction
# to a non local post
elif interaction.type == interaction.Types.vote:
if interaction.identity.local and not interaction.post.local:
await FanOut.objects.acreate(
type=FanOut.Types.interaction,
identity_id=interaction.post.author_id,
subject_post=interaction.post,
subject_post_interaction=interaction,
)
else:
raise ValueError("Cannot fan out unknown type")
# And one for themselves if they're local and it's a boost
@ -125,6 +137,7 @@ class PostInteraction(StatorModel):
class Types(models.TextChoices):
like = "like"
boost = "boost"
vote = "vote"
id = models.BigIntegerField(
primary_key=True,
@ -154,6 +167,10 @@ class PostInteraction(StatorModel):
related_name="interactions",
)
# Used to store any interaction extra text value like the vote
# in the question/poll case
value = models.CharField(max_length=50, blank=True, null=True)
# When the activity was originally created (as opposed to when we received it)
# Mastodon only seems to send this for boosts, not likes
published = models.DateTimeField(default=timezone.now)
@ -203,9 +220,44 @@ class PostInteraction(StatorModel):
"""
Returns a version of the object with all relations pre-loaded
"""
return await PostInteraction.objects.select_related("identity", "post").aget(
pk=self.pk
)
return await PostInteraction.objects.select_related(
"identity", "post", "post__author"
).aget(pk=self.pk)
### Create helpers ###
@classmethod
def create_votes(cls, post, identity, choices) -> list["PostInteraction"]:
question = post.type_data
if question.end_time and timezone.now() > question.end_time:
raise ValueError("Validation failed: The poll has already ended")
if post.interactions.filter(identity=identity, type=cls.Types.vote).exists():
raise ValueError("Validation failed: You have already voted on this poll")
votes = []
with transaction.atomic():
for choice in set(choices):
vote = cls.objects.create(
identity=identity,
post=post,
type=PostInteraction.Types.vote,
value=question.options[choice].name,
)
vote.object_uri = f"{identity.actor_uri}#votes/{vote.id}"
vote.save()
votes.append(vote)
if not post.local:
question.options[choice].votes += 1
if not post.local:
question.voter_count += 1
post.calculate_type_data()
return votes
### ActivityPub (outbound) ###
@ -233,10 +285,33 @@ class PostInteraction(StatorModel):
"actor": self.identity.actor_uri,
"object": self.post.object_uri,
}
elif self.type == self.Types.vote:
value = {
"type": "Note",
"id": self.object_uri,
"to": self.post.author.actor_uri,
"name": self.value,
"inReplyTo": self.post.object_uri,
"attributedTo": self.identity.actor_uri,
}
else:
raise ValueError("Cannot turn into AP")
return value
def to_create_ap(self):
"""
Returns the AP JSON to create this object
"""
object = self.to_ap()
return {
"to": object.get("to", []),
"cc": object.get("cc", []),
"type": "Create",
"id": self.object_uri,
"actor": self.identity.actor_uri,
"object": object,
}
def to_undo_ap(self) -> dict:
"""
Returns the AP JSON to undo this object
@ -267,12 +342,40 @@ class PostInteraction(StatorModel):
# Resolve the author
identity = Identity.by_actor_uri(data["actor"], create=True)
# Resolve the post
post = Post.by_object_uri(get_str_or_id(data["object"]), fetch=True)
object = data["object"]
target = get_str_or_id(object, "inReplyTo") or get_str_or_id(object)
post = Post.by_object_uri(target, fetch=True)
value = None
# Get the right type
if data["type"].lower() == "like":
type = cls.Types.like
elif data["type"].lower() == "announce":
type = cls.Types.boost
elif (
data["type"].lower() == "create"
and object["type"].lower() == "note"
and isinstance(post.type_data, QuestionData)
):
type = cls.Types.vote
question = post.type_data
value = object["name"]
if question.end_time and timezone.now() > question.end_time:
# TODO: Maybe create an expecific expired exception?
raise cls.DoesNotExist(
f"Cannot create a vote to the expired question {post.id}"
)
already_voted = (
post.type_data.mode == "oneOf"
and post.interactions.filter(
type=cls.Types.vote, identity=identity
).exists()
)
if already_voted:
raise cls.DoesNotExist(
f"The identity {identity.handle} already voted in question {post.id}"
)
else:
raise ValueError(f"Cannot handle AP type {data['type']}")
# Make the actual interaction
@ -283,6 +386,7 @@ class PostInteraction(StatorModel):
published=parse_ld_date(data.get("published", None))
or timezone.now(),
type=type,
value=value,
)
else:
raise cls.DoesNotExist(f"No interaction with ID {data['id']}", data)
@ -301,7 +405,9 @@ class PostInteraction(StatorModel):
# That post is gone, boss
# TODO: Limited retry state?
return
interaction.post.calculate_stats()
interaction.post.calculate_type_data()
@classmethod
def handle_undo_ap(cls, data):
@ -324,10 +430,11 @@ class PostInteraction(StatorModel):
interaction.transition_perform(PostInteractionStates.undone_fanned_out)
# Recalculate post stats
interaction.post.calculate_stats()
interaction.post.calculate_type_data()
### Mastodon API ###
def to_mastodon_status_json(self, interactions=None):
def to_mastodon_status_json(self, interactions=None, identity=None):
"""
This wraps Posts in a fake Status for boost interactions.
"""
@ -336,7 +443,9 @@ class PostInteraction(StatorModel):
f"Cannot make status JSON for interaction of type {self.type}"
)
# Make a fake post for this boost (because mastodon treats boosts as posts)
post_json = self.post.to_mastodon_json(interactions=interactions)
post_json = self.post.to_mastodon_json(
interactions=interactions, identity=identity
)
return {
"id": f"{self.pk}",
"uri": post_json["uri"],
@ -356,7 +465,7 @@ class PostInteraction(StatorModel):
"url": post_json["url"],
"in_reply_to_id": None,
"in_reply_to_account_id": None,
"poll": None,
"poll": post_json["poll"],
"card": None,
"language": None,
"text": "",

View file

@ -2,8 +2,11 @@ import json
from datetime import datetime
from typing import Literal
from django.utils import timezone
from pydantic import BaseModel, Field
from core.ld import format_ld_date
class BasePostDataType(BaseModel):
pass
@ -45,6 +48,51 @@ class QuestionData(BasePostDataType):
data["options"] = options
super().__init__(**data)
def to_mastodon_json(self, post, identity=None):
from activities.models import PostInteraction
multiple = self.mode == "anyOf"
value = {
"id": post.id,
"expires_at": None,
"expired": False,
"multiple": multiple,
"votes_count": 0,
"voters_count": self.voter_count if multiple else None,
"voted": False,
"own_votes": [],
"options": [],
"emojis": [],
}
if self.end_time:
value["expires_at"] = format_ld_date(self.end_time)
value["expired"] = timezone.now() >= self.end_time
options = self.options or []
option_map = {}
for index, option in enumerate(options):
value["options"].append(
{
"title": option.name,
"votes_count": option.votes,
}
)
value["votes_count"] += option.votes
option_map[option.name] = index
if identity:
votes = post.interactions.filter(
identity=identity,
type=PostInteraction.Types.vote,
)
value["voted"] = post.author == identity or votes.exists()
value["own_votes"] = [
option_map[vote.value] for vote in votes if vote.value in option_map
]
return value
class ArticleData(BasePostDataType):
type: Literal["Article"]

View file

@ -221,12 +221,14 @@ class TimelineEvent(models.Model):
raise ValueError(f"Cannot convert {self.type} to notification JSON")
return result
def to_mastodon_status_json(self, interactions=None):
def to_mastodon_status_json(self, interactions=None, identity=None):
if self.type == self.Types.post:
return self.subject_post.to_mastodon_json(interactions=interactions)
return self.subject_post.to_mastodon_json(
interactions=interactions, identity=identity
)
elif self.type == self.Types.boost:
return self.subject_post_interaction.to_mastodon_status_json(
interactions=interactions
interactions=interactions, identity=identity
)
else:
raise ValueError(f"Cannot make status JSON for type {self.type}")

View file

@ -153,7 +153,9 @@ class PaginationResult(Generic[T]):
"""
interactions = PostInteraction.get_post_interactions(self.results, identity)
self.jsonify_results(
lambda post: post.to_mastodon_json(interactions=interactions)
lambda post: post.to_mastodon_json(
interactions=interactions, identity=identity
)
)
def jsonify_status_events(self, identity):
@ -162,7 +164,9 @@ class PaginationResult(Generic[T]):
"""
interactions = PostInteraction.get_event_interactions(self.results, identity)
self.jsonify_results(
lambda event: event.to_mastodon_status_json(interactions=interactions)
lambda event: event.to_mastodon_status_json(
interactions=interactions, identity=identity
)
)
def jsonify_notification_events(self, identity):

View file

@ -91,6 +91,32 @@ class MediaAttachment(Schema):
return cls(**attachment.to_mastodon_json())
class PollOptions(Schema):
title: str
votes_count: int | None
class Poll(Schema):
id: str
expires_at: str | None
expired: bool
multiple: bool
votes_count: int
voters_count: int | None
voted: bool
own_votes: list[int]
options: list[PollOptions]
emojis: list[CustomEmoji]
@classmethod
def from_post(
cls,
post: activities_models.Post,
identity: users_models.Identity | None = None,
) -> "Poll":
return cls(**post.type_data.to_mastodon_json(post, identity=identity))
class StatusMention(Schema):
id: str
username: str
@ -123,7 +149,7 @@ class Status(Schema):
in_reply_to_id: str | None = Field(...)
in_reply_to_account_id: str | None = Field(...)
reblog: Optional["Status"] = Field(...)
poll: None = Field(...)
poll: Poll | None = Field(...)
card: None = Field(...)
language: None = Field(...)
text: str | None = Field(...)
@ -139,8 +165,11 @@ class Status(Schema):
cls,
post: activities_models.Post,
interactions: dict[str, set[str]] | None = None,
identity: users_models.Identity | None = None,
) -> "Status":
return cls(**post.to_mastodon_json(interactions=interactions))
return cls(
**post.to_mastodon_json(interactions=interactions, identity=identity)
)
@classmethod
def map_from_post(
@ -151,15 +180,23 @@ class Status(Schema):
interactions = activities_models.PostInteraction.get_post_interactions(
posts, identity
)
return [cls.from_post(post, interactions=interactions) for post in posts]
return [
cls.from_post(post, interactions=interactions, identity=identity)
for post in posts
]
@classmethod
def from_timeline_event(
cls,
timeline_event: activities_models.TimelineEvent,
interactions: dict[str, set[str]] | None = None,
identity: users_models.Identity | None = None,
) -> "Status":
return cls(**timeline_event.to_mastodon_status_json(interactions=interactions))
return cls(
**timeline_event.to_mastodon_status_json(
interactions=interactions, identity=identity
)
)
@classmethod
def map_from_timeline_event(
@ -171,7 +208,7 @@ class Status(Schema):
events, identity
)
return [
cls.from_timeline_event(event, interactions=interactions)
cls.from_timeline_event(event, interactions=interactions, identity=identity)
for event in events
]

View file

@ -10,6 +10,7 @@ from api.views import (
instance,
media,
notifications,
polls,
search,
statuses,
timelines,
@ -62,6 +63,9 @@ urlpatterns = [
path("v1/statuses/<id>/source", statuses.status_source),
# Notifications
path("v1/notifications", notifications.notifications),
# Polls
path("v1/polls/<id>", polls.get_poll),
path("v1/polls/<id>/votes", polls.vote_poll),
# Search
path("v2/search", search.search),
# Statuses

26
api/views/polls.py Normal file
View file

@ -0,0 +1,26 @@
from django.shortcuts import get_object_or_404
from hatchway import Schema, api_view
from activities.models import Post, PostInteraction
from api import schemas
from api.decorators import identity_required
class PostVoteSchema(Schema):
choices: list[int]
@identity_required
@api_view.get
def get_poll(request, id: str) -> schemas.Poll:
post = get_object_or_404(Post, pk=id, type=Post.Types.question)
return schemas.Poll.from_post(post, identity=request.identity)
@identity_required
@api_view.post
def vote_poll(request, id: str, details: PostVoteSchema) -> schemas.Poll:
post = get_object_or_404(Post, pk=id, type=Post.Types.question)
PostInteraction.create_votes(post, request.identity, details.choices)
post.refresh_from_db()
return schemas.Poll.from_post(post, identity=request.identity)

View file

@ -47,7 +47,9 @@ def search(
search_result["posts"], request.identity
)
result["statuses"] = [
schemas.Status.from_post(p, interactions=interactions)
schemas.Status.from_post(
p, interactions=interactions, identity=request.identity
)
for p in search_result["posts"]
]
return schemas.Search(**result)

View file

@ -1,7 +1,9 @@
from datetime import timedelta
from typing import Literal
from django.http import HttpRequest
from django.shortcuts import get_object_or_404
from django.utils import timezone
from hatchway import ApiError, ApiResponse, Schema, api_view
from activities.models import (
@ -18,6 +20,24 @@ from api.pagination import MastodonPaginator, PaginationResult
from core.models import Config
class PostPollSchema(Schema):
options: list[str]
expires_in: int
multiple: bool = False
hide_totals: bool = False
def dict(self):
return {
"type": "Question",
"mode": "anyOf" if self.multiple else "oneOf",
"options": [
{"name": name, "type": "Note", "votes": 0} for name in self.options
],
"voter_count": 0,
"end_time": timezone.now() + timedelta(seconds=self.expires_in),
}
class PostStatusSchema(Schema):
status: str
in_reply_to_id: str | None = None
@ -27,6 +47,7 @@ class PostStatusSchema(Schema):
language: str | None = None
scheduled_at: str | None = None
media_ids: list[str] = []
poll: PostPollSchema | None = None
class EditStatusSchema(Schema):
@ -82,10 +103,11 @@ def post_status(request, details: PostStatusSchema) -> schemas.Status:
visibility=visibility_map[details.visibility],
reply_to=reply_post,
attachments=attachments,
question=details.poll.dict() if details.poll else None,
)
# Add their own timeline event for immediate visibility
TimelineEvent.add_post(request.identity, post)
return schemas.Status.from_post(post)
return schemas.Status.from_post(post, identity=request.identity)
@identity_required
@ -93,7 +115,9 @@ def post_status(request, details: PostStatusSchema) -> schemas.Status:
def status(request, id: str) -> schemas.Status:
post = post_for_id(request, id)
interactions = PostInteraction.get_post_interactions([post], request.identity)
return schemas.Status.from_post(post, interactions=interactions)
return schemas.Status.from_post(
post, interactions=interactions, identity=request.identity
)
@identity_required
@ -121,7 +145,7 @@ def delete_status(request, id: str) -> schemas.Status:
if post.author != request.identity:
raise ApiError(401, "Not the author of this status")
PostService(post).delete()
return schemas.Status.from_post(post)
return schemas.Status.from_post(post, identity=request.identity)
@identity_required
@ -142,11 +166,16 @@ def status_context(request, id: str) -> schemas.Context:
)
return schemas.Context(
ancestors=[
schemas.Status.from_post(p, interactions=interactions)
schemas.Status.from_post(
p, interactions=interactions, identity=request.identity
)
for p in reversed(ancestors)
],
descendants=[
schemas.Status.from_post(p, interactions=interactions) for p in descendants
schemas.Status.from_post(
p, interactions=interactions, identity=request.identity
)
for p in descendants
],
)
@ -158,7 +187,9 @@ def favourite_status(request, id: str) -> schemas.Status:
service = PostService(post)
service.like_as(request.identity)
interactions = PostInteraction.get_post_interactions([post], request.identity)
return schemas.Status.from_post(post, interactions=interactions)
return schemas.Status.from_post(
post, interactions=interactions, identity=request.identity
)
@identity_required
@ -168,7 +199,9 @@ def unfavourite_status(request, id: str) -> schemas.Status:
service = PostService(post)
service.unlike_as(request.identity)
interactions = PostInteraction.get_post_interactions([post], request.identity)
return schemas.Status.from_post(post, interactions=interactions)
return schemas.Status.from_post(
post, interactions=interactions, identity=request.identity
)
@api_view.get
@ -219,7 +252,9 @@ def reblog_status(request, id: str) -> schemas.Status:
service = PostService(post)
service.boost_as(request.identity)
interactions = PostInteraction.get_post_interactions([post], request.identity)
return schemas.Status.from_post(post, interactions=interactions)
return schemas.Status.from_post(
post, interactions=interactions, identity=request.identity
)
@identity_required
@ -229,4 +264,6 @@ def unreblog_status(request, id: str) -> schemas.Status:
service = PostService(post)
service.unboost_as(request.identity)
interactions = PostInteraction.get_post_interactions([post], request.identity)
return schemas.Status.from_post(post, interactions=interactions)
return schemas.Status.from_post(
post, interactions=interactions, identity=request.identity
)

View file

@ -626,14 +626,14 @@ def get_list(container, key) -> list:
return value
def get_str_or_id(value: str | dict | None) -> str | None:
def get_str_or_id(value: str | dict | None, key: str = "id") -> str | None:
"""
Given a value that could be a str or {"id": str}, return the str
"""
if isinstance(value, str):
return value
elif isinstance(value, dict):
return value.get("id")
return value.get(key)
return None

View file

@ -2,6 +2,7 @@ import pytest
from pytest_httpx import HTTPXMock
from activities.models import Post, PostStates
from activities.models.post_types import QuestionData
from users.models import Identity, InboxMessage
@ -254,6 +255,46 @@ def test_content_map(remote_identity):
assert post3.content == "Hello World"
@pytest.mark.django_db
def test_content_map_question(remote_identity: Identity):
"""
Tests post contentmap for questions
"""
post = Post.by_ap(
data={
"id": "https://remote.test/posts/1/",
"type": "Question",
"votersCount": 10,
"closed": "2023-01-01T26:04:45Z",
"content": "Test Question",
"attributedTo": "https://remote.test/test-actor/",
"published": "2022-12-23T10:50:54Z",
"endTime": "2023-01-01T20:04:45Z",
"oneOf": [
{
"type": "Note",
"name": "Option 1",
"replies": {
"type": "Collection",
"totalItems": 6,
},
},
{
"type": "Note",
"name": "Option 2",
"replies": {
"type": "Collection",
"totalItems": 4,
},
},
],
},
create=True,
)
assert post.content == "Test Question"
assert isinstance(post.type_data, QuestionData)
@pytest.mark.django_db
@pytest.mark.parametrize("delete_type", ["note", "tombstone", "ref"])
def test_inbound_posts(

View file

@ -0,0 +1,314 @@
from datetime import timedelta
import pytest
from django.utils import timezone
from activities.models import Post, PostInteraction
from activities.models.post_types import QuestionData
from core.ld import format_ld_date
from users.models import Identity
@pytest.mark.django_db
def test_vote_in_question(identity: Identity, remote_identity: Identity, config_system):
post = Post.create_local(
author=identity,
content="<p>Test Question</p>",
question={
"type": "Question",
"mode": "oneOf",
"options": [
{"name": "Option 1", "type": "Note", "votes": 0},
{"name": "Option 2", "type": "Note", "votes": 0},
],
"voter_count": 0,
"end_time": format_ld_date(timezone.now() + timedelta(1)),
},
)
PostInteraction.handle_ap(
data={
"id": "https://remote.test/test-actor#votes/11/activity",
"to": "https://example.com/@test@example.com/",
"type": "Create",
"actor": "https://remote.test/test-actor/",
"object": {
"id": "https://remote.test/users/test-actor#votes/11",
"to": "https://example.com/@test@example.com/",
"name": "Option 1",
"type": "Note",
"inReplyTo": post.object_uri,
"attributedTo": "https://remote.test/test-actor/",
},
"@context": [
"https://www.w3.org/ns/activitystreams",
{
"toot": "http://joinmastodon.org/ns#",
"Emoji": "toot:Emoji",
"Public": "as:Public",
"Hashtag": "as:Hashtag",
"votersCount": "toot:votersCount",
},
"https://w3id.org/security/v1",
],
},
)
post.refresh_from_db()
assert isinstance(post.type_data, QuestionData)
assert post.type_data.voter_count == 1
assert post.type_data.options
assert len(post.type_data.options) == 2
assert post.type_data.options[0].votes == 1
assert post.type_data.options[1].votes == 0
@pytest.mark.django_db
def test_vote_in_multiple_choice_question(
identity: Identity, remote_identity: Identity, config_system
):
post = Post.create_local(
author=identity,
content="<p>Test Question</p>",
question={
"type": "Question",
"mode": "anyOf",
"options": [
{"name": "Option 1", "type": "Note", "votes": 0},
{"name": "Option 2", "type": "Note", "votes": 0},
{"name": "Option 3", "type": "Note", "votes": 0},
],
"voter_count": 0,
"end_time": format_ld_date(timezone.now() + timedelta(1)),
},
)
PostInteraction.handle_ap(
data={
"id": "https://remote.test/test-actor#votes/12/activity",
"to": "https://example.com/@test@example.com/",
"type": "Create",
"actor": "https://remote.test/test-actor/",
"object": {
"id": "https://remote.test/users/test-actor#votes/12",
"to": "https://example.com/@test@example.com/",
"name": "Option 1",
"type": "Note",
"inReplyTo": post.object_uri,
"attributedTo": "https://remote.test/test-actor/",
},
"@context": [
"https://www.w3.org/ns/activitystreams",
{
"toot": "http://joinmastodon.org/ns#",
"Emoji": "toot:Emoji",
"Public": "as:Public",
"Hashtag": "as:Hashtag",
"votersCount": "toot:votersCount",
},
"https://w3id.org/security/v1",
],
},
)
PostInteraction.handle_ap(
data={
"id": "https://remote.test/test-actor#votes/13/activity",
"to": "https://example.com/@test@example.com/",
"type": "Create",
"actor": "https://remote.test/test-actor/",
"object": {
"id": "https://remote.test/users/test-actor#votes/13",
"to": "https://example.com/@test@example.com/",
"name": "Option 2",
"type": "Note",
"inReplyTo": post.object_uri,
"attributedTo": "https://remote.test/test-actor/",
},
"@context": [
"https://www.w3.org/ns/activitystreams",
{
"toot": "http://joinmastodon.org/ns#",
"Emoji": "toot:Emoji",
"Public": "as:Public",
"Hashtag": "as:Hashtag",
"votersCount": "toot:votersCount",
},
"https://w3id.org/security/v1",
],
},
)
post.refresh_from_db()
assert isinstance(post.type_data, QuestionData)
assert post.type_data.voter_count == 1
assert post.type_data.options
assert len(post.type_data.options) == 3
assert post.type_data.options[0].votes == 1
assert post.type_data.options[1].votes == 1
assert post.type_data.options[2].votes == 0
@pytest.mark.django_db
def test_multiple_votes_to_single_vote_question(
identity: Identity, remote_identity: Identity, config_system
):
post = Post.create_local(
author=identity,
content="<p>Test Question</p>",
question={
"type": "Question",
"mode": "oneOf",
"options": [
{"name": "Option 1", "type": "Note", "votes": 0},
{"name": "Option 2", "type": "Note", "votes": 0},
],
"voter_count": 0,
"end_time": format_ld_date(timezone.now() + timedelta(1)),
},
)
PostInteraction.by_ap(
data={
"id": "https://remote.test/test-actor#votes/14/activity",
"to": "https://example.com/@test@example.com/",
"type": "Create",
"actor": "https://remote.test/test-actor/",
"object": {
"id": "https://remote.test/users/test-actor#votes/14",
"to": "https://example.com/@test@example.com/",
"name": "Option 1",
"type": "Note",
"inReplyTo": post.object_uri,
"attributedTo": "https://remote.test/test-actor/",
},
"@context": [
"https://www.w3.org/ns/activitystreams",
{
"toot": "http://joinmastodon.org/ns#",
"Emoji": "toot:Emoji",
"Public": "as:Public",
"Hashtag": "as:Hashtag",
"votersCount": "toot:votersCount",
},
"https://w3id.org/security/v1",
],
},
create=True,
)
with pytest.raises(PostInteraction.DoesNotExist) as ex:
PostInteraction.by_ap(
data={
"id": "https://remote.test/test-actor#votes/15/activity",
"to": "https://example.com/@test@example.com/",
"type": "Create",
"actor": "https://remote.test/test-actor/",
"object": {
"id": "https://remote.test/users/test-actor#votes/15",
"to": "https://example.com/@test@example.com/",
"name": "Option 2",
"type": "Note",
"inReplyTo": post.object_uri,
"attributedTo": "https://remote.test/test-actor/",
},
"@context": [
"https://www.w3.org/ns/activitystreams",
{
"toot": "http://joinmastodon.org/ns#",
"Emoji": "toot:Emoji",
"Public": "as:Public",
"Hashtag": "as:Hashtag",
"votersCount": "toot:votersCount",
},
"https://w3id.org/security/v1",
],
},
create=True,
)
assert "already voted" in str(ex)
@pytest.mark.django_db
def test_vote_in_expired_question(
identity: Identity, remote_identity: Identity, config_system
):
post = Post.create_local(
author=identity,
content="<p>Test Question</p>",
question={
"type": "Question",
"mode": "oneOf",
"options": [
{"name": "Option 1", "type": "Note", "votes": 0},
{"name": "Option 2", "type": "Note", "votes": 0},
],
"voter_count": 0,
"end_time": "2022-01-01T23:04:45+00:00",
},
)
with pytest.raises(PostInteraction.DoesNotExist) as ex:
PostInteraction.by_ap(
data={
"id": "https://remote.test/test-actor#votes/16/activity",
"to": "https://example.com/@test@example.com/",
"type": "Create",
"actor": "https://remote.test/test-actor/",
"object": {
"id": "https://remote.test/users/test-actor#votes/16",
"to": "https://example.com/@test@example.com/",
"name": "Option 1",
"type": "Note",
"inReplyTo": post.object_uri,
"attributedTo": "https://remote.test/test-actor/",
},
"@context": [
"https://www.w3.org/ns/activitystreams",
{
"toot": "http://joinmastodon.org/ns#",
"Emoji": "toot:Emoji",
"Public": "as:Public",
"Hashtag": "as:Hashtag",
"votersCount": "toot:votersCount",
},
"https://w3id.org/security/v1",
],
},
create=True,
)
assert "Cannot create a vote to the expired question" in str(ex)
@pytest.mark.django_db
def test_vote_to_ap(identity: Identity, remote_identity: Identity, config_system):
post = Post.objects.create(
author=remote_identity,
local=False,
content="<p>Test Question</p>",
type_data={
"type": "Question",
"mode": "oneOf",
"options": [
{"name": "Option 1", "type": "Note", "votes": 6},
{"name": "Option 2", "type": "Note", "votes": 4},
],
"voter_count": 10,
"end_time": format_ld_date(timezone.now() + timedelta(1)),
},
)
post.refresh_from_db()
interaction = PostInteraction.create_votes(
post=post,
identity=identity,
choices=[0],
)[0]
data = interaction.to_create_ap()
assert data["object"]["to"] == remote_identity.actor_uri
assert data["object"]["attributedTo"] == identity.actor_uri
assert data["object"]["name"] == "Option 1"
assert data["object"]["inReplyTo"] == post.object_uri

62
tests/api/test_polls.py Normal file
View file

@ -0,0 +1,62 @@
from datetime import timedelta
import pytest
from django.utils import timezone
from activities.models import Post
from core.ld import format_ld_date
@pytest.mark.django_db
def test_get_poll(api_client):
response = api_client.post(
"/api/v1/statuses",
content_type="application/json",
data={
"status": "Hello, world!",
"poll": {
"options": ["Option 1", "Option 2"],
"expires_in": 300,
},
},
).json()
id = response["id"]
response = api_client.get(
f"/api/v1/polls/{id}",
).json()
assert response["id"] == id
assert response["voted"]
@pytest.mark.django_db
def test_vote_poll(api_client, identity2):
post = Post.create_local(
author=identity2,
content="<p>Test Question</p>",
question={
"type": "Question",
"mode": "oneOf",
"options": [
{"name": "Option 1", "type": "Note", "votes": 0},
{"name": "Option 2", "type": "Note", "votes": 0},
],
"voter_count": 0,
"end_time": format_ld_date(timezone.now() + timedelta(1)),
},
)
response = api_client.post(
f"/api/v1/polls/{post.id}/votes",
content_type="application/json",
data={
"choices": [0],
},
).json()
assert response["id"] == str(post.id)
assert response["voted"]
assert response["votes_count"] == 1
assert response["own_votes"] == [0]

View file

@ -75,3 +75,68 @@ def test_mention_format(api_client, identity, remote_identity):
assert (
response["text"] == '<p>Hey <a href="https://example.com/@test/">@test</a></p>'
)
@pytest.mark.django_db
def test_post_question_status(api_client):
response = api_client.post(
"/api/v1/statuses",
content_type="application/json",
data={
"status": "Hello, world!",
"poll": {
"options": ["Option 1", "Option 2"],
"expires_in": 300,
},
},
).json()
assert response["poll"]["id"] == response["id"]
assert response["poll"]["options"] == [
{"title": "Option 1", "votes_count": 0},
{"title": "Option 2", "votes_count": 0},
]
assert not response["poll"]["expired"]
assert not response["poll"]["multiple"]
@pytest.mark.django_db
def test_question_format(api_client, remote_identity):
"""
Ensures incoming questions are property parsed.
"""
# Make a remote question post and check it
post = Post.objects.create(
local=False,
author=remote_identity,
content="<p>Test Question</p>",
object_uri="https://remote.test/status/123456",
type=Post.Types.question,
type_data={
"type": "Question",
"mode": "oneOf",
"options": [
{"name": "Option 1", "type": "Note", "votes": 10},
{"name": "Option 2", "type": "Note", "votes": 20},
],
"voter_count": 30,
"end_time": "2022-01-01T23:04:45+00:00",
},
)
response = api_client.get(f"/api/v1/statuses/{post.id}").json()
assert response["text"] == "<p>Test Question</p>"
assert response["poll"] == {
"id": str(post.id),
"expires_at": "2022-01-01T23:04:45.000Z",
"expired": True,
"multiple": False,
"votes_count": 30,
"voters_count": None,
"voted": False,
"own_votes": [],
"options": [
{"title": "Option 1", "votes_count": 10},
{"title": "Option 2", "votes_count": 20},
],
"emojis": [],
}

View file

@ -30,9 +30,15 @@ class InboxMessageStates(StateGraph):
case "create":
match instance.message_object_type:
case "note":
await sync_to_async(Post.handle_create_ap)(instance.message)
if instance.message_object_has_content:
await sync_to_async(Post.handle_create_ap)(instance.message)
else:
# Notes without content are Interaction candidates
await sync_to_async(PostInteraction.handle_ap)(
instance.message
)
case "question":
pass # Drop for now
await sync_to_async(Post.handle_create_ap)(instance.message)
case unknown:
if unknown in Post.Types.names:
await sync_to_async(Post.handle_create_ap)(instance.message)
@ -213,3 +219,7 @@ class InboxMessage(StatorModel):
@property
def message_actor(self):
return self.message.get("actor")
@property
def message_object_has_content(self):
return "content" in self.message.get("object", {})