Handle corner-cases of invmod() This fixes #67
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39
diff --git a/bn_fast_mp_invmod.c b/bn_fast_mp_invmod.c
index 08389dd..cabed0c 100644
--- a/bn_fast_mp_invmod.c
+++ b/bn_fast_mp_invmod.c
@@ -46,6 +46,12 @@ int fast_mp_invmod(const mp_int *a, const mp_int *b, mp_int *c)
goto LBL_ERR;
}
+ /* if one of x,y is zero return an error! */
+ if ((mp_iszero(&x) == MP_YES) || (mp_iszero(&y) == MP_YES)) {
+ res = MP_VAL;
+ goto LBL_ERR;
+ }
+
/* 3. u=x, v=y, A=1, B=0, C=0,D=1 */
if ((res = mp_copy(&x, &u)) != MP_OKAY) {
goto LBL_ERR;
diff --git a/bn_mp_invmod.c b/bn_mp_invmod.c
index 525493a..528b0c7 100644
--- a/bn_mp_invmod.c
+++ b/bn_mp_invmod.c
@@ -18,14 +18,14 @@
/* hac 14.61, pp608 */
int mp_invmod(const mp_int *a, const mp_int *b, mp_int *c)
{
- /* b cannot be negative */
- if ((b->sign == MP_NEG) || (mp_iszero(b) == MP_YES)) {
+ /* b cannot be negative and has to be >1 */
+ if ((b->sign == MP_NEG) || (mp_cmp_d(b, 1) != MP_GT)) {
return MP_VAL;
}
#ifdef BN_FAST_MP_INVMOD_C
/* if the modulus is odd we can use a faster routine instead */
- if ((mp_isodd(b) == MP_YES) && (mp_cmp_d(b, 1) != MP_EQ)) {
+ if ((mp_isodd(b) == MP_YES)) {
return fast_mp_invmod(a, b, c);
}
#endif