libgpuverify

Signature verification on GPUs (WiP)
Log | Files | Refs | README | LICENSE

commit ba7f0ed534e534f30f8f7e8b4e13d7aadbf7d753
parent d50b5b70b89f9fd9a7208fa52c8bf0c9aadb9f92
Author: Cedric <cedric.zwahlen@students.bfh.ch>
Date:   Sat,  4 Nov 2023 18:12:13 +0100

the kernel compiles, and outputs data, though it's not verifying correctly yet

Diffstat:
Msource/big-int-test.c | 10+++++-----
Msource/rsa-test.c | 4++--
Mxcode/lib-gpu-verify.xcodeproj/project.xcworkspace/xcuserdata/cedriczwahlen.xcuserdatad/UserInterfaceState.xcuserstate | 0
Mxcode/lib-gpu-verify.xcodeproj/xcuserdata/cedriczwahlen.xcuserdatad/xcdebugger/Breakpoints_v2.xcbkptlist | 83+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++--
Mxcode/verify.cl | 650+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++--------------------
5 files changed, 575 insertions(+), 172 deletions(-)

diff --git a/source/big-int-test.c b/source/big-int-test.c @@ -405,7 +405,7 @@ static int QhatTooBig(DIGIT_T qhat, DIGIT_T rhat, return 0; } -static DIGIT_T mpMultSub(DIGIT_T wn, DIGIT_T w[], const DIGIT_T v[], +static DIGIT_T mpMultSub(DIGIT_T wn, DIGIT_T *w, const DIGIT_T v[], DIGIT_T q, size_t n) { /* Compute w = w - qv where w = (WnW[n-1]...W[0]) @@ -439,7 +439,7 @@ static DIGIT_T mpMultSub(DIGIT_T wn, DIGIT_T w[], const DIGIT_T v[], return wn; } -DIGIT_T mpShiftLeft(DIGIT_T a[], const DIGIT_T *b, +DIGIT_T mpShiftLeft(DIGIT_T *a, const DIGIT_T *b, size_t shift, size_t ndigits) { /* Computes a = b << shift */ /* [v2.1] Modified to cope with shift > BITS_PERDIGIT */ @@ -485,7 +485,7 @@ DIGIT_T mpShiftLeft(DIGIT_T a[], const DIGIT_T *b, return carry; } -DIGIT_T mpShiftRight(DIGIT_T a[], const DIGIT_T b[], size_t shift, size_t ndigits) +DIGIT_T mpShiftRight(DIGIT_T *a, const DIGIT_T *b, size_t shift, size_t ndigits) { /* Computes a = b >> shift */ /* [v2.1] Modified to cope with shift > BITS_PERDIGIT */ size_t i, y, nw, bits; @@ -783,7 +783,7 @@ void mpPrintHex(const char *prefix, const DIGIT_T *a, size_t len, const char *su } -int mpModExpO(DIGIT_T yout[], const DIGIT_T x[], const DIGIT_T e[], DIGIT_T m[], size_t ndigits) +int mpModExpO(DIGIT_T *yout, const DIGIT_T *x, const DIGIT_T *e, DIGIT_T *m, size_t ndigits) { /* Computes y = x^e mod m */ /* "Classic" binary left-to-right method */ /* [v2.2] removed const restriction on m[] to avoid using an extra alloc'd var @@ -850,7 +850,7 @@ done: return 0; } -int mpSquare(DIGIT_T w[], const DIGIT_T x[], size_t ndigits) +int mpSquare(DIGIT_T *w, const DIGIT_T *x, size_t ndigits) /* New in Version 2.0 */ { /* Computes square w = x * x diff --git a/source/rsa-test.c b/source/rsa-test.c @@ -325,7 +325,7 @@ int rsa_tests(void) { err = clEnqueueWriteBuffer(commands, s_mem, CL_TRUE, 0, s_len, s_buf, 0, NULL, NULL); err |= clEnqueueWriteBuffer(commands, e_mem, CL_TRUE, 0, e_len, e_buf, 0, NULL, NULL); err |= clEnqueueWriteBuffer(commands, n_mem, CL_TRUE, 0, n_len, n_buf, 0, NULL, NULL); - //err |= clEnqueueWriteBuffer(commands, res_mem, CL_TRUE, 0, res_len, res_buf, 0, NULL, NULL); + err |= clEnqueueWriteBuffer(commands, res_mem, CL_TRUE, 0, res_len, res_buf, 0, NULL, NULL); if (err != CL_SUCCESS) { printf("Error: Failed to write to source array!\n"); @@ -394,7 +394,7 @@ int rsa_tests(void) { mpConvToHex(res_buf, sz_res, comp, sz_mm); - printf("%s",comp); + printf("%s\n",comp); // Print a brief summary detailing the results // diff --git a/xcode/lib-gpu-verify.xcodeproj/project.xcworkspace/xcuserdata/cedriczwahlen.xcuserdatad/UserInterfaceState.xcuserstate b/xcode/lib-gpu-verify.xcodeproj/project.xcworkspace/xcuserdata/cedriczwahlen.xcuserdatad/UserInterfaceState.xcuserstate Binary files differ. diff --git a/xcode/lib-gpu-verify.xcodeproj/xcuserdata/cedriczwahlen.xcuserdatad/xcdebugger/Breakpoints_v2.xcbkptlist b/xcode/lib-gpu-verify.xcodeproj/xcuserdata/cedriczwahlen.xcuserdatad/xcdebugger/Breakpoints_v2.xcbkptlist @@ -684,7 +684,7 @@ BreakpointExtensionID = "Xcode.Breakpoint.FileBreakpoint"> <BreakpointContent uuid = "E34A3BBB-4BEA-4FC9-B0F1-55FB3A180116" - shouldBeEnabled = "Yes" + shouldBeEnabled = "No" ignoreCount = "0" continueAfterRunningActions = "No" filePath = "../source/rsa-test.c" @@ -755,6 +755,21 @@ endingLineNumber = "397" offsetFromSymbolStart = "4098"> </Location> + <Location + uuid = "E34A3BBB-4BEA-4FC9-B0F1-55FB3A180116 - b0b9078e770c934d" + shouldBeEnabled = "Yes" + ignoreCount = "0" + continueAfterRunningActions = "No" + symbolName = "rsa_tests" + moduleName = "lib-gpu-verify" + usesParentBreakpointCondition = "Yes" + urlString = "file:///Users/cedriczwahlen/libgpuverify/source/rsa-test.c" + startingColumnNumber = "9223372036854775807" + endingColumnNumber = "9223372036854775807" + startingLineNumber = "397" + endingLineNumber = "397" + offsetFromSymbolStart = "4187"> + </Location> </Locations> </BreakpointContent> </BreakpointProxy> @@ -810,7 +825,7 @@ BreakpointExtensionID = "Xcode.Breakpoint.FileBreakpoint"> <BreakpointContent uuid = "0B50C23D-36DE-4C9C-B911-72E48EA8C7FD" - shouldBeEnabled = "Yes" + shouldBeEnabled = "No" ignoreCount = "0" continueAfterRunningActions = "No" filePath = "../source/rsa-test.c" @@ -822,5 +837,69 @@ landmarkType = "9"> </BreakpointContent> </BreakpointProxy> + <BreakpointProxy + BreakpointExtensionID = "Xcode.Breakpoint.FileBreakpoint"> + <BreakpointContent + uuid = "CCA5ECFD-4BDD-4A9B-8C34-A3E9A049CB5C" + shouldBeEnabled = "Yes" + ignoreCount = "0" + continueAfterRunningActions = "No" + filePath = "../source/rsa-test.c" + startingColumnNumber = "9223372036854775807" + endingColumnNumber = "9223372036854775807" + startingLineNumber = "329" + endingLineNumber = "329" + landmarkName = "rsa_tests()" + landmarkType = "9"> + <Locations> + <Location + uuid = "CCA5ECFD-4BDD-4A9B-8C34-A3E9A049CB5C - b0b9078e770c9a09" + shouldBeEnabled = "Yes" + ignoreCount = "0" + continueAfterRunningActions = "No" + symbolName = "rsa_tests" + moduleName = "lib-gpu-verify" + usesParentBreakpointCondition = "Yes" + urlString = "file:///Users/cedriczwahlen/libgpuverify/source/rsa-test.c" + startingColumnNumber = "9223372036854775807" + endingColumnNumber = "9223372036854775807" + startingLineNumber = "329" + endingLineNumber = "329" + offsetFromSymbolStart = "3113"> + </Location> + <Location + uuid = "CCA5ECFD-4BDD-4A9B-8C34-A3E9A049CB5C - b0b9078e770c9a09" + shouldBeEnabled = "Yes" + ignoreCount = "0" + continueAfterRunningActions = "No" + symbolName = "rsa_tests" + moduleName = "lib-gpu-verify" + usesParentBreakpointCondition = "Yes" + urlString = "file:///Users/cedriczwahlen/libgpuverify/source/rsa-test.c" + startingColumnNumber = "9223372036854775807" + endingColumnNumber = "9223372036854775807" + startingLineNumber = "329" + endingLineNumber = "329" + offsetFromSymbolStart = "3202"> + </Location> + </Locations> + </BreakpointContent> + </BreakpointProxy> + <BreakpointProxy + BreakpointExtensionID = "Xcode.Breakpoint.FileBreakpoint"> + <BreakpointContent + uuid = "C1B34A3E-98EE-448B-A02F-A50E969F7C92" + shouldBeEnabled = "No" + ignoreCount = "0" + continueAfterRunningActions = "No" + filePath = "../source/rsa-test.c" + startingColumnNumber = "9223372036854775807" + endingColumnNumber = "9223372036854775807" + startingLineNumber = "506" + endingLineNumber = "506" + landmarkName = "verify(sign, ee, nn, mm)" + landmarkType = "9"> + </BreakpointContent> + </BreakpointProxy> </Breakpoints> </Bucket> diff --git a/xcode/verify.cl b/xcode/verify.cl @@ -41,56 +41,84 @@ typedef uint16 HALF_DIGIT_T; // forward definitions -int mpModulo(DIGIT_T r[], const DIGIT_T u[], size_t udigits, DIGIT_T v[], size_t vdigits); +int mpModulo(__global DIGIT_T *r, DIGIT_T *u, size_t udigits, __global DIGIT_T *v, size_t vdigits); -int mpModMult(DIGIT_T a[], const DIGIT_T x[], const DIGIT_T y[], DIGIT_T m[], size_t ndigits); +//int mpModMult(__global DIGIT_T *a, __global DIGIT_T *x, const DIGIT_T *y, DIGIT_T *m, size_t ndigits); + +int mpMultiply( DIGIT_T *w, __global DIGIT_T *u, const DIGIT_T *v, size_t ndigits); + +DIGIT_T mpAdd( DIGIT_T *w, const DIGIT_T *u, const DIGIT_T *v, size_t ndigits); + +int mpDivide(DIGIT_T *q, DIGIT_T *r, DIGIT_T *u, size_t udigits, __global DIGIT_T *v, size_t vdigits); -int mpMultiply(DIGIT_T w[], const DIGIT_T u[], const DIGIT_T v[], size_t ndigits); -DIGIT_T mpAdd(DIGIT_T w[], const DIGIT_T u[], const DIGIT_T v[], size_t ndigits); -int mpDivide(DIGIT_T q[], DIGIT_T r[], const DIGIT_T u[], size_t udigits, DIGIT_T v[], size_t vdigits); int QhatTooBig(DIGIT_T qhat, DIGIT_T rhat, DIGIT_T vn2, DIGIT_T ujn2); -DIGIT_T mpMultSub(DIGIT_T wn, DIGIT_T w[], const DIGIT_T v[], DIGIT_T q, size_t n); -DIGIT_T mpShiftLeft(DIGIT_T a[], const DIGIT_T *b, size_t shift, size_t ndigits); +DIGIT_T mpMultSub(DIGIT_T wn, DIGIT_T *w, const DIGIT_T *v, DIGIT_T q, size_t n); + +DIGIT_T mpShiftLeft( DIGIT_T *a, const DIGIT_T *b, size_t shift, size_t ndigits); -void mpSetDigit(DIGIT_T a[], DIGIT_T d, size_t ndigits); -int mpCompare(const DIGIT_T a[], const DIGIT_T b[], size_t ndigits); +void mpSetDigit(DIGIT_T *a, DIGIT_T d, size_t ndigits); +int mpCompare(const DIGIT_T *a, const DIGIT_T *b, size_t ndigits); -DIGIT_T mpShiftRight(DIGIT_T a[], const DIGIT_T b[], size_t shift, size_t ndigits); + +DIGIT_T mpShiftRight( DIGIT_T *a, const DIGIT_T *b, size_t shift, size_t ndigits); int spMultiply(uint p[2], uint x, uint y); uint spDivide(uint *pq, uint *pr, const uint u[2], uint v); -int mpSquare(DIGIT_T w[], const DIGIT_T x[], size_t ndigits); +int mpSquare( DIGIT_T *w, const DIGIT_T *x, size_t ndigits); -size_t mpBitLength(const DIGIT_T d[], size_t ndigits); +size_t mpBitLength(const DIGIT_T *d, size_t ndigits); -DIGIT_T mpShortDiv(DIGIT_T q[], const DIGIT_T u[], DIGIT_T v, +DIGIT_T mpShortDiv( DIGIT_T *q, const DIGIT_T *u, DIGIT_T v, size_t ndigits); -void mpSetEqual(DIGIT_T a[], const DIGIT_T b[], size_t ndigits); +void mpSetEqual(DIGIT_T *a, const DIGIT_T *b, size_t ndigits); size_t uiceil(float x); volatile uint8 zeroise_bytes(volatile void *v, size_t n); -size_t mpSizeof(const DIGIT_T a[], size_t ndigits); +size_t mpSizeof(const DIGIT_T *a, size_t ndigits); -volatile DIGIT_T mpSetZero(volatile DIGIT_T a[], size_t ndigits); +volatile DIGIT_T mpSetZero(volatile DIGIT_T *a, size_t ndigits); -int mpIsZero(const DIGIT_T a[], size_t ndigits); +int mpIsZero( const DIGIT_T *a, size_t ndigits); void mpFail(char *msg); -int mpModExpO(DIGIT_T *yout[], DIGIT_T *x[], DIGIT_T *e[], DIGIT_T *m[], size_t ndigits); +int mpModExpO( DIGIT_T *yout, DIGIT_T *x, DIGIT_T *e, DIGIT_T *m, size_t ndigits); void assert(bool precondition); +// global memory definitions + +size_t mpSizeof_g(__global DIGIT_T *a, size_t ndigits); + +DIGIT_T mpShiftLeft_gg(__global DIGIT_T *a, __global const DIGIT_T *b, size_t shift, size_t ndigits); + +DIGIT_T mpShiftRight_gg(__global DIGIT_T *a, __global const DIGIT_T *b, size_t shift, size_t ndigits); + +DIGIT_T mpShiftLeft_lg(DIGIT_T *a, __global const DIGIT_T *b, size_t shift, size_t ndigits); + +int mpCompare_g(__global const DIGIT_T *a, __global const DIGIT_T *b, size_t ndigits); + +int mpCompare_lg(const DIGIT_T *a, __global const DIGIT_T *b, size_t ndigits); + +void mpSetEqual_lg(DIGIT_T *a, __global const DIGIT_T *b, size_t ndigits); +void mpSetEqual_gl(__global DIGIT_T *a, const DIGIT_T *b, size_t ndigits); + +DIGIT_T mpMultSub_lg(DIGIT_T wn, DIGIT_T *w, __global const DIGIT_T *v, DIGIT_T q, size_t n); + +DIGIT_T mpAdd_llg( DIGIT_T *w, const DIGIT_T *u, __global const DIGIT_T *v, size_t ndigits); + +void mpSetDigit_g(__global DIGIT_T *a, DIGIT_T d, size_t ndigits); + +volatile DIGIT_T mpSetZero_g(__global volatile DIGIT_T *a, size_t ndigits); // implementation -int mpModulo(DIGIT_T r[], const DIGIT_T u[], size_t udigits, - DIGIT_T v[], size_t vdigits) +int mpModulo(__global DIGIT_T *r, DIGIT_T *u, size_t udigits, __global DIGIT_T *v, size_t vdigits) { /* Computes r = u mod v where r, v are multiprecision integers of length vdigits @@ -121,7 +149,7 @@ int mpModulo(DIGIT_T r[], const DIGIT_T u[], size_t udigits, mpDivide(qq, rr, u, udigits, v, vdigits); /* Final r is only vdigits long */ - mpSetEqual(r, rr, vdigits); + mpSetEqual_gl(r, rr, vdigits); mpDESTROY(rr, udigits); mpDESTROY(qq, udigits); @@ -129,31 +157,26 @@ int mpModulo(DIGIT_T r[], const DIGIT_T u[], size_t udigits, return 0; } -int mpModMult(DIGIT_T a[], const DIGIT_T x[], const DIGIT_T y[], - DIGIT_T m[], size_t ndigits) -{ /* Computes a = (x * y) mod m */ - -/* Double-length temp variable p */ -// #ifdef NO_ALLOCS - DIGIT_T p[MAX_FIXED_DIGITS * 2]; -// assert(ndigits <= MAX_FIXED_DIGITS); -/*#else - DIGIT_T *p; - p = mpAlloc(ndigits * 2); -#endif -*/ - /* Calc p[2n] = x * y */ - mpMultiply(p, x, y, ndigits); - - /* Then modulo (NOTE: a is OK at only ndigits long) */ - mpModulo(a, p, ndigits * 2, m, ndigits); - - mpDESTROY(p, ndigits * 2); - - return 0; -} - -int mpMultiply(DIGIT_T w[], const DIGIT_T u[], const DIGIT_T v[], size_t ndigits) +//int mpModMult(__global DIGIT_T *a, __global DIGIT_T *x, const DIGIT_T *y, DIGIT_T *m, size_t ndigits) +//{ /* Computes a = (x * y) mod m */ +// +///* Double-length temp variable p */ +// +// DIGIT_T p[MAX_FIXED_DIGITS * 2]; +//// assert(ndigits <= MAX_FIXED_DIGITS); +// +// //Calc p[2n] = x * y +// mpMultiply(p, x, y, ndigits); +// +// /* Then modulo (NOTE: a is OK at only ndigits long) */ +// mpModulo(a, p, ndigits * 2, m, ndigits); +// +// mpDESTROY(p, ndigits * 2); +// +// return 0; +//} + +int mpMultiply( DIGIT_T *w, __global DIGIT_T *u, const DIGIT_T *v, size_t ndigits) { /* Computes product w = u * v where u, v are multiprecision integers of ndigits each @@ -208,7 +231,7 @@ int mpMultiply(DIGIT_T w[], const DIGIT_T u[], const DIGIT_T v[], size_t ndigits return 0; } -DIGIT_T mpAdd(DIGIT_T w[], const DIGIT_T u[], const DIGIT_T v[], size_t ndigits) +DIGIT_T mpAdd( DIGIT_T *w, const DIGIT_T *u, const DIGIT_T *v, size_t ndigits) { /* Calculates w = u + v where w, u, v are multiprecision integers of ndigits each @@ -245,10 +268,44 @@ DIGIT_T mpAdd(DIGIT_T w[], const DIGIT_T u[], const DIGIT_T v[], size_t ndigits) return k; /* w_n = k */ } -// MARK: This function is causing problems – this function calls mpShiftLeft, mpShiftRight at some point (and so does mpShortDiv) they contain recursions, which are forbidden +DIGIT_T mpAdd_llg( DIGIT_T *w, const DIGIT_T *u, __global const DIGIT_T *v, size_t ndigits) +{ + /* Calculates w = u + v + where w, u, v are multiprecision integers of ndigits each + Returns carry if overflow. Carry = 0 or 1. + + Ref: Knuth Vol 2 Ch 4.3.1 p 266 Algorithm A. + */ + + DIGIT_T k; + size_t j; + + // assert(w != v); + + /* Step A1. Initialise */ + k = 0; + + for (j = 0; j < ndigits; j++) + { + /* Step A2. Add digits w_j = (u_j + v_j + k) + Set k = 1 if carry (overflow) occurs + */ + w[j] = u[j] + k; + if (w[j] < k) + k = 1; + else + k = 0; + + w[j] += v[j]; + if (w[j] < v[j]) + k++; -int mpDivide(DIGIT_T q[], DIGIT_T r[], const DIGIT_T u[], - size_t udigits, DIGIT_T v[], size_t vdigits) + } /* Step A3. Loop on j */ + + return k; /* w_n = k */ +} + +int mpDivide(DIGIT_T *q, DIGIT_T *r, DIGIT_T *u, size_t udigits, __global DIGIT_T *v, size_t vdigits) { /* Computes quotient q = u / v and remainder r = u mod v where q, r, u are multiple precision digits all of udigits and the divisor v is vdigits. @@ -275,7 +332,7 @@ int mpDivide(DIGIT_T q[], DIGIT_T r[], const DIGIT_T u[], mpSetZero(r, udigits); /* Work out exact sizes of u and v */ - n = (int)mpSizeof(v, vdigits); + n = (int)mpSizeof_g(v, vdigits); m = (int)mpSizeof(u, udigits); m -= n; @@ -297,7 +354,7 @@ int mpDivide(DIGIT_T q[], DIGIT_T r[], const DIGIT_T u[], if (m == 0) { /* u and v are the same length */ - cmp = mpCompare(u, v, (size_t)n); + cmp = mpCompare_lg(u, v, (size_t)n); if (cmp < 0) { /* v > u, as above */ mpSetEqual(r, u, udigits); @@ -333,7 +390,7 @@ int mpDivide(DIGIT_T q[], DIGIT_T r[], const DIGIT_T u[], } /* Normalise v in situ - NB only shift non-zero digits */ - overflow = mpShiftLeft(v, v, shift, n); + overflow = mpShiftLeft_gg(v, v, shift, n); /* Copy normalised dividend u*d into r */ overflow = mpShiftLeft(r, u, shift, n + m); @@ -378,14 +435,14 @@ int mpDivide(DIGIT_T q[], DIGIT_T r[], const DIGIT_T u[], /* Step D4. Multiply and subtract */ ww = &uu[j]; - overflow = mpMultSub(t[1], ww, v, qhat, (size_t)n); + overflow = mpMultSub_lg(t[1], ww, v, qhat, (size_t)n); /* Step D5. Test remainder. Set Qj = Qhat */ q[j] = qhat; if (overflow) { /* Step D6. Add back if D4 was negative */ q[j]--; - overflow = mpAdd(ww, ww, v, (size_t)n); + overflow = mpAdd_llg(ww, ww, v, (size_t)n); } t[0] = uu[j+n-1]; /* Uj+n on next round */ @@ -399,12 +456,23 @@ int mpDivide(DIGIT_T q[], DIGIT_T r[], const DIGIT_T u[], /* Step D8. Unnormalise. */ mpShiftRight(r, r, shift, n); - mpShiftRight(v, v, shift, n); + mpShiftRight_gg(v, v, shift, n); return 0; } -void mpSetDigit(DIGIT_T a[], DIGIT_T d, size_t ndigits) +void mpSetDigit( DIGIT_T *a, DIGIT_T d, size_t ndigits) +{ /* Sets a = d where d is a single digit */ + size_t i; + + for (i = 1; i < ndigits; i++) + { + a[i] = 0; + } + a[0] = d; +} + +void mpSetDigit_g(__global DIGIT_T *a, DIGIT_T d, size_t ndigits) { /* Sets a = d where d is a single digit */ size_t i; @@ -415,7 +483,7 @@ void mpSetDigit(DIGIT_T a[], DIGIT_T d, size_t ndigits) a[0] = d; } -DIGIT_T mpShortDiv(DIGIT_T q[], const DIGIT_T u[], DIGIT_T v, +DIGIT_T mpShortDiv( DIGIT_T *q, const DIGIT_T *u, DIGIT_T v, size_t ndigits) { /* Calculates quotient q = u div v @@ -487,7 +555,7 @@ int QhatTooBig(DIGIT_T qhat, DIGIT_T rhat, return 0; } -DIGIT_T mpMultSub(DIGIT_T wn, DIGIT_T w[], const DIGIT_T v[], +DIGIT_T mpMultSub(DIGIT_T wn, DIGIT_T *w, const DIGIT_T *v, DIGIT_T q, size_t n) { /* Compute w = w - qv where w = (WnW[n-1]...W[0]) @@ -521,8 +589,95 @@ DIGIT_T mpMultSub(DIGIT_T wn, DIGIT_T w[], const DIGIT_T v[], return wn; } -DIGIT_T mpShiftLeft(DIGIT_T a[], const DIGIT_T *b, - size_t shift, size_t ndigits) +DIGIT_T mpMultSub_lg(DIGIT_T wn, DIGIT_T *w, __global const DIGIT_T *v, DIGIT_T q, size_t n) +{ /* Compute w = w - qv + where w = (WnW[n-1]...W[0]) + return modified Wn. + */ + DIGIT_T k, t[2]; + size_t i; + + if (q == 0) /* No change */ + return wn; + + k = 0; + + for (i = 0; i < n; i++) + { + spMultiply(t, q, v[i]); + w[i] -= k; + if (w[i] > MAX_DIGIT - k) + k = 1; + else + k = 0; + w[i] -= t[0]; + if (w[i] > MAX_DIGIT - t[0]) + k++; + k += t[1]; + } + + /* Cope with Wn not stored in array w[0..n-1] */ + wn -= k; + + return wn; +} + +DIGIT_T mpShiftLeft( DIGIT_T *a, const DIGIT_T *b, size_t shift, size_t ndigits) +{ /* Computes a = b << shift */ + /* [v2.1] Modified to cope with shift > BITS_PERDIGIT */ + + DIGIT_T carry = 0; + + // this replaces the recursion + while (1) { + + size_t i, y, nw, bits; + DIGIT_T mask, tempCarry, nextcarry; + + /* Do we shift whole digits? */ + if (shift >= BITS_PER_DIGIT) + { + nw = shift / BITS_PER_DIGIT; + i = ndigits; + while (i--) + { + if (i >= nw) + a[i] = b[i-nw]; + else + a[i] = 0; + } + /* Call again to shift bits inside digits */ + bits = shift % BITS_PER_DIGIT; + tempCarry = b[ndigits-nw] << bits; + if (bits) { + carry |= tempCarry; + continue; + } + return carry; + } + else + { + bits = shift; + } + + /* Construct mask = high bits set */ + mask = ~(~(DIGIT_T)0 >> bits); + + y = BITS_PER_DIGIT - bits; + carry = 0; + for (i = 0; i < ndigits; i++) + { + nextcarry = (b[i] & mask) >> y; + a[i] = b[i] << bits | carry; + carry = nextcarry; + } + + return carry; + + } +} + +DIGIT_T mpShiftLeft_gg(__global DIGIT_T *a, __global const DIGIT_T *b, size_t shift, size_t ndigits) { /* Computes a = b << shift */ /* [v2.1] Modified to cope with shift > BITS_PERDIGIT */ @@ -577,7 +732,62 @@ DIGIT_T mpShiftLeft(DIGIT_T a[], const DIGIT_T *b, } } -DIGIT_T mpShiftRight(DIGIT_T a[], const DIGIT_T b[], size_t shift, size_t ndigits) +DIGIT_T mpShiftLeft_lg(DIGIT_T *a, __global const DIGIT_T *b, size_t shift, size_t ndigits) +{ /* Computes a = b << shift */ + /* [v2.1] Modified to cope with shift > BITS_PERDIGIT */ + + DIGIT_T carry = 0; + + // this replaces the recursion + while (1) { + + size_t i, y, nw, bits; + DIGIT_T mask, tempCarry, nextcarry; + + /* Do we shift whole digits? */ + if (shift >= BITS_PER_DIGIT) + { + nw = shift / BITS_PER_DIGIT; + i = ndigits; + while (i--) + { + if (i >= nw) + a[i] = b[i-nw]; + else + a[i] = 0; + } + /* Call again to shift bits inside digits */ + bits = shift % BITS_PER_DIGIT; + tempCarry = b[ndigits-nw] << bits; + if (bits) { + carry |= tempCarry; + continue; + } + return carry; + } + else + { + bits = shift; + } + + /* Construct mask = high bits set */ + mask = ~(~(DIGIT_T)0 >> bits); + + y = BITS_PER_DIGIT - bits; + carry = 0; + for (i = 0; i < ndigits; i++) + { + nextcarry = (b[i] & mask) >> y; + a[i] = b[i] << bits | carry; + carry = nextcarry; + } + + return carry; + + } +} + +DIGIT_T mpShiftRight( DIGIT_T *a, const DIGIT_T *b, size_t shift, size_t ndigits) { /* Computes a = b >> shift */ /* [v2.1] Modified to cope with shift > BITS_PERDIGIT */ @@ -630,7 +840,58 @@ DIGIT_T mpShiftRight(DIGIT_T a[], const DIGIT_T b[], size_t shift, size_t ndigit } } +DIGIT_T mpShiftRight_gg(__global DIGIT_T *a, __global const DIGIT_T *b, size_t shift, size_t ndigits) +{ /* Computes a = b >> shift */ + /* [v2.1] Modified to cope with shift > BITS_PERDIGIT */ + + DIGIT_T carry = 0; + + while (1) { + + size_t i, y, nw, bits; + DIGIT_T mask, tempCarry, nextcarry; + + /* Do we shift whole digits? */ + if (shift >= BITS_PER_DIGIT) + { + nw = shift / BITS_PER_DIGIT; + for (i = 0; i < ndigits; i++) + { + if ((i+nw) < ndigits) + a[i] = b[i+nw]; + else + a[i] = 0; + } + /* Call again to shift bits inside digits */ + bits = shift % BITS_PER_DIGIT; + tempCarry = b[nw-1] >> bits; + if (bits) + carry |= tempCarry; + return carry; + } + else + { + bits = shift; + } + + /* Construct mask to set low bits */ + /* (thanks to Jesse Chisholm for suggesting this improved technique) */ + mask = ~(~(DIGIT_T)0 << bits); + + y = BITS_PER_DIGIT - bits; + carry = 0; + i = ndigits; + while (i--) + { + nextcarry = (b[i] & mask) << y; + a[i] = b[i] >> bits | carry; + carry = nextcarry; + } + return carry; + + } +} int spMultiply(uint p[2], uint x, uint y) { @@ -654,7 +915,22 @@ uint spDivide(uint *pq, uint *pr, const uint u[2], uint v) return (uint)(q >> 32); } -int mpCompare(const DIGIT_T a[], const DIGIT_T b[], size_t ndigits) +int mpCompare(const DIGIT_T *a, const DIGIT_T *b, size_t ndigits) +{ + /* if (ndigits == 0) return 0; // deleted [v2.5] */ + + while (ndigits--) + { + if (a[ndigits] > b[ndigits]) + return 1; /* GT */ + if (a[ndigits] < b[ndigits]) + return -1; /* LT */ + } + + return 0; /* EQ */ +} + +int mpCompare_g(__global const DIGIT_T *a, __global const DIGIT_T *b, size_t ndigits) { /* if (ndigits == 0) return 0; // deleted [v2.5] */ @@ -669,7 +945,42 @@ int mpCompare(const DIGIT_T a[], const DIGIT_T b[], size_t ndigits) return 0; /* EQ */ } -void mpSetEqual(DIGIT_T a[], const DIGIT_T b[], size_t ndigits) +int mpCompare_lg(const DIGIT_T *a, __global const DIGIT_T *b, size_t ndigits) +{ + /* if (ndigits == 0) return 0; // deleted [v2.5] */ + + while (ndigits--) + { + if (a[ndigits] > b[ndigits]) + return 1; /* GT */ + if (a[ndigits] < b[ndigits]) + return -1; /* LT */ + } + + return 0; /* EQ */ +} + +void mpSetEqual(DIGIT_T *a, const DIGIT_T *b, size_t ndigits) +{ /* Sets a = b */ + size_t i; + + for (i = 0; i < ndigits; i++) + { + a[i] = b[i]; + } +} + +void mpSetEqual_lg(DIGIT_T *a, __global const DIGIT_T *b, size_t ndigits) +{ /* Sets a = b */ + size_t i; + + for (i = 0; i < ndigits; i++) + { + a[i] = b[i]; + } +} + +void mpSetEqual_gl(__global DIGIT_T *a, const DIGIT_T *b, size_t ndigits) { /* Sets a = b */ size_t i; @@ -679,7 +990,7 @@ void mpSetEqual(DIGIT_T a[], const DIGIT_T b[], size_t ndigits) } } -volatile DIGIT_T mpSetZero(volatile DIGIT_T a[], size_t ndigits) +volatile DIGIT_T mpSetZero(volatile DIGIT_T *a, size_t ndigits) { /* Sets a = 0 */ /* Prevent optimiser ignoring this */ @@ -693,7 +1004,31 @@ volatile DIGIT_T mpSetZero(volatile DIGIT_T a[], size_t ndigits) return optdummy; } -size_t mpSizeof(const DIGIT_T a[], size_t ndigits) +volatile DIGIT_T mpSetZero_g(__global volatile DIGIT_T *a, size_t ndigits) +{ /* Sets a = 0 */ + + /* Prevent optimiser ignoring this */ + volatile DIGIT_T optdummy; + __global volatile DIGIT_T *p = a; + + while (ndigits--) + a[ndigits] = 0; + + optdummy = *p; + return optdummy; +} + +size_t mpSizeof(const DIGIT_T *a, size_t ndigits) +{ + while(ndigits--) + { + if (a[ndigits] != 0) + return (++ndigits); + } + return 0; +} + +size_t mpSizeof_g(__global DIGIT_T *a, size_t ndigits) { while(ndigits--) { @@ -719,7 +1054,7 @@ void mpFail(char *msg) printf("the program should stop here"); } -size_t mpBitLength(const DIGIT_T d[], size_t ndigits) +size_t mpBitLength( const DIGIT_T *d, size_t ndigits) /* Returns no of significant bits in d */ { size_t n, i, bits; @@ -741,8 +1076,7 @@ size_t mpBitLength(const DIGIT_T d[], size_t ndigits) return bits; } - - +/* void mpModSquareTemp(DIGIT_T *y,DIGIT_T *m,size_t n,DIGIT_T *t1,DIGIT_T *t2) { mpSquare(t1,y,n); @@ -757,88 +1091,8 @@ void mpModMultTemp(DIGIT_T *y, DIGIT_T *x, DIGIT_T *m, size_t n, DIGIT_T* t1, DI } - - - - -int mpModExpO(DIGIT_T *yout[], DIGIT_T *x[], DIGIT_T *e[], DIGIT_T *m[], size_t ndigits) -{ /* Computes y = x^e mod m */ - /* "Classic" binary left-to-right method */ - /* [v2.2] removed const restriction on m[] to avoid using an extra alloc'd var - (m is changed in-situ during the divide operation then restored) */ - DIGIT_T mask; - size_t n; - size_t nn = ndigits * 2; - /* Create some double-length temps */ -//#ifdef NO_ALLOCS - DIGIT_T t1[MAX_FIXED_DIGITS * 2]; - DIGIT_T t2[MAX_FIXED_DIGITS * 2]; - DIGIT_T y[MAX_FIXED_DIGITS * 2]; - assert(ndigits <= MAX_FIXED_DIGITS); -/*#else - DIGIT_T *t1, *t2, *y; - t1 = mpAlloc(nn); - t2 = mpAlloc(nn); - y = mpAlloc(nn); -#endif - */ - assert(ndigits != 0); - - n = mpSizeof(*e, ndigits); - /* Catch e==0 => x^0=1 */ - if (0 == n) - { - mpSetDigit(*yout, 1, ndigits); - goto done; - } - /* Find second-most significant bit in e */ - for (mask = HIBITMASK; mask > 0; mask >>= 1) - { - if (*e[n-1] & mask) - break; - } - mpNEXTBITMASK(mask, n); - - /* Set y = x */ - mpSetEqual(*y, *x, ndigits); - - /* For bit j = k-2 downto 0 */ - while (n) - { - /* Square y = y * y mod n */ - //mpMODSQUARETEMP(*y, *m, ndigits, t1, t2); - //mpModSquareTemp(*y, *m, ndigits, t1, t2); - - mpSquare(t1,y,n); - mpDivide(t2,y,t1,n*2,m,n); - - if (*e[n-1] & mask) - { /* if e(j) == 1 then multiply - y = y * x mod n */ - //mpMODMULTTEMP(*y, *x, *m, ndigits, t1, t2); - //mpModMultTemp(*y, *x, *m, ndigits, t1, t2); - - mpMultiply(t1,x,y,n); - mpDivide(t2,y,t1,n*2,m,n); - - } - - /* Move to next bit */ - mpNEXTBITMASK(mask, n); - } - - /* Return y */ - mpSetEqual(*yout, y, ndigits); - -done: - mpDESTROY(t1, nn); - mpDESTROY(t2, nn); - mpDESTROY(y, ndigits); - - return 0; -} - -int mpSquare(DIGIT_T w[], const DIGIT_T x[], size_t ndigits) +*/ +int mpSquare(DIGIT_T *w, const DIGIT_T *x, size_t ndigits) /* New in Version 2.0 */ { /* Computes square w = x * x @@ -947,10 +1201,7 @@ int mpSquare(DIGIT_T w[], const DIGIT_T x[], size_t ndigits) return 0; } - - - -int mpIsZero(const DIGIT_T a[], size_t ndigits) +int mpIsZero( const DIGIT_T *a, size_t ndigits) { size_t i; @@ -965,7 +1216,6 @@ int mpIsZero(const DIGIT_T a[], size_t ndigits) return (!0); /* True */ } - void assert(bool precondition) { char str[] = "assert reached, also this message leaks memory"; @@ -976,21 +1226,95 @@ void assert(bool precondition) { } + +//int mpModExpO(__global DIGIT_T *yout, __global DIGIT_T *x, __global DIGIT_T *e, __global DIGIT_T *m, size_t ndigits) + // some might be constants -__kernel void single(global DIGIT_T* s, const unsigned int s_len, - global DIGIT_T* e, const unsigned int e_len, - global DIGIT_T* n, const unsigned int n_len, - global DIGIT_T* res, const unsigned int res_len, +__kernel void single(__global DIGIT_T* x, const unsigned int s_len, + __global DIGIT_T* e, const unsigned int e_len, + __global DIGIT_T* m, const unsigned int n_len, + __global DIGIT_T *yout, const unsigned int res_len, //global DIGIT_T* comp, const unsigned int comp_len, - const unsigned int max_len, - global int8* valid + const unsigned int ndigits, + __global int8* valid //const unsigned int count ) { + // memory(res); + + + // __global DIGIT_T * __local ptr_x; + + // ptr_x = x; + // ptr_x[3] = 4; - mpModExpO(&res,&s,&e,&n,max_len); + // mpModExpO(res,s,e,n,max_len); + + DIGIT_T mask; + size_t n; + size_t nn = ndigits * 2; + /* Create some double-length temps */ + + DIGIT_T t1[MAX_FIXED_DIGITS * 2]; + DIGIT_T t2[MAX_FIXED_DIGITS * 2]; + DIGIT_T y[MAX_FIXED_DIGITS * 2]; + assert(ndigits <= MAX_FIXED_DIGITS); + + assert(ndigits != 0); + + n = mpSizeof_g(e, ndigits); + /* Catch e==0 => x^0=1 */ + if (0 == n) + { + mpSetDigit_g(yout, 1, ndigits); + goto done; + } + /* Find second-most significant bit in e */ + for (mask = HIBITMASK; mask > 0; mask >>= 1) + { + if (e[n-1] & mask) + break; + } + mpNEXTBITMASK(mask, n); + /* Set y = x */ + mpSetEqual_lg(y, x, ndigits); + + /* For bit j = k-2 downto 0 */ + while (n) + { + /* Square y = y * y mod n */ + mpMODSQUARETEMP(y, m, ndigits, t1, t2); + //mpModSquareTemp(*y, *m, ndigits, t1, t2); + + //mpSquare(t1,y,n); + // mpDivide(t2,y,t1,n*2,m,n); + + if (e[n-1] & mask) + { /* if e(j) == 1 then multiply + y = y * x mod n */ + mpMODMULTTEMP(y, x, m, ndigits, t1, t2); + //mpModMultTemp(*y, *x, *m, ndigits, t1, t2); + + // mpMultiply(t1,x,y,n); + // mpDivide(t2,y,t1,n*2,m,n); + + } + + /* Move to next bit */ + mpNEXTBITMASK(mask, n); + } + + /* Return y */ + mpSetEqual_gl(yout, y, ndigits); + +done: + mpDESTROY(t1, nn); + mpDESTROY(t2, nn); + mpDESTROY(y, ndigits); + + }