Commit 323366984799a7d573b29fe2fd31c962974dd2d0

wyhaya 2022-05-06T20:36:43

Add TOTP::from_url

diff --git a/Cargo.toml b/Cargo.toml
index 9696ff3..d4a6387 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -12,12 +12,13 @@ keywords = ["authentication", "2fa", "totp", "hmac", "otp"]
 categories = ["authentication", "web-programming"]
 
 [package.metadata.docs.rs]
-features = [ "qr", "serde_support" ]
+features = [ "qr", "serde_support", "otpauth" ]
 
 [features]
 default = []
 qr = ["qrcodegen", "image", "base64"]
 serde_support = ["serde"]
+otpauth = ["url"]
 
 [dependencies]
 serde = { version = "1.0", features = ["derive"], optional = true }
@@ -29,3 +30,4 @@ constant_time_eq = "~0.2.1"
 qrcodegen = { version = "~1.8", optional = true }
 image = { version = "~0.24.2", features = ["png"], optional = true, default-features = false}
 base64 = { version = "~0.13", optional = true }
+url = { version = "2.2.2", optional = true }
\ No newline at end of file
diff --git a/src/lib.rs b/src/lib.rs
index 64ecda5..ff14699 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -51,6 +51,9 @@ use core::fmt;
 #[cfg(feature = "qr")]
 use {base64, image::Luma, qrcodegen};
 
+#[cfg(feature = "otpauth")]
+use url::{Host, ParseError, Url};
+
 use hmac::Mac;
 use std::time::{SystemTime, SystemTimeError, UNIX_EPOCH};
 
@@ -108,6 +111,18 @@ fn system_time() -> Result<u64, SystemTimeError> {
     Ok(t)
 }
 
+#[cfg(feature = "otpauth")]
+#[derive(Debug)]
+pub enum TotpUrlError {
+    Url(ParseError),
+    Scheme,
+    Host,
+    Secret,
+    Algorithm,
+    Digits,
+    Step,
+}
+
 /// TOTP holds informations as to how to generate an auth code and validate it. Its [secret](struct.TOTP.html#structfield.secret) field is sensitive data, treat it accordingly
 #[derive(Debug, Clone)]
 #[cfg_attr(feature = "serde_support", derive(Serialize, Deserialize))]
@@ -206,6 +221,54 @@ impl<T: AsRef<[u8]>> TOTP<T> {
             self.secret.as_ref(),
         )
     }
+    
+    /// Generate a TOTP from the standard otpauth URL
+    #[cfg(feature = "otpauth")]
+    pub fn from_url<S: AsRef<str>>(url: S) -> Result<TOTP<Vec<u8>>, TotpUrlError> {
+        let url = Url::parse(url.as_ref()).map_err(|err| TotpUrlError::Url(err))?;
+        if url.scheme() != "otpauth" {
+            return Err(TotpUrlError::Scheme);
+        }
+        if url.host() != Some(Host::Domain("totp")) {
+            return Err(TotpUrlError::Host);
+        }
+        
+        let mut algorithm = Algorithm::SHA1;
+        let mut digits = 6;
+        let mut step = 30;
+        let mut secret = Vec::new();
+
+        for (key, value) in url.query_pairs() {
+            match key.as_ref() {
+                "algorithm" => {
+                    algorithm = match value.as_ref() {
+                        "SHA1" => Algorithm::SHA1,
+                        "SHA256" => Algorithm::SHA256,
+                        "SHA512" => Algorithm::SHA512,
+                        _ => return Err(TotpUrlError::Algorithm),
+                    }
+                }
+                "digits" => {
+                    digits = value.parse::<usize>().map_err(|_| TotpUrlError::Digits)?;
+                }
+                "period" => {
+                    step = value.parse::<u64>().map_err(|_| TotpUrlError::Step)?;
+                }
+                "secret" => {
+                    secret =
+                        base32::decode(base32::Alphabet::RFC4648 { padding: false }, value.as_ref())
+                            .ok_or(TotpUrlError::Secret)?;
+                }
+                _ => {}
+            }
+        }
+
+        if secret.is_empty() {
+            return Err(TotpUrlError::Secret);
+        }
+
+        Ok(TOTP::new(algorithm, digits, 1, step, secret))
+    }
 
     /// Will generate a standard URL used to automatically add TOTP auths. Usually used with qr codes
     pub fn get_url(&self, label: &str, issuer: &str) -> String {
@@ -416,6 +479,35 @@ mod tests {
     }
 
     #[test]
+    #[cfg(feature = "otpauth")]
+    fn from_url_err() {
+        assert!(TOTP::<Vec<u8>>::from_url("otpauth://hotp/123").is_err());
+        assert!(TOTP::<Vec<u8>>::from_url("otpauth://totp/GitHub:test").is_err());
+    }
+
+    #[test]
+    #[cfg(feature = "otpauth")]
+    fn from_url_default() {
+        let totp = TOTP::<Vec<u8>>::from_url("otpauth://totp/GitHub:test?secret=ABC").unwrap();
+        assert_eq!(totp.secret, base32::decode(base32::Alphabet::RFC4648 { padding: false }, "ABC").unwrap());
+        assert_eq!(totp.algorithm, Algorithm::SHA1);
+        assert_eq!(totp.digits, 6);
+        assert_eq!(totp.skew, 1);
+        assert_eq!(totp.step, 30);
+    }
+
+    #[test]
+    #[cfg(feature = "otpauth")]
+    fn from_url_query() {
+        let totp = TOTP::<Vec<u8>>::from_url("otpauth://totp/GitHub:test?secret=ABC&digits=8&period=60&algorithm=SHA256").unwrap();
+        assert_eq!(totp.secret, base32::decode(base32::Alphabet::RFC4648 { padding: false }, "ABC").unwrap());
+        assert_eq!(totp.algorithm, Algorithm::SHA256);
+        assert_eq!(totp.digits, 8);
+        assert_eq!(totp.skew, 1);
+        assert_eq!(totp.step, 60);
+    }
+
+    #[test]
     #[cfg(feature = "qr")]
     fn generates_qr() {
         use sha1::{Digest, Sha1};