Skip to content

Commit 2b264c3

Browse files
Add support for fine-tunning and files using the Azure API. (#76)
* Add support for fine-tunning and files using the Azure API. * Small changes + version bumps * Version bump after merge * fix typo * adressed comments
1 parent 101b444 commit 2b264c3

12 files changed

+153
-39
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,5 @@ __pycache__
66
build
77
*.egg
88
.vscode/settings.json
9-
.ipynb_checkpoints
9+
.ipynb_checkpoints
10+
.vscode/launch.json

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ search = openai.Engine(id="deployment-namme").search(documents=["White House", "
7777
print(search)
7878
```
7979

80-
Please note that for the moment, the Microsoft Azure endpoints can only be used for completion and search operations.
80+
Please note that for the moment, the Microsoft Azure endpoints can only be used for completion, search and fine-tuning operations.
8181

8282
### Command-line interface
8383

openai/api_requestor.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ def request(
111111

112112
def handle_error_response(self, rbody, rcode, resp, rheaders, stream_error=False):
113113
try:
114-
error_data = resp["error"]
114+
error_data = resp["error"] if self.api_type == ApiType.OPEN_AI else resp
115115
except (KeyError, TypeError):
116116
raise error.APIError(
117117
"Invalid response object from API: %r (HTTP response code "
@@ -322,6 +322,10 @@ def _interpret_response(
322322
def _interpret_response_line(
323323
self, rbody, rcode, rheaders, stream: bool
324324
) -> OpenAIResponse:
325+
# HTTP 204 response code does not have any content in the body.
326+
if rcode == 204:
327+
return OpenAIResponse(None, rheaders)
328+
325329
if rcode == 503:
326330
raise error.ServiceUnavailableError(
327331
"The server is overloaded or not ready yet.", rbody, rcode, headers=rheaders

openai/api_resources/abstract/api_resource.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88

99
class APIResource(OpenAIObject):
1010
api_prefix = ""
11-
azure_api_prefix = 'openai/deployments'
11+
azure_api_prefix = 'openai'
12+
azure_deployments_prefix = 'deployments'
1213

1314
@classmethod
1415
def retrieve(cls, id, api_key=None, request_id=None, **params):
@@ -46,20 +47,21 @@ def instance_url(self, operation=None):
4647
"id",
4748
)
4849
api_version = self.api_version or openai.api_version
50+
extn = quote_plus(id)
4951

5052
if self.typed_api_type == ApiType.AZURE:
5153
if not api_version:
5254
raise error.InvalidRequestError("An API version is required for the Azure API type.")
55+
5356
if not operation:
54-
raise error.InvalidRequestError(
55-
"The request needs an operation (eg: 'search') for the Azure OpenAI API type."
56-
)
57-
extn = quote_plus(id)
58-
return "/%s/%s/%s?api-version=%s" % (self.azure_api_prefix, extn, operation, api_version)
57+
base = self.class_url()
58+
return "/%s%s/%s?api-version=%s" % (self.azure_api_prefix, base, extn, api_version)
59+
60+
return "/%s/%s/%s/%s?api-version=%s" % (
61+
self.azure_api_prefix, self.azure_deployments_prefix, extn, operation, api_version)
5962

6063
elif self.typed_api_type == ApiType.OPEN_AI:
6164
base = self.class_url()
62-
extn = quote_plus(id)
6365
return "%s/%s" % (base, extn)
6466

6567
else:
@@ -75,6 +77,7 @@ def _static_request(
7577
url_,
7678
api_key=None,
7779
api_base=None,
80+
api_type=None,
7881
request_id=None,
7982
api_version=None,
8083
organization=None,
@@ -85,10 +88,18 @@ def _static_request(
8588
api_version=api_version,
8689
organization=organization,
8790
api_base=api_base,
91+
api_type=api_type
8892
)
8993
response, _, api_key = requestor.request(
9094
method_, url_, params, request_id=request_id
9195
)
9296
return util.convert_to_openai_object(
9397
response, api_key, api_version, organization
9498
)
99+
100+
@classmethod
101+
def _get_api_type_and_version(cls, api_type: str, api_version: str):
102+
typed_api_type = ApiType.from_str(api_type) if api_type else ApiType.from_str(openai.api_type)
103+
typed_api_version = api_version or openai.api_version
104+
return (typed_api_type, typed_api_version)
105+

openai/api_resources/abstract/createable_api_resource.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
from openai import api_requestor, util
1+
from openai import api_requestor, util, error
22
from openai.api_resources.abstract.api_resource import APIResource
3+
from openai.util import ApiType
34

45

56
class CreateableAPIResource(APIResource):
@@ -10,6 +11,7 @@ def create(
1011
cls,
1112
api_key=None,
1213
api_base=None,
14+
api_type=None,
1315
request_id=None,
1416
api_version=None,
1517
organization=None,
@@ -18,10 +20,20 @@ def create(
1820
requestor = api_requestor.APIRequestor(
1921
api_key,
2022
api_base=api_base,
23+
api_type=api_type,
2124
api_version=api_version,
2225
organization=organization,
2326
)
24-
url = cls.class_url()
27+
typed_api_type, api_version = cls._get_api_type_and_version(api_type, api_version)
28+
29+
if typed_api_type == ApiType.AZURE:
30+
base = cls.class_url()
31+
url = "/%s%s?api-version=%s" % (cls.azure_api_prefix, base, api_version)
32+
elif typed_api_type == ApiType.OPEN_AI:
33+
url = cls.class_url()
34+
else:
35+
raise error.InvalidAPIType('Unsupported API type %s' % api_type)
36+
2537
response, _, api_key = requestor.request(
2638
"post", url, params, request_id=request_id
2739
)
Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,24 @@
11
from urllib.parse import quote_plus
22

3+
from openai import error
34
from openai.api_resources.abstract.api_resource import APIResource
4-
5+
from openai.util import ApiType
56

67
class DeletableAPIResource(APIResource):
78
@classmethod
8-
def delete(cls, sid, **params):
9+
def delete(cls, sid, api_type=None, api_version=None, **params):
910
if isinstance(cls, APIResource):
1011
raise ValueError(".delete may only be called as a class method now.")
11-
url = "%s/%s" % (cls.class_url(), quote_plus(sid))
12-
return cls._static_request("delete", url, **params)
12+
13+
base = cls.class_url()
14+
extn = quote_plus(sid)
15+
16+
typed_api_type, api_version = cls._get_api_type_and_version(api_type, api_version)
17+
if typed_api_type == ApiType.AZURE:
18+
url = "/%s%s/%s?api-version=%s" % (cls.azure_api_prefix, base, extn, api_version)
19+
elif typed_api_type == ApiType.OPEN_AI:
20+
url = "%s/%s" % (base, extn)
21+
else:
22+
raise error.InvalidAPIType('Unsupported API type %s' % api_type)
23+
24+
return cls._static_request("delete", url, api_type=api_type, api_version=api_version, **params)

openai/api_resources/abstract/engine_api_resource.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
class EngineAPIResource(APIResource):
1616
engine_required = True
1717
plain_old_data = False
18-
azure_api_prefix = 'openai/deployments'
1918

2019
def __init__(self, engine: Optional[str] = None, **kwargs):
2120
super().__init__(engine=engine, **kwargs)
@@ -25,8 +24,7 @@ def class_url(cls, engine: Optional[str] = None, api_type : Optional[str] = None
2524
# Namespaces are separated in object names with periods (.) and in URLs
2625
# with forward slashes (/), so replace the former with the latter.
2726
base = cls.OBJECT_NAME.replace(".", "/") # type: ignore
28-
typed_api_type = ApiType.from_str(api_type) if api_type else ApiType.from_str(openai.api_type)
29-
api_version = api_version or openai.api_version
27+
typed_api_type, api_version = cls._get_api_type_and_version(api_type, api_version)
3028

3129
if typed_api_type == ApiType.AZURE:
3230
if not api_version:
@@ -36,7 +34,8 @@ def class_url(cls, engine: Optional[str] = None, api_type : Optional[str] = None
3634
"You must provide the deployment name in the 'engine' parameter to access the Azure OpenAI service"
3735
)
3836
extn = quote_plus(engine)
39-
return "/%s/%s/%ss?api-version=%s" % (cls.azure_api_prefix, extn, base, api_version)
37+
return "/%s/%s/%s/%ss?api-version=%s" % (
38+
cls.azure_api_prefix, cls.azure_deployments_prefix, extn, base, api_version)
4039

4140
elif typed_api_type == ApiType.OPEN_AI:
4241
if engine is None:
@@ -133,19 +132,20 @@ def instance_url(self):
133132
"id",
134133
)
135134

135+
extn = quote_plus(id)
136136
params_connector = '?'
137+
137138
if self.typed_api_type == ApiType.AZURE:
138139
api_version = self.api_version or openai.api_version
139140
if not api_version:
140141
raise error.InvalidRequestError("An API version is required for the Azure API type.")
141-
extn = quote_plus(id)
142142
base = self.OBJECT_NAME.replace(".", "/")
143-
url = "/%s/%s/%ss/%s?api-version=%s" % (self.azure_api_prefix, self.engine, base, extn, api_version)
143+
url = "/%s/%s/%s/%ss/%s?api-version=%s" % (
144+
self.azure_api_prefix, self.azure_deployments_prefix, self.engine, base, extn, api_version)
144145
params_connector = '&'
145146

146147
elif self.typed_api_type == ApiType.OPEN_AI:
147148
base = self.class_url(self.engine, self.api_type, self.api_version)
148-
extn = quote_plus(id)
149149
url = "%s/%s" % (base, extn)
150150

151151
else:

openai/api_resources/abstract/listable_api_resource.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
from openai import api_requestor, util
1+
from openai import api_requestor, util, error
22
from openai.api_resources.abstract.api_resource import APIResource
3+
from openai.util import ApiType
34

45

56
class ListableAPIResource(APIResource):
@@ -15,15 +16,27 @@ def list(
1516
api_version=None,
1617
organization=None,
1718
api_base=None,
19+
api_type=None,
1820
**params,
1921
):
2022
requestor = api_requestor.APIRequestor(
2123
api_key,
2224
api_base=api_base or cls.api_base(),
2325
api_version=api_version,
26+
api_type=api_type,
2427
organization=organization,
2528
)
26-
url = cls.class_url()
29+
30+
typed_api_type, api_version = cls._get_api_type_and_version(api_type, api_version)
31+
32+
if typed_api_type == ApiType.AZURE:
33+
base = cls.class_url()
34+
url = "/%s%s?api-version=%s" % (cls.azure_api_prefix, base, api_version)
35+
elif typed_api_type == ApiType.OPEN_AI:
36+
url = cls.class_url()
37+
else:
38+
raise error.InvalidAPIType('Unsupported API type %s' % api_type)
39+
2740
response, _, api_key = requestor.request(
2841
"get", url, params, request_id=request_id
2942
)

openai/api_resources/file.py

Lines changed: 39 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@
33
from typing import cast
44

55
import openai
6-
from openai import api_requestor, util
6+
from openai import api_requestor, util, error
77
from openai.api_resources.abstract import DeletableAPIResource, ListableAPIResource
8+
from openai.util import ApiType
89

910

1011
class File(ListableAPIResource, DeletableAPIResource):
@@ -18,6 +19,7 @@ def create(
1819
model=None,
1920
api_key=None,
2021
api_base=None,
22+
api_type=None,
2123
api_version=None,
2224
organization=None,
2325
user_provided_filename=None,
@@ -27,35 +29,61 @@ def create(
2729
requestor = api_requestor.APIRequestor(
2830
api_key,
2931
api_base=api_base or openai.api_base,
32+
api_type=api_type,
3033
api_version=api_version,
3134
organization=organization,
3235
)
33-
url = cls.class_url()
36+
typed_api_type, api_version = cls._get_api_type_and_version(api_type, api_version)
37+
38+
if typed_api_type == ApiType.AZURE:
39+
base = cls.class_url()
40+
url = "/%s%s?api-version=%s" % (cls.azure_api_prefix, base, api_version)
41+
elif typed_api_type == ApiType.OPEN_AI:
42+
url = cls.class_url()
43+
else:
44+
raise error.InvalidAPIType('Unsupported API type %s' % api_type)
45+
3446
# Set the filename on 'purpose' and 'model' to None so they are
3547
# interpreted as form data.
3648
files = [("purpose", (None, purpose))]
3749
if model is not None:
3850
files.append(("model", (None, model)))
3951
if user_provided_filename is not None:
40-
files.append(("file", (user_provided_filename, file)))
52+
files.append(("file", (user_provided_filename, file, 'application/octet-stream')))
4153
else:
42-
files.append(("file", file))
54+
files.append(("file", file, 'application/octet-stream'))
4355
response, _, api_key = requestor.request("post", url, files=files)
4456
return util.convert_to_openai_object(
4557
response, api_key, api_version, organization
4658
)
4759

4860
@classmethod
4961
def download(
50-
cls, id, api_key=None, api_base=None, api_version=None, organization=None
62+
cls,
63+
id,
64+
api_key=None,
65+
api_base=None,
66+
api_type=None,
67+
api_version=None,
68+
organization=None
5169
):
5270
requestor = api_requestor.APIRequestor(
5371
api_key,
5472
api_base=api_base or openai.api_base,
73+
api_type=api_type,
5574
api_version=api_version,
5675
organization=organization,
5776
)
58-
url = f"{cls.class_url()}/{id}/content"
77+
typed_api_type, api_version = cls._get_api_type_and_version(api_type, api_version)
78+
79+
if typed_api_type == ApiType.AZURE:
80+
base = cls.class_url()
81+
url = "/%s%s/%s/content?api-version=%s" % (cls.azure_api_prefix, base, id, api_version)
82+
elif typed_api_type == ApiType.OPEN_AI:
83+
url = f"{cls.class_url()}/{id}/content"
84+
else:
85+
raise error.InvalidAPIType('Unsupported API type %s' % api_type)
86+
5987
result = requestor.request_raw("get", url)
6088
if not 200 <= result.status_code < 300:
6189
raise requestor.handle_error_response(
@@ -75,13 +103,15 @@ def find_matching_files(
75103
purpose,
76104
api_key=None,
77105
api_base=None,
106+
api_type=None,
78107
api_version=None,
79108
organization=None,
80109
):
81110
"""Find already uploaded files with the same name, size, and purpose."""
82111
all_files = cls.list(
83112
api_key=api_key,
84113
api_base=api_base or openai.api_base,
114+
api_type=api_type,
85115
api_version=api_version,
86116
organization=organization,
87117
).get("data", [])
@@ -93,7 +123,9 @@ def find_matching_files(
93123
file_basename = os.path.basename(f["filename"])
94124
if file_basename != basename:
95125
continue
96-
if f["bytes"] != bytes:
126+
if "bytes" in f and f["bytes"] != bytes:
127+
continue
128+
if "size" in f and int(f["size"]) != bytes:
97129
continue
98130
matching_files.append(f)
99131
return matching_files

0 commit comments

Comments
 (0)