sdk.py 16 KB


  1. import operator
  2. import os
  3. import csv
  4. import operator
  5. from enum import Enum, auto
  6. from typing import List, Set, ClassVar, Any
  7. from dataclasses import dataclass, field
  8. from ansi.color import fg
  9. from cxxheaderparser.parser import CxxParser
  10. # 'Fixing' complaints about typedefs
  11. CxxParser._fundamentals.discard("wchar_t")
  12. from cxxheaderparser.types import (
  13. EnumDecl,
  14. Field,
  15. ForwardDecl,
  16. FriendDecl,
  17. Function,
  18. Method,
  19. Typedef,
  20. UsingAlias,
  21. UsingDecl,
  22. Variable,
  23. Pointer,
  24. Type,
  25. PQName,
  26. NameSpecifier,
  27. FundamentalSpecifier,
  28. Parameter,
  29. Array,
  30. Value,
  31. Token,
  32. FunctionType,
  33. )
  34. from cxxheaderparser.parserstate import (
  35. State,
  36. EmptyBlockState,
  37. ClassBlockState,
  38. ExternBlockState,
  39. NamespaceBlockState,
  40. )
  41. @dataclass(frozen=True)
  42. class ApiEntryFunction:
  43. name: str
  44. returns: str
  45. params: str
  46. csv_type: ClassVar[str] = "Function"
  47. def dictify(self):
  48. return dict(name=self.name, type=self.returns, params=self.params)
  49. @dataclass(frozen=True)
  50. class ApiEntryVariable:
  51. name: str
  52. var_type: str
  53. csv_type: ClassVar[str] = "Variable"
  54. def dictify(self):
  55. return dict(name=self.name, type=self.var_type, params=None)
  56. @dataclass(frozen=True)
  57. class ApiHeader:
  58. name: str
  59. csv_type: ClassVar[str] = "Header"
  60. def dictify(self):
  61. return dict(name=self.name, type=None, params=None)
  62. @dataclass
  63. class ApiEntries:
  64. # These are sets, to avoid creating duplicates when we have multiple
  65. # declarations with same signature
  66. functions: Set[ApiEntryFunction] = field(default_factory=set)
  67. variables: Set[ApiEntryVariable] = field(default_factory=set)
  68. headers: Set[ApiHeader] = field(default_factory=set)
  69. class SymbolManager:
  70. def __init__(self):
  71. self.api = ApiEntries()
  72. self.name_hashes = set()
  73. # Calculate hash of name and raise exception if it already is in the set
  74. def _name_check(self, name: str):
  75. name_hash = gnu_sym_hash(name)
  76. if name_hash in self.name_hashes:
  77. raise Exception(f"Hash collision on {name}")
  78. self.name_hashes.add(name_hash)
  79. def add_function(self, function_def: ApiEntryFunction):
  80. if function_def in self.api.functions:
  81. return
  82. self._name_check(function_def.name)
  83. self.api.functions.add(function_def)
  84. def add_variable(self, variable_def: ApiEntryVariable):
  85. if variable_def in self.api.variables:
  86. return
  87. self._name_check(variable_def.name)
  88. self.api.variables.add(variable_def)
  89. def add_header(self, header: str):
  90. self.api.headers.add(ApiHeader(header))
  91. def gnu_sym_hash(name: str):
  92. h = 0x1505
  93. for c in name:
  94. h = (h << 5) + h + ord(c)
  95. return str(hex(h))[-8:]
  96. class SdkCollector:
  97. def __init__(self):
  98. self.symbol_manager = SymbolManager()
  99. def add_header_to_sdk(self, header: str):
  100. self.symbol_manager.add_header(header)
  101. def process_source_file_for_sdk(self, file_path: str):
  102. visitor = SdkCxxVisitor(self.symbol_manager)
  103. with open(file_path, "rt") as f:
  104. content = f.read()
  105. parser = CxxParser(file_path, content, visitor, None)
  106. parser.parse()
  107. def get_api(self):
  108. return self.symbol_manager.api
  109. def stringify_array_dimension(size_descr):
  110. if not size_descr:
  111. return ""
  112. return stringify_descr(size_descr)
  113. def stringify_array_descr(type_descr):
  114. assert isinstance(type_descr, Array)
  115. return (
  116. stringify_descr(type_descr.array_of),
  117. stringify_array_dimension(type_descr.size),
  118. )
  119. def stringify_descr(type_descr):
  120. if isinstance(type_descr, (NameSpecifier, FundamentalSpecifier)):
  121. return type_descr.name
  122. elif isinstance(type_descr, PQName):
  123. return "::".join(map(stringify_descr, type_descr.segments))
  124. elif isinstance(type_descr, Pointer):
  125. # Hack
  126. if isinstance(type_descr.ptr_to, FunctionType):
  127. return stringify_descr(type_descr.ptr_to)
  128. return f"{stringify_descr(type_descr.ptr_to)}*"
  129. elif isinstance(type_descr, Type):
  130. return (
  131. f"{'const ' if type_descr.const else ''}"
  132. f"{'volatile ' if type_descr.volatile else ''}"
  133. f"{stringify_descr(type_descr.typename)}"
  134. )
  135. elif isinstance(type_descr, Parameter):
  136. return stringify_descr(type_descr.type)
  137. elif isinstance(type_descr, Array):
  138. # Hack for 2d arrays
  139. if isinstance(type_descr.array_of, Array):
  140. argtype, dimension = stringify_array_descr(type_descr.array_of)
  141. return (
  142. f"{argtype}[{stringify_array_dimension(type_descr.size)}][{dimension}]"
  143. )
  144. return f"{stringify_descr(type_descr.array_of)}[{stringify_array_dimension(type_descr.size)}]"
  145. elif isinstance(type_descr, Value):
  146. return " ".join(map(stringify_descr, type_descr.tokens))
  147. elif isinstance(type_descr, FunctionType):
  148. return f"{stringify_descr(type_descr.return_type)} (*)({', '.join(map(stringify_descr, type_descr.parameters))})"
  149. elif isinstance(type_descr, Token):
  150. return type_descr.value
  151. elif type_descr is None:
  152. return ""
  153. else:
  154. raise Exception("unsupported type_descr: %s" % type_descr)
  155. class SdkCxxVisitor:
  156. def __init__(self, symbol_manager: SymbolManager):
  157. self.api = symbol_manager
  158. def on_variable(self, state: State, v: Variable) -> None:
  159. if not v.extern:
  160. return
  161. self.api.add_variable(
  162. ApiEntryVariable(
  163. stringify_descr(v.name),
  164. stringify_descr(v.type),
  165. )
  166. )
  167. def on_function(self, state: State, fn: Function) -> None:
  168. if fn.inline or fn.has_body:
  169. return
  170. self.api.add_function(
  171. ApiEntryFunction(
  172. stringify_descr(fn.name),
  173. stringify_descr(fn.return_type),
  174. ", ".join(map(stringify_descr, fn.parameters))
  175. + (", ..." if fn.vararg else ""),
  176. )
  177. )
  178. def on_define(self, state: State, content: str) -> None:
  179. pass
  180. def on_pragma(self, state: State, content: str) -> None:
  181. pass
  182. def on_include(self, state: State, filename: str) -> None:
  183. pass
  184. def on_empty_block_start(self, state: EmptyBlockState) -> None:
  185. pass
  186. def on_empty_block_end(self, state: EmptyBlockState) -> None:
  187. pass
  188. def on_extern_block_start(self, state: ExternBlockState) -> None:
  189. pass
  190. def on_extern_block_end(self, state: ExternBlockState) -> None:
  191. pass
  192. def on_namespace_start(self, state: NamespaceBlockState) -> None:
  193. pass
  194. def on_namespace_end(self, state: NamespaceBlockState) -> None:
  195. pass
  196. def on_forward_decl(self, state: State, fdecl: ForwardDecl) -> None:
  197. pass
  198. def on_typedef(self, state: State, typedef: Typedef) -> None:
  199. pass
  200. def on_using_namespace(self, state: State, namespace: List[str]) -> None:
  201. pass
  202. def on_using_alias(self, state: State, using: UsingAlias) -> None:
  203. pass
  204. def on_using_declaration(self, state: State, using: UsingDecl) -> None:
  205. pass
  206. def on_enum(self, state: State, enum: EnumDecl) -> None:
  207. pass
  208. def on_class_start(self, state: ClassBlockState) -> None:
  209. pass
  210. def on_class_field(self, state: State, f: Field) -> None:
  211. pass
  212. def on_class_method(self, state: ClassBlockState, method: Method) -> None:
  213. pass
  214. def on_class_friend(self, state: ClassBlockState, friend: FriendDecl) -> None:
  215. pass
  216. def on_class_end(self, state: ClassBlockState) -> None:
  217. pass
  218. @dataclass(frozen=True)
  219. class SdkVersion:
  220. major: int = 0
  221. minor: int = 0
  222. csv_type: ClassVar[str] = "Version"
  223. def __str__(self) -> str:
  224. return f"{self.major}.{self.minor}"
  225. def as_int(self) -> int:
  226. return ((self.major & 0xFFFF) << 16) | (self.minor & 0xFFFF)
  227. @staticmethod
  228. def from_str(s: str) -> "SdkVersion":
  229. major, minor = s.split(".")
  230. return SdkVersion(int(major), int(minor))
  231. def dictify(self) -> dict:
  232. return dict(name=str(self), type=None, params=None)
  233. class VersionBump(Enum):
  234. NONE = auto()
  235. MAJOR = auto()
  236. MINOR = auto()
  237. class ApiEntryState(Enum):
  238. PENDING = "?"
  239. APPROVED = "+"
  240. DISABLED = "-"
  241. # Special value for API version entry so users have less incentive to edit it
  242. VERSION_PENDING = "v"
  243. # Class that stores all known API entries, both enabled and disabled.
  244. # Also keeps track of API versioning
  245. # Allows comparison and update from newly-generated API
  246. class SdkCache:
  247. CSV_FIELD_NAMES = ("entry", "status", "name", "type", "params")
  248. def __init__(self, cache_file: str, load_version_only=False):
  249. self.cache_file_name = cache_file
  250. self.version = SdkVersion(0, 0)
  251. self.sdk = ApiEntries()
  252. self.disabled_entries = set()
  253. self.new_entries = set()
  254. self.loaded_dirty_version = False
  255. self.version_action = VersionBump.NONE
  256. self._load_version_only = load_version_only
  257. self.load_cache()
  258. def is_buildable(self) -> bool:
  259. return (
  260. self.version != SdkVersion(0, 0)
  261. and self.version_action == VersionBump.NONE
  262. and not self._have_pending_entries()
  263. )
  264. def _filter_enabled(self, sdk_entries):
  265. return sorted(
  266. filter(lambda e: e not in self.disabled_entries, sdk_entries),
  267. key=operator.attrgetter("name"),
  268. )
  269. def get_valid_names(self):
  270. syms = set(map(lambda e: e.name, self.get_functions()))
  271. syms.update(map(lambda e: e.name, self.get_variables()))
  272. return syms
  273. def get_functions(self):
  274. return self._filter_enabled(self.sdk.functions)
  275. def get_variables(self):
  276. return self._filter_enabled(self.sdk.variables)
  277. def get_headers(self):
  278. return self._filter_enabled(self.sdk.headers)
  279. def _get_entry_status(self, entry) -> str:
  280. if entry in self.disabled_entries:
  281. return ApiEntryState.DISABLED
  282. elif entry in self.new_entries:
  283. if isinstance(entry, SdkVersion):
  284. return ApiEntryState.VERSION_PENDING
  285. return ApiEntryState.PENDING
  286. else:
  287. return ApiEntryState.APPROVED
  288. def _format_entry(self, obj):
  289. obj_dict = obj.dictify()
  290. obj_dict.update(
  291. dict(
  292. entry=obj.csv_type,
  293. status=self._get_entry_status(obj).value,
  294. )
  295. )
  296. return obj_dict
  297. def save(self) -> None:
  298. if self._load_version_only:
  299. raise Exception("Only SDK version was loaded, cannot save")
  300. if self.version_action == VersionBump.MINOR:
  301. self.version = SdkVersion(self.version.major, self.version.minor + 1)
  302. elif self.version_action == VersionBump.MAJOR:
  303. self.version = SdkVersion(self.version.major + 1, 0)
  304. if self._have_pending_entries():
  305. self.new_entries.add(self.version)
  306. print(
  307. fg.red(
  308. f"API version is still WIP: {self.version}. Review the changes and re-run command."
  309. )
  310. )
  311. print(f"CSV file entries to mark up:")
  312. print(
  313. fg.yellow(
  314. "\n".join(
  315. map(
  316. str,
  317. filter(
  318. lambda e: not isinstance(e, SdkVersion),
  319. self.new_entries,
  320. ),
  321. )
  322. )
  323. )
  324. )
  325. else:
  326. print(fg.green(f"API version {self.version} is up to date"))
  327. regenerate_csv = (
  328. self.loaded_dirty_version
  329. or self._have_pending_entries()
  330. or self.version_action != VersionBump.NONE
  331. )
  332. if regenerate_csv:
  333. str_cache_entries = [self.version]
  334. name_getter = operator.attrgetter("name")
  335. str_cache_entries.extend(sorted(self.sdk.headers, key=name_getter))
  336. str_cache_entries.extend(sorted(self.sdk.functions, key=name_getter))
  337. str_cache_entries.extend(sorted(self.sdk.variables, key=name_getter))
  338. with open(self.cache_file_name, "wt", newline="") as f:
  339. writer = csv.DictWriter(f, fieldnames=SdkCache.CSV_FIELD_NAMES)
  340. writer.writeheader()
  341. for entry in str_cache_entries:
  342. writer.writerow(self._format_entry(entry))
  343. def _process_entry(self, entry_dict: dict) -> None:
  344. entry_class = entry_dict["entry"]
  345. entry_status = entry_dict["status"]
  346. entry_name = entry_dict["name"]
  347. entry = None
  348. if entry_class == SdkVersion.csv_type:
  349. self.version = SdkVersion.from_str(entry_name)
  350. if entry_status == ApiEntryState.VERSION_PENDING.value:
  351. self.loaded_dirty_version = True
  352. elif entry_class == ApiHeader.csv_type:
  353. self.sdk.headers.add(entry := ApiHeader(entry_name))
  354. elif entry_class == ApiEntryFunction.csv_type:
  355. self.sdk.functions.add(
  356. entry := ApiEntryFunction(
  357. entry_name,
  358. entry_dict["type"],
  359. entry_dict["params"],
  360. )
  361. )
  362. elif entry_class == ApiEntryVariable.csv_type:
  363. self.sdk.variables.add(
  364. entry := ApiEntryVariable(entry_name, entry_dict["type"])
  365. )
  366. else:
  367. print(entry_dict)
  368. raise Exception("Unknown entry type: %s" % entry_class)
  369. if entry is None:
  370. return
  371. if entry_status == ApiEntryState.DISABLED.value:
  372. self.disabled_entries.add(entry)
  373. elif entry_status == ApiEntryState.PENDING.value:
  374. self.new_entries.add(entry)
  375. def load_cache(self) -> None:
  376. if not os.path.exists(self.cache_file_name):
  377. raise Exception(
  378. f"Cannot load symbol cache '{self.cache_file_name}'! File does not exist"
  379. )
  380. with open(self.cache_file_name, "rt") as f:
  381. reader = csv.DictReader(f)
  382. for row in reader:
  383. self._process_entry(row)
  384. if self._load_version_only and row.get("entry") == SdkVersion.csv_type:
  385. break
  386. def _have_pending_entries(self) -> bool:
  387. return any(
  388. filter(
  389. lambda e: not isinstance(e, SdkVersion),
  390. self.new_entries,
  391. )
  392. )
  393. def sync_sets(
  394. self, known_set: Set[Any], new_set: Set[Any], update_version: bool = True
  395. ):
  396. new_entries = new_set - known_set
  397. if new_entries:
  398. print(f"New: {new_entries}")
  399. known_set |= new_entries
  400. self.new_entries |= new_entries
  401. if update_version and self.version_action == VersionBump.NONE:
  402. self.version_action = VersionBump.MINOR
  403. removed_entries = known_set - new_set
  404. if removed_entries:
  405. print(f"Removed: {removed_entries}")
  406. known_set -= removed_entries
  407. # If any of removed entries was a part of active API, that's a major bump
  408. if update_version and any(
  409. filter(
  410. lambda e: e not in self.disabled_entries
  411. and e not in self.new_entries,
  412. removed_entries,
  413. )
  414. ):
  415. self.version_action = VersionBump.MAJOR
  416. self.disabled_entries -= removed_entries
  417. self.new_entries -= removed_entries
  418. def validate_api(self, api: ApiEntries) -> None:
  419. self.sync_sets(self.sdk.headers, api.headers, False)
  420. self.sync_sets(self.sdk.functions, api.functions)
  421. self.sync_sets(self.sdk.variables, api.variables)