Skip to content

Commit bb7dab5

Browse files
authored
[Vertex AI] Move ImagenModelConfig params to ImagenGenerationConfig (#14340)
1 parent dbbfb38 commit bb7dab5

File tree

6 files changed

+29
-81
lines changed

6 files changed

+29
-81
lines changed

FirebaseVertexAI/Sources/Types/Public/Imagen/ImagenGenerationConfig.swift

+6-1
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,17 @@
1616
public struct ImagenGenerationConfig {
1717
public var numberOfImages: Int?
1818
public var negativePrompt: String?
19+
public var imageFormat: ImagenImageFormat?
1920
public var aspectRatio: ImagenAspectRatio?
21+
public var addWatermark: Bool?
2022

2123
public init(numberOfImages: Int? = nil, negativePrompt: String? = nil,
22-
aspectRatio: ImagenAspectRatio? = nil) {
24+
imageFormat: ImagenImageFormat? = nil, aspectRatio: ImagenAspectRatio? = nil,
25+
addWatermark: Bool? = nil) {
2326
self.numberOfImages = numberOfImages
2427
self.negativePrompt = negativePrompt
28+
self.imageFormat = imageFormat
2529
self.aspectRatio = aspectRatio
30+
self.addWatermark = addWatermark
2631
}
2732
}

FirebaseVertexAI/Sources/Types/Public/Imagen/ImagenModel.swift

+2-9
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,6 @@ public final class ImagenModel {
2424
/// The backing service responsible for sending and receiving model requests to the backend.
2525
let generativeAIService: GenerativeAIService
2626

27-
let modelConfig: ImagenModelConfig?
28-
2927
let safetySettings: ImagenSafetySettings?
3028

3129
/// Configuration parameters for sending requests to the backend.
@@ -34,7 +32,6 @@ public final class ImagenModel {
3432
init(name: String,
3533
projectID: String,
3634
apiKey: String,
37-
modelConfig: ImagenModelConfig?,
3835
safetySettings: ImagenSafetySettings?,
3936
requestOptions: RequestOptions,
4037
appCheck: AppCheckInterop?,
@@ -48,7 +45,6 @@ public final class ImagenModel {
4845
auth: auth,
4946
urlSession: urlSession
5047
)
51-
self.modelConfig = modelConfig
5248
self.safetySettings = safetySettings
5349
self.requestOptions = requestOptions
5450
}
@@ -61,7 +57,6 @@ public final class ImagenModel {
6157
parameters: ImagenModel.imageGenerationParameters(
6258
storageURI: nil,
6359
generationConfig: generationConfig,
64-
modelConfig: modelConfig,
6560
safetySettings: safetySettings
6661
)
6762
)
@@ -75,7 +70,6 @@ public final class ImagenModel {
7570
parameters: ImagenModel.imageGenerationParameters(
7671
storageURI: storageURI,
7772
generationConfig: generationConfig,
78-
modelConfig: modelConfig,
7973
safetySettings: safetySettings
8074
)
8175
)
@@ -96,7 +90,6 @@ public final class ImagenModel {
9690

9791
static func imageGenerationParameters(storageURI: String?,
9892
generationConfig: ImagenGenerationConfig?,
99-
modelConfig: ImagenModelConfig?,
10093
safetySettings: ImagenSafetySettings?)
10194
-> ImageGenerationParameters {
10295
return ImageGenerationParameters(
@@ -106,13 +99,13 @@ public final class ImagenModel {
10699
aspectRatio: generationConfig?.aspectRatio?.rawValue,
107100
safetyFilterLevel: safetySettings?.safetyFilterLevel?.rawValue,
108101
personGeneration: safetySettings?.personFilterLevel?.rawValue,
109-
outputOptions: modelConfig?.imageFormat.map {
102+
outputOptions: generationConfig?.imageFormat.map {
110103
ImageGenerationOutputOptions(
111104
mimeType: $0.mimeType,
112105
compressionQuality: $0.compressionQuality
113106
)
114107
},
115-
addWatermark: modelConfig?.addWatermark,
108+
addWatermark: generationConfig?.addWatermark,
116109
includeResponsibleAIFilterReason: true
117110
)
118111
}

FirebaseVertexAI/Sources/Types/Public/Imagen/ImagenModelConfig.swift

-24
This file was deleted.

FirebaseVertexAI/Sources/VertexAI.swift

+1-3
Original file line numberDiff line numberDiff line change
@@ -104,14 +104,12 @@ public class VertexAI {
104104
)
105105
}
106106

107-
public func imagenModel(modelName: String, modelConfig: ImagenModelConfig? = nil,
108-
safetySettings: ImagenSafetySettings? = nil,
107+
public func imagenModel(modelName: String, safetySettings: ImagenSafetySettings? = nil,
109108
requestOptions: RequestOptions = RequestOptions()) -> ImagenModel {
110109
return ImagenModel(
111110
name: modelResourceName(modelName: modelName),
112111
projectID: projectID,
113112
apiKey: apiKey,
114-
modelConfig: modelConfig,
115113
safetySettings: safetySettings,
116114
requestOptions: requestOptions,
117115
appCheck: appCheck,

FirebaseVertexAI/Tests/TestApp/Tests/Integration/IntegrationTests.swift

+4-2
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,6 @@ final class IntegrationTests: XCTestCase {
6363
)
6464
imagenModel = vertex.imagenModel(
6565
modelName: "imagen-3.0-fast-generate-001",
66-
modelConfig: ImagenModelConfig(imageFormat: .jpeg(compressionQuality: 70)),
6766
safetySettings: ImagenSafetySettings(
6867
safetyFilterLevel: .blockLowAndAbove,
6968
personFilterLevel: .blockAll
@@ -254,7 +253,10 @@ final class IntegrationTests: XCTestCase {
254253
overlooking a vast African savanna at sunset. Golden hour light, long shadows, sharp focus on
255254
the lion, shallow depth of field, detailed fur texture, DSLR, 85mm lens.
256255
"""
257-
let generationConfig = ImagenGenerationConfig(aspectRatio: .landscape16x9)
256+
let generationConfig = ImagenGenerationConfig(
257+
imageFormat: .jpeg(compressionQuality: 70),
258+
aspectRatio: .landscape16x9
259+
)
258260

259261
let response = try await imagenModel.generateImages(
260262
prompt: imagePrompt,

FirebaseVertexAI/Tests/Unit/Types/Imagen/ImageGenerationParametersTests.swift

+16-42
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@ final class ImageGenerationParametersTests: XCTestCase {
4040
let parameters = ImagenModel.imageGenerationParameters(
4141
storageURI: nil,
4242
generationConfig: nil,
43-
modelConfig: nil,
4443
safetySettings: nil
4544
)
4645

@@ -64,37 +63,6 @@ final class ImageGenerationParametersTests: XCTestCase {
6463
let parameters = ImagenModel.imageGenerationParameters(
6564
storageURI: storageURI,
6665
generationConfig: nil,
67-
modelConfig: nil,
68-
safetySettings: nil
69-
)
70-
71-
XCTAssertEqual(parameters, expectedParameters)
72-
}
73-
74-
func testParameters_includeModelConfig() throws {
75-
let compressionQuality = 80
76-
let imageFormat = ImagenImageFormat.jpeg(compressionQuality: compressionQuality)
77-
let addWatermark = true
78-
let modelConfig = ImagenModelConfig(imageFormat: imageFormat, addWatermark: addWatermark)
79-
let expectedParameters = ImageGenerationParameters(
80-
sampleCount: 1,
81-
storageURI: nil,
82-
negativePrompt: nil,
83-
aspectRatio: nil,
84-
safetyFilterLevel: nil,
85-
personGeneration: nil,
86-
outputOptions: ImageGenerationOutputOptions(
87-
mimeType: imageFormat.mimeType,
88-
compressionQuality: imageFormat.compressionQuality
89-
),
90-
addWatermark: addWatermark,
91-
includeResponsibleAIFilterReason: true
92-
)
93-
94-
let parameters = ImagenModel.imageGenerationParameters(
95-
storageURI: nil,
96-
generationConfig: nil,
97-
modelConfig: modelConfig,
9866
safetySettings: nil
9967
)
10068

@@ -104,11 +72,16 @@ final class ImageGenerationParametersTests: XCTestCase {
10472
func testParameters_includeGenerationConfig() throws {
10573
let sampleCount = 2
10674
let negativePrompt = "test-negative-prompt"
75+
let compressionQuality = 80
76+
let imageFormat = ImagenImageFormat.jpeg(compressionQuality: compressionQuality)
10777
let aspectRatio = ImagenAspectRatio.landscape16x9
78+
let addWatermark = true
10879
let generationConfig = ImagenGenerationConfig(
10980
numberOfImages: sampleCount,
11081
negativePrompt: negativePrompt,
111-
aspectRatio: aspectRatio
82+
imageFormat: imageFormat,
83+
aspectRatio: aspectRatio,
84+
addWatermark: addWatermark
11285
)
11386
let expectedParameters = ImageGenerationParameters(
11487
sampleCount: sampleCount,
@@ -117,15 +90,17 @@ final class ImageGenerationParametersTests: XCTestCase {
11790
aspectRatio: aspectRatio.rawValue,
11891
safetyFilterLevel: nil,
11992
personGeneration: nil,
120-
outputOptions: nil,
121-
addWatermark: nil,
93+
outputOptions: ImageGenerationOutputOptions(
94+
mimeType: imageFormat.mimeType,
95+
compressionQuality: imageFormat.compressionQuality
96+
),
97+
addWatermark: addWatermark,
12298
includeResponsibleAIFilterReason: true
12399
)
124100

125101
let parameters = ImagenModel.imageGenerationParameters(
126102
storageURI: nil,
127103
generationConfig: generationConfig,
128-
modelConfig: nil,
129104
safetySettings: nil
130105
)
131106

@@ -155,7 +130,6 @@ final class ImageGenerationParametersTests: XCTestCase {
155130
let parameters = ImagenModel.imageGenerationParameters(
156131
storageURI: nil,
157132
generationConfig: nil,
158-
modelConfig: nil,
159133
safetySettings: safetySettings
160134
)
161135

@@ -168,15 +142,16 @@ final class ImageGenerationParametersTests: XCTestCase {
168142
let storageURI = "gs://test-bucket/path"
169143
let sampleCount = 4
170144
let negativePrompt = "test-negative-prompt"
145+
let imageFormat = ImagenImageFormat.png()
171146
let aspectRatio = ImagenAspectRatio.portrait3x4
147+
let addWatermark = false
172148
let generationConfig = ImagenGenerationConfig(
173149
numberOfImages: sampleCount,
174150
negativePrompt: negativePrompt,
175-
aspectRatio: aspectRatio
151+
imageFormat: imageFormat,
152+
aspectRatio: aspectRatio,
153+
addWatermark: addWatermark
176154
)
177-
let imageFormat = ImagenImageFormat.png()
178-
let addWatermark = false
179-
let modelConfig = ImagenModelConfig(imageFormat: imageFormat, addWatermark: addWatermark)
180155
let safetyFilterLevel = ImagenSafetyFilterLevel.blockNone
181156
let personFilterLevel = ImagenPersonFilterLevel.blockAll
182157
let safetySettings = ImagenSafetySettings(
@@ -201,7 +176,6 @@ final class ImageGenerationParametersTests: XCTestCase {
201176
let parameters = ImagenModel.imageGenerationParameters(
202177
storageURI: storageURI,
203178
generationConfig: generationConfig,
204-
modelConfig: modelConfig,
205179
safetySettings: safetySettings
206180
)
207181

0 commit comments

Comments
 (0)