Commit d4f6b43fa8ab24cd0bd71a31f1a89836bb528929

czurnieden 2019-10-11T00:29:20

use of mp_ilogb in mp_radix_size

diff --git a/bn_mp_radix_size.c b/bn_mp_radix_size.c
index b96f487..33ef7d3 100644
--- a/bn_mp_radix_size.c
+++ b/bn_mp_radix_size.c
@@ -6,12 +6,8 @@
 /* returns size of ASCII representation */
 mp_err mp_radix_size(const mp_int *a, int radix, int *size)
 {
-   mp_err  err;
-   int digs;
-   mp_int   t;
-   mp_digit d;
-
-   *size = 0;
+   mp_err err;
+   mp_int a_, b;
 
    /* make sure the radix is in range */
    if ((radix < 2) || (radix > 64)) {
@@ -23,43 +19,25 @@ mp_err mp_radix_size(const mp_int *a, int radix, int *size)
       return MP_OKAY;
    }
 
-   /* special case for binary */
-   if (radix == 2) {
-      *size = (mp_count_bits(a) + ((a->sign == MP_NEG) ? 1 : 0) + 1);
-      return MP_OKAY;
-   }
-
-   /* digs is the digit count */
-   digs = 0;
-
-   /* if it's negative add one for the sign */
-   if (a->sign == MP_NEG) {
-      ++digs;
+   if ((err = mp_init(&b)) != MP_OKAY) {
+      goto LBL_ERR;
    }
 
-   /* init a copy of the input */
-   if ((err = mp_init_copy(&t, a)) != MP_OKAY) {
-      return err;
+   a_ = *a;
+   a_.sign = MP_ZPOS;
+   if ((err = mp_ilogb(&a_, (uint32_t)radix, &b)) != MP_OKAY) {
+      goto LBL_ERR;
    }
 
-   /* force temp to positive */
-   t.sign = MP_ZPOS;
-
-   /* fetch out all of the digits */
-   while (!MP_IS_ZERO(&t)) {
-      if ((err = mp_div_d(&t, (mp_digit)radix, &t, &d)) != MP_OKAY) {
-         goto LBL_ERR;
-      }
-      ++digs;
-   }
+   *size = (int)mp_get_l(&b);
 
-   /* return digs + 1, the 1 is for the NULL byte that would be required. */
-   *size = digs + 1;
-   err = MP_OKAY;
+   /* mp_ilogb truncates to zero, hence we need one extra put on top and one for `\0`. */
+   *size += 2 + (a->sign == MP_NEG);
 
 LBL_ERR:
-   mp_clear(&t);
+   mp_clear(&b);
    return err;
 }
 
+
 #endif
diff --git a/demo/test.c b/demo/test.c
index c1fc878..4c3f1cc 100644
--- a/demo/test.c
+++ b/demo/test.c
@@ -2282,6 +2282,62 @@ LTM_ERR:
    return EXIT_FAILURE;
 }
 
+static int test_mp_radix_size(void)
+{
+   mp_err err;
+   mp_int a;
+   int radix, size;
+/* *INDENT-OFF* */
+   int results[65] = {
+       0, 0, 1627, 1027, 814, 702, 630, 581, 543,
+       514, 491, 471, 455, 441, 428, 418, 408, 399,
+       391, 384, 378, 372, 366, 361, 356, 352, 347,
+       343, 340, 336, 333, 330, 327, 324, 321, 318,
+       316, 314, 311, 309, 307, 305, 303, 301, 299,
+       298, 296, 294, 293, 291, 290, 288, 287, 285,
+       284, 283, 281, 280, 279, 278, 277, 276, 275,
+       273, 272
+   };
+/* *INDENT-ON* */
+
+   mp_init(&a);
+
+   /* number to result in a different size for every base: 67^(4 * 67) */
+   mp_set(&a, 67);
+   if ((err = mp_expt_u32(&a, 268u, &a)) != MP_OKAY) {
+      goto LTM_ERR;
+   }
+
+   for (radix = 2; radix < 65; radix++) {
+      if ((err = mp_radix_size(&a, radix, &size)) != MP_OKAY) {
+         goto LTM_ERR;
+      }
+      if (size != results[radix]) {
+         fprintf(stderr, "mp_radix_size: result for base %d was %d instead of %d\n",
+                 radix, size, results[radix]);
+         goto LTM_ERR;
+      }
+      a.sign = MP_NEG;
+      if ((err = mp_radix_size(&a, radix, &size)) != MP_OKAY) {
+         goto LTM_ERR;
+      }
+      if (size != (results[radix] + 1)) {
+         fprintf(stderr, "mp_radix_size: result for base %d was %d instead of %d\n",
+                 radix, size, results[radix]);
+         goto LTM_ERR;
+      }
+      a.sign = MP_ZPOS;
+   }
+
+   mp_clear(&a);
+   return EXIT_SUCCESS;
+LTM_ERR:
+   mp_clear(&a);
+   return EXIT_FAILURE;
+}
+
+
+
 static int test_mp_read_write_ubin(void)
 {
    mp_int a, b, c;
@@ -2446,6 +2502,7 @@ static int unit_tests(int argc, char **argv)
       T1(mp_read_write_sbin, MP_TO_SBIN),
       T1(mp_reduce_2k, MP_REDUCE_2K),
       T1(mp_reduce_2k_l, MP_REDUCE_2K_L),
+      T1(mp_radix_size, MP_RADIX_SIZE),
 #if defined(__STDC_IEC_559__) || defined(__GCC_IEC_559)
       T1(mp_set_double, MP_SET_DOUBLE),
 #endif
diff --git a/tommath_class.h b/tommath_class.h
index b3c5dbd..ec712d1 100644
--- a/tommath_class.h
+++ b/tommath_class.h
@@ -751,9 +751,9 @@
 
 #if defined(BN_MP_RADIX_SIZE_C)
 #   define BN_MP_CLEAR_C
-#   define BN_MP_COUNT_BITS_C
-#   define BN_MP_DIV_D_C
-#   define BN_MP_INIT_COPY_C
+#   define BN_MP_GET_L_C
+#   define BN_MP_ILOGB_C
+#   define BN_MP_INIT_C
 #endif
 
 #if defined(BN_MP_RADIX_SMAP_C)