Commit 0fa802f24b1bbb5b164b64d07cd8a38ed825b5d3

Daniel Mendler 2019-11-06T16:49:59

make mp_sqr private (optimization of mp_mul)

diff --git a/mp_mul.c b/mp_mul.c
index 9a83687..b2dbf7d 100644
--- a/mp_mul.c
+++ b/mp_mul.c
@@ -12,18 +12,34 @@ mp_err mp_mul(const mp_int *a, const mp_int *b, mp_int *c)
        digs = a->used + b->used + 1;
    mp_sign neg = (a->sign == b->sign) ? MP_ZPOS : MP_NEG;
 
-   if (MP_HAS(S_MP_MUL_BALANCE) &&
-       /* Check sizes. The smaller one needs to be larger than the Karatsuba cut-off.
-        * The bigger one needs to be at least about one MP_MUL_KARATSUBA_CUTOFF bigger
-        * to make some sense, but it depends on architecture, OS, position of the
-        * stars... so YMMV.
-        * Using it to cut the input into slices small enough for s_mp_mul_comba
-        * was actually slower on the author's machine, but YMMV.
-        */
-       (min >= MP_MUL_KARATSUBA_CUTOFF) &&
-       ((max / 2) >= MP_MUL_KARATSUBA_CUTOFF) &&
-       /* Not much effect was observed below a ratio of 1:2, but again: YMMV. */
-       (max >= (2 * min))) {
+   if ((a == b) &&
+       MP_HAS(S_MP_SQR_TOOM) && /* use Toom-Cook? */
+       (a->used >= MP_SQR_TOOM_CUTOFF)) {
+      err = s_mp_sqr_toom(a, c);
+   } else if ((a == b) &&
+              MP_HAS(S_MP_SQR_KARATSUBA) &&  /* Karatsuba? */
+              (a->used >= MP_SQR_KARATSUBA_CUTOFF)) {
+      err = s_mp_sqr_karatsuba(a, c);
+   } else if ((a == b) &&
+              MP_HAS(S_MP_SQR_COMBA) && /* can we use the fast comba multiplier? */
+              (((a->used * 2) + 1) < MP_WARRAY) &&
+              (a->used < (MP_MAX_COMBA / 2))) {
+      err = s_mp_sqr_comba(a, c);
+   } else if ((a == b) &&
+              MP_HAS(S_MP_SQR)) {
+      err = s_mp_sqr(a, c);
+   } else if (MP_HAS(S_MP_MUL_BALANCE) &&
+              /* Check sizes. The smaller one needs to be larger than the Karatsuba cut-off.
+               * The bigger one needs to be at least about one MP_MUL_KARATSUBA_CUTOFF bigger
+               * to make some sense, but it depends on architecture, OS, position of the
+               * stars... so YMMV.
+               * Using it to cut the input into slices small enough for s_mp_mul_comba
+               * was actually slower on the author's machine, but YMMV.
+               */
+              (min >= MP_MUL_KARATSUBA_CUTOFF) &&
+              ((max / 2) >= MP_MUL_KARATSUBA_CUTOFF) &&
+              /* Not much effect was observed below a ratio of 1:2, but again: YMMV. */
+              (max >= (2 * min))) {
       err = s_mp_mul_balance(a,b,c);
    } else if (MP_HAS(S_MP_MUL_TOOM) &&
               (min >= MP_MUL_TOOM_CUTOFF)) {
diff --git a/mp_sqr.c b/mp_sqr.c
deleted file mode 100644
index 67a8224..0000000
--- a/mp_sqr.c
+++ /dev/null
@@ -1,28 +0,0 @@
-#include "tommath_private.h"
-#ifdef MP_SQR_C
-/* LibTomMath, multiple-precision integer library -- Tom St Denis */
-/* SPDX-License-Identifier: Unlicense */
-
-/* computes b = a*a */
-mp_err mp_sqr(const mp_int *a, mp_int *b)
-{
-   mp_err err;
-   if (MP_HAS(S_MP_SQR_TOOM) && /* use Toom-Cook? */
-       (a->used >= MP_SQR_TOOM_CUTOFF)) {
-      err = s_mp_sqr_toom(a, b);
-   } else if (MP_HAS(S_MP_SQR_KARATSUBA) &&  /* Karatsuba? */
-              (a->used >= MP_SQR_KARATSUBA_CUTOFF)) {
-      err = s_mp_sqr_karatsuba(a, b);
-   } else if (MP_HAS(S_MP_SQR_COMBA) && /* can we use the fast comba multiplier? */
-              (((a->used * 2) + 1) < MP_WARRAY) &&
-              (a->used < (MP_MAX_COMBA / 2))) {
-      err = s_mp_sqr_comba(a, b);
-   } else if (MP_HAS(S_MP_SQR)) {
-      err = s_mp_sqr(a, b);
-   } else {
-      err = MP_VAL;
-   }
-   b->sign = MP_ZPOS;
-   return err;
-}
-#endif
diff --git a/tommath.h b/tommath.h
index 68a1592..01a6d34 100644
--- a/tommath.h
+++ b/tommath.h
@@ -366,7 +366,7 @@ mp_err mp_sub(const mp_int *a, const mp_int *b, mp_int *c) MP_WUR;
 mp_err mp_mul(const mp_int *a, const mp_int *b, mp_int *c) MP_WUR;
 
 /* b = a*a  */
-mp_err mp_sqr(const mp_int *a, mp_int *b) MP_WUR;
+#define mp_sqr(a, b) mp_mul((a), (a), (b))
 
 /* a/b => cb + d == a */
 mp_err mp_div(const mp_int *a, const mp_int *b, mp_int *c, mp_int *d) MP_WUR;