program.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459
  1. #!/usr/bin/env python3
  2. import logging
  3. import os
  4. import socket
  5. import subprocess
  6. import time
  7. import typing
  8. from abc import ABC, abstractmethod
  9. from dataclasses import dataclass
  10. from flipper.app import App
  11. class Programmer(ABC):
  12. @abstractmethod
  13. def flash(self, bin: str) -> bool:
  14. pass
  15. @abstractmethod
  16. def probe(self) -> bool:
  17. pass
  18. @abstractmethod
  19. def get_name(self) -> str:
  20. pass
  21. @abstractmethod
  22. def set_serial(self, serial: str):
  23. pass
  24. @dataclass
  25. class OpenOCDInterface:
  26. name: str
  27. file: str
  28. serial_cmd: str
  29. additional_args: typing.Optional[list[str]] = None
  30. class OpenOCDProgrammer(Programmer):
  31. def __init__(self, interface: OpenOCDInterface):
  32. self.interface = interface
  33. self.logger = logging.getLogger("OpenOCD")
  34. self.serial: typing.Optional[str] = None
  35. def _add_file(self, params: list[str], file: str):
  36. params.append("-f")
  37. params.append(file)
  38. def _add_command(self, params: list[str], command: str):
  39. params.append("-c")
  40. params.append(command)
  41. def _add_serial(self, params: list[str], serial: str):
  42. self._add_command(params, f"{self.interface.serial_cmd} {serial}")
  43. def set_serial(self, serial: str):
  44. self.serial = serial
  45. def flash(self, bin: str) -> bool:
  46. i = self.interface
  47. if os.altsep:
  48. bin = bin.replace(os.sep, os.altsep)
  49. openocd_launch_params = ["openocd"]
  50. self._add_file(openocd_launch_params, i.file)
  51. if self.serial:
  52. self._add_serial(openocd_launch_params, self.serial)
  53. if i.additional_args:
  54. for a in i.additional_args:
  55. self._add_command(openocd_launch_params, a)
  56. self._add_file(openocd_launch_params, "target/stm32wbx.cfg")
  57. self._add_command(openocd_launch_params, "init")
  58. self._add_command(openocd_launch_params, f"program {bin} reset exit 0x8000000")
  59. # join the list of parameters into a string, but add quote if there are spaces
  60. openocd_launch_params_string = " ".join(
  61. [f'"{p}"' if " " in p else p for p in openocd_launch_params]
  62. )
  63. self.logger.debug(f"Launching: {openocd_launch_params_string}")
  64. process = subprocess.Popen(
  65. openocd_launch_params,
  66. stdout=subprocess.PIPE,
  67. stderr=subprocess.STDOUT,
  68. )
  69. while process.poll() is None:
  70. time.sleep(0.25)
  71. print(".", end="", flush=True)
  72. print()
  73. success = process.returncode == 0
  74. if not success:
  75. self.logger.error("OpenOCD failed to flash")
  76. if process.stdout:
  77. self.logger.error(process.stdout.read().decode("utf-8").strip())
  78. return success
  79. def probe(self) -> bool:
  80. i = self.interface
  81. openocd_launch_params = ["openocd"]
  82. self._add_file(openocd_launch_params, i.file)
  83. if self.serial:
  84. self._add_serial(openocd_launch_params, self.serial)
  85. if i.additional_args:
  86. for a in i.additional_args:
  87. self._add_command(openocd_launch_params, a)
  88. self._add_file(openocd_launch_params, "target/stm32wbx.cfg")
  89. self._add_command(openocd_launch_params, "init")
  90. self._add_command(openocd_launch_params, "exit")
  91. self.logger.debug(f"Launching: {' '.join(openocd_launch_params)}")
  92. process = subprocess.Popen(
  93. openocd_launch_params,
  94. stderr=subprocess.STDOUT,
  95. stdout=subprocess.PIPE,
  96. )
  97. # Wait for OpenOCD to end and get the return code
  98. process.wait()
  99. found = process.returncode == 0
  100. if process.stdout:
  101. self.logger.debug(process.stdout.read().decode("utf-8").strip())
  102. return found
  103. def get_name(self) -> str:
  104. return self.interface.name
  105. def blackmagic_find_serial(serial: str):
  106. import serial.tools.list_ports as list_ports
  107. if serial and os.name == "nt":
  108. if not serial.startswith("\\\\.\\"):
  109. serial = f"\\\\.\\{serial}"
  110. ports = list(list_ports.grep("blackmagic"))
  111. if len(ports) == 0:
  112. return None
  113. elif len(ports) > 2:
  114. if serial:
  115. ports = list(
  116. filter(
  117. lambda p: p.serial_number == serial
  118. or p.name == serial
  119. or p.device == serial,
  120. ports,
  121. )
  122. )
  123. if len(ports) == 0:
  124. return None
  125. if len(ports) > 2:
  126. raise Exception("More than one Blackmagic probe found")
  127. # If you're getting any issues with auto lookup, uncomment this
  128. # print("\n".join([f"{p.device} {vars(p)}" for p in ports]))
  129. port = sorted(ports, key=lambda p: f"{p.location}_{p.name}")[0]
  130. if serial:
  131. if (
  132. serial != port.serial_number
  133. and serial != port.name
  134. and serial != port.device
  135. ):
  136. return None
  137. if os.name == "nt":
  138. port.device = f"\\\\.\\{port.device}"
  139. return port.device
  140. def _resolve_hostname(hostname):
  141. try:
  142. return socket.gethostbyname(hostname)
  143. except socket.gaierror:
  144. return None
  145. def blackmagic_find_networked(serial: str):
  146. if not serial:
  147. serial = "blackmagic.local"
  148. # remove the tcp: prefix if it's there
  149. if serial.startswith("tcp:"):
  150. serial = serial[4:]
  151. # remove the port if it's there
  152. if ":" in serial:
  153. serial = serial.split(":")[0]
  154. if not (probe := _resolve_hostname(serial)):
  155. return None
  156. return f"tcp:{probe}:2345"
  157. class BlackmagicProgrammer(Programmer):
  158. def __init__(
  159. self,
  160. port_resolver, # typing.Callable[typing.Union[str, None], typing.Optional[str]]
  161. name: str,
  162. ):
  163. self.port_resolver = port_resolver
  164. self.name = name
  165. self.logger = logging.getLogger("BlackmagicUSB")
  166. self.port: typing.Optional[str] = None
  167. def _add_command(self, params: list[str], command: str):
  168. params.append("-ex")
  169. params.append(command)
  170. def _valid_ip(self, address):
  171. try:
  172. socket.inet_aton(address)
  173. return True
  174. except Exception:
  175. return False
  176. def set_serial(self, serial: str):
  177. if self._valid_ip(serial):
  178. self.port = f"{serial}:2345"
  179. elif ip := _resolve_hostname(serial):
  180. self.port = f"{ip}:2345"
  181. else:
  182. self.port = serial
  183. def flash(self, bin: str) -> bool:
  184. if not self.port:
  185. if not self.probe():
  186. return False
  187. # We can convert .bin to .elf with objcopy:
  188. # arm-none-eabi-objcopy -I binary -O elf32-littlearm --change-section-address=.data=0x8000000 -B arm -S app.bin app.elf
  189. # But I choose to use the .elf file directly because we are flashing our own firmware and it always has an elf predecessor.
  190. elf = bin.replace(".bin", ".elf")
  191. if not os.path.exists(elf):
  192. self.logger.error(
  193. f"Sorry, but Blackmagic can't flash .bin file, and {elf} doesn't exist"
  194. )
  195. return False
  196. # arm-none-eabi-gdb build/f7-firmware-D/firmware.bin
  197. # -ex 'set pagination off'
  198. # -ex 'target extended-remote /dev/cu.usbmodem21201'
  199. # -ex 'set confirm off'
  200. # -ex 'monitor swdp_scan'
  201. # -ex 'attach 1'
  202. # -ex 'set mem inaccessible-by-default off'
  203. # -ex 'load'
  204. # -ex 'compare-sections'
  205. # -ex 'quit'
  206. gdb_launch_params = ["arm-none-eabi-gdb", elf]
  207. self._add_command(gdb_launch_params, f"target extended-remote {self.port}")
  208. self._add_command(gdb_launch_params, "set pagination off")
  209. self._add_command(gdb_launch_params, "set confirm off")
  210. self._add_command(gdb_launch_params, "monitor swdp_scan")
  211. self._add_command(gdb_launch_params, "attach 1")
  212. self._add_command(gdb_launch_params, "set mem inaccessible-by-default off")
  213. self._add_command(gdb_launch_params, "load")
  214. self._add_command(gdb_launch_params, "compare-sections")
  215. self._add_command(gdb_launch_params, "quit")
  216. self.logger.debug(f"Launching: {' '.join(gdb_launch_params)}")
  217. process = subprocess.Popen(
  218. gdb_launch_params,
  219. stdout=subprocess.PIPE,
  220. stderr=subprocess.STDOUT,
  221. )
  222. while process.poll() is None:
  223. time.sleep(0.5)
  224. print(".", end="", flush=True)
  225. print()
  226. if not process.stdout:
  227. return False
  228. output = process.stdout.read().decode("utf-8").strip()
  229. flashed = "Loading section .text," in output
  230. # Check flash verification
  231. if "MIS-MATCHED!" in output:
  232. flashed = False
  233. if "target image does not match the loaded file" in output:
  234. flashed = False
  235. if not flashed:
  236. self.logger.error("Blackmagic failed to flash")
  237. self.logger.error(output)
  238. return flashed
  239. def probe(self) -> bool:
  240. if not (port := self.port_resolver(self.port)):
  241. return False
  242. self.port = port
  243. return True
  244. def get_name(self) -> str:
  245. return self.name
  246. programmers: list[Programmer] = [
  247. OpenOCDProgrammer(
  248. OpenOCDInterface(
  249. "cmsis-dap",
  250. "interface/cmsis-dap.cfg",
  251. "cmsis_dap_serial",
  252. ["transport select swd"],
  253. ),
  254. ),
  255. OpenOCDProgrammer(
  256. OpenOCDInterface(
  257. "stlink", "interface/stlink.cfg", "hla_serial", ["transport select hla_swd"]
  258. ),
  259. ),
  260. BlackmagicProgrammer(blackmagic_find_serial, "blackmagic_usb"),
  261. ]
  262. network_programmers = [
  263. BlackmagicProgrammer(blackmagic_find_networked, "blackmagic_wifi")
  264. ]
  265. class Main(App):
  266. def init(self):
  267. self.subparsers = self.parser.add_subparsers(help="sub-command help")
  268. self.parser_flash = self.subparsers.add_parser("flash", help="Flash a binary")
  269. self.parser_flash.add_argument(
  270. "bin",
  271. type=str,
  272. help="Binary to flash",
  273. )
  274. interfaces = [i.get_name() for i in programmers]
  275. interfaces.extend([i.get_name() for i in network_programmers])
  276. self.parser_flash.add_argument(
  277. "--interface",
  278. choices=interfaces,
  279. type=str,
  280. help="Interface to use",
  281. )
  282. self.parser_flash.add_argument(
  283. "--serial",
  284. type=str,
  285. help="Serial number or port of the programmer",
  286. )
  287. self.parser_flash.set_defaults(func=self.flash)
  288. def _search_interface(self, serial: typing.Optional[str]) -> list[Programmer]:
  289. found_programmers = []
  290. for p in programmers:
  291. name = p.get_name()
  292. if serial:
  293. p.set_serial(serial)
  294. self.logger.debug(f"Trying {name} with {serial}")
  295. else:
  296. self.logger.debug(f"Trying {name}")
  297. if p.probe():
  298. self.logger.debug(f"Found {name}")
  299. found_programmers += [p]
  300. else:
  301. self.logger.debug(f"Failed to probe {name}")
  302. return found_programmers
  303. def _search_network_interface(
  304. self, serial: typing.Optional[str]
  305. ) -> list[Programmer]:
  306. found_programmers = []
  307. for p in network_programmers:
  308. name = p.get_name()
  309. if serial:
  310. p.set_serial(serial)
  311. self.logger.debug(f"Trying {name} with {serial}")
  312. else:
  313. self.logger.debug(f"Trying {name}")
  314. if p.probe():
  315. self.logger.debug(f"Found {name}")
  316. found_programmers += [p]
  317. else:
  318. self.logger.debug(f"Failed to probe {name}")
  319. return found_programmers
  320. def flash(self):
  321. start_time = time.time()
  322. bin_path = os.path.abspath(self.args.bin)
  323. if not os.path.exists(bin_path):
  324. self.logger.error(f"Binary file not found: {bin_path}")
  325. return 1
  326. if self.args.interface:
  327. i_name = self.args.interface
  328. interfaces = [p for p in programmers if p.get_name() == i_name]
  329. if len(interfaces) == 0:
  330. interfaces = [p for p in network_programmers if p.get_name() == i_name]
  331. else:
  332. self.logger.info("Probing for interfaces...")
  333. interfaces = self._search_interface(self.args.serial)
  334. if len(interfaces) == 0:
  335. # Probe network blackmagic
  336. self.logger.info("Probing for network interfaces...")
  337. interfaces = self._search_network_interface(self.args.serial)
  338. if len(interfaces) == 0:
  339. self.logger.error("No interface found")
  340. return 1
  341. if len(interfaces) > 1:
  342. self.logger.error("Multiple interfaces found: ")
  343. self.logger.error(
  344. f"Please specify '--interface={[i.get_name() for i in interfaces]}'"
  345. )
  346. return 1
  347. interface = interfaces[0]
  348. if self.args.serial:
  349. interface.set_serial(self.args.serial)
  350. self.logger.info(
  351. f"Flashing {bin_path} via {interface.get_name()} with {self.args.serial}"
  352. )
  353. else:
  354. self.logger.info(f"Flashing {bin_path} via {interface.get_name()}")
  355. if not interface.flash(bin_path):
  356. self.logger.error(f"Failed to flash via {interface.get_name()}")
  357. return 1
  358. flash_time = time.time() - start_time
  359. bin_size = os.path.getsize(bin_path)
  360. self.logger.info(f"Flashed successfully in {flash_time:.2f}s")
  361. self.logger.info(f"Effective speed: {bin_size / flash_time / 1024:.2f} KiB/s")
  362. return 0
  363. if __name__ == "__main__":
  364. Main()()