homeassistant.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346
  1. """Service for communicating with Home Assistant via REST API."""
  2. import logging
  3. from typing import TYPE_CHECKING
  4. from urllib.parse import urlparse
  5. import httpx
  6. if TYPE_CHECKING:
  7. from backend.app.models.smart_plug import SmartPlug
  8. logger = logging.getLogger(__name__)
  9. class HomeAssistantService:
  10. """Service for controlling Home Assistant entities via REST API."""
  11. def __init__(self, timeout: float = 10.0):
  12. self.timeout = timeout
  13. self.base_url: str = ""
  14. self.token: str = ""
  15. def configure(self, url: str, token: str):
  16. """Configure HA connection settings."""
  17. self.base_url = url.rstrip("/") if url else ""
  18. self.token = token or ""
  19. def _headers(self) -> dict:
  20. return {
  21. "Authorization": f"Bearer {self.token}",
  22. "Content-Type": "application/json",
  23. }
  24. async def get_status(self, plug: "SmartPlug") -> dict:
  25. """Get current state of HA entity.
  26. Returns dict with:
  27. - state: "ON" or "OFF" or None if unreachable
  28. - reachable: bool
  29. - device_name: str or None
  30. """
  31. if not self.base_url or not self.token:
  32. return {"state": None, "reachable": False, "device_name": None}
  33. try:
  34. async with httpx.AsyncClient(timeout=self.timeout) as client:
  35. response = await client.get(
  36. f"{self.base_url}/api/states/{plug.ha_entity_id}",
  37. headers=self._headers(),
  38. )
  39. response.raise_for_status()
  40. data = response.json()
  41. state_value = data.get("state", "").lower()
  42. # Normalize to ON/OFF
  43. if state_value == "on":
  44. state = "ON"
  45. elif state_value == "off":
  46. state = "OFF"
  47. else:
  48. state = None
  49. return {
  50. "state": state,
  51. "reachable": True,
  52. "device_name": data.get("attributes", {}).get("friendly_name"),
  53. }
  54. except Exception as e:
  55. logger.warning("Failed to get HA entity state for %s: %s", plug.ha_entity_id, e)
  56. return {"state": None, "reachable": False, "device_name": None}
  57. async def turn_on(self, plug: "SmartPlug") -> bool:
  58. """Turn on HA entity. Returns True if successful."""
  59. success = await self._call_service(plug, "turn_on")
  60. if success:
  61. logger.info("Turned ON HA entity '%s' (%s)", plug.name, plug.ha_entity_id)
  62. return success
  63. async def turn_off(self, plug: "SmartPlug") -> bool:
  64. """Turn off HA entity. Returns True if successful."""
  65. success = await self._call_service(plug, "turn_off")
  66. if success:
  67. logger.info("Turned OFF HA entity '%s' (%s)", plug.name, plug.ha_entity_id)
  68. return success
  69. async def toggle(self, plug: "SmartPlug") -> bool:
  70. """Toggle HA entity. Returns True if successful."""
  71. success = await self._call_service(plug, "toggle")
  72. if success:
  73. logger.info("Toggled HA entity '%s' (%s)", plug.name, plug.ha_entity_id)
  74. return success
  75. async def _call_service(self, plug: "SmartPlug", action: str) -> bool:
  76. """Call HA service on entity."""
  77. if not self.base_url or not self.token or not plug.ha_entity_id:
  78. return False
  79. domain = plug.ha_entity_id.split(".")[0] # "switch", "light", etc.
  80. try:
  81. async with httpx.AsyncClient(timeout=self.timeout) as client:
  82. response = await client.post(
  83. f"{self.base_url}/api/services/{domain}/{action}",
  84. headers=self._headers(),
  85. json={"entity_id": plug.ha_entity_id},
  86. )
  87. response.raise_for_status()
  88. return True
  89. except Exception as e:
  90. logger.warning("Failed to %s HA entity %s: %s", action, plug.ha_entity_id, e)
  91. return False
  92. async def get_energy(self, plug: "SmartPlug") -> dict | None:
  93. """Get energy data from HA sensor entities or switch attributes.
  94. First tries dedicated sensor entities if configured, then falls back
  95. to checking the switch entity's attributes.
  96. Returns dict with energy data or None if not available.
  97. """
  98. if not self.base_url or not self.token:
  99. return None
  100. power = None
  101. today = None
  102. total = None
  103. try:
  104. async with httpx.AsyncClient(timeout=self.timeout) as client:
  105. # Fetch power from dedicated sensor entity if configured
  106. if plug.ha_power_entity:
  107. power = await self._get_sensor_value(client, plug.ha_power_entity)
  108. # Fetch today's energy from dedicated sensor entity if configured
  109. if plug.ha_energy_today_entity:
  110. today = await self._get_sensor_value(client, plug.ha_energy_today_entity)
  111. # Fetch total energy from dedicated sensor entity if configured
  112. if plug.ha_energy_total_entity:
  113. total = await self._get_sensor_value(client, plug.ha_energy_total_entity)
  114. # Fallback: try switch entity attributes (original behavior)
  115. if power is None:
  116. response = await client.get(
  117. f"{self.base_url}/api/states/{plug.ha_entity_id}",
  118. headers=self._headers(),
  119. )
  120. response.raise_for_status()
  121. attrs = response.json().get("attributes", {})
  122. power = attrs.get("current_power_w") or attrs.get("power")
  123. if today is None:
  124. today = attrs.get("today_energy_kwh")
  125. if total is None:
  126. total = attrs.get("total_energy_kwh")
  127. if power is None:
  128. return None
  129. return {
  130. "power": power,
  131. "voltage": None,
  132. "current": None,
  133. "today": today,
  134. "total": total,
  135. "yesterday": None,
  136. "factor": None,
  137. "apparent_power": None,
  138. "reactive_power": None,
  139. }
  140. except Exception as e:
  141. logger.debug("Failed to get HA energy data: %s", e)
  142. return None
  143. async def _get_sensor_value(self, client: httpx.AsyncClient, entity_id: str) -> float | None:
  144. """Fetch numeric value from a HA sensor entity."""
  145. try:
  146. response = await client.get(
  147. f"{self.base_url}/api/states/{entity_id}",
  148. headers=self._headers(),
  149. )
  150. response.raise_for_status()
  151. state = response.json().get("state")
  152. if state and state not in ("unknown", "unavailable"):
  153. return float(state)
  154. except Exception:
  155. pass # Sensor read is best-effort; caller handles None
  156. return None
  157. @staticmethod
  158. def _validate_url(url: str) -> str | None:
  159. """Validate HA URL scheme and block dangerous destinations."""
  160. try:
  161. parsed = urlparse(url)
  162. except ValueError:
  163. return None
  164. if parsed.scheme not in ("http", "https") or not parsed.hostname:
  165. return None
  166. blocked = ("169.254.169.254", "metadata.google.internal", "0.0.0.0") # nosec B104
  167. if parsed.hostname.lower() in blocked or (parsed.hostname or "").startswith("169.254."):
  168. return None
  169. return f"{parsed.scheme}://{parsed.hostname}" + (f":{parsed.port}" if parsed.port else "") + (parsed.path or "")
  170. async def test_connection(self, url: str, token: str) -> dict:
  171. """Test connection to Home Assistant.
  172. Returns dict with:
  173. - success: bool
  174. - message: str or None (HA message on success)
  175. - error: str or None (error message on failure)
  176. """
  177. safe_url = self._validate_url(url)
  178. if not safe_url:
  179. return {"success": False, "message": None, "error": "Invalid Home Assistant URL"}
  180. try:
  181. async with httpx.AsyncClient(timeout=self.timeout) as client:
  182. response = await client.get(
  183. f"{safe_url.rstrip('/')}/api/",
  184. headers={"Authorization": f"Bearer {token}"},
  185. )
  186. response.raise_for_status()
  187. data = response.json()
  188. return {
  189. "success": True,
  190. "message": data.get("message", "Connected"),
  191. "error": None,
  192. }
  193. except httpx.HTTPStatusError as e:
  194. if e.response.status_code == 401:
  195. return {"success": False, "message": None, "error": "Invalid access token"}
  196. return {"success": False, "message": None, "error": f"HTTP {e.response.status_code}"}
  197. except httpx.TimeoutException:
  198. return {"success": False, "message": None, "error": "Connection timeout"}
  199. except httpx.ConnectError:
  200. return {"success": False, "message": None, "error": "Could not connect to Home Assistant"}
  201. except Exception as e:
  202. return {"success": False, "message": None, "error": str(e)}
  203. async def list_entities(self, url: str, token: str, search: str | None = None) -> list[dict]:
  204. """List available entities from HA.
  205. Always filters to switch/light/input_boolean/script — the only domains
  206. the SmartPlugBase.ha_entity_id pattern accepts. When a search query is
  207. provided it narrows the same domain-filtered list by entity_id or
  208. friendly_name substring (case-insensitive).
  209. Previously search bypassed the domain filter, which let users pick a
  210. sensor.* or binary_sensor.* entity from the dropdown that the backend
  211. schema would then reject with the cryptic Pydantic pattern error
  212. (#1388). Picking what you can't save isn't a useful UX.
  213. Returns list of entity dicts with:
  214. - entity_id: str
  215. - friendly_name: str
  216. - state: str
  217. - domain: str
  218. """
  219. # Allowed domains for smart plug control — must mirror the regex in
  220. # backend/app/schemas/smart_plug.py:17 (SmartPlugBase.ha_entity_id).
  221. allowed_domains = {"switch", "light", "input_boolean", "script"}
  222. try:
  223. async with httpx.AsyncClient(timeout=self.timeout) as client:
  224. response = await client.get(
  225. f"{url.rstrip('/')}/api/states",
  226. headers={"Authorization": f"Bearer {token}"},
  227. )
  228. response.raise_for_status()
  229. entities = []
  230. search_lower = search.lower().strip() if search else None
  231. for entity in response.json():
  232. entity_id = entity.get("entity_id", "")
  233. domain = entity_id.split(".")[0] if "." in entity_id else ""
  234. friendly_name = entity.get("attributes", {}).get("friendly_name", entity_id)
  235. if domain not in allowed_domains:
  236. continue
  237. if search_lower and (
  238. search_lower not in entity_id.lower() and search_lower not in friendly_name.lower()
  239. ):
  240. continue
  241. entities.append(
  242. {
  243. "entity_id": entity_id,
  244. "friendly_name": friendly_name,
  245. "state": entity.get("state"),
  246. "domain": domain,
  247. }
  248. )
  249. return sorted(entities, key=lambda x: x["friendly_name"].lower())
  250. except Exception as e:
  251. logger.warning("Failed to list HA entities: %s", e)
  252. return []
  253. async def list_sensor_entities(self, url: str, token: str) -> list[dict]:
  254. """List available sensor entities for energy monitoring.
  255. Returns list of sensor entities with power/energy units.
  256. """
  257. try:
  258. async with httpx.AsyncClient(timeout=self.timeout) as client:
  259. response = await client.get(
  260. f"{url.rstrip('/')}/api/states",
  261. headers={"Authorization": f"Bearer {token}"},
  262. )
  263. response.raise_for_status()
  264. # Valid units for energy monitoring sensors (lowercase for case-insensitive matching)
  265. power_units = {"w", "kw", "mw"}
  266. energy_units = {"kwh", "wh", "mwh"}
  267. valid_units = power_units | energy_units
  268. entities = []
  269. for entity in response.json():
  270. entity_id = entity.get("entity_id", "")
  271. domain = entity_id.split(".")[0] if "." in entity_id else ""
  272. # Filter to sensor domain only
  273. if domain != "sensor":
  274. continue
  275. attrs = entity.get("attributes", {})
  276. unit = attrs.get("unit_of_measurement", "")
  277. # Only include sensors with power/energy units (case-insensitive)
  278. if unit.lower() in valid_units:
  279. entities.append(
  280. {
  281. "entity_id": entity_id,
  282. "friendly_name": attrs.get("friendly_name", entity_id),
  283. "state": entity.get("state"),
  284. "unit_of_measurement": unit,
  285. }
  286. )
  287. return sorted(entities, key=lambda x: x["friendly_name"].lower())
  288. except Exception as e:
  289. logger.warning("Failed to list HA sensor entities: %s", e)
  290. return []
  291. # Singleton instance
  292. homeassistant_service = HomeAssistantService()