diff --git a/fedireads/signatures.py b/fedireads/signatures.py index 8fab5c31..49d1a2d3 100644 --- a/fedireads/signatures.py +++ b/fedireads/signatures.py @@ -1,5 +1,6 @@ import hashlib from urllib.parse import urlparse +import datetime from base64 import b64encode, b64decode from Crypto import Random @@ -7,6 +8,7 @@ from Crypto.PublicKey import RSA from Crypto.Signature import pkcs1_15 #pylint: disable=no-name-in-module from Crypto.Hash import SHA256 +MAX_SIGNATURE_AGE = 300 def create_key_pair(): random_generator = Random.new().read @@ -77,6 +79,9 @@ class Signature: def verify(self, public_key, request): ''' verify rsa signature ''' + if http_date_age(request.headers['date']) > MAX_SIGNATURE_AGE: + raise ValueError( + "Request too old: %s" % (request.headers['date'],)) public_key = RSA.import_key(public_key) comparison_string = [] @@ -99,3 +104,8 @@ class Signature: # raises a ValueError if it fails signer.verify(digest, self.signature) + +def http_date_age(datestr): + parsed = datetime.datetime.strptime(datestr, '%a, %d %b %Y %H:%M:%S GMT') + delta = datetime.datetime.utcnow() - parsed + return delta.total_seconds() diff --git a/fedireads/tests/test_signing.py b/fedireads/tests/test_signing.py index 7310a445..b93d263e 100644 --- a/fedireads/tests/test_signing.py +++ b/fedireads/tests/test_signing.py @@ -1,3 +1,4 @@ +import time from collections import namedtuple from urllib.parse import urlsplit @@ -51,8 +52,9 @@ class Signature(TestCase): sender, signer=None, send_data=None, - digest=None): - now = http_date() + digest=None, + date=None): + now = date or http_date() data = get_follow_data(sender, self.rat) signature = make_signature( signer or sender, self.rat.inbox, now, digest or make_digest(data)) @@ -105,3 +107,11 @@ class Signature(TestCase): self.mouse, digest='SHA-256=AAAAAAAAAAAAAAAAAA') self.assertEqual(response.status_code, 401) + + def test_old_message(self): + '''Old messages should be rejected to prevent replay attacks.''' + response = self.send_test_request( + self.mouse, + date=http_date(time.time() - 301) + ) + self.assertEqual(response.status_code, 401)