Tag
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142
/* LibTomMath, multiple-precision integer library -- Tom St Denis
*
* LibTomMath is library that provides for multiple-precision
* integer arithmetic as well as number theoretic functionality.
*
* The library is designed directly after the MPI library by
* Michael Fromberger but has been written from scratch with
* additional optimizations in place.
*
* The library is free for all purposes without any express
* guarantee it works.
*
* Tom St Denis, tomstdenis@iahu.ca, http://libtommath.iahu.ca
*/
#include <tommath.h>
/* c = |a| * |b| using Karatsuba Multiplication using three half size multiplications
*
* Let B represent the radix [e.g. 2**DIGIT_BIT] and let n represent half of the number of digits in the min(a,b)
*
* a = a1 * B^n + a0
* b = b1 * B^n + b0
*
* Then, a * b => a1b1 * B^2n + ((a1 - b1)(a0 - b0) + a0b0 + a1b1) * B + a0b0
*
* Note that a1b1 and a0b0 are used twice and only need to be computed once. So in total
* three half size (half # of digit) multiplications are performed, a0b0, a1b1 and (a1-b1)(a0-b0)
*
* Note that a multiplication of half the digits requires 1/4th the number of single precision
* multiplications so in total after one call 25% of the single precision multiplications are saved.
* Note also that the call to mp_mul can end up back in this function if the a0, a1, b0, or b1 are above
* the threshold. This is known as divide-and-conquer and leads to the famous O(N^lg(3)) or O(N^1.584) work which
* is asymptopically lower than the standard O(N^2) that the baseline/comba methods use. Generally though the
* overhead of this method doesn't pay off until a certain size (N ~ 80) is reached.
*/
int
mp_karatsuba_mul (mp_int * a, mp_int * b, mp_int * c)
{
mp_int x0, x1, y0, y1, t1, t2, x0y0, x1y1;
int B, err, x;
err = MP_MEM;
/* min # of digits */
B = MIN (a->used, b->used);
/* now divide in two */
B = B / 2;
/* init copy all the temps */
if (mp_init_size (&x0, B) != MP_OKAY)
goto ERR;
if (mp_init_size (&x1, a->used - B) != MP_OKAY)
goto X0;
if (mp_init_size (&y0, B) != MP_OKAY)
goto X1;
if (mp_init_size (&y1, b->used - B) != MP_OKAY)
goto Y0;
/* init temps */
if (mp_init (&t1) != MP_OKAY)
goto Y1;
if (mp_init (&t2) != MP_OKAY)
goto T1;
if (mp_init (&x0y0) != MP_OKAY)
goto T2;
if (mp_init (&x1y1) != MP_OKAY)
goto X0Y0;
/* now shift the digits */
x0.sign = x1.sign = a->sign;
y0.sign = y1.sign = b->sign;
x0.used = y0.used = B;
x1.used = a->used - B;
y1.used = b->used - B;
/* we copy the digits directly instead of using higher level functions
* since we also need to shift the digits
*/
for (x = 0; x < B; x++) {
x0.dp[x] = a->dp[x];
y0.dp[x] = b->dp[x];
}
for (x = B; x < a->used; x++) {
x1.dp[x - B] = a->dp[x];
}
for (x = B; x < b->used; x++) {
y1.dp[x - B] = b->dp[x];
}
/* only need to clamp the lower words since by definition the upper words x1/y1 must
* have a known number of digits
*/
mp_clamp (&x0);
mp_clamp (&y0);
/* now calc the products x0y0 and x1y1 */
if (mp_mul (&x0, &y0, &x0y0) != MP_OKAY)
goto X1Y1; /* x0y0 = x0*y0 */
if (mp_mul (&x1, &y1, &x1y1) != MP_OKAY)
goto X1Y1; /* x1y1 = x1*y1 */
/* now calc x1-x0 and y1-y0 */
if (mp_sub (&x1, &x0, &t1) != MP_OKAY)
goto X1Y1; /* t1 = x1 - x0 */
if (mp_sub (&y1, &y0, &t2) != MP_OKAY)
goto X1Y1; /* t2 = y1 - y0 */
if (mp_mul (&t1, &t2, &t1) != MP_OKAY)
goto X1Y1; /* t1 = (x1 - x0) * (y1 - y0) */
/* add x0y0 */
if (mp_add (&x0y0, &x1y1, &t2) != MP_OKAY)
goto X1Y1; /* t2 = x0y0 + x1y1 */
if (mp_sub (&t2, &t1, &t1) != MP_OKAY)
goto X1Y1; /* t1 = x0y0 + x1y1 - (x1-x0)*(y1-y0) */
/* shift by B */
if (mp_lshd (&t1, B) != MP_OKAY)
goto X1Y1; /* t1 = (x0y0 + x1y1 - (x1-x0)*(y1-y0))<<B */
if (mp_lshd (&x1y1, B * 2) != MP_OKAY)
goto X1Y1; /* x1y1 = x1y1 << 2*B */
if (mp_add (&x0y0, &t1, &t1) != MP_OKAY)
goto X1Y1; /* t1 = x0y0 + t1 */
if (mp_add (&t1, &x1y1, c) != MP_OKAY)
goto X1Y1; /* t1 = x0y0 + t1 + x1y1 */
err = MP_OKAY;
X1Y1:mp_clear (&x1y1);
X0Y0:mp_clear (&x0y0);
T2:mp_clear (&t2);
T1:mp_clear (&t1);
Y1:mp_clear (&y1);
Y0:mp_clear (&y0);
X1:mp_clear (&x1);
X0:mp_clear (&x0);
ERR:
return err;
}