libgpuverify

Signature verification on GPUs (WiP)
Log | Files | Refs | README | LICENSE

commit e94b65d99d607b663b4dc78103429a7ab34f5966
parent ee7c77de0e470ff8800459d0ebc91c6f4865acb3
Author: Christian Grothoff <christian@grothoff.org>
Date:   Sat, 18 Nov 2023 21:15:28 +0100

-commenting on problems

Diffstat:
Msource/rsa-test.c | 7+++++--
Mxcode/verify.cl | 173++++++++++++++++++++++++++++++++++++++++++-------------------------------------
2 files changed, 96 insertions(+), 84 deletions(-)

diff --git a/source/rsa-test.c b/source/rsa-test.c @@ -816,7 +816,9 @@ int rsa_tests(void) { } // Set the arguments to our compute kernel - // + // FIXME: this is for some different version of the compute kernel, + // arguments do not match current compute kernel. Aborting... + abort (); err = 0; { unsigned int ctr = 0; @@ -832,13 +834,14 @@ int rsa_tests(void) { ctr++; \ } while (0) SET_ARG (s_mem); - SET_ARG (s_len); + SET_ARG (s_len); /* FIXME: treated as an array of size_t! */ SET_ARG (e_mem); SET_ARG (e_len); SET_ARG (n_mem); SET_ARG (n_len); SET_ARG (res_mem); SET_ARG (res_len); + SET_ARG (max_len); SET_ARG (valid); #undef SET_ARG } diff --git a/xcode/verify.cl b/xcode/verify.cl @@ -132,7 +132,7 @@ int mpModulo(__global DIGIT_T *r, DIGIT_T *u, size_t udigits, __global DIGIT_T * DIGIT_T qq[MAX_FIXED_DIGITS*2]; DIGIT_T rr[MAX_FIXED_DIGITS*2]; - + /* rr[nn] = u mod v */ mpDivide(qq, rr, u, udigits, v, vdigits); @@ -224,7 +224,7 @@ DIGIT_T mpAdd( DIGIT_T *w, const DIGIT_T *u, const DIGIT_T *v, size_t ndigits) k = 1; else k = 0; - + w[j] += v[j]; if (w[j] < v[j]) k++; @@ -261,7 +261,7 @@ DIGIT_T mpAdd_llg( DIGIT_T *w, const DIGIT_T *u, __global const DIGIT_T *v, size k = 1; else k = 0; - + w[j] += v[j]; if (w[j] < v[j]) k++; @@ -364,9 +364,9 @@ int mpDivide(DIGIT_T *q, DIGIT_T *r, DIGIT_T *u, size_t udigits, __global DIGIT_ t[0] = overflow; /* Extra digit Um+n */ - + /* Step D2. Initialise j. Set j = m */ - + for (j = m; j >= 0; j--) { /* Step D3. Set Qhat = [(b.Uj+n + Uj+n-1)/Vn-1] @@ -430,7 +430,7 @@ int mpDivide(DIGIT_T *q, DIGIT_T *r, DIGIT_T *u, size_t udigits, __global DIGIT_ void mpSetDigit( DIGIT_T *a, DIGIT_T d, size_t ndigits) { /* Sets a = d where d is a single digit */ size_t i; - + for (i = 1; i < ndigits; i++) { a[i] = 0; @@ -441,7 +441,7 @@ void mpSetDigit( DIGIT_T *a, DIGIT_T d, size_t ndigits) void mpSetDigit_g(__global DIGIT_T *a, DIGIT_T d, size_t ndigits) { /* Sets a = d where d is a single digit */ size_t i; - + for (i = 1; i < ndigits; i++) { a[i] = 0; @@ -458,7 +458,7 @@ DIGIT_T mpShortDiv( DIGIT_T *q, const DIGIT_T *u, DIGIT_T v, and r, v are single precision digits. Makes no assumptions about normalisation. - + Ref: Knuth Vol 2 Ch 4.3.1 Exercise 16 p625 */ size_t j; @@ -485,7 +485,7 @@ DIGIT_T mpShortDiv( DIGIT_T *q, const DIGIT_T *u, DIGIT_T v, v <<= shift; overflow = mpShiftLeft(q, u, shift, ndigits); uu = q; - + /* Step S1 - modified for extra digit. */ r = overflow; /* New digit Un */ j = ndigits; @@ -499,7 +499,7 @@ DIGIT_T mpShortDiv( DIGIT_T *q, const DIGIT_T *u, DIGIT_T v, /* Unnormalise */ r >>= shift; - + return r; } @@ -591,12 +591,12 @@ DIGIT_T mpMultSub_lg(DIGIT_T wn, DIGIT_T *w, __global const DIGIT_T *v, DIGIT_T DIGIT_T mpShiftLeft( DIGIT_T *a, const DIGIT_T *b, size_t shift, size_t ndigits) { /* Computes a = b << shift */ /* [v2.1] Modified to cope with shift > BITS_PERDIGIT */ - + DIGIT_T carry = 0; - + // this replaces the recursion while (1) { - + size_t i, y, nw, bits; DIGIT_T mask, tempCarry, nextcarry; @@ -628,7 +628,7 @@ DIGIT_T mpShiftLeft( DIGIT_T *a, const DIGIT_T *b, size_t shift, size_t ndigits) /* Construct mask = high bits set */ mask = ~(~(DIGIT_T)0 >> bits); - + y = BITS_PER_DIGIT - bits; carry = 0; for (i = 0; i < ndigits; i++) @@ -639,19 +639,19 @@ DIGIT_T mpShiftLeft( DIGIT_T *a, const DIGIT_T *b, size_t shift, size_t ndigits) } return carry; - + } } DIGIT_T mpShiftLeft_gg(__global DIGIT_T *a, __global const DIGIT_T *b, size_t shift, size_t ndigits) { /* Computes a = b << shift */ /* [v2.1] Modified to cope with shift > BITS_PERDIGIT */ - + DIGIT_T carry = 0; - + // this replaces the recursion while (1) { - + size_t i, y, nw, bits; DIGIT_T mask, tempCarry, nextcarry; @@ -683,7 +683,7 @@ DIGIT_T mpShiftLeft_gg(__global DIGIT_T *a, __global const DIGIT_T *b, size_t sh /* Construct mask = high bits set */ mask = ~(~(DIGIT_T)0 >> bits); - + y = BITS_PER_DIGIT - bits; carry = 0; for (i = 0; i < ndigits; i++) @@ -694,19 +694,19 @@ DIGIT_T mpShiftLeft_gg(__global DIGIT_T *a, __global const DIGIT_T *b, size_t sh } return carry; - + } } DIGIT_T mpShiftLeft_lg(DIGIT_T *a, __global const DIGIT_T *b, size_t shift, size_t ndigits) { /* Computes a = b << shift */ /* [v2.1] Modified to cope with shift > BITS_PERDIGIT */ - + DIGIT_T carry = 0; - + // this replaces the recursion while (1) { - + size_t i, y, nw, bits; DIGIT_T mask, tempCarry, nextcarry; @@ -738,7 +738,7 @@ DIGIT_T mpShiftLeft_lg(DIGIT_T *a, __global const DIGIT_T *b, size_t shift, size /* Construct mask = high bits set */ mask = ~(~(DIGIT_T)0 >> bits); - + y = BITS_PER_DIGIT - bits; carry = 0; for (i = 0; i < ndigits; i++) @@ -749,18 +749,18 @@ DIGIT_T mpShiftLeft_lg(DIGIT_T *a, __global const DIGIT_T *b, size_t shift, size } return carry; - + } } DIGIT_T mpShiftRight( DIGIT_T *a, const DIGIT_T *b, size_t shift, size_t ndigits) { /* Computes a = b >> shift */ /* [v2.1] Modified to cope with shift > BITS_PERDIGIT */ - + DIGIT_T carry = 0; - + while (1) { - + size_t i, y, nw, bits; DIGIT_T mask, tempCarry, nextcarry; @@ -790,7 +790,7 @@ DIGIT_T mpShiftRight( DIGIT_T *a, const DIGIT_T *b, size_t shift, size_t ndigits /* Construct mask to set low bits */ /* (thanks to Jesse Chisholm for suggesting this improved technique) */ mask = ~(~(DIGIT_T)0 << bits); - + y = BITS_PER_DIGIT - bits; carry = 0; i = ndigits; @@ -802,18 +802,18 @@ DIGIT_T mpShiftRight( DIGIT_T *a, const DIGIT_T *b, size_t shift, size_t ndigits } return carry; - + } } DIGIT_T mpShiftRight_gg(__global DIGIT_T *a, __global const DIGIT_T *b, size_t shift, size_t ndigits) { /* Computes a = b >> shift */ /* [v2.1] Modified to cope with shift > BITS_PERDIGIT */ - + DIGIT_T carry = 0; - + while (1) { - + size_t i, y, nw, bits; DIGIT_T mask, tempCarry, nextcarry; @@ -843,7 +843,7 @@ DIGIT_T mpShiftRight_gg(__global DIGIT_T *a, __global const DIGIT_T *b, size_t s /* Construct mask to set low bits */ /* (thanks to Jesse Chisholm for suggesting this improved technique) */ mask = ~(~(DIGIT_T)0 << bits); - + y = BITS_PER_DIGIT - bits; carry = 0; i = ndigits; @@ -855,16 +855,16 @@ DIGIT_T mpShiftRight_gg(__global DIGIT_T *a, __global const DIGIT_T *b, size_t s } return carry; - + } } int spMultiply(uint p[2], uint x, uint y) { - - - - + + + + /* Use a 64-bit temp for product */ //ulong t = (ulong)x * (ulong)y; /* then split into two parts */ @@ -933,7 +933,7 @@ int mpCompare_lg(const DIGIT_T *a, __global const DIGIT_T *b, size_t ndigits) void mpSetEqual(DIGIT_T *a, const DIGIT_T *b, size_t ndigits) { /* Sets a = b */ size_t i; - + for (i = 0; i < ndigits; i++) { a[i] = b[i]; @@ -943,7 +943,7 @@ void mpSetEqual(DIGIT_T *a, const DIGIT_T *b, size_t ndigits) void mpSetEqual_lg(DIGIT_T *a, __global const DIGIT_T *b, size_t ndigits) { /* Sets a = b */ size_t i; - + for (i = 0; i < ndigits; i++) { a[i] = b[i]; @@ -953,7 +953,7 @@ void mpSetEqual_lg(DIGIT_T *a, __global const DIGIT_T *b, size_t ndigits) void mpSetEqual_gl(__global DIGIT_T *a, const DIGIT_T *b, size_t ndigits) { /* Sets a = b */ size_t i; - + for (i = 0; i < ndigits; i++) { a[i] = b[i]; @@ -969,7 +969,7 @@ volatile DIGIT_T mpSetZero(volatile DIGIT_T *a, size_t ndigits) while (ndigits--) a[ndigits] = 0; - + optdummy = *p; return optdummy; } @@ -983,7 +983,7 @@ volatile DIGIT_T mpSetZero_g(__global volatile DIGIT_T *a, size_t ndigits) while (ndigits--) a[ndigits] = 0; - + optdummy = *p; return optdummy; } @@ -1048,19 +1048,19 @@ size_t mpBitLength( const DIGIT_T *d, size_t ndigits) } /* void mpModSquareTemp(DIGIT_T *y,DIGIT_T *m,size_t n,DIGIT_T *t1,DIGIT_T *t2) { - + mpSquare(t1,y,n); mpDivide(t2,y,t1,n*2,m,n); - + } void mpModMultTemp(DIGIT_T *y, DIGIT_T *x, DIGIT_T *m, size_t n, DIGIT_T* t1, DIGIT_T *t2) { - + mpMultiply(t1,x,y,n); mpDivide(t2,y,t1,n*2,m,n); - + } - + */ int mpSquare(DIGIT_T *w, const DIGIT_T *x, size_t ndigits) /* New in Version 2.0 */ @@ -1187,15 +1187,24 @@ int mpIsZero( const DIGIT_T *a, size_t ndigits) } void assert(bool precondition) { - + char str[] = "assert reached, also this message leaks memory"; - + if (!precondition) mpFail(str); - - + + } + +// UGH: size_t could be 32-bit or 64-bit +// depending on your GPU. Doubt you want +// to adjust the client side every time. +// Maybe _always_ use 32-bits for the sizes? +// http://man.opencl.org/scalarDataTypes.html +// Moreover, s_len is initialized as a *single* +// size_t (not even a pointer!), +// but treated as an array here!??!? __kernel void several(__global DIGIT_T* x, __global const size_t *s_len, __global DIGIT_T* e, __global const size_t *e_len, __global DIGIT_T* m, __global const size_t *n_len, @@ -1203,53 +1212,53 @@ __kernel void several(__global DIGIT_T* x, __global const size_t *s_len, __global unsigned long* valid, const int count ) { - + int index = get_global_id(0); - + if (index < count) { - + int ndigits = max( max( n_len[index] - (index == 0 ? 0 : n_len[index - 1]) , mm_len[index] - (index == 0 ? 0 : mm_len[index - 1]) ), s_len[index] - (index == 0 ? 0 : s_len[index - 1]) ); int edigits = e_len[index] - ( index == 0 ? 0 : e_len[index - 1] ); - + // int ndigits = 64; // int edigits = 1; - + // the result is copied in here, compare it to mm DIGIT_T yout[MAX_ALLOC_SIZE *2]; - + DIGIT_T mask; size_t n; - + __global DIGIT_T * __private window_x; __global DIGIT_T * __private window_e; __global DIGIT_T * __private window_m; __global DIGIT_T * __private window_mm; - - + + window_x = &x[index == 0 ? 0 : (s_len[index - 1])]; window_e = &e[index == 0 ? 0 : (e_len[index - 1])]; window_m = &m[index == 0 ? 0 : (n_len[index - 1])]; window_mm = &mm[index == 0 ? 0 : (mm_len[index - 1])]; -// +// // window_x = &x[0]; // window_e = &e[0]; // window_m = &m[0]; // window_mm = &mm[0]; -// +// // can probably be smaller __private DIGIT_T t1[MAX_ALLOC_SIZE *2]; __private DIGIT_T t2[MAX_ALLOC_SIZE *2]; __private DIGIT_T y[MAX_ALLOC_SIZE *2]; - + assert(ndigits <= MAX_FIXED_DIGITS); assert(ndigits != 0); - + n = mpSizeof_g(window_e, edigits); /* Catch e==0 => x^0=1 */ if (0 == n) { mpSetDigit(yout, 1, ndigits); - + } /* Find second-most significant bit in e */ for (mask = HIBITMASK; mask > 0; mask >>= 1) @@ -1258,42 +1267,42 @@ __kernel void several(__global DIGIT_T* x, __global const size_t *s_len, break; } mpNEXTBITMASK(mask, n); - + /* Set y = x */ mpSetEqual_lg(y, window_x, ndigits); - + /* For bit j = k-2 downto 0 */ while (n) // I think it just goes the bit length of e { /* Square y = y * y mod n */ mpMODSQUARETEMP(y, window_m, ndigits, t1, t2); - - + + if (e[n-1] & mask) { /* if e(j) == 1 then multiply y = y * x mod n */ mpMODMULTTEMP(y, window_x, window_m, ndigits, t1, t2); - + } - + /* Move to next bit */ mpNEXTBITMASK(mask, n); } - + mpSetEqual(yout, y, ndigits); - + int len = ( mm_len[index] - (index == 0 ? 0 : mm_len[index - 1]) ); - - + + // MARK: valid cannot be written to by several at once (the same unit anyway) if (mpCompare_lg(yout,window_mm,len) == 0 && index == count - 1) { *valid |= 0x1 << index; } - - - + + + } - + // if (index == 8) { *valid = 0xBA; } - + }