Skip to content

Commit 83ea7d0

Browse files
authored
Merge pull request #21 from matlab-deep-learning/fix-image-gen-bugs
Fixing checks for images and empty prompts.
2 parents 69ca8b6 + 2d27fec commit 83ea7d0

File tree

3 files changed

+24
-8
lines changed

3 files changed

+24
-8
lines changed

openAIImages.m

+6-4
Original file line numberDiff line numberDiff line change
@@ -85,8 +85,8 @@
8585
% Only "dall-e-3" supports this parameter.
8686

8787
arguments
88-
this (1,1) openAIImages
89-
prompt {mustBeTextScalar}
88+
this (1,1) openAIImages
89+
prompt {mustBeNonzeroLengthTextScalar}
9090
nvp.NumImages (1,1) {mustBePositive, mustBeInteger,...
9191
mustBeLessThanOrEqual(nvp.NumImages,10)} = 1
9292
nvp.Size (1,1) string {mustBeMember(nvp.Size, ["256x256", "512x512", ...
@@ -176,7 +176,7 @@
176176
arguments
177177
this (1,1) openAIImages
178178
imagePath {mustBeValidFileType(imagePath)}
179-
prompt {mustBeTextScalar}
179+
prompt {mustBeNonzeroLengthTextScalar}
180180
nvp.MaskImagePath {mustBeValidFileType(nvp.MaskImagePath)}
181181
nvp.NumImages (1,1) {mustBePositive, mustBeInteger,...
182182
mustBeLessThanOrEqual(nvp.NumImages,10)} = 1
@@ -345,7 +345,9 @@ function validatePromptSize(model, prompt)
345345
function mustBeValidFileType(filePath)
346346
mustBeFile(filePath);
347347
s = dir(filePath);
348-
if ~endsWith(s.name, ".png")
348+
imgDetails = imfinfo(filePath);
349+
imgFormat = imgDetails.Format;
350+
if ~(imgFormat=="png")
349351
error("llms:pngExpected", ...
350352
llms.utils.errorMessageCatalog.getMessage("llms:pngExpected"));
351353
end

tests/test_files/solar.png

64.7 KB
Loading

tests/topenAIImages.m

+18-4
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,12 @@ function constructModelWithAllNVP(testCase)
105105
testCase.verifyEqual(mdl.ModelName, modelName);
106106
end
107107

108+
function fakePNGImage(testCase)
109+
mdl = openAIImages(ApiKey="this-is-not-a-real-key");
110+
fakePng = fullfile("test_files", "solar.png");
111+
testCase.verifyError(@()edit(mdl,fakePng, "bla"), "llms:pngExpected");
112+
end
113+
108114
function invalidInputsConstructor(testCase, InvalidConstructorInput)
109115
testCase.verifyError(@()openAIImages(InvalidConstructorInput.Input{:}), InvalidConstructorInput.Error);
110116
end
@@ -157,11 +163,15 @@ function invalidInputsVariation(testCase, InvalidVariationInput)
157163
invalidGenerateInput = struct( ...
158164
"EmptyInput",struct( ...
159165
"Input",{{ [] }},...
160-
"Error","MATLAB:validators:mustBeTextScalar"),...
166+
"Error","MATLAB:validators:mustBeNonzeroLengthText"),...
161167
...
162168
"InvalidInputType",struct( ...
163169
"Input",{{ 123 }},...
164-
"Error","MATLAB:validators:mustBeTextScalar"),...
170+
"Error","MATLAB:validators:mustBeNonzeroLengthText"),...
171+
...
172+
"InvalidPromptLen",struct( ...
173+
"Input",{{ "" }},...
174+
"Error","MATLAB:validators:mustBeNonzeroLengthText"),...
165175
...
166176
"InvalidNumImagesType",struct( ...
167177
"Input",{{ "prompt" "NumImages" "2" }},...
@@ -233,17 +243,21 @@ function invalidInputsVariation(testCase, InvalidVariationInput)
233243
"Input",{{ 123, "prompt" }},...
234244
"Error","MATLAB:validators:mustBeNonzeroLengthText"),...
235245
...
246+
"InvalidPromptLen",struct( ...
247+
"Input",{{ validImage, "" }},...
248+
"Error","MATLAB:validators:mustBeNonzeroLengthText"),...
249+
...
236250
"InvalidImageExtension",struct( ...
237251
"Input",{{ nonPNGImage, "prompt" }},...
238252
"Error","llms:pngExpected"),...
239253
...
240254
"EmptyPrompt",struct( ...
241255
"Input",{{ validImage, [] }},...
242-
"Error","MATLAB:validators:mustBeTextScalar"),...
256+
"Error","MATLAB:validators:mustBeNonzeroLengthText"),...
243257
...
244258
"InvalidPromptType",struct( ...
245259
"Input",{{ validImage, 123 }},...
246-
"Error","MATLAB:validators:mustBeTextScalar"),...
260+
"Error","MATLAB:validators:mustBeNonzeroLengthText"),...
247261
...
248262
"InvalidMaskImage",struct( ...
249263
"Input",{{ validImage, "foo", "MaskImagePath", 123 }},...

0 commit comments

Comments
 (0)