Commit 08d281c46218ef6ff30a512470813c785403897b

Daniel Mendler 2019-12-05T00:48:25

introduce MP_MAX_DIGIT_COUNT to prevent overflow

diff --git a/demo/shared.c b/demo/shared.c
index 85c26ed..536a1de 100644
--- a/demo/shared.c
+++ b/demo/shared.c
@@ -45,5 +45,5 @@ void print_header(void)
    printf("Size of mp_digit: %u\n", (unsigned int)sizeof(mp_digit));
    printf("Size of mp_word: %u\n", (unsigned int)sizeof(mp_word));
    printf("MP_DIGIT_BIT: %d\n", MP_DIGIT_BIT);
-   printf("MP_PREC: %d\n", MP_PREC);
+   printf("MP_DEFAULT_DIGIT_COUNT: %d\n", MP_DEFAULT_DIGIT_COUNT);
 }
diff --git a/demo/test.c b/demo/test.c
index 731c398..5009a8d 100644
--- a/demo/test.c
+++ b/demo/test.c
@@ -2220,6 +2220,8 @@ static int test_s_mp_radix_size_overestimate(void)
        284u, 283u, 281u, 280u, 279u, 278u, 277u, 276u, 275u,
        273u, 272u
    };
+
+#if 0
    size_t big_results[65] = {
               0u,         0u,         0u,  1354911329u, 1073741825u,
       924870867u, 830760078u, 764949110u,   715827883u,  677455665u,
@@ -2235,6 +2237,7 @@ static int test_s_mp_radix_size_overestimate(void)
       371449582u, 369786879u, 368168034u,   366591092u,  365054217u,
       363555684u, 362093873u, 360667257u,   359274399u,  357913942
    };
+#endif
 
 /* *INDENT-ON* */
    if ((err = mp_init(&a)) != MP_OKAY)        goto LBL_ERR;
@@ -2265,6 +2268,8 @@ static int test_s_mp_radix_size_overestimate(void)
       }
       a.sign = MP_ZPOS;
    }
+
+#if 0
    if ((err = mp_2expt(&a, INT_MAX - 1)) != MP_OKAY) {
       goto LBL_ERR;
    }
@@ -2292,6 +2297,8 @@ static int test_s_mp_radix_size_overestimate(void)
       }
       a.sign = MP_ZPOS;
    }
+#endif
+
    mp_clear(&a);
    return EXIT_SUCCESS;
 LBL_ERR:
diff --git a/mp_grow.c b/mp_grow.c
index 344a5a8..ff3b96d 100644
--- a/mp_grow.c
+++ b/mp_grow.c
@@ -8,16 +8,21 @@ mp_err mp_grow(mp_int *a, int size)
 {
    /* if the alloc size is smaller alloc more ram */
    if (a->alloc < size) {
-           /* TODO */
+      mp_digit *dp;
+
+      if (size > MP_MAX_DIGIT_COUNT) {
+         return MP_MEM;
+      }
+
       /* reallocate the array a->dp
        *
        * We store the return in a temporary variable
        * in case the operation failed we don't want
        * to overwrite the dp member of a.
        */
-      mp_digit *dp = (mp_digit *) MP_REALLOC(a->dp,
-                                             (size_t)a->alloc * sizeof(mp_digit),
-                                             (size_t)size * sizeof(mp_digit));
+      dp = (mp_digit *) MP_REALLOC(a->dp,
+                                   (size_t)a->alloc * sizeof(mp_digit),
+                                   (size_t)size * sizeof(mp_digit));
       if (dp == NULL) {
          /* reallocation failed but "a" is still valid [can be freed] */
          return MP_MEM;
diff --git a/mp_init.c b/mp_init.c
index 9b82282..af16744 100644
--- a/mp_init.c
+++ b/mp_init.c
@@ -7,7 +7,7 @@
 mp_err mp_init(mp_int *a)
 {
    /* allocate memory required and clear it */
-   a->dp = (mp_digit *) MP_CALLOC((size_t)MP_PREC, sizeof(mp_digit));
+   a->dp = (mp_digit *) MP_CALLOC((size_t)MP_DEFAULT_DIGIT_COUNT, sizeof(mp_digit));
    if (a->dp == NULL) {
       return MP_MEM;
    }
@@ -15,7 +15,7 @@ mp_err mp_init(mp_int *a)
    /* set the used to zero, allocated digits to the default precision
     * and sign to positive */
    a->used  = 0;
-   a->alloc = MP_PREC;
+   a->alloc = MP_DEFAULT_DIGIT_COUNT;
    a->sign  = MP_ZPOS;
 
    return MP_OKAY;
diff --git a/mp_init_size.c b/mp_init_size.c
index fb7a37d..979a0b7 100644
--- a/mp_init_size.c
+++ b/mp_init_size.c
@@ -6,9 +6,12 @@
 /* init an mp_init for a given size */
 mp_err mp_init_size(mp_int *a, int size)
 {
-   size = MP_MAX(MP_MIN_PREC, size);
+   size = MP_MAX(MP_MIN_DIGIT_COUNT, size);
+
+   if (size > MP_MAX_DIGIT_COUNT) {
+      return MP_MEM;
+   }
 
-   /*TODO*/
    /* alloc mem */
    a->dp = (mp_digit *) MP_CALLOC((size_t)size, sizeof(mp_digit));
    if (a->dp == NULL) {
diff --git a/mp_shrink.c b/mp_shrink.c
index e5814cb..3d9b162 100644
--- a/mp_shrink.c
+++ b/mp_shrink.c
@@ -6,7 +6,7 @@
 /* shrink a bignum */
 mp_err mp_shrink(mp_int *a)
 {
-   int alloc = MP_MAX(MP_MIN_PREC, a->used);
+   int alloc = MP_MAX(MP_MIN_DIGIT_COUNT, a->used);
    if (a->alloc != alloc) {
       mp_digit *dp = (mp_digit *) MP_REALLOC(a->dp,
                                              (size_t)a->alloc * sizeof(mp_digit),
diff --git a/tommath_private.h b/tommath_private.h
index 24e7781..0096479 100644
--- a/tommath_private.h
+++ b/tommath_private.h
@@ -140,22 +140,29 @@ typedef uint64_t mp_word;
 
 MP_STATIC_ASSERT(correct_word_size, sizeof(mp_word) == (2u * sizeof(mp_digit)))
 
-/* default precision */
-#ifndef MP_PREC
+/* default number of digits */
+#ifndef MP_DEFAULT_DIGIT_COUNT
 #   ifndef MP_LOW_MEM
-#      define MP_PREC 32        /* default digits of precision */
+#      define MP_DEFAULT_DIGIT_COUNT 32
 #   else
-#      define MP_PREC 8         /* default digits of precision */
+#      define MP_DEFAULT_DIGIT_COUNT 8
 #   endif
 #endif
 
-/* Minimum number of available digits in mp_int, MP_PREC >= MP_MIN_PREC
+/* Minimum number of available digits in mp_int, MP_DEFAULT_DIGIT_COUNT >= MP_MIN_DIGIT_COUNT
  * - Must be at least 3 for s_mp_div_school.
  * - Must be large enough such that the mp_set_u64 setter can
  *   store uint64_t in the mp_int without growing
  */
-#define MP_MIN_PREC MP_MAX(3, (((int)MP_SIZEOF_BITS(uint64_t) + MP_DIGIT_BIT) - 1) / MP_DIGIT_BIT)
-MP_STATIC_ASSERT(prec_geq_min_prec, MP_PREC >= MP_MIN_PREC)
+#define MP_MIN_DIGIT_COUNT MP_MAX(3, (((int)MP_SIZEOF_BITS(uint64_t) + MP_DIGIT_BIT) - 1) / MP_DIGIT_BIT)
+MP_STATIC_ASSERT(prec_geq_min_prec, MP_DEFAULT_DIGIT_COUNT >= MP_MIN_DIGIT_COUNT)
+
+/* Maximum number of digits.
+ * - Must be small enough such that mp_bit_count does not overflow.
+ * - Must be small enough such that mp_radix_size for base 2 does not overflow.
+ *   mp_radix_size needs two additional bytes for zero termination and sign.
+ */
+#define MP_MAX_DIGIT_COUNT ((INT_MAX - 2) / MP_DIGIT_BIT)
 
 /* random number source */
 extern MP_PRIVATE mp_err(*s_mp_rand_source)(void *out, size_t size);