From 20579c80db5dd50d1b0e53b50a11b9eb2f90deaf Mon Sep 17 00:00:00 2001 From: Aditya Mehra Date: Thu, 12 Jan 2023 21:55:33 +1030 Subject: [PATCH] Switch wsconnector to internal socket client --- ovos_PHAL_plugin_homeassistant/__init__.py | 50 ++-- .../logic/connector.py | 77 +++-- .../logic/socketclient.py | 274 ++++++++++++++++++ requirements.txt | 2 +- 4 files changed, 336 insertions(+), 67 deletions(-) create mode 100644 ovos_PHAL_plugin_homeassistant/logic/socketclient.py diff --git a/ovos_PHAL_plugin_homeassistant/__init__.py b/ovos_PHAL_plugin_homeassistant/__init__.py index a6d2daf..7085c14 100644 --- a/ovos_PHAL_plugin_homeassistant/__init__.py +++ b/ovos_PHAL_plugin_homeassistant/__init__.py @@ -32,6 +32,8 @@ def __init__(self, bus=None, config=None): self.gui = GUIInterface(bus=self.bus, skill_id=self.name) self.integrator = Integrator(self.bus, self.gui) self.instance_available = False + self.use_ws = False + self.enable_debug = self.config.get("enable_debug", False) self.device_types = { "sensor": HomeAssistantSensor, "binary_sensor": HomeAssistantBinarySensor, @@ -95,14 +97,19 @@ def validate_instance_connection(self, host, api_key): bool: True if the connection is valid, False otherwise """ try: - if self.config.get('use_websocket'): - LOG.info("Using websocket connection") - validator = HomeAssistantWSConnector(host, api_key) + if self.use_ws: + validator = HomeAssistantWSConnector(host, api_key, self.enable_debug) else: - validator = HomeAssistantRESTConnector(host, api_key) + validator = HomeAssistantRESTConnector(host, api_key, self.enable_debug) validator.get_all_devices() + + if self.use_ws: + if validator.client: + validator.disconnect() + return True + except Exception as e: LOG.error(e) return False @@ -118,15 +125,7 @@ def setup_configuration(self, message): if host and key: if host.startswith("ws") or host.startswith("wss"): - config_patch = { - "PHAL": { - "ovos-PHAL-plugin-homeassistant": { - "use_websocket": True - } - } - } - update_mycroft_config(config=config_patch, bus=self.bus) - sleep(2) # wait for config to be updated + self.use_ws = True if self.validate_instance_connection(host, key): self.config["host"] = host @@ -149,21 +148,24 @@ def init_configuration(self, message=None): """ Initialize instance configuration """ configuration_host = self.config.get("host", "") configuration_api_key = self.config.get("api_key", "") + if configuration_host.startswith("ws") or configuration_host.startswith("wss"): + self.use_ws = True + if not self.config.get("use_group_display"): self.config["use_group_display"] = False if configuration_host != "" and configuration_api_key != "": self.instance_available = True - if self.config.get('use_websocket'): + if self.use_ws: self.connector = HomeAssistantWSConnector(configuration_host, - configuration_api_key) + configuration_api_key, self.enable_debug) else: self.connector = HomeAssistantRESTConnector( - configuration_host, configuration_api_key) + configuration_host, configuration_api_key, self.enable_debug) self.devices = self.connector.get_all_devices() self.registered_devices = [] self.build_devices() - self.gui["use_websocket"] = self.config.get("use_websocket", False) + self.gui["use_websocket"] = self.use_ws self.gui["instanceAvailable"] = True self.bus.emit(Message("ovos.phal.plugin.homeassistant.ready")) else: @@ -185,8 +187,10 @@ def build_devices(self): device_icon = f"mdi:{device_type}" device_state = device.get("state", None) device_area = device.get("area_id", None) - LOG.info( - f"Device added: {device_name} - {device_type} - {device_area}") + if self.enable_debug: + LOG.info( + f"Device added: {device_name} - {device_type} - {device_area}") + device_attributes = device.get("attributes", {}) if device_type in self.device_types: self.registered_devices.append(self.device_types[device_type]( @@ -396,7 +400,7 @@ def handle_show_dashboard(self, message=None): message (Message): The message object """ if self.instance_available: - self.gui["use_websocket"] = self.config.get("use_websocket", False) + self.gui["use_websocket"] = self.use_ws if not self.config.get("use_group_display"): display_list_model = { "items": self.build_display_dashboard_device_model()} @@ -420,8 +424,9 @@ def handle_show_dashboard(self, message=None): self.gui["use_group_display"] = self.config.get("use_group_display", False) self.gui.show_page(page, override_idle=True) - LOG.info("Using group display") - LOG.info(self.config["use_group_display"]) + if self.enable_debug: + LOG.debug("Using group display") + LOG.debug(self.config["use_group_display"]) def handle_close_dashboard(self, message): """ Handle the close dashboard message @@ -499,7 +504,6 @@ def handle_set_group_display_settings(self, message): "ovos-PHAL-plugin-homeassistant": { "host": self.config.get("host"), "api_key": self.config.get("api_key"), - "use_websocket": self.config.get("use_websocket", False), "use_group_display": use_group_display } } diff --git a/ovos_PHAL_plugin_homeassistant/logic/connector.py b/ovos_PHAL_plugin_homeassistant/logic/connector.py index 6368444..2cf970c 100644 --- a/ovos_PHAL_plugin_homeassistant/logic/connector.py +++ b/ovos_PHAL_plugin_homeassistant/logic/connector.py @@ -4,20 +4,19 @@ import requests import json import sys -import asyncio -from hass_client.client import HomeAssistantClient from ovos_utils.log import LOG - +from ovos_PHAL_plugin_homeassistant.logic.socketclient import HomeAssistantClient class HomeAssistantConnector: - def __init__(self, host, api_key): + def __init__(self, host, api_key, enable_debug=False): """ Constructor Args: host (str): The host of the home assistant instance. api_key (str): The api key """ + self.enable_debug = enable_debug self.host = host self.api_key = api_key @@ -115,8 +114,9 @@ def call_function(self, device_id, device_type, function, arguments=None): class HomeAssistantRESTConnector(HomeAssistantConnector): - def __init__(self, host, api_key): + def __init__(self, host, api_key, enable_debug=False): super().__init__(host, api_key) + self.enable_debug = enable_debug self.headers = {'Authorization': 'Bearer ' + self.api_key, 'content-type': 'application/json'} @@ -251,37 +251,33 @@ def call_function(self, device_id, device_type, function, arguments=None): class HomeAssistantWSConnector(HomeAssistantConnector): - def __init__(self, host, api_key): + def __init__(self, host, api_key, enable_debug=False): super().__init__(host, api_key) if self.host.startswith('http'): self.host.replace('http', 'ws', 1) - try: - self._loop = asyncio.get_event_loop() - except RuntimeError: - asyncio.set_event_loop(asyncio.new_event_loop()) - self._loop = asyncio.get_event_loop() - self._client: HomeAssistantClient = None - self._loop.run_until_complete(self.start_client()) - - @property - def client(self): - if not self._client.connected: - LOG.error("Client not connected, re-initializing") - self._loop.run_until_complete(self.start_client()) - return self._client - - async def start_client(self): - self._client = self._client or HomeAssistantClient(self.host, - self.api_key) - await self._client.connect() + self.enable_debug = enable_debug + self._connection = HomeAssistantClient(self.host, self.api_key) + self._connection.connect() + + # Initialize client instance + self.client = self._connection.get_instance_sync() + self.client.build_registries_sync() + self.client.register_event_listener(self.event_listener) + self.client.subscribe_events_sync() + + def event_listener(self, message): + # Todo: Implementation with UI + # For now it uses the old states update method + pass @staticmethod - def _device_entry_compat(devices: dict): + def _device_entry_compat(devices: dict, enable_debug): disabled_devices = list() for idx, dev in devices.items(): if dev.get('disabled_by'): - LOG.debug(f'Ignoring {dev.get("entity_id")} disabled by ' - f'{dev.get("disabled_by")}') + if(enable_debug): + LOG.debug(f'Ignoring {dev.get("entity_id")} disabled by ' + f'{dev.get("disabled_by")}') disabled_devices.append(idx) else: devices[idx].setdefault( @@ -291,14 +287,12 @@ def _device_entry_compat(devices: dict): def get_all_devices(self) -> list: devices = self.client.entity_registry - self._device_entry_compat(devices) + self._device_entry_compat(devices, self.enable_debug) devices_with_area = self.assign_group_for_devices(devices) return list(devices_with_area.values()) def get_device_state(self, entity_id: str) -> dict: - message = {'type': 'get_states'} - states = self._loop.run_until_complete( - self.client.send_command(message)) + states = self.client.get_states_sync() for state in states: if state['entity_id'] == entity_id: return state @@ -324,30 +318,27 @@ def get_all_devices_with_type_and_attribute_not_in(self, device_type, attribute, def turn_on(self, device_id, device_type): LOG.debug(f"Turn on {device_id}") - self._loop.run_until_complete( - self.client.call_service(device_type, 'turn_on', - {'entity_id': device_id})) + self.client.call_service_sync(device_type, 'turn_on', {'entity_id': device_id}) def turn_off(self, device_id, device_type): LOG.debug(f"Turn off {device_id}") - self._loop.run_until_complete( - self.client.call_service(device_type, 'turn_off', - {'entity_id': device_id})) + self.client.call_service_sync(device_type, 'turn_off', {'entity_id': device_id}) def call_function(self, device_id, device_type, function, arguments=None): arguments = arguments or dict() arguments['entity_id'] = device_id - self._loop.run_until_complete( - self.client.call_service(device_type, function, arguments)) + self.client.call_service_sync(device_type, function, arguments) def assign_group_for_devices(self, devices): - devices_from_registry = self._loop.run_until_complete( - self.client.send_command({'type': 'config/device_registry/list'})) + devices_from_registry = self.client.send_command_sync('config/device_registry/list') - for device_item in devices_from_registry: + for device_item in devices_from_registry["result"]: for device in devices.values(): if device['device_id'] == device_item['id']: device['area_id'] = device_item['area_id'] break return devices + + def disconnect(self): + self._connection.disconnect() \ No newline at end of file diff --git a/ovos_PHAL_plugin_homeassistant/logic/socketclient.py b/ovos_PHAL_plugin_homeassistant/logic/socketclient.py new file mode 100644 index 0000000..abf9d42 --- /dev/null +++ b/ovos_PHAL_plugin_homeassistant/logic/socketclient.py @@ -0,0 +1,274 @@ +import asyncio +import json +import threading + +import requests +import websockets +from ovos_utils.log import LOG + + +class HomeAssistantClient: + def __init__(self, url, token): + self.url = url + self.token = token + self.websocket = None + self.loop = asyncio.new_event_loop() + asyncio.set_event_loop(self.loop) + self.response_queue = asyncio.Queue() + self.event_queue = asyncio.Queue() + self.id_list = [] + self.thread = threading.Thread(target=self.run) + self.last_id = None + self.event_listener = None + self.authenticated = False + + self._device_registry = {} + self._entity_registry = {} + self._area_registry = {} + + async def authenticate(self): + await self.websocket.send(f'{{"type": "auth", "access_token": "{self.token}"}}') + message = await self.websocket.recv() + message = json.loads(message) + if message.get("type") == "auth_ok": + self.authenticated = True + await self.listen() + else: + self.authenticated = False + LOG.error("WS HA Connection Failed to authenticate") + return + + async def _connect(self): + try: + uri = f"{self.url}/api/websocket" + self.websocket = await websockets.connect(uri) + + # Wait for the auth_required message + message = await self.websocket.recv() + message = json.loads(message) + if message.get("type") == "auth_required": + await self.authenticate() + if not self.authenticated: + return + else: + raise Exception("Expected auth_required message") + except Exception as e: + LOG.error(e) + await self._disconnect() + return + + async def _disconnect(self): + if self.websocket is not None: + await self.websocket.close() + self.websocket = None + + async def listen(self): + while self.websocket is not None: + message = await self.websocket.recv() + message = json.loads(message) + if message.get("type") == "event": + if self.event_listener is not None: + self.event_listener(message) + else: + await self.event_queue.put(message) + else: + await self.response_queue.put(message) + + async def send_command(self, command): + id = self.counter + self.last_id = id + message = { + "id": id, + "type": command + } + await self.websocket.send(json.dumps(message)) + + async def send_raw_command(self, command): + id = self.counter + self.last_id = id + message = { + "id": id, + "type": command + } + await self.websocket.send(json.dumps(message)) + response = await self.response_queue.get() + self.response_queue.task_done() + return response + + async def call_service(self, domain, service, service_data): + id = self.counter + self.last_id = id + message = { + "id": id, + "type": "call_service", + "domain": domain, + "service": service, + "service_data": service_data + } + await self.websocket.send(json.dumps(message)) + response = await self.response_queue.get() + self.response_queue.task_done() + return response + + async def get_states(self): + await self.send_command("get_states") + message = await self.response_queue.get() + self.response_queue.task_done() + if message.get("result") is None: + LOG.info("No states found") + return [] + else: + return message["result"] + + async def subscribe_events(self): + await self.send_command("subscribe_events") + message = await self.response_queue.get() + self.response_queue.task_done() + return message + + async def get_instance(self): + while self.websocket is None: + await asyncio.sleep(0.1) + return self + + async def build_registries(self): + # First clean the registries + self._device_registry = {} + self._entity_registry = {} + self._area_registry = {} + + # device registry + await self.send_command("config/device_registry/list") + message = await self.response_queue.get() + self.response_queue.task_done() + for item in message["result"]: + item_id = item["id"] + self._device_registry[item_id] = item + + # entity registry + await self.send_command("config/entity_registry/list") + message = await self.response_queue.get() + self.response_queue.task_done() + for item in message["result"]: + item_id = item["entity_id"] + self._entity_registry[item_id] = item + + # area registry + await self.send_command("config/area_registry/list") + message = await self.response_queue.get() + self.response_queue.task_done() + for item in message["result"]: + item_id = item["area_id"] + self._area_registry[item_id] = item + + return True + + @property + def device_registry(self) -> dict: + """Return device registry.""" + if not self._device_registry: + asyncio.run_coroutine_threadsafe( + self.build_registries(), self.loop) + LOG.debug("Registry is empty, building registry first.") + return self._device_registry + + @property + def entity_registry(self) -> dict: + """Return device registry.""" + if not self._entity_registry: + asyncio.run_coroutine_threadsafe( + self.build_registries(), self.loop) + LOG.debug("Registry is empty, building registry first.") + return self._entity_registry + + @property + def area_registry(self) -> dict: + """Return device registry.""" + if not self._area_registry: + asyncio.run_coroutine_threadsafe( + self.build_registries(), self.loop) + LOG.debug("Registry is empty, building registry first.") + return self._area_registry + + @property + def counter(self): + if len(self.id_list) == 0: + self.id_list.append(1) + return 1 + else: + new_id = max(self.id_list) + 1 + self.id_list.append(new_id) + return new_id + + def set_state(self, entity_id, state, attributes): + id = self.counter + self.last_id = id + data = { + "id": id, + "type": "get_state", + "entity_id": entity_id, + "state": state, + "attributes": attributes + } + response = self._post_request(f"states/{entity_id}", data) + return response + + def _post_request(self, endpoint, data): + url = self.url.replace("wss", "https").replace("ws", "http") + full_url = f"{url}{endpoint}" + headers = { + "Authorization": f"Bearer {self.token}", + "Content-Type": "application/json", + } + response = requests.post(full_url, headers=headers, json=data) + return response.json() + + def run(self): + self.loop.run_until_complete(self._connect()) + LOG.info(self.loop.is_running()) + + def connect(self): + self.thread.start() + + def disconnect(self): + asyncio.run_coroutine_threadsafe(self._disconnect(), self.loop) + self.thread.join() + + def get_states_sync(self): + task = asyncio.run_coroutine_threadsafe(self.get_states(), self.loop) + return task.result() + + def subscribe_events_sync(self): + task = asyncio.run_coroutine_threadsafe( + self.subscribe_events(), self.loop) + return task.result() + + def get_instance_sync(self): + task = asyncio.run_coroutine_threadsafe(self.get_instance(), self.loop) + return task.result() + + def get_event_sync(self): + task = asyncio.run_coroutine_threadsafe( + self.event_queue.get(), self.loop) + return task.result() + + def build_registries_sync(self): + task = asyncio.run_coroutine_threadsafe( + self.build_registries(), self.loop) + return task.result() + + def register_event_listener(self, listener): + self.event_listener = listener + + def unregister_event_listener(self): + self.event_listener = None + + def send_command_sync(self, command): + task = asyncio.run_coroutine_threadsafe( + self.send_raw_command(command), self.loop) + return task.result() + + def call_service_sync(self, domain, service, service_data): + task = asyncio.run_coroutine_threadsafe( + self.call_service(domain, service, service_data), self.loop) + return task.result() diff --git a/requirements.txt b/requirements.txt index 690bf0b..669fcbf 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,4 +4,4 @@ ovos-config~=0.0.5 mycroft-messagebus-client youtube-search pytube -hass-client~=0.1 \ No newline at end of file +websockets