Skip to content

Commit d1595f5

Browse files
committed
Add safety around missing ssh drivers.
1 parent 4ec6344 commit d1595f5

File tree

2 files changed

+28
-9
lines changed

2 files changed

+28
-9
lines changed

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "pywattbox"
3-
version = "0.7.1"
3+
version = "0.7.2"
44
description = "A python wrapper for WattBox APIs."
55
license = "MIT"
66
readme = "README.md"

pywattbox/ip_wattbox.py

+27-8
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
Union,
1616
)
1717

18+
from scrapli.exceptions import ScrapliTransportPluginError
1819
from scrapli.response import Response
1920

2021
from .base import BaseWattBox, Commands, Outlet, _async_create_wattbox, _create_wattbox
@@ -100,6 +101,10 @@ class UpdateBaseResponses(NamedTuple):
100101
_Responses = TypeVar("_Responses", bound=Union[InitialResponses, UpdateBaseResponses])
101102

102103

104+
class DriverUnavailableError(Exception):
105+
pass
106+
107+
103108
class IpWattBox(BaseWattBox):
104109
def __init__(
105110
self,
@@ -129,18 +134,26 @@ def __init__(
129134
else:
130135
raise ValueError("Non Standard Port, Transport must be set.")
131136

132-
self.driver = WattBoxDriver(
133-
**conninfo,
134-
transport="ssh2" if transport == "ssh" else "telnet",
135-
)
136-
self.async_driver = WattBoxAsyncDriver(
137-
**conninfo,
138-
transport="asyncssh" if transport == "ssh" else "asynctelnet",
139-
)
137+
try:
138+
self.driver: Optional[WattBoxDriver] = WattBoxDriver(
139+
**conninfo,
140+
transport="ssh2" if transport == "ssh" else "telnet",
141+
)
142+
except ScrapliTransportPluginError:
143+
self.driver = None
144+
try:
145+
self.async_driver: Optional[WattBoxAsyncDriver] = WattBoxAsyncDriver(
146+
**conninfo,
147+
transport="asyncssh" if transport == "ssh" else "asynctelnet",
148+
)
149+
except ScrapliTransportPluginError:
150+
self.async_driver = None
140151

141152
def send_requests(
142153
self, requests: Iterable[Union[REQUEST_MESSAGES, str]]
143154
) -> List[Response]:
155+
if not self.driver:
156+
raise DriverUnavailableError
144157
responses: List[Response] = []
145158
for request in requests:
146159
responses.append(
@@ -153,6 +166,8 @@ def send_requests(
153166
async def async_send_requests(
154167
self, requests: Iterable[Union[REQUEST_MESSAGES, str]]
155168
) -> List[Response]:
169+
if not self.async_driver:
170+
raise DriverUnavailableError
156171
responses: List[Response] = []
157172
for request in requests:
158173
responses.append(
@@ -266,6 +281,8 @@ async def async_update(self) -> None:
266281

267282
def send_command(self, outlet: int, command: Commands) -> None:
268283
logger.debug("Send Command")
284+
if not self.driver:
285+
raise DriverUnavailableError
269286
self.driver._send_command(
270287
CONTROL_MESSAGES.OUTLET_SET.value.format(
271288
outlet=outlet, action=command.name, delay=0
@@ -275,6 +292,8 @@ def send_command(self, outlet: int, command: Commands) -> None:
275292

276293
async def async_send_command(self, outlet: int, command: Commands) -> None:
277294
logger.debug("Async Send Command")
295+
if not self.async_driver:
296+
raise DriverUnavailableError
278297
await self.async_driver._send_command(
279298
CONTROL_MESSAGES.OUTLET_SET.value.format(
280299
outlet=outlet, action=command.name, delay=0

0 commit comments

Comments
 (0)