Merge pull request #487 from czurnieden/DoS_sqrt_mod Added checks for input in mp_sqrtmod_prime that caused infinite loops
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172
diff --git a/demo/test.c b/demo/test.c
index 2d1d774..f6b3c36 100644
--- a/demo/test.c
+++ b/demo/test.c
@@ -707,9 +707,9 @@ static int test_mp_sqrtmod_prime(void)
};
static struct mp_sqrtmod_prime_st sqrtmod_prime[] = {
- { 5, 14, 3 },
- { 7, 9, 4 },
- { 113, 2, 62 }
+ { 5, 14, 3 }, /* 5 \cong 1 (mod 4) */
+ { 7, 9, 4 }, /* 7 \cong 3 (mod 4) */
+ { 113, 2, 62 } /* 113 \cong 1 (mod 4) */
};
int i;
@@ -723,6 +723,14 @@ static int test_mp_sqrtmod_prime(void)
DO(mp_sqrtmod_prime(&b, &a, &c));
EXPECT(mp_cmp_d(&c, sqrtmod_prime[i].r) == MP_EQ);
}
+ /* Check handling of wrong input (here: modulus is square and cong. 1 mod 4,24 ) */
+ mp_set_ul(&a, 25);
+ mp_set_ul(&b, 2);
+ EXPECT(mp_sqrtmod_prime(&b, &a, &c) == MP_VAL);
+ /* b \cong 0 (mod a) */
+ mp_set_ul(&a, 45);
+ mp_set_ul(&b, 3);
+ EXPECT(mp_sqrtmod_prime(&b, &a, &c) == MP_VAL);
mp_clear_multi(&a, &b, &c, NULL);
return EXIT_SUCCESS;
diff --git a/mp_sqrtmod_prime.c b/mp_sqrtmod_prime.c
index 8930184..0fae1d0 100644
--- a/mp_sqrtmod_prime.c
+++ b/mp_sqrtmod_prime.c
@@ -13,19 +13,23 @@ mp_err mp_sqrtmod_prime(const mp_int *n, const mp_int *prime, mp_int *ret)
{
mp_err err;
int legendre;
- mp_int t1, C, Q, S, Z, M, T, R, two;
- mp_digit i;
+ /* The type is "int" because of the types in the mp_int struct.
+ Don't forget to change them here when you change them there! */
+ int S, M, i;
+ mp_int t1, C, Q, Z, T, R, two;
/* first handle the simple cases */
if (mp_cmp_d(n, 0uL) == MP_EQ) {
mp_zero(ret);
return MP_OKAY;
}
- if (mp_cmp_d(prime, 2uL) == MP_EQ) return MP_VAL; /* prime must be odd */
- if ((err = mp_kronecker(n, prime, &legendre)) != MP_OKAY) return err;
- if (legendre == -1) return MP_VAL; /* quadratic non-residue mod prime */
+ /* "prime" must be odd and > 2 */
+ if (mp_iseven(prime) || (mp_cmp_d(prime, 3uL) == MP_LT)) return MP_VAL;
+ if ((err = mp_kronecker(n, prime, &legendre)) != MP_OKAY) return err;
+ /* n \not\cong 0 (mod p) and n \cong r^2 (mod p) for some r \in N^+ */
+ if (legendre != 1) return MP_VAL;
- if ((err = mp_init_multi(&t1, &C, &Q, &S, &Z, &M, &T, &R, &two, NULL)) != MP_OKAY) {
+ if ((err = mp_init_multi(&t1, &C, &Q, &Z, &T, &R, &two, NULL)) != MP_OKAY) {
return err;
}
@@ -33,8 +37,8 @@ mp_err mp_sqrtmod_prime(const mp_int *n, const mp_int *prime, mp_int *ret)
* compute directly: err = n^(prime+1)/4 mod prime
* Handbook of Applied Cryptography algorithm 3.36
*/
- if ((err = mp_mod_d(prime, 4uL, &i)) != MP_OKAY) goto LBL_END;
- if (i == 3u) {
+ /* x%4 == x&3 for x in N and x>0 */
+ if ((prime->dp[0] & 3u) == 3u) {
if ((err = mp_add_d(prime, 1uL, &t1)) != MP_OKAY) goto LBL_END;
if ((err = mp_div_2(&t1, &t1)) != MP_OKAY) goto LBL_END;
if ((err = mp_div_2(&t1, &t1)) != MP_OKAY) goto LBL_END;
@@ -49,12 +53,12 @@ mp_err mp_sqrtmod_prime(const mp_int *n, const mp_int *prime, mp_int *ret)
if ((err = mp_copy(prime, &Q)) != MP_OKAY) goto LBL_END;
if ((err = mp_sub_d(&Q, 1uL, &Q)) != MP_OKAY) goto LBL_END;
/* Q = prime - 1 */
- mp_zero(&S);
+ S = 0;
/* S = 0 */
while (mp_iseven(&Q)) {
if ((err = mp_div_2(&Q, &Q)) != MP_OKAY) goto LBL_END;
/* Q = Q / 2 */
- if ((err = mp_add_d(&S, 1uL, &S)) != MP_OKAY) goto LBL_END;
+ S++;
/* S = S + 1 */
}
@@ -63,6 +67,12 @@ mp_err mp_sqrtmod_prime(const mp_int *n, const mp_int *prime, mp_int *ret)
/* Z = 2 */
for (;;) {
if ((err = mp_kronecker(&Z, prime, &legendre)) != MP_OKAY) goto LBL_END;
+ /* If "prime" (p) is an odd prime Jacobi(k|p) = 0 for k \cong 0 (mod p) */
+ /* but there is at least one non-quadratic residue before k>=p if p is an odd prime. */
+ if (legendre == 0) {
+ err = MP_VAL;
+ goto LBL_END;
+ }
if (legendre == -1) break;
if ((err = mp_add_d(&Z, 1uL, &Z)) != MP_OKAY) goto LBL_END;
/* Z = Z + 1 */
@@ -77,7 +87,7 @@ mp_err mp_sqrtmod_prime(const mp_int *n, const mp_int *prime, mp_int *ret)
/* R = n ^ ((Q + 1) / 2) mod prime */
if ((err = mp_exptmod(n, &Q, prime, &T)) != MP_OKAY) goto LBL_END;
/* T = n ^ Q mod prime */
- if ((err = mp_copy(&S, &M)) != MP_OKAY) goto LBL_END;
+ M = S;
/* M = S */
mp_set(&two, 2uL);
@@ -86,16 +96,21 @@ mp_err mp_sqrtmod_prime(const mp_int *n, const mp_int *prime, mp_int *ret)
i = 0;
for (;;) {
if (mp_cmp_d(&t1, 1uL) == MP_EQ) break;
+ /* No exponent in the range 0 < i < M found
+ (M is at least 1 in the first round because "prime" > 2) */
+ if (M == i) {
+ err = MP_VAL;
+ goto LBL_END;
+ }
if ((err = mp_exptmod(&t1, &two, prime, &t1)) != MP_OKAY) goto LBL_END;
i++;
}
- if (i == 0u) {
+ if (i == 0) {
if ((err = mp_copy(&R, ret)) != MP_OKAY) goto LBL_END;
err = MP_OKAY;
goto LBL_END;
}
- if ((err = mp_sub_d(&M, i, &t1)) != MP_OKAY) goto LBL_END;
- if ((err = mp_sub_d(&t1, 1uL, &t1)) != MP_OKAY) goto LBL_END;
+ mp_set_i32(&t1, M - i - 1);
if ((err = mp_exptmod(&two, &t1, prime, &t1)) != MP_OKAY) goto LBL_END;
/* t1 = 2 ^ (M - i - 1) */
if ((err = mp_exptmod(&C, &t1, prime, &t1)) != MP_OKAY) goto LBL_END;
@@ -106,12 +121,12 @@ mp_err mp_sqrtmod_prime(const mp_int *n, const mp_int *prime, mp_int *ret)
/* R = (R * t1) mod prime */
if ((err = mp_mulmod(&T, &C, prime, &T)) != MP_OKAY) goto LBL_END;
/* T = (T * C) mod prime */
- mp_set(&M, i);
+ M = i;
/* M = i */
}
LBL_END:
- mp_clear_multi(&t1, &C, &Q, &S, &Z, &M, &T, &R, &two, NULL);
+ mp_clear_multi(&t1, &C, &Q, &Z, &T, &R, &two, NULL);
return err;
}
diff --git a/tommath_class.h b/tommath_class.h
index 0fe046f..68055cc 100644
--- a/tommath_class.h
+++ b/tommath_class.h
@@ -872,12 +872,12 @@
# define MP_CMP_D_C
# define MP_COPY_C
# define MP_DIV_2_C
-# define MP_DIV_D_C
# define MP_EXPTMOD_C
# define MP_INIT_MULTI_C
# define MP_KRONECKER_C
# define MP_MULMOD_C
# define MP_SET_C
+# define MP_SET_I32_C
# define MP_SQRMOD_C
# define MP_SUB_D_C
# define MP_ZERO_C