Skip to content

Commit 28a3b4a

Browse files
authored
Add error handling for OpenAI (home-assistant#86671)
* Add error handling for OpenAI * Simplify area filtering * better prompt
1 parent c395698 commit 28a3b4a

File tree

3 files changed

+70
-39
lines changed

3 files changed

+70
-39
lines changed

homeassistant/components/openai_conversation/__init__.py

+21-22
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33

44
from functools import partial
55
import logging
6-
from typing import cast
76

87
import openai
98
from openai import error
@@ -13,7 +12,7 @@
1312
from homeassistant.const import CONF_API_KEY
1413
from homeassistant.core import HomeAssistant
1514
from homeassistant.exceptions import ConfigEntryNotReady, TemplateError
16-
from homeassistant.helpers import area_registry, device_registry, intent, template
15+
from homeassistant.helpers import area_registry, intent, template
1716
from homeassistant.util import ulid
1817

1918
from .const import DEFAULT_MODEL, DEFAULT_PROMPT
@@ -97,15 +96,26 @@ async def async_process(
9796

9897
_LOGGER.debug("Prompt for %s: %s", model, prompt)
9998

100-
result = await self.hass.async_add_executor_job(
101-
partial(
102-
openai.Completion.create,
103-
engine=model,
104-
prompt=prompt,
105-
max_tokens=150,
106-
user=conversation_id,
99+
try:
100+
result = await self.hass.async_add_executor_job(
101+
partial(
102+
openai.Completion.create,
103+
engine=model,
104+
prompt=prompt,
105+
max_tokens=150,
106+
user=conversation_id,
107+
)
107108
)
108-
)
109+
except error.OpenAIError as err:
110+
intent_response = intent.IntentResponse(language=user_input.language)
111+
intent_response.async_set_error(
112+
intent.IntentResponseErrorCode.UNKNOWN,
113+
f"Sorry, I had a problem talking to OpenAI: {err}",
114+
)
115+
return conversation.ConversationResult(
116+
response=intent_response, conversation_id=conversation_id
117+
)
118+
109119
_LOGGER.debug("Response %s", result)
110120
response = result["choices"][0]["text"].strip()
111121
self.history[conversation_id] = prompt + response
@@ -122,20 +132,9 @@ async def async_process(
122132

123133
def _async_generate_prompt(self) -> str:
124134
"""Generate a prompt for the user."""
125-
dev_reg = device_registry.async_get(self.hass)
126135
return template.Template(DEFAULT_PROMPT, self.hass).async_render(
127136
{
128137
"ha_name": self.hass.config.location_name,
129-
"areas": [
130-
area
131-
for area in area_registry.async_get(self.hass).areas.values()
132-
# Filter out areas without devices
133-
if any(
134-
not dev.disabled_by
135-
for dev in device_registry.async_entries_for_area(
136-
dev_reg, cast(str, area.id)
137-
)
138-
)
139-
],
138+
"areas": list(area_registry.async_get(self.hass).areas.values()),
140139
}
141140
)

homeassistant/components/openai_conversation/const.py

+15-8
Original file line numberDiff line numberDiff line change
@@ -3,19 +3,26 @@
33
DOMAIN = "openai_conversation"
44
CONF_PROMPT = "prompt"
55
DEFAULT_MODEL = "text-davinci-003"
6-
DEFAULT_PROMPT = """
7-
You are a conversational AI for a smart home named {{ ha_name }}.
8-
If a user wants to control a device, reject the request and suggest using the Home Assistant UI.
6+
DEFAULT_PROMPT = """This smart home is controlled by Home Assistant.
97
108
An overview of the areas and the devices in this smart home:
11-
{% for area in areas %}
9+
{%- for area in areas %}
10+
{%- set area_info = namespace(printed=false) %}
11+
{%- for device in area_devices(area.name) -%}
12+
{%- if not device_attr(device, "disabled_by") and not device_attr(device, "entry_type") %}
13+
{%- if not area_info.printed %}
14+
1215
{{ area.name }}:
13-
{% for device in area_devices(area.name) -%}
14-
{%- if not device_attr(device, "disabled_by") %}
15-
- {{ device_attr(device, "name") }} ({{ device_attr(device, "model") }} by {{ device_attr(device, "manufacturer") }})
16+
{%- set area_info.printed = true %}
17+
{%- endif %}
18+
- {{ device_attr(device, "name") }}{% if device_attr(device, "model") not in device_attr(device, "name") %} ({{ device_attr(device, "model") }}){% endif %}
1619
{%- endif %}
1720
{%- endfor %}
18-
{% endfor %}
21+
{%- endfor %}
22+
23+
Answer the users questions about the world truthfully.
24+
25+
If the user wants to control a device, reject the request and suggest using the Home Assistant UI.
1926
2027
Now finish this conversation:
2128
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,20 @@
11
"""Tests for the OpenAI integration."""
22
from unittest.mock import patch
33

4+
from openai import error
5+
46
from homeassistant.components import conversation
57
from homeassistant.core import Context
6-
from homeassistant.helpers import device_registry
8+
from homeassistant.helpers import area_registry, device_registry, intent
79

810

911
async def test_default_prompt(hass, mock_init_component):
1012
"""Test that the default prompt works."""
1113
device_reg = device_registry.async_get(hass)
14+
area_reg = area_registry.async_get(hass)
15+
16+
for i in range(3):
17+
area_reg.async_create(f"{i}Empty Area")
1218

1319
device_reg.async_get_or_create(
1420
config_entry_id="1234",
@@ -18,20 +24,30 @@ async def test_default_prompt(hass, mock_init_component):
1824
model="Test Model",
1925
suggested_area="Test Area",
2026
)
27+
for i in range(3):
28+
device_reg.async_get_or_create(
29+
config_entry_id="1234",
30+
connections={("test", f"{i}abcd")},
31+
name="Test Service",
32+
manufacturer="Test Manufacturer",
33+
model="Test Model",
34+
suggested_area="Test Area",
35+
entry_type=device_registry.DeviceEntryType.SERVICE,
36+
)
2137
device_reg.async_get_or_create(
2238
config_entry_id="1234",
2339
connections={("test", "5678")},
2440
name="Test Device 2",
2541
manufacturer="Test Manufacturer 2",
26-
model="Test Model 2",
42+
model="Device 2",
2743
suggested_area="Test Area 2",
2844
)
2945
device_reg.async_get_or_create(
3046
config_entry_id="1234",
3147
connections={("test", "9876")},
3248
name="Test Device 3",
3349
manufacturer="Test Manufacturer 3",
34-
model="Test Model 3",
50+
model="Test Model 3A",
3551
suggested_area="Test Area 2",
3652
)
3753

@@ -40,24 +56,33 @@ async def test_default_prompt(hass, mock_init_component):
4056

4157
assert (
4258
mock_create.mock_calls[0][2]["prompt"]
43-
== """You are a conversational AI for a smart home named test home.
44-
If a user wants to control a device, reject the request and suggest using the Home Assistant UI.
59+
== """This smart home is controlled by Home Assistant.
4560
4661
An overview of the areas and the devices in this smart home:
4762
4863
Test Area:
49-
50-
- Test Device (Test Model by Test Manufacturer)
64+
- Test Device (Test Model)
5165
5266
Test Area 2:
67+
- Test Device 2
68+
- Test Device 3 (Test Model 3A)
5369
54-
- Test Device 2 (Test Model 2 by Test Manufacturer 2)
55-
- Test Device 3 (Test Model 3 by Test Manufacturer 3)
70+
Answer the users questions about the world truthfully.
5671
72+
If the user wants to control a device, reject the request and suggest using the Home Assistant UI.
5773
5874
Now finish this conversation:
5975
6076
Smart home: How can I assist?
6177
User: hello
6278
Smart home: """
6379
)
80+
81+
82+
async def test_error_handling(hass, mock_init_component):
83+
"""Test that the default prompt works."""
84+
with patch("openai.Completion.create", side_effect=error.ServiceUnavailableError):
85+
result = await conversation.async_converse(hass, "hello", None, Context())
86+
87+
assert result.response.response_type == intent.IntentResponseType.ERROR, result
88+
assert result.response.error_code == "unknown", result

0 commit comments

Comments
 (0)