Commit 1f1e1a6fe722deb1656f483b1367ea4be978db5b

Cleo Rebert 2022-04-24T16:41:56

Add constant-time token comparison and partialEq trait Add PartialEq for TOTP<T> and PartialEq+Eq for Algorithm

diff --git a/Cargo.toml b/Cargo.toml
index 9473725..c44f914 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -1,6 +1,6 @@
 [package]
 name = "totp-rs"
-version = "1.0.0"
+version = "1.1.0"
 authors = ["Cleo Rebert <cleo.rebert@gmail.com>"]
 edition = "2021"
 readme = "README.md"
@@ -25,6 +25,7 @@ sha2 = "~0.10.2"
 sha-1 = "~0.10.0"
 hmac = "~0.12.1"
 base32 = "~0.4"
+constant_time_eq = "~0.2.1"
 qrcode = { version = "~0.12", optional = true }
 image = { version = "~0.23.14", optional = true}
 base64 = { version = "~0.13", optional = true }
diff --git a/README.md b/README.md
index a461319..7102b2b 100644
--- a/README.md
+++ b/README.md
@@ -17,7 +17,7 @@ With optional feature "serde_support", library-defined types will be Deserialize
 Add it to your `Cargo.toml`:
 ```toml
 [dependencies]
-totp-rs = "~1.0"
+totp-rs = "~1.1"
 ```
 You can then do something like:
 ```Rust
@@ -45,7 +45,7 @@ println!("{}", token);
 Add it to your `Cargo.toml`:
 ```toml
 [dependencies.totp-rs]
-version = "~1.0"
+version = "~1.1"
 features = ["qr"]
 ```
 You can then do something like:
@@ -67,6 +67,6 @@ println!("{}", code);
 Add it to your `Cargo.toml`:
 ```toml
 [dependencies.totp-rs]
-version = "~1.0"
+version = "~1.1"
 features = ["serde_support"]
 ```
diff --git a/src/lib.rs b/src/lib.rs
index 8bb2c94..0476204 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -44,6 +44,8 @@
 //! # }
 //! ```
 
+use constant_time_eq::constant_time_eq;
+
 #[cfg(feature = "serde_support")]
 use serde::{Deserialize, Serialize};
 
@@ -59,7 +61,7 @@ type HmacSha256 = hmac::Hmac<sha2::Sha256>;
 type HmacSha512 = hmac::Hmac<sha2::Sha512>;
 
 /// Algorithm enum holds the three standards algorithms for TOTP as per the [reference implementation](https://tools.ietf.org/html/rfc6238#appendix-A)
-#[derive(Debug, Copy, Clone)]
+#[derive(Debug, Copy, Clone, Eq, PartialEq)]
 #[cfg_attr(feature = "serde_support", derive(Serialize, Deserialize))]
 pub enum Algorithm {
     SHA1,
@@ -117,6 +119,24 @@ pub struct TOTP<T = Vec<u8>> {
     pub secret: T,
 }
 
+impl <T: AsRef<[u8]>> PartialEq for TOTP<T> {
+    fn eq(&self, other: &Self) -> bool {
+        if self.algorithm != other.algorithm {
+            return false;
+        }
+        if self.digits != other.digits {
+            return false;
+        }
+        if self.skew != other.skew {
+            return false;
+        }
+        if self.step != other.step {
+            return false;
+        }
+        constant_time_eq(self.secret.as_ref(), other.secret.as_ref())
+    }
+}
+
 impl<T: AsRef<[u8]>> TOTP<T> {
     /// Will create a new instance of TOTP with given parameters. See [the doc](struct.TOTP.html#fields) for reference as to how to choose those values
     pub fn new(algorithm: Algorithm, digits: usize, skew: u8, step: u64, secret: T) -> TOTP<T> {
@@ -154,7 +174,8 @@ impl<T: AsRef<[u8]>> TOTP<T> {
         let basestep = time / self.step - (self.skew as u64);
         for i in 0..self.skew * 2 + 1 {
             let step_time = (basestep + (i as u64)) * (self.step as u64);
-            if self.generate(step_time) == token {
+            
+            if constant_time_eq(self.generate(step_time).as_bytes(), token.as_bytes()) {
                 return true;
             }
         }
@@ -209,6 +230,48 @@ mod tests {
     use super::*;
 
     #[test]
+    fn comparison_ok() {
+        let reference = TOTP::new(Algorithm::SHA1, 6, 1, 1, "TestSecret");
+        let test = TOTP::new(Algorithm::SHA1, 6, 1, 1, "TestSecret");
+        assert_eq!(reference, test);
+    }
+
+    #[test]
+    fn comparison_different_algo() {
+        let reference = TOTP::new(Algorithm::SHA1, 6, 1, 1, "TestSecret");
+        let test = TOTP::new(Algorithm::SHA256, 6, 1, 1, "TestSecret");
+        assert_ne!(reference, test);
+    }
+
+    #[test]
+    fn comparison_different_digits() {
+        let reference = TOTP::new(Algorithm::SHA1, 6, 1, 1, "TestSecret");
+        let test = TOTP::new(Algorithm::SHA1, 8, 1, 1, "TestSecret");
+        assert_ne!(reference, test);
+    }
+
+    #[test]
+    fn comparison_different_skew() {
+        let reference = TOTP::new(Algorithm::SHA1, 6, 1, 1, "TestSecret");
+        let test = TOTP::new(Algorithm::SHA1, 6, 0, 1, "TestSecret");
+        assert_ne!(reference, test);
+    }
+
+    #[test]
+    fn comparison_different_step() {
+        let reference = TOTP::new(Algorithm::SHA1, 6, 1, 1, "TestSecret");
+        let test = TOTP::new(Algorithm::SHA1, 6, 1, 30, "TestSecret");
+        assert_ne!(reference, test);
+    }
+
+    #[test]
+    fn comparison_different_secret() {
+        let reference = TOTP::new(Algorithm::SHA1, 6, 1, 1, "TestSecret");
+        let test = TOTP::new(Algorithm::SHA1, 6, 1, 1, "TestSecretL");
+        assert_ne!(reference, test);
+    }
+
+    #[test]
     fn url_for_secret_matches_sha1() {
         let totp = TOTP::new(Algorithm::SHA1, 6, 1, 1, "TestSecret");
         let url = totp.get_url("test_url", "totp-rs");