Procházet zdrojové kódy

Script that can find programmer and flash firmware via it. (#2193)

* Init
* Fallback to networked interface
* remove unneeded cmsis_dap_backend
* serial number
* windows :(
* remove jlink, fix path handling
* scripts: program: path normalization
* scripts: program: path normalization: second encounter

Co-authored-by: hedger <hedger@nanode.su>
Co-authored-by: あく <alleteam@gmail.com>
Sergey Gavrilov před 2 roky
rodič
revize
6e179bda1f
1 změnil soubory, kde provedl 459 přidání a 0 odebrání
  1. 459 0
      scripts/program.py

+ 459 - 0
scripts/program.py

@@ -0,0 +1,459 @@
+#!/usr/bin/env python3
+import typing
+import subprocess
+import logging
+import time
+import os
+import socket
+
+from abc import ABC, abstractmethod
+from dataclasses import dataclass
+from flipper.app import App
+
+
+class Programmer(ABC):
+    @abstractmethod
+    def flash(self, bin: str) -> bool:
+        pass
+
+    @abstractmethod
+    def probe(self) -> bool:
+        pass
+
+    @abstractmethod
+    def get_name(self) -> str:
+        pass
+
+    @abstractmethod
+    def set_serial(self, serial: str):
+        pass
+
+
+@dataclass
+class OpenOCDInterface:
+    name: str
+    file: str
+    serial_cmd: str
+    additional_args: typing.Optional[list[str]] = None
+
+
+class OpenOCDProgrammer(Programmer):
+    def __init__(self, interface: OpenOCDInterface):
+        self.interface = interface
+        self.logger = logging.getLogger("OpenOCD")
+        self.serial: typing.Optional[str] = None
+
+    def _add_file(self, params: list[str], file: str):
+        params.append("-f")
+        params.append(file)
+
+    def _add_command(self, params: list[str], command: str):
+        params.append("-c")
+        params.append(command)
+
+    def _add_serial(self, params: list[str], serial: str):
+        self._add_command(params, f"{self.interface.serial_cmd} {serial}")
+
+    def set_serial(self, serial: str):
+        self.serial = serial
+
+    def flash(self, bin: str) -> bool:
+        i = self.interface
+
+        if os.altsep:
+            bin = bin.replace(os.sep, os.altsep)
+
+        openocd_launch_params = ["openocd"]
+        self._add_file(openocd_launch_params, i.file)
+        if self.serial:
+            self._add_serial(openocd_launch_params, self.serial)
+        if i.additional_args:
+            for a in i.additional_args:
+                self._add_command(openocd_launch_params, a)
+        self._add_file(openocd_launch_params, "target/stm32wbx.cfg")
+        self._add_command(openocd_launch_params, "init")
+        self._add_command(openocd_launch_params, f"program {bin} reset exit 0x8000000")
+
+        # join the list of parameters into a string, but add quote if there are spaces
+        openocd_launch_params_string = " ".join(
+            [f'"{p}"' if " " in p else p for p in openocd_launch_params]
+        )
+
+        self.logger.debug(f"Launching: {openocd_launch_params_string}")
+
+        process = subprocess.Popen(
+            openocd_launch_params,
+            stdout=subprocess.PIPE,
+            stderr=subprocess.STDOUT,
+        )
+
+        while process.poll() is None:
+            time.sleep(0.25)
+            print(".", end="", flush=True)
+        print()
+
+        success = process.returncode == 0
+
+        if not success:
+            self.logger.error("OpenOCD failed to flash")
+            if process.stdout:
+                self.logger.error(process.stdout.read().decode("utf-8").strip())
+
+        return success
+
+    def probe(self) -> bool:
+        i = self.interface
+
+        openocd_launch_params = ["openocd"]
+        self._add_file(openocd_launch_params, i.file)
+        if self.serial:
+            self._add_serial(openocd_launch_params, self.serial)
+        if i.additional_args:
+            for a in i.additional_args:
+                self._add_command(openocd_launch_params, a)
+        self._add_file(openocd_launch_params, "target/stm32wbx.cfg")
+        self._add_command(openocd_launch_params, "init")
+        self._add_command(openocd_launch_params, "exit")
+
+        self.logger.debug(f"Launching: {' '.join(openocd_launch_params)}")
+
+        process = subprocess.Popen(
+            openocd_launch_params,
+            stderr=subprocess.STDOUT,
+            stdout=subprocess.PIPE,
+        )
+
+        # Wait for OpenOCD to end and get the return code
+        process.wait()
+        found = process.returncode == 0
+
+        if process.stdout:
+            self.logger.debug(process.stdout.read().decode("utf-8").strip())
+
+        return found
+
+    def get_name(self) -> str:
+        return self.interface.name
+
+
+def blackmagic_find_serial(serial: str):
+    import serial.tools.list_ports as list_ports
+
+    if serial and os.name == "nt":
+        if not serial.startswith("\\\\.\\"):
+            serial = f"\\\\.\\{serial}"
+
+    ports = list(list_ports.grep("blackmagic"))
+    if len(ports) == 0:
+        return None
+    elif len(ports) > 2:
+        if serial:
+            ports = list(
+                filter(
+                    lambda p: p.serial_number == serial
+                    or p.name == serial
+                    or p.device == serial,
+                    ports,
+                )
+            )
+            if len(ports) == 0:
+                return None
+
+        if len(ports) > 2:
+            raise Exception("More than one Blackmagic probe found")
+
+    # If you're getting any issues with auto lookup, uncomment this
+    # print("\n".join([f"{p.device} {vars(p)}" for p in ports]))
+    port = sorted(ports, key=lambda p: f"{p.location}_{p.name}")[0]
+
+    if serial:
+        if (
+            serial != port.serial_number
+            and serial != port.name
+            and serial != port.device
+        ):
+            return None
+
+    if os.name == "nt":
+        port.device = f"\\\\.\\{port.device}"
+    return port.device
+
+
+def _resolve_hostname(hostname):
+    try:
+        return socket.gethostbyname(hostname)
+    except socket.gaierror:
+        return None
+
+
+def blackmagic_find_networked(serial: str):
+    if not serial:
+        serial = "blackmagic.local"
+
+    # remove the tcp: prefix if it's there
+    if serial.startswith("tcp:"):
+        serial = serial[4:]
+
+    # remove the port if it's there
+    if ":" in serial:
+        serial = serial.split(":")[0]
+
+    if not (probe := _resolve_hostname(serial)):
+        return None
+
+    return f"tcp:{probe}:2345"
+
+
+class BlackmagicProgrammer(Programmer):
+    def __init__(
+        self,
+        port_resolver,  # typing.Callable[typing.Union[str, None], typing.Optional[str]]
+        name: str,
+    ):
+        self.port_resolver = port_resolver
+        self.name = name
+        self.logger = logging.getLogger("BlackmagicUSB")
+        self.port: typing.Optional[str] = None
+
+    def _add_command(self, params: list[str], command: str):
+        params.append("-ex")
+        params.append(command)
+
+    def _valid_ip(self, address):
+        try:
+            socket.inet_aton(address)
+            return True
+        except:
+            return False
+
+    def set_serial(self, serial: str):
+        if self._valid_ip(serial):
+            self.port = f"{serial}:2345"
+        elif ip := _resolve_hostname(serial):
+            self.port = f"{ip}:2345"
+        else:
+            self.port = serial
+
+    def flash(self, bin: str) -> bool:
+        if not self.port:
+            if not self.probe():
+                return False
+
+        # We can convert .bin to .elf with objcopy:
+        # arm-none-eabi-objcopy -I binary -O elf32-littlearm --change-section-address=.data=0x8000000 -B arm -S app.bin app.elf
+        # But I choose to use the .elf file directly because we are flashing our own firmware and it always has an elf predecessor.
+        elf = bin.replace(".bin", ".elf")
+        if not os.path.exists(elf):
+            self.logger.error(
+                f"Sorry, but Blackmagic can't flash .bin file, and {elf} doesn't exist"
+            )
+            return False
+
+        # arm-none-eabi-gdb build/f7-firmware-D/firmware.bin
+        # -ex 'set pagination off'
+        # -ex 'target extended-remote /dev/cu.usbmodem21201'
+        # -ex 'set confirm off'
+        # -ex 'monitor swdp_scan'
+        # -ex 'attach 1'
+        # -ex 'set mem inaccessible-by-default off'
+        # -ex 'load'
+        # -ex 'compare-sections'
+        # -ex 'quit'
+
+        gdb_launch_params = ["arm-none-eabi-gdb", elf]
+        self._add_command(gdb_launch_params, f"target extended-remote {self.port}")
+        self._add_command(gdb_launch_params, "set pagination off")
+        self._add_command(gdb_launch_params, "set confirm off")
+        self._add_command(gdb_launch_params, "monitor swdp_scan")
+        self._add_command(gdb_launch_params, "attach 1")
+        self._add_command(gdb_launch_params, "set mem inaccessible-by-default off")
+        self._add_command(gdb_launch_params, "load")
+        self._add_command(gdb_launch_params, "compare-sections")
+        self._add_command(gdb_launch_params, "quit")
+
+        self.logger.debug(f"Launching: {' '.join(gdb_launch_params)}")
+
+        process = subprocess.Popen(
+            gdb_launch_params,
+            stdout=subprocess.PIPE,
+            stderr=subprocess.STDOUT,
+        )
+
+        while process.poll() is None:
+            time.sleep(0.5)
+            print(".", end="", flush=True)
+        print()
+
+        if not process.stdout:
+            return False
+
+        output = process.stdout.read().decode("utf-8").strip()
+        flashed = "Loading section .text," in output
+
+        # Check flash verification
+        if "MIS-MATCHED!" in output:
+            flashed = False
+
+        if "target image does not match the loaded file" in output:
+            flashed = False
+
+        if not flashed:
+            self.logger.error("Blackmagic failed to flash")
+            self.logger.error(output)
+
+        return flashed
+
+    def probe(self) -> bool:
+        if not (port := self.port_resolver(self.port)):
+            return False
+
+        self.port = port
+        return True
+
+    def get_name(self) -> str:
+        return self.name
+
+
+programmers: list[Programmer] = [
+    OpenOCDProgrammer(
+        OpenOCDInterface(
+            "cmsis-dap",
+            "interface/cmsis-dap.cfg",
+            "cmsis_dap_serial",
+            ["transport select swd"],
+        ),
+    ),
+    OpenOCDProgrammer(
+        OpenOCDInterface(
+            "stlink", "interface/stlink.cfg", "hla_serial", ["transport select hla_swd"]
+        ),
+    ),
+    BlackmagicProgrammer(blackmagic_find_serial, "blackmagic_usb"),
+]
+
+network_programmers = [
+    BlackmagicProgrammer(blackmagic_find_networked, "blackmagic_wifi")
+]
+
+
+class Main(App):
+    def init(self):
+        self.subparsers = self.parser.add_subparsers(help="sub-command help")
+        self.parser_flash = self.subparsers.add_parser("flash", help="Flash a binary")
+        self.parser_flash.add_argument(
+            "bin",
+            type=str,
+            help="Binary to flash",
+        )
+        interfaces = [i.get_name() for i in programmers]
+        interfaces.extend([i.get_name() for i in network_programmers])
+        self.parser_flash.add_argument(
+            "--interface",
+            choices=interfaces,
+            type=str,
+            help="Interface to use",
+        )
+        self.parser_flash.add_argument(
+            "--serial",
+            type=str,
+            help="Serial number or port of the programmer",
+        )
+        self.parser_flash.set_defaults(func=self.flash)
+
+    def _search_interface(self, serial: typing.Optional[str]) -> list[Programmer]:
+        found_programmers = []
+
+        for p in programmers:
+            name = p.get_name()
+            if serial:
+                p.set_serial(serial)
+                self.logger.debug(f"Trying {name} with {serial}")
+            else:
+                self.logger.debug(f"Trying {name}")
+
+            if p.probe():
+                self.logger.debug(f"Found {name}")
+                found_programmers += [p]
+            else:
+                self.logger.debug(f"Failed to probe {name}")
+
+        return found_programmers
+
+    def _search_network_interface(
+        self, serial: typing.Optional[str]
+    ) -> list[Programmer]:
+        found_programmers = []
+
+        for p in network_programmers:
+            name = p.get_name()
+
+            if serial:
+                p.set_serial(serial)
+                self.logger.debug(f"Trying {name} with {serial}")
+            else:
+                self.logger.debug(f"Trying {name}")
+
+            if p.probe():
+                self.logger.debug(f"Found {name}")
+                found_programmers += [p]
+            else:
+                self.logger.debug(f"Failed to probe {name}")
+
+        return found_programmers
+
+    def flash(self):
+        start_time = time.time()
+        bin_path = os.path.abspath(self.args.bin)
+
+        if not os.path.exists(bin_path):
+            self.logger.error(f"Binary file not found: {bin_path}")
+            return 1
+
+        if self.args.interface:
+            i_name = self.args.interface
+            interfaces = [p for p in programmers if p.get_name() == i_name]
+            if len(interfaces) == 0:
+                interfaces = [p for p in network_programmers if p.get_name() == i_name]
+        else:
+            self.logger.info(f"Probing for interfaces...")
+            interfaces = self._search_interface(self.args.serial)
+
+            if len(interfaces) == 0:
+                # Probe network blackmagic
+                self.logger.info(f"Probing for network interfaces...")
+                interfaces = self._search_network_interface(self.args.serial)
+
+            if len(interfaces) == 0:
+                self.logger.error("No interface found")
+                return 1
+
+            if len(interfaces) > 1:
+                self.logger.error("Multiple interfaces found: ")
+                self.logger.error(
+                    f"Please specify '--interface={[i.get_name() for i in interfaces]}'"
+                )
+                return 1
+
+        interface = interfaces[0]
+
+        if self.args.serial:
+            interface.set_serial(self.args.serial)
+            self.logger.info(
+                f"Flashing {bin_path} via {interface.get_name()} with {self.args.serial}"
+            )
+        else:
+            self.logger.info(f"Flashing {bin_path} via {interface.get_name()}")
+
+        if not interface.flash(bin_path):
+            self.logger.error(f"Failed to flash via {interface.get_name()}")
+            return 1
+
+        flash_time = time.time() - start_time
+        bin_size = os.path.getsize(bin_path)
+        self.logger.info(f"Flashed successfully in {flash_time:.2f}s")
+        self.logger.info(f"Effective speed: {bin_size / flash_time / 1024:.2f} KiB/s")
+        return 0
+
+
+if __name__ == "__main__":
+    Main()()