Commit 7b6c6965bb947a92c9d03d9224dcfa379ae72d39

Daniel Mendler 2019-10-29T20:05:30

simplifications: toom and karatsuba

diff --git a/s_mp_balance_mul.c b/s_mp_balance_mul.c
index 4108830..77852a4 100644
--- a/s_mp_balance_mul.c
+++ b/s_mp_balance_mul.c
@@ -6,15 +6,11 @@
 /* single-digit multiplication with the smaller number as the single-digit */
 mp_err s_mp_balance_mul(const mp_int *a, const mp_int *b, mp_int *c)
 {
-   int count, len_a, len_b, nblocks, i, j, bsize;
-   mp_int a0, tmp, A, B, r;
+   mp_int a0, tmp, r;
    mp_err err;
-
-   len_a = a->used;
-   len_b = b->used;
-
-   nblocks = MP_MAX(a->used, b->used) / MP_MIN(a->used, b->used);
-   bsize = MP_MIN(a->used, b->used) ;
+   int i, j, count,
+       nblocks = MP_MAX(a->used, b->used) / MP_MIN(a->used, b->used),
+       bsize = MP_MIN(a->used, b->used);
 
    if ((err = mp_init_size(&a0, bsize + 2)) != MP_OKAY) {
       return err;
@@ -25,24 +21,20 @@ mp_err s_mp_balance_mul(const mp_int *a, const mp_int *b, mp_int *c)
    }
 
    /* Make sure that A is the larger one*/
-   if (len_a < len_b) {
-      B = *a;
-      A = *b;
-   } else {
-      A = *a;
-      B = *b;
+   if (a->used < b->used) {
+      MP_EXCH(const mp_int *, a, b);
    }
 
    for (i = 0, j=0; i < nblocks; i++) {
       /* Cut a slice off of a */
       a0.used = 0;
       for (count = 0; count < bsize; count++) {
-         a0.dp[count] = A.dp[ j++ ];
+         a0.dp[count] = a->dp[ j++ ];
          a0.used++;
       }
       mp_clamp(&a0);
       /* Multiply with b */
-      if ((err = mp_mul(&a0, &B, &tmp)) != MP_OKAY) {
+      if ((err = mp_mul(&a0, b, &tmp)) != MP_OKAY) {
          goto LBL_ERR;
       }
       /* Shift tmp to the correct position */
@@ -55,14 +47,14 @@ mp_err s_mp_balance_mul(const mp_int *a, const mp_int *b, mp_int *c)
       }
    }
    /* The left-overs; there are always left-overs */
-   if (j < A.used) {
+   if (j < a->used) {
       a0.used = 0;
-      for (count = 0; j < A.used; count++) {
-         a0.dp[count] = A.dp[ j++ ];
+      for (count = 0; j < a->used; count++) {
+         a0.dp[count] = a->dp[ j++ ];
          a0.used++;
       }
       mp_clamp(&a0);
-      if ((err = mp_mul(&a0, &B, &tmp)) != MP_OKAY) {
+      if ((err = mp_mul(&a0, b, &tmp)) != MP_OKAY) {
          goto LBL_ERR;
       }
       if ((err = mp_lshd(&tmp, bsize * i)) != MP_OKAY) {
diff --git a/s_mp_karatsuba_mul.c b/s_mp_karatsuba_mul.c
index df3daa7..762e5e2 100644
--- a/s_mp_karatsuba_mul.c
+++ b/s_mp_karatsuba_mul.c
@@ -35,8 +35,8 @@
 mp_err s_mp_karatsuba_mul(const mp_int *a, const mp_int *b, mp_int *c)
 {
    mp_int  x0, x1, y0, y1, t1, x0y0, x1y1;
-   int     B;
-   mp_err  err = MP_MEM; /* default the return code to an error */
+   int  B, i;
+   mp_err  err;
 
    /* min # of digits */
    B = MP_MIN(a->used, b->used);
@@ -45,27 +45,27 @@ mp_err s_mp_karatsuba_mul(const mp_int *a, const mp_int *b, mp_int *c)
    B = B >> 1;
 
    /* init copy all the temps */
-   if (mp_init_size(&x0, B) != MP_OKAY) {
+   if ((err = mp_init_size(&x0, B)) != MP_OKAY) {
       goto LBL_ERR;
    }
-   if (mp_init_size(&x1, a->used - B) != MP_OKAY) {
+   if ((err = mp_init_size(&x1, a->used - B)) != MP_OKAY) {
       goto X0;
    }
-   if (mp_init_size(&y0, B) != MP_OKAY) {
+   if ((err = mp_init_size(&y0, B)) != MP_OKAY) {
       goto X1;
    }
-   if (mp_init_size(&y1, b->used - B) != MP_OKAY) {
+   if ((err = mp_init_size(&y1, b->used - B)) != MP_OKAY) {
       goto Y0;
    }
 
    /* init temps */
-   if (mp_init_size(&t1, B * 2) != MP_OKAY) {
+   if ((err = mp_init_size(&t1, B * 2)) != MP_OKAY) {
       goto Y1;
    }
-   if (mp_init_size(&x0y0, B * 2) != MP_OKAY) {
+   if ((err = mp_init_size(&x0y0, B * 2)) != MP_OKAY) {
       goto T1;
    }
-   if (mp_init_size(&x1y1, B * 2) != MP_OKAY) {
+   if ((err = mp_init_size(&x1y1, B * 2)) != MP_OKAY) {
       goto X0Y0;
    }
 
@@ -74,32 +74,18 @@ mp_err s_mp_karatsuba_mul(const mp_int *a, const mp_int *b, mp_int *c)
    x1.used = a->used - B;
    y1.used = b->used - B;
 
-   {
-      int x;
-      mp_digit *tmpa, *tmpb, *tmpx, *tmpy;
-
-      /* we copy the digits directly instead of using higher level functions
-       * since we also need to shift the digits
-       */
-      tmpa = a->dp;
-      tmpb = b->dp;
-
-      tmpx = x0.dp;
-      tmpy = y0.dp;
-      for (x = 0; x < B; x++) {
-         *tmpx++ = *tmpa++;
-         *tmpy++ = *tmpb++;
-      }
-
-      tmpx = x1.dp;
-      for (x = B; x < a->used; x++) {
-         *tmpx++ = *tmpa++;
-      }
-
-      tmpy = y1.dp;
-      for (x = B; x < b->used; x++) {
-         *tmpy++ = *tmpb++;
-      }
+   /* we copy the digits directly instead of using higher level functions
+    * since we also need to shift the digits
+    */
+   for (i = 0; i < B; i++) {
+      x0.dp[i] = a->dp[i];
+      y0.dp[i] = b->dp[i];
+   }
+   for (i = B; i < a->used; i++) {
+      x1.dp[i - B] = a->dp[i];
+   }
+   for (i = B; i < b->used; i++) {
+      y1.dp[i - B] = b->dp[i];
    }
 
    /* only need to clamp the lower words since by definition the
@@ -110,50 +96,47 @@ mp_err s_mp_karatsuba_mul(const mp_int *a, const mp_int *b, mp_int *c)
 
    /* now calc the products x0y0 and x1y1 */
    /* after this x0 is no longer required, free temp [x0==t2]! */
-   if (mp_mul(&x0, &y0, &x0y0) != MP_OKAY) {
+   if ((err = mp_mul(&x0, &y0, &x0y0)) != MP_OKAY) {
       goto X1Y1;          /* x0y0 = x0*y0 */
    }
-   if (mp_mul(&x1, &y1, &x1y1) != MP_OKAY) {
+   if ((err = mp_mul(&x1, &y1, &x1y1)) != MP_OKAY) {
       goto X1Y1;          /* x1y1 = x1*y1 */
    }
 
    /* now calc x1+x0 and y1+y0 */
-   if (s_mp_add(&x1, &x0, &t1) != MP_OKAY) {
+   if ((err = s_mp_add(&x1, &x0, &t1)) != MP_OKAY) {
       goto X1Y1;          /* t1 = x1 - x0 */
    }
-   if (s_mp_add(&y1, &y0, &x0) != MP_OKAY) {
+   if ((err = s_mp_add(&y1, &y0, &x0)) != MP_OKAY) {
       goto X1Y1;          /* t2 = y1 - y0 */
    }
-   if (mp_mul(&t1, &x0, &t1) != MP_OKAY) {
+   if ((err = mp_mul(&t1, &x0, &t1)) != MP_OKAY) {
       goto X1Y1;          /* t1 = (x1 + x0) * (y1 + y0) */
    }
 
    /* add x0y0 */
-   if (mp_add(&x0y0, &x1y1, &x0) != MP_OKAY) {
+   if ((err = mp_add(&x0y0, &x1y1, &x0)) != MP_OKAY) {
       goto X1Y1;          /* t2 = x0y0 + x1y1 */
    }
-   if (s_mp_sub(&t1, &x0, &t1) != MP_OKAY) {
+   if ((err = s_mp_sub(&t1, &x0, &t1)) != MP_OKAY) {
       goto X1Y1;          /* t1 = (x1+x0)*(y1+y0) - (x1y1 + x0y0) */
    }
 
    /* shift by B */
-   if (mp_lshd(&t1, B) != MP_OKAY) {
+   if ((err = mp_lshd(&t1, B)) != MP_OKAY) {
       goto X1Y1;          /* t1 = (x0y0 + x1y1 - (x1-x0)*(y1-y0))<<B */
    }
-   if (mp_lshd(&x1y1, B * 2) != MP_OKAY) {
+   if ((err = mp_lshd(&x1y1, B * 2)) != MP_OKAY) {
       goto X1Y1;          /* x1y1 = x1y1 << 2*B */
    }
 
-   if (mp_add(&x0y0, &t1, &t1) != MP_OKAY) {
+   if ((err = mp_add(&x0y0, &t1, &t1)) != MP_OKAY) {
       goto X1Y1;          /* t1 = x0y0 + t1 */
    }
-   if (mp_add(&t1, &x1y1, c) != MP_OKAY) {
+   if ((err = mp_add(&t1, &x1y1, c)) != MP_OKAY) {
       goto X1Y1;          /* t1 = x0y0 + t1 + x1y1 */
    }
 
-   /* Algorithm succeeded set the return code to MP_OKAY */
-   err = MP_OKAY;
-
 X1Y1:
    mp_clear(&x1y1);
 X0Y0:
diff --git a/s_mp_karatsuba_sqr.c b/s_mp_karatsuba_sqr.c
index 7f22842..824fcdc 100644
--- a/s_mp_karatsuba_sqr.c
+++ b/s_mp_karatsuba_sqr.c
@@ -13,8 +13,8 @@
 mp_err s_mp_karatsuba_sqr(const mp_int *a, mp_int *b)
 {
    mp_int  x0, x1, t1, t2, x0x0, x1x1;
-   int     B;
-   mp_err  err = MP_MEM;
+   int B, x;
+   mp_err  err;
 
    /* min # of digits */
    B = a->used;
@@ -23,37 +23,27 @@ mp_err s_mp_karatsuba_sqr(const mp_int *a, mp_int *b)
    B = B >> 1;
 
    /* init copy all the temps */
-   if (mp_init_size(&x0, B) != MP_OKAY)
+   if ((err = mp_init_size(&x0, B)) != MP_OKAY)
       goto LBL_ERR;
-   if (mp_init_size(&x1, a->used - B) != MP_OKAY)
+   if ((err = mp_init_size(&x1, a->used - B)) != MP_OKAY)
       goto X0;
 
    /* init temps */
-   if (mp_init_size(&t1, a->used * 2) != MP_OKAY)
+   if ((err = mp_init_size(&t1, a->used * 2)) != MP_OKAY)
       goto X1;
-   if (mp_init_size(&t2, a->used * 2) != MP_OKAY)
+   if ((err = mp_init_size(&t2, a->used * 2)) != MP_OKAY)
       goto T1;
-   if (mp_init_size(&x0x0, B * 2) != MP_OKAY)
+   if ((err = mp_init_size(&x0x0, B * 2)) != MP_OKAY)
       goto T2;
-   if (mp_init_size(&x1x1, (a->used - B) * 2) != MP_OKAY)
+   if ((err = mp_init_size(&x1x1, (a->used - B) * 2)) != MP_OKAY)
       goto X0X0;
 
-   {
-      int x;
-      mp_digit *dst, *src;
-
-      src = a->dp;
-
-      /* now shift the digits */
-      dst = x0.dp;
-      for (x = 0; x < B; x++) {
-         *dst++ = *src++;
-      }
-
-      dst = x1.dp;
-      for (x = B; x < a->used; x++) {
-         *dst++ = *src++;
-      }
+   /* now shift the digits */
+   for (x = 0; x < B; x++) {
+      x0.dp[x] = a->dp[x];
+   }
+   for (x = B; x < a->used; x++) {
+      x1.dp[x - B] = a->dp[x];
    }
 
    x0.used = B;
@@ -62,36 +52,34 @@ mp_err s_mp_karatsuba_sqr(const mp_int *a, mp_int *b)
    mp_clamp(&x0);
 
    /* now calc the products x0*x0 and x1*x1 */
-   if (mp_sqr(&x0, &x0x0) != MP_OKAY)
+   if ((err = mp_sqr(&x0, &x0x0)) != MP_OKAY)
       goto X1X1;           /* x0x0 = x0*x0 */
-   if (mp_sqr(&x1, &x1x1) != MP_OKAY)
+   if ((err = mp_sqr(&x1, &x1x1)) != MP_OKAY)
       goto X1X1;           /* x1x1 = x1*x1 */
 
    /* now calc (x1+x0)**2 */
-   if (s_mp_add(&x1, &x0, &t1) != MP_OKAY)
+   if ((err = s_mp_add(&x1, &x0, &t1)) != MP_OKAY)
       goto X1X1;           /* t1 = x1 - x0 */
-   if (mp_sqr(&t1, &t1) != MP_OKAY)
+   if ((err = mp_sqr(&t1, &t1)) != MP_OKAY)
       goto X1X1;           /* t1 = (x1 - x0) * (x1 - x0) */
 
    /* add x0y0 */
-   if (s_mp_add(&x0x0, &x1x1, &t2) != MP_OKAY)
+   if ((err = s_mp_add(&x0x0, &x1x1, &t2)) != MP_OKAY)
       goto X1X1;           /* t2 = x0x0 + x1x1 */
-   if (s_mp_sub(&t1, &t2, &t1) != MP_OKAY)
+   if ((err = s_mp_sub(&t1, &t2, &t1)) != MP_OKAY)
       goto X1X1;           /* t1 = (x1+x0)**2 - (x0x0 + x1x1) */
 
    /* shift by B */
-   if (mp_lshd(&t1, B) != MP_OKAY)
+   if ((err = mp_lshd(&t1, B)) != MP_OKAY)
       goto X1X1;           /* t1 = (x0x0 + x1x1 - (x1-x0)*(x1-x0))<<B */
-   if (mp_lshd(&x1x1, B * 2) != MP_OKAY)
+   if ((err = mp_lshd(&x1x1, B * 2)) != MP_OKAY)
       goto X1X1;           /* x1x1 = x1x1 << 2*B */
 
-   if (mp_add(&x0x0, &t1, &t1) != MP_OKAY)
+   if ((err = mp_add(&x0x0, &t1, &t1)) != MP_OKAY)
       goto X1X1;           /* t1 = x0x0 + t1 */
-   if (mp_add(&t1, &x1x1, b) != MP_OKAY)
+   if ((err = mp_add(&t1, &x1x1, b)) != MP_OKAY)
       goto X1X1;           /* t1 = x0x0 + t1 + x1x1 */
 
-   err = MP_OKAY;
-
 X1X1:
    mp_clear(&x1x1);
 X0X0:
diff --git a/s_mp_toom_sqr.c b/s_mp_toom_sqr.c
index 67c465c..d8f2f8e 100644
--- a/s_mp_toom_sqr.c
+++ b/s_mp_toom_sqr.c
@@ -21,11 +21,9 @@
 mp_err s_mp_toom_sqr(const mp_int *a, mp_int *b)
 {
    mp_int S0, a0, a1, a2;
-   mp_digit *tmpa, *tmpc;
    int B, count;
    mp_err err;
 
-
    /* init temps */
    if ((err = mp_init(&S0)) != MP_OKAY) {
       return err;
@@ -42,18 +40,14 @@ mp_err s_mp_toom_sqr(const mp_int *a, mp_int *b)
    a1.used = B;
    if ((err = mp_init_size(&a2, B + (a->used - (3 * B)))) != MP_OKAY) goto LBL_ERRa2;
 
-   tmpa = a->dp;
-   tmpc = a0.dp;
    for (count = 0; count < B; count++) {
-      *tmpc++ = *tmpa++;
+      a0.dp[count] = a->dp[count];
    }
-   tmpc = a1.dp;
    for (; count < (2 * B); count++) {
-      *tmpc++ = *tmpa++;
+      a1.dp[count - B] = a->dp[count];
    }
-   tmpc = a2.dp;
    for (; count < a->used; count++) {
-      *tmpc++ = *tmpa++;
+      a2.dp[count - 2 * B] = a->dp[count];
       a2.used++;
    }
    mp_clamp(&a0);