bignum.c 63 KB


  1. /**
  2. * Copyright (c) 2013-2014 Tomas Dzetkulic
  3. * Copyright (c) 2013-2014 Pavol Rusnak
  4. * Copyright (c) 2015 Jochen Hoenicke
  5. * Copyright (c) 2016 Alex Beregszaszi
  6. *
  7. * Permission is hereby granted, free of charge, to any person obtaining
  8. * a copy of this software and associated documentation files (the "Software"),
  9. * to deal in the Software without restriction, including without limitation
  10. * the rights to use, copy, modify, merge, publish, distribute, sublicense,
  11. * and/or sell copies of the Software, and to permit persons to whom the
  12. * Software is furnished to do so, subject to the following conditions:
  13. *
  14. * The above copyright notice and this permission notice shall be included
  15. * in all copies or substantial portions of the Software.
  16. *
  17. * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
  18. * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
  19. * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
  20. * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES
  21. * OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE,
  22. * ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
  23. * OTHER DEALINGS IN THE SOFTWARE.
  24. */
  25. #include "bignum.h"
  26. #include <assert.h>
  27. #include <stdint.h>
  28. #include <stdio.h>
  29. #include <string.h>
  30. #include "memzero.h"
  31. #include "script.h"
  32. /*
  33. This library implements 256-bit numbers arithmetic.
  34. An unsigned 256-bit number is represented by a bignum256 structure, that is an
  35. array of nine 32-bit values called limbs. Limbs are digits of the number in
  36. the base 2**29 representation in the little endian order. This means that
  37. bignum256 x;
  38. represents the value
  39. sum([x[i] * 2**(29*i) for i in range(9)).
  40. A limb of a bignum256 is *normalized* iff it's less than 2**29.
  41. A bignum256 is *normalized* iff every its limb is normalized.
  42. A number is *fully reduced modulo p* iff it is less than p.
  43. A number is *partly reduced modulo p* iff is is less than 2*p.
  44. The number p is usually a prime number such that 2^256 - 2^224 <= p <= 2^256.
  45. All functions except bn_fast_mod expect that all their bignum256 inputs are
  46. normalized. (The function bn_fast_mod allows the input number to have the
  47. most significant limb unnormalized). All bignum256 outputs of all functions
  48. are guaranteed to be normalized.
  49. A number can be partly reduced with bn_fast_mod, a partly reduced number can
  50. be fully reduced with bn_mod.
  51. A function has *constant control flow with regard to its argument* iff the
  52. order in which instructions of the function are executed doesn't depend on the
  53. value of the argument.
  54. A function has *constant memory access flow with regard to its argument* iff
  55. the memory addresses that are acessed and the order in which they are accessed
  56. don't depend on the value of the argument.
  57. A function *has contant control (memory access) flow* iff it has constant
  58. control (memory access) flow with regard to all its arguments.
  59. The following function has contant control flow with regard to its arugment
  60. n, however is doesn't have constant memory access flow with regard to it:
  61. void (int n, int *a) }
  62. a[0] = 0;
  63. a[n] = 0; // memory address reveals the value of n
  64. }
  65. Unless stated otherwise all functions are supposed to have both constant
  66. control flow and constant memory access flow.
  67. */
  68. #define BN_MAX_DECIMAL_DIGITS 79 // floor(log(2**(LIMBS * BITS_PER_LIMB), 10)) + 1
  69. // out_number = (bignum256) in_number
  70. // Assumes in_number is a raw bigendian 256-bit number
  71. // Guarantees out_number is normalized
  72. void bn_read_be(const uint8_t* in_number, bignum256* out_number) {
  73. uint32_t temp = 0;
  74. for(int i = 0; i < BN_LIMBS - 1; i++) {
  75. uint32_t limb = read_be(in_number + (BN_LIMBS - 2 - i) * 4);
  76. temp |= limb << (BN_EXTRA_BITS * i);
  77. out_number->val[i] = temp & BN_LIMB_MASK;
  78. temp = limb >> (32 - BN_EXTRA_BITS * (i + 1));
  79. }
  80. out_number->val[BN_LIMBS - 1] = temp;
  81. }
  82. // out_number = (256BE) in_number
  83. // Assumes in_number < 2**256
  84. // Guarantess out_number is a raw bigendian 256-bit number
  85. void bn_write_be(const bignum256* in_number, uint8_t* out_number) {
  86. uint32_t temp = in_number->val[BN_LIMBS - 1];
  87. for(int i = BN_LIMBS - 2; i >= 0; i--) {
  88. uint32_t limb = in_number->val[i];
  89. temp = (temp << (BN_BITS_PER_LIMB - BN_EXTRA_BITS * i)) | (limb >> (BN_EXTRA_BITS * i));
  90. write_be(out_number + (BN_LIMBS - 2 - i) * 4, temp);
  91. temp = limb;
  92. }
  93. }
  94. // out_number = (bignum256) in_number
  95. // Assumes in_number is a raw little endian 256-bit number
  96. // Guarantees out_number is normalized
  97. void bn_read_le(const uint8_t* in_number, bignum256* out_number) {
  98. uint32_t temp = 0;
  99. for(int i = 0; i < BN_LIMBS - 1; i++) {
  100. uint32_t limb = read_le(in_number + i * 4);
  101. temp |= limb << (BN_EXTRA_BITS * i);
  102. out_number->val[i] = temp & BN_LIMB_MASK;
  103. temp = limb >> (32 - BN_EXTRA_BITS * (i + 1));
  104. }
  105. out_number->val[BN_LIMBS - 1] = temp;
  106. }
  107. // out_number = (256LE) in_number
  108. // Assumes in_number < 2**256
  109. // Guarantess out_number is a raw little endian 256-bit number
  110. void bn_write_le(const bignum256* in_number, uint8_t* out_number) {
  111. uint32_t temp = in_number->val[BN_LIMBS - 1];
  112. for(int i = BN_LIMBS - 2; i >= 0; i--) {
  113. uint32_t limb = in_number->val[i];
  114. temp = (temp << (BN_BITS_PER_LIMB - BN_EXTRA_BITS * i)) | (limb >> (BN_EXTRA_BITS * i));
  115. write_le(out_number + i * 4, temp);
  116. temp = limb;
  117. }
  118. }
  119. // out_number = (bignum256) in_number
  120. // Guarantees out_number is normalized
  121. void bn_read_uint32(uint32_t in_number, bignum256* out_number) {
  122. out_number->val[0] = in_number & BN_LIMB_MASK;
  123. out_number->val[1] = in_number >> BN_BITS_PER_LIMB;
  124. for(uint32_t i = 2; i < BN_LIMBS; i++) out_number->val[i] = 0;
  125. }
  126. // out_number = (bignum256) in_number
  127. // Guarantees out_number is normalized
  128. void bn_read_uint64(uint64_t in_number, bignum256* out_number) {
  129. out_number->val[0] = in_number & BN_LIMB_MASK;
  130. out_number->val[1] = (in_number >>= BN_BITS_PER_LIMB) & BN_LIMB_MASK;
  131. out_number->val[2] = in_number >> BN_BITS_PER_LIMB;
  132. for(uint32_t i = 3; i < BN_LIMBS; i++) out_number->val[i] = 0;
  133. }
  134. // Returns the bitsize of x
  135. // Assumes x is normalized
  136. // The function doesn't have neither constant control flow nor constant memory
  137. // access flow
  138. int bn_bitcount(const bignum256* x) {
  139. for(int i = BN_LIMBS - 1; i >= 0; i--) {
  140. uint32_t limb = x->val[i];
  141. if(limb != 0) {
  142. // __builtin_clz returns the number of leading zero bits starting at the
  143. // most significant bit position
  144. return i * BN_BITS_PER_LIMB + (32 - __builtin_clz(limb));
  145. }
  146. }
  147. return 0;
  148. }
  149. // Returns the number of decimal digits of x; if x is 0, returns 1
  150. // Assumes x is normalized
  151. // The function doesn't have neither constant control flow nor constant memory
  152. // access flow
  153. unsigned int bn_digitcount(const bignum256* x) {
  154. bignum256 val = {0};
  155. bn_copy(x, &val);
  156. unsigned int digits = 1;
  157. for(unsigned int i = 0; i < BN_MAX_DECIMAL_DIGITS; i += 3) {
  158. uint32_t limb = 0;
  159. bn_divmod1000(&val, &limb);
  160. if(limb >= 100) {
  161. digits = i + 3;
  162. } else if(limb >= 10) {
  163. digits = i + 2;
  164. } else if(limb >= 1) {
  165. digits = i + 1;
  166. }
  167. }
  168. memzero(&val, sizeof(val));
  169. return digits;
  170. }
  171. // x = 0
  172. // Guarantees x is normalized
  173. void bn_zero(bignum256* x) {
  174. for(int i = 0; i < BN_LIMBS; i++) {
  175. x->val[i] = 0;
  176. }
  177. }
  178. // x = 1
  179. // Guarantees x is normalized
  180. void bn_one(bignum256* x) {
  181. x->val[0] = 1;
  182. for(int i = 1; i < BN_LIMBS; i++) {
  183. x->val[i] = 0;
  184. }
  185. }
  186. // Returns x == 0
  187. // Assumes x is normalized
  188. int bn_is_zero(const bignum256* x) {
  189. uint32_t result = 0;
  190. for(int i = 0; i < BN_LIMBS; i++) {
  191. result |= x->val[i];
  192. }
  193. return !result;
  194. }
  195. // Returns x == 1
  196. // Assumes x is normalized
  197. int bn_is_one(const bignum256* x) {
  198. uint32_t result = x->val[0] ^ 1;
  199. for(int i = 1; i < BN_LIMBS; i++) {
  200. result |= x->val[i];
  201. }
  202. return !result;
  203. }
  204. // Returns x < y
  205. // Assumes x, y are normalized
  206. int bn_is_less(const bignum256* x, const bignum256* y) {
  207. uint32_t res1 = 0;
  208. uint32_t res2 = 0;
  209. for(int i = BN_LIMBS - 1; i >= 0; i--) {
  210. res1 = (res1 << 1) | (x->val[i] < y->val[i]);
  211. res2 = (res2 << 1) | (x->val[i] > y->val[i]);
  212. }
  213. return res1 > res2;
  214. }
  215. // Returns x == y
  216. // Assumes x, y are normalized
  217. int bn_is_equal(const bignum256* x, const bignum256* y) {
  218. uint32_t result = 0;
  219. for(int i = 0; i < BN_LIMBS; i++) {
  220. result |= x->val[i] ^ y->val[i];
  221. }
  222. return !result;
  223. }
  224. // res = cond if truecase else falsecase
  225. // Assumes cond is either 0 or 1
  226. // Works properly even if &res == &truecase or &res == &falsecase or
  227. // &truecase == &falsecase or &res == &truecase == &falsecase
  228. void bn_cmov(
  229. bignum256* res,
  230. volatile uint32_t cond,
  231. const bignum256* truecase,
  232. const bignum256* falsecase) {
  233. // Intentional use of bitwise OR operator to ensure constant-time
  234. assert((int)(cond == 1) | (int)(cond == 0));
  235. uint32_t tmask = -cond; // tmask = 0xFFFFFFFF if cond else 0x00000000
  236. uint32_t fmask = ~tmask; // fmask = 0x00000000 if cond else 0xFFFFFFFF
  237. for(int i = 0; i < BN_LIMBS; i++) {
  238. res->val[i] = (truecase->val[i] & tmask) | (falsecase->val[i] & fmask);
  239. }
  240. }
  241. // x = -x % prime if cond else x,
  242. // Explicitly x = (3 * prime - x if x > prime else 2 * prime - x) if cond else
  243. // else (x if x > prime else x + prime)
  244. // Assumes x is normalized and partly reduced
  245. // Assumes cond is either 1 or 0
  246. // Guarantees x is normalized
  247. // Assumes prime is normalized and
  248. // 0 < prime < 2**260 == 2**(BITS_PER_LIMB * LIMBS - 1)
  249. void bn_cnegate(volatile uint32_t cond, bignum256* x, const bignum256* prime) {
  250. // Intentional use of bitwise OR operator to ensure constant time
  251. assert((int)(cond == 1) | (int)(cond == 0));
  252. uint32_t tmask = -cond; // tmask = 0xFFFFFFFF if cond else 0x00000000
  253. uint32_t fmask = ~tmask; // fmask = 0x00000000 if cond else 0xFFFFFFFF
  254. bn_mod(x, prime);
  255. // x < prime
  256. uint32_t acc1 = 1;
  257. uint32_t acc2 = 0;
  258. for(int i = 0; i < BN_LIMBS; i++) {
  259. acc1 += (BN_BASE - 1) + 2 * prime->val[i] - x->val[i];
  260. // acc1 neither overflows 32 bits nor underflows 0
  261. // Proof:
  262. // acc1 + (BASE - 1) + 2 * prime[i] - x[i]
  263. // >= (BASE - 1) - x >= (2**BITS_PER_LIMB - 1) - (2**BITS_PER_LIMB - 1)
  264. // == 0
  265. // acc1 + (BASE - 1) + 2 * prime[i] - x[i]
  266. // <= acc1 + (BASE - 1) + 2 * prime[i]
  267. // <= (2**(32 - BITS_PER_LIMB) - 1) + 2 * (2**BITS_PER_LIMB - 1) +
  268. // (2**BITS_PER_LIMB - 1)
  269. // == 7 + 3 * 2**29 < 2**32
  270. acc2 += prime->val[i] + x->val[i];
  271. // acc2 doesn't overflow 32 bits
  272. // Proof:
  273. // acc2 + prime[i] + x[i]
  274. // <= 2**(32 - BITS_PER_LIMB) - 1 + 2 * (2**BITS_PER_LIMB - 1)
  275. // == 2**(32 - BITS_PER_LIMB) + 2**(BITS_PER_LIMB + 1) - 2
  276. // == 2**30 + 5 < 2**32
  277. // x = acc1 & LIMB_MASK if cond else acc2 & LIMB_MASK
  278. x->val[i] = ((acc1 & tmask) | (acc2 & fmask)) & BN_LIMB_MASK;
  279. acc1 >>= BN_BITS_PER_LIMB;
  280. // acc1 <= 7 == 2**(32 - BITS_PER_LIMB) - 1
  281. // acc1 == 2**(BITS_PER_LIMB * (i + 1)) + 2 * prime[:i + 1] - x[:i + 1]
  282. // >> BITS_PER_LIMB * (i + 1)
  283. acc2 >>= BN_BITS_PER_LIMB;
  284. // acc2 <= 7 == 2**(32 - BITS_PER_LIMB) - 1
  285. // acc2 == prime[:i + 1] + x[:i + 1] >> BITS_PER_LIMB * (i + 1)
  286. }
  287. // assert(acc1 == 1); // assert prime <= 2**260
  288. // assert(acc2 == 0);
  289. // clang-format off
  290. // acc1 == 1
  291. // Proof:
  292. // acc1 == 2**(BITS_PER_LIMB * LIMBS) + 2 * prime[:LIMBS] - x[:LIMBS] >> BITS_PER_LIMB * LIMBS
  293. // == 2**(BITS_PER_LIMB * LIMBS) + 2 * prime - x >> BITS_PER_LIMB * LIMBS
  294. // <= 2**(BITS_PER_LIMB * LIMBS) + 2 * prime >> BITS_PER_LIMB * LIMBS
  295. // <= 2**(BITS_PER_LIMB * LIMBS) + 2 * (2**(BITS_PER_LIMB * LIMBS - 1) - 1) >> BITS_PER_LIMB * LIMBS
  296. // <= 2**(BITS_PER_LIMB * LIMBS) + 2**(BITS_PER_LIMB * LIMBS) - 2 >> BITS_PER_LIMB * LIMBS
  297. // == 1
  298. // acc1 == 2**(BITS_PER_LIMB * LIMBS) + 2 * prime[:LIMBS] - x[:LIMBS] >> BITS_PER_LIMB * LIMBS
  299. // == 2**(BITS_PER_LIMB * LIMBS) + 2 * prime - x >> BITS_PER_LIMB * LIMBS
  300. // >= 2**(BITS_PER_LIMB * LIMBS) + 0 >> BITS_PER_LIMB * LIMBS
  301. // == 1
  302. // acc2 == 0
  303. // Proof:
  304. // acc2 == prime[:LIMBS] + x[:LIMBS] >> BITS_PER_LIMB * LIMBS
  305. // == prime + x >> BITS_PER_LIMB * LIMBS
  306. // <= 2 * prime - 1 >> BITS_PER_LIMB * LIMBS
  307. // <= 2 * (2**(BITS_PER_LIMB * LIMBS - 1) - 1) - 1 >> 261
  308. // == 2**(BITS_PER_LIMB * LIMBS) - 3 >> BITS_PER_LIMB * LIMBS
  309. // == 0
  310. // clang-format on
  311. }
  312. // x <<= 1
  313. // Assumes x is normalized, x < 2**260 == 2**(LIMBS*BITS_PER_LIMB - 1)
  314. // Guarantees x is normalized
  315. void bn_lshift(bignum256* x) {
  316. for(int i = BN_LIMBS - 1; i > 0; i--) {
  317. x->val[i] = ((x->val[i] << 1) & BN_LIMB_MASK) | (x->val[i - 1] >> (BN_BITS_PER_LIMB - 1));
  318. }
  319. x->val[0] = (x->val[0] << 1) & BN_LIMB_MASK;
  320. }
  321. // x >>= 1, i.e. x = floor(x/2)
  322. // Assumes x is normalized
  323. // Guarantees x is normalized
  324. // If x is partly reduced (fully reduced) modulo prime,
  325. // guarantess x will be partly reduced (fully reduced) modulo prime
  326. void bn_rshift(bignum256* x) {
  327. for(int i = 0; i < BN_LIMBS - 1; i++) {
  328. x->val[i] = (x->val[i] >> 1) | ((x->val[i + 1] & 1) << (BN_BITS_PER_LIMB - 1));
  329. }
  330. x->val[BN_LIMBS - 1] >>= 1;
  331. }
  332. // Sets i-th least significant bit (counting from zero)
  333. // Assumes x is normalized and 0 <= i < 261 == LIMBS*BITS_PER_LIMB
  334. // Guarantees x is normalized
  335. // The function has constant control flow but not constant memory access flow
  336. // with regard to i
  337. void bn_setbit(bignum256* x, uint16_t i) {
  338. assert(i < BN_LIMBS * BN_BITS_PER_LIMB);
  339. x->val[i / BN_BITS_PER_LIMB] |= (1u << (i % BN_BITS_PER_LIMB));
  340. }
  341. // clears i-th least significant bit (counting from zero)
  342. // Assumes x is normalized and 0 <= i < 261 == LIMBS*BITS_PER_LIMB
  343. // Guarantees x is normalized
  344. // The function has constant control flow but not constant memory access flow
  345. // with regard to i
  346. void bn_clearbit(bignum256* x, uint16_t i) {
  347. assert(i < BN_LIMBS * BN_BITS_PER_LIMB);
  348. x->val[i / BN_BITS_PER_LIMB] &= ~(1u << (i % BN_BITS_PER_LIMB));
  349. }
  350. // returns i-th least significant bit (counting from zero)
  351. // Assumes x is normalized and 0 <= i < 261 == LIMBS*BITS_PER_LIMB
  352. // The function has constant control flow but not constant memory access flow
  353. // with regard to i
  354. uint32_t bn_testbit(const bignum256* x, uint16_t i) {
  355. assert(i < BN_LIMBS * BN_BITS_PER_LIMB);
  356. return (x->val[i / BN_BITS_PER_LIMB] >> (i % BN_BITS_PER_LIMB)) & 1;
  357. }
  358. // res = x ^ y
  359. // Assumes x, y are normalized
  360. // Guarantees res is normalized
  361. // Works properly even if &res == &x or &res == &y or &res == &x == &y
  362. void bn_xor(bignum256* res, const bignum256* x, const bignum256* y) {
  363. for(int i = 0; i < BN_LIMBS; i++) {
  364. res->val[i] = x->val[i] ^ y->val[i];
  365. }
  366. }
  367. // x = x / 2 % prime
  368. // Explicitly x = x / 2 if is_even(x) else (x + prime) / 2
  369. // Assumes x is normalized, x + prime < 261 == LIMBS * BITS_PER_LIMB
  370. // Guarantees x is normalized
  371. // If x is partly reduced (fully reduced) modulo prime,
  372. // guarantess x will be partly reduced (fully reduced) modulo prime
  373. // Assumes prime is an odd number and normalized
  374. void bn_mult_half(bignum256* x, const bignum256* prime) {
  375. // x = x / 2 if is_even(x) else (x + prime) / 2
  376. uint32_t x_is_odd_mask = -(x->val[0] & 1); // x_is_odd_mask = 0xFFFFFFFF if is_odd(x) else 0
  377. uint32_t acc = (x->val[0] + (prime->val[0] & x_is_odd_mask)) >> 1;
  378. // acc < 2**BITS_PER_LIMB
  379. // Proof:
  380. // acc == x[0] + prime[0] & x_is_odd_mask >> 1
  381. // <= (2**(BITS_PER_LIMB) - 1) + (2**(BITS_PER_LIMB) - 1) >> 1
  382. // == 2**(BITS_PER_LIMB + 1) - 2 >> 1
  383. // < 2**(BITS_PER_LIMB)
  384. for(int i = 0; i < BN_LIMBS - 1; i++) {
  385. uint32_t temp = (x->val[i + 1] + (prime->val[i + 1] & x_is_odd_mask));
  386. // temp < 2**(BITS_PER_LIMB + 1)
  387. // Proof:
  388. // temp == x[i + 1] + val[i + 1] & x_is_odd_mask
  389. // <= (2**(BITS_PER_LIMB) - 1) + (2**(BITS_PER_LIMB) - 1)
  390. // < 2**(BITS_PER_LIMB + 1)
  391. acc += (temp & 1) << (BN_BITS_PER_LIMB - 1);
  392. // acc doesn't overflow 32 bits
  393. // Proof:
  394. // acc + (temp & 1 << BITS_PER_LIMB - 1)
  395. // <= 2**(BITS_PER_LIMB + 1) + 2**(BITS_PER_LIMB - 1)
  396. // <= 2**30 + 2**28 < 2**32
  397. x->val[i] = acc & BN_LIMB_MASK;
  398. acc >>= BN_BITS_PER_LIMB;
  399. acc += temp >> 1;
  400. // acc < 2**(BITS_PER_LIMB + 1)
  401. // Proof:
  402. // acc + (temp >> 1)
  403. // <= (2**(32 - BITS_PER_LIMB) - 1) + (2**(BITS_PER_LIMB + 1) - 1 >> 1)
  404. // == 7 + 2**(BITS_PER_LIMB) - 1 < 2**(BITS_PER_LIMB + 1)
  405. // acc == x[:i+2]+(prime[:i+2] & x_is_odd_mask) >> BITS_PER_LIMB * (i+1)
  406. }
  407. x->val[BN_LIMBS - 1] = acc;
  408. // assert(acc >> BITS_PER_LIMB == 0);
  409. // acc >> BITS_PER_LIMB == 0
  410. // Proof:
  411. // acc
  412. // == x[:LIMBS] + (prime[:LIMBS] & x_is_odd_mask) >> BITS_PER_LIMB*LIMBS
  413. // == x + (prime & x_is_odd_mask) >> BITS_PER_LIMB * LIMBS
  414. // <= x + prime >> BITS_PER_LIMB * LIMBS
  415. // <= 2**(BITS_PER_LIMB * LIMBS) - 1 >> BITS_PER_LIMB * LIMBS
  416. // == 0
  417. }
  418. // x = x * k % prime
  419. // Assumes x is normalized, 0 <= k <= 8 = 2**(32 - BITS_PER_LIMB)
  420. // Assumes prime is normalized and 2^256 - 2^224 <= prime <= 2^256
  421. // Guarantees x is normalized and partly reduced modulo prime
  422. void bn_mult_k(bignum256* x, uint8_t k, const bignum256* prime) {
  423. assert(k <= 8);
  424. for(int i = 0; i < BN_LIMBS; i++) {
  425. x->val[i] = k * x->val[i];
  426. // x[i] doesn't overflow 32 bits
  427. // k * x[i] <= 2**(32 - BITS_PER_LIMB) * (2**BITS_PER_LIMB - 1)
  428. // < 2**(32 - BITS_PER_LIMB) * 2**BITS_PER_LIMB == 2**32
  429. }
  430. bn_fast_mod(x, prime);
  431. }
  432. // Reduces partly reduced x modulo prime
  433. // Explicitly x = x if x < prime else x - prime
  434. // Assumes x is partly reduced modulo prime
  435. // Guarantees x is fully reduced modulo prime
  436. // Assumes prime is nonzero and normalized
  437. void bn_mod(bignum256* x, const bignum256* prime) {
  438. uint32_t x_less_prime = bn_is_less(x, prime);
  439. bignum256 temp = {0};
  440. bn_subtract(x, prime, &temp);
  441. bn_cmov(x, x_less_prime, x, &temp);
  442. memzero(&temp, sizeof(temp));
  443. }
  444. // Auxiliary function for bn_multiply
  445. // res = k * x
  446. // Assumes k and x are normalized
  447. // Guarantees res is normalized 18 digit little endian number in base 2**29
  448. void bn_multiply_long(const bignum256* k, const bignum256* x, uint32_t res[2 * BN_LIMBS]) {
  449. // Uses long multiplication in base 2**29, see
  450. // https://en.wikipedia.org/wiki/Multiplication_algorithm#Long_multiplication
  451. uint64_t acc = 0;
  452. // compute lower half
  453. for(int i = 0; i < BN_LIMBS; i++) {
  454. for(int j = 0; j <= i; j++) {
  455. acc += k->val[j] * (uint64_t)x->val[i - j];
  456. // acc doesn't overflow 64 bits
  457. // Proof:
  458. // acc <= acc + sum([k[j] * x[i-j] for j in range(i)])
  459. // <= (2**(64 - BITS_PER_LIMB) - 1) +
  460. // LIMBS * (2**BITS_PER_LIMB - 1) * (2**BITS_PER_LIMB - 1)
  461. // == (2**35 - 1) + 9 * (2**29 - 1) * (2**29 - 1)
  462. // <= 2**35 + 9 * 2**58 < 2**64
  463. }
  464. res[i] = acc & BN_LIMB_MASK;
  465. acc >>= BN_BITS_PER_LIMB;
  466. // acc <= 2**35 - 1 == 2**(64 - BITS_PER_LIMB) - 1
  467. }
  468. // compute upper half
  469. for(int i = BN_LIMBS; i < 2 * BN_LIMBS - 1; i++) {
  470. for(int j = i - BN_LIMBS + 1; j < BN_LIMBS; j++) {
  471. acc += k->val[j] * (uint64_t)x->val[i - j];
  472. // acc doesn't overflow 64 bits
  473. // Proof:
  474. // acc <= acc + sum([k[j] * x[i-j] for j in range(i)])
  475. // <= (2**(64 - BITS_PER_LIMB) - 1)
  476. // LIMBS * (2**BITS_PER_LIMB - 1) * (2**BITS_PER_LIMB - 1)
  477. // == (2**35 - 1) + 9 * (2**29 - 1) * (2**29 - 1)
  478. // <= 2**35 + 9 * 2**58 < 2**64
  479. }
  480. res[i] = acc & (BN_BASE - 1);
  481. acc >>= BN_BITS_PER_LIMB;
  482. // acc < 2**35 == 2**(64 - BITS_PER_LIMB)
  483. }
  484. res[2 * BN_LIMBS - 1] = acc;
  485. }
  486. // Auxiliary function for bn_multiply
  487. // Assumes 0 <= d <= 8 == LIMBS - 1
  488. // Assumes res is normalized and res < 2**(256 + 29*d + 31)
  489. // Guarantess res in normalized and res < 2 * prime * 2**(29*d)
  490. // Assumes prime is normalized, 2**256 - 2**224 <= prime <= 2**256
  491. void bn_multiply_reduce_step(uint32_t res[2 * BN_LIMBS], const bignum256* prime, uint32_t d) {
  492. // clang-format off
  493. // Computes res = res - (res // 2**(256 + BITS_PER_LIMB * d)) * prime * 2**(BITS_PER_LIMB * d)
  494. // res - (res // 2**(256 + BITS_PER_LIMB * d)) * prime * 2**(BITS_PER_LIMB * d) < 2 * prime * 2**(BITS_PER_LIMB * d)
  495. // Proof:
  496. // res - res // (2**(256 + BITS_PER_LIMB * d)) * 2**(BITS_PER_LIMB * d) * prime
  497. // == res - res // (2**(256 + BITS_PER_LIMB * d)) * 2**(BITS_PER_LIMB * d) * (2**256 - (2**256 - prime))
  498. // == res - res // (2**(256 + BITS_PER_LIMB * d)) * 2**(BITS_PER_LIMB * d) * 2**256 + res // (2**(256 + BITS_PER_LIMB * d)) * 2**(BITS_PER_LIMB * d) * (2**256 - prime)
  499. // == (res % 2**(256 + BITS_PER_LIMB * d)) + res // (2**256 + BITS_PER_LIMB * d) * 2**(BITS_PER_LIMB * d) * (2**256 - prime)
  500. // <= (2**(256 + 29*d + 31) % 2**(256 + 29*d)) + (2**(256 + 29*d + 31) - 1) / (2**256 + 29*d) * 2**(29*d) * (2**256 - prime)
  501. // <= 2**(256 + 29*d) + 2**(256 + 29*d + 31) / (2**256 + 29*d) * 2**(29*d) * (2**256 - prime)
  502. // == 2**(256 + 29*d) + 2**31 * 2**(29*d) * (2**256 - prime)
  503. // == 2**(29*d) * (2**256 + 2**31 * (2*256 - prime))
  504. // <= 2**(29*d) * (2**256 + 2**31 * 2*224)
  505. // <= 2**(29*d) * (2**256 + 2**255)
  506. // <= 2**(29*d) * 2 * (2**256 - 2**224)
  507. // <= 2 * prime * 2**(29*d)
  508. // clang-format on
  509. uint32_t coef = (res[d + BN_LIMBS - 1] >> (256 - (BN_LIMBS - 1) * BN_BITS_PER_LIMB)) +
  510. (res[d + BN_LIMBS] << ((BN_LIMBS * BN_BITS_PER_LIMB) - 256));
  511. // coef == res // 2**(256 + BITS_PER_LIMB * d)
  512. // coef < 2**31
  513. // Proof:
  514. // coef == res // 2**(256 + BITS_PER_LIMB * d)
  515. // < 2**(256 + 29 * d + 31) // 2**(256 + 29 * d)
  516. // == 2**31
  517. const int shift = 31;
  518. uint64_t acc = 1ull << shift;
  519. for(int i = 0; i < BN_LIMBS; i++) {
  520. acc += (((uint64_t)(BN_BASE - 1)) << shift) + res[d + i] - prime->val[i] * (uint64_t)coef;
  521. // acc neither overflow 64 bits nor underflow zero
  522. // Proof:
  523. // acc + ((BASE - 1) << shift) + res[d + i] - prime[i] * coef
  524. // >= ((BASE - 1) << shift) - prime[i] * coef
  525. // == 2**shift * (2**BITS_PER_LIMB - 1) - (2**BITS_PER_LIMB - 1) *
  526. // (2**31 - 1)
  527. // == (2**shift - 2**31 + 1) * (2**BITS_PER_LIMB - 1)
  528. // == (2**31 - 2**31 + 1) * (2**29 - 1)
  529. // == 2**29 - 1 > 0
  530. // acc + ((BASE - 1) << shift) + res[d + i] - prime[i] * coef
  531. // <= acc + ((BASE - 1) << shift) + res[d+i]
  532. // <= (2**(64 - BITS_PER_LIMB) - 1) + 2**shift * (2**BITS_PER_LIMB - 1)
  533. // + (2*BITS_PER_LIMB - 1)
  534. // == (2**(64 - BITS_PER_LIMB) - 1) + (2**shift + 1) *
  535. // (2**BITS_PER_LIMB - 1)
  536. // == (2**35 - 1) + (2**31 + 1) * (2**29 - 1)
  537. // <= 2**35 + 2**60 + 2**29 < 2**64
  538. res[d + i] = acc & BN_LIMB_MASK;
  539. acc >>= BN_BITS_PER_LIMB;
  540. // acc <= 2**(64 - BITS_PER_LIMB) - 1 == 2**35 - 1
  541. // acc == (1 << BITS_PER_LIMB * (i + 1) + shift) + res[d : d + i + 1]
  542. // - coef * prime[:i + 1] >> BITS_PER_LIMB * (i + 1)
  543. }
  544. // acc += (((uint64_t)(BASE - 1)) << shift) + res[d + LIMBS];
  545. // acc >>= BITS_PER_LIMB;
  546. // assert(acc <= 1ul << shift);
  547. // clang-format off
  548. // acc == 1 << shift
  549. // Proof:
  550. // acc
  551. // == (1 << BITS_PER_LIMB * (LIMBS + 1) + shift) + res[d : d + LIMBS + 1] - coef * prime[:LIMBS] >> BITS_PER_LIMB * (LIMBS + 1)
  552. // == (1 << BITS_PER_LIMB * (LIMBS + 1) + shift) + res[d : d + LIMBS + 1] - coef * prime >> BITS_PER_LIMB * (LIMBS + 1)
  553. // == (1 << BITS_PER_LIMB * (LIMBS + 1) + shift) + (res[d : d + LIMBS + 1] - coef * prime) >> BITS_PER_LIMB * (LIMBS + 1)
  554. // <= (1 << BITS_PER_LIMB * (LIMBS + 1) + shift) + (res[:d] + BASE**d * res[d : d + LIMBS + 1] - BASE**d * coef * prime)//BASE**d >> BITS_PER_LIMB * (LIMBS + 1)
  555. // <= (1 << BITS_PER_LIMB * (LIMBS + 1) + shift) + (res - BASE**d * coef * prime) // BASE**d >> BITS_PER_LIMB * (LIMBS + 1)
  556. // == (1 << BITS_PER_LIMB * (LIMBS + 1) + shift) + (2 * prime * BASE**d) // BASE**d >> BITS_PER_LIMB * (LIMBS + 1)
  557. // <= (1 << 321) + 2 * 2**256 >> 290
  558. // == 1 << 31 == 1 << shift
  559. // == (1 << BITS_PER_LIMB * (LIMBS + 1) + shift) + res[d : d + LIMBS + 1] - coef * prime[:LIMBS + 1] >> BITS_PER_LIMB * (LIMBS + 1)
  560. // >= (1 << BITS_PER_LIMB * (LIMBS + 1) + shift) + 0 >> BITS_PER_LIMB * (LIMBS + 1)
  561. // == 1 << shift
  562. // clang-format on
  563. res[d + BN_LIMBS] = 0;
  564. }
  565. // Auxiliary function for bn_multiply
  566. // Partly reduces res and stores both in x and res
  567. // Assumes res in normalized and res < 2**519
  568. // Guarantees x is normalized and partly reduced modulo prime
  569. // Assumes prime is normalized, 2**256 - 2**224 <= prime <= 2**256
  570. void bn_multiply_reduce(bignum256* x, uint32_t res[2 * BN_LIMBS], const bignum256* prime) {
  571. for(int i = BN_LIMBS - 1; i >= 0; i--) {
  572. // res < 2**(256 + 29*i + 31)
  573. // Proof:
  574. // if i == LIMBS - 1:
  575. // res < 2**519
  576. // == 2**(256 + 29 * 8 + 31)
  577. // == 2**(256 + 29 * (LIMBS - 1) + 31)
  578. // else:
  579. // res < 2 * prime * 2**(29 * (i + 1))
  580. // <= 2**256 * 2**(29*i + 29) < 2**(256 + 29*i + 31)
  581. bn_multiply_reduce_step(res, prime, i);
  582. }
  583. for(int i = 0; i < BN_LIMBS; i++) {
  584. x->val[i] = res[i];
  585. }
  586. }
  587. // x = k * x % prime
  588. // Assumes k, x are normalized, k * x < 2**519
  589. // Guarantees x is normalized and partly reduced modulo prime
  590. // Assumes prime is normalized, 2**256 - 2**224 <= prime <= 2**256
  591. void bn_multiply(const bignum256* k, bignum256* x, const bignum256* prime) {
  592. uint32_t res[2 * BN_LIMBS] = {0};
  593. bn_multiply_long(k, x, res);
  594. bn_multiply_reduce(x, res, prime);
  595. memzero(res, sizeof(res));
  596. }
  597. // Partly reduces x modulo prime
  598. // Assumes limbs of x except the last (the most significant) one are normalized
  599. // Assumes prime is normalized and 2^256 - 2^224 <= prime <= 2^256
  600. // Guarantees x is normalized and partly reduced modulo prime
  601. void bn_fast_mod(bignum256* x, const bignum256* prime) {
  602. // Computes x = x - (x // 2**256) * prime
  603. // x < 2**((LIMBS - 1) * BITS_PER_LIMB + 32) == 2**264
  604. // x - (x // 2**256) * prime < 2 * prime
  605. // Proof:
  606. // x - (x // 2**256) * prime
  607. // == x - (x // 2**256) * (2**256 - (2**256 - prime))
  608. // == x - ((x // 2**256) * 2**256) + (x // 2**256) * (2**256 - prime)
  609. // == (x % prime) + (x // 2**256) * (2**256 - prime)
  610. // <= prime - 1 + (2**264 // 2**256) * (2**256 - prime)
  611. // <= 2**256 + 2**8 * 2**224 == 2**256 + 2**232
  612. // < 2 * (2**256 - 2**224)
  613. // <= 2 * prime
  614. // x - (x // 2**256 - 1) * prime < 2 * prime
  615. // Proof:
  616. // x - (x // 2**256) * prime + prime
  617. // == x - (x // 2**256) * (2**256 - (2**256 - prime)) + prime
  618. // == x - ((x//2**256) * 2**256) + (x//2**256) * (2**256 - prime) + prime
  619. // == (x % prime) + (x // 2**256) * (2**256 - prime) + prime
  620. // <= 2 * prime - 1 + (2**264 // 2**256) * (2**256 - prime)
  621. // <= 2 * prime + 2**8 * 2**224 == 2**256 + 2**232 + 2**256 - 2**224
  622. // < 2 * (2**256 - 2**224)
  623. // <= 2 * prime
  624. uint32_t coef = x->val[BN_LIMBS - 1] >> (256 - ((BN_LIMBS - 1) * BN_BITS_PER_LIMB));
  625. // clang-format off
  626. // coef == x // 2**256
  627. // 0 <= coef < 2**((LIMBS - 1) * BITS_PER_LIMB + 32 - 256) == 256
  628. // Proof:
  629. //* Let x[[a : b] be the number consisting of a-th to (b-1)-th bit of the number x.
  630. // x[LIMBS - 1] >> (256 - ((LIMBS - 1) * BITS_PER_LIMB))
  631. // == x[[(LIMBS - 1) * BITS_PER_LIMB : (LIMBS - 1) * BITS_PER_LIMB + 32]] >> (256 - ((LIMBS - 1) * BITS_PER_LIMB))
  632. // == x[[256 - ((LIMBS - 1) * BITS_PER_LIMB) + (LIMBS - 1) * BITS_PER_LIMB : (LIMBS - 1) * BITS_PER_LIMB + 32]]
  633. // == x[[256 : (LIMBS - 1) * BITS_PER_LIMB + 32]]
  634. // == x[[256 : 264]] == x // 2**256
  635. // clang-format on
  636. const int shift = 8;
  637. uint64_t acc = 1ull << shift;
  638. for(int i = 0; i < BN_LIMBS; i++) {
  639. acc += (((uint64_t)(BN_BASE - 1)) << shift) + x->val[i] - prime->val[i] * (uint64_t)coef;
  640. // acc neither overflows 64 bits nor underflows 0
  641. // Proof:
  642. // acc + (BASE - 1 << shift) + x[i] - prime[i] * coef
  643. // >= (BASE - 1 << shift) - prime[i] * coef
  644. // >= 2**shift * (2**BITS_PER_LIMB - 1) - (2**BITS_PER_LIMB - 1) * 255
  645. // == (2**shift - 255) * (2**BITS_PER_LIMB - 1)
  646. // == (2**8 - 255) * (2**29 - 1) == 2**29 - 1 >= 0
  647. // acc + (BASE - 1 << shift) + x[i] - prime[i] * coef
  648. // <= acc + ((BASE - 1) << shift) + x[i]
  649. // <= (2**(64 - BITS_PER_LIMB) - 1) + 2**shift * (2**BITS_PER_LIMB - 1)
  650. // + (2**32 - 1)
  651. // == (2**35 - 1) + 2**8 * (2**29 - 1) + 2**32
  652. // < 2**35 + 2**37 + 2**32 < 2**64
  653. x->val[i] = acc & BN_LIMB_MASK;
  654. acc >>= BN_BITS_PER_LIMB;
  655. // acc <= 2**(64 - BITS_PER_LIMB) - 1 == 2**35 - 1
  656. // acc == (1 << BITS_PER_LIMB * (i + 1) + shift) + x[:i + 1]
  657. // - coef * prime[:i + 1] >> BITS_PER_LIMB * (i + 1)
  658. }
  659. // assert(acc == 1 << shift);
  660. // clang-format off
  661. // acc == 1 << shift
  662. // Proof:
  663. // acc
  664. // == (1 << BITS_PER_LIMB * LIMBS + shift) + x[:LIMBS] - coef * prime[:LIMBS] >> BITS_PER_LIMB * LIMBS
  665. // == (1 << BITS_PER_LIMB * LIMBS + shift) + (x - coef * prime) >> BITS_PER_LIMB * LIMBS
  666. // <= (1 << BITS_PER_LIMB * LIMBS + shift) + (2 * prime) >> BITS_PER_LIMB * LIMBS
  667. // <= (1 << BITS_PER_LIMB * LIMBS + shift) + 2 * 2**256 >> BITS_PER_LIMB * LIMBS
  668. // <= 2**269 + 2**257 >> 2**261
  669. // <= 1 << 8 == 1 << shift
  670. // acc
  671. // == (1 << BITS_PER_LIMB * LIMBS + shift) + x[:LIMBS] - coef * prime[:LIMBS] >> BITS_PER_LIMB * LIMBS
  672. // >= (1 << BITS_PER_LIMB * LIMBS + shift) + 0 >> BITS_PER_LIMB * LIMBS
  673. // == (1 << BITS_PER_LIMB * LIMBS + shift) + 0 >> BITS_PER_LIMB * LIMBS
  674. // <= 1 << 8 == 1 << shift
  675. // clang-format on
  676. }
  677. // res = x**e % prime
  678. // Assumes both x and e are normalized, x < 2**259
  679. // Guarantees res is normalized and partly reduced modulo prime
  680. // Works properly even if &x == &res
  681. // Assumes prime is normalized, 2**256 - 2**224 <= prime <= 2**256
  682. // The function doesn't have neither constant control flow nor constant memory
  683. // access flow with regard to e
  684. void bn_power_mod(const bignum256* x, const bignum256* e, const bignum256* prime, bignum256* res) {
  685. // Uses iterative right-to-left exponentiation by squaring, see
  686. // https://en.wikipedia.org/wiki/Modular_exponentiation#Right-to-left_binary_method
  687. bignum256 acc = {0};
  688. bn_copy(x, &acc);
  689. bn_one(res);
  690. for(int i = 0; i < BN_LIMBS; i++) {
  691. uint32_t limb = e->val[i];
  692. for(int j = 0; j < BN_BITS_PER_LIMB; j++) {
  693. // Break if the following bits of the last limb are zero
  694. if(i == BN_LIMBS - 1 && limb == 0) break;
  695. if(limb & 1)
  696. // acc * res < 2**519
  697. // Proof:
  698. // acc * res <= max(2**259 - 1, 2 * prime) * (2 * prime)
  699. // == max(2**259 - 1, 2**257) * 2**257 < 2**259 * 2**257
  700. // == 2**516 < 2**519
  701. bn_multiply(&acc, res, prime);
  702. limb >>= 1;
  703. // acc * acc < 2**519
  704. // Proof:
  705. // acc * acc <= max(2**259 - 1, 2 * prime)**2
  706. // <= (2**259)**2 == 2**518 < 2**519
  707. bn_multiply(&acc, &acc, prime);
  708. }
  709. // acc == x**(e[:i + 1]) % prime
  710. }
  711. memzero(&acc, sizeof(acc));
  712. }
  713. // x = sqrt(x) % prime
  714. // Explicitly x = x**((prime+1)/4) % prime
  715. // The other root is -sqrt(x)
  716. // Assumes x is normalized, x < 2**259 and quadratic residuum mod prime
  717. // Assumes prime is a prime number, prime % 4 == 3, it is normalized and
  718. // 2**256 - 2**224 <= prime <= 2**256
  719. // Guarantees x is normalized and fully reduced modulo prime
  720. // The function doesn't have neither constant control flow nor constant memory
  721. // access flow with regard to prime
  722. void bn_sqrt(bignum256* x, const bignum256* prime) {
  723. // Uses the Lagrange formula for the primes of the special form, see
  724. // http://en.wikipedia.org/wiki/Quadratic_residue#Prime_or_prime_power_modulus
  725. // If prime % 4 == 3, then sqrt(x) % prime == x**((prime+1)//4) % prime
  726. assert(prime->val[BN_LIMBS - 1] % 4 == 3);
  727. // e = (prime + 1) // 4
  728. bignum256 e = {0};
  729. bn_copy(prime, &e);
  730. bn_addi(&e, 1);
  731. bn_rshift(&e);
  732. bn_rshift(&e);
  733. bn_power_mod(x, &e, prime, x);
  734. bn_mod(x, prime);
  735. memzero(&e, sizeof(e));
  736. }
  737. // a = 1/a % 2**n
  738. // Assumes a is odd, 1 <= n <= 32
  739. // The function doesn't have neither constant control flow nor constant memory
  740. // access flow with regard to n
  741. uint32_t inverse_mod_power_two(uint32_t a, uint32_t n) {
  742. // Uses "Explicit Quadratic Modular inverse modulo 2" from section 3.3 of "On
  743. // Newton-Raphson iteration for multiplicative inverses modulo prime powers"
  744. // by Jean-Guillaume Dumas, see
  745. // https://arxiv.org/pdf/1209.6626.pdf
  746. // 1/a % 2**n
  747. // = (2-a) * product([1 + (a-1)**(2**i) for i in range(1, floor(log2(n)))])
  748. uint32_t acc = 2 - a;
  749. uint32_t f = a - 1;
  750. // mask = (1 << n) - 1
  751. uint32_t mask = n == 32 ? 0xFFFFFFFF : (1u << n) - 1;
  752. for(uint32_t i = 1; i < n; i <<= 1) {
  753. f = (f * f) & mask;
  754. acc = (acc * (1 + f)) & mask;
  755. }
  756. return acc;
  757. }
  758. // x = (x / 2**BITS_PER_LIMB) % prime
  759. // Assumes both x and prime are normalized
  760. // Assumes prime is an odd number and normalized
  761. // Guarantees x is normalized
  762. // If x is partly reduced (fully reduced) modulo prime,
  763. // guarantess x will be partly reduced (fully reduced) modulo prime
  764. void bn_divide_base(bignum256* x, const bignum256* prime) {
  765. // Uses an explicit formula for the modular inverse of power of two
  766. // (x / 2**n) % prime == (x + ((-x / prime) % 2**n) * prime) // 2**n
  767. // Proof:
  768. // (x + ((-x / prime) % 2**n) * prime) % 2**n
  769. // == (x - x / prime * prime) % 2**n
  770. // == 0
  771. // (x + ((-1 / prime) % 2**n) * prime) % prime
  772. // == x
  773. // if x < prime:
  774. // (x + ((-x / prime) % 2**n) * prime) // 2**n
  775. // <= ((prime - 1) + (2**n - 1) * prime) / 2**n
  776. // == (2**n * prime - 1) / 2**n == prime - 1 / 2**n < prime
  777. // if x < 2 * prime:
  778. // (x + ((-x / prime) % 2**n) * prime) // 2**n
  779. // <= ((2 * prime - 1) + (2**n - 1) * prime) / 2**n
  780. // == (2**n * prime + prime - 1) / 2**n
  781. // == prime + (prime - 1) / 2**n < 2 * prime
  782. // m = (-x / prime) % 2**BITS_PER_LIMB
  783. uint32_t m = (x->val[0] * (BN_BASE - inverse_mod_power_two(prime->val[0], BN_BITS_PER_LIMB))) &
  784. BN_LIMB_MASK;
  785. // m < 2**BITS_PER_LIMB
  786. uint64_t acc = x->val[0] + (uint64_t)m * prime->val[0];
  787. acc >>= BN_BITS_PER_LIMB;
  788. for(int i = 1; i < BN_LIMBS; i++) {
  789. acc = acc + x->val[i] + (uint64_t)m * prime->val[i];
  790. // acc does not overflow 64 bits
  791. // acc == acc + x + m * prime
  792. // <= 2**(64 - BITS_PER_LIMB) + 2**(BITS_PER_LIMB)
  793. // 2**(BITS_PER_LIMB) * 2**(BITS_PER_LIMB)
  794. // <= 2**(2 * BITS_PER_LIMB) + 2**(64 - BITS_PER_LIMB) +
  795. // 2**(BITS_PER_LIMB)
  796. // <= 2**58 + 2**35 + 2**29 < 2**64
  797. x->val[i - 1] = acc & BN_LIMB_MASK;
  798. acc >>= BN_BITS_PER_LIMB;
  799. // acc < 2**35 == 2**(64 - BITS_PER_LIMB)
  800. // acc == x[:i + 1] + m * prime[:i + 1] >> BITS_PER_LIMB * (i + 1)
  801. }
  802. x->val[BN_LIMBS - 1] = acc;
  803. assert(acc >> BN_BITS_PER_LIMB == 0);
  804. // clang-format off
  805. // acc >> BITS_PER_LIMB == 0
  806. // Proof:
  807. // acc >> BITS_PER_LIMB
  808. // == (x[:LIMB] + m * prime[:LIMB] >> BITS_PER_LIMB * LIMBS) >> BITS_PER_LIMB * (LIMBS + 1)
  809. // == x + m * prime >> BITS_PER_LIMB * (LIMBS + 1)
  810. // <= (2**(BITS_PER_LIMB * LIMBS) - 1) + (2**BITS_PER_LIMB - 1) * (2**(BITS_PER_LIMB * LIMBS) - 1) >> BITS_PER_LIMB * (LIMBS + 1)
  811. // == 2**(BITS_PER_LIMB * LIMBS) - 1 + 2**(BITS_PER_LIMB * (LIMBS + 1)) - 2**(BITS_PER_LIMB * LIMBS) - 2**BITS_PER_LIMB + 1 >> BITS_PER_LIMB * (LIMBS + 1)
  812. // == 2**(BITS_PER_LIMB * (LIMBS + 1)) - 2**BITS_PER_LIMB >> BITS_PER_LIMB * (LIMBS + 1)
  813. // == 0
  814. // clang-format on
  815. }
  816. #if !USE_INVERSE_FAST
  817. // x = 1/x % prime if x != 0 else 0
  818. // Assumes x is normalized
  819. // Assumes prime is a prime number
  820. // Guarantees x is normalized and fully reduced modulo prime
  821. // Assumes prime is normalized, 2**256 - 2**224 <= prime <= 2**256
  822. // The function doesn't have neither constant control flow nor constant memory
  823. // access flow with regard to prime
  824. static void bn_inverse_slow(bignum256* x, const bignum256* prime) {
  825. // Uses formula 1/x % prime == x**(prime - 2) % prime
  826. // See https://en.wikipedia.org/wiki/Fermat%27s_little_theorem
  827. bn_fast_mod(x, prime);
  828. // e = prime - 2
  829. bignum256 e = {0};
  830. bn_read_uint32(2, &e);
  831. bn_subtract(prime, &e, &e);
  832. bn_power_mod(x, &e, prime, x);
  833. bn_mod(x, prime);
  834. memzero(&e, sizeof(e));
  835. }
  836. #endif
  837. #if false
  838. // x = 1/x % prime if x != 0 else 0
  839. // Assumes x is is_normalized
  840. // Assumes GCD(x, prime) = 1
  841. // Guarantees x is normalized and fully reduced modulo prime
  842. // Assumes prime is odd, normalized, 2**256 - 2**224 <= prime <= 2**256
  843. // The function doesn't have neither constant control flow nor constant memory
  844. // access flow with regard to prime and x
  845. static void bn_inverse_fast(bignum256 *x, const bignum256 *prime) {
  846. // "The Almost Montgomery Inverse" from the section 3 of "Constant Time
  847. // Modular Inversion" by Joppe W. Bos
  848. // See http://www.joppebos.com/files/CTInversion.pdf
  849. /*
  850. u = prime
  851. v = x & prime
  852. s = 1
  853. r = 0
  854. k = 0
  855. while v != 1:
  856. k += 1
  857. if is_even(u):
  858. u = u // 2
  859. s = 2 * s
  860. elif is_even(v):
  861. v = v // 2
  862. r = 2 * r
  863. elif v < u:
  864. u = (u - v) // 2
  865. r = r + s
  866. s = 2 * s
  867. else:
  868. v = (v - u) // 2
  869. s = r + s
  870. r = 2 * r
  871. s = (s / 2**k) % prime
  872. return s
  873. */
  874. if (bn_is_zero(x)) return;
  875. bn_fast_mod(x, prime);
  876. bn_mod(x, prime);
  877. bignum256 u = {0}, v = {0}, r = {0}, s = {0};
  878. bn_copy(prime, &u);
  879. bn_copy(x, &v);
  880. bn_one(&s);
  881. bn_zero(&r);
  882. int k = 0;
  883. while (!bn_is_one(&v)) {
  884. if ((u.val[0] & 1) == 0) {
  885. bn_rshift(&u);
  886. bn_lshift(&s);
  887. } else if ((v.val[0] & 1) == 0) {
  888. bn_rshift(&v);
  889. bn_lshift(&r);
  890. } else if (bn_is_less(&v, &u)) {
  891. bn_subtract(&u, &v, &u);
  892. bn_rshift(&u);
  893. bn_add(&r, &s);
  894. bn_lshift(&s);
  895. } else {
  896. bn_subtract(&v, &u, &v);
  897. bn_rshift(&v);
  898. bn_add(&s, &r);
  899. bn_lshift(&r);
  900. }
  901. k += 1;
  902. assert(!bn_is_zero(&v)); // assert GCD(x, prime) == 1
  903. }
  904. // s = s / 2**(k // BITS_PER_LIMB * BITS_PER_LIMB)
  905. for (int i = 0; i < k / BITS_PER_LIMB; i++) {
  906. bn_divide_base(&s, prime);
  907. }
  908. // s = s / 2**(k % BITS_PER_LIMB)
  909. for (int i = 0; i < k % BN_BITS_PER_LIMB; i++) {
  910. bn_mult_half(&s, prime);
  911. }
  912. bn_copy(&s, x);
  913. memzero(&u, sizeof(u));
  914. memzero(&v, sizeof(v));
  915. memzero(&r, sizeof(r));
  916. memzero(&s, sizeof(s));
  917. }
  918. #endif
  919. #if USE_INVERSE_FAST
  920. // x = 1/x % prime if x != 0 else 0
  921. // Assumes x is is_normalized
  922. // Assumes GCD(x, prime) = 1
  923. // Guarantees x is normalized and fully reduced modulo prime
  924. // Assumes prime is odd, normalized, 2**256 - 2**224 <= prime <= 2**256
  925. // The function has constant control flow but not constant memory access flow
  926. // with regard to prime and x
  927. static void bn_inverse_fast(bignum256* x, const bignum256* prime) {
  928. // Custom constant time version of "The Almost Montgomery Inverse" from the
  929. // section 3 of "Constant Time Modular Inversion" by Joppe W. Bos
  930. // See http://www.joppebos.com/files/CTInversion.pdf
  931. /*
  932. u = prime
  933. v = x % prime
  934. s = 1
  935. r = 0
  936. k = 0
  937. while v != 1:
  938. k += 1
  939. if is_even(u): # b1
  940. u = u // 2
  941. s = 2 * s
  942. elif is_even(v): # b2
  943. v = v // 2
  944. r = 2 * r
  945. elif v < u: # b3
  946. u = (u - v) // 2
  947. r = r + s
  948. s = 2 * s
  949. else: # b4
  950. v = (v - u) // 2
  951. s = r + s
  952. r = 2 * r
  953. s = (s / 2**k) % prime
  954. return s
  955. */
  956. bn_fast_mod(x, prime);
  957. bn_mod(x, prime);
  958. bignum256 u = {0}, v = {0}, r = {0}, s = {0};
  959. bn_copy(prime, &u);
  960. bn_copy(x, &v);
  961. bn_one(&s);
  962. bn_zero(&r);
  963. bignum256 zero = {0};
  964. bn_zero(&zero);
  965. int k = 0;
  966. int finished = 0, u_even = 0, v_even = 0, v_less_u = 0, b1 = 0, b2 = 0, b3 = 0, b4 = 0;
  967. finished = 0;
  968. for(int i = 0; i < 2 * BN_LIMBS * BN_BITS_PER_LIMB; i++) {
  969. finished = finished | -bn_is_one(&v);
  970. u_even = -bn_is_even(&u);
  971. v_even = -bn_is_even(&v);
  972. v_less_u = -bn_is_less(&v, &u);
  973. b1 = ~finished & u_even;
  974. b2 = ~finished & ~b1 & v_even;
  975. b3 = ~finished & ~b1 & ~b2 & v_less_u;
  976. b4 = ~finished & ~b1 & ~b2 & ~b3;
  977. // The ternary operator for pointers with constant control flow
  978. // BN_INVERSE_FAST_TERNARY(c, t, f) = t if c else f
  979. // Very nasty hack, sorry for that
  980. #define BN_INVERSE_FAST_TERNARY(c, t, f) \
  981. ((void*)(((c) & (uintptr_t)(t)) | (~(c) & (uintptr_t)(f))))
  982. bn_subtract(
  983. BN_INVERSE_FAST_TERNARY(b3, &u, &v),
  984. BN_INVERSE_FAST_TERNARY(b3 | b4, BN_INVERSE_FAST_TERNARY(b3, &v, &u), &zero),
  985. BN_INVERSE_FAST_TERNARY(b3, &u, &v));
  986. bn_add(
  987. BN_INVERSE_FAST_TERNARY(b3, &r, &s),
  988. BN_INVERSE_FAST_TERNARY(b3 | b4, BN_INVERSE_FAST_TERNARY(b3, &s, &r), &zero));
  989. bn_rshift(BN_INVERSE_FAST_TERNARY(b1 | b3, &u, &v));
  990. bn_lshift(BN_INVERSE_FAST_TERNARY(b1 | b3, &s, &r));
  991. k = k - ~finished;
  992. }
  993. // s = s / 2**(k // BITS_PER_LIMB * BITS_PER_LIMB)
  994. for(int i = 0; i < 2 * BN_LIMBS; i++) {
  995. // s = s / 2**BITS_PER_LIMB % prime if i < k // BITS_PER_LIMB else s
  996. bn_copy(&s, &r);
  997. bn_divide_base(&r, prime);
  998. bn_cmov(&s, i < k / BN_BITS_PER_LIMB, &r, &s);
  999. }
  1000. // s = s / 2**(k % BITS_PER_LIMB)
  1001. for(int i = 0; i < BN_BITS_PER_LIMB; i++) {
  1002. // s = s / 2 % prime if i < k % BITS_PER_LIMB else s
  1003. bn_copy(&s, &r);
  1004. bn_mult_half(&r, prime);
  1005. bn_cmov(&s, i < k % BN_BITS_PER_LIMB, &r, &s);
  1006. }
  1007. bn_cmov(x, bn_is_zero(x), x, &s);
  1008. memzero(&u, sizeof(u));
  1009. memzero(&v, sizeof(v));
  1010. memzero(&r, sizeof(s));
  1011. memzero(&s, sizeof(s));
  1012. }
  1013. #endif
  1014. #if false
  1015. // x = 1/x % prime if x != 0 else 0
  1016. // Assumes x is is_normalized
  1017. // Assumes GCD(x, prime) = 1
  1018. // Guarantees x is normalized and fully reduced modulo prime
  1019. // Assumes prime is odd, normalized, 2**256 - 2**224 <= prime <= 2**256
  1020. static void bn_inverse_fast(bignum256 *x, const bignum256 *prime) {
  1021. // Custom constant time version of "The Almost Montgomery Inverse" from the
  1022. // section 3 of "Constant Time Modular Inversion" by Joppe W. Bos
  1023. // See http://www.joppebos.com/files/CTInversion.pdf
  1024. /*
  1025. u = prime
  1026. v = x % prime
  1027. s = 1
  1028. r = 0
  1029. k = 0
  1030. while v != 1:
  1031. k += 1
  1032. if is_even(u): # b1
  1033. u = u // 2
  1034. s = 2 * s
  1035. elif is_even(v): # b2
  1036. v = v // 2
  1037. r = 2 * r
  1038. elif v < u: # b3
  1039. u = (u - v) // 2
  1040. r = r + s
  1041. s = 2 * s
  1042. else: # b4
  1043. v = (v - u) // 2
  1044. s = r + s
  1045. r = 2 * r
  1046. s = (s / 2**k) % prime
  1047. return s
  1048. */
  1049. bn_fast_mod(x, prime);
  1050. bn_mod(x, prime);
  1051. bignum256 u = {0}, v = {0}, r = {0}, s = {0};
  1052. bn_copy(prime, &u);
  1053. bn_copy(x, &v);
  1054. bn_one(&s);
  1055. bn_zero(&r);
  1056. bignum256 zero = {0};
  1057. bn_zero(&zero);
  1058. int k = 0;
  1059. uint32_t finished = 0, u_even = 0, v_even = 0, v_less_u = 0, b1 = 0, b2 = 0,
  1060. b3 = 0, b4 = 0;
  1061. finished = 0;
  1062. bignum256 u_half = {0}, v_half = {0}, u_minus_v_half = {0}, v_minus_u_half = {0}, r_plus_s = {0}, r_twice = {0}, s_twice = {0};
  1063. for (int i = 0; i < 2 * BN_LIMBS * BN_BITS_PER_LIMB; i++) {
  1064. finished = finished | bn_is_one(&v);
  1065. u_even = bn_is_even(&u);
  1066. v_even = bn_is_even(&v);
  1067. v_less_u = bn_is_less(&v, &u);
  1068. b1 = (finished ^ 1) & u_even;
  1069. b2 = (finished ^ 1) & (b1 ^ 1) & v_even;
  1070. b3 = (finished ^ 1) & (b1 ^ 1) & (b2 ^ 1) & v_less_u;
  1071. b4 = (finished ^ 1) & (b1 ^ 1) & (b2 ^ 1) & (b3 ^ 1);
  1072. // u_half = u // 2
  1073. bn_copy(&u, &u_half);
  1074. bn_rshift(&u_half);
  1075. // v_half = v // 2
  1076. bn_copy(&v, &v_half);
  1077. bn_rshift(&v_half);
  1078. // u_minus_v_half = (u - v) // 2
  1079. bn_subtract(&u, &v, &u_minus_v_half);
  1080. bn_rshift(&u_minus_v_half);
  1081. // v_minus_u_half = (v - u) // 2
  1082. bn_subtract(&v, &u, &v_minus_u_half);
  1083. bn_rshift(&v_minus_u_half);
  1084. // r_plus_s = r + s
  1085. bn_copy(&r, &r_plus_s);
  1086. bn_add(&r_plus_s, &s);
  1087. // r_twice = 2 * r
  1088. bn_copy(&r, &r_twice);
  1089. bn_lshift(&r_twice);
  1090. // s_twice = 2 * s
  1091. bn_copy(&s, &s_twice);
  1092. bn_lshift(&s_twice);
  1093. bn_cmov(&u, b1, &u_half, &u);
  1094. bn_cmov(&u, b3, &u_minus_v_half, &u);
  1095. bn_cmov(&v, b2, &v_half, &v);
  1096. bn_cmov(&v, b4, &v_minus_u_half, &v);
  1097. bn_cmov(&r, b2 | b4, &r_twice, &r);
  1098. bn_cmov(&r, b3, &r_plus_s, &r);
  1099. bn_cmov(&s, b1 | b3, &s_twice, &s);
  1100. bn_cmov(&s, b4, &r_plus_s, &s);
  1101. k = k + (finished ^ 1);
  1102. }
  1103. // s = s / 2**(k // BITS_PER_LIMB * BITS_PER_LIMB)
  1104. for (int i = 0; i < 2 * BN_LIMBS; i++) {
  1105. // s = s / 2**BITS_PER_LIMB % prime if i < k // BITS_PER_LIMB else s
  1106. bn_copy(&s, &r);
  1107. bn_divide_base(&r, prime);
  1108. bn_cmov(&s, i < k / BITS_PER_LIMB, &r, &s);
  1109. }
  1110. // s = s / 2**(k % BITS_PER_LIMB)
  1111. for (int i = 0; i < BN_BITS_PER_LIMB; i++) {
  1112. // s = s / 2 % prime if i < k % BITS_PER_LIMB else s
  1113. bn_copy(&s, &r);
  1114. bn_mult_half(&r, prime);
  1115. bn_cmov(&s, i < k % BN_BITS_PER_LIMB, &r, &s);
  1116. }
  1117. bn_cmov(x, bn_is_zero(x), x, &s);
  1118. memzero(&u, sizeof(u));
  1119. memzero(&v, sizeof(v));
  1120. memzero(&r, sizeof(r));
  1121. memzero(&s, sizeof(s));
  1122. memzero(&u_half, sizeof(u_half));
  1123. memzero(&v_half, sizeof(v_half));
  1124. memzero(&u_minus_v_half, sizeof(u_minus_v_half));
  1125. memzero(&v_minus_u_half, sizeof(v_minus_u_half));
  1126. memzero(&r_twice, sizeof(r_twice));
  1127. memzero(&s_twice, sizeof(s_twice));
  1128. memzero(&r_plus_s, sizeof(r_plus_s));
  1129. }
  1130. #endif
  1131. // Normalizes x
  1132. // Assumes x < 2**261 == 2**(LIMBS * BITS_PER_LIMB)
  1133. // Guarantees x is normalized
  1134. void bn_normalize(bignum256* x) {
  1135. uint32_t acc = 0;
  1136. for(int i = 0; i < BN_LIMBS; i++) {
  1137. acc += x->val[i];
  1138. // acc doesn't overflow 32 bits
  1139. // Proof:
  1140. // acc + x[i]
  1141. // <= (2**(32 - BITS_PER_LIMB) - 1) + (2**BITS_PER_LIMB - 1)
  1142. // == 7 + 2**29 - 1 < 2**32
  1143. x->val[i] = acc & BN_LIMB_MASK;
  1144. acc >>= (BN_BITS_PER_LIMB);
  1145. // acc <= 7 == 2**(32 - BITS_PER_LIMB) - 1
  1146. }
  1147. }
  1148. // x = x + y
  1149. // Assumes x, y are normalized, x + y < 2**(LIMBS*BITS_PER_LIMB) == 2**261
  1150. // Guarantees x is normalized
  1151. // Works properly even if &x == &y
  1152. void bn_add(bignum256* x, const bignum256* y) {
  1153. uint32_t acc = 0;
  1154. for(int i = 0; i < BN_LIMBS; i++) {
  1155. acc += x->val[i] + y->val[i];
  1156. // acc doesn't overflow 32 bits
  1157. // Proof:
  1158. // acc + x[i] + y[i]
  1159. // <= (2**(32 - BITS_PER_LIMB) - 1) + 2 * (2**BITS_PER_LIMB - 1)
  1160. // == (2**(32 - BITS_PER_LIMB) - 1) + 2**(BITS_PER_LIMB + 1) - 2
  1161. // == 7 + 2**30 - 2 < 2**32
  1162. x->val[i] = acc & BN_LIMB_MASK;
  1163. acc >>= BN_BITS_PER_LIMB;
  1164. // acc <= 7 == 2**(32 - BITS_PER_LIMB) - 1
  1165. // acc == x[:i + 1] + y[:i + 1] >> BITS_PER_LIMB * (i + 1)
  1166. }
  1167. // assert(acc == 0); // assert x + y < 2**261
  1168. // acc == 0
  1169. // Proof:
  1170. // acc == x[:LIMBS] + y[:LIMBS] >> LIMBS * BITS_PER_LIMB
  1171. // == x + y >> LIMBS * BITS_PER_LIMB
  1172. // <= 2**(LIMBS * BITS_PER_LIMB) - 1 >> LIMBS * BITS_PER_LIMB == 0
  1173. }
  1174. // x = x + y % prime
  1175. // Assumes x, y are normalized
  1176. // Guarantees x is normalized and partly reduced modulo prime
  1177. // Assumes prime is normalized and 2^256 - 2^224 <= prime <= 2^256
  1178. void bn_addmod(bignum256* x, const bignum256* y, const bignum256* prime) {
  1179. for(int i = 0; i < BN_LIMBS; i++) {
  1180. x->val[i] += y->val[i];
  1181. // x[i] doesn't overflow 32 bits
  1182. // Proof:
  1183. // x[i] + y[i]
  1184. // <= 2 * (2**BITS_PER_LIMB - 1)
  1185. // == 2**30 - 2 < 2**32
  1186. }
  1187. bn_fast_mod(x, prime);
  1188. }
  1189. // x = x + y
  1190. // Assumes x is normalized
  1191. // Assumes y <= 2**32 - 2**29 == 2**32 - 2**BITS_PER_LIMB and
  1192. // x + y < 2**261 == 2**(LIMBS * BITS_PER_LIMB)
  1193. // Guarantees x is normalized
  1194. void bn_addi(bignum256* x, uint32_t y) {
  1195. // assert(y <= 3758096384); // assert y <= 2**32 - 2**29
  1196. uint32_t acc = y;
  1197. for(int i = 0; i < BN_LIMBS; i++) {
  1198. acc += x->val[i];
  1199. // acc doesn't overflow 32 bits
  1200. // Proof:
  1201. // if i == 0:
  1202. // acc + x[i] == y + x[0]
  1203. // <= (2**32 - 2**BITS_PER_LIMB) + (2**BITS_PER_LIMB - 1)
  1204. // == 2**32 - 1 < 2**32
  1205. // else:
  1206. // acc + x[i]
  1207. // <= (2**(32 - BITS_PER_LIMB) - 1) + (2**BITS_PER_LIMB - 1)
  1208. // == 7 + 2**29 - 1 < 2**32
  1209. x->val[i] = acc & BN_LIMB_MASK;
  1210. acc >>= (BN_BITS_PER_LIMB);
  1211. // acc <= 7 == 2**(32 - BITS_PER_LIMB) - 1
  1212. // acc == x[:i + 1] + y >> BITS_PER_LIMB * (i + 1)
  1213. }
  1214. // assert(acc == 0); // assert x + y < 2**261
  1215. // acc == 0
  1216. // Proof:
  1217. // acc == x[:LIMBS] + y << LIMBS * BITS_PER_LIMB
  1218. // == x + y << LIMBS * BITS_PER_LIMB
  1219. // <= 2**(LIMBS + BITS_PER_LIMB) - 1 << LIMBS * BITS_PER_LIMB
  1220. // == 0
  1221. }
  1222. // x = x - y % prime
  1223. // Explicitly x = x + prime - y
  1224. // Assumes x, y are normalized
  1225. // Assumes y < prime[0], x + prime - y < 2**261 == 2**(LIMBS * BITS_PER_LIMB)
  1226. // Guarantees x is normalized
  1227. // If x is fully reduced modulo prime,
  1228. // guarantess x will be partly reduced modulo prime
  1229. // Assumes prime is nonzero and normalized
  1230. void bn_subi(bignum256* x, uint32_t y, const bignum256* prime) {
  1231. assert(y < prime->val[0]);
  1232. // x = x + prime - y
  1233. uint32_t acc = -y;
  1234. for(int i = 0; i < BN_LIMBS; i++) {
  1235. acc += x->val[i] + prime->val[i];
  1236. // acc neither overflows 32 bits nor underflows 0
  1237. // Proof:
  1238. // acc + x[i] + prime[i]
  1239. // <= (2**(32 - BITS_PER_LIMB) - 1) + 2 * (2**BITS_PER_LIMB - 1)
  1240. // <= 7 + 2**30 - 2 < 2**32
  1241. // acc + x[i] + prime[i]
  1242. // >= -y + prime[0] >= 0
  1243. x->val[i] = acc & BN_LIMB_MASK;
  1244. acc >>= BN_BITS_PER_LIMB;
  1245. // acc <= 7 == 2**(32 - BITS_PER_LIMB) - 1
  1246. // acc == x[:i + 1] + prime[:i + 1] - y >> BITS_PER_LIMB * (i + 1)
  1247. }
  1248. // assert(acc == 0); // assert x + prime - y < 2**261
  1249. // acc == 0
  1250. // Proof:
  1251. // acc == x[:LIMBS] + prime[:LIMBS] - y >> BITS_PER_LIMB * LIMBS
  1252. // == x + prime - y >> BITS_PER_LIMB * LIMBS
  1253. // <= 2**(LIMBS * BITS_PER_LIMB) - 1 >> BITS_PER_LIMB * LIMBS == 0
  1254. }
  1255. // res = x - y % prime
  1256. // Explicitly res = x + (2 * prime - y)
  1257. // Assumes x, y are normalized, y is partly reduced
  1258. // Assumes x + 2 * prime - y < 2**261 == 2**(BITS_PER_LIMB * LIMBS)
  1259. // Guarantees res is normalized
  1260. // Assumes prime is nonzero and normalized
  1261. void bn_subtractmod(const bignum256* x, const bignum256* y, bignum256* res, const bignum256* prime) {
  1262. // res = x + (2 * prime - y)
  1263. uint32_t acc = 1;
  1264. for(int i = 0; i < BN_LIMBS; i++) {
  1265. acc += (BN_BASE - 1) + x->val[i] + 2 * prime->val[i] - y->val[i];
  1266. // acc neither overflows 32 bits nor underflows 0
  1267. // Proof:
  1268. // acc + (BASE - 1) + x[i] + 2 * prime[i] - y[i]
  1269. // >= (BASE - 1) - y[i]
  1270. // == (2**BITS_PER_LIMB - 1) - (2**BITS_PER_LIMB - 1) == 0
  1271. // acc + (BASE - 1) + x[i] + 2 * prime[i] - y[i]
  1272. // <= acc + (BASE - 1) + x[i] + 2 * prime[i]
  1273. // <= (2**(32 - BITS_PER_LIMB) - 1) + (2**BITS_PER_LIMB - 1) +
  1274. // (2**BITS_PER_LIMB - 1) + 2 * (2**BITS_PER_LIMB - 1)
  1275. // <= (2**(32 - BITS_PER_LIMB) - 1) + 4 * (2**BITS_PER_LIMB - 1)
  1276. // == 7 + 4 * 2**29 - 4 == 2**31 + 3 < 2**32
  1277. res->val[i] = acc & (BN_BASE - 1);
  1278. acc >>= BN_BITS_PER_LIMB;
  1279. // acc <= 7 == 2**(32 - BITS_PER_LIMB) - 1
  1280. // acc == 2**(BITS_PER_LIMB * (i + 1)) + x[:i+1] - y[:i+1] + 2*prime[:i+1]
  1281. // >> BITS_PER_LIMB * (i+1)
  1282. }
  1283. // assert(acc == 1); // assert x + 2 * prime - y < 2**261
  1284. // clang-format off
  1285. // acc == 1
  1286. // Proof:
  1287. // acc == 2**(BITS_PER_LIMB * LIMBS) + x[:LIMBS] - y[:LIMBS] + 2 * prime[:LIMBS] >> BITS_PER_LIMB * LIMBS
  1288. // == 2**(BITS_PER_LIMB * LIMBS) + x - y + 2 * prime >> BITS_PER_LIMB * LIMBS
  1289. // == 2**(BITS_PER_LIMB * LIMBS) + x + (2 * prime - y) >> BITS_PER_LIMB * LIMBS
  1290. // <= 2**(BITS_PER_LIMB * LIMBS) + 2**(BITS_PER_LIMB * LIMBS) - 1 >> BITS_PER_LIMB * LIMBS
  1291. // <= 2 * 2**(BITS_PER_LIMB * LIMBS) - 1 >> BITS_PER_LIMB * LIMBS
  1292. // == 1
  1293. // acc == 2**(BITS_PER_LIMB * LIMBS) + x[:LIMBS] - y[:LIMBS] + 2 * prime[:LIMBS] >> BITS_PER_LIMB * LIMBS
  1294. // == 2**(BITS_PER_LIMB * LIMBS) + x - y + 2 * prime >> BITS_PER_LIMB * LIMBS
  1295. // == 2**(BITS_PER_LIMB * LIMBS) + x + (2 * prime - y) >> BITS_PER_LIMB * LIMBS
  1296. // >= 2**(BITS_PER_LIMB * LIMBS) + 0 + 1 >> BITS_PER_LIMB * LIMBS
  1297. // == 1
  1298. // clang-format on
  1299. }
  1300. // res = x - y
  1301. // Assumes x, y are normalized and x >= y
  1302. // Guarantees res is normalized
  1303. // Works properly even if &x == &y or &x == &res or &y == &res or
  1304. // &x == &y == &res
  1305. void bn_subtract(const bignum256* x, const bignum256* y, bignum256* res) {
  1306. uint32_t acc = 1;
  1307. for(int i = 0; i < BN_LIMBS; i++) {
  1308. acc += (BN_BASE - 1) + x->val[i] - y->val[i];
  1309. // acc neither overflows 32 bits nor underflows 0
  1310. // Proof:
  1311. // acc + (BASE - 1) + x[i] - y[i]
  1312. // >= (BASE - 1) - y == (2**BITS_PER_LIMB - 1) - (2**BITS_PER_LIMB - 1)
  1313. // == 0
  1314. // acc + (BASE - 1) + x[i] - y[i]
  1315. // <= acc + (BASE - 1) + x[i]
  1316. // <= (2**(32 - BITS_PER_LIMB) - 1) + (2**BITS_PER_LIMB - 1) +
  1317. // (2**BITS_PER_LIMB - 1)
  1318. // == 7 + 2 * 2**29 < 2 **32
  1319. res->val[i] = acc & BN_LIMB_MASK;
  1320. acc >>= BN_BITS_PER_LIMB;
  1321. // acc <= 7 == 2**(32 - BITS_PER_LIMB) - 1
  1322. // acc == 2**(BITS_PER_LIMB * (i + 1)) + x[:i + 1] - y[:i + 1]
  1323. // >> BITS_PER_LIMB * (i + 1)
  1324. }
  1325. // assert(acc == 1); // assert x >= y
  1326. // clang-format off
  1327. // acc == 1
  1328. // Proof:
  1329. // acc == 2**(BITS_PER_LIMB * LIMBS) + x[:LIMBS] - y[:LIMBS] >> BITS_PER_LIMB * LIMBS
  1330. // == 2**(BITS_PER_LIMB * LIMBS) + x - y >> BITS_PER_LIMB * LIMBS
  1331. // == 2**(BITS_PER_LIMB * LIMBS) + x >> BITS_PER_LIMB * LIMBS
  1332. // <= 2**(BITS_PER_LIMB * LIMBS) + 2**(BITS_PER_LIMB * LIMBS) - 1 >> BITS_PER_LIMB * LIMBS
  1333. // <= 2 * 2**(BITS_PER_LIMB * LIMBS) - 1 >> BITS_PER_LIMB * LIMBS
  1334. // == 1
  1335. // acc == 2**(BITS_PER_LIMB * LIMBS) + x[:LIMBS] - y[:LIMBS] >> BITS_PER_LIMB * LIMBS
  1336. // == 2**(BITS_PER_LIMB * LIMBS) + x - y >> BITS_PER_LIMB * LIMBS
  1337. // >= 2**(BITS_PER_LIMB * LIMBS) >> BITS_PER_LIMB * LIMBS
  1338. // == 1
  1339. }
  1340. // q = x // d, r = x % d
  1341. // Assumes x is normalized, 1 <= d <= 61304
  1342. // Guarantees q is normalized
  1343. void bn_long_division(bignum256 *x, uint32_t d, bignum256 *q, uint32_t *r) {
  1344. assert(1 <= d && d < 61304);
  1345. uint32_t acc = 0;
  1346. *r = x->val[BN_LIMBS - 1] % d;
  1347. q->val[BN_LIMBS - 1] = x->val[BN_LIMBS - 1] / d;
  1348. for (int i = BN_LIMBS - 2; i >= 0; i--) {
  1349. acc = *r * (BN_BASE % d) + x->val[i];
  1350. // acc doesn't overflow 32 bits
  1351. // Proof:
  1352. // r * (BASE % d) + x[i]
  1353. // <= (d - 1) * (d - 1) + (2**BITS_PER_LIMB - 1)
  1354. // == d**2 - 2*d + 2**BITS_PER_LIMB
  1355. // == 61304**2 - 2 * 61304 + 2**29
  1356. // == 3758057808 + 2**29 < 2**32
  1357. q->val[i] = *r * (BN_BASE / d) + (acc / d);
  1358. // q[i] doesn't overflow 32 bits
  1359. // Proof:
  1360. // r * (BASE // d) + (acc // d)
  1361. // <= (d - 1) * (2**BITS_PER_LIMB / d) +
  1362. // ((d**2 - 2*d + 2**BITS_PER_LIMB) / d)
  1363. // <= (d - 1) * (2**BITS_PER_LIMB / d) + (d - 2 + 2**BITS_PER_LIMB / d)
  1364. // == (d - 1 + 1) * (2**BITS_PER_LIMB / d) + d - 2
  1365. // == 2**BITS_PER_LIMB + d - 2 <= 2**29 + 61304 < 2**32
  1366. // q[i] == (r * BASE + x[i]) // d
  1367. // Proof:
  1368. // q[i] == r * (BASE // d) + (acc // d)
  1369. // == r * (BASE // d) + (r * (BASE % d) + x[i]) // d
  1370. // == (r * d * (BASE // d) + r * (BASE % d) + x[i]) // d
  1371. // == (r * (d * (BASE // d) + (BASE % d)) + x[i]) // d
  1372. // == (r * BASE + x[i]) // d
  1373. // q[i] < 2**BITS_PER_LIMB
  1374. // Proof:
  1375. // q[i] == (r * BASE + x[i]) // d
  1376. // <= ((d - 1) * 2**BITS_PER_LIMB + (2**BITS_PER_LIMB - 1)) / d
  1377. // == (d * 2**BITS_PER_LIMB - 1) / d == 2**BITS_PER_LIMB - 1 / d
  1378. // < 2**BITS_PER_LIMB
  1379. *r = acc % d;
  1380. // r == (r * BASE + x[i]) % d
  1381. // Proof:
  1382. // r == acc % d == (r * (BASE % d) + x[i]) % d
  1383. // == (r * BASE + x[i]) % d
  1384. // x[:i] == q[:i] * d + r
  1385. }
  1386. }
  1387. // x = x // 58, r = x % 58
  1388. // Assumes x is normalized
  1389. // Guarantees x is normalized
  1390. void bn_divmod58(bignum256 *x, uint32_t *r) { bn_long_division(x, 58, x, r); }
  1391. // x = x // 1000, r = x % 1000
  1392. // Assumes x is normalized
  1393. // Guarantees x is normalized
  1394. void bn_divmod1000(bignum256 *x, uint32_t *r) {
  1395. bn_long_division(x, 1000, x, r);
  1396. }
  1397. // x = x // 10, r = x % 10
  1398. // Assumes x is normalized
  1399. // Guarantees x is normalized
  1400. void bn_divmod10(bignum256 *x, uint32_t *r) { bn_long_division(x, 10, x, r); }
  1401. // Formats amount
  1402. // Assumes amount is normalized
  1403. // Assumes prefix and suffix are null-terminated strings
  1404. // Assumes output is an array of length output_length
  1405. // The function doesn't have neither constant control flow nor constant memory
  1406. // access flow with regard to any its argument
  1407. size_t bn_format(const bignum256 *amount, const char *prefix, const char *suffix, unsigned int decimals, int exponent, bool trailing, char thousands, char *output, size_t output_length) {
  1408. /*
  1409. Python prototype of the function:
  1410. def format(amount, prefix, suffix, decimals, exponent, trailing, thousands):
  1411. if exponent >= 0:
  1412. amount *= 10**exponent
  1413. else:
  1414. amount //= 10 ** (-exponent)
  1415. d = pow(10, decimals)
  1416. integer_part = amount // d
  1417. integer_str = f"{integer_part:,}".replace(",", thousands or "")
  1418. if decimals:
  1419. decimal_part = amount % d
  1420. decimal_str = f".{decimal_part:0{decimals}d}"
  1421. if not trailing:
  1422. decimal_str = decimal_str.rstrip("0").rstrip(".")
  1423. else:
  1424. decimal_str = ""
  1425. return prefix + integer_str + decimal_str + suffix
  1426. */
  1427. // Auxiliary macro for bn_format
  1428. // If enough space adds one character to output starting from the end
  1429. #define BN_FORMAT_ADD_OUTPUT_CHAR(c) \
  1430. { \
  1431. --position; \
  1432. if (output <= position && position < output + output_length) { \
  1433. *position = (c); \
  1434. } else { \
  1435. memset(output, '\0', output_length); \
  1436. return 0; \
  1437. } \
  1438. }
  1439. bignum256 temp = {0};
  1440. bn_copy(amount, &temp);
  1441. uint32_t digit = 0;
  1442. char *position = output + output_length;
  1443. // Add string ending character
  1444. BN_FORMAT_ADD_OUTPUT_CHAR('\0');
  1445. // Add suffix
  1446. size_t suffix_length = suffix ? strlen(suffix) : 0;
  1447. for (int i = suffix_length - 1; i >= 0; --i)
  1448. BN_FORMAT_ADD_OUTPUT_CHAR(suffix[i])
  1449. // amount //= 10**exponent
  1450. for (; exponent < 0; ++exponent) {
  1451. // if temp == 0, there is no need to divide it by 10 anymore
  1452. if (bn_is_zero(&temp)) {
  1453. exponent = 0;
  1454. break;
  1455. }
  1456. bn_divmod10(&temp, &digit);
  1457. }
  1458. // exponent >= 0 && decimals >= 0
  1459. bool fractional_part = false; // is fractional-part of amount present
  1460. { // Add fractional-part digits of amount
  1461. // Add trailing zeroes
  1462. unsigned int trailing_zeros = decimals < (unsigned int) exponent ? decimals : (unsigned int) exponent;
  1463. // When casting a negative int to unsigned int, UINT_MAX is added to the int before
  1464. // Since exponent >= 0, the value remains unchanged
  1465. decimals -= trailing_zeros;
  1466. exponent -= trailing_zeros;
  1467. if (trailing && trailing_zeros) {
  1468. fractional_part = true;
  1469. for (; trailing_zeros > 0; --trailing_zeros)
  1470. BN_FORMAT_ADD_OUTPUT_CHAR('0')
  1471. }
  1472. // exponent == 0 || decimals == 0
  1473. // Add significant digits and leading zeroes
  1474. for (; decimals > 0; --decimals) {
  1475. bn_divmod10(&temp, &digit);
  1476. if (fractional_part || digit || trailing) {
  1477. fractional_part = true;
  1478. BN_FORMAT_ADD_OUTPUT_CHAR('0' + digit)
  1479. }
  1480. else if (bn_is_zero(&temp)) {
  1481. // We break since the remaining digits are zeroes and fractional_part == trailing == false
  1482. decimals = 0;
  1483. break;
  1484. }
  1485. }
  1486. // decimals == 0
  1487. }
  1488. if (fractional_part) {
  1489. BN_FORMAT_ADD_OUTPUT_CHAR('.')
  1490. }
  1491. { // Add integer-part digits of amount
  1492. // Add trailing zeroes
  1493. int digits = 0;
  1494. if (!bn_is_zero(&temp)) {
  1495. for (; exponent > 0; --exponent) {
  1496. ++digits;
  1497. BN_FORMAT_ADD_OUTPUT_CHAR('0')
  1498. if (thousands != 0 && digits % 3 == 0) {
  1499. BN_FORMAT_ADD_OUTPUT_CHAR(thousands)
  1500. }
  1501. }
  1502. }
  1503. // decimals == 0 && exponent == 0
  1504. // Add significant digits
  1505. bool is_zero = false;
  1506. do {
  1507. ++digits;
  1508. bn_divmod10(&temp, &digit);
  1509. is_zero = bn_is_zero(&temp);
  1510. BN_FORMAT_ADD_OUTPUT_CHAR('0' + digit)
  1511. if (thousands != 0 && !is_zero && digits % 3 == 0) {
  1512. BN_FORMAT_ADD_OUTPUT_CHAR(thousands)
  1513. }
  1514. } while (!is_zero);
  1515. }
  1516. // Add prefix
  1517. size_t prefix_length = prefix ? strlen(prefix) : 0;
  1518. for (int i = prefix_length - 1; i >= 0; --i)
  1519. BN_FORMAT_ADD_OUTPUT_CHAR(prefix[i])
  1520. // Move formatted amount to the start of output
  1521. int length = output - position + output_length;
  1522. memmove(output, position, length);
  1523. return length - 1;
  1524. }
  1525. #if USE_BN_PRINT
  1526. // Prints x in hexadecimal
  1527. // Assumes x is normalized and x < 2**256
  1528. void bn_print(const bignum256 *x) {
  1529. printf("%06x", x->val[8]);
  1530. printf("%08x", ((x->val[7] << 3) | (x->val[6] >> 26)));
  1531. printf("%07x", ((x->val[6] << 2) | (x->val[5] >> 27)) & 0x0FFFFFFF);
  1532. printf("%07x", ((x->val[5] << 1) | (x->val[4] >> 28)) & 0x0FFFFFFF);
  1533. printf("%07x", x->val[4] & 0x0FFFFFFF);
  1534. printf("%08x", ((x->val[3] << 3) | (x->val[2] >> 26)));
  1535. printf("%07x", ((x->val[2] << 2) | (x->val[1] >> 27)) & 0x0FFFFFFF);
  1536. printf("%07x", ((x->val[1] << 1) | (x->val[0] >> 28)) & 0x0FFFFFFF);
  1537. printf("%07x", x->val[0] & 0x0FFFFFFF);
  1538. }
  1539. // Prints comma separated list of limbs of x
  1540. void bn_print_raw(const bignum256 *x) {
  1541. for (int i = 0; i < BN_LIMBS - 1; i++) {
  1542. printf("0x%08x, ", x->val[i]);
  1543. }
  1544. printf("0x%08x", x->val[BN_LIMBS - 1]);
  1545. }
  1546. #endif
  1547. #if USE_INVERSE_FAST
  1548. void bn_inverse(bignum256 *x, const bignum256 *prime) {
  1549. bn_inverse_fast(x, prime);
  1550. }
  1551. #else
  1552. void bn_inverse(bignum256 *x, const bignum256 *prime) {
  1553. bn_inverse_slow(x, prime);
  1554. }
  1555. #endif