sdk.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519
  1. import operator
  2. import os
  3. import csv
  4. import operator
  5. from enum import Enum, auto
  6. from typing import List, Set, ClassVar, Any
  7. from dataclasses import dataclass, field
  8. from cxxheaderparser.parser import CxxParser
  9. # 'Fixing' complaints about typedefs
  10. CxxParser._fundamentals.discard("wchar_t")
  11. from cxxheaderparser.types import (
  12. EnumDecl,
  13. Field,
  14. ForwardDecl,
  15. FriendDecl,
  16. Function,
  17. Method,
  18. Typedef,
  19. UsingAlias,
  20. UsingDecl,
  21. Variable,
  22. Pointer,
  23. Type,
  24. PQName,
  25. NameSpecifier,
  26. FundamentalSpecifier,
  27. Parameter,
  28. Array,
  29. Value,
  30. Token,
  31. FunctionType,
  32. )
  33. from cxxheaderparser.parserstate import (
  34. State,
  35. EmptyBlockState,
  36. ClassBlockState,
  37. ExternBlockState,
  38. NamespaceBlockState,
  39. )
  40. @dataclass(frozen=True)
  41. class ApiEntryFunction:
  42. name: str
  43. returns: str
  44. params: str
  45. csv_type: ClassVar[str] = "Function"
  46. def dictify(self):
  47. return dict(name=self.name, type=self.returns, params=self.params)
  48. @dataclass(frozen=True)
  49. class ApiEntryVariable:
  50. name: str
  51. var_type: str
  52. csv_type: ClassVar[str] = "Variable"
  53. def dictify(self):
  54. return dict(name=self.name, type=self.var_type, params=None)
  55. @dataclass(frozen=True)
  56. class ApiHeader:
  57. name: str
  58. csv_type: ClassVar[str] = "Header"
  59. def dictify(self):
  60. return dict(name=self.name, type=None, params=None)
  61. @dataclass
  62. class ApiEntries:
  63. # These are sets, to avoid creating duplicates when we have multiple
  64. # declarations with same signature
  65. functions: Set[ApiEntryFunction] = field(default_factory=set)
  66. variables: Set[ApiEntryVariable] = field(default_factory=set)
  67. headers: Set[ApiHeader] = field(default_factory=set)
  68. class SymbolManager:
  69. def __init__(self):
  70. self.api = ApiEntries()
  71. self.name_hashes = set()
  72. # Calculate hash of name and raise exception if it already is in the set
  73. def _name_check(self, name: str):
  74. name_hash = gnu_sym_hash(name)
  75. if name_hash in self.name_hashes:
  76. raise Exception(f"Hash collision on {name}")
  77. self.name_hashes.add(name_hash)
  78. def add_function(self, function_def: ApiEntryFunction):
  79. if function_def in self.api.functions:
  80. return
  81. self._name_check(function_def.name)
  82. self.api.functions.add(function_def)
  83. def add_variable(self, variable_def: ApiEntryVariable):
  84. if variable_def in self.api.variables:
  85. return
  86. self._name_check(variable_def.name)
  87. self.api.variables.add(variable_def)
  88. def add_header(self, header: str):
  89. self.api.headers.add(ApiHeader(header))
  90. def gnu_sym_hash(name: str):
  91. h = 0x1505
  92. for c in name:
  93. h = (h << 5) + h + ord(c)
  94. return str(hex(h))[-8:]
  95. class SdkCollector:
  96. def __init__(self):
  97. self.symbol_manager = SymbolManager()
  98. def add_header_to_sdk(self, header: str):
  99. self.symbol_manager.add_header(header)
  100. def process_source_file_for_sdk(self, file_path: str):
  101. visitor = SdkCxxVisitor(self.symbol_manager)
  102. with open(file_path, "rt") as f:
  103. content = f.read()
  104. parser = CxxParser(file_path, content, visitor, None)
  105. parser.parse()
  106. def get_api(self):
  107. return self.symbol_manager.api
  108. def stringify_array_dimension(size_descr):
  109. if not size_descr:
  110. return ""
  111. return stringify_descr(size_descr)
  112. def stringify_array_descr(type_descr):
  113. assert isinstance(type_descr, Array)
  114. return (
  115. stringify_descr(type_descr.array_of),
  116. stringify_array_dimension(type_descr.size),
  117. )
  118. def stringify_descr(type_descr):
  119. if isinstance(type_descr, (NameSpecifier, FundamentalSpecifier)):
  120. return type_descr.name
  121. elif isinstance(type_descr, PQName):
  122. return "::".join(map(stringify_descr, type_descr.segments))
  123. elif isinstance(type_descr, Pointer):
  124. # Hack
  125. if isinstance(type_descr.ptr_to, FunctionType):
  126. return stringify_descr(type_descr.ptr_to)
  127. return f"{stringify_descr(type_descr.ptr_to)}*"
  128. elif isinstance(type_descr, Type):
  129. return (
  130. f"{'const ' if type_descr.const else ''}"
  131. f"{'volatile ' if type_descr.volatile else ''}"
  132. f"{stringify_descr(type_descr.typename)}"
  133. )
  134. elif isinstance(type_descr, Parameter):
  135. return stringify_descr(type_descr.type)
  136. elif isinstance(type_descr, Array):
  137. # Hack for 2d arrays
  138. if isinstance(type_descr.array_of, Array):
  139. argtype, dimension = stringify_array_descr(type_descr.array_of)
  140. return (
  141. f"{argtype}[{stringify_array_dimension(type_descr.size)}][{dimension}]"
  142. )
  143. return f"{stringify_descr(type_descr.array_of)}[{stringify_array_dimension(type_descr.size)}]"
  144. elif isinstance(type_descr, Value):
  145. return " ".join(map(stringify_descr, type_descr.tokens))
  146. elif isinstance(type_descr, FunctionType):
  147. return f"{stringify_descr(type_descr.return_type)} (*)({', '.join(map(stringify_descr, type_descr.parameters))})"
  148. elif isinstance(type_descr, Token):
  149. return type_descr.value
  150. elif type_descr is None:
  151. return ""
  152. else:
  153. raise Exception("unsupported type_descr: %s" % type_descr)
  154. class SdkCxxVisitor:
  155. def __init__(self, symbol_manager: SymbolManager):
  156. self.api = symbol_manager
  157. def on_variable(self, state: State, v: Variable) -> None:
  158. if not v.extern:
  159. return
  160. self.api.add_variable(
  161. ApiEntryVariable(
  162. stringify_descr(v.name),
  163. stringify_descr(v.type),
  164. )
  165. )
  166. def on_function(self, state: State, fn: Function) -> None:
  167. if fn.inline or fn.has_body:
  168. return
  169. self.api.add_function(
  170. ApiEntryFunction(
  171. stringify_descr(fn.name),
  172. stringify_descr(fn.return_type),
  173. ", ".join(map(stringify_descr, fn.parameters))
  174. + (", ..." if fn.vararg else ""),
  175. )
  176. )
  177. def on_define(self, state: State, content: str) -> None:
  178. pass
  179. def on_pragma(self, state: State, content: str) -> None:
  180. pass
  181. def on_include(self, state: State, filename: str) -> None:
  182. pass
  183. def on_empty_block_start(self, state: EmptyBlockState) -> None:
  184. pass
  185. def on_empty_block_end(self, state: EmptyBlockState) -> None:
  186. pass
  187. def on_extern_block_start(self, state: ExternBlockState) -> None:
  188. pass
  189. def on_extern_block_end(self, state: ExternBlockState) -> None:
  190. pass
  191. def on_namespace_start(self, state: NamespaceBlockState) -> None:
  192. pass
  193. def on_namespace_end(self, state: NamespaceBlockState) -> None:
  194. pass
  195. def on_forward_decl(self, state: State, fdecl: ForwardDecl) -> None:
  196. pass
  197. def on_typedef(self, state: State, typedef: Typedef) -> None:
  198. pass
  199. def on_using_namespace(self, state: State, namespace: List[str]) -> None:
  200. pass
  201. def on_using_alias(self, state: State, using: UsingAlias) -> None:
  202. pass
  203. def on_using_declaration(self, state: State, using: UsingDecl) -> None:
  204. pass
  205. def on_enum(self, state: State, enum: EnumDecl) -> None:
  206. pass
  207. def on_class_start(self, state: ClassBlockState) -> None:
  208. pass
  209. def on_class_field(self, state: State, f: Field) -> None:
  210. pass
  211. def on_class_method(self, state: ClassBlockState, method: Method) -> None:
  212. pass
  213. def on_class_friend(self, state: ClassBlockState, friend: FriendDecl) -> None:
  214. pass
  215. def on_class_end(self, state: ClassBlockState) -> None:
  216. pass
  217. @dataclass(frozen=True)
  218. class SdkVersion:
  219. major: int = 0
  220. minor: int = 0
  221. csv_type: ClassVar[str] = "Version"
  222. def __str__(self) -> str:
  223. return f"{self.major}.{self.minor}"
  224. def as_int(self) -> int:
  225. return ((self.major & 0xFFFF) << 16) | (self.minor & 0xFFFF)
  226. @staticmethod
  227. def from_str(s: str) -> "SdkVersion":
  228. major, minor = s.split(".")
  229. return SdkVersion(int(major), int(minor))
  230. def dictify(self) -> dict:
  231. return dict(name=str(self), type=None, params=None)
  232. class VersionBump(Enum):
  233. NONE = auto()
  234. MAJOR = auto()
  235. MINOR = auto()
  236. class ApiEntryState(Enum):
  237. PENDING = "?"
  238. APPROVED = "+"
  239. DISABLED = "-"
  240. # Special value for API version entry so users have less incentive to edit it
  241. VERSION_PENDING = "v"
  242. # Class that stores all known API entries, both enabled and disabled.
  243. # Also keeps track of API versioning
  244. # Allows comparison and update from newly-generated API
  245. class SdkCache:
  246. CSV_FIELD_NAMES = ("entry", "status", "name", "type", "params")
  247. def __init__(self, cache_file: str, load_version_only=False):
  248. self.cache_file_name = cache_file
  249. self.version = SdkVersion(0, 0)
  250. self.sdk = ApiEntries()
  251. self.disabled_entries = set()
  252. self.new_entries = set()
  253. self.loaded_dirty_version = False
  254. self.version_action = VersionBump.NONE
  255. self._load_version_only = load_version_only
  256. self.load_cache()
  257. def is_buildable(self) -> bool:
  258. return (
  259. self.version != SdkVersion(0, 0)
  260. and self.version_action == VersionBump.NONE
  261. and not self._have_pending_entries()
  262. )
  263. def _filter_enabled(self, sdk_entries):
  264. return sorted(
  265. filter(lambda e: e not in self.disabled_entries, sdk_entries),
  266. key=operator.attrgetter("name"),
  267. )
  268. def get_valid_names(self):
  269. syms = set(map(lambda e: e.name, self.get_functions()))
  270. syms.update(map(lambda e: e.name, self.get_variables()))
  271. return syms
  272. def get_functions(self):
  273. return self._filter_enabled(self.sdk.functions)
  274. def get_variables(self):
  275. return self._filter_enabled(self.sdk.variables)
  276. def get_headers(self):
  277. return self._filter_enabled(self.sdk.headers)
  278. def _get_entry_status(self, entry) -> str:
  279. if entry in self.disabled_entries:
  280. return ApiEntryState.DISABLED
  281. elif entry in self.new_entries:
  282. if isinstance(entry, SdkVersion):
  283. return ApiEntryState.VERSION_PENDING
  284. return ApiEntryState.PENDING
  285. else:
  286. return ApiEntryState.APPROVED
  287. def _format_entry(self, obj):
  288. obj_dict = obj.dictify()
  289. obj_dict.update(
  290. dict(
  291. entry=obj.csv_type,
  292. status=self._get_entry_status(obj).value,
  293. )
  294. )
  295. return obj_dict
  296. def save(self) -> None:
  297. if self._load_version_only:
  298. raise Exception("Only SDK version was loaded, cannot save")
  299. if self.version_action == VersionBump.MINOR:
  300. self.version = SdkVersion(self.version.major, self.version.minor + 1)
  301. elif self.version_action == VersionBump.MAJOR:
  302. self.version = SdkVersion(self.version.major + 1, 0)
  303. if self._have_pending_entries():
  304. self.new_entries.add(self.version)
  305. print(
  306. f"API version is still WIP: {self.version}. Review the changes and re-run command."
  307. )
  308. print(f"Entries to review:")
  309. print(
  310. "\n".join(
  311. map(
  312. str,
  313. filter(
  314. lambda e: not isinstance(e, SdkVersion), self.new_entries
  315. ),
  316. )
  317. )
  318. )
  319. else:
  320. print(f"API version {self.version} is up to date")
  321. regenerate_csv = (
  322. self.loaded_dirty_version
  323. or self._have_pending_entries()
  324. or self.version_action != VersionBump.NONE
  325. )
  326. if regenerate_csv:
  327. str_cache_entries = [self.version]
  328. name_getter = operator.attrgetter("name")
  329. str_cache_entries.extend(sorted(self.sdk.headers, key=name_getter))
  330. str_cache_entries.extend(sorted(self.sdk.functions, key=name_getter))
  331. str_cache_entries.extend(sorted(self.sdk.variables, key=name_getter))
  332. with open(self.cache_file_name, "wt", newline="") as f:
  333. writer = csv.DictWriter(f, fieldnames=SdkCache.CSV_FIELD_NAMES)
  334. writer.writeheader()
  335. for entry in str_cache_entries:
  336. writer.writerow(self._format_entry(entry))
  337. def _process_entry(self, entry_dict: dict) -> None:
  338. entry_class = entry_dict["entry"]
  339. entry_status = entry_dict["status"]
  340. entry_name = entry_dict["name"]
  341. entry = None
  342. if entry_class == SdkVersion.csv_type:
  343. self.version = SdkVersion.from_str(entry_name)
  344. if entry_status == ApiEntryState.VERSION_PENDING.value:
  345. self.loaded_dirty_version = True
  346. elif entry_class == ApiHeader.csv_type:
  347. self.sdk.headers.add(entry := ApiHeader(entry_name))
  348. elif entry_class == ApiEntryFunction.csv_type:
  349. self.sdk.functions.add(
  350. entry := ApiEntryFunction(
  351. entry_name,
  352. entry_dict["type"],
  353. entry_dict["params"],
  354. )
  355. )
  356. elif entry_class == ApiEntryVariable.csv_type:
  357. self.sdk.variables.add(
  358. entry := ApiEntryVariable(entry_name, entry_dict["type"])
  359. )
  360. else:
  361. print(entry_dict)
  362. raise Exception("Unknown entry type: %s" % entry_class)
  363. if entry is None:
  364. return
  365. if entry_status == ApiEntryState.DISABLED.value:
  366. self.disabled_entries.add(entry)
  367. elif entry_status == ApiEntryState.PENDING.value:
  368. self.new_entries.add(entry)
  369. def load_cache(self) -> None:
  370. if not os.path.exists(self.cache_file_name):
  371. raise Exception(
  372. f"Cannot load symbol cache '{self.cache_file_name}'! File does not exist"
  373. )
  374. with open(self.cache_file_name, "rt") as f:
  375. reader = csv.DictReader(f)
  376. for row in reader:
  377. self._process_entry(row)
  378. if self._load_version_only and row.get("entry") == SdkVersion.csv_type:
  379. break
  380. def _have_pending_entries(self) -> bool:
  381. return any(
  382. filter(
  383. lambda e: not isinstance(e, SdkVersion),
  384. self.new_entries,
  385. )
  386. )
  387. def sync_sets(
  388. self, known_set: Set[Any], new_set: Set[Any], update_version: bool = True
  389. ):
  390. new_entries = new_set - known_set
  391. if new_entries:
  392. print(f"New: {new_entries}")
  393. known_set |= new_entries
  394. self.new_entries |= new_entries
  395. if update_version and self.version_action == VersionBump.NONE:
  396. self.version_action = VersionBump.MINOR
  397. removed_entries = known_set - new_set
  398. if removed_entries:
  399. print(f"Removed: {removed_entries}")
  400. known_set -= removed_entries
  401. # If any of removed entries was a part of active API, that's a major bump
  402. if update_version and any(
  403. filter(
  404. lambda e: e not in self.disabled_entries
  405. and e not in self.new_entries,
  406. removed_entries,
  407. )
  408. ):
  409. self.version_action = VersionBump.MAJOR
  410. self.disabled_entries -= removed_entries
  411. self.new_entries -= removed_entries
  412. def validate_api(self, api: ApiEntries) -> None:
  413. self.sync_sets(self.sdk.headers, api.headers, False)
  414. self.sync_sets(self.sdk.functions, api.functions)
  415. self.sync_sets(self.sdk.variables, api.variables)