# HG changeset patch
# User Jakub Stasiak <jakub@stasiak.at>
# Date 1593049914 -7200
#      Thu Jun 25 03:51:54 2020 +0200
# Node ID 7e30f0f9203fe513a3e65e3ee9162954a5cd2891
# Parent  26178d7b046af04862530d2e19d3093091b61119
Disallow ttl=None in (Multi)Fernet.decrypt_at_time() (#5280)

* Disallow ttl=None in (Multi)Fernet.decrypt_at_time()

Since the introduction of the _at_time() methods in #5256[1] there's
been this little voice in the back of my mind telling me that maybe it's
not the best idea to allow ttl=None in decrypt_at_time(). It's been like
this for convenience and code reuse reasons.

Then I submitted a patch for cryptography stubs in typeshed[2] and I had
to decide whether to define decrypt_at_time()'s ttl as int and be
incompatible with cryptography's behavior or Optional[int] and advertise
an API that can be misused much too easily. I went ahead with int.

Considering the above I decided to propose this patch. Some amount of
redundancy (and a new test to properly cover the
MultiFernet.decrypt_at_time() implementation) is a price to prevent
clients from shooting themselves in the foot with the tll=None gun since
setting ttl to None disabled timestamp checks even if current_time was
provided.

[1] https://github.com/pyca/cryptography/pull/5256
[2] https://github.com/python/typeshed/pull/4238

* Actually test the return value here

* Fix formatting

diff --git a/src/cryptography/fernet.py b/src/cryptography/fernet.py
--- a/src/cryptography/fernet.py
+++ b/src/cryptography/fernet.py
@@ -73,9 +73,14 @@
         return base64.urlsafe_b64encode(basic_parts + hmac)
 
     def decrypt(self, token, ttl=None):
-        return self.decrypt_at_time(token, ttl, int(time.time()))
+        timestamp, data = Fernet._get_unverified_token_data(token)
+        return self._decrypt_data(data, timestamp, ttl, int(time.time()))
 
     def decrypt_at_time(self, token, ttl, current_time):
+        if ttl is None:
+            raise ValueError(
+                "decrypt_at_time() can only be used with a non-None ttl"
+            )
         timestamp, data = Fernet._get_unverified_token_data(token)
         return self._decrypt_data(data, timestamp, ttl, current_time)
 
@@ -170,7 +175,12 @@
         return self._fernets[0]._encrypt_from_parts(p, timestamp, iv)
 
     def decrypt(self, msg, ttl=None):
-        return self.decrypt_at_time(msg, ttl, int(time.time()))
+        for f in self._fernets:
+            try:
+                return f.decrypt(msg, ttl)
+            except InvalidToken:
+                pass
+        raise InvalidToken
 
     def decrypt_at_time(self, msg, ttl, current_time):
         for f in self._fernets:
diff --git a/tests/test_fernet.py b/tests/test_fernet.py
--- a/tests/test_fernet.py
+++ b/tests/test_fernet.py
@@ -117,11 +117,16 @@
         token = f.encrypt(pt)
         ts = "1985-10-26T01:20:01-07:00"
         current_time = calendar.timegm(iso8601.parse_date(ts).utctimetuple())
-        assert f.decrypt_at_time(
-            token, ttl=None, current_time=current_time) == pt
         monkeypatch.setattr(time, "time", lambda: current_time)
         assert f.decrypt(token, ttl=None) == pt
 
+    def test_ttl_required_in_decrypt_at_time(self, monkeypatch, backend):
+        f = Fernet(base64.urlsafe_b64encode(b"\x00" * 32), backend=backend)
+        pt = b"encrypt me"
+        token = f.encrypt(pt)
+        with pytest.raises(ValueError):
+            f.decrypt_at_time(token, ttl=None, current_time=int(time.time()))
+
     @pytest.mark.parametrize("message", [b"", b"Abc!", b"\x00\xFF\x00\x80"])
     def test_roundtrips(self, message, backend):
         f = Fernet(Fernet.generate_key(), backend=backend)
@@ -167,6 +172,17 @@
         with pytest.raises(InvalidToken):
             f.decrypt(b"\x00" * 16)
 
+    def test_decrypt_at_time(self, backend):
+        f1 = Fernet(base64.urlsafe_b64encode(b"\x00" * 32), backend=backend)
+        f = MultiFernet([f1])
+        pt = b"encrypt me"
+        token = f.encrypt_at_time(pt, current_time=100)
+        assert f.decrypt_at_time(token, ttl=1, current_time=100) == pt
+        with pytest.raises(InvalidToken):
+            f.decrypt_at_time(token, ttl=1, current_time=102)
+        with pytest.raises(ValueError):
+            f.decrypt_at_time(token, ttl=None, current_time=100)
+
     def test_no_fernets(self, backend):
         with pytest.raises(ValueError):
             MultiFernet([])