Add constant-time token comparison and partialEq trait Add PartialEq for TOTP<T> and PartialEq+Eq for Algorithm
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156
diff --git a/Cargo.toml b/Cargo.toml
index 9473725..c44f914 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -1,6 +1,6 @@
[package]
name = "totp-rs"
-version = "1.0.0"
+version = "1.1.0"
authors = ["Cleo Rebert <cleo.rebert@gmail.com>"]
edition = "2021"
readme = "README.md"
@@ -25,6 +25,7 @@ sha2 = "~0.10.2"
sha-1 = "~0.10.0"
hmac = "~0.12.1"
base32 = "~0.4"
+constant_time_eq = "~0.2.1"
qrcode = { version = "~0.12", optional = true }
image = { version = "~0.23.14", optional = true}
base64 = { version = "~0.13", optional = true }
diff --git a/README.md b/README.md
index a461319..7102b2b 100644
--- a/README.md
+++ b/README.md
@@ -17,7 +17,7 @@ With optional feature "serde_support", library-defined types will be Deserialize
Add it to your `Cargo.toml`:
```toml
[dependencies]
-totp-rs = "~1.0"
+totp-rs = "~1.1"
```
You can then do something like:
```Rust
@@ -45,7 +45,7 @@ println!("{}", token);
Add it to your `Cargo.toml`:
```toml
[dependencies.totp-rs]
-version = "~1.0"
+version = "~1.1"
features = ["qr"]
```
You can then do something like:
@@ -67,6 +67,6 @@ println!("{}", code);
Add it to your `Cargo.toml`:
```toml
[dependencies.totp-rs]
-version = "~1.0"
+version = "~1.1"
features = ["serde_support"]
```
diff --git a/src/lib.rs b/src/lib.rs
index 8bb2c94..0476204 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -44,6 +44,8 @@
//! # }
//! ```
+use constant_time_eq::constant_time_eq;
+
#[cfg(feature = "serde_support")]
use serde::{Deserialize, Serialize};
@@ -59,7 +61,7 @@ type HmacSha256 = hmac::Hmac<sha2::Sha256>;
type HmacSha512 = hmac::Hmac<sha2::Sha512>;
/// Algorithm enum holds the three standards algorithms for TOTP as per the [reference implementation](https://tools.ietf.org/html/rfc6238#appendix-A)
-#[derive(Debug, Copy, Clone)]
+#[derive(Debug, Copy, Clone, Eq, PartialEq)]
#[cfg_attr(feature = "serde_support", derive(Serialize, Deserialize))]
pub enum Algorithm {
SHA1,
@@ -117,6 +119,24 @@ pub struct TOTP<T = Vec<u8>> {
pub secret: T,
}
+impl <T: AsRef<[u8]>> PartialEq for TOTP<T> {
+ fn eq(&self, other: &Self) -> bool {
+ if self.algorithm != other.algorithm {
+ return false;
+ }
+ if self.digits != other.digits {
+ return false;
+ }
+ if self.skew != other.skew {
+ return false;
+ }
+ if self.step != other.step {
+ return false;
+ }
+ constant_time_eq(self.secret.as_ref(), other.secret.as_ref())
+ }
+}
+
impl<T: AsRef<[u8]>> TOTP<T> {
/// Will create a new instance of TOTP with given parameters. See [the doc](struct.TOTP.html#fields) for reference as to how to choose those values
pub fn new(algorithm: Algorithm, digits: usize, skew: u8, step: u64, secret: T) -> TOTP<T> {
@@ -154,7 +174,8 @@ impl<T: AsRef<[u8]>> TOTP<T> {
let basestep = time / self.step - (self.skew as u64);
for i in 0..self.skew * 2 + 1 {
let step_time = (basestep + (i as u64)) * (self.step as u64);
- if self.generate(step_time) == token {
+
+ if constant_time_eq(self.generate(step_time).as_bytes(), token.as_bytes()) {
return true;
}
}
@@ -209,6 +230,48 @@ mod tests {
use super::*;
#[test]
+ fn comparison_ok() {
+ let reference = TOTP::new(Algorithm::SHA1, 6, 1, 1, "TestSecret");
+ let test = TOTP::new(Algorithm::SHA1, 6, 1, 1, "TestSecret");
+ assert_eq!(reference, test);
+ }
+
+ #[test]
+ fn comparison_different_algo() {
+ let reference = TOTP::new(Algorithm::SHA1, 6, 1, 1, "TestSecret");
+ let test = TOTP::new(Algorithm::SHA256, 6, 1, 1, "TestSecret");
+ assert_ne!(reference, test);
+ }
+
+ #[test]
+ fn comparison_different_digits() {
+ let reference = TOTP::new(Algorithm::SHA1, 6, 1, 1, "TestSecret");
+ let test = TOTP::new(Algorithm::SHA1, 8, 1, 1, "TestSecret");
+ assert_ne!(reference, test);
+ }
+
+ #[test]
+ fn comparison_different_skew() {
+ let reference = TOTP::new(Algorithm::SHA1, 6, 1, 1, "TestSecret");
+ let test = TOTP::new(Algorithm::SHA1, 6, 0, 1, "TestSecret");
+ assert_ne!(reference, test);
+ }
+
+ #[test]
+ fn comparison_different_step() {
+ let reference = TOTP::new(Algorithm::SHA1, 6, 1, 1, "TestSecret");
+ let test = TOTP::new(Algorithm::SHA1, 6, 1, 30, "TestSecret");
+ assert_ne!(reference, test);
+ }
+
+ #[test]
+ fn comparison_different_secret() {
+ let reference = TOTP::new(Algorithm::SHA1, 6, 1, 1, "TestSecret");
+ let test = TOTP::new(Algorithm::SHA1, 6, 1, 1, "TestSecretL");
+ assert_ne!(reference, test);
+ }
+
+ #[test]
fn url_for_secret_matches_sha1() {
let totp = TOTP::new(Algorithm::SHA1, 6, 1, 1, "TestSecret");
let url = totp.get_url("test_url", "totp-rs");