Skip to content

Commit 11a54ea

Browse files
committed
Properly accept char and cellstr in extractOpenAIEmbeddings
Fixes #85
1 parent 3f312b5 commit 11a54ea

File tree

2 files changed

+36
-8
lines changed

2 files changed

+36
-8
lines changed

extractOpenAIEmbeddings.m

+7-6
Original file line numberDiff line numberDiff line change
@@ -23,17 +23,18 @@
2323
% Copyright 2023-2024 The MathWorks, Inc.
2424

2525
arguments
26-
text (1,:) {mustBeNonzeroLengthText}
27-
nvp.ModelName (1,1) {mustBeMember(nvp.ModelName,["text-embedding-ada-002", ...
28-
"text-embedding-3-large", "text-embedding-3-small"])} = "text-embedding-ada-002"
29-
nvp.TimeOut (1,1) {mustBeNumeric,mustBeReal,mustBePositive} = 10
30-
nvp.Dimensions (1,1) {mustBeNumeric,mustBeInteger,mustBePositive}
31-
nvp.APIKey {llms.utils.mustBeNonzeroLengthTextScalar}
26+
text (1,:) {mustBeNonzeroLengthText}
27+
nvp.ModelName (1,1) string {mustBeMember(nvp.ModelName,["text-embedding-ada-002", ...
28+
"text-embedding-3-large", "text-embedding-3-small"])} = "text-embedding-ada-002"
29+
nvp.TimeOut (1,1) {mustBeNumeric,mustBeReal,mustBePositive} = 10
30+
nvp.Dimensions (1,1) {mustBeNumeric,mustBeInteger,mustBePositive}
31+
nvp.APIKey {llms.utils.mustBeNonzeroLengthTextScalar}
3232
end
3333

3434
END_POINT = "https://api.openai.com/v1/embeddings";
3535

3636
key = llms.internal.getApiKeyFromNvpOrEnv(nvp,"OPENAI_API_KEY");
37+
text = convertCharsToStrings(text);
3738

3839
parameters = struct("input",text,"model",nvp.ModelName);
3940

tests/textractOpenAIEmbeddings.m

+29-2
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
properties(TestParameter)
77
InvalidInput = iGetInvalidInput();
8+
ValidInput = iGetValidInput();
89
ValidDimensionsModelCombinations = iGetValidDimensionsModelCombinations();
910
end
1011

@@ -34,8 +35,9 @@ function validCombinationOfModelAndDimension(testCase, ValidDimensionsModelCombi
3435
APIKey="not-real"));
3536
end
3637

37-
function embedStringWithSuccessfulOpenAICall(testCase)
38-
testCase.verifyWarningFree(@()extractOpenAIEmbeddings("bla"));
38+
function embedTextWithSuccessfulOpenAICall(testCase,ValidInput)
39+
result = testCase.verifyWarningFree(@()extractOpenAIEmbeddings(ValidInput.Input{:}));
40+
testCase.verifySize(result, ValidInput.ExpectedSize);
3941
end
4042

4143
function invalidCombinationOfModelAndDimension(testCase)
@@ -57,6 +59,31 @@ function testInvalidInputs(testCase, InvalidInput)
5759
end
5860
end
5961

62+
function validInput = iGetValidInput()
63+
validInput = struct( ...
64+
"ScalarString", struct( ...
65+
"Input",{{ "blah" }}, ...
66+
"ExpectedSize",[1,1536]), ...
67+
"StringVector", struct( ...
68+
"Input",{{ ["a", "b", "c"] }}, ...
69+
"ExpectedSize",[3,1536]), ...
70+
"CharVector", struct( ...
71+
"Input", {{ 'foo' }}, ...
72+
"ExpectedSize",[1,1536]), ...
73+
"Cellstr", struct( ...
74+
"Input",{{ {'cat', 'dog', 'mouse'} }}, ...
75+
"ExpectedSize",[3,1536]), ...
76+
"ModelAsString", struct( ...
77+
"Input",{{ "foo","ModelName","text-embedding-3-small" }}, ...
78+
"ExpectedSize",[1,1536]), ...
79+
"ModelAsChar", struct( ...
80+
"Input",{{ "foo","ModelName",'text-embedding-3-small' }}, ...
81+
"ExpectedSize",[1,1536]), ...
82+
"ModelAsCellstr", struct( ...
83+
"Input",{{ "foo","ModelName",{'text-embedding-3-small'} }}, ...
84+
"ExpectedSize",[1,1536]));
85+
end
86+
6087
function invalidInput = iGetInvalidInput()
6188
invalidInput = struct( ...
6289
"InvalidEmptyText", struct( ...

0 commit comments

Comments
 (0)