bugfix for an edgecase
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122
diff --git a/bn_s_mp_balance_mul.c b/bn_s_mp_balance_mul.c
index efc1809..7ece5d7 100644
--- a/bn_s_mp_balance_mul.c
+++ b/bn_s_mp_balance_mul.c
@@ -40,6 +40,7 @@ mp_err s_mp_balance_mul(const mp_int *a, const mp_int *b, mp_int *c)
a0.dp[count] = A.dp[ j++ ];
a0.used++;
}
+ mp_clamp(&a0);
/* Multiply with b */
if ((err = mp_mul(&a0, &B, &tmp)) != MP_OKAY) {
goto LBL_ERR;
@@ -60,6 +61,7 @@ mp_err s_mp_balance_mul(const mp_int *a, const mp_int *b, mp_int *c)
a0.dp[count] = A.dp[ j++ ];
a0.used++;
}
+ mp_clamp(&a0);
if ((err = mp_mul(&a0, &B, &tmp)) != MP_OKAY) {
goto LBL_ERR;
}
diff --git a/bn_s_mp_toom_mul.c b/bn_s_mp_toom_mul.c
index 35f1e03..8efd803 100644
--- a/bn_s_mp_toom_mul.c
+++ b/bn_s_mp_toom_mul.c
@@ -61,6 +61,7 @@ mp_err s_mp_toom_mul(const mp_int *a, const mp_int *b, mp_int *c)
a2.dp[count - (2 * B)] = a->dp[count];
a2.used++;
}
+ mp_clamp(&a2);
/** b = b2 * x^2 + b1 * x + b0; */
if ((err = mp_init_size(&b0, B)) != MP_OKAY) goto LBL_ERRb0;
@@ -80,6 +81,7 @@ mp_err s_mp_toom_mul(const mp_int *a, const mp_int *b, mp_int *c)
b2.dp[count - (2 * B)] = b->dp[count];
b2.used++;
}
+ mp_clamp(&b2);
/** \\ S1 = (a2+a1+a0) * (b2+b1+b0); */
/** T1 = a2 + a1; */
diff --git a/bn_s_mp_toom_sqr.c b/bn_s_mp_toom_sqr.c
index e6be8a8..9eaa9d0 100644
--- a/bn_s_mp_toom_sqr.c
+++ b/bn_s_mp_toom_sqr.c
@@ -57,6 +57,7 @@ mp_err s_mp_toom_sqr(const mp_int *a, mp_int *b)
}
mp_clamp(&a0);
mp_clamp(&a1);
+ mp_clamp(&a2);
/** S0 = a0^2; */
if ((err = mp_sqr(&a0, &S0)) != MP_OKAY) goto LBL_ERR;
diff --git a/demo/test.c b/demo/test.c
index cfabf37..7d4f065 100644
--- a/demo/test.c
+++ b/demo/test.c
@@ -2154,9 +2154,53 @@ static int test_s_mp_toom_mul(void)
mp_int a, b, c, d;
int size, err;
+#if (MP_DIGIT_BIT == 60)
+ int tc_cutoff;
+#endif
+
if ((err = mp_init_multi(&a, &b, &c, &d, NULL)) != MP_OKAY) {
goto LTM_ERR;
}
+ /* This number construction is limb-size specific */
+#if (MP_DIGIT_BIT == 60)
+ if ((err = mp_rand(&a, 1196)) != MP_OKAY) {
+ goto LTM_ERR;
+ }
+ if ((err = mp_mul_2d(&a,71787 - mp_count_bits(&a), &a)) != MP_OKAY) {
+ goto LTM_ERR;
+ }
+
+ if ((err = mp_rand(&b, 1338)) != MP_OKAY) {
+ goto LTM_ERR;
+ }
+ if ((err = mp_mul_2d(&b, 80318 - mp_count_bits(&b), &b)) != MP_OKAY) {
+ goto LTM_ERR;
+ }
+ if ((err = mp_mul_2d(&b, 6310, &b)) != MP_OKAY) {
+ goto LTM_ERR;
+ }
+ if ((err = mp_2expt(&c, 99000 - 1000)) != MP_OKAY) {
+ goto LTM_ERR;
+ }
+ if ((err = mp_add(&b, &c, &b)) != MP_OKAY) {
+ goto LTM_ERR;
+ }
+
+ tc_cutoff = TOOM_MUL_CUTOFF;
+ TOOM_MUL_CUTOFF = INT_MAX;
+ if ((err = mp_mul(&a, &b, &c)) != MP_OKAY) {
+ goto LTM_ERR;
+ }
+ TOOM_MUL_CUTOFF = tc_cutoff;
+ if ((err = mp_mul(&a, &b, &d)) != MP_OKAY) {
+ goto LTM_ERR;
+ }
+ if (mp_cmp(&c, &d) != MP_EQ) {
+ fprintf(stderr, "Toom-Cook 3-way multiplication failed for edgecase f1 * f2\n");
+ goto LTM_ERR;
+ }
+#endif
+
for (size = MP_TOOM_MUL_CUTOFF; size < MP_TOOM_MUL_CUTOFF + 20; size++) {
if ((err = mp_rand(&a, size)) != MP_OKAY) {
goto LTM_ERR;
diff --git a/tommath_class.h b/tommath_class.h
index 2eefad7..a60a757 100644
--- a/tommath_class.h
+++ b/tommath_class.h
@@ -1102,6 +1102,7 @@
#if defined(BN_S_MP_BALANCE_MUL_C)
# define BN_MP_ADD_C
+# define BN_MP_CLAMP_C
# define BN_MP_CLEAR_C
# define BN_MP_CLEAR_MULTI_C
# define BN_MP_EXCH_C