gpuv.cl (32945B)
1 /* 2 * lib-gpu-verify 3 * 4 * This software contains code derived from or inspired by the BigDigit library, 5 * <http://www.di-mgt.com.au/bigdigits.html> 6 * which is distributed under the Mozilla Public License, version 2.0. 7 * 8 * The original code and modifications made to it are subject to the terms and 9 * conditions of the Mozilla Public License, version 2.0. A copy of the 10 * MPL license can be obtained at 11 * https://www.mozilla.org/en-US/MPL/2.0/. 12 * 13 * Changes and additions to the original code are as follows: 14 * - Copied various parts of the BigDigit library into this kernel. 15 * - Some functions and macros were changed to accomodate the architecture of OpenCL. 16 * - functions were added to extend the functionality required by this OpenCL kernel. 17 * 18 * Contributors: 19 * - Cedric Zwahlen cedric.zwahlen@bfh.ch 20 * 21 * Please note that this software is distributed on an "AS IS" BASIS, 22 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 23 * See the Mozilla Public License, version 2.0, for the specific language 24 * governing permissions and limitations under the License. 25 */ 26 27 // macros 28 29 //#define max(a,b) (((a) > (b)) ? (a) : (b)) 30 31 // only for that string conversion 32 #define ALLOC_BYTES(b,n) do{assert((n)<=sizeof((b)));zeroise_bytes((b),(n));}while(0) 33 #define FREE_BYTES(b,n) zeroise_bytes((b),(n)) 34 35 36 #define MAX_DIGIT 0xFFFFFFFFUL 37 #define MAX_HALF_DIGIT 0xFFFFUL /* NB 'L' */ 38 #define BITS_PER_DIGIT 32 39 #define HIBITMASK 0x80000000UL 40 41 #define MAX_FIXED_BIT_LENGTH 8192 42 #define MAX_FIXED_DIGITS ((MAX_FIXED_BIT_LENGTH + BITS_PER_DIGIT - 1) / BITS_PER_DIGIT) 43 44 #define MAX_ALLOC_SIZE 64 45 46 #define BYTES_PER_DIGIT (BITS_PER_DIGIT / 8) 47 48 #define PRIuBIGD PRIu32 49 #define PRIxBIGD PRIx32 50 #define PRIXBIGD PRIX32 51 52 /* MACROS TO DO MODULAR SQUARING AND MULTIPLICATION USING PRE-ALLOCATED TEMPS */ 53 /* Required lengths |y|=|t1|=|t2|=2*n, |m|=n; but final |y|=n */ 54 /* Square: y = (y * y) mod m */ 55 #define mpMODSQUARETEMP(y,m,n,t1,t2) do{mpSquare(t1,y,n);mpDivide(t2,y,t1,n*2,m,n);}while(0) 56 /* Mult: y = (y * x) mod m */ 57 #define mpMODMULTTEMP(y,x,m,n,t1,t2) do{mpMultiply(t1,x,y,n);mpDivide(t2,y,t1,n*2,m,n);}while(0) 58 59 #define mpNEXTBITMASK(mask, n) do{if(mask==1){mask=HIBITMASK;n--;}else{mask>>=1;}}while(0) 60 61 #define assert(x){if((x)==0){printf((char __constant *)"assert reached\n");}} 62 63 typedef uint DIGIT_T; 64 65 typedef uint16 HALF_DIGIT_T; 66 67 68 // forward definitions 69 70 int mpModulo(__global DIGIT_T *r, DIGIT_T *u, size_t udigits, __global DIGIT_T *v, size_t vdigits); 71 72 //int mpModMult(__global DIGIT_T *a, __global DIGIT_T *x, const DIGIT_T *y, DIGIT_T *m, size_t ndigits); 73 74 int mpMultiply( DIGIT_T *w, __global DIGIT_T *u, const DIGIT_T *v, size_t ndigits); 75 76 DIGIT_T mpAdd( DIGIT_T *w, const DIGIT_T *u, const DIGIT_T *v, size_t ndigits); 77 78 int mpDivide(DIGIT_T *q, DIGIT_T *r, DIGIT_T *u, size_t udigits, __global DIGIT_T *v, size_t vdigits); 79 80 int QhatTooBig(DIGIT_T qhat, DIGIT_T rhat, DIGIT_T vn2, DIGIT_T ujn2); 81 82 DIGIT_T mpMultSub(DIGIT_T wn, DIGIT_T *w, const DIGIT_T *v, DIGIT_T q, size_t n); 83 84 DIGIT_T mpShiftLeft( DIGIT_T *a, const DIGIT_T *b, size_t shift, size_t ndigits); 85 86 87 void mpSetDigit(DIGIT_T *a, DIGIT_T d, size_t ndigits); 88 89 int mpCompare(const DIGIT_T *a, const DIGIT_T *b, size_t ndigits); 90 91 92 DIGIT_T mpShiftRight( DIGIT_T *a, const DIGIT_T *b, size_t shift, size_t ndigits); 93 int spMultiply(uint p[2], uint x, uint y); 94 uint spDivide(uint *pq, uint *pr, const uint u[2], uint v); 95 96 int mpSquare( DIGIT_T *w, const DIGIT_T *x, size_t ndigits); 97 98 size_t mpBitLength(const DIGIT_T *d, size_t ndigits); 99 100 DIGIT_T mpShortDiv( DIGIT_T *q, const DIGIT_T *u, DIGIT_T v, 101 size_t ndigits); 102 103 void mpSetEqual(DIGIT_T *a, const DIGIT_T *b, size_t ndigits); 104 105 size_t uiceil(float x); 106 volatile uint8 zeroise_bytes(volatile void *v, size_t n); 107 108 size_t mpSizeof(const DIGIT_T *a, size_t ndigits); 109 110 volatile DIGIT_T mpSetZero(volatile DIGIT_T *a, size_t ndigits); 111 112 int mpIsZero( const DIGIT_T *a, size_t ndigits); 113 114 void mpFail(char *msg); 115 116 //int mpModExpO( DIGIT_T *yout, DIGIT_T *x, DIGIT_T *e, DIGIT_T *m, size_t ndigits); 117 118 //void assert(bool precondition); 119 120 // global memory definitions 121 122 size_t mpSizeof_g(__global DIGIT_T *a, size_t ndigits); 123 124 DIGIT_T mpShiftLeft_gg(__global DIGIT_T *a, __global const DIGIT_T *b, size_t shift, size_t ndigits); 125 126 DIGIT_T mpShiftRight_gg(__global DIGIT_T *a, __global const DIGIT_T *b, size_t shift, size_t ndigits); 127 128 DIGIT_T mpShiftLeft_lg(DIGIT_T *a, __global const DIGIT_T *b, size_t shift, size_t ndigits); 129 130 int mpCompare_g(__global const DIGIT_T *a, __global const DIGIT_T *b, size_t ndigits); 131 132 int mpCompare_lg(const DIGIT_T *a, __global const DIGIT_T *b, size_t ndigits); 133 134 void mpSetEqual_lg(DIGIT_T *a, __global const DIGIT_T *b, size_t ndigits); 135 void mpSetEqual_gl(__global DIGIT_T *a, const DIGIT_T *b, size_t ndigits); 136 137 DIGIT_T mpMultSub_lg(DIGIT_T wn, DIGIT_T *w, __global const DIGIT_T *v, DIGIT_T q, size_t n); 138 139 DIGIT_T mpAdd_llg( DIGIT_T *w, const DIGIT_T *u, __global const DIGIT_T *v, size_t ndigits); 140 141 void mpSetDigit_g(__global DIGIT_T *a, DIGIT_T d, size_t ndigits); 142 143 volatile DIGIT_T mpSetZero_g(__global volatile DIGIT_T *a, size_t ndigits); 144 145 // implementation 146 147 int mpModulo(__global DIGIT_T *r, DIGIT_T *u, size_t udigits, __global DIGIT_T *v, size_t vdigits) 148 { 149 /* Computes r = u mod v 150 where r, v are multiprecision integers of length vdigits 151 and u is a multiprecision integer of length udigits. 152 r may overlap v. 153 154 Note that r here is only vdigits long, 155 whereas in mpDivide it is udigits long. 156 157 Use remainder from mpDivide function. 158 */ 159 160 DIGIT_T qq[MAX_FIXED_DIGITS*2]; 161 DIGIT_T rr[MAX_FIXED_DIGITS*2]; 162 163 /* rr[nn] = u mod v */ 164 mpDivide(qq, rr, u, udigits, v, vdigits); 165 166 /* Final r is only vdigits long */ 167 mpSetEqual_gl(r, rr, vdigits); 168 169 return 0; 170 } 171 172 int mpMultiply( DIGIT_T *w, __global DIGIT_T *u, const DIGIT_T *v, size_t ndigits) 173 { 174 /* Computes product w = u * v 175 where u, v are multiprecision integers of ndigits each 176 and w is a multiprecision integer of 2*ndigits 177 178 Ref: Knuth Vol 2 Ch 4.3.1 p 268 Algorithm M. 179 */ 180 181 DIGIT_T k, t[2]; 182 size_t i, j, m, n; 183 184 //assert(w != u && w != v); 185 186 m = n = ndigits; 187 188 /* Step M1. Initialise */ 189 for (i = 0; i < 2 * m; i++) 190 w[i] = 0; 191 192 for (j = 0; j < n; j++) 193 { 194 /* Step M2. Zero multiplier? */ 195 if (v[j] == 0) 196 { 197 w[j + m] = 0; 198 } 199 else 200 { 201 /* Step M3. Initialise i */ 202 k = 0; 203 for (i = 0; i < m; i++) 204 { 205 /* Step M4. Multiply and add */ 206 /* t = u_i * v_j + w_(i+j) + k */ 207 spMultiply(t, u[i], v[j]); 208 209 t[0] += k; 210 if (t[0] < k) 211 t[1]++; 212 t[0] += w[i+j]; 213 if (t[0] < w[i+j]) 214 t[1]++; 215 216 w[i+j] = t[0]; 217 k = t[1]; 218 } 219 /* Step M5. Loop on i, set w_(j+m) = k */ 220 w[j+m] = k; 221 } 222 } /* Step M6. Loop on j */ 223 224 return 0; 225 } 226 227 DIGIT_T mpAdd( DIGIT_T *w, const DIGIT_T *u, const DIGIT_T *v, size_t ndigits) 228 { 229 /* Calculates w = u + v 230 where w, u, v are multiprecision integers of ndigits each 231 Returns carry if overflow. Carry = 0 or 1. 232 233 Ref: Knuth Vol 2 Ch 4.3.1 p 266 Algorithm A. 234 */ 235 236 DIGIT_T k; 237 size_t j; 238 239 //assert(w != v); 240 241 /* Step A1. Initialise */ 242 k = 0; 243 244 for (j = 0; j < ndigits; j++) 245 { 246 /* Step A2. Add digits w_j = (u_j + v_j + k) 247 Set k = 1 if carry (overflow) occurs 248 */ 249 w[j] = u[j] + k; 250 if (w[j] < k) 251 k = 1; 252 else 253 k = 0; 254 255 w[j] += v[j]; 256 if (w[j] < v[j]) 257 k++; 258 259 } /* Step A3. Loop on j */ 260 261 return k; /* w_n = k */ 262 } 263 264 DIGIT_T mpAdd_llg( DIGIT_T *w, const DIGIT_T *u, __global const DIGIT_T *v, size_t ndigits) 265 { 266 /* Calculates w = u + v 267 where w, u, v are multiprecision integers of ndigits each 268 Returns carry if overflow. Carry = 0 or 1. 269 270 Ref: Knuth Vol 2 Ch 4.3.1 p 266 Algorithm A. 271 */ 272 273 DIGIT_T k; 274 size_t j; 275 276 //assert(w != v); 277 278 /* Step A1. Initialise */ 279 k = 0; 280 281 for (j = 0; j < ndigits; j++) 282 { 283 /* Step A2. Add digits w_j = (u_j + v_j + k) 284 Set k = 1 if carry (overflow) occurs 285 */ 286 w[j] = u[j] + k; 287 if (w[j] < k) 288 k = 1; 289 else 290 k = 0; 291 292 w[j] += v[j]; 293 if (w[j] < v[j]) 294 k++; 295 296 } /* Step A3. Loop on j */ 297 298 return k; /* w_n = k */ 299 } 300 301 int mpDivide(DIGIT_T *q, DIGIT_T *r, DIGIT_T *u, size_t udigits, __global DIGIT_T *v, size_t vdigits) 302 { /* Computes quotient q = u / v and remainder r = u mod v 303 where q, r, u are multiple precision digits 304 all of udigits and the divisor v is vdigits. 305 306 Ref: Knuth Vol 2 Ch 4.3.1 p 272 Algorithm D. 307 308 Do without extra storage space, i.e. use r[] for 309 normalised u[], unnormalise v[] at end, and cope with 310 extra digit Uj+n added to u after normalisation. 311 312 WARNING: this trashes q and r first, so cannot do 313 u = u / v or v = u mod v. 314 It also changes v temporarily so cannot make it const. 315 */ 316 size_t shift; 317 int n, m, j; 318 DIGIT_T bitmask, overflow; 319 DIGIT_T qhat, rhat, t[2]; 320 DIGIT_T *uu, *ww; 321 int qhatOK, cmp; 322 323 /* Clear q and r */ 324 mpSetZero(q, udigits); 325 mpSetZero(r, udigits); 326 327 /* Work out exact sizes of u and v */ 328 n = (int)mpSizeof_g(v, vdigits); 329 m = (int)mpSizeof(u, udigits); 330 m -= n; 331 332 /* Catch special cases */ 333 if (n == 0) 334 return -1; /* Error: divide by zero */ 335 336 if (n == 1) 337 { /* Use short division instead */ 338 r[0] = mpShortDiv(q, u, v[0], udigits); 339 return 0; 340 } 341 342 if (m < 0) 343 { /* v > u, so just set q = 0 and r = u */ 344 mpSetEqual(r, u, udigits); 345 return 0; 346 } 347 348 if (m == 0) 349 { /* u and v are the same length */ 350 cmp = mpCompare_lg(u, v, (size_t)n); 351 if (cmp < 0) 352 { /* v > u, as above */ 353 mpSetEqual(r, u, udigits); 354 return 0; 355 } 356 else if (cmp == 0) 357 { /* v == u, so set q = 1 and r = 0 */ 358 mpSetDigit(q, 1, udigits); 359 return 0; 360 } 361 } 362 363 /* In Knuth notation, we have: 364 Given 365 u = (Um+n-1 ... U1U0) 366 v = (Vn-1 ... V1V0) 367 Compute 368 q = u/v = (QmQm-1 ... Q0) 369 r = u mod v = (Rn-1 ... R1R0) 370 */ 371 372 /* Step D1. Normalise */ 373 /* Requires high bit of Vn-1 374 to be set, so find most signif. bit then shift left, 375 i.e. d = 2^shift, u' = u * d, v' = v * d. 376 */ 377 bitmask = HIBITMASK; 378 for (shift = 0; shift < BITS_PER_DIGIT; shift++) 379 { 380 if (v[n-1] & bitmask) 381 break; 382 bitmask >>= 1; 383 } 384 385 /* Normalise v in situ - NB only shift non-zero digits */ 386 overflow = mpShiftLeft_gg(v, v, shift, n); 387 388 /* Copy normalised dividend u*d into r */ 389 overflow = mpShiftLeft(r, u, shift, n + m); 390 uu = r; /* Use ptr to keep notation constant */ 391 392 t[0] = overflow; /* Extra digit Um+n */ 393 394 395 /* Step D2. Initialise j. Set j = m */ 396 397 for (j = m; j >= 0; j--) 398 { 399 /* Step D3. Set Qhat = [(b.Uj+n + Uj+n-1)/Vn-1] 400 and Rhat = remainder */ 401 qhatOK = 0; 402 t[1] = t[0]; /* This is Uj+n */ 403 t[0] = uu[j+n-1]; 404 overflow = spDivide(&qhat, &rhat, t, v[n-1]); 405 406 /* Test Qhat */ 407 if (overflow) 408 { /* Qhat == b so set Qhat = b - 1 */ 409 qhat = MAX_DIGIT; 410 rhat = uu[j+n-1]; 411 rhat += v[n-1]; 412 if (rhat < v[n-1]) /* Rhat >= b, so no re-test */ 413 qhatOK = 1; 414 } 415 /* [VERSION 2: Added extra test "qhat && "] */ 416 if (qhat && !qhatOK && QhatTooBig(qhat, rhat, v[n-2], uu[j+n-2])) 417 { /* If Qhat.Vn-2 > b.Rhat + Uj+n-2 418 decrease Qhat by one, increase Rhat by Vn-1 419 */ 420 qhat--; 421 rhat += v[n-1]; 422 /* Repeat this test if Rhat < b */ 423 if (!(rhat < v[n-1])) 424 if (QhatTooBig(qhat, rhat, v[n-2], uu[j+n-2])) 425 qhat--; 426 } 427 428 429 /* Step D4. Multiply and subtract */ 430 ww = &uu[j]; 431 overflow = mpMultSub_lg(t[1], ww, v, qhat, (size_t)n); 432 433 /* Step D5. Test remainder. Set Qj = Qhat */ 434 q[j] = qhat; 435 if (overflow) 436 { /* Step D6. Add back if D4 was negative */ 437 q[j]--; 438 overflow = mpAdd_llg(ww, ww, v, (size_t)n); 439 } 440 441 t[0] = uu[j+n-1]; /* Uj+n on next round */ 442 443 } /* Step D7. Loop on j */ 444 445 /* Clear high digits in uu */ 446 for (j = n; j < m+n; j++) 447 uu[j] = 0; 448 449 /* Step D8. Unnormalise. */ 450 451 mpShiftRight(r, r, shift, n); 452 mpShiftRight_gg(v, v, shift, n); 453 454 return 0; 455 } 456 457 void mpSetDigit( DIGIT_T *a, DIGIT_T d, size_t ndigits) 458 { /* Sets a = d where d is a single digit */ 459 size_t i; 460 461 for (i = 1; i < ndigits; i++) 462 { 463 a[i] = 0; 464 } 465 a[0] = d; 466 } 467 468 void mpSetDigit_g(__global DIGIT_T *a, DIGIT_T d, size_t ndigits) 469 { /* Sets a = d where d is a single digit */ 470 size_t i; 471 472 for (i = 1; i < ndigits; i++) 473 { 474 a[i] = 0; 475 } 476 a[0] = d; 477 } 478 479 DIGIT_T mpShortDiv( DIGIT_T *q, const DIGIT_T *u, DIGIT_T v, 480 size_t ndigits) 481 { 482 /* Calculates quotient q = u div v 483 Returns remainder r = u mod v 484 where q, u are multiprecision integers of ndigits each 485 and r, v are single precision digits. 486 487 Makes no assumptions about normalisation. 488 489 Ref: Knuth Vol 2 Ch 4.3.1 Exercise 16 p625 490 */ 491 size_t j; 492 DIGIT_T t[2], r; 493 size_t shift; 494 DIGIT_T bitmask, overflow, *uu; 495 496 if (ndigits == 0) return 0; 497 if (v == 0) return 0; /* Divide by zero error */ 498 499 /* Normalise first */ 500 /* Requires high bit of V 501 to be set, so find most signif. bit then shift left, 502 i.e. d = 2^shift, u' = u * d, v' = v * d. 503 */ 504 bitmask = HIBITMASK; 505 for (shift = 0; shift < BITS_PER_DIGIT; shift++) 506 { 507 if (v & bitmask) 508 break; 509 bitmask >>= 1; 510 } 511 512 v <<= shift; 513 overflow = mpShiftLeft(q, u, shift, ndigits); 514 uu = q; 515 516 /* Step S1 - modified for extra digit. */ 517 r = overflow; /* New digit Un */ 518 j = ndigits; 519 while (j--) 520 { 521 /* Step S2. */ 522 t[1] = r; 523 t[0] = uu[j]; 524 overflow = spDivide(&q[j], &r, t, v); 525 } 526 527 /* Unnormalise */ 528 r >>= shift; 529 530 return r; 531 } 532 533 int QhatTooBig(DIGIT_T qhat, DIGIT_T rhat, 534 DIGIT_T vn2, DIGIT_T ujn2) 535 { /* Returns true if Qhat is too big 536 i.e. if (Qhat * Vn-2) > (b.Rhat + Uj+n-2) 537 */ 538 DIGIT_T t[2]; 539 540 spMultiply(t, qhat, vn2); 541 if (t[1] < rhat) 542 return 0; 543 else if (t[1] > rhat) 544 return 1; 545 else if (t[0] > ujn2) 546 return 1; 547 548 return 0; 549 } 550 551 DIGIT_T mpMultSub(DIGIT_T wn, DIGIT_T *w, const DIGIT_T *v, 552 DIGIT_T q, size_t n) 553 { /* Compute w = w - qv 554 where w = (WnW[n-1]...W[0]) 555 return modified Wn. 556 */ 557 DIGIT_T k, t[2]; 558 size_t i; 559 560 if (q == 0) /* No change */ 561 return wn; 562 563 k = 0; 564 565 for (i = 0; i < n; i++) 566 { 567 spMultiply(t, q, v[i]); 568 w[i] -= k; 569 if (w[i] > MAX_DIGIT - k) 570 k = 1; 571 else 572 k = 0; 573 w[i] -= t[0]; 574 if (w[i] > MAX_DIGIT - t[0]) 575 k++; 576 k += t[1]; 577 } 578 579 /* Cope with Wn not stored in array w[0..n-1] */ 580 wn -= k; 581 582 return wn; 583 } 584 585 DIGIT_T mpMultSub_lg(DIGIT_T wn, DIGIT_T *w, __global const DIGIT_T *v, DIGIT_T q, size_t n) 586 { /* Compute w = w - qv 587 where w = (WnW[n-1]...W[0]) 588 return modified Wn. 589 */ 590 DIGIT_T k, t[2]; 591 size_t i; 592 593 if (q == 0) /* No change */ 594 return wn; 595 596 k = 0; 597 598 for (i = 0; i < n; i++) 599 { 600 spMultiply(t, q, v[i]); 601 w[i] -= k; 602 if (w[i] > MAX_DIGIT - k) 603 k = 1; 604 else 605 k = 0; 606 w[i] -= t[0]; 607 if (w[i] > MAX_DIGIT - t[0]) 608 k++; 609 k += t[1]; 610 } 611 612 /* Cope with Wn not stored in array w[0..n-1] */ 613 wn -= k; 614 615 return wn; 616 } 617 618 DIGIT_T mpShiftLeft( DIGIT_T *a, const DIGIT_T *b, size_t shift, size_t ndigits) 619 { /* Computes a = b << shift */ 620 /* [v2.1] Modified to cope with shift > BITS_PERDIGIT */ 621 622 DIGIT_T carry = 0; 623 624 // this replaces the recursion 625 while (1) { 626 627 size_t i, y, nw, bits; 628 DIGIT_T mask, tempCarry, nextcarry; 629 630 /* Do we shift whole digits? */ 631 if (shift >= BITS_PER_DIGIT) 632 { 633 nw = shift / BITS_PER_DIGIT; 634 i = ndigits; 635 while (i--) 636 { 637 if (i >= nw) 638 a[i] = b[i-nw]; 639 else 640 a[i] = 0; 641 } 642 /* Call again to shift bits inside digits */ 643 bits = shift % BITS_PER_DIGIT; 644 tempCarry = b[ndigits-nw] << bits; 645 if (bits) { 646 carry |= tempCarry; 647 continue; 648 } 649 return carry; 650 } 651 else 652 { 653 bits = shift; 654 } 655 656 /* Construct mask = high bits set */ 657 mask = ~(~(DIGIT_T)0 >> bits); 658 659 y = BITS_PER_DIGIT - bits; 660 carry = 0; 661 for (i = 0; i < ndigits; i++) 662 { 663 nextcarry = (b[i] & mask) >> y; 664 a[i] = b[i] << bits | carry; 665 carry = nextcarry; 666 } 667 668 return carry; 669 670 } 671 } 672 673 DIGIT_T mpShiftLeft_gg(__global DIGIT_T *a, __global const DIGIT_T *b, size_t shift, size_t ndigits) 674 { /* Computes a = b << shift */ 675 /* [v2.1] Modified to cope with shift > BITS_PERDIGIT */ 676 677 DIGIT_T carry = 0; 678 679 // this replaces the recursion 680 while (1) { 681 682 size_t i, y, nw, bits; 683 DIGIT_T mask, tempCarry, nextcarry; 684 685 /* Do we shift whole digits? */ 686 if (shift >= BITS_PER_DIGIT) 687 { 688 nw = shift / BITS_PER_DIGIT; 689 i = ndigits; 690 while (i--) 691 { 692 if (i >= nw) 693 a[i] = b[i-nw]; 694 else 695 a[i] = 0; 696 } 697 /* Call again to shift bits inside digits */ 698 bits = shift % BITS_PER_DIGIT; 699 tempCarry = b[ndigits-nw] << bits; 700 if (bits) { 701 carry |= tempCarry; 702 continue; 703 } 704 return carry; 705 } 706 else 707 { 708 bits = shift; 709 } 710 711 /* Construct mask = high bits set */ 712 mask = ~(~(DIGIT_T)0 >> bits); 713 714 y = BITS_PER_DIGIT - bits; 715 carry = 0; 716 for (i = 0; i < ndigits; i++) 717 { 718 nextcarry = (b[i] & mask) >> y; 719 a[i] = b[i] << bits | carry; 720 carry = nextcarry; 721 } 722 723 return carry; 724 725 } 726 } 727 728 DIGIT_T mpShiftLeft_lg(DIGIT_T *a, __global const DIGIT_T *b, size_t shift, size_t ndigits) 729 { /* Computes a = b << shift */ 730 /* [v2.1] Modified to cope with shift > BITS_PERDIGIT */ 731 732 DIGIT_T carry = 0; 733 734 // this replaces the recursion 735 while (1) { 736 737 size_t i, y, nw, bits; 738 DIGIT_T mask, tempCarry, nextcarry; 739 740 /* Do we shift whole digits? */ 741 if (shift >= BITS_PER_DIGIT) 742 { 743 nw = shift / BITS_PER_DIGIT; 744 i = ndigits; 745 while (i--) 746 { 747 if (i >= nw) 748 a[i] = b[i-nw]; 749 else 750 a[i] = 0; 751 } 752 /* Call again to shift bits inside digits */ 753 bits = shift % BITS_PER_DIGIT; 754 tempCarry = b[ndigits-nw] << bits; 755 if (bits) { 756 carry |= tempCarry; 757 continue; 758 } 759 return carry; 760 } 761 else 762 { 763 bits = shift; 764 } 765 766 /* Construct mask = high bits set */ 767 mask = ~(~(DIGIT_T)0 >> bits); 768 769 y = BITS_PER_DIGIT - bits; 770 carry = 0; 771 for (i = 0; i < ndigits; i++) 772 { 773 nextcarry = (b[i] & mask) >> y; 774 a[i] = b[i] << bits | carry; 775 carry = nextcarry; 776 } 777 778 return carry; 779 780 } 781 } 782 783 DIGIT_T mpShiftRight( DIGIT_T *a, const DIGIT_T *b, size_t shift, size_t ndigits) 784 { /* Computes a = b >> shift */ 785 /* [v2.1] Modified to cope with shift > BITS_PERDIGIT */ 786 787 DIGIT_T carry = 0; 788 789 while (1) { 790 791 size_t i, y, nw, bits; 792 DIGIT_T mask, tempCarry, nextcarry; 793 794 /* Do we shift whole digits? */ 795 if (shift >= BITS_PER_DIGIT) 796 { 797 nw = shift / BITS_PER_DIGIT; 798 for (i = 0; i < ndigits; i++) 799 { 800 if ((i+nw) < ndigits) 801 a[i] = b[i+nw]; 802 else 803 a[i] = 0; 804 } 805 /* Call again to shift bits inside digits */ 806 bits = shift % BITS_PER_DIGIT; 807 tempCarry = b[nw-1] >> bits; 808 if (bits) 809 carry |= tempCarry; 810 return carry; 811 } 812 else 813 { 814 bits = shift; 815 } 816 817 /* Construct mask to set low bits */ 818 /* (thanks to Jesse Chisholm for suggesting this improved technique) */ 819 mask = ~(~(DIGIT_T)0 << bits); 820 821 y = BITS_PER_DIGIT - bits; 822 carry = 0; 823 i = ndigits; 824 while (i--) 825 { 826 nextcarry = (b[i] & mask) << y; 827 a[i] = b[i] >> bits | carry; 828 carry = nextcarry; 829 } 830 831 return carry; 832 833 } 834 } 835 836 DIGIT_T mpShiftRight_gg(__global DIGIT_T *a, __global const DIGIT_T *b, size_t shift, size_t ndigits) 837 { /* Computes a = b >> shift */ 838 /* [v2.1] Modified to cope with shift > BITS_PERDIGIT */ 839 840 DIGIT_T carry = 0; 841 842 while (1) { 843 844 size_t i, y, nw, bits; 845 DIGIT_T mask, tempCarry, nextcarry; 846 847 /* Do we shift whole digits? */ 848 if (shift >= BITS_PER_DIGIT) 849 { 850 nw = shift / BITS_PER_DIGIT; 851 for (i = 0; i < ndigits; i++) 852 { 853 if ((i+nw) < ndigits) 854 a[i] = b[i+nw]; 855 else 856 a[i] = 0; 857 } 858 /* Call again to shift bits inside digits */ 859 bits = shift % BITS_PER_DIGIT; 860 tempCarry = b[nw-1] >> bits; 861 if (bits) 862 carry |= tempCarry; 863 return carry; 864 } 865 else 866 { 867 bits = shift; 868 } 869 870 /* Construct mask to set low bits */ 871 /* (thanks to Jesse Chisholm for suggesting this improved technique) */ 872 mask = ~(~(DIGIT_T)0 << bits); 873 874 y = BITS_PER_DIGIT - bits; 875 carry = 0; 876 i = ndigits; 877 while (i--) 878 { 879 nextcarry = (b[i] & mask) << y; 880 a[i] = b[i] >> bits | carry; 881 carry = nextcarry; 882 } 883 884 return carry; 885 886 } 887 } 888 889 int spMultiply(uint p[2], uint x, uint y) 890 { 891 892 893 894 895 /* Use a 64-bit temp for product */ 896 //ulong t = (ulong)x * (ulong)y; 897 /* then split into two parts */ 898 p[1] = mul_hi(x,y); 899 p[0] = x * y; 900 901 return 0; 902 } 903 904 uint spDivide(uint *pq, uint *pr, const uint u[2], uint v) 905 { 906 ulong uu, q; 907 uu = (ulong)u[1] << 32 | (ulong)u[0]; 908 q = uu / (ulong)v; 909 //r = uu % (uint64_t)v; 910 *pr = (uint)(uu - q * v); 911 *pq = (uint)(q & 0xFFFFFFFF); 912 return (uint)(q >> 32); 913 } 914 915 int mpCompare(const DIGIT_T *a, const DIGIT_T *b, size_t ndigits) 916 { 917 /* if (ndigits == 0) return 0; // deleted [v2.5] */ 918 919 while (ndigits--) 920 { 921 if (a[ndigits] > b[ndigits]) 922 return 1; /* GT */ 923 if (a[ndigits] < b[ndigits]) 924 return -1; /* LT */ 925 } 926 927 return 0; /* EQ */ 928 } 929 930 int mpCompare_g(__global const DIGIT_T *a, __global const DIGIT_T *b, size_t ndigits) 931 { 932 /* if (ndigits == 0) return 0; // deleted [v2.5] */ 933 934 while (ndigits--) 935 { 936 if (a[ndigits] > b[ndigits]) 937 return 1; /* GT */ 938 if (a[ndigits] < b[ndigits]) 939 return -1; /* LT */ 940 } 941 942 return 0; /* EQ */ 943 } 944 945 int mpCompare_lg(const DIGIT_T *a, __global const DIGIT_T *b, size_t ndigits) 946 { 947 /* if (ndigits == 0) return 0; // deleted [v2.5] */ 948 949 while (ndigits--) 950 { 951 if (a[ndigits] > b[ndigits]) 952 return 1; /* GT */ 953 if (a[ndigits] < b[ndigits]) 954 return -1; /* LT */ 955 } 956 957 return 0; /* EQ */ 958 } 959 960 void mpSetEqual(DIGIT_T *a, const DIGIT_T *b, size_t ndigits) 961 { /* Sets a = b */ 962 size_t i; 963 964 for (i = 0; i < ndigits; i++) 965 { 966 a[i] = b[i]; 967 } 968 } 969 970 void mpSetEqual_lg(DIGIT_T *a, __global const DIGIT_T *b, size_t ndigits) 971 { /* Sets a = b */ 972 size_t i; 973 974 for (i = 0; i < ndigits; i++) 975 { 976 a[i] = b[i]; 977 } 978 } 979 980 void mpSetEqual_gl(__global DIGIT_T *a, const DIGIT_T *b, size_t ndigits) 981 { /* Sets a = b */ 982 size_t i; 983 984 for (i = 0; i < ndigits; i++) 985 { 986 a[i] = b[i]; 987 } 988 } 989 990 volatile DIGIT_T mpSetZero(volatile DIGIT_T *a, size_t ndigits) 991 { /* Sets a = 0 */ 992 993 /* Prevent optimiser ignoring this */ 994 volatile DIGIT_T optdummy; 995 volatile DIGIT_T *p = a; 996 997 while (ndigits--) 998 a[ndigits] = 0; 999 1000 optdummy = *p; 1001 return optdummy; 1002 } 1003 1004 volatile DIGIT_T mpSetZero_g(__global volatile DIGIT_T *a, size_t ndigits) 1005 { /* Sets a = 0 */ 1006 1007 /* Prevent optimiser ignoring this */ 1008 volatile DIGIT_T optdummy; 1009 __global volatile DIGIT_T *p = a; 1010 1011 while (ndigits--) 1012 a[ndigits] = 0; 1013 1014 optdummy = *p; 1015 return optdummy; 1016 } 1017 1018 size_t mpSizeof(const DIGIT_T *a, size_t ndigits) 1019 { 1020 while(ndigits--) 1021 { 1022 if (a[ndigits] != 0) 1023 return (++ndigits); 1024 } 1025 return 0; 1026 } 1027 1028 size_t mpSizeof_g(__global DIGIT_T *a, size_t ndigits) 1029 { 1030 while(ndigits--) 1031 { 1032 if (a[ndigits] != 0) 1033 return (++ndigits); 1034 } 1035 return 0; 1036 } 1037 1038 volatile uint8 zeroise_bytes(volatile void *v, size_t n) 1039 { /* Zeroise byte array b and make sure optimiser does not ignore this */ 1040 volatile uint8 optdummy; 1041 volatile uint8 *b = (uint8*)v; 1042 while(n--) 1043 b[n] = 0; 1044 optdummy = *b; 1045 return optdummy; 1046 } 1047 1048 void mpFail(char *msg) 1049 { 1050 //perror(msg); 1051 printf("the program should stop here"); 1052 } 1053 1054 size_t mpBitLength( const DIGIT_T *d, size_t ndigits) 1055 /* Returns no of significant bits in d */ 1056 { 1057 size_t n, i, bits; 1058 DIGIT_T mask; 1059 1060 if (!d || ndigits == 0) 1061 return 0; 1062 1063 n = mpSizeof(d, ndigits); 1064 if (0 == n) return 0; 1065 1066 for (i = 0, mask = HIBITMASK; mask > 0; mask >>= 1, i++) 1067 { 1068 if (d[n-1] & mask) 1069 break; 1070 } 1071 1072 bits = n * BITS_PER_DIGIT - i; 1073 1074 return bits; 1075 } 1076 /* 1077 void mpModSquareTemp(DIGIT_T *y,DIGIT_T *m,size_t n,DIGIT_T *t1,DIGIT_T *t2) { 1078 1079 mpSquare(t1,y,n); 1080 mpDivide(t2,y,t1,n*2,m,n); 1081 1082 } 1083 1084 void mpModMultTemp(DIGIT_T *y, DIGIT_T *x, DIGIT_T *m, size_t n, DIGIT_T* t1, DIGIT_T *t2) { 1085 1086 mpMultiply(t1,x,y,n); 1087 mpDivide(t2,y,t1,n*2,m,n); 1088 1089 } 1090 1091 */ 1092 int mpSquare(DIGIT_T *w, const DIGIT_T *x, size_t ndigits) 1093 /* New in Version 2.0 */ 1094 { 1095 /* Computes square w = x * x 1096 where x is a multiprecision integer of ndigits 1097 and w is a multiprecision integer of 2*ndigits 1098 1099 Ref: Menezes p596 Algorithm 14.16 with errata. 1100 */ 1101 1102 DIGIT_T k, p[2], u[2], cbit, carry; 1103 size_t i, j, t, i2, cpos; 1104 1105 assert(w != x); 1106 1107 t = ndigits; 1108 1109 /* 1. For i from 0 to (2t-1) do: w_i = 0 */ 1110 i2 = t << 1; 1111 for (i = 0; i < i2; i++) 1112 w[i] = 0; 1113 1114 carry = 0; 1115 cpos = i2-1; 1116 /* 2. For i from 0 to (t-1) do: */ 1117 for (i = 0; i < t; i++) 1118 { 1119 /* 2.1 (uv) = w_2i + x_i * x_i, w_2i = v, c = u 1120 Careful, w_2i may be double-prec 1121 */ 1122 i2 = i << 1; /* 2*i */ 1123 spMultiply(p, x[i], x[i]); 1124 p[0] += w[i2]; 1125 if (p[0] < w[i2]) 1126 p[1]++; 1127 k = 0; /* p[1] < b, so no overflow here */ 1128 if (i2 == cpos && carry) 1129 { 1130 p[1] += carry; 1131 if (p[1] < carry) 1132 k++; 1133 carry = 0; 1134 } 1135 w[i2] = p[0]; 1136 u[0] = p[1]; 1137 u[1] = k; 1138 1139 /* 2.2 for j from (i+1) to (t-1) do: 1140 (uv) = w_{i+j} + 2x_j * x_i + c, 1141 w_{i+j} = v, c = u, 1142 u is double-prec 1143 w_{i+j} is dbl if [i+j] == cpos 1144 */ 1145 k = 0; 1146 for (j = i+1; j < t; j++) 1147 { 1148 /* p = x_j * x_i */ 1149 spMultiply(p, x[j], x[i]); 1150 /* p = 2p <=> p <<= 1 */ 1151 cbit = (p[0] & HIBITMASK) != 0; 1152 k = (p[1] & HIBITMASK) != 0; 1153 p[0] <<= 1; 1154 p[1] <<= 1; 1155 p[1] |= cbit; 1156 /* p = p + c */ 1157 p[0] += u[0]; 1158 if (p[0] < u[0]) 1159 { 1160 p[1]++; 1161 if (p[1] == 0) 1162 k++; 1163 } 1164 p[1] += u[1]; 1165 if (p[1] < u[1]) 1166 k++; 1167 /* p = p + w_{i+j} */ 1168 p[0] += w[i+j]; 1169 if (p[0] < w[i+j]) 1170 { 1171 p[1]++; 1172 if (p[1] == 0) 1173 k++; 1174 } 1175 if ((i+j) == cpos && carry) 1176 { /* catch overflow from last round */ 1177 p[1] += carry; 1178 if (p[1] < carry) 1179 k++; 1180 carry = 0; 1181 } 1182 /* w_{i+j} = v, c = u */ 1183 w[i+j] = p[0]; 1184 u[0] = p[1]; 1185 u[1] = k; 1186 } 1187 /* 2.3 w_{i+t} = u */ 1188 w[i+t] = u[0]; 1189 /* remember overflow in w_{i+t} */ 1190 carry = u[1]; 1191 cpos = i+t; 1192 } 1193 1194 /* (NB original step 3 deleted in Menezes errata) */ 1195 1196 /* Return w */ 1197 1198 return 0; 1199 } 1200 1201 int mpIsZero( const DIGIT_T *a, size_t ndigits) 1202 { 1203 size_t i; 1204 1205 /* if (ndigits == 0) return -1; // deleted [v2.5] */ 1206 1207 for (i = 0; i < ndigits; i++) /* Start at lsb */ 1208 { 1209 if (a[i] != 0) 1210 return 0; /* False */ 1211 } 1212 1213 return (!0); /* True */ 1214 } 1215 1216 1217 1218 __kernel void several(__global DIGIT_T* x, //__global const unsigned long *s_len, 1219 __global DIGIT_T* e, //__global const unsigned long *e_len, 1220 __global DIGIT_T* m, //__global const unsigned long *n_len, 1221 __global DIGIT_T *mm,// __global const unsigned long *mm_len, 1222 __global uint* valid, 1223 __global uint *pks 1224 ) { 1225 1226 int index = get_global_id(0); 1227 1228 int pk = 0; 1229 1230 pk = pks[index]; 1231 1232 int ndigits = MAX_ALLOC_SIZE; 1233 int edigits = 1; 1234 1235 //printf((char __constant *)"%i\n", ndigits); 1236 1237 // the result is copied in here, compare it to mm 1238 DIGIT_T yout[MAX_ALLOC_SIZE * 2]; 1239 1240 DIGIT_T mask; 1241 unsigned long n; 1242 1243 __global DIGIT_T * __private window_x; // private scope pointers to global memory 1244 __global DIGIT_T * __private window_e; 1245 __global DIGIT_T * __private window_m; 1246 __global DIGIT_T * __private window_mm; 1247 1248 window_e = &e[pk]; 1249 window_m = &m[pk * MAX_ALLOC_SIZE]; 1250 1251 window_x = &x[index * MAX_ALLOC_SIZE]; 1252 window_mm = &mm[index * MAX_ALLOC_SIZE]; 1253 1254 1255 __private DIGIT_T t1[MAX_ALLOC_SIZE *2]; // obsolete? 1256 __private DIGIT_T t2[MAX_ALLOC_SIZE *2]; // obsolete? 1257 __private DIGIT_T y[MAX_ALLOC_SIZE *2]; // obsolete? 1258 1259 n = mpSizeof_g(window_e, edigits); 1260 /* Catch e==0 => x^0=1 */ 1261 if (0 == n) 1262 { 1263 mpSetDigit(yout, 1, ndigits); 1264 1265 } 1266 /* Find second-most significant bit in e */ 1267 for (mask = HIBITMASK; mask > 0; mask >>= 1) 1268 { 1269 if (window_e[n-1] & mask) 1270 break; 1271 } 1272 mpNEXTBITMASK(mask, n); 1273 1274 /* Set y = x */ 1275 mpSetEqual_lg(y, window_x, ndigits); 1276 1277 /* For bit j = k-2 downto 0 */ 1278 while (n) 1279 { 1280 /* Square y = y * y mod n */ 1281 mpMODSQUARETEMP(y, window_m, ndigits, t1, t2); 1282 1283 1284 if (e[n-1] & mask) 1285 { /* if e(j) == 1 then multiply 1286 y = y * x mod n */ 1287 mpMODMULTTEMP(y, window_x, window_m, ndigits, t1, t2); 1288 1289 } 1290 1291 /* Move to next bit */ 1292 mpNEXTBITMASK(mask, n); 1293 } 1294 1295 mpSetEqual(yout, y, ndigits); 1296 1297 int len = MAX_ALLOC_SIZE; 1298 1299 1300 // only increase if there was an error verifying – because otherwise several kernels might write simultaneously, and not every success would be counted 1301 if (mpCompare_lg(yout,window_mm,len) == 0) { 1302 1303 uint out_offset = index / (sizeof(uint) * 8); // 32 bit 1304 1305 uint mv = 1 << index; 1306 1307 atomic_or(&valid[out_offset], mv); 1308 } 1309 1310 }