Merge branch 'fix/67' into develop This closes #67
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59
diff --git a/bn_fast_mp_invmod.c b/bn_fast_mp_invmod.c
index 08389dd..cabed0c 100644
--- a/bn_fast_mp_invmod.c
+++ b/bn_fast_mp_invmod.c
@@ -46,6 +46,12 @@ int fast_mp_invmod(const mp_int *a, const mp_int *b, mp_int *c)
goto LBL_ERR;
}
+ /* if one of x,y is zero return an error! */
+ if ((mp_iszero(&x) == MP_YES) || (mp_iszero(&y) == MP_YES)) {
+ res = MP_VAL;
+ goto LBL_ERR;
+ }
+
/* 3. u=x, v=y, A=1, B=0, C=0,D=1 */
if ((res = mp_copy(&x, &u)) != MP_OKAY) {
goto LBL_ERR;
diff --git a/bn_mp_invmod.c b/bn_mp_invmod.c
index 525493a..528b0c7 100644
--- a/bn_mp_invmod.c
+++ b/bn_mp_invmod.c
@@ -18,14 +18,14 @@
/* hac 14.61, pp608 */
int mp_invmod(const mp_int *a, const mp_int *b, mp_int *c)
{
- /* b cannot be negative */
- if ((b->sign == MP_NEG) || (mp_iszero(b) == MP_YES)) {
+ /* b cannot be negative and has to be >1 */
+ if ((b->sign == MP_NEG) || (mp_cmp_d(b, 1) != MP_GT)) {
return MP_VAL;
}
#ifdef BN_FAST_MP_INVMOD_C
/* if the modulus is odd we can use a faster routine instead */
- if ((mp_isodd(b) == MP_YES) && (mp_cmp_d(b, 1) != MP_EQ)) {
+ if ((mp_isodd(b) == MP_YES)) {
return fast_mp_invmod(a, b, c);
}
#endif
diff --git a/demo/demo.c b/demo/demo.c
index 7136a4c..4e59002 100644
--- a/demo/demo.c
+++ b/demo/demo.c
@@ -229,6 +229,15 @@ int main(void)
return EXIT_FAILURE;
}
+ mp_set_int(&a, 42);
+ mp_set_int(&b, 1);
+ mp_neg(&b, &b);
+ mp_set_int(&c, 1);
+ mp_exptmod(&a, &b, &c, &d);
+
+ mp_set_int(&c, 7);
+ mp_exptmod(&a, &b, &c, &d);
+
mp_set_int(&a, 0);
mp_set_int(&b, 1);