diff --git a/src/cryptography/fernet.py b/src/cryptography/fernet.py index 26178d7b046af04862530d2e19d3093091b61119_c3JjL2NyeXB0b2dyYXBoeS9mZXJuZXQucHk=..7e30f0f9203fe513a3e65e3ee9162954a5cd2891_c3JjL2NyeXB0b2dyYXBoeS9mZXJuZXQucHk= 100644 --- a/src/cryptography/fernet.py +++ b/src/cryptography/fernet.py @@ -73,6 +73,7 @@ return base64.urlsafe_b64encode(basic_parts + hmac) def decrypt(self, token, ttl=None): - return self.decrypt_at_time(token, ttl, int(time.time())) + timestamp, data = Fernet._get_unverified_token_data(token) + return self._decrypt_data(data, timestamp, ttl, int(time.time())) def decrypt_at_time(self, token, ttl, current_time): @@ -77,5 +78,9 @@ def decrypt_at_time(self, token, ttl, current_time): + if ttl is None: + raise ValueError( + "decrypt_at_time() can only be used with a non-None ttl" + ) timestamp, data = Fernet._get_unverified_token_data(token) return self._decrypt_data(data, timestamp, ttl, current_time) @@ -170,7 +175,12 @@ return self._fernets[0]._encrypt_from_parts(p, timestamp, iv) def decrypt(self, msg, ttl=None): - return self.decrypt_at_time(msg, ttl, int(time.time())) + for f in self._fernets: + try: + return f.decrypt(msg, ttl) + except InvalidToken: + pass + raise InvalidToken def decrypt_at_time(self, msg, ttl, current_time): for f in self._fernets: diff --git a/tests/test_fernet.py b/tests/test_fernet.py index 26178d7b046af04862530d2e19d3093091b61119_dGVzdHMvdGVzdF9mZXJuZXQucHk=..7e30f0f9203fe513a3e65e3ee9162954a5cd2891_dGVzdHMvdGVzdF9mZXJuZXQucHk= 100644 --- a/tests/test_fernet.py +++ b/tests/test_fernet.py @@ -117,8 +117,6 @@ token = f.encrypt(pt) ts = "1985-10-26T01:20:01-07:00" current_time = calendar.timegm(iso8601.parse_date(ts).utctimetuple()) - assert f.decrypt_at_time( - token, ttl=None, current_time=current_time) == pt monkeypatch.setattr(time, "time", lambda: current_time) assert f.decrypt(token, ttl=None) == pt @@ -122,6 +120,13 @@ monkeypatch.setattr(time, "time", lambda: current_time) assert f.decrypt(token, ttl=None) == pt + def test_ttl_required_in_decrypt_at_time(self, monkeypatch, backend): + f = Fernet(base64.urlsafe_b64encode(b"\x00" * 32), backend=backend) + pt = b"encrypt me" + token = f.encrypt(pt) + with pytest.raises(ValueError): + f.decrypt_at_time(token, ttl=None, current_time=int(time.time())) + @pytest.mark.parametrize("message", [b"", b"Abc!", b"\x00\xFF\x00\x80"]) def test_roundtrips(self, message, backend): f = Fernet(Fernet.generate_key(), backend=backend) @@ -167,6 +172,17 @@ with pytest.raises(InvalidToken): f.decrypt(b"\x00" * 16) + def test_decrypt_at_time(self, backend): + f1 = Fernet(base64.urlsafe_b64encode(b"\x00" * 32), backend=backend) + f = MultiFernet([f1]) + pt = b"encrypt me" + token = f.encrypt_at_time(pt, current_time=100) + assert f.decrypt_at_time(token, ttl=1, current_time=100) == pt + with pytest.raises(InvalidToken): + f.decrypt_at_time(token, ttl=1, current_time=102) + with pytest.raises(ValueError): + f.decrypt_at_time(token, ttl=None, current_time=100) + def test_no_fernets(self, backend): with pytest.raises(ValueError): MultiFernet([])