commit ba7f0ed534e534f30f8f7e8b4e13d7aadbf7d753
parent d50b5b70b89f9fd9a7208fa52c8bf0c9aadb9f92
Author: Cedric <cedric.zwahlen@students.bfh.ch>
Date: Sat, 4 Nov 2023 18:12:13 +0100
the kernel compiles, and outputs data, though it's not verifying correctly yet
Diffstat:
5 files changed, 575 insertions(+), 172 deletions(-)
diff --git a/source/big-int-test.c b/source/big-int-test.c
@@ -405,7 +405,7 @@ static int QhatTooBig(DIGIT_T qhat, DIGIT_T rhat,
return 0;
}
-static DIGIT_T mpMultSub(DIGIT_T wn, DIGIT_T w[], const DIGIT_T v[],
+static DIGIT_T mpMultSub(DIGIT_T wn, DIGIT_T *w, const DIGIT_T v[],
DIGIT_T q, size_t n)
{ /* Compute w = w - qv
where w = (WnW[n-1]...W[0])
@@ -439,7 +439,7 @@ static DIGIT_T mpMultSub(DIGIT_T wn, DIGIT_T w[], const DIGIT_T v[],
return wn;
}
-DIGIT_T mpShiftLeft(DIGIT_T a[], const DIGIT_T *b,
+DIGIT_T mpShiftLeft(DIGIT_T *a, const DIGIT_T *b,
size_t shift, size_t ndigits)
{ /* Computes a = b << shift */
/* [v2.1] Modified to cope with shift > BITS_PERDIGIT */
@@ -485,7 +485,7 @@ DIGIT_T mpShiftLeft(DIGIT_T a[], const DIGIT_T *b,
return carry;
}
-DIGIT_T mpShiftRight(DIGIT_T a[], const DIGIT_T b[], size_t shift, size_t ndigits)
+DIGIT_T mpShiftRight(DIGIT_T *a, const DIGIT_T *b, size_t shift, size_t ndigits)
{ /* Computes a = b >> shift */
/* [v2.1] Modified to cope with shift > BITS_PERDIGIT */
size_t i, y, nw, bits;
@@ -783,7 +783,7 @@ void mpPrintHex(const char *prefix, const DIGIT_T *a, size_t len, const char *su
}
-int mpModExpO(DIGIT_T yout[], const DIGIT_T x[], const DIGIT_T e[], DIGIT_T m[], size_t ndigits)
+int mpModExpO(DIGIT_T *yout, const DIGIT_T *x, const DIGIT_T *e, DIGIT_T *m, size_t ndigits)
{ /* Computes y = x^e mod m */
/* "Classic" binary left-to-right method */
/* [v2.2] removed const restriction on m[] to avoid using an extra alloc'd var
@@ -850,7 +850,7 @@ done:
return 0;
}
-int mpSquare(DIGIT_T w[], const DIGIT_T x[], size_t ndigits)
+int mpSquare(DIGIT_T *w, const DIGIT_T *x, size_t ndigits)
/* New in Version 2.0 */
{
/* Computes square w = x * x
diff --git a/source/rsa-test.c b/source/rsa-test.c
@@ -325,7 +325,7 @@ int rsa_tests(void) {
err = clEnqueueWriteBuffer(commands, s_mem, CL_TRUE, 0, s_len, s_buf, 0, NULL, NULL);
err |= clEnqueueWriteBuffer(commands, e_mem, CL_TRUE, 0, e_len, e_buf, 0, NULL, NULL);
err |= clEnqueueWriteBuffer(commands, n_mem, CL_TRUE, 0, n_len, n_buf, 0, NULL, NULL);
- //err |= clEnqueueWriteBuffer(commands, res_mem, CL_TRUE, 0, res_len, res_buf, 0, NULL, NULL);
+ err |= clEnqueueWriteBuffer(commands, res_mem, CL_TRUE, 0, res_len, res_buf, 0, NULL, NULL);
if (err != CL_SUCCESS)
{
printf("Error: Failed to write to source array!\n");
@@ -394,7 +394,7 @@ int rsa_tests(void) {
mpConvToHex(res_buf, sz_res, comp, sz_mm);
- printf("%s",comp);
+ printf("%s\n",comp);
// Print a brief summary detailing the results
//
diff --git a/xcode/lib-gpu-verify.xcodeproj/project.xcworkspace/xcuserdata/cedriczwahlen.xcuserdatad/UserInterfaceState.xcuserstate b/xcode/lib-gpu-verify.xcodeproj/project.xcworkspace/xcuserdata/cedriczwahlen.xcuserdatad/UserInterfaceState.xcuserstate
Binary files differ.
diff --git a/xcode/lib-gpu-verify.xcodeproj/xcuserdata/cedriczwahlen.xcuserdatad/xcdebugger/Breakpoints_v2.xcbkptlist b/xcode/lib-gpu-verify.xcodeproj/xcuserdata/cedriczwahlen.xcuserdatad/xcdebugger/Breakpoints_v2.xcbkptlist
@@ -684,7 +684,7 @@
BreakpointExtensionID = "Xcode.Breakpoint.FileBreakpoint">
<BreakpointContent
uuid = "E34A3BBB-4BEA-4FC9-B0F1-55FB3A180116"
- shouldBeEnabled = "Yes"
+ shouldBeEnabled = "No"
ignoreCount = "0"
continueAfterRunningActions = "No"
filePath = "../source/rsa-test.c"
@@ -755,6 +755,21 @@
endingLineNumber = "397"
offsetFromSymbolStart = "4098">
</Location>
+ <Location
+ uuid = "E34A3BBB-4BEA-4FC9-B0F1-55FB3A180116 - b0b9078e770c934d"
+ shouldBeEnabled = "Yes"
+ ignoreCount = "0"
+ continueAfterRunningActions = "No"
+ symbolName = "rsa_tests"
+ moduleName = "lib-gpu-verify"
+ usesParentBreakpointCondition = "Yes"
+ urlString = "file:///Users/cedriczwahlen/libgpuverify/source/rsa-test.c"
+ startingColumnNumber = "9223372036854775807"
+ endingColumnNumber = "9223372036854775807"
+ startingLineNumber = "397"
+ endingLineNumber = "397"
+ offsetFromSymbolStart = "4187">
+ </Location>
</Locations>
</BreakpointContent>
</BreakpointProxy>
@@ -810,7 +825,7 @@
BreakpointExtensionID = "Xcode.Breakpoint.FileBreakpoint">
<BreakpointContent
uuid = "0B50C23D-36DE-4C9C-B911-72E48EA8C7FD"
- shouldBeEnabled = "Yes"
+ shouldBeEnabled = "No"
ignoreCount = "0"
continueAfterRunningActions = "No"
filePath = "../source/rsa-test.c"
@@ -822,5 +837,69 @@
landmarkType = "9">
</BreakpointContent>
</BreakpointProxy>
+ <BreakpointProxy
+ BreakpointExtensionID = "Xcode.Breakpoint.FileBreakpoint">
+ <BreakpointContent
+ uuid = "CCA5ECFD-4BDD-4A9B-8C34-A3E9A049CB5C"
+ shouldBeEnabled = "Yes"
+ ignoreCount = "0"
+ continueAfterRunningActions = "No"
+ filePath = "../source/rsa-test.c"
+ startingColumnNumber = "9223372036854775807"
+ endingColumnNumber = "9223372036854775807"
+ startingLineNumber = "329"
+ endingLineNumber = "329"
+ landmarkName = "rsa_tests()"
+ landmarkType = "9">
+ <Locations>
+ <Location
+ uuid = "CCA5ECFD-4BDD-4A9B-8C34-A3E9A049CB5C - b0b9078e770c9a09"
+ shouldBeEnabled = "Yes"
+ ignoreCount = "0"
+ continueAfterRunningActions = "No"
+ symbolName = "rsa_tests"
+ moduleName = "lib-gpu-verify"
+ usesParentBreakpointCondition = "Yes"
+ urlString = "file:///Users/cedriczwahlen/libgpuverify/source/rsa-test.c"
+ startingColumnNumber = "9223372036854775807"
+ endingColumnNumber = "9223372036854775807"
+ startingLineNumber = "329"
+ endingLineNumber = "329"
+ offsetFromSymbolStart = "3113">
+ </Location>
+ <Location
+ uuid = "CCA5ECFD-4BDD-4A9B-8C34-A3E9A049CB5C - b0b9078e770c9a09"
+ shouldBeEnabled = "Yes"
+ ignoreCount = "0"
+ continueAfterRunningActions = "No"
+ symbolName = "rsa_tests"
+ moduleName = "lib-gpu-verify"
+ usesParentBreakpointCondition = "Yes"
+ urlString = "file:///Users/cedriczwahlen/libgpuverify/source/rsa-test.c"
+ startingColumnNumber = "9223372036854775807"
+ endingColumnNumber = "9223372036854775807"
+ startingLineNumber = "329"
+ endingLineNumber = "329"
+ offsetFromSymbolStart = "3202">
+ </Location>
+ </Locations>
+ </BreakpointContent>
+ </BreakpointProxy>
+ <BreakpointProxy
+ BreakpointExtensionID = "Xcode.Breakpoint.FileBreakpoint">
+ <BreakpointContent
+ uuid = "C1B34A3E-98EE-448B-A02F-A50E969F7C92"
+ shouldBeEnabled = "No"
+ ignoreCount = "0"
+ continueAfterRunningActions = "No"
+ filePath = "../source/rsa-test.c"
+ startingColumnNumber = "9223372036854775807"
+ endingColumnNumber = "9223372036854775807"
+ startingLineNumber = "506"
+ endingLineNumber = "506"
+ landmarkName = "verify(sign, ee, nn, mm)"
+ landmarkType = "9">
+ </BreakpointContent>
+ </BreakpointProxy>
</Breakpoints>
</Bucket>
diff --git a/xcode/verify.cl b/xcode/verify.cl
@@ -41,56 +41,84 @@ typedef uint16 HALF_DIGIT_T;
// forward definitions
-int mpModulo(DIGIT_T r[], const DIGIT_T u[], size_t udigits, DIGIT_T v[], size_t vdigits);
+int mpModulo(__global DIGIT_T *r, DIGIT_T *u, size_t udigits, __global DIGIT_T *v, size_t vdigits);
-int mpModMult(DIGIT_T a[], const DIGIT_T x[], const DIGIT_T y[], DIGIT_T m[], size_t ndigits);
+//int mpModMult(__global DIGIT_T *a, __global DIGIT_T *x, const DIGIT_T *y, DIGIT_T *m, size_t ndigits);
+
+int mpMultiply( DIGIT_T *w, __global DIGIT_T *u, const DIGIT_T *v, size_t ndigits);
+
+DIGIT_T mpAdd( DIGIT_T *w, const DIGIT_T *u, const DIGIT_T *v, size_t ndigits);
+
+int mpDivide(DIGIT_T *q, DIGIT_T *r, DIGIT_T *u, size_t udigits, __global DIGIT_T *v, size_t vdigits);
-int mpMultiply(DIGIT_T w[], const DIGIT_T u[], const DIGIT_T v[], size_t ndigits);
-DIGIT_T mpAdd(DIGIT_T w[], const DIGIT_T u[], const DIGIT_T v[], size_t ndigits);
-int mpDivide(DIGIT_T q[], DIGIT_T r[], const DIGIT_T u[], size_t udigits, DIGIT_T v[], size_t vdigits);
int QhatTooBig(DIGIT_T qhat, DIGIT_T rhat, DIGIT_T vn2, DIGIT_T ujn2);
-DIGIT_T mpMultSub(DIGIT_T wn, DIGIT_T w[], const DIGIT_T v[], DIGIT_T q, size_t n);
-DIGIT_T mpShiftLeft(DIGIT_T a[], const DIGIT_T *b, size_t shift, size_t ndigits);
+DIGIT_T mpMultSub(DIGIT_T wn, DIGIT_T *w, const DIGIT_T *v, DIGIT_T q, size_t n);
+
+DIGIT_T mpShiftLeft( DIGIT_T *a, const DIGIT_T *b, size_t shift, size_t ndigits);
-void mpSetDigit(DIGIT_T a[], DIGIT_T d, size_t ndigits);
-int mpCompare(const DIGIT_T a[], const DIGIT_T b[], size_t ndigits);
+void mpSetDigit(DIGIT_T *a, DIGIT_T d, size_t ndigits);
+int mpCompare(const DIGIT_T *a, const DIGIT_T *b, size_t ndigits);
-DIGIT_T mpShiftRight(DIGIT_T a[], const DIGIT_T b[], size_t shift, size_t ndigits);
+
+DIGIT_T mpShiftRight( DIGIT_T *a, const DIGIT_T *b, size_t shift, size_t ndigits);
int spMultiply(uint p[2], uint x, uint y);
uint spDivide(uint *pq, uint *pr, const uint u[2], uint v);
-int mpSquare(DIGIT_T w[], const DIGIT_T x[], size_t ndigits);
+int mpSquare( DIGIT_T *w, const DIGIT_T *x, size_t ndigits);
-size_t mpBitLength(const DIGIT_T d[], size_t ndigits);
+size_t mpBitLength(const DIGIT_T *d, size_t ndigits);
-DIGIT_T mpShortDiv(DIGIT_T q[], const DIGIT_T u[], DIGIT_T v,
+DIGIT_T mpShortDiv( DIGIT_T *q, const DIGIT_T *u, DIGIT_T v,
size_t ndigits);
-void mpSetEqual(DIGIT_T a[], const DIGIT_T b[], size_t ndigits);
+void mpSetEqual(DIGIT_T *a, const DIGIT_T *b, size_t ndigits);
size_t uiceil(float x);
volatile uint8 zeroise_bytes(volatile void *v, size_t n);
-size_t mpSizeof(const DIGIT_T a[], size_t ndigits);
+size_t mpSizeof(const DIGIT_T *a, size_t ndigits);
-volatile DIGIT_T mpSetZero(volatile DIGIT_T a[], size_t ndigits);
+volatile DIGIT_T mpSetZero(volatile DIGIT_T *a, size_t ndigits);
-int mpIsZero(const DIGIT_T a[], size_t ndigits);
+int mpIsZero( const DIGIT_T *a, size_t ndigits);
void mpFail(char *msg);
-int mpModExpO(DIGIT_T *yout[], DIGIT_T *x[], DIGIT_T *e[], DIGIT_T *m[], size_t ndigits);
+int mpModExpO( DIGIT_T *yout, DIGIT_T *x, DIGIT_T *e, DIGIT_T *m, size_t ndigits);
void assert(bool precondition);
+// global memory definitions
+
+size_t mpSizeof_g(__global DIGIT_T *a, size_t ndigits);
+
+DIGIT_T mpShiftLeft_gg(__global DIGIT_T *a, __global const DIGIT_T *b, size_t shift, size_t ndigits);
+
+DIGIT_T mpShiftRight_gg(__global DIGIT_T *a, __global const DIGIT_T *b, size_t shift, size_t ndigits);
+
+DIGIT_T mpShiftLeft_lg(DIGIT_T *a, __global const DIGIT_T *b, size_t shift, size_t ndigits);
+
+int mpCompare_g(__global const DIGIT_T *a, __global const DIGIT_T *b, size_t ndigits);
+
+int mpCompare_lg(const DIGIT_T *a, __global const DIGIT_T *b, size_t ndigits);
+
+void mpSetEqual_lg(DIGIT_T *a, __global const DIGIT_T *b, size_t ndigits);
+void mpSetEqual_gl(__global DIGIT_T *a, const DIGIT_T *b, size_t ndigits);
+
+DIGIT_T mpMultSub_lg(DIGIT_T wn, DIGIT_T *w, __global const DIGIT_T *v, DIGIT_T q, size_t n);
+
+DIGIT_T mpAdd_llg( DIGIT_T *w, const DIGIT_T *u, __global const DIGIT_T *v, size_t ndigits);
+
+void mpSetDigit_g(__global DIGIT_T *a, DIGIT_T d, size_t ndigits);
+
+volatile DIGIT_T mpSetZero_g(__global volatile DIGIT_T *a, size_t ndigits);
// implementation
-int mpModulo(DIGIT_T r[], const DIGIT_T u[], size_t udigits,
- DIGIT_T v[], size_t vdigits)
+int mpModulo(__global DIGIT_T *r, DIGIT_T *u, size_t udigits, __global DIGIT_T *v, size_t vdigits)
{
/* Computes r = u mod v
where r, v are multiprecision integers of length vdigits
@@ -121,7 +149,7 @@ int mpModulo(DIGIT_T r[], const DIGIT_T u[], size_t udigits,
mpDivide(qq, rr, u, udigits, v, vdigits);
/* Final r is only vdigits long */
- mpSetEqual(r, rr, vdigits);
+ mpSetEqual_gl(r, rr, vdigits);
mpDESTROY(rr, udigits);
mpDESTROY(qq, udigits);
@@ -129,31 +157,26 @@ int mpModulo(DIGIT_T r[], const DIGIT_T u[], size_t udigits,
return 0;
}
-int mpModMult(DIGIT_T a[], const DIGIT_T x[], const DIGIT_T y[],
- DIGIT_T m[], size_t ndigits)
-{ /* Computes a = (x * y) mod m */
-
-/* Double-length temp variable p */
-// #ifdef NO_ALLOCS
- DIGIT_T p[MAX_FIXED_DIGITS * 2];
-// assert(ndigits <= MAX_FIXED_DIGITS);
-/*#else
- DIGIT_T *p;
- p = mpAlloc(ndigits * 2);
-#endif
-*/
- /* Calc p[2n] = x * y */
- mpMultiply(p, x, y, ndigits);
-
- /* Then modulo (NOTE: a is OK at only ndigits long) */
- mpModulo(a, p, ndigits * 2, m, ndigits);
-
- mpDESTROY(p, ndigits * 2);
-
- return 0;
-}
-
-int mpMultiply(DIGIT_T w[], const DIGIT_T u[], const DIGIT_T v[], size_t ndigits)
+//int mpModMult(__global DIGIT_T *a, __global DIGIT_T *x, const DIGIT_T *y, DIGIT_T *m, size_t ndigits)
+//{ /* Computes a = (x * y) mod m */
+//
+///* Double-length temp variable p */
+//
+// DIGIT_T p[MAX_FIXED_DIGITS * 2];
+//// assert(ndigits <= MAX_FIXED_DIGITS);
+//
+// //Calc p[2n] = x * y
+// mpMultiply(p, x, y, ndigits);
+//
+// /* Then modulo (NOTE: a is OK at only ndigits long) */
+// mpModulo(a, p, ndigits * 2, m, ndigits);
+//
+// mpDESTROY(p, ndigits * 2);
+//
+// return 0;
+//}
+
+int mpMultiply( DIGIT_T *w, __global DIGIT_T *u, const DIGIT_T *v, size_t ndigits)
{
/* Computes product w = u * v
where u, v are multiprecision integers of ndigits each
@@ -208,7 +231,7 @@ int mpMultiply(DIGIT_T w[], const DIGIT_T u[], const DIGIT_T v[], size_t ndigits
return 0;
}
-DIGIT_T mpAdd(DIGIT_T w[], const DIGIT_T u[], const DIGIT_T v[], size_t ndigits)
+DIGIT_T mpAdd( DIGIT_T *w, const DIGIT_T *u, const DIGIT_T *v, size_t ndigits)
{
/* Calculates w = u + v
where w, u, v are multiprecision integers of ndigits each
@@ -245,10 +268,44 @@ DIGIT_T mpAdd(DIGIT_T w[], const DIGIT_T u[], const DIGIT_T v[], size_t ndigits)
return k; /* w_n = k */
}
-// MARK: This function is causing problems – this function calls mpShiftLeft, mpShiftRight at some point (and so does mpShortDiv) they contain recursions, which are forbidden
+DIGIT_T mpAdd_llg( DIGIT_T *w, const DIGIT_T *u, __global const DIGIT_T *v, size_t ndigits)
+{
+ /* Calculates w = u + v
+ where w, u, v are multiprecision integers of ndigits each
+ Returns carry if overflow. Carry = 0 or 1.
+
+ Ref: Knuth Vol 2 Ch 4.3.1 p 266 Algorithm A.
+ */
+
+ DIGIT_T k;
+ size_t j;
+
+ // assert(w != v);
+
+ /* Step A1. Initialise */
+ k = 0;
+
+ for (j = 0; j < ndigits; j++)
+ {
+ /* Step A2. Add digits w_j = (u_j + v_j + k)
+ Set k = 1 if carry (overflow) occurs
+ */
+ w[j] = u[j] + k;
+ if (w[j] < k)
+ k = 1;
+ else
+ k = 0;
+
+ w[j] += v[j];
+ if (w[j] < v[j])
+ k++;
-int mpDivide(DIGIT_T q[], DIGIT_T r[], const DIGIT_T u[],
- size_t udigits, DIGIT_T v[], size_t vdigits)
+ } /* Step A3. Loop on j */
+
+ return k; /* w_n = k */
+}
+
+int mpDivide(DIGIT_T *q, DIGIT_T *r, DIGIT_T *u, size_t udigits, __global DIGIT_T *v, size_t vdigits)
{ /* Computes quotient q = u / v and remainder r = u mod v
where q, r, u are multiple precision digits
all of udigits and the divisor v is vdigits.
@@ -275,7 +332,7 @@ int mpDivide(DIGIT_T q[], DIGIT_T r[], const DIGIT_T u[],
mpSetZero(r, udigits);
/* Work out exact sizes of u and v */
- n = (int)mpSizeof(v, vdigits);
+ n = (int)mpSizeof_g(v, vdigits);
m = (int)mpSizeof(u, udigits);
m -= n;
@@ -297,7 +354,7 @@ int mpDivide(DIGIT_T q[], DIGIT_T r[], const DIGIT_T u[],
if (m == 0)
{ /* u and v are the same length */
- cmp = mpCompare(u, v, (size_t)n);
+ cmp = mpCompare_lg(u, v, (size_t)n);
if (cmp < 0)
{ /* v > u, as above */
mpSetEqual(r, u, udigits);
@@ -333,7 +390,7 @@ int mpDivide(DIGIT_T q[], DIGIT_T r[], const DIGIT_T u[],
}
/* Normalise v in situ - NB only shift non-zero digits */
- overflow = mpShiftLeft(v, v, shift, n);
+ overflow = mpShiftLeft_gg(v, v, shift, n);
/* Copy normalised dividend u*d into r */
overflow = mpShiftLeft(r, u, shift, n + m);
@@ -378,14 +435,14 @@ int mpDivide(DIGIT_T q[], DIGIT_T r[], const DIGIT_T u[],
/* Step D4. Multiply and subtract */
ww = &uu[j];
- overflow = mpMultSub(t[1], ww, v, qhat, (size_t)n);
+ overflow = mpMultSub_lg(t[1], ww, v, qhat, (size_t)n);
/* Step D5. Test remainder. Set Qj = Qhat */
q[j] = qhat;
if (overflow)
{ /* Step D6. Add back if D4 was negative */
q[j]--;
- overflow = mpAdd(ww, ww, v, (size_t)n);
+ overflow = mpAdd_llg(ww, ww, v, (size_t)n);
}
t[0] = uu[j+n-1]; /* Uj+n on next round */
@@ -399,12 +456,23 @@ int mpDivide(DIGIT_T q[], DIGIT_T r[], const DIGIT_T u[],
/* Step D8. Unnormalise. */
mpShiftRight(r, r, shift, n);
- mpShiftRight(v, v, shift, n);
+ mpShiftRight_gg(v, v, shift, n);
return 0;
}
-void mpSetDigit(DIGIT_T a[], DIGIT_T d, size_t ndigits)
+void mpSetDigit( DIGIT_T *a, DIGIT_T d, size_t ndigits)
+{ /* Sets a = d where d is a single digit */
+ size_t i;
+
+ for (i = 1; i < ndigits; i++)
+ {
+ a[i] = 0;
+ }
+ a[0] = d;
+}
+
+void mpSetDigit_g(__global DIGIT_T *a, DIGIT_T d, size_t ndigits)
{ /* Sets a = d where d is a single digit */
size_t i;
@@ -415,7 +483,7 @@ void mpSetDigit(DIGIT_T a[], DIGIT_T d, size_t ndigits)
a[0] = d;
}
-DIGIT_T mpShortDiv(DIGIT_T q[], const DIGIT_T u[], DIGIT_T v,
+DIGIT_T mpShortDiv( DIGIT_T *q, const DIGIT_T *u, DIGIT_T v,
size_t ndigits)
{
/* Calculates quotient q = u div v
@@ -487,7 +555,7 @@ int QhatTooBig(DIGIT_T qhat, DIGIT_T rhat,
return 0;
}
-DIGIT_T mpMultSub(DIGIT_T wn, DIGIT_T w[], const DIGIT_T v[],
+DIGIT_T mpMultSub(DIGIT_T wn, DIGIT_T *w, const DIGIT_T *v,
DIGIT_T q, size_t n)
{ /* Compute w = w - qv
where w = (WnW[n-1]...W[0])
@@ -521,8 +589,95 @@ DIGIT_T mpMultSub(DIGIT_T wn, DIGIT_T w[], const DIGIT_T v[],
return wn;
}
-DIGIT_T mpShiftLeft(DIGIT_T a[], const DIGIT_T *b,
- size_t shift, size_t ndigits)
+DIGIT_T mpMultSub_lg(DIGIT_T wn, DIGIT_T *w, __global const DIGIT_T *v, DIGIT_T q, size_t n)
+{ /* Compute w = w - qv
+ where w = (WnW[n-1]...W[0])
+ return modified Wn.
+ */
+ DIGIT_T k, t[2];
+ size_t i;
+
+ if (q == 0) /* No change */
+ return wn;
+
+ k = 0;
+
+ for (i = 0; i < n; i++)
+ {
+ spMultiply(t, q, v[i]);
+ w[i] -= k;
+ if (w[i] > MAX_DIGIT - k)
+ k = 1;
+ else
+ k = 0;
+ w[i] -= t[0];
+ if (w[i] > MAX_DIGIT - t[0])
+ k++;
+ k += t[1];
+ }
+
+ /* Cope with Wn not stored in array w[0..n-1] */
+ wn -= k;
+
+ return wn;
+}
+
+DIGIT_T mpShiftLeft( DIGIT_T *a, const DIGIT_T *b, size_t shift, size_t ndigits)
+{ /* Computes a = b << shift */
+ /* [v2.1] Modified to cope with shift > BITS_PERDIGIT */
+
+ DIGIT_T carry = 0;
+
+ // this replaces the recursion
+ while (1) {
+
+ size_t i, y, nw, bits;
+ DIGIT_T mask, tempCarry, nextcarry;
+
+ /* Do we shift whole digits? */
+ if (shift >= BITS_PER_DIGIT)
+ {
+ nw = shift / BITS_PER_DIGIT;
+ i = ndigits;
+ while (i--)
+ {
+ if (i >= nw)
+ a[i] = b[i-nw];
+ else
+ a[i] = 0;
+ }
+ /* Call again to shift bits inside digits */
+ bits = shift % BITS_PER_DIGIT;
+ tempCarry = b[ndigits-nw] << bits;
+ if (bits) {
+ carry |= tempCarry;
+ continue;
+ }
+ return carry;
+ }
+ else
+ {
+ bits = shift;
+ }
+
+ /* Construct mask = high bits set */
+ mask = ~(~(DIGIT_T)0 >> bits);
+
+ y = BITS_PER_DIGIT - bits;
+ carry = 0;
+ for (i = 0; i < ndigits; i++)
+ {
+ nextcarry = (b[i] & mask) >> y;
+ a[i] = b[i] << bits | carry;
+ carry = nextcarry;
+ }
+
+ return carry;
+
+ }
+}
+
+DIGIT_T mpShiftLeft_gg(__global DIGIT_T *a, __global const DIGIT_T *b, size_t shift, size_t ndigits)
{ /* Computes a = b << shift */
/* [v2.1] Modified to cope with shift > BITS_PERDIGIT */
@@ -577,7 +732,62 @@ DIGIT_T mpShiftLeft(DIGIT_T a[], const DIGIT_T *b,
}
}
-DIGIT_T mpShiftRight(DIGIT_T a[], const DIGIT_T b[], size_t shift, size_t ndigits)
+DIGIT_T mpShiftLeft_lg(DIGIT_T *a, __global const DIGIT_T *b, size_t shift, size_t ndigits)
+{ /* Computes a = b << shift */
+ /* [v2.1] Modified to cope with shift > BITS_PERDIGIT */
+
+ DIGIT_T carry = 0;
+
+ // this replaces the recursion
+ while (1) {
+
+ size_t i, y, nw, bits;
+ DIGIT_T mask, tempCarry, nextcarry;
+
+ /* Do we shift whole digits? */
+ if (shift >= BITS_PER_DIGIT)
+ {
+ nw = shift / BITS_PER_DIGIT;
+ i = ndigits;
+ while (i--)
+ {
+ if (i >= nw)
+ a[i] = b[i-nw];
+ else
+ a[i] = 0;
+ }
+ /* Call again to shift bits inside digits */
+ bits = shift % BITS_PER_DIGIT;
+ tempCarry = b[ndigits-nw] << bits;
+ if (bits) {
+ carry |= tempCarry;
+ continue;
+ }
+ return carry;
+ }
+ else
+ {
+ bits = shift;
+ }
+
+ /* Construct mask = high bits set */
+ mask = ~(~(DIGIT_T)0 >> bits);
+
+ y = BITS_PER_DIGIT - bits;
+ carry = 0;
+ for (i = 0; i < ndigits; i++)
+ {
+ nextcarry = (b[i] & mask) >> y;
+ a[i] = b[i] << bits | carry;
+ carry = nextcarry;
+ }
+
+ return carry;
+
+ }
+}
+
+DIGIT_T mpShiftRight( DIGIT_T *a, const DIGIT_T *b, size_t shift, size_t ndigits)
{ /* Computes a = b >> shift */
/* [v2.1] Modified to cope with shift > BITS_PERDIGIT */
@@ -630,7 +840,58 @@ DIGIT_T mpShiftRight(DIGIT_T a[], const DIGIT_T b[], size_t shift, size_t ndigit
}
}
+DIGIT_T mpShiftRight_gg(__global DIGIT_T *a, __global const DIGIT_T *b, size_t shift, size_t ndigits)
+{ /* Computes a = b >> shift */
+ /* [v2.1] Modified to cope with shift > BITS_PERDIGIT */
+
+ DIGIT_T carry = 0;
+
+ while (1) {
+
+ size_t i, y, nw, bits;
+ DIGIT_T mask, tempCarry, nextcarry;
+
+ /* Do we shift whole digits? */
+ if (shift >= BITS_PER_DIGIT)
+ {
+ nw = shift / BITS_PER_DIGIT;
+ for (i = 0; i < ndigits; i++)
+ {
+ if ((i+nw) < ndigits)
+ a[i] = b[i+nw];
+ else
+ a[i] = 0;
+ }
+ /* Call again to shift bits inside digits */
+ bits = shift % BITS_PER_DIGIT;
+ tempCarry = b[nw-1] >> bits;
+ if (bits)
+ carry |= tempCarry;
+ return carry;
+ }
+ else
+ {
+ bits = shift;
+ }
+
+ /* Construct mask to set low bits */
+ /* (thanks to Jesse Chisholm for suggesting this improved technique) */
+ mask = ~(~(DIGIT_T)0 << bits);
+
+ y = BITS_PER_DIGIT - bits;
+ carry = 0;
+ i = ndigits;
+ while (i--)
+ {
+ nextcarry = (b[i] & mask) << y;
+ a[i] = b[i] >> bits | carry;
+ carry = nextcarry;
+ }
+ return carry;
+
+ }
+}
int spMultiply(uint p[2], uint x, uint y)
{
@@ -654,7 +915,22 @@ uint spDivide(uint *pq, uint *pr, const uint u[2], uint v)
return (uint)(q >> 32);
}
-int mpCompare(const DIGIT_T a[], const DIGIT_T b[], size_t ndigits)
+int mpCompare(const DIGIT_T *a, const DIGIT_T *b, size_t ndigits)
+{
+ /* if (ndigits == 0) return 0; // deleted [v2.5] */
+
+ while (ndigits--)
+ {
+ if (a[ndigits] > b[ndigits])
+ return 1; /* GT */
+ if (a[ndigits] < b[ndigits])
+ return -1; /* LT */
+ }
+
+ return 0; /* EQ */
+}
+
+int mpCompare_g(__global const DIGIT_T *a, __global const DIGIT_T *b, size_t ndigits)
{
/* if (ndigits == 0) return 0; // deleted [v2.5] */
@@ -669,7 +945,42 @@ int mpCompare(const DIGIT_T a[], const DIGIT_T b[], size_t ndigits)
return 0; /* EQ */
}
-void mpSetEqual(DIGIT_T a[], const DIGIT_T b[], size_t ndigits)
+int mpCompare_lg(const DIGIT_T *a, __global const DIGIT_T *b, size_t ndigits)
+{
+ /* if (ndigits == 0) return 0; // deleted [v2.5] */
+
+ while (ndigits--)
+ {
+ if (a[ndigits] > b[ndigits])
+ return 1; /* GT */
+ if (a[ndigits] < b[ndigits])
+ return -1; /* LT */
+ }
+
+ return 0; /* EQ */
+}
+
+void mpSetEqual(DIGIT_T *a, const DIGIT_T *b, size_t ndigits)
+{ /* Sets a = b */
+ size_t i;
+
+ for (i = 0; i < ndigits; i++)
+ {
+ a[i] = b[i];
+ }
+}
+
+void mpSetEqual_lg(DIGIT_T *a, __global const DIGIT_T *b, size_t ndigits)
+{ /* Sets a = b */
+ size_t i;
+
+ for (i = 0; i < ndigits; i++)
+ {
+ a[i] = b[i];
+ }
+}
+
+void mpSetEqual_gl(__global DIGIT_T *a, const DIGIT_T *b, size_t ndigits)
{ /* Sets a = b */
size_t i;
@@ -679,7 +990,7 @@ void mpSetEqual(DIGIT_T a[], const DIGIT_T b[], size_t ndigits)
}
}
-volatile DIGIT_T mpSetZero(volatile DIGIT_T a[], size_t ndigits)
+volatile DIGIT_T mpSetZero(volatile DIGIT_T *a, size_t ndigits)
{ /* Sets a = 0 */
/* Prevent optimiser ignoring this */
@@ -693,7 +1004,31 @@ volatile DIGIT_T mpSetZero(volatile DIGIT_T a[], size_t ndigits)
return optdummy;
}
-size_t mpSizeof(const DIGIT_T a[], size_t ndigits)
+volatile DIGIT_T mpSetZero_g(__global volatile DIGIT_T *a, size_t ndigits)
+{ /* Sets a = 0 */
+
+ /* Prevent optimiser ignoring this */
+ volatile DIGIT_T optdummy;
+ __global volatile DIGIT_T *p = a;
+
+ while (ndigits--)
+ a[ndigits] = 0;
+
+ optdummy = *p;
+ return optdummy;
+}
+
+size_t mpSizeof(const DIGIT_T *a, size_t ndigits)
+{
+ while(ndigits--)
+ {
+ if (a[ndigits] != 0)
+ return (++ndigits);
+ }
+ return 0;
+}
+
+size_t mpSizeof_g(__global DIGIT_T *a, size_t ndigits)
{
while(ndigits--)
{
@@ -719,7 +1054,7 @@ void mpFail(char *msg)
printf("the program should stop here");
}
-size_t mpBitLength(const DIGIT_T d[], size_t ndigits)
+size_t mpBitLength( const DIGIT_T *d, size_t ndigits)
/* Returns no of significant bits in d */
{
size_t n, i, bits;
@@ -741,8 +1076,7 @@ size_t mpBitLength(const DIGIT_T d[], size_t ndigits)
return bits;
}
-
-
+/*
void mpModSquareTemp(DIGIT_T *y,DIGIT_T *m,size_t n,DIGIT_T *t1,DIGIT_T *t2) {
mpSquare(t1,y,n);
@@ -757,88 +1091,8 @@ void mpModMultTemp(DIGIT_T *y, DIGIT_T *x, DIGIT_T *m, size_t n, DIGIT_T* t1, DI
}
-
-
-
-
-int mpModExpO(DIGIT_T *yout[], DIGIT_T *x[], DIGIT_T *e[], DIGIT_T *m[], size_t ndigits)
-{ /* Computes y = x^e mod m */
- /* "Classic" binary left-to-right method */
- /* [v2.2] removed const restriction on m[] to avoid using an extra alloc'd var
- (m is changed in-situ during the divide operation then restored) */
- DIGIT_T mask;
- size_t n;
- size_t nn = ndigits * 2;
- /* Create some double-length temps */
-//#ifdef NO_ALLOCS
- DIGIT_T t1[MAX_FIXED_DIGITS * 2];
- DIGIT_T t2[MAX_FIXED_DIGITS * 2];
- DIGIT_T y[MAX_FIXED_DIGITS * 2];
- assert(ndigits <= MAX_FIXED_DIGITS);
-/*#else
- DIGIT_T *t1, *t2, *y;
- t1 = mpAlloc(nn);
- t2 = mpAlloc(nn);
- y = mpAlloc(nn);
-#endif
- */
- assert(ndigits != 0);
-
- n = mpSizeof(*e, ndigits);
- /* Catch e==0 => x^0=1 */
- if (0 == n)
- {
- mpSetDigit(*yout, 1, ndigits);
- goto done;
- }
- /* Find second-most significant bit in e */
- for (mask = HIBITMASK; mask > 0; mask >>= 1)
- {
- if (*e[n-1] & mask)
- break;
- }
- mpNEXTBITMASK(mask, n);
-
- /* Set y = x */
- mpSetEqual(*y, *x, ndigits);
-
- /* For bit j = k-2 downto 0 */
- while (n)
- {
- /* Square y = y * y mod n */
- //mpMODSQUARETEMP(*y, *m, ndigits, t1, t2);
- //mpModSquareTemp(*y, *m, ndigits, t1, t2);
-
- mpSquare(t1,y,n);
- mpDivide(t2,y,t1,n*2,m,n);
-
- if (*e[n-1] & mask)
- { /* if e(j) == 1 then multiply
- y = y * x mod n */
- //mpMODMULTTEMP(*y, *x, *m, ndigits, t1, t2);
- //mpModMultTemp(*y, *x, *m, ndigits, t1, t2);
-
- mpMultiply(t1,x,y,n);
- mpDivide(t2,y,t1,n*2,m,n);
-
- }
-
- /* Move to next bit */
- mpNEXTBITMASK(mask, n);
- }
-
- /* Return y */
- mpSetEqual(*yout, y, ndigits);
-
-done:
- mpDESTROY(t1, nn);
- mpDESTROY(t2, nn);
- mpDESTROY(y, ndigits);
-
- return 0;
-}
-
-int mpSquare(DIGIT_T w[], const DIGIT_T x[], size_t ndigits)
+*/
+int mpSquare(DIGIT_T *w, const DIGIT_T *x, size_t ndigits)
/* New in Version 2.0 */
{
/* Computes square w = x * x
@@ -947,10 +1201,7 @@ int mpSquare(DIGIT_T w[], const DIGIT_T x[], size_t ndigits)
return 0;
}
-
-
-
-int mpIsZero(const DIGIT_T a[], size_t ndigits)
+int mpIsZero( const DIGIT_T *a, size_t ndigits)
{
size_t i;
@@ -965,7 +1216,6 @@ int mpIsZero(const DIGIT_T a[], size_t ndigits)
return (!0); /* True */
}
-
void assert(bool precondition) {
char str[] = "assert reached, also this message leaks memory";
@@ -976,21 +1226,95 @@ void assert(bool precondition) {
}
+
+//int mpModExpO(__global DIGIT_T *yout, __global DIGIT_T *x, __global DIGIT_T *e, __global DIGIT_T *m, size_t ndigits)
+
// some might be constants
-__kernel void single(global DIGIT_T* s, const unsigned int s_len,
- global DIGIT_T* e, const unsigned int e_len,
- global DIGIT_T* n, const unsigned int n_len,
- global DIGIT_T* res, const unsigned int res_len,
+__kernel void single(__global DIGIT_T* x, const unsigned int s_len,
+ __global DIGIT_T* e, const unsigned int e_len,
+ __global DIGIT_T* m, const unsigned int n_len,
+ __global DIGIT_T *yout, const unsigned int res_len,
//global DIGIT_T* comp, const unsigned int comp_len,
- const unsigned int max_len,
- global int8* valid
+ const unsigned int ndigits,
+ __global int8* valid
//const unsigned int count
)
{
+ // memory(res);
+
+
+ // __global DIGIT_T * __local ptr_x;
+
+ // ptr_x = x;
+ // ptr_x[3] = 4;
- mpModExpO(&res,&s,&e,&n,max_len);
+ // mpModExpO(res,s,e,n,max_len);
+
+ DIGIT_T mask;
+ size_t n;
+ size_t nn = ndigits * 2;
+ /* Create some double-length temps */
+
+ DIGIT_T t1[MAX_FIXED_DIGITS * 2];
+ DIGIT_T t2[MAX_FIXED_DIGITS * 2];
+ DIGIT_T y[MAX_FIXED_DIGITS * 2];
+ assert(ndigits <= MAX_FIXED_DIGITS);
+
+ assert(ndigits != 0);
+
+ n = mpSizeof_g(e, ndigits);
+ /* Catch e==0 => x^0=1 */
+ if (0 == n)
+ {
+ mpSetDigit_g(yout, 1, ndigits);
+ goto done;
+ }
+ /* Find second-most significant bit in e */
+ for (mask = HIBITMASK; mask > 0; mask >>= 1)
+ {
+ if (e[n-1] & mask)
+ break;
+ }
+ mpNEXTBITMASK(mask, n);
+ /* Set y = x */
+ mpSetEqual_lg(y, x, ndigits);
+
+ /* For bit j = k-2 downto 0 */
+ while (n)
+ {
+ /* Square y = y * y mod n */
+ mpMODSQUARETEMP(y, m, ndigits, t1, t2);
+ //mpModSquareTemp(*y, *m, ndigits, t1, t2);
+
+ //mpSquare(t1,y,n);
+ // mpDivide(t2,y,t1,n*2,m,n);
+
+ if (e[n-1] & mask)
+ { /* if e(j) == 1 then multiply
+ y = y * x mod n */
+ mpMODMULTTEMP(y, x, m, ndigits, t1, t2);
+ //mpModMultTemp(*y, *x, *m, ndigits, t1, t2);
+
+ // mpMultiply(t1,x,y,n);
+ // mpDivide(t2,y,t1,n*2,m,n);
+
+ }
+
+ /* Move to next bit */
+ mpNEXTBITMASK(mask, n);
+ }
+
+ /* Return y */
+ mpSetEqual_gl(yout, y, ndigits);
+
+done:
+ mpDESTROY(t1, nn);
+ mpDESTROY(t2, nn);
+ mpDESTROY(y, ndigits);
+
+
}