commit e7b32caad0c0abf324ccd1469062046d7b2f2a3d
parent 8c195299c2dbac1ea152aaa5b61a769116743405
Author: Cedric <cedric.zwahlen@students.bfh.ch>
Date: Fri, 24 Nov 2023 13:49:10 +0100
Extend montgomery.cl
Not tested yet
Diffstat:
4 files changed, 594 insertions(+), 12 deletions(-)
diff --git a/source/lib-gpu-verify.c b/source/lib-gpu-verify.c
@@ -35,7 +35,7 @@ int main(int argc, char** argv)
//opencl_tests();
- rsa_tests();
+ //rsa_tests();
// montgomery_test();
diff --git a/xcode/lib-gpu-verify.xcodeproj/project.xcworkspace/xcuserdata/cedriczwahlen.xcuserdatad/UserInterfaceState.xcuserstate b/xcode/lib-gpu-verify.xcodeproj/project.xcworkspace/xcuserdata/cedriczwahlen.xcuserdatad/UserInterfaceState.xcuserstate
Binary files differ.
diff --git a/xcode/lib-gpu-verify.xcodeproj/xcuserdata/cedriczwahlen.xcuserdatad/xcdebugger/Breakpoints_v2.xcbkptlist b/xcode/lib-gpu-verify.xcodeproj/xcuserdata/cedriczwahlen.xcuserdatad/xcdebugger/Breakpoints_v2.xcbkptlist
@@ -2888,8 +2888,8 @@
filePath = "montgomery.cl"
startingColumnNumber = "9223372036854775807"
endingColumnNumber = "9223372036854775807"
- startingLineNumber = "1662"
- endingLineNumber = "1662"
+ startingLineNumber = "1664"
+ endingLineNumber = "1664"
landmarkName = "mpz_sizeinbase()"
landmarkType = "9">
</BreakpointContent>
diff --git a/xcode/montgomery.cl b/xcode/montgomery.cl
@@ -213,6 +213,7 @@ typedef struct
typedef __mpz_struct mpz_t[1];
typedef __mpz_struct *mpz_ptr;
+
typedef const __mpz_struct *mpz_srcptr;
struct gmp_div_inverse
@@ -292,6 +293,7 @@ mpz_set (mpz_t r, const mpz_t x)
}
}
+
void
mpz_set_ui (mpz_t r, unsigned long int x)
{
@@ -1749,6 +1751,8 @@ mpz_abs (mpz_t r, const mpz_t u)
mpz_set (r, u);
r->_mp_size = GMP_ABS (r->_mp_size);
}
+
+
mp_bitcnt_t
mpn_scan1 (mp_srcptr ptr, mp_bitcnt_t bit)
{
@@ -1759,6 +1763,42 @@ mpn_scan1 (mp_srcptr ptr, mp_bitcnt_t bit)
i, ptr, i, 0);
}
+mp_bitcnt_t
+mpz_scan1 (const mpz_t u, mp_bitcnt_t starting_bit)
+{
+ mp_ptr up;
+ mp_size_t us, un, i;
+ mp_limb_t limb, ux;
+
+ us = u->_mp_size;
+ un = GMP_ABS (us);
+ i = starting_bit / GMP_LIMB_BITS;
+
+ /* Past the end there's no 1 bits for u>=0, or an immediate 1 bit
+ for u<0. Notice this test picks up any u==0 too. */
+ if (i >= un)
+ return (us >= 0 ? ~(mp_bitcnt_t) 0 : starting_bit);
+
+ up = u->_mp_d;
+ ux = 0;
+ limb = up[i];
+
+ if (starting_bit != 0)
+ {
+ if (us < 0)
+ {
+ ux = mpn_zero_p (up, i);
+ limb = ~ limb + ux;
+ ux = - (mp_limb_t) (limb >= ux);
+ }
+
+ /* Mask to 0 all bits before starting_bit, thus ignoring them. */
+ limb &= GMP_LIMB_MAX << (starting_bit % GMP_LIMB_BITS);
+ }
+
+ return mpn_common_scan (limb, i, up, un, ux);
+}
+
mp_bitcnt_t
mpz_make_odd (mpz_t r)
@@ -2063,15 +2103,557 @@ mpz_addmul_ui (mpz_t r, const mpz_t u, unsigned long int v)
mpz_clear (t);
}
-__kernel void montgomery(__global unsigned long* x, __global const unsigned long *s_len,
- __global unsigned long* e, __global const unsigned long *e_len,
- __global unsigned long* m, __global const unsigned long *n_len,
- __global unsigned long *mm, __global const unsigned long *mm_len,
- __global unsigned long* valid,
- const unsigned int count,
- const unsigned int pks
- ) {
+
+// STRING CONVERSION
+
+unsigned
+mpn_base_power_of_two_p (unsigned b)
+{
+ switch (b)
+ {
+ case 2: return 1;
+ case 4: return 2;
+ case 8: return 3;
+ case 16: return 4;
+ case 32: return 5;
+ case 64: return 6;
+ case 128: return 7;
+ case 256: return 8;
+ default: return 0;
+ }
+}
+
+struct mpn_base_info
+{
+ /* bb is the largest power of the base which fits in one limb, and
+ exp is the corresponding exponent. */
+ unsigned exp;
+ mp_limb_t bb;
+};
+
+void
+mpn_get_base_info (struct mpn_base_info *info, mp_limb_t b)
+{
+ mp_limb_t m;
+ mp_limb_t p;
+ unsigned exp;
+
+ m = GMP_LIMB_MAX / b;
+ for (exp = 1, p = b; p <= m; exp++)
+ p *= b;
+
+ info->exp = exp;
+ info->bb = p;
+}
+
+int isspace(unsigned char c) {
+ if (c == '\n' || c == ' ' || c == '\t' || c == '\r' || c == '\f' || c == '\v')
+ return 1;
+ return 0;
+}
+
+int strlen(__constant char *c) {
+
+ // rather naive implementation – we assume a string is terminated, and is not 0 characters long.
+
+ int i = 0;
+ while (1) {
+ if (c[i] == '\0')
+ return i;
+ }
+ return i;
+}
+
+mp_size_t
+mpn_set_str_bits (mp_ptr rp, const unsigned char *sp, size_t sn,
+ unsigned bits)
+{
+ mp_size_t rn;
+ mp_limb_t limb;
+ unsigned shift;
+
+ for (limb = 0, rn = 0, shift = 0; sn-- > 0; )
+ {
+ limb |= (mp_limb_t) sp[sn] << shift;
+ shift += bits;
+ if (shift >= GMP_LIMB_BITS)
+ {
+ shift -= GMP_LIMB_BITS;
+ rp[rn++] = limb;
+ /* Next line is correct also if shift == 0,
+ bits == 8, and mp_limb_t == unsigned char. */
+ limb = (unsigned int) sp[sn] >> (bits - shift);
+ }
+ }
+ if (limb != 0)
+ rp[rn++] = limb;
+ else
+ rn = mpn_normalized_size (rp, rn);
+ return rn;
+}
+
+mp_size_t
+mpn_set_str_other (mp_ptr rp, const unsigned char *sp, size_t sn,
+ mp_limb_t b, const struct mpn_base_info *info)
+{
+ mp_size_t rn;
+ mp_limb_t w;
+ unsigned k;
+ size_t j;
+
+ assert (sn > 0);
+
+ k = 1 + (sn - 1) % info->exp;
+
+ j = 0;
+ w = sp[j++];
+ while (--k != 0)
+ w = w * b + sp[j++];
+
+ rp[0] = w;
+
+ for (rn = 1; j < sn;)
+ {
+ mp_limb_t cy;
+
+ w = sp[j++];
+ for (k = 1; k < info->exp; k++)
+ w = w * b + sp[j++];
+
+ cy = mpn_mul_1 (rp, rp, rn, info->bb);
+ cy += mpn_add_1 (rp, rp, rn, w);
+ if (cy > 0)
+ rp[rn++] = cy;
+ }
+ assert (j == sn);
+
+ return rn;
+}
+
+
+int
+mpz_set_str (mpz_t r, __constant char *sp, int base)
+{
+ unsigned bits, value_of_a;
+ mp_size_t rn, alloc;
+ mp_ptr rp;
+ size_t dn, sn;
+ int sign;
+ unsigned char dp[2048];
+
+ assert (base == 0 || (base >= 2 && base <= 62));
+
+ while (isspace( (unsigned char) *sp))
+ sp++;
+
+ sign = (*sp == '-');
+ sp += sign;
+
+ if (base == 0)
+ {
+ if (sp[0] == '0')
+ {
+ if (sp[1] == 'x' || sp[1] == 'X')
+ {
+ base = 16;
+ sp += 2;
+ }
+ else if (sp[1] == 'b' || sp[1] == 'B')
+ {
+ base = 2;
+ sp += 2;
+ }
+ else
+ base = 8;
+ }
+ else
+ base = 10;
+ }
+
+ if (!*sp)
+ {
+ r->_mp_size = 0;
+ return -1;
+ }
+ sn = strlen(sp);
+ //dp = (unsigned char *) gmp_alloc (sn);
+
+
+ value_of_a = (base > 36) ? 36 : 10;
+ for (dn = 0; *sp; sp++)
+ {
+ unsigned digit;
+
+ if (isspace ((unsigned char) *sp))
+ continue;
+ else if (*sp >= '0' && *sp <= '9')
+ digit = *sp - '0';
+ else if (*sp >= 'a' && *sp <= 'z')
+ digit = *sp - 'a' + value_of_a;
+ else if (*sp >= 'A' && *sp <= 'Z')
+ digit = *sp - 'A' + 10;
+ else
+ digit = base; /* fail */
+
+ if (digit >= (unsigned) base)
+ {
+ //gmp_free (dp, sn);
+ r->_mp_size = 0;
+ return -1;
+ }
+
+ dp[dn++] = digit;
+ }
+
+ if (!dn)
+ {
+ //gmp_free (dp, sn);
+ r->_mp_size = 0;
+ return -1;
+ }
+ bits = mpn_base_power_of_two_p (base);
+
+ if (bits > 0)
+ {
+ alloc = (dn * bits + GMP_LIMB_BITS - 1) / GMP_LIMB_BITS;
+ rp = MPZ_REALLOC (r, alloc);
+ rn = mpn_set_str_bits (rp, dp, dn, bits);
+ }
+ else
+ {
+ struct mpn_base_info info;
+ mpn_get_base_info (&info, base);
+ alloc = (dn + info.exp - 1) / info.exp;
+ rp = MPZ_REALLOC (r, alloc);
+ rn = mpn_set_str_other (rp, dp, dn, base, &info);
+ /* Normalization, needed for all-zero input. */
+ assert (rn > 0);
+ rn -= rp[rn-1] == 0;
+ }
+ assert (rn <= alloc);
+ //gmp_free (dp, sn);
+
+ r->_mp_size = sign ? - rn : rn;
+
+ return 0;
+}
+
+
+
+int
+mpz_init_set_str (mpz_t r, __constant char *sp, int base)
+{
+ mpz_init (r);
+ return mpz_set_str (r, sp, base);
+}
+
+
+
+
+
+
+
+
+// Montgomery multiplication
+
+void mont_prepare(mpz_t b, mpz_t e, mpz_t m,
+ mpz_t r, mpz_t r_1,
+ mpz_t ni, mpz_t M, mpz_t x
+ );
+
+void mont_product(mpz_t ret,
+ const mpz_t a, const mpz_t b,
+ const mpz_t r, const mpz_t r_1,
+ const mpz_t n, const mpz_t ni
+ );
+
+void mont_modexp(mpz_t ret,
+ mpz_t a, mpz_t e,
+ const mpz_t M,
+ const mpz_t n, const mpz_t ni,
+ const mpz_t r, const mpz_t r_1
+ );
+
+void mont_finish(mpz_t ret,
+ const mpz_t xx,
+ const mpz_t n, const mpz_t ni,
+ const mpz_t r, const mpz_t r_1
+ );
+
+void mont_prepare_even_modulus(mpz_t m, mpz_t q, mpz_t powj);
+
+void mont_mulmod(mpz_t res, const mpz_t a, const mpz_t b, const mpz_t mod);
+
+
+
+
+void mont_prepare_even_modulus(mpz_t m, mpz_t q, mpz_t powj) {
+
+ mpz_t two; // powj == 2^j
+
+ mpz_init_set_ui(two, 2);
+
+ mp_bitcnt_t j = mpz_scan1(m, 0);
+
+ mpz_tdiv_q_2exp(q,m,j);
+ mpz_mul_2exp(powj,two,j - 1);
+
+ mpz_clear(two);
+
+}
+
+// CPU
+void mont_prepare(mpz_t b, mpz_t e, mpz_t m,
+ mpz_t r, mpz_t r_1,
+ mpz_t ni, mpz_t M, mpz_t x) {
+
+ // MARK: break this up, reduce the amount of temporary variables
+
+ // r and n (modulus) must be relatively prime (this is a given if n (modulus) is odd)
+
+ // calculate r, which must be larger than the modulo and also a power of 2
+
+ mpz_t one, oo; // some helper variables
+ mpz_init_set_si(one,1);
+ mpz_init_set_si(oo,0);
+
+ size_t len = mpz_sizeinbase(m,2);
+
+ mpz_mul_2exp(r,one,len);
+
+ mpz_set_si(one, 0);
+
+ mpz_gcdext(one, r_1, ni, r, m); // set r_1 and ni
+
+ int sgn = mpz_sgn(r_1);
+
+ mpz_abs(r_1, r_1);
+ mpz_abs(ni, ni);
+
+ if (sgn == -1) {
+ mpz_sub(ni, r, ni);
+ mpz_sub(r_1, m, r_1);
+ }
+
+ if (mpz_cmp_ui(one, 1))
+ assert(0);
+
+ mpz_mul(one, r, r_1);
+ mpz_mul(oo,ni,m);
+
+ mpz_sub(one, one, oo); // oo must be one
+
+ if (mpz_cmp_ui(one, 1))
+ assert(0);
+
+ mpz_mul(M, b, r);
+ mpz_mod(M, M, m); // set M
+
+ mpz_mod(x, r, m); // set x
+
+ mpz_clear(one);
+ mpz_clear(oo);
+
+}
+
+// maybe GPU?
+// MARK: n MUST be an odd number
+void mont_modexp(mpz_t ret,
+ mpz_t a, mpz_t e,
+ const mpz_t M,
+ const mpz_t n, const mpz_t ni,
+ const mpz_t r, const mpz_t r_1
+ ) {
+
+ mpz_t aa,xx;
+
+ mpz_init_set(aa, M);
+ mpz_init_set(xx, a);
+
+ int k = (int)mpz_sizeinbase(e,2);
+
+ for (int i = k - 1; i >= 0; i--) {
+
+ mont_product(xx, xx, xx, r, r_1, n, ni);
+
+ if (mpz_tstbit(e, i))
+ mont_product(xx, aa, xx, r, r_1, n, ni);
+
+ }
+
+ mpz_set(ret, xx);
+
+ mpz_clear(aa);
+ mpz_clear(xx);
+
+}
+
+void mont_finish(mpz_t ret,
+ const mpz_t xx,
+ const mpz_t n, const mpz_t ni,
+ const mpz_t r, const mpz_t r_1
+ ) {
+
+
+ mpz_t x,one;
+
+ mpz_init(x);
+ mpz_init_set_ui(one, 1);
+
+ mont_product(x, xx, one, r, r_1, n, ni);
+
+ mpz_set(ret, x);
+
+ mpz_clear(x);
+ mpz_clear(one);
+
+}
+
+
+// GPU
+void mont_product(mpz_t ret,
+ const mpz_t a, const mpz_t b,
+ const mpz_t r, const mpz_t r_1,
+ const mpz_t n, const mpz_t ni
+ ) {
+
+ mpz_t t,m,u;
+
+ mpz_init(t);
+ mpz_init(m);
+ mpz_init(u);
+
+
+
+ mont_mulmod(t, b, a, r);
+
+ mont_mulmod(m, ni, t, r);
+
+ mpz_t ab,mn;
+
+ mpz_init(ab);
+ mpz_init(mn);
+
+ mpz_mul(ab, a, b);
+ mpz_mul(mn, m, n);
+
+ mpz_add(ab, ab, mn);
+
+ unsigned long sz = mpz_sizeinbase(r,2) - 1;
+ mpz_tdiv_q_2exp(u, ab, sz); // this is essentially a bit shift, instead of a division
+
+ if (mpz_cmp(u, n) >= 0)
+ mpz_sub(u, u, n);
+
+ mpz_set(ret, u);
+
+ mpz_clear(ab);
+ mpz_clear(mn);
+ mpz_clear(t);
+ mpz_clear(m);
+ mpz_clear(u);
+
+}
+
+
+// not the fastest... but it does not increase the variable sizes
+void mont_mulmod(mpz_t res, const mpz_t a, const mpz_t b, const mpz_t mod) {
+
+ mpz_t aa, bb;
+ mpz_init_set(aa, a);
+ mpz_init_set(bb,b);
+
+ mpz_mod(aa, aa, mod); // in case a is bigger
+
+ while (mpz_cmp_ui(bb, 0) > 0) {
+ if (mpz_odd_p(bb)) {
+ mpz_add(res, res, aa);
+ mpz_mod(res, res, mod);
+ }
+
+ mpz_mul_2exp(aa,aa,1);
+ mpz_mod(aa, aa, mod);
+ mpz_tdiv_q_2exp(bb, bb, 1);
+ }
+}
+
+
+
+__kernel void montgomery(__constant unsigned long *valid, __constant char *base, __constant char *exponent, __constant char *modulus, __constant char *signature) {
+
+ int radix = 16;
+
+ mpz_t b,e,m, res;
+
+ mpz_init(res);
+
+ mpz_init_set_str(b,base,radix); // M
+ mpz_init_set_str(e,exponent,radix);
+ mpz_init_set_str(m,modulus,radix); // n
+
+ mpz_t r, r_1, ni, M, x;
+ mpz_init(r); // MARK: I think I have to destroy these myself
+ mpz_init(r_1);
+ mpz_init(ni);
+ mpz_init(M);
+ mpz_init(x);
+
+ mpz_t xx;
+ mpz_init(xx);
+
+ if (mpz_even_p(m)) {
+
+ mpz_t bb, x1, x2, q, powj;
+ mpz_init(bb);
+ mpz_init(x1);
+ mpz_init(x2);
+ mpz_init(q);
+ mpz_init(powj);
+
+ mont_prepare_even_modulus(m, q, powj);
+
+ // q is uneven, so we can use regular modexp
+ // MARK: we can improve the efficiency here by doing simple reductions
+
+ mpz_mod(bb, b, q); // reductions like this
+
+ mont_prepare(bb, e, q, r, r_1, ni, M, x);
+ mont_modexp(xx, x, e, M, q, ni, r, r_1);
+ mont_finish(x1, xx, q, ni, r, r_1);
+
+
+ // MARK: we can also reduce and really speed this up as well -> binary method?
+ mpz_powm(x2, b, e, powj);
+
+ mpz_t y, q_1;
+ mpz_init(y);
+ mpz_init(q_1);
+
+ mpz_sub(y, x2, x1);
+
+ mpz_invert(q_1, q, powj);
+
+ mpz_mul(y, y, q_1);
+ mpz_mod(y, y, powj);
+
+ mpz_addmul(x1, q, y);
+
+ mpz_set(res, x1);
+
+ printf("--\n");
+
+
+
+ } else {
+
+ mont_prepare(b, e, m, r, r_1, ni, M, x);
+
+ mont_modexp(xx, x, e, M, m, ni, r, r_1);
+
+ mont_finish(res, xx, m, ni, r, r_1);
+
+ }
+
- *valid = 1;
}