test_curves.py 9.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381
  1. #!/usr/bin/py.test
  2. import binascii
  3. import ctypes as c
  4. import hashlib
  5. import os
  6. import random
  7. import curve25519
  8. import ecdsa
  9. import pytest
  10. def bytes2num(s):
  11. res = 0
  12. for i, b in enumerate(reversed(bytearray(s))):
  13. res += b << (i * 8)
  14. return res
  15. curves = {"nist256p1": ecdsa.curves.NIST256p, "secp256k1": ecdsa.curves.SECP256k1}
  16. class Point:
  17. def __init__(self, name, x, y):
  18. self.curve = name
  19. self.x = x
  20. self.y = y
  21. points = [
  22. Point(
  23. "secp256k1",
  24. 0x79BE667EF9DCBBAC55A06295CE870B07029BFCDB2DCE28D959F2815B16F81798,
  25. 0x483ADA7726A3C4655DA4FBFC0E1108A8FD17B448A68554199C47D08FFB10D4B8,
  26. ),
  27. Point(
  28. "secp256k1",
  29. 0x1,
  30. 0x4218F20AE6C646B363DB68605822FB14264CA8D2587FDD6FBC750D587E76A7EE,
  31. ),
  32. Point(
  33. "secp256k1",
  34. 0x2,
  35. 0x66FBE727B2BA09E09F5A98D70A5EFCE8424C5FA425BBDA1C511F860657B8535E,
  36. ),
  37. Point(
  38. "secp256k1",
  39. 0x1B,
  40. 0x1ADCEA1CF831B0AD1653E769D1A229091D0CC68D4B0328691B9CAACC76E37C90,
  41. ),
  42. Point(
  43. "nist256p1",
  44. 0x6B17D1F2E12C4247F8BCE6E563A440F277037D812DEB33A0F4A13945D898C296,
  45. 0x4FE342E2FE1A7F9B8EE7EB4A7C0F9E162BCE33576B315ECECBB6406837BF51F5,
  46. ),
  47. Point(
  48. "nist256p1",
  49. 0x0,
  50. 0x66485C780E2F83D72433BD5D84A06BB6541C2AF31DAE871728BF856A174F93F4,
  51. ),
  52. Point(
  53. "nist256p1",
  54. 0x0,
  55. 0x99B7A386F1D07C29DBCC42A27B5F9449ABE3D50DE25178E8D7407A95E8B06C0B,
  56. ),
  57. Point(
  58. "nist256p1",
  59. 0xAF8BBDFE8CDD5577ACBF345B543D28CF402F4E94D3865B97EA0787F2D3AA5D22,
  60. 0x35802B8B376B995265918B078BC109C21A535176585C40F519ACA52D6AFC147C,
  61. ),
  62. Point(
  63. "nist256p1",
  64. 0x80000,
  65. 0x580610071F440F0DCC14A22E2D5D5AFC1224C0CD11A3B4B51B8ECD2224EE1CE2,
  66. ),
  67. ]
  68. random_iters = int(os.environ.get("ITERS", 1))
  69. DIR = os.path.abspath(os.path.dirname(__file__))
  70. lib = c.cdll.LoadLibrary(os.path.join(DIR, "libtrezor-crypto.so"))
  71. if not lib.zkp_context_is_initialized():
  72. assert lib.zkp_context_init() == 0
  73. BIGNUM = c.c_uint32 * 9
  74. class curve_info(c.Structure):
  75. _fields_ = [("bip32_name", c.c_char_p), ("params", c.c_void_p)]
  76. class curve_point(c.Structure):
  77. _fields_ = [("x", BIGNUM), ("y", BIGNUM)]
  78. class ecdsa_curve(c.Structure):
  79. _fields_ = [
  80. ("prime", BIGNUM),
  81. ("G", curve_point),
  82. ("order", BIGNUM),
  83. ("order_half", BIGNUM),
  84. ("a", c.c_int),
  85. ("b", BIGNUM),
  86. ]
  87. lib.get_curve_by_name.restype = c.POINTER(curve_info)
  88. class Random(random.Random):
  89. def randbytes(self, n):
  90. buf = (c.c_uint8 * n)()
  91. for i in range(n):
  92. buf[i] = self.randrange(0, 256)
  93. return buf
  94. def randpoint(self, curve):
  95. k = self.randrange(0, curve.order)
  96. return k * curve.generator
  97. def int2bn(x, bn_type=BIGNUM):
  98. b = bn_type()
  99. b._int = x
  100. for i in range(len(b)):
  101. b[i] = x % (1 << 29)
  102. x = x >> 29
  103. return b
  104. def bn2int(b):
  105. x = 0
  106. for i in range(len(b)):
  107. x += b[i] << (29 * i)
  108. return x
  109. @pytest.fixture(params=range(random_iters))
  110. def r(request):
  111. seed = request.param
  112. return Random(seed + int(os.environ.get("SEED", 0)))
  113. def get_curve_obj(name):
  114. curve_ptr = lib.get_curve_by_name(bytes(name, "ascii")).contents.params
  115. assert curve_ptr, "curve {} not found".format(name)
  116. curve_obj = curves[name]
  117. curve_obj.ptr = c.cast(curve_ptr, c.POINTER(ecdsa_curve))
  118. curve_obj.p = curve_obj.curve.p() # shorthand
  119. return curve_obj
  120. @pytest.fixture(params=list(sorted(curves)))
  121. def curve(request):
  122. return get_curve_obj(request.param)
  123. @pytest.fixture(params=points)
  124. def point(request):
  125. name = request.param.curve
  126. curve_ptr = lib.get_curve_by_name(bytes(name, "ascii")).contents.params
  127. assert curve_ptr, "curve {} not found".format(name)
  128. curve_obj = curves[name]
  129. curve_obj.ptr = c.c_void_p(curve_ptr)
  130. curve_obj.p = ecdsa.ellipticcurve.Point(
  131. curve_obj.curve, request.param.x, request.param.y
  132. )
  133. return curve_obj
  134. POINT = BIGNUM * 2
  135. def to_POINT(p):
  136. return POINT(int2bn(p.x()), int2bn(p.y()))
  137. def from_POINT(p):
  138. return (bn2int(p[0]), bn2int(p[1]))
  139. JACOBIAN = BIGNUM * 3
  140. def to_JACOBIAN(jp):
  141. return JACOBIAN(int2bn(jp[0]), int2bn(jp[1]), int2bn(jp[2]))
  142. def from_JACOBIAN(p):
  143. return (bn2int(p[0]), bn2int(p[1]), bn2int(p[2]))
  144. def test_curve_parameters(curve):
  145. assert curve.curve.p() == bn2int(curve.ptr.contents.prime)
  146. assert curve.generator.x() == bn2int(curve.ptr.contents.G.x)
  147. assert curve.generator.y() == bn2int(curve.ptr.contents.G.y)
  148. assert curve.order == bn2int(curve.ptr.contents.order)
  149. assert curve.order // 2 == bn2int(curve.ptr.contents.order_half)
  150. assert curve.curve.a() == curve.ptr.contents.a
  151. assert curve.curve.b() == bn2int(curve.ptr.contents.b)
  152. def test_point_multiply(curve, r):
  153. p = r.randpoint(curve)
  154. k = r.randrange(0, 2**256)
  155. kp = k * p
  156. res = POINT(int2bn(0), int2bn(0))
  157. lib.point_multiply(curve.ptr, int2bn(k), to_POINT(p), res)
  158. res = from_POINT(res)
  159. assert res == (kp.x(), kp.y())
  160. def test_point_add(curve, r):
  161. p1 = r.randpoint(curve)
  162. p2 = r.randpoint(curve)
  163. # print '-' * 80
  164. q = p1 + p2
  165. q1 = to_POINT(p1)
  166. q2 = to_POINT(p2)
  167. lib.point_add(curve.ptr, q1, q2)
  168. q_ = from_POINT(q2)
  169. assert q_ == (q.x(), q.y())
  170. def test_point_double(curve, r):
  171. p = r.randpoint(curve)
  172. q = p.double()
  173. q_ = to_POINT(p)
  174. lib.point_double(curve.ptr, q_)
  175. q_ = from_POINT(q_)
  176. assert q_ == (q.x(), q.y())
  177. def test_point_to_jacobian(curve, r):
  178. p = r.randpoint(curve)
  179. jp = JACOBIAN()
  180. lib.curve_to_jacobian(to_POINT(p), jp, int2bn(curve.p))
  181. jx, jy, jz = from_JACOBIAN(jp)
  182. assert jx % curve.p == (p.x() * jz**2) % curve.p
  183. assert jy % curve.p == (p.y() * jz**3) % curve.p
  184. q = POINT()
  185. lib.jacobian_to_curve(jp, q, int2bn(curve.p))
  186. q = from_POINT(q)
  187. assert q == (p.x(), p.y())
  188. def test_jacobian_add(curve, r):
  189. p1 = r.randpoint(curve)
  190. p2 = r.randpoint(curve)
  191. prime = int2bn(curve.p)
  192. q = POINT()
  193. jp2 = JACOBIAN()
  194. lib.curve_to_jacobian(to_POINT(p2), jp2, prime)
  195. lib.point_jacobian_add(to_POINT(p1), jp2, curve.ptr)
  196. lib.jacobian_to_curve(jp2, q, prime)
  197. q = from_POINT(q)
  198. p_ = p1 + p2
  199. assert (p_.x(), p_.y()) == q
  200. def test_jacobian_add_double(curve, r):
  201. p1 = r.randpoint(curve)
  202. p2 = p1
  203. prime = int2bn(curve.p)
  204. q = POINT()
  205. jp2 = JACOBIAN()
  206. lib.curve_to_jacobian(to_POINT(p2), jp2, prime)
  207. lib.point_jacobian_add(to_POINT(p1), jp2, curve.ptr)
  208. lib.jacobian_to_curve(jp2, q, prime)
  209. q = from_POINT(q)
  210. p_ = p1 + p2
  211. assert (p_.x(), p_.y()) == q
  212. def test_jacobian_double(curve, r):
  213. p = r.randpoint(curve)
  214. p2 = p.double()
  215. prime = int2bn(curve.p)
  216. q = POINT()
  217. jp = JACOBIAN()
  218. lib.curve_to_jacobian(to_POINT(p), jp, prime)
  219. lib.point_jacobian_double(jp, curve.ptr)
  220. lib.jacobian_to_curve(jp, q, prime)
  221. q = from_POINT(q)
  222. assert (p2.x(), p2.y()) == q
  223. def sigdecode(sig, _):
  224. return map(bytes2num, [sig[:32], sig[32:]])
  225. def test_sign(curve, r):
  226. priv = r.randbytes(32)
  227. digest = r.randbytes(32)
  228. sig = r.randbytes(64)
  229. lib.ecdsa_sign_digest(curve.ptr, priv, digest, sig, c.c_void_p(0), c.c_void_p(0))
  230. exp = bytes2num(priv)
  231. sk = ecdsa.SigningKey.from_secret_exponent(exp, curve, hashfunc=hashlib.sha256)
  232. vk = sk.get_verifying_key()
  233. sig_ref = sk.sign_digest_deterministic(
  234. digest, hashfunc=hashlib.sha256, sigencode=ecdsa.util.sigencode_string_canonize
  235. )
  236. assert binascii.hexlify(sig) == binascii.hexlify(sig_ref)
  237. assert vk.verify_digest(sig, digest, sigdecode)
  238. def test_sign_zkp(r):
  239. curve = get_curve_obj("secp256k1")
  240. priv = r.randbytes(32)
  241. digest = r.randbytes(32)
  242. sig = r.randbytes(64)
  243. lib.zkp_ecdsa_sign_digest(
  244. curve.ptr, priv, digest, sig, c.c_void_p(0), c.c_void_p(0)
  245. )
  246. exp = bytes2num(priv)
  247. sk = ecdsa.SigningKey.from_secret_exponent(exp, curve, hashfunc=hashlib.sha256)
  248. vk = sk.get_verifying_key()
  249. sig_ref = sk.sign_digest_deterministic(
  250. digest, hashfunc=hashlib.sha256, sigencode=ecdsa.util.sigencode_string_canonize
  251. )
  252. assert binascii.hexlify(sig) == binascii.hexlify(sig_ref)
  253. assert vk.verify_digest(sig, digest, sigdecode)
  254. def test_validate_pubkey(curve, r):
  255. p = r.randpoint(curve)
  256. assert lib.ecdsa_validate_pubkey(curve.ptr, to_POINT(p))
  257. def test_validate_pubkey_direct(point):
  258. assert lib.ecdsa_validate_pubkey(point.ptr, to_POINT(point.p))
  259. def test_curve25519(r):
  260. sec1 = bytes(bytearray(r.randbytes(32)))
  261. sec2 = bytes(bytearray(r.randbytes(32)))
  262. pub1 = curve25519.Private(sec1).get_public()
  263. pub2 = curve25519.Private(sec2).get_public()
  264. session1 = r.randbytes(32)
  265. lib.curve25519_scalarmult(session1, sec2, pub1.public)
  266. session2 = r.randbytes(32)
  267. lib.curve25519_scalarmult(session2, sec1, pub2.public)
  268. assert bytearray(session1) == bytearray(session2)
  269. shared1 = curve25519.Private(sec2).get_shared_key(pub1, hashfunc=lambda x: x)
  270. shared2 = curve25519.Private(sec1).get_shared_key(pub2, hashfunc=lambda x: x)
  271. assert shared1 == shared2
  272. assert bytearray(session1) == shared1
  273. assert bytearray(session2) == shared2
  274. def test_curve25519_pubkey(r):
  275. sec = bytes(bytearray(r.randbytes(32)))
  276. pub = curve25519.Private(sec).get_public()
  277. res = r.randbytes(32)
  278. lib.curve25519_scalarmult_basepoint(res, sec)
  279. assert bytearray(res) == pub.public
  280. def test_curve25519_scalarmult_from_gpg(r):
  281. sec = binascii.unhexlify(
  282. "4a1e76f133afb29dbc7860bcbc16d0e829009cc15c2f81ed26de1179b1d9c938"
  283. )
  284. pub = binascii.unhexlify(
  285. "5d6fc75c016e85b17f54e0128a216d5f9229f25bac1ec85cecab8daf48621b31"
  286. )
  287. res = r.randbytes(32)
  288. lib.curve25519_scalarmult(res, sec[::-1], pub[::-1])
  289. expected = "a93dbdb23e5c99da743e203bd391af79f2b83fb8d0fd6ec813371c71f08f2d4d"
  290. assert binascii.hexlify(bytearray(res)) == bytes(expected, "ascii")