Commit 8bb11ded99dc7c9d9cb12e91d01696b2cad3cf6b

czurnieden 2019-09-02T21:05:04

included tests for mp_prime_next_prime

diff --git a/bn_mp_prime_next_prime.c b/bn_mp_prime_next_prime.c
index 330dfa7..81b781b 100644
--- a/bn_mp_prime_next_prime.c
+++ b/bn_mp_prime_next_prime.c
@@ -49,9 +49,13 @@ mp_err mp_prime_next_prime(mp_int *a, int t, int bbs_style)
             }
          }
       }
-      /* at this point a maybe 1 */
-      if (mp_cmp_d(a, 1uL) == MP_EQ) {
-         mp_set(a, 2uL);
+      /* at this point a maybe smaller than the smallest prime in the table */
+      if (mp_cmp_d(a, 2uL) != MP_GT) {
+         if (bbs_style == 1) {
+            mp_set(a, 3uL);
+         } else {
+            mp_set(a, 2uL);
+         }
          return MP_OKAY;
       }
       /* fall through to the sieve */
diff --git a/demo/test.c b/demo/test.c
index 38a75a9..390d32d 100644
--- a/demo/test.c
+++ b/demo/test.c
@@ -1052,6 +1052,136 @@ LBL_ERR:
 
 }
 
+
+static int test_mp_prime_next_prime(void)
+{
+   mp_err err;
+   mp_int a, b, c;
+
+   mp_init_multi(&a, &b, &c, NULL);
+
+
+   /* edge cases */
+   mp_set(&a, 0u);
+   if ((err = mp_prime_next_prime(&a, 5, 0)) != MP_OKAY) {
+      goto LBL_ERR;
+   }
+   if (mp_cmp_d(&a, 2u) != MP_EQ) {
+      printf("mp_prime_next_prime: output should have been 2 but was: ");
+      mp_fwrite(&a,10,stdout);
+      putchar('\n');
+      goto LBL_ERR;
+   }
+
+   mp_set(&a, 0u);
+   if ((err = mp_prime_next_prime(&a, 5, 1)) != MP_OKAY) {
+      goto LBL_ERR;
+   }
+   if (mp_cmp_d(&a, 3u) != MP_EQ) {
+      printf("mp_prime_next_prime: output should have been 3 but was: ");
+      mp_fwrite(&a,10,stdout);
+      putchar('\n');
+      goto LBL_ERR;
+   }
+
+   mp_set(&a, 2u);
+   if ((err = mp_prime_next_prime(&a, 5, 0)) != MP_OKAY) {
+      goto LBL_ERR;
+   }
+   if (mp_cmp_d(&a, 3u) != MP_EQ) {
+      printf("mp_prime_next_prime: output should have been 3 but was: ");
+      mp_fwrite(&a,10,stdout);
+      putchar('\n');
+      goto LBL_ERR;
+   }
+
+   mp_set(&a, 2u);
+   if ((err = mp_prime_next_prime(&a, 5, 1)) != MP_OKAY) {
+      goto LBL_ERR;
+   }
+   if (mp_cmp_d(&a, 3u) != MP_EQ) {
+      printf("mp_prime_next_prime: output should have been 3 but was: ");
+      mp_fwrite(&a,10,stdout);
+      putchar('\n');
+      goto LBL_ERR;
+   }
+   mp_set(&a, 8);
+   if ((err = mp_prime_next_prime(&a, 5, 1)) != MP_OKAY) {
+      goto LBL_ERR;
+   }
+   if (mp_cmp_d(&a, 11u) != MP_EQ) {
+      printf("mp_prime_next_prime: output should have been 11 but was: ");
+      mp_fwrite(&a,10,stdout);
+      putchar('\n');
+      goto LBL_ERR;
+   }
+   /* 2^300 + 157 is a 300 bit large prime to guarantee a multi-limb bigint */
+   if ((err = mp_2expt(&a, 300)) != MP_OKAY) {
+      goto LBL_ERR;
+   }
+   mp_set_u32(&b, 157);
+   if ((err = mp_add(&a, &b, &a)) != MP_OKAY) {
+      goto LBL_ERR;
+   }
+   if ((err = mp_copy(&a, &b)) != MP_OKAY) {
+      goto LBL_ERR;
+   }
+
+   /* 2^300 + 385 is the next prime */
+   mp_set_u32(&c, 228);
+   if ((err = mp_add(&b, &c, &b)) != MP_OKAY) {
+      goto LBL_ERR;
+   }
+   if ((err = mp_prime_next_prime(&a, 5, 0)) != MP_OKAY) {
+      goto LBL_ERR;
+   }
+   if (mp_cmp(&a, &b) != MP_EQ) {
+      printf("mp_prime_next_prime: output should have been\n");
+      mp_fwrite(&b,10,stdout);
+      putchar('\n');
+      printf("but was:\n");
+      mp_fwrite(&a,10,stdout);
+      putchar('\n');
+      goto LBL_ERR;
+   }
+
+   /* Use another temporary variable or recompute? Mmh... */
+   if ((err = mp_2expt(&a, 300)) != MP_OKAY) {
+      goto LBL_ERR;
+   }
+   mp_set_u32(&b, 157);
+   if ((err = mp_add(&a, &b, &a)) != MP_OKAY) {
+      goto LBL_ERR;
+   }
+   if ((err = mp_copy(&a, &b)) != MP_OKAY) {
+      goto LBL_ERR;
+   }
+
+   /* 2^300 + 631 is the next prime congruent to 3 mod 4*/
+   mp_set_u32(&c, 474);
+   if ((err = mp_add(&b, &c, &b)) != MP_OKAY) {
+      goto LBL_ERR;
+   }
+   if ((err = mp_prime_next_prime(&a, 5, 1)) != MP_OKAY) {
+      goto LBL_ERR;
+   }
+   if (mp_cmp(&a, &b) != MP_EQ) {
+      printf("mp_prime_next_prime (bbs): output should have been\n");
+      mp_fwrite(&b,10,stdout);
+      putchar('\n');
+      printf("but was:\n");
+      mp_fwrite(&a,10,stdout);
+      putchar('\n');
+      goto LBL_ERR;
+   }
+
+   mp_clear_multi(&a, &b, &c, NULL);
+   return EXIT_SUCCESS;
+LBL_ERR:
+   mp_clear_multi(&a, &b, &c, NULL);
+   return EXIT_FAILURE;
+}
+
 static int test_mp_montgomery_reduce(void)
 {
    mp_digit mp;
@@ -2113,6 +2243,7 @@ int unit_tests(int argc, char **argv)
       T(mp_root_u32),
       T(mp_or),
       T(mp_prime_is_prime),
+      T(mp_prime_next_prime),
       T(mp_prime_rand),
       T(mp_rand),
       T(mp_read_radix),