libgpuverify

Signature verification on GPUs (WiP)
Log | Files | Refs | README | LICENSE

gpuv.cl (32945B)


      1 /*
      2  * lib-gpu-verify
      3  *
      4  * This software contains code derived from or inspired by the BigDigit library,
      5  * <http://www.di-mgt.com.au/bigdigits.html>
      6  * which is distributed under the Mozilla Public License, version 2.0.
      7  *
      8  * The original code and modifications made to it are subject to the terms and
      9  * conditions of the Mozilla Public License, version 2.0. A copy of the
     10  * MPL license can be obtained at
     11  * https://www.mozilla.org/en-US/MPL/2.0/.
     12  *
     13  * Changes and additions to the original code are as follows:
     14  *   - Copied various parts of the BigDigit library into this kernel.
     15  *   - Some functions and macros were changed to accomodate the architecture of OpenCL.
     16  *   - functions were added to extend the functionality required by this OpenCL kernel.
     17  *
     18  * Contributors:
     19  *   - Cedric Zwahlen cedric.zwahlen@bfh.ch
     20  *
     21  * Please note that this software is distributed on an "AS IS" BASIS,
     22  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     23  * See the Mozilla Public License, version 2.0, for the specific language
     24  * governing permissions and limitations under the License.
     25  */
     26 
     27 // macros
     28 
     29 //#define max(a,b)            (((a) > (b)) ? (a) : (b))
     30 
     31 // only for that string conversion
     32 #define ALLOC_BYTES(b,n) do{assert((n)<=sizeof((b)));zeroise_bytes((b),(n));}while(0)
     33 #define FREE_BYTES(b,n) zeroise_bytes((b),(n))
     34 
     35 
     36 #define MAX_DIGIT 0xFFFFFFFFUL
     37 #define MAX_HALF_DIGIT 0xFFFFUL    /* NB 'L' */
     38 #define BITS_PER_DIGIT 32
     39 #define HIBITMASK 0x80000000UL
     40 
     41 #define MAX_FIXED_BIT_LENGTH 8192
     42 #define MAX_FIXED_DIGITS ((MAX_FIXED_BIT_LENGTH + BITS_PER_DIGIT - 1) / BITS_PER_DIGIT)
     43 
     44 #define MAX_ALLOC_SIZE 64
     45 
     46 #define BYTES_PER_DIGIT (BITS_PER_DIGIT / 8)
     47 
     48 #define PRIuBIGD PRIu32
     49 #define PRIxBIGD PRIx32
     50 #define PRIXBIGD PRIX32
     51 
     52 /* MACROS TO DO MODULAR SQUARING AND MULTIPLICATION USING PRE-ALLOCATED TEMPS */
     53 /* Required lengths |y|=|t1|=|t2|=2*n, |m|=n; but final |y|=n */
     54 /* Square: y = (y * y) mod m */
     55 #define mpMODSQUARETEMP(y,m,n,t1,t2) do{mpSquare(t1,y,n);mpDivide(t2,y,t1,n*2,m,n);}while(0)
     56 /* Mult:   y = (y * x) mod m */
     57 #define mpMODMULTTEMP(y,x,m,n,t1,t2) do{mpMultiply(t1,x,y,n);mpDivide(t2,y,t1,n*2,m,n);}while(0)
     58 
     59 #define mpNEXTBITMASK(mask, n) do{if(mask==1){mask=HIBITMASK;n--;}else{mask>>=1;}}while(0)
     60 
     61 #define assert(x){if((x)==0){printf((char __constant *)"assert reached\n");}}
     62 
     63 typedef uint DIGIT_T;
     64 
     65 typedef uint16 HALF_DIGIT_T;
     66 
     67 
     68 // forward definitions
     69 
     70 int mpModulo(__global DIGIT_T *r, DIGIT_T *u, size_t udigits, __global DIGIT_T *v, size_t vdigits);
     71 
     72 //int mpModMult(__global DIGIT_T *a, __global DIGIT_T *x, const DIGIT_T *y, DIGIT_T *m, size_t ndigits);
     73 
     74 int mpMultiply( DIGIT_T *w, __global DIGIT_T *u, const DIGIT_T *v, size_t ndigits);
     75 
     76 DIGIT_T mpAdd( DIGIT_T *w, const DIGIT_T *u, const DIGIT_T *v, size_t ndigits);
     77 
     78 int mpDivide(DIGIT_T *q, DIGIT_T *r, DIGIT_T *u, size_t udigits, __global DIGIT_T *v, size_t vdigits);
     79 
     80 int QhatTooBig(DIGIT_T qhat, DIGIT_T rhat, DIGIT_T vn2, DIGIT_T ujn2);
     81 
     82 DIGIT_T mpMultSub(DIGIT_T wn, DIGIT_T *w, const DIGIT_T *v, DIGIT_T q, size_t n);
     83 
     84 DIGIT_T mpShiftLeft( DIGIT_T *a, const DIGIT_T *b, size_t shift, size_t ndigits);
     85 
     86 
     87 void mpSetDigit(DIGIT_T *a, DIGIT_T d, size_t ndigits);
     88 
     89 int mpCompare(const DIGIT_T *a, const DIGIT_T *b, size_t ndigits);
     90 
     91 
     92 DIGIT_T mpShiftRight( DIGIT_T *a, const DIGIT_T *b, size_t shift, size_t ndigits);
     93 int spMultiply(uint p[2], uint x, uint y);
     94 uint spDivide(uint *pq, uint *pr, const uint u[2], uint v);
     95 
     96 int mpSquare( DIGIT_T *w, const DIGIT_T *x, size_t ndigits);
     97 
     98 size_t mpBitLength(const DIGIT_T *d, size_t ndigits);
     99 
    100 DIGIT_T mpShortDiv( DIGIT_T *q, const DIGIT_T *u, DIGIT_T v,
    101                    size_t ndigits);
    102 
    103 void mpSetEqual(DIGIT_T *a, const DIGIT_T *b, size_t ndigits);
    104 
    105 size_t uiceil(float x);
    106 volatile uint8 zeroise_bytes(volatile void *v, size_t n);
    107 
    108 size_t mpSizeof(const DIGIT_T *a, size_t ndigits);
    109 
    110 volatile DIGIT_T mpSetZero(volatile DIGIT_T *a, size_t ndigits);
    111 
    112 int mpIsZero( const DIGIT_T *a, size_t ndigits);
    113 
    114 void mpFail(char *msg);
    115 
    116 //int mpModExpO( DIGIT_T *yout,  DIGIT_T *x,  DIGIT_T *e,  DIGIT_T *m, size_t ndigits);
    117 
    118 //void assert(bool precondition);
    119 
    120 // global memory definitions
    121 
    122 size_t mpSizeof_g(__global DIGIT_T *a, size_t ndigits);
    123 
    124 DIGIT_T mpShiftLeft_gg(__global DIGIT_T *a, __global const DIGIT_T *b, size_t shift, size_t ndigits);
    125 
    126 DIGIT_T mpShiftRight_gg(__global DIGIT_T *a, __global const DIGIT_T *b, size_t shift, size_t ndigits);
    127 
    128 DIGIT_T mpShiftLeft_lg(DIGIT_T *a, __global const DIGIT_T *b, size_t shift, size_t ndigits);
    129 
    130 int mpCompare_g(__global const DIGIT_T *a, __global const DIGIT_T *b, size_t ndigits);
    131 
    132 int mpCompare_lg(const DIGIT_T *a, __global const DIGIT_T *b, size_t ndigits);
    133 
    134 void mpSetEqual_lg(DIGIT_T *a, __global const DIGIT_T *b, size_t ndigits);
    135 void mpSetEqual_gl(__global DIGIT_T *a, const DIGIT_T *b, size_t ndigits);
    136 
    137 DIGIT_T mpMultSub_lg(DIGIT_T wn, DIGIT_T *w, __global const DIGIT_T *v, DIGIT_T q, size_t n);
    138 
    139 DIGIT_T mpAdd_llg( DIGIT_T *w, const DIGIT_T *u, __global const DIGIT_T *v, size_t ndigits);
    140 
    141 void mpSetDigit_g(__global DIGIT_T *a, DIGIT_T d, size_t ndigits);
    142 
    143 volatile DIGIT_T mpSetZero_g(__global volatile DIGIT_T *a, size_t ndigits);
    144 
    145 // implementation
    146 
    147 int mpModulo(__global DIGIT_T *r, DIGIT_T *u, size_t udigits, __global DIGIT_T *v, size_t vdigits)
    148 {
    149     /*    Computes r = u mod v
    150         where r, v are multiprecision integers of length vdigits
    151         and u is a multiprecision integer of length udigits.
    152         r may overlap v.
    153 
    154         Note that r here is only vdigits long,
    155         whereas in mpDivide it is udigits long.
    156 
    157         Use remainder from mpDivide function.
    158     */
    159 
    160     DIGIT_T qq[MAX_FIXED_DIGITS*2];
    161     DIGIT_T rr[MAX_FIXED_DIGITS*2];
    162 
    163     /* rr[nn] = u mod v */
    164     mpDivide(qq, rr, u, udigits, v, vdigits);
    165 
    166     /* Final r is only vdigits long */
    167     mpSetEqual_gl(r, rr, vdigits);
    168 
    169     return 0;
    170 }
    171 
    172 int mpMultiply( DIGIT_T *w, __global DIGIT_T *u, const DIGIT_T *v, size_t ndigits)
    173 {
    174     /*    Computes product w = u * v
    175         where u, v are multiprecision integers of ndigits each
    176         and w is a multiprecision integer of 2*ndigits
    177 
    178         Ref: Knuth Vol 2 Ch 4.3.1 p 268 Algorithm M.
    179     */
    180 
    181     DIGIT_T k, t[2];
    182     size_t i, j, m, n;
    183 
    184     //assert(w != u && w != v);
    185 
    186     m = n = ndigits;
    187 
    188     /* Step M1. Initialise */
    189     for (i = 0; i < 2 * m; i++)
    190         w[i] = 0;
    191 
    192     for (j = 0; j < n; j++)
    193     {
    194         /* Step M2. Zero multiplier? */
    195         if (v[j] == 0)
    196         {
    197             w[j + m] = 0;
    198         }
    199         else
    200         {
    201             /* Step M3. Initialise i */
    202             k = 0;
    203             for (i = 0; i < m; i++)
    204             {
    205                 /* Step M4. Multiply and add */
    206                 /* t = u_i * v_j + w_(i+j) + k */
    207                 spMultiply(t, u[i], v[j]);
    208 
    209                 t[0] += k;
    210                 if (t[0] < k)
    211                     t[1]++;
    212                 t[0] += w[i+j];
    213                 if (t[0] < w[i+j])
    214                     t[1]++;
    215 
    216                 w[i+j] = t[0];
    217                 k = t[1];
    218             }
    219             /* Step M5. Loop on i, set w_(j+m) = k */
    220             w[j+m] = k;
    221         }
    222     }    /* Step M6. Loop on j */
    223 
    224     return 0;
    225 }
    226 
    227 DIGIT_T mpAdd( DIGIT_T *w, const DIGIT_T *u, const DIGIT_T *v, size_t ndigits)
    228 {
    229     /*    Calculates w = u + v
    230         where w, u, v are multiprecision integers of ndigits each
    231         Returns carry if overflow. Carry = 0 or 1.
    232 
    233         Ref: Knuth Vol 2 Ch 4.3.1 p 266 Algorithm A.
    234     */
    235 
    236     DIGIT_T k;
    237     size_t j;
    238 
    239     //assert(w != v);
    240 
    241     /* Step A1. Initialise */
    242     k = 0;
    243 
    244     for (j = 0; j < ndigits; j++)
    245     {
    246         /*    Step A2. Add digits w_j = (u_j + v_j + k)
    247             Set k = 1 if carry (overflow) occurs
    248         */
    249         w[j] = u[j] + k;
    250         if (w[j] < k)
    251             k = 1;
    252         else
    253             k = 0;
    254 
    255         w[j] += v[j];
    256         if (w[j] < v[j])
    257             k++;
    258 
    259     }    /* Step A3. Loop on j */
    260 
    261     return k;    /* w_n = k */
    262 }
    263 
    264 DIGIT_T mpAdd_llg( DIGIT_T *w, const DIGIT_T *u, __global const DIGIT_T *v, size_t ndigits)
    265 {
    266     /*    Calculates w = u + v
    267         where w, u, v are multiprecision integers of ndigits each
    268         Returns carry if overflow. Carry = 0 or 1.
    269 
    270         Ref: Knuth Vol 2 Ch 4.3.1 p 266 Algorithm A.
    271     */
    272 
    273     DIGIT_T k;
    274     size_t j;
    275 
    276     //assert(w != v);
    277 
    278     /* Step A1. Initialise */
    279     k = 0;
    280 
    281     for (j = 0; j < ndigits; j++)
    282     {
    283         /*    Step A2. Add digits w_j = (u_j + v_j + k)
    284             Set k = 1 if carry (overflow) occurs
    285         */
    286         w[j] = u[j] + k;
    287         if (w[j] < k)
    288             k = 1;
    289         else
    290             k = 0;
    291 
    292         w[j] += v[j];
    293         if (w[j] < v[j])
    294             k++;
    295 
    296     }    /* Step A3. Loop on j */
    297 
    298     return k;    /* w_n = k */
    299 }
    300 
    301 int mpDivide(DIGIT_T *q, DIGIT_T *r, DIGIT_T *u, size_t udigits, __global DIGIT_T *v, size_t vdigits)
    302 {    /*    Computes quotient q = u / v and remainder r = u mod v
    303         where q, r, u are multiple precision digits
    304         all of udigits and the divisor v is vdigits.
    305 
    306         Ref: Knuth Vol 2 Ch 4.3.1 p 272 Algorithm D.
    307 
    308         Do without extra storage space, i.e. use r[] for
    309         normalised u[], unnormalise v[] at end, and cope with
    310         extra digit Uj+n added to u after normalisation.
    311 
    312         WARNING: this trashes q and r first, so cannot do
    313         u = u / v or v = u mod v.
    314         It also changes v temporarily so cannot make it const.
    315     */
    316     size_t shift;
    317     int n, m, j;
    318     DIGIT_T bitmask, overflow;
    319     DIGIT_T qhat, rhat, t[2];
    320     DIGIT_T *uu, *ww;
    321     int qhatOK, cmp;
    322 
    323     /* Clear q and r */
    324     mpSetZero(q, udigits);
    325     mpSetZero(r, udigits);
    326 
    327     /* Work out exact sizes of u and v */
    328     n = (int)mpSizeof_g(v, vdigits);
    329     m = (int)mpSizeof(u, udigits);
    330     m -= n;
    331 
    332     /* Catch special cases */
    333     if (n == 0)
    334         return -1;    /* Error: divide by zero */
    335 
    336     if (n == 1)
    337     {    /* Use short division instead */
    338         r[0] = mpShortDiv(q, u, v[0], udigits);
    339         return 0;
    340     }
    341 
    342     if (m < 0)
    343     {    /* v > u, so just set q = 0 and r = u */
    344         mpSetEqual(r, u, udigits);
    345         return 0;
    346     }
    347 
    348     if (m == 0)
    349     {    /* u and v are the same length */
    350         cmp = mpCompare_lg(u, v, (size_t)n);
    351         if (cmp < 0)
    352         {    /* v > u, as above */
    353             mpSetEqual(r, u, udigits);
    354             return 0;
    355         }
    356         else if (cmp == 0)
    357         {    /* v == u, so set q = 1 and r = 0 */
    358             mpSetDigit(q, 1, udigits);
    359             return 0;
    360         }
    361     }
    362 
    363     /*    In Knuth notation, we have:
    364         Given
    365         u = (Um+n-1 ... U1U0)
    366         v = (Vn-1 ... V1V0)
    367         Compute
    368         q = u/v = (QmQm-1 ... Q0)
    369         r = u mod v = (Rn-1 ... R1R0)
    370     */
    371 
    372     /*    Step D1. Normalise */
    373     /*    Requires high bit of Vn-1
    374         to be set, so find most signif. bit then shift left,
    375         i.e. d = 2^shift, u' = u * d, v' = v * d.
    376     */
    377     bitmask = HIBITMASK;
    378     for (shift = 0; shift < BITS_PER_DIGIT; shift++)
    379     {
    380         if (v[n-1] & bitmask)
    381             break;
    382         bitmask >>= 1;
    383     }
    384 
    385     /* Normalise v in situ - NB only shift non-zero digits */
    386     overflow = mpShiftLeft_gg(v, v, shift, n);
    387 
    388     /* Copy normalised dividend u*d into r */
    389     overflow = mpShiftLeft(r, u, shift, n + m);
    390     uu = r;    /* Use ptr to keep notation constant */
    391 
    392     t[0] = overflow;    /* Extra digit Um+n */
    393 
    394 
    395     /* Step D2. Initialise j. Set j = m */
    396 
    397     for (j = m; j >= 0; j--)
    398     {
    399         /* Step D3. Set Qhat = [(b.Uj+n + Uj+n-1)/Vn-1]
    400            and Rhat = remainder */
    401         qhatOK = 0;
    402         t[1] = t[0];    /* This is Uj+n */
    403         t[0] = uu[j+n-1];
    404         overflow = spDivide(&qhat, &rhat, t, v[n-1]);
    405 
    406         /* Test Qhat */
    407         if (overflow)
    408         {    /* Qhat == b so set Qhat = b - 1 */
    409             qhat = MAX_DIGIT;
    410             rhat = uu[j+n-1];
    411             rhat += v[n-1];
    412             if (rhat < v[n-1])    /* Rhat >= b, so no re-test */
    413                 qhatOK = 1;
    414         }
    415         /* [VERSION 2: Added extra test "qhat && "] */
    416         if (qhat && !qhatOK && QhatTooBig(qhat, rhat, v[n-2], uu[j+n-2]))
    417         {    /* If Qhat.Vn-2 > b.Rhat + Uj+n-2
    418                decrease Qhat by one, increase Rhat by Vn-1
    419             */
    420             qhat--;
    421             rhat += v[n-1];
    422             /* Repeat this test if Rhat < b */
    423             if (!(rhat < v[n-1]))
    424                 if (QhatTooBig(qhat, rhat, v[n-2], uu[j+n-2]))
    425                     qhat--;
    426         }
    427 
    428 
    429         /* Step D4. Multiply and subtract */
    430         ww = &uu[j];
    431         overflow = mpMultSub_lg(t[1], ww, v, qhat, (size_t)n);
    432 
    433         /* Step D5. Test remainder. Set Qj = Qhat */
    434         q[j] = qhat;
    435         if (overflow)
    436         {    /* Step D6. Add back if D4 was negative */
    437             q[j]--;
    438             overflow = mpAdd_llg(ww, ww, v, (size_t)n);
    439         }
    440 
    441         t[0] = uu[j+n-1];    /* Uj+n on next round */
    442 
    443     }    /* Step D7. Loop on j */
    444 
    445     /* Clear high digits in uu */
    446     for (j = n; j < m+n; j++)
    447         uu[j] = 0;
    448 
    449     /* Step D8. Unnormalise. */
    450 
    451     mpShiftRight(r, r, shift, n);
    452     mpShiftRight_gg(v, v, shift, n);
    453 
    454     return 0;
    455 }
    456 
    457 void mpSetDigit( DIGIT_T *a, DIGIT_T d, size_t ndigits)
    458 {    /* Sets a = d where d is a single digit */
    459     size_t i;
    460 
    461     for (i = 1; i < ndigits; i++)
    462     {
    463         a[i] = 0;
    464     }
    465     a[0] = d;
    466 }
    467 
    468 void mpSetDigit_g(__global DIGIT_T *a, DIGIT_T d, size_t ndigits)
    469 {    /* Sets a = d where d is a single digit */
    470     size_t i;
    471 
    472     for (i = 1; i < ndigits; i++)
    473     {
    474         a[i] = 0;
    475     }
    476     a[0] = d;
    477 }
    478 
    479 DIGIT_T mpShortDiv( DIGIT_T *q, const DIGIT_T *u, DIGIT_T v,
    480                    size_t ndigits)
    481 {
    482     /*    Calculates quotient q = u div v
    483         Returns remainder r = u mod v
    484         where q, u are multiprecision integers of ndigits each
    485         and r, v are single precision digits.
    486 
    487         Makes no assumptions about normalisation.
    488 
    489         Ref: Knuth Vol 2 Ch 4.3.1 Exercise 16 p625
    490     */
    491     size_t j;
    492     DIGIT_T t[2], r;
    493     size_t shift;
    494     DIGIT_T bitmask, overflow, *uu;
    495 
    496     if (ndigits == 0) return 0;
    497     if (v == 0)    return 0;    /* Divide by zero error */
    498 
    499     /*    Normalise first */
    500     /*    Requires high bit of V
    501         to be set, so find most signif. bit then shift left,
    502         i.e. d = 2^shift, u' = u * d, v' = v * d.
    503     */
    504     bitmask = HIBITMASK;
    505     for (shift = 0; shift < BITS_PER_DIGIT; shift++)
    506     {
    507         if (v & bitmask)
    508             break;
    509         bitmask >>= 1;
    510     }
    511 
    512     v <<= shift;
    513     overflow = mpShiftLeft(q, u, shift, ndigits);
    514     uu = q;
    515 
    516     /* Step S1 - modified for extra digit. */
    517     r = overflow;    /* New digit Un */
    518     j = ndigits;
    519     while (j--)
    520     {
    521         /* Step S2. */
    522         t[1] = r;
    523         t[0] = uu[j];
    524         overflow = spDivide(&q[j], &r, t, v);
    525     }
    526 
    527     /* Unnormalise */
    528     r >>= shift;
    529 
    530     return r;
    531 }
    532 
    533 int QhatTooBig(DIGIT_T qhat, DIGIT_T rhat,
    534                       DIGIT_T vn2, DIGIT_T ujn2)
    535 {    /*    Returns true if Qhat is too big
    536         i.e. if (Qhat * Vn-2) > (b.Rhat + Uj+n-2)
    537     */
    538     DIGIT_T t[2];
    539 
    540     spMultiply(t, qhat, vn2);
    541     if (t[1] < rhat)
    542         return 0;
    543     else if (t[1] > rhat)
    544         return 1;
    545     else if (t[0] > ujn2)
    546         return 1;
    547 
    548     return 0;
    549 }
    550 
    551 DIGIT_T mpMultSub(DIGIT_T wn, DIGIT_T *w, const DIGIT_T *v,
    552                        DIGIT_T q, size_t n)
    553 {    /*    Compute w = w - qv
    554         where w = (WnW[n-1]...W[0])
    555         return modified Wn.
    556     */
    557     DIGIT_T k, t[2];
    558     size_t i;
    559 
    560     if (q == 0)    /* No change */
    561         return wn;
    562 
    563     k = 0;
    564 
    565     for (i = 0; i < n; i++)
    566     {
    567         spMultiply(t, q, v[i]);
    568         w[i] -= k;
    569         if (w[i] > MAX_DIGIT - k)
    570             k = 1;
    571         else
    572             k = 0;
    573         w[i] -= t[0];
    574         if (w[i] > MAX_DIGIT - t[0])
    575             k++;
    576         k += t[1];
    577     }
    578 
    579     /* Cope with Wn not stored in array w[0..n-1] */
    580     wn -= k;
    581 
    582     return wn;
    583 }
    584 
    585 DIGIT_T mpMultSub_lg(DIGIT_T wn, DIGIT_T *w, __global const DIGIT_T *v, DIGIT_T q, size_t n)
    586 {    /*    Compute w = w - qv
    587         where w = (WnW[n-1]...W[0])
    588         return modified Wn.
    589     */
    590     DIGIT_T k, t[2];
    591     size_t i;
    592 
    593     if (q == 0)    /* No change */
    594         return wn;
    595 
    596     k = 0;
    597 
    598     for (i = 0; i < n; i++)
    599     {
    600         spMultiply(t, q, v[i]);
    601         w[i] -= k;
    602         if (w[i] > MAX_DIGIT - k)
    603             k = 1;
    604         else
    605             k = 0;
    606         w[i] -= t[0];
    607         if (w[i] > MAX_DIGIT - t[0])
    608             k++;
    609         k += t[1];
    610     }
    611 
    612     /* Cope with Wn not stored in array w[0..n-1] */
    613     wn -= k;
    614 
    615     return wn;
    616 }
    617 
    618 DIGIT_T mpShiftLeft( DIGIT_T *a, const DIGIT_T *b, size_t shift, size_t ndigits)
    619 {    /* Computes a = b << shift */
    620     /* [v2.1] Modified to cope with shift > BITS_PERDIGIT */
    621 
    622     DIGIT_T carry = 0;
    623 
    624     // this replaces the recursion
    625     while (1) {
    626 
    627         size_t i, y, nw, bits;
    628         DIGIT_T mask, tempCarry, nextcarry;
    629 
    630         /* Do we shift whole digits? */
    631         if (shift >= BITS_PER_DIGIT)
    632         {
    633             nw = shift / BITS_PER_DIGIT;
    634             i = ndigits;
    635             while (i--)
    636             {
    637                 if (i >= nw)
    638                     a[i] = b[i-nw];
    639                 else
    640                     a[i] = 0;
    641             }
    642             /* Call again to shift bits inside digits */
    643             bits = shift % BITS_PER_DIGIT;
    644             tempCarry = b[ndigits-nw] << bits;
    645             if (bits) {
    646                 carry |= tempCarry;
    647                 continue;
    648             }
    649             return carry;
    650         }
    651         else
    652         {
    653             bits = shift;
    654         }
    655 
    656         /* Construct mask = high bits set */
    657         mask = ~(~(DIGIT_T)0 >> bits);
    658 
    659         y = BITS_PER_DIGIT - bits;
    660         carry = 0;
    661         for (i = 0; i < ndigits; i++)
    662         {
    663             nextcarry = (b[i] & mask) >> y;
    664             a[i] = b[i] << bits | carry;
    665             carry = nextcarry;
    666         }
    667 
    668         return carry;
    669 
    670     }
    671 }
    672 
    673 DIGIT_T mpShiftLeft_gg(__global DIGIT_T *a, __global const DIGIT_T *b, size_t shift, size_t ndigits)
    674 {    /* Computes a = b << shift */
    675     /* [v2.1] Modified to cope with shift > BITS_PERDIGIT */
    676 
    677     DIGIT_T carry = 0;
    678 
    679     // this replaces the recursion
    680     while (1) {
    681 
    682         size_t i, y, nw, bits;
    683         DIGIT_T mask, tempCarry, nextcarry;
    684 
    685         /* Do we shift whole digits? */
    686         if (shift >= BITS_PER_DIGIT)
    687         {
    688             nw = shift / BITS_PER_DIGIT;
    689             i = ndigits;
    690             while (i--)
    691             {
    692                 if (i >= nw)
    693                     a[i] = b[i-nw];
    694                 else
    695                     a[i] = 0;
    696             }
    697             /* Call again to shift bits inside digits */
    698             bits = shift % BITS_PER_DIGIT;
    699             tempCarry = b[ndigits-nw] << bits;
    700             if (bits) {
    701                 carry |= tempCarry;
    702                 continue;
    703             }
    704             return carry;
    705         }
    706         else
    707         {
    708             bits = shift;
    709         }
    710 
    711         /* Construct mask = high bits set */
    712         mask = ~(~(DIGIT_T)0 >> bits);
    713 
    714         y = BITS_PER_DIGIT - bits;
    715         carry = 0;
    716         for (i = 0; i < ndigits; i++)
    717         {
    718             nextcarry = (b[i] & mask) >> y;
    719             a[i] = b[i] << bits | carry;
    720             carry = nextcarry;
    721         }
    722 
    723         return carry;
    724 
    725     }
    726 }
    727 
    728 DIGIT_T mpShiftLeft_lg(DIGIT_T *a, __global const DIGIT_T *b, size_t shift, size_t ndigits)
    729 {    /* Computes a = b << shift */
    730     /* [v2.1] Modified to cope with shift > BITS_PERDIGIT */
    731 
    732     DIGIT_T carry = 0;
    733 
    734     // this replaces the recursion
    735     while (1) {
    736 
    737         size_t i, y, nw, bits;
    738         DIGIT_T mask, tempCarry, nextcarry;
    739 
    740         /* Do we shift whole digits? */
    741         if (shift >= BITS_PER_DIGIT)
    742         {
    743             nw = shift / BITS_PER_DIGIT;
    744             i = ndigits;
    745             while (i--)
    746             {
    747                 if (i >= nw)
    748                     a[i] = b[i-nw];
    749                 else
    750                     a[i] = 0;
    751             }
    752             /* Call again to shift bits inside digits */
    753             bits = shift % BITS_PER_DIGIT;
    754             tempCarry = b[ndigits-nw] << bits;
    755             if (bits) {
    756                 carry |= tempCarry;
    757                 continue;
    758             }
    759             return carry;
    760         }
    761         else
    762         {
    763             bits = shift;
    764         }
    765 
    766         /* Construct mask = high bits set */
    767         mask = ~(~(DIGIT_T)0 >> bits);
    768 
    769         y = BITS_PER_DIGIT - bits;
    770         carry = 0;
    771         for (i = 0; i < ndigits; i++)
    772         {
    773             nextcarry = (b[i] & mask) >> y;
    774             a[i] = b[i] << bits | carry;
    775             carry = nextcarry;
    776         }
    777 
    778         return carry;
    779 
    780     }
    781 }
    782 
    783 DIGIT_T mpShiftRight( DIGIT_T *a, const DIGIT_T *b, size_t shift, size_t ndigits)
    784 {    /* Computes a = b >> shift */
    785     /* [v2.1] Modified to cope with shift > BITS_PERDIGIT */
    786 
    787     DIGIT_T carry = 0;
    788 
    789     while (1) {
    790 
    791         size_t i, y, nw, bits;
    792         DIGIT_T mask, tempCarry, nextcarry;
    793 
    794         /* Do we shift whole digits? */
    795         if (shift >= BITS_PER_DIGIT)
    796         {
    797             nw = shift / BITS_PER_DIGIT;
    798             for (i = 0; i < ndigits; i++)
    799             {
    800                 if ((i+nw) < ndigits)
    801                     a[i] = b[i+nw];
    802                 else
    803                     a[i] = 0;
    804             }
    805             /* Call again to shift bits inside digits */
    806             bits = shift % BITS_PER_DIGIT;
    807             tempCarry = b[nw-1] >> bits;
    808             if (bits)
    809                 carry |= tempCarry;
    810             return carry;
    811         }
    812         else
    813         {
    814             bits = shift;
    815         }
    816 
    817         /* Construct mask to set low bits */
    818         /* (thanks to Jesse Chisholm for suggesting this improved technique) */
    819         mask = ~(~(DIGIT_T)0 << bits);
    820 
    821         y = BITS_PER_DIGIT - bits;
    822         carry = 0;
    823         i = ndigits;
    824         while (i--)
    825         {
    826             nextcarry = (b[i] & mask) << y;
    827             a[i] = b[i] >> bits | carry;
    828             carry = nextcarry;
    829         }
    830 
    831         return carry;
    832 
    833     }
    834 }
    835 
    836 DIGIT_T mpShiftRight_gg(__global DIGIT_T *a, __global const DIGIT_T *b, size_t shift, size_t ndigits)
    837 {    /* Computes a = b >> shift */
    838     /* [v2.1] Modified to cope with shift > BITS_PERDIGIT */
    839 
    840     DIGIT_T carry = 0;
    841 
    842     while (1) {
    843 
    844         size_t i, y, nw, bits;
    845         DIGIT_T mask, tempCarry, nextcarry;
    846 
    847         /* Do we shift whole digits? */
    848         if (shift >= BITS_PER_DIGIT)
    849         {
    850             nw = shift / BITS_PER_DIGIT;
    851             for (i = 0; i < ndigits; i++)
    852             {
    853                 if ((i+nw) < ndigits)
    854                     a[i] = b[i+nw];
    855                 else
    856                     a[i] = 0;
    857             }
    858             /* Call again to shift bits inside digits */
    859             bits = shift % BITS_PER_DIGIT;
    860             tempCarry = b[nw-1] >> bits;
    861             if (bits)
    862                 carry |= tempCarry;
    863             return carry;
    864         }
    865         else
    866         {
    867             bits = shift;
    868         }
    869 
    870         /* Construct mask to set low bits */
    871         /* (thanks to Jesse Chisholm for suggesting this improved technique) */
    872         mask = ~(~(DIGIT_T)0 << bits);
    873 
    874         y = BITS_PER_DIGIT - bits;
    875         carry = 0;
    876         i = ndigits;
    877         while (i--)
    878         {
    879             nextcarry = (b[i] & mask) << y;
    880             a[i] = b[i] >> bits | carry;
    881             carry = nextcarry;
    882         }
    883 
    884         return carry;
    885 
    886     }
    887 }
    888 
    889 int spMultiply(uint p[2], uint x, uint y)
    890 {
    891 
    892 
    893 
    894 
    895     /* Use a 64-bit temp for product */
    896     //ulong t = (ulong)x * (ulong)y;
    897     /* then split into two parts */
    898     p[1] = mul_hi(x,y);
    899     p[0] = x * y;
    900 
    901     return 0;
    902 }
    903 
    904 uint spDivide(uint *pq, uint *pr, const uint u[2], uint v)
    905 {
    906     ulong uu, q;
    907     uu = (ulong)u[1] << 32 | (ulong)u[0];
    908     q = uu / (ulong)v;
    909     //r = uu % (uint64_t)v;
    910     *pr = (uint)(uu - q * v);
    911     *pq = (uint)(q & 0xFFFFFFFF);
    912     return (uint)(q >> 32);
    913 }
    914 
    915 int mpCompare(const DIGIT_T *a, const DIGIT_T *b, size_t ndigits)
    916 {
    917     /* if (ndigits == 0) return 0; // deleted [v2.5] */
    918 
    919     while (ndigits--)
    920     {
    921         if (a[ndigits] > b[ndigits])
    922             return 1;    /* GT */
    923         if (a[ndigits] < b[ndigits])
    924             return -1;    /* LT */
    925     }
    926 
    927     return 0;    /* EQ */
    928 }
    929 
    930 int mpCompare_g(__global const DIGIT_T *a, __global const DIGIT_T *b, size_t ndigits)
    931 {
    932     /* if (ndigits == 0) return 0; // deleted [v2.5] */
    933 
    934     while (ndigits--)
    935     {
    936         if (a[ndigits] > b[ndigits])
    937             return 1;    /* GT */
    938         if (a[ndigits] < b[ndigits])
    939             return -1;    /* LT */
    940     }
    941 
    942     return 0;    /* EQ */
    943 }
    944 
    945 int mpCompare_lg(const DIGIT_T *a, __global const DIGIT_T *b, size_t ndigits)
    946 {
    947     /* if (ndigits == 0) return 0; // deleted [v2.5] */
    948 
    949     while (ndigits--)
    950     {
    951         if (a[ndigits] > b[ndigits])
    952             return 1;    /* GT */
    953         if (a[ndigits] < b[ndigits])
    954             return -1;    /* LT */
    955     }
    956 
    957     return 0;    /* EQ */
    958 }
    959 
    960 void mpSetEqual(DIGIT_T *a, const DIGIT_T *b, size_t ndigits)
    961 {    /* Sets a = b */
    962     size_t i;
    963 
    964     for (i = 0; i < ndigits; i++)
    965     {
    966         a[i] = b[i];
    967     }
    968 }
    969 
    970 void mpSetEqual_lg(DIGIT_T *a, __global const DIGIT_T *b, size_t ndigits)
    971 {    /* Sets a = b */
    972     size_t i;
    973 
    974     for (i = 0; i < ndigits; i++)
    975     {
    976         a[i] = b[i];
    977     }
    978 }
    979 
    980 void mpSetEqual_gl(__global DIGIT_T *a, const DIGIT_T *b, size_t ndigits)
    981 {    /* Sets a = b */
    982     size_t i;
    983 
    984     for (i = 0; i < ndigits; i++)
    985     {
    986         a[i] = b[i];
    987     }
    988 }
    989 
    990 volatile DIGIT_T mpSetZero(volatile DIGIT_T *a, size_t ndigits)
    991 {    /* Sets a = 0 */
    992 
    993     /* Prevent optimiser ignoring this */
    994     volatile DIGIT_T optdummy;
    995     volatile DIGIT_T *p = a;
    996 
    997     while (ndigits--)
    998         a[ndigits] = 0;
    999 
   1000     optdummy = *p;
   1001     return optdummy;
   1002 }
   1003 
   1004 volatile DIGIT_T mpSetZero_g(__global volatile DIGIT_T *a, size_t ndigits)
   1005 {    /* Sets a = 0 */
   1006 
   1007     /* Prevent optimiser ignoring this */
   1008     volatile DIGIT_T optdummy;
   1009     __global volatile DIGIT_T *p = a;
   1010 
   1011     while (ndigits--)
   1012         a[ndigits] = 0;
   1013 
   1014     optdummy = *p;
   1015     return optdummy;
   1016 }
   1017 
   1018 size_t mpSizeof(const DIGIT_T *a, size_t ndigits)
   1019 {
   1020     while(ndigits--)
   1021     {
   1022         if (a[ndigits] != 0)
   1023             return (++ndigits);
   1024     }
   1025     return 0;
   1026 }
   1027 
   1028 size_t mpSizeof_g(__global DIGIT_T *a, size_t ndigits)
   1029 {
   1030     while(ndigits--)
   1031     {
   1032         if (a[ndigits] != 0)
   1033             return (++ndigits);
   1034     }
   1035     return 0;
   1036 }
   1037 
   1038 volatile uint8 zeroise_bytes(volatile void *v, size_t n)
   1039 {    /* Zeroise byte array b and make sure optimiser does not ignore this */
   1040     volatile uint8 optdummy;
   1041     volatile uint8 *b = (uint8*)v;
   1042     while(n--)
   1043         b[n] = 0;
   1044     optdummy = *b;
   1045     return optdummy;
   1046 }
   1047 
   1048 void mpFail(char *msg)
   1049 {
   1050     //perror(msg);
   1051     printf("the program should stop here");
   1052 }
   1053 
   1054 size_t mpBitLength( const DIGIT_T *d, size_t ndigits)
   1055 /* Returns no of significant bits in d */
   1056 {
   1057     size_t n, i, bits;
   1058     DIGIT_T mask;
   1059 
   1060     if (!d || ndigits == 0)
   1061         return 0;
   1062 
   1063     n = mpSizeof(d, ndigits);
   1064     if (0 == n) return 0;
   1065 
   1066     for (i = 0, mask = HIBITMASK; mask > 0; mask >>= 1, i++)
   1067     {
   1068         if (d[n-1] & mask)
   1069             break;
   1070     }
   1071 
   1072     bits = n * BITS_PER_DIGIT - i;
   1073 
   1074     return bits;
   1075 }
   1076 /*
   1077 void mpModSquareTemp(DIGIT_T *y,DIGIT_T *m,size_t n,DIGIT_T *t1,DIGIT_T *t2) {
   1078 
   1079     mpSquare(t1,y,n);
   1080     mpDivide(t2,y,t1,n*2,m,n);
   1081 
   1082 }
   1083 
   1084 void mpModMultTemp(DIGIT_T *y, DIGIT_T *x, DIGIT_T *m, size_t n, DIGIT_T* t1, DIGIT_T *t2) {
   1085 
   1086     mpMultiply(t1,x,y,n);
   1087     mpDivide(t2,y,t1,n*2,m,n);
   1088 
   1089 }
   1090 
   1091 */
   1092 int mpSquare(DIGIT_T *w, const DIGIT_T *x, size_t ndigits)
   1093 /* New in Version 2.0 */
   1094 {
   1095     /*    Computes square w = x * x
   1096         where x is a multiprecision integer of ndigits
   1097         and w is a multiprecision integer of 2*ndigits
   1098 
   1099         Ref: Menezes p596 Algorithm 14.16 with errata.
   1100     */
   1101 
   1102     DIGIT_T k, p[2], u[2], cbit, carry;
   1103     size_t i, j, t, i2, cpos;
   1104 
   1105     assert(w != x);
   1106 
   1107     t = ndigits;
   1108 
   1109     /* 1. For i from 0 to (2t-1) do: w_i = 0 */
   1110     i2 = t << 1;
   1111     for (i = 0; i < i2; i++)
   1112         w[i] = 0;
   1113 
   1114     carry = 0;
   1115     cpos = i2-1;
   1116     /* 2. For i from 0 to (t-1) do: */
   1117     for (i = 0; i < t; i++)
   1118     {
   1119         /* 2.1 (uv) = w_2i + x_i * x_i, w_2i = v, c = u
   1120            Careful, w_2i may be double-prec
   1121         */
   1122         i2 = i << 1; /* 2*i */
   1123         spMultiply(p, x[i], x[i]);
   1124         p[0] += w[i2];
   1125         if (p[0] < w[i2])
   1126             p[1]++;
   1127         k = 0;    /* p[1] < b, so no overflow here */
   1128         if (i2 == cpos && carry)
   1129         {
   1130             p[1] += carry;
   1131             if (p[1] < carry)
   1132                 k++;
   1133             carry = 0;
   1134         }
   1135         w[i2] = p[0];
   1136         u[0] = p[1];
   1137         u[1] = k;
   1138 
   1139         /* 2.2 for j from (i+1) to (t-1) do:
   1140            (uv) = w_{i+j} + 2x_j * x_i + c,
   1141            w_{i+j} = v, c = u,
   1142            u is double-prec
   1143            w_{i+j} is dbl if [i+j] == cpos
   1144         */
   1145         k = 0;
   1146         for (j = i+1; j < t; j++)
   1147         {
   1148             /* p = x_j * x_i */
   1149             spMultiply(p, x[j], x[i]);
   1150             /* p = 2p <=> p <<= 1 */
   1151             cbit = (p[0] & HIBITMASK) != 0;
   1152             k =  (p[1] & HIBITMASK) != 0;
   1153             p[0] <<= 1;
   1154             p[1] <<= 1;
   1155             p[1] |= cbit;
   1156             /* p = p + c */
   1157             p[0] += u[0];
   1158             if (p[0] < u[0])
   1159             {
   1160                 p[1]++;
   1161                 if (p[1] == 0)
   1162                     k++;
   1163             }
   1164             p[1] += u[1];
   1165             if (p[1] < u[1])
   1166                 k++;
   1167             /* p = p + w_{i+j} */
   1168             p[0] += w[i+j];
   1169             if (p[0] < w[i+j])
   1170             {
   1171                 p[1]++;
   1172                 if (p[1] == 0)
   1173                     k++;
   1174             }
   1175             if ((i+j) == cpos && carry)
   1176             {    /* catch overflow from last round */
   1177                 p[1] += carry;
   1178                 if (p[1] < carry)
   1179                     k++;
   1180                 carry = 0;
   1181             }
   1182             /* w_{i+j} = v, c = u */
   1183             w[i+j] = p[0];
   1184             u[0] = p[1];
   1185             u[1] = k;
   1186         }
   1187         /* 2.3 w_{i+t} = u */
   1188         w[i+t] = u[0];
   1189         /* remember overflow in w_{i+t} */
   1190         carry = u[1];
   1191         cpos = i+t;
   1192     }
   1193 
   1194     /* (NB original step 3 deleted in Menezes errata) */
   1195 
   1196     /* Return w */
   1197 
   1198     return 0;
   1199 }
   1200 
   1201 int mpIsZero( const DIGIT_T *a, size_t ndigits)
   1202 {
   1203     size_t i;
   1204 
   1205     /* if (ndigits == 0) return -1; // deleted [v2.5] */
   1206 
   1207     for (i = 0; i < ndigits; i++)    /* Start at lsb */
   1208     {
   1209         if (a[i] != 0)
   1210             return 0;    /* False */
   1211     }
   1212 
   1213     return (!0);    /* True */
   1214 }
   1215 
   1216 
   1217 
   1218 __kernel void several(__global DIGIT_T* x, //__global const unsigned long *s_len,
   1219                       __global DIGIT_T* e, //__global const unsigned long *e_len,
   1220                       __global DIGIT_T* m, //__global const unsigned long *n_len,
   1221                       __global DIGIT_T *mm,// __global const unsigned long *mm_len,
   1222                       __global uint* valid,
   1223                       __global uint *pks
   1224                       ) {
   1225 
   1226     int index = get_global_id(0);
   1227     
   1228     int pk = 0;
   1229     
   1230     pk = pks[index];
   1231     
   1232     int ndigits = MAX_ALLOC_SIZE;
   1233     int edigits = 1;
   1234     
   1235     //printf((char __constant *)"%i\n", ndigits);
   1236     
   1237     // the result is copied in here, compare it to mm
   1238     DIGIT_T yout[MAX_ALLOC_SIZE * 2];
   1239     
   1240     DIGIT_T mask;
   1241     unsigned long n;
   1242     
   1243     __global DIGIT_T * __private window_x; // private scope pointers to global memory
   1244     __global DIGIT_T * __private window_e;
   1245     __global DIGIT_T * __private window_m;
   1246     __global DIGIT_T * __private window_mm;
   1247     
   1248     window_e = &e[pk];
   1249     window_m = &m[pk * MAX_ALLOC_SIZE];
   1250     
   1251     window_x = &x[index * MAX_ALLOC_SIZE];
   1252     window_mm = &mm[index * MAX_ALLOC_SIZE];
   1253     
   1254     
   1255     __private DIGIT_T t1[MAX_ALLOC_SIZE *2]; // obsolete?
   1256     __private DIGIT_T t2[MAX_ALLOC_SIZE *2]; // obsolete?
   1257     __private DIGIT_T y[MAX_ALLOC_SIZE *2]; // obsolete?
   1258     
   1259     n = mpSizeof_g(window_e, edigits);
   1260     /* Catch e==0 => x^0=1 */
   1261     if (0 == n)
   1262     {
   1263         mpSetDigit(yout, 1, ndigits);
   1264         
   1265     }
   1266     /* Find second-most significant bit in e */
   1267     for (mask = HIBITMASK; mask > 0; mask >>= 1)
   1268     {
   1269         if (window_e[n-1] & mask)
   1270             break;
   1271     }
   1272     mpNEXTBITMASK(mask, n);
   1273     
   1274     /* Set y = x */
   1275     mpSetEqual_lg(y, window_x, ndigits);
   1276     
   1277     /* For bit j = k-2 downto 0 */
   1278     while (n)
   1279     {
   1280         /* Square y = y * y mod n */
   1281         mpMODSQUARETEMP(y, window_m, ndigits, t1, t2);
   1282         
   1283         
   1284         if (e[n-1] & mask)
   1285         {    /*    if e(j) == 1 then multiply
   1286               y = y * x mod n */
   1287             mpMODMULTTEMP(y, window_x, window_m, ndigits, t1, t2);
   1288             
   1289         }
   1290         
   1291         /* Move to next bit */
   1292         mpNEXTBITMASK(mask, n);
   1293     }
   1294     
   1295     mpSetEqual(yout, y, ndigits);
   1296     
   1297     int len = MAX_ALLOC_SIZE;
   1298     
   1299     
   1300     // only increase if there was an error verifying – because otherwise several kernels might write simultaneously, and not every success would be counted
   1301     if (mpCompare_lg(yout,window_mm,len) == 0) {
   1302         
   1303         uint out_offset = index / (sizeof(uint) * 8); // 32 bit
   1304         
   1305         uint mv = 1 << index;
   1306         
   1307         atomic_or(&valid[out_offset], mv);
   1308     }
   1309     
   1310 }