88#include <time.h>
99#include "../tommath_private.h"
1010
11- static mp_digit prime_digit (void )
11+ static void mp_print (const char * s , const mp_int * a , int radix , FILE * stream )
12+ {
13+ mp_err err ;
14+ fputs (s , stream );
15+ err = mp_fwrite (a , radix , stream );
16+ if (err != MP_OKAY ) {
17+ fprintf (stderr ,"mp_fwrite in mp_print failed. error = %s\n" , mp_error_to_string (err ));
18+ exit (EXIT_FAILURE );
19+ }
20+ fputc ('\n' ,stream );
21+ }
22+
23+ static mp_digit prime_digit (int bits )
1224{
13- int n ;
1425 mp_digit d = 0 ;
1526 mp_int a ;
1627 mp_err err ;
1728
18- n = abs (rand ()) % MP_MASK ;
19- if ((err = mp_init_ul (& a , (unsigned long )n )) != MP_OKAY ) goto LTM_ERR ;
20- if ((err = mp_prime_next_prime (& a , -1 , false)) != MP_OKAY ) goto LTM_ERR ;
21- while (a .used > 1 ) {
22- if ((err = mp_div_2 (& a , & a )) != MP_OKAY ) goto LTM_ERR ;
23- if ((err = mp_prime_next_prime (& a , -1 , false)) != MP_OKAY ) goto LTM_ERR ;
29+ if ((err = mp_init (& a )) != MP_OKAY ) {
30+ return 0 ;
2431 }
32+
33+ if ((err = mp_prime_rand (& a , 1 , bits , false)) != MP_OKAY ) goto LTM_ERR ;
2534 d = a .dp [0 ];
2635
2736LTM_ERR :
@@ -35,12 +44,14 @@ static mp_err pprime(int k, int li, mp_int *p, mp_int *q)
3544{
3645 mp_int a , b , c , n , x , y , z , v ;
3746 mp_err err = MP_OKAY ;
38- int ii ;
39- static const mp_digit bases [] = { 2 , 3 , 5 , 7 , 11 , 13 , 17 , 19 };
47+ int ii , bits ;
4048
4149 /* single digit ? */
42- if (k <= (int ) MP_DIGIT_BIT ) {
43- mp_set (p , prime_digit ());
50+ if (k < (int ) MP_DIGIT_BIT ) {
51+ mp_set (p , prime_digit (k ));
52+ if (mp_iszero (p )) {
53+ return MP_VAL ;
54+ }
4455 return MP_OKAY ;
4556 }
4657
@@ -54,14 +65,27 @@ static mp_err pprime(int k, int li, mp_int *p, mp_int *q)
5465 }
5566
5667 /* set the prime */
57- mp_set (& a , prime_digit ());
68+ mp_set (& a , prime_digit (MP_DIGIT_BIT ));
69+ if (mp_iszero (& a )) {
70+ err = MP_VAL ;
71+ goto LTM_ERR ;
72+ }
5873
5974 /* now loop making the single digit */
6075 while (mp_count_bits (& a ) < k ) {
61- fprintf (stderr , "prime has %4d bits left\r" , k - mp_count_bits (& a ));
76+ bits = k - mp_count_bits (& a );
77+ fprintf (stderr , "prime has %4d bits left\r" , bits );
6278 fflush (stderr );
6379top :
64- mp_set (& b , prime_digit ());
80+ if (bits < MP_DIGIT_BIT ) {
81+ mp_set (& b , prime_digit (bits ));
82+ } else {
83+ mp_set (& b , prime_digit (MP_DIGIT_BIT ));
84+ }
85+ if (mp_iszero (& b )) {
86+ err = MP_VAL ;
87+ goto LTM_ERR ;
88+ }
6589
6690 /* now compute z = a * b * 2 */
6791 /* z = a * b */
@@ -78,10 +102,10 @@ static mp_err pprime(int k, int li, mp_int *p, mp_int *q)
78102 if (mp_cmp_d (& y , 1uL ) != MP_EQ ) {
79103 goto top ;
80104 }
81-
105+ mp_set ( & x , 2u );
82106 /* now try base x=bases[ii] */
83107 for (ii = 0 ; ii < li ; ii ++ ) {
84- mp_set ( & x , bases [ ii ]) ;
108+ if (( err = mp_prime_next_prime ( & x , -1 , false)) != MP_OKAY ) goto LTM_ERR ;
85109
86110 /* compute x^a mod n; y = x^a mod n */
87111 if ((err = mp_exptmod (& x , & a , & n , & y )) != MP_OKAY ) goto LTM_ERR ;
@@ -132,17 +156,11 @@ static mp_err pprime(int k, int li, mp_int *p, mp_int *q)
132156 goto top ;
133157 }
134158
135- {
136- char buf [4096 ];
137-
138- if ((err = mp_to_decimal (& n , buf , sizeof (buf ))) != MP_OKAY ) goto LTM_ERR ;
139- printf ("Certificate of primality for:\n%s\n\n" , buf );
140- if ((err = mp_to_decimal (& a , buf , sizeof (buf ))) != MP_OKAY ) goto LTM_ERR ;
141- printf ("A == \n%s\n\n" , buf );
142- if ((err = mp_to_decimal (& b , buf , sizeof (buf ))) != MP_OKAY ) goto LTM_ERR ;
143- printf ("B == \n%s\n\nG == %lu\n" , buf , bases [ii ]);
144- printf ("----------------------------------------------------------------\n" );
145- }
159+ mp_print ("Certificate of primality for:\n " , & n , 10 , stdout );
160+ mp_print ("A == " , & a , 10 , stdout );
161+ mp_print ("B == " , & b , 10 , stdout );
162+ mp_print ("G == " , & x , 10 , stdout );
163+ printf ("----------------------------------------------------------------\n" );
146164
147165 /* a = n */
148166 if ((err = mp_copy (& n , & a )) != MP_OKAY ) goto LTM_ERR ;
@@ -170,29 +188,28 @@ int main(void)
170188 int k , li ;
171189 clock_t t1 ;
172190
173- srand (time (NULL ));
174-
175191 printf ("Enter # of bits: \n" );
176192 fgets (buf , sizeof (buf ), stdin );
177193 sscanf (buf , "%d" , & k );
178194
179- printf ("Enter number of bases to try (1 to 8): \n" );
195+ printf ("Enter number of bases to try\n" );
180196 fgets (buf , sizeof (buf ), stdin );
181197 sscanf (buf , "%d" , & li );
182198
183199
184- if ((err = mp_init_multi (& p , & q , NULL )) != MP_OKAY ) goto LTM_ERR ;
200+ if ((err = mp_init_multi (& p , & q , NULL )) != MP_OKAY ) goto LTM_ERR ;
185201
186202 t1 = clock ();
187- pprime (k , li , & p , & q );
203+ if ((err = pprime (k , li , & p , & q )) != MP_OKAY ) {
204+ fprintf (stderr , "Something went wrong in function pprime: %s\n" , mp_error_to_string (err ));
205+ goto LTM_ERR ;
206+ }
188207 t1 = clock () - t1 ;
189208
190209 printf ("\n\nTook %lu ticks, %d bits\n" , t1 , mp_count_bits (& p ));
191210
192- if ((err = mp_to_decimal (& p , buf , sizeof (buf ))) != MP_OKAY ) goto LTM_ERR ;
193- printf ("P == %s\n" , buf );
194- if ((err = mp_to_decimal (& q , buf , sizeof (buf ))) != MP_OKAY ) goto LTM_ERR ;
195- printf ("Q == %s\n" , buf );
211+ mp_print ("P == " , & p , 10 , stdout );
212+ mp_print ("Q == " , & q , 10 , stdout );
196213
197214 mp_clear_multi (& p , & q , NULL );
198215 exit (EXIT_SUCCESS );
0 commit comments