programmer_openocd.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294
  1. import logging
  2. import os
  3. import typing
  4. from enum import Enum
  5. from flipper.utils.programmer import Programmer
  6. from flipper.utils.openocd import OpenOCD
  7. from flipper.utils.stm32wb55 import STM32WB55
  8. from flipper.assets.obdata import OptionBytesData
  9. class OpenOCDProgrammerResult(Enum):
  10. Success = 0
  11. ErrorGeneric = 1
  12. ErrorAlignment = 2
  13. ErrorAlreadyWritten = 3
  14. ErrorValidation = 4
  15. class OpenOCDProgrammer(Programmer):
  16. def __init__(
  17. self,
  18. interface: str = "interface/cmsis-dap.cfg",
  19. port_base: typing.Union[int, None] = None,
  20. serial: typing.Union[str, None] = None,
  21. ):
  22. super().__init__()
  23. config = {}
  24. config["interface"] = interface
  25. config["target"] = "target/stm32wbx.cfg"
  26. if serial is not None:
  27. if interface == "interface/cmsis-dap.cfg":
  28. config["serial"] = f"cmsis_dap_serial {serial}"
  29. elif "stlink" in interface:
  30. config["serial"] = f"stlink_serial {serial}"
  31. if port_base is not None:
  32. config["port_base"] = port_base
  33. self.openocd = OpenOCD(config)
  34. self.logger = logging.getLogger()
  35. def reset(self, mode: Programmer.RunMode = Programmer.RunMode.Run) -> bool:
  36. stm32 = STM32WB55()
  37. if mode == Programmer.RunMode.Run:
  38. stm32.reset(self.openocd, stm32.RunMode.Run)
  39. elif mode == Programmer.RunMode.Stop:
  40. stm32.reset(self.openocd, stm32.RunMode.Init)
  41. else:
  42. raise Exception("Unknown mode")
  43. return True
  44. def flash(self, address: int, file_path: str, verify: bool = True) -> bool:
  45. if not os.path.exists(file_path):
  46. raise Exception(f"File {file_path} not found")
  47. self.openocd.start()
  48. self.openocd.send_tcl("init")
  49. self.openocd.send_tcl(
  50. f"program {file_path} 0x{address:08x}{' verify' if verify else ''} reset exit"
  51. )
  52. self.openocd.stop()
  53. return True
  54. def _ob_print_diff_table(self, ob_reference: bytes, ob_read: bytes, print_fn):
  55. print_fn(
  56. f'{"Reference": <20} {"Device": <20} {"Diff Reference": <20} {"Diff Device": <20}'
  57. )
  58. # Split into 8 byte, word + word
  59. for i in range(0, len(ob_reference), 8):
  60. ref = ob_reference[i : i + 8]
  61. read = ob_read[i : i + 8]
  62. diff_str1 = ""
  63. diff_str2 = ""
  64. for j in range(0, len(ref.hex()), 2):
  65. byte_str_1 = ref.hex()[j : j + 2]
  66. byte_str_2 = read.hex()[j : j + 2]
  67. if byte_str_1 == byte_str_2:
  68. diff_str1 += "__"
  69. diff_str2 += "__"
  70. else:
  71. diff_str1 += byte_str_1
  72. diff_str2 += byte_str_2
  73. print_fn(
  74. f"{ref.hex(): <20} {read.hex(): <20} {diff_str1: <20} {diff_str2: <20}"
  75. )
  76. def option_bytes_validate(self, file_path: str) -> bool:
  77. # Registers
  78. stm32 = STM32WB55()
  79. # OpenOCD
  80. self.openocd.start()
  81. stm32.reset(self.openocd, stm32.RunMode.Init)
  82. # Generate Option Bytes data
  83. ob_data = OptionBytesData(file_path)
  84. ob_values = ob_data.gen_values().export()
  85. ob_reference = ob_values.reference
  86. ob_compare_mask = ob_values.compare_mask
  87. ob_length = len(ob_reference)
  88. ob_words = int(ob_length / 4)
  89. # Read Option Bytes
  90. ob_read = bytes()
  91. for i in range(ob_words):
  92. addr = stm32.OPTION_BYTE_BASE + i * 4
  93. value = self.openocd.read_32(addr)
  94. ob_read += value.to_bytes(4, "little")
  95. # Compare Option Bytes with reference by mask
  96. ob_compare = bytes()
  97. for i in range(ob_length):
  98. ob_compare += bytes([ob_read[i] & ob_compare_mask[i]])
  99. # Compare Option Bytes
  100. return_code = False
  101. if ob_reference == ob_compare:
  102. self.logger.info("Option Bytes are valid")
  103. return_code = True
  104. else:
  105. self.logger.error("Option Bytes are invalid")
  106. self._ob_print_diff_table(ob_reference, ob_compare, self.logger.error)
  107. # Stop OpenOCD
  108. stm32.reset(self.openocd, stm32.RunMode.Run)
  109. self.openocd.stop()
  110. return return_code
  111. def _unpack_u32(self, data: bytes, offset: int):
  112. return int.from_bytes(data[offset : offset + 4], "little")
  113. def option_bytes_set(self, file_path: str) -> bool:
  114. # Registers
  115. stm32 = STM32WB55()
  116. # OpenOCD
  117. self.openocd.start()
  118. stm32.reset(self.openocd, stm32.RunMode.Init)
  119. # Generate Option Bytes data
  120. ob_data = OptionBytesData(file_path)
  121. ob_values = ob_data.gen_values().export()
  122. ob_reference_bytes = ob_values.reference
  123. ob_compare_mask_bytes = ob_values.compare_mask
  124. ob_write_mask_bytes = ob_values.write_mask
  125. ob_length = len(ob_reference_bytes)
  126. ob_dwords = int(ob_length / 8)
  127. # Clear flash errors
  128. stm32.clear_flash_errors(self.openocd)
  129. # Unlock Flash and Option Bytes
  130. stm32.flash_unlock(self.openocd)
  131. stm32.option_bytes_unlock(self.openocd)
  132. ob_need_to_apply = False
  133. for i in range(ob_dwords):
  134. device_addr = stm32.OPTION_BYTE_BASE + i * 8
  135. device_value = self.openocd.read_32(device_addr)
  136. ob_write_mask = self._unpack_u32(ob_write_mask_bytes, i * 8)
  137. ob_compare_mask = self._unpack_u32(ob_compare_mask_bytes, i * 8)
  138. ob_value_ref = self._unpack_u32(ob_reference_bytes, i * 8)
  139. ob_value_masked = device_value & ob_compare_mask
  140. need_patch = ((ob_value_masked ^ ob_value_ref) & ob_write_mask) != 0
  141. if need_patch:
  142. ob_need_to_apply = True
  143. self.logger.info(
  144. f"Need to patch: {device_addr:08X}: {ob_value_masked:08X} != {ob_value_ref:08X}, REG[{i}]"
  145. )
  146. # Check if this option byte (dword) is mapped to a register
  147. device_reg_addr = stm32.option_bytes_id_to_address(i)
  148. # Construct new value for the OB register
  149. ob_value = device_value & (~ob_write_mask)
  150. ob_value |= ob_value_ref & ob_write_mask
  151. self.logger.info(f"Writing {ob_value:08X} to {device_reg_addr:08X}")
  152. self.openocd.write_32(device_reg_addr, ob_value)
  153. if ob_need_to_apply:
  154. stm32.option_bytes_apply(self.openocd)
  155. else:
  156. self.logger.info("Option Bytes are already correct")
  157. # Load Option Bytes
  158. # That will reset and also lock the Option Bytes and the Flash
  159. stm32.option_bytes_load(self.openocd)
  160. # Stop OpenOCD
  161. stm32.reset(self.openocd, stm32.RunMode.Run)
  162. self.openocd.stop()
  163. return True
  164. def otp_write(self, address: int, file_path: str) -> OpenOCDProgrammerResult:
  165. # Open file, check that it aligned to 8 bytes
  166. with open(file_path, "rb") as f:
  167. data = f.read()
  168. if len(data) % 8 != 0:
  169. self.logger.error(f"File {file_path} is not aligned to 8 bytes")
  170. return OpenOCDProgrammerResult.ErrorAlignment
  171. # Check that address is aligned to 8 bytes
  172. if address % 8 != 0:
  173. self.logger.error(f"Address {address} is not aligned to 8 bytes")
  174. return OpenOCDProgrammerResult.ErrorAlignment
  175. # Get size of data
  176. data_size = len(data)
  177. # Check that data size is aligned to 8 bytes
  178. if data_size % 8 != 0:
  179. self.logger.error(f"Data size {data_size} is not aligned to 8 bytes")
  180. return OpenOCDProgrammerResult.ErrorAlignment
  181. self.logger.debug(f"Writing {data_size} bytes to OTP at {address:08X}")
  182. self.logger.debug(f"Data: {data.hex().upper()}")
  183. # Start OpenOCD
  184. oocd = self.openocd
  185. oocd.start()
  186. # Registers
  187. stm32 = STM32WB55()
  188. try:
  189. # Check that OTP is empty for the given address
  190. # Also check that data is already written
  191. already_written = True
  192. for i in range(0, data_size, 4):
  193. file_word = int.from_bytes(data[i : i + 4], "little")
  194. device_word = oocd.read_32(address + i)
  195. if device_word != 0xFFFFFFFF and device_word != file_word:
  196. self.logger.error(
  197. f"OTP memory at {address + i:08X} is not empty: {device_word:08X}"
  198. )
  199. return OpenOCDProgrammerResult.ErrorAlreadyWritten
  200. if device_word != file_word:
  201. already_written = False
  202. if already_written:
  203. self.logger.info("OTP memory is already written with the given data")
  204. return OpenOCDProgrammerResult.Success
  205. self.reset(self.RunMode.Stop)
  206. stm32.clear_flash_errors(oocd)
  207. # Write OTP memory by 8 bytes
  208. for i in range(0, data_size, 8):
  209. word_1 = int.from_bytes(data[i : i + 4], "little")
  210. word_2 = int.from_bytes(data[i + 4 : i + 8], "little")
  211. self.logger.debug(
  212. f"Writing {word_1:08X} {word_2:08X} to {address + i:08X}"
  213. )
  214. stm32.write_flash_64(oocd, address + i, word_1, word_2)
  215. # Validate OTP memory
  216. validation_result = True
  217. for i in range(0, data_size, 4):
  218. file_word = int.from_bytes(data[i : i + 4], "little")
  219. device_word = oocd.read_32(address + i)
  220. if file_word != device_word:
  221. self.logger.error(
  222. f"Validation failed: {file_word:08X} != {device_word:08X} at {address + i:08X}"
  223. )
  224. validation_result = False
  225. finally:
  226. # Stop OpenOCD
  227. stm32.reset(oocd, stm32.RunMode.Run)
  228. oocd.stop()
  229. return (
  230. OpenOCDProgrammerResult.Success
  231. if validation_result
  232. else OpenOCDProgrammerResult.ErrorValidation
  233. )