Commit 857b112ef27efe39735fe1201a86dc8e4afa98e2

Steffen Jaeckel 2019-09-03T01:04:00

Merge pull request #334 from czurnieden/re_issue_332 repair of #333

diff --git a/bn_mp_prime_next_prime.c b/bn_mp_prime_next_prime.c
index aaa821b..17abbf8 100644
--- a/bn_mp_prime_next_prime.c
+++ b/bn_mp_prime_next_prime.c
@@ -24,20 +24,17 @@ mp_err mp_prime_next_prime(mp_int *a, int t, int bbs_style)
       /* find which prime it is bigger than */
       for (x = PRIVATE_MP_PRIME_TAB_SIZE - 2; x >= 0; x--) {
          if (mp_cmp_d(a, s_mp_prime_tab[x]) != MP_LT) {
+            /* ok we found a prime smaller or
+             * equal [so the next is larger]
+             */
             if (bbs_style == 1) {
-               /* ok we found a prime smaller or
-                * equal [so the next is larger]
-                *
-                * however, the prime must be
+               /* ... however, the prime must be
                 * congruent to 3 mod 4
-                */
-               if ((s_mp_prime_tab[x + 1] & 3u) != 3u) {
-                  /* scan upwards for a prime congruent to 3 mod 4 */
-                  for (y = x + 1; y < PRIVATE_MP_PRIME_TAB_SIZE; y++) {
-                     if ((s_mp_prime_tab[y] & 3u) == 3u) {
-                        mp_set(a, s_mp_prime_tab[y]);
-                        return MP_OKAY;
-                     }
+                * so do a scan upwards for such a prime */
+               for (y = x + 1; y < PRIVATE_MP_PRIME_TAB_SIZE; y++) {
+                  if ((s_mp_prime_tab[y] & 3u) == 3u) {
+                     mp_set(a, s_mp_prime_tab[y]);
+                     return MP_OKAY;
                   }
                }
             } else {
@@ -46,9 +43,13 @@ mp_err mp_prime_next_prime(mp_int *a, int t, int bbs_style)
             }
          }
       }
-      /* at this point a maybe 1 */
-      if (mp_cmp_d(a, 1uL) == MP_EQ) {
-         mp_set(a, 2uL);
+      /* at this point a maybe smaller than the smallest prime in the table */
+      if (mp_cmp_d(a, 2uL) != MP_GT) {
+         if (bbs_style == 1) {
+            mp_set(a, 3uL);
+         } else {
+            mp_set(a, 2uL);
+         }
          return MP_OKAY;
       }
       /* fall through to the sieve */
diff --git a/demo/test.c b/demo/test.c
index 38a75a9..390d32d 100644
--- a/demo/test.c
+++ b/demo/test.c
@@ -1052,6 +1052,136 @@ LBL_ERR:
 
 }
 
+
+static int test_mp_prime_next_prime(void)
+{
+   mp_err err;
+   mp_int a, b, c;
+
+   mp_init_multi(&a, &b, &c, NULL);
+
+
+   /* edge cases */
+   mp_set(&a, 0u);
+   if ((err = mp_prime_next_prime(&a, 5, 0)) != MP_OKAY) {
+      goto LBL_ERR;
+   }
+   if (mp_cmp_d(&a, 2u) != MP_EQ) {
+      printf("mp_prime_next_prime: output should have been 2 but was: ");
+      mp_fwrite(&a,10,stdout);
+      putchar('\n');
+      goto LBL_ERR;
+   }
+
+   mp_set(&a, 0u);
+   if ((err = mp_prime_next_prime(&a, 5, 1)) != MP_OKAY) {
+      goto LBL_ERR;
+   }
+   if (mp_cmp_d(&a, 3u) != MP_EQ) {
+      printf("mp_prime_next_prime: output should have been 3 but was: ");
+      mp_fwrite(&a,10,stdout);
+      putchar('\n');
+      goto LBL_ERR;
+   }
+
+   mp_set(&a, 2u);
+   if ((err = mp_prime_next_prime(&a, 5, 0)) != MP_OKAY) {
+      goto LBL_ERR;
+   }
+   if (mp_cmp_d(&a, 3u) != MP_EQ) {
+      printf("mp_prime_next_prime: output should have been 3 but was: ");
+      mp_fwrite(&a,10,stdout);
+      putchar('\n');
+      goto LBL_ERR;
+   }
+
+   mp_set(&a, 2u);
+   if ((err = mp_prime_next_prime(&a, 5, 1)) != MP_OKAY) {
+      goto LBL_ERR;
+   }
+   if (mp_cmp_d(&a, 3u) != MP_EQ) {
+      printf("mp_prime_next_prime: output should have been 3 but was: ");
+      mp_fwrite(&a,10,stdout);
+      putchar('\n');
+      goto LBL_ERR;
+   }
+   mp_set(&a, 8);
+   if ((err = mp_prime_next_prime(&a, 5, 1)) != MP_OKAY) {
+      goto LBL_ERR;
+   }
+   if (mp_cmp_d(&a, 11u) != MP_EQ) {
+      printf("mp_prime_next_prime: output should have been 11 but was: ");
+      mp_fwrite(&a,10,stdout);
+      putchar('\n');
+      goto LBL_ERR;
+   }
+   /* 2^300 + 157 is a 300 bit large prime to guarantee a multi-limb bigint */
+   if ((err = mp_2expt(&a, 300)) != MP_OKAY) {
+      goto LBL_ERR;
+   }
+   mp_set_u32(&b, 157);
+   if ((err = mp_add(&a, &b, &a)) != MP_OKAY) {
+      goto LBL_ERR;
+   }
+   if ((err = mp_copy(&a, &b)) != MP_OKAY) {
+      goto LBL_ERR;
+   }
+
+   /* 2^300 + 385 is the next prime */
+   mp_set_u32(&c, 228);
+   if ((err = mp_add(&b, &c, &b)) != MP_OKAY) {
+      goto LBL_ERR;
+   }
+   if ((err = mp_prime_next_prime(&a, 5, 0)) != MP_OKAY) {
+      goto LBL_ERR;
+   }
+   if (mp_cmp(&a, &b) != MP_EQ) {
+      printf("mp_prime_next_prime: output should have been\n");
+      mp_fwrite(&b,10,stdout);
+      putchar('\n');
+      printf("but was:\n");
+      mp_fwrite(&a,10,stdout);
+      putchar('\n');
+      goto LBL_ERR;
+   }
+
+   /* Use another temporary variable or recompute? Mmh... */
+   if ((err = mp_2expt(&a, 300)) != MP_OKAY) {
+      goto LBL_ERR;
+   }
+   mp_set_u32(&b, 157);
+   if ((err = mp_add(&a, &b, &a)) != MP_OKAY) {
+      goto LBL_ERR;
+   }
+   if ((err = mp_copy(&a, &b)) != MP_OKAY) {
+      goto LBL_ERR;
+   }
+
+   /* 2^300 + 631 is the next prime congruent to 3 mod 4*/
+   mp_set_u32(&c, 474);
+   if ((err = mp_add(&b, &c, &b)) != MP_OKAY) {
+      goto LBL_ERR;
+   }
+   if ((err = mp_prime_next_prime(&a, 5, 1)) != MP_OKAY) {
+      goto LBL_ERR;
+   }
+   if (mp_cmp(&a, &b) != MP_EQ) {
+      printf("mp_prime_next_prime (bbs): output should have been\n");
+      mp_fwrite(&b,10,stdout);
+      putchar('\n');
+      printf("but was:\n");
+      mp_fwrite(&a,10,stdout);
+      putchar('\n');
+      goto LBL_ERR;
+   }
+
+   mp_clear_multi(&a, &b, &c, NULL);
+   return EXIT_SUCCESS;
+LBL_ERR:
+   mp_clear_multi(&a, &b, &c, NULL);
+   return EXIT_FAILURE;
+}
+
 static int test_mp_montgomery_reduce(void)
 {
    mp_digit mp;
@@ -2113,6 +2243,7 @@ int unit_tests(int argc, char **argv)
       T(mp_root_u32),
       T(mp_or),
       T(mp_prime_is_prime),
+      T(mp_prime_next_prime),
       T(mp_prime_rand),
       T(mp_rand),
       T(mp_read_radix),