Commit cb1b2dc8797b17d731d401a9fae8dcd3d7cb329a

Daniel Mendler 2019-10-15T14:04:32

mp_log_u32: return uint32_t

diff --git a/bn_mp_log_u32.c b/bn_mp_log_u32.c
index 75cf59e..fa60ba3 100644
--- a/bn_mp_log_u32.c
+++ b/bn_mp_log_u32.c
@@ -70,11 +70,11 @@ static mp_digit s_digit_ilogb(mp_digit base, mp_digit n)
          as is the output of mp_bitcount.
          With the same problem: max size is INT_MAX * MP_DIGIT not INT_MAX only!
 */
-mp_err mp_log_u32(const mp_int *a, uint32_t base, mp_int *c)
+mp_err mp_log_u32(const mp_int *a, uint32_t base, uint32_t *c)
 {
    mp_err err;
    mp_ord cmp;
-   unsigned int high, low, mid;
+   uint32_t high, low, mid;
    mp_int bracket_low, bracket_high, bracket_mid, t, bi_base;
 
    err = MP_OKAY;
@@ -93,29 +93,23 @@ mp_err mp_log_u32(const mp_int *a, uint32_t base, mp_int *c)
 
    /* A small shortcut for bases that are powers of two. */
    if (!(base & (base - 1u))) {
-      int x, y, bit_count;
+      int y, bit_count;
       for (y=0; (y < 7) && !(base & 1u); y++) {
          base >>= 1;
       }
       bit_count = mp_count_bits(a) - 1;
-      x = bit_count/y;
-      mp_set_u32(c, (uint32_t)(x));
+      *c = (uint32_t)(bit_count/y);
       return MP_OKAY;
    }
 
    if (a->used == 1) {
-      mp_set(c, s_digit_ilogb(base, a->dp[0]));
+      *c = (uint32_t)s_digit_ilogb(base, a->dp[0]);
       return err;
    }
 
    cmp = mp_cmp_d(a, base);
-
-   if (cmp == MP_LT) {
-      mp_zero(c);
-      return err;
-   }
-   if (cmp == MP_EQ) {
-      mp_set(c, 1uL);
+   if (cmp == MP_LT || cmp == MP_EQ) {
+      *c = cmp == MP_EQ;
       return err;
    }
 
@@ -168,16 +162,12 @@ mp_err mp_log_u32(const mp_int *a, uint32_t base, mp_int *c)
          mp_exch(&bracket_mid, &bracket_low);
       }
       if (cmp == MP_EQ) {
-         mp_set_u32(c, mid);
+         *c = mid;
          goto LBL_END;
       }
    }
 
-   if (mp_cmp(&bracket_high, a) == MP_EQ) {
-      mp_set_u32(c, high);
-   } else {
-      mp_set_u32(c, low);
-   }
+   *c = mp_cmp(&bracket_high, a) == MP_EQ ? high : low;
 
 LBL_END:
 LBL_ERR:
diff --git a/demo/test.c b/demo/test.c
index 8282917..7b29a4c 100644
--- a/demo/test.c
+++ b/demo/test.c
@@ -1544,19 +1544,20 @@ LBL_ERR:
 }
 /* stripped down version of mp_radix_size. The faster version can be off by up t
 o +3  */
-static mp_err s_rs(const mp_int *a, int radix, int *size)
+/* TODO: This function should be removed, replaced by mp_radix_size, mp_radix_size_overestimate in 2.0 */
+static mp_err s_rs(const mp_int *a, int radix, uint32_t *size)
 {
    mp_err res;
-   int digs = 0;
+   uint32_t digs = 0u;
    mp_int  t;
    mp_digit d;
-   *size = 0;
+   *size = 0u;
    if (mp_iszero(a) == MP_YES) {
-      *size = 2;
+      *size = 2u;
       return MP_OKAY;
    }
    if (radix == 2) {
-      *size = mp_count_bits(a) + 1;
+      *size = (uint32_t)mp_count_bits(a) + 1u;
       return MP_OKAY;
    }
    if ((res = mp_init_copy(&t, a)) != MP_OKAY) {
@@ -1576,13 +1577,12 @@ static mp_err s_rs(const mp_int *a, int radix, int *size)
 }
 static int test_mp_log_u32(void)
 {
-   mp_int a, lb;
+   mp_int a;
    mp_digit d;
-   uint32_t base;
-   int size;
+   uint32_t base, lb, size;
    const uint32_t max_base = MP_MIN(UINT32_MAX, MP_DIGIT_MAX);
 
-   if (mp_init_multi(&a, &lb, NULL) != MP_OKAY) {
+   if (mp_init(&a) != MP_OKAY) {
       goto LBL_ERR;
    }
 
@@ -1618,7 +1618,7 @@ static int test_mp_log_u32(void)
       if (mp_log_u32(&a, base, &lb) != MP_OKAY) {
          goto LBL_ERR;
       }
-      if (mp_cmp_d(&lb, (d == 1)?0uL:1uL) != MP_EQ) {
+      if (lb != ((d == 1)?0uL:1uL)) {
          goto LBL_ERR;
       }
    }
@@ -1639,7 +1639,7 @@ static int test_mp_log_u32(void)
       if (mp_log_u32(&a, base, &lb) != MP_OKAY) {
          goto LBL_ERR;
       }
-      if (mp_cmp_d(&lb, (d < base)?0uL:1uL) != MP_EQ) {
+      if (lb != ((d < base)?0uL:1uL)) {
          goto LBL_ERR;
       }
    }
@@ -1661,7 +1661,7 @@ static int test_mp_log_u32(void)
       }
       /* radix_size includes the memory needed for '\0', too*/
       size -= 2;
-      if (mp_cmp_d(&lb, (mp_digit)size) != MP_EQ) {
+      if (lb != size) {
          goto LBL_ERR;
       }
    }
@@ -1681,7 +1681,7 @@ static int test_mp_log_u32(void)
          goto LBL_ERR;
       }
       size -= 2;
-      if (mp_cmp_d(&lb, (mp_digit)size) != MP_EQ) {
+      if (lb != size) {
          goto LBL_ERR;
       }
    }
@@ -1697,14 +1697,14 @@ static int test_mp_log_u32(void)
    if (mp_log_u32(&a, max_base, &lb) != MP_OKAY) {
       goto LBL_ERR;
    }
-   if (mp_cmp_d(&lb, 10uL) != MP_EQ) {
+   if (lb != 10u) {
       goto LBL_ERR;
    }
 
-   mp_clear_multi(&a, &lb, NULL);
+   mp_clear(&a);
    return EXIT_SUCCESS;
 LBL_ERR:
-   mp_clear_multi(&a, &lb, NULL);
+   mp_clear(&a);
    return EXIT_FAILURE;
 }
 
diff --git a/tommath.h b/tommath.h
index b92b06e..8dd3bb3 100644
--- a/tommath.h
+++ b/tommath.h
@@ -717,7 +717,7 @@ MP_DEPRECATED(mp_prime_rand) mp_err mp_prime_random_ex(mp_int *a, int t, int siz
 mp_err mp_prime_rand(mp_int *a, int t, int size, int flags) MP_WUR;
 
 /* Integer logarithm to integer base */
-mp_err mp_log_u32(const mp_int *a, uint32_t base, mp_int *c) MP_WUR;
+mp_err mp_log_u32(const mp_int *a, uint32_t base, uint32_t *c) MP_WUR;
 
 /* c = a**b */
 mp_err mp_expt_u32(const mp_int *a, uint32_t b, mp_int *c) MP_WUR;
diff --git a/tommath_class.h b/tommath_class.h
index d7f71a4..52ba585 100644
--- a/tommath_class.h
+++ b/tommath_class.h
@@ -627,9 +627,7 @@
 #   define BN_MP_INIT_MULTI_C
 #   define BN_MP_MUL_C
 #   define BN_MP_SET_C
-#   define BN_MP_SET_U32_C
 #   define BN_MP_SQR_C
-#   define BN_MP_ZERO_C
 #endif
 
 #if defined(BN_MP_LSHD_C)