forked from mirrors/bookwyrm
Smarter way of inferring serializers (which are explicitly present)
This commit is contained in:
parent
a3b7063e4b
commit
91908eb1b6
3 changed files with 39 additions and 25 deletions
|
@ -40,8 +40,9 @@ class Signature:
|
||||||
signatureValue: str
|
signatureValue: str
|
||||||
type: str = 'RsaSignature2017'
|
type: str = 'RsaSignature2017'
|
||||||
|
|
||||||
def naive_parse(activity_objects, activity_json):
|
def naive_parse(activity_objects, activity_json, serializer=None):
|
||||||
''' this navigates circular import issues '''
|
''' this navigates circular import issues '''
|
||||||
|
if not serializer:
|
||||||
if activity_json.get('publicKeyPem'):
|
if activity_json.get('publicKeyPem'):
|
||||||
# ugh
|
# ugh
|
||||||
activity_json['type'] = 'PublicKey'
|
activity_json['type'] = 'PublicKey'
|
||||||
|
@ -71,8 +72,9 @@ class ActivityObject:
|
||||||
is_subclass = issubclass(field.type, ActivityObject)
|
is_subclass = issubclass(field.type, ActivityObject)
|
||||||
except TypeError:
|
except TypeError:
|
||||||
is_subclass = False
|
is_subclass = False
|
||||||
if is_subclass and activity_objects:
|
if is_subclass:
|
||||||
value = naive_parse(activity_objects, value)
|
value = naive_parse(
|
||||||
|
activity_objects, value, serializer=field.type)
|
||||||
|
|
||||||
except KeyError:
|
except KeyError:
|
||||||
if field.default == MISSING and \
|
if field.default == MISSING and \
|
||||||
|
@ -89,7 +91,8 @@ class ActivityObject:
|
||||||
|
|
||||||
# only reject statuses if we're potentially creating them
|
# only reject statuses if we're potentially creating them
|
||||||
if allow_create and \
|
if allow_create and \
|
||||||
hasattr(model, 'ignore_activity') and model.ignore_activity(self):
|
hasattr(model, 'ignore_activity') and \
|
||||||
|
model.ignore_activity(self):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# check for an existing instance
|
# check for an existing instance
|
||||||
|
@ -219,11 +222,12 @@ def get_model_from_type(activity_type):
|
||||||
model = [m for m in models if hasattr(m, 'activity_serializer') and \
|
model = [m for m in models if hasattr(m, 'activity_serializer') and \
|
||||||
hasattr(m.activity_serializer, 'type') and \
|
hasattr(m.activity_serializer, 'type') and \
|
||||||
m.activity_serializer.type == activity_type]
|
m.activity_serializer.type == activity_type]
|
||||||
if not len(model):
|
if not model:
|
||||||
raise ActivitySerializerError(
|
raise ActivitySerializerError(
|
||||||
'No model found for activity type "%s"' % activity_type)
|
'No model found for activity type "%s"' % activity_type)
|
||||||
return model[0]
|
return model[0]
|
||||||
|
|
||||||
|
|
||||||
def resolve_remote_id(remote_id, model=None, refresh=False, save=True):
|
def resolve_remote_id(remote_id, model=None, refresh=False, save=True):
|
||||||
''' take a remote_id and return an instance, creating if necessary '''
|
''' take a remote_id and return an instance, creating if necessary '''
|
||||||
if model:# a bonus check we can do if we already know the model
|
if model:# a bonus check we can do if we already know the model
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
''' getting and verifying signatures '''
|
||||||
import time
|
import time
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
from urllib.parse import urlsplit
|
from urllib.parse import urlsplit
|
||||||
|
@ -12,31 +13,33 @@ import pytest
|
||||||
from django.test import TestCase, Client
|
from django.test import TestCase, Client
|
||||||
from django.utils.http import http_date
|
from django.utils.http import http_date
|
||||||
|
|
||||||
from bookwyrm.models import User
|
from bookwyrm import models
|
||||||
from bookwyrm.activitypub import Follow
|
from bookwyrm.activitypub import Follow
|
||||||
from bookwyrm.settings import DOMAIN
|
from bookwyrm.settings import DOMAIN
|
||||||
from bookwyrm.signatures import create_key_pair, make_signature, make_digest
|
from bookwyrm.signatures import create_key_pair, make_signature, make_digest
|
||||||
|
|
||||||
def get_follow_data(follower, followee):
|
def get_follow_activity(follower, followee):
|
||||||
follow_activity = Follow(
|
''' generates a test activity '''
|
||||||
|
return Follow(
|
||||||
id='https://test.com/user/follow/id',
|
id='https://test.com/user/follow/id',
|
||||||
actor=follower.remote_id,
|
actor=follower.remote_id,
|
||||||
object=followee.remote_id,
|
object=followee.remote_id,
|
||||||
).serialize()
|
).serialize()
|
||||||
return json.dumps(follow_activity)
|
|
||||||
|
|
||||||
KeyPair = namedtuple('KeyPair', ('private_key', 'public_key'))
|
KeyPair = namedtuple('KeyPair', ('private_key', 'public_key'))
|
||||||
Sender = namedtuple('Sender', ('remote_id', 'key_pair'))
|
Sender = namedtuple('Sender', ('remote_id', 'key_pair'))
|
||||||
|
|
||||||
class Signature(TestCase):
|
class Signature(TestCase):
|
||||||
|
''' signature test '''
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.mouse = User.objects.create_user(
|
''' create users and test data '''
|
||||||
|
self.mouse = models.User.objects.create_user(
|
||||||
'mouse@%s' % DOMAIN, 'mouse@example.com', '',
|
'mouse@%s' % DOMAIN, 'mouse@example.com', '',
|
||||||
local=True, localname='mouse')
|
local=True, localname='mouse')
|
||||||
self.rat = User.objects.create_user(
|
self.rat = models.User.objects.create_user(
|
||||||
'rat@%s' % DOMAIN, 'rat@example.com', '',
|
'rat@%s' % DOMAIN, 'rat@example.com', '',
|
||||||
local=True, localname='rat')
|
local=True, localname='rat')
|
||||||
self.cat = User.objects.create_user(
|
self.cat = models.User.objects.create_user(
|
||||||
'cat@%s' % DOMAIN, 'cat@example.com', '',
|
'cat@%s' % DOMAIN, 'cat@example.com', '',
|
||||||
local=True, localname='cat')
|
local=True, localname='cat')
|
||||||
|
|
||||||
|
@ -47,6 +50,8 @@ class Signature(TestCase):
|
||||||
KeyPair(private_key, public_key)
|
KeyPair(private_key, public_key)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
models.SiteSettings.objects.create()
|
||||||
|
|
||||||
def send(self, signature, now, data, digest):
|
def send(self, signature, now, data, digest):
|
||||||
''' test request '''
|
''' test request '''
|
||||||
c = Client()
|
c = Client()
|
||||||
|
@ -63,7 +68,7 @@ class Signature(TestCase):
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
def send_test_request(
|
def send_test_request(#pylint: disable=too-many-arguments
|
||||||
self,
|
self,
|
||||||
sender,
|
sender,
|
||||||
signer=None,
|
signer=None,
|
||||||
|
@ -72,7 +77,7 @@ class Signature(TestCase):
|
||||||
date=None):
|
date=None):
|
||||||
''' sends a follow request to the "rat" user '''
|
''' sends a follow request to the "rat" user '''
|
||||||
now = date or http_date()
|
now = date or http_date()
|
||||||
data = json.dumps(get_follow_data(sender, self.rat))
|
data = json.dumps(get_follow_activity(sender, self.rat))
|
||||||
digest = digest or make_digest(data)
|
digest = digest or make_digest(data)
|
||||||
signature = make_signature(
|
signature = make_signature(
|
||||||
signer or sender, self.rat.inbox, now, digest)
|
signer or sender, self.rat.inbox, now, digest)
|
||||||
|
@ -81,6 +86,7 @@ class Signature(TestCase):
|
||||||
return self.send(signature, now, send_data or data, digest)
|
return self.send(signature, now, send_data or data, digest)
|
||||||
|
|
||||||
def test_correct_signature(self):
|
def test_correct_signature(self):
|
||||||
|
''' this one should just work '''
|
||||||
response = self.send_test_request(sender=self.mouse)
|
response = self.send_test_request(sender=self.mouse)
|
||||||
self.assertEqual(response.status_code, 200)
|
self.assertEqual(response.status_code, 200)
|
||||||
|
|
||||||
|
@ -120,6 +126,7 @@ class Signature(TestCase):
|
||||||
|
|
||||||
@responses.activate
|
@responses.activate
|
||||||
def test_key_needs_refresh(self):
|
def test_key_needs_refresh(self):
|
||||||
|
''' an out of date key should be updated and the new key work '''
|
||||||
datafile = pathlib.Path(__file__).parent.joinpath('data/ap_user.json')
|
datafile = pathlib.Path(__file__).parent.joinpath('data/ap_user.json')
|
||||||
data = json.loads(datafile.read_bytes())
|
data = json.loads(datafile.read_bytes())
|
||||||
data['id'] = self.fake_remote.remote_id
|
data['id'] = self.fake_remote.remote_id
|
||||||
|
@ -165,6 +172,7 @@ class Signature(TestCase):
|
||||||
|
|
||||||
@responses.activate
|
@responses.activate
|
||||||
def test_nonexistent_signer(self):
|
def test_nonexistent_signer(self):
|
||||||
|
''' fail when unable to look up signer '''
|
||||||
responses.add(
|
responses.add(
|
||||||
responses.GET,
|
responses.GET,
|
||||||
self.fake_remote.remote_id,
|
self.fake_remote.remote_id,
|
||||||
|
@ -180,11 +188,12 @@ class Signature(TestCase):
|
||||||
with patch('bookwyrm.activitypub.resolve_remote_id'):
|
with patch('bookwyrm.activitypub.resolve_remote_id'):
|
||||||
response = self.send_test_request(
|
response = self.send_test_request(
|
||||||
self.mouse,
|
self.mouse,
|
||||||
send_data=get_follow_data(self.mouse, self.cat))
|
send_data=get_follow_activity(self.mouse, self.cat))
|
||||||
self.assertEqual(response.status_code, 401)
|
self.assertEqual(response.status_code, 401)
|
||||||
|
|
||||||
@pytest.mark.integration
|
@pytest.mark.integration
|
||||||
def test_invalid_digest(self):
|
def test_invalid_digest(self):
|
||||||
|
''' signature digest must be valid '''
|
||||||
with patch('bookwyrm.activitypub.resolve_remote_id'):
|
with patch('bookwyrm.activitypub.resolve_remote_id'):
|
||||||
response = self.send_test_request(
|
response = self.send_test_request(
|
||||||
self.mouse,
|
self.mouse,
|
||||||
|
|
|
@ -75,7 +75,8 @@ def has_valid_signature(request, activity):
|
||||||
if key_actor != activity.get('actor'):
|
if key_actor != activity.get('actor'):
|
||||||
raise ValueError("Wrong actor created signature.")
|
raise ValueError("Wrong actor created signature.")
|
||||||
|
|
||||||
remote_user = activitypub.resolve_remote_id(models.User, key_actor)
|
remote_user = activitypub.resolve_remote_id(
|
||||||
|
key_actor, model=models.User)
|
||||||
if not remote_user:
|
if not remote_user:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
@ -84,7 +85,7 @@ def has_valid_signature(request, activity):
|
||||||
except ValueError:
|
except ValueError:
|
||||||
old_key = remote_user.key_pair.public_key
|
old_key = remote_user.key_pair.public_key
|
||||||
remote_user = activitypub.resolve_remote_id(
|
remote_user = activitypub.resolve_remote_id(
|
||||||
models.User, remote_user.remote_id, refresh=True
|
remote_user.remote_id, model=models.User, refresh=True
|
||||||
)
|
)
|
||||||
if remote_user.key_pair.public_key == old_key:
|
if remote_user.key_pair.public_key == old_key:
|
||||||
raise # Key unchanged.
|
raise # Key unchanged.
|
||||||
|
|
Loading…
Reference in a new issue