Skip to content

Commit 362a1a5

Browse files
committed
Add persistent client connections
1 parent ea352b2 commit 362a1a5

File tree

3 files changed

+30
-27
lines changed

3 files changed

+30
-27
lines changed

lib/openai/client.rb

+21-2
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ class Client
1212
request_timeout
1313
extra_headers
1414
].freeze
15-
attr_reader *CONFIG_KEYS, :faraday_middleware
15+
attr_reader *CONFIG_KEYS
1616

1717
def initialize(config = {}, &faraday_middleware)
1818
CONFIG_KEYS.each do |key|
@@ -23,7 +23,12 @@ def initialize(config = {}, &faraday_middleware)
2323
config[key].nil? ? OpenAI.configuration.send(key) : config[key]
2424
)
2525
end
26-
@faraday_middleware = faraday_middleware
26+
27+
@connection = build_connection
28+
faraday_middleware&.call(@connection)
29+
30+
@multipart_connection = build_connection(multipart: true)
31+
faraday_middleware&.call(@multipart_connection)
2732
end
2833

2934
def chat(parameters: {})
@@ -107,5 +112,19 @@ def beta(apis)
107112
client.add_headers("OpenAI-Beta": apis.map { |k, v| "#{k}=#{v}" }.join(";"))
108113
end
109114
end
115+
116+
private
117+
118+
attr_reader :connection, :multipart_connection
119+
120+
def build_connection(multipart: false)
121+
Faraday.new do |faraday|
122+
faraday.options[:timeout] = @request_timeout
123+
faraday.request(:multipart) if multipart
124+
faraday.use MiddlewareErrors if @log_errors
125+
faraday.response :raise_error
126+
faraday.response :json
127+
end
128+
end
110129
end
111130
end

lib/openai/http.rb

+5-19
Original file line numberDiff line numberDiff line change
@@ -7,32 +7,32 @@ module HTTP
77
include HTTPHeaders
88

99
def get(path:, parameters: nil)
10-
parse_jsonl(conn.get(uri(path: path), parameters) do |req|
10+
parse_jsonl(connection.get(uri(path: path), parameters) do |req|
1111
req.headers = headers
1212
end&.body)
1313
end
1414

1515
def post(path:)
16-
parse_jsonl(conn.post(uri(path: path)) do |req|
16+
parse_jsonl(connection.post(uri(path: path)) do |req|
1717
req.headers = headers
1818
end&.body)
1919
end
2020

2121
def json_post(path:, parameters:)
22-
conn.post(uri(path: path)) do |req|
22+
connection.post(uri(path: path)) do |req|
2323
configure_json_post_request(req, parameters)
2424
end&.body
2525
end
2626

2727
def multipart_post(path:, parameters: nil)
28-
conn(multipart: true).post(uri(path: path)) do |req|
28+
multipart_connection.post(uri(path: path)) do |req|
2929
req.headers = headers.merge({ "Content-Type" => "multipart/form-data" })
3030
req.body = multipart_parameters(parameters)
3131
end&.body
3232
end
3333

3434
def delete(path:)
35-
conn.delete(uri(path: path)) do |req|
35+
connection.delete(uri(path: path)) do |req|
3636
req.headers = headers
3737
end&.body
3838
end
@@ -70,20 +70,6 @@ def to_json_stream(user_proc:)
7070
end
7171
end
7272

73-
def conn(multipart: false)
74-
connection = Faraday.new do |f|
75-
f.options[:timeout] = @request_timeout
76-
f.request(:multipart) if multipart
77-
f.use MiddlewareErrors if @log_errors
78-
f.response :raise_error
79-
f.response :json
80-
end
81-
82-
@faraday_middleware&.call(connection)
83-
84-
connection
85-
end
86-
8773
def uri(path:)
8874
if azure?
8975
base = File.join(@uri_base, path)

spec/openai/client/client_spec.rb

+4-6
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
expect(c0.uri_base).to eq(OpenAI::Configuration::DEFAULT_URI_BASE)
4545
expect(c0.send(:headers).values).to include("Bearer #{c0.access_token}")
4646
expect(c0.send(:headers).values).to include(c0.organization_id)
47-
expect(c0.send(:conn).options.timeout).to eq(OpenAI::Configuration::DEFAULT_REQUEST_TIMEOUT)
47+
expect(c0.send(:connection).options.timeout).to eq(OpenAI::Configuration::DEFAULT_REQUEST_TIMEOUT)
4848
expect(c0.send(:uri, path: "")).to include(OpenAI::Configuration::DEFAULT_URI_BASE)
4949
expect(c0.send(:headers).values).to include("X-Default")
5050
expect(c0.send(:headers).values).not_to include("X-Test")
@@ -55,7 +55,7 @@
5555
expect(c1.request_timeout).to eq(60)
5656
expect(c1.uri_base).to eq("https://oai.hconeai.com/")
5757
expect(c1.send(:headers).values).to include(c1.access_token)
58-
expect(c1.send(:conn).options.timeout).to eq(60)
58+
expect(c1.send(:connection).options.timeout).to eq(60)
5959
expect(c1.send(:uri, path: "")).to include("https://oai.hconeai.com/")
6060
expect(c1.send(:headers).values).not_to include("X-Default")
6161
expect(c1.send(:headers).values).to include("X-Test")
@@ -67,7 +67,7 @@
6767
expect(c2.uri_base).to eq("https://example.com/")
6868
expect(c2.send(:headers).values).to include("Bearer #{c2.access_token}")
6969
expect(c2.send(:headers).values).to include(c2.organization_id)
70-
expect(c2.send(:conn).options.timeout).to eq(1)
70+
expect(c2.send(:connection).options.timeout).to eq(1)
7171
expect(c2.send(:uri, path: "")).to include("https://example.com/")
7272
expect(c2.send(:headers).values).to include("X-Default")
7373
expect(c2.send(:headers).values).not_to include("X-Test")
@@ -128,9 +128,7 @@
128128
end
129129

130130
it "sets the logger" do
131-
connection = Faraday.new
132-
client.faraday_middleware.call(connection)
133-
expect(connection.builder.handlers).to include Faraday::Response::Logger
131+
expect(client.send(:connection).builder.handlers).to include Faraday::Response::Logger
134132
end
135133
end
136134
end

0 commit comments

Comments
 (0)