Skip to content

Commit 4db5dc8

Browse files
authored
Add Vertex AI unit tests (#6090)
1 parent d26a5f8 commit 4db5dc8

File tree

11 files changed

+747
-2
lines changed

11 files changed

+747
-2
lines changed

.github/workflows/ci_tests.yml

+5
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,11 @@ jobs:
6464
./gradlew :common:updateVersion common:publishToMavenLocal
6565
cd ..
6666
67+
- name: Clone mock responses
68+
if: matrix.module == ':firebase-vertexai'
69+
run: |
70+
firebase-vertexai/update_responses.sh
71+
6772
- name: Add google-services.json
6873
env:
6974
INTEG_TESTS_GOOGLE_SERVICES: ${{ secrets.INTEG_TESTS_GOOGLE_SERVICES }}

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ smoke-test-logs/
1313
smoke-tests/build-debug-headGit-smoke-test
1414
smoke-tests/firehorn.log
1515
macrobenchmark-output.json
16+
vertexai-sdk-test-data/
1617

1718
# generated Terraform docs
1819
.terraform/*

README.md

+5
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,11 @@ Unit tests can be executed on the command line by running
6363
./gradlew :<firebase-project>:check
6464
```
6565

66+
#### Vertex AI for Firebase
67+
68+
See the Vertex AI for Firebase [README](firebase-vertexai#running-tests) for setup
69+
instructions specific to that project.
70+
6671
### Integration Testing
6772

6873
These are tests that run on a hardware device or emulator. These tests have

firebase-vertexai/README.md

+4
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,10 @@ All Gradle commands should be run from the root of this repository.
1515

1616
## Running Tests
1717

18+
> [!IMPORTANT]
19+
> These unit tests require mock response files, which can be downloaded by running
20+
`./firebase-vertexai/update_responses.sh` from the root of this repository.
21+
1822
Unit tests:
1923

2024
`./gradlew :firebase-vertexai:check`

firebase-vertexai/firebase-vertexai.gradle.kts

+9-2
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,10 @@ android {
4949
targetCompatibility = JavaVersion.VERSION_1_8
5050
}
5151
kotlinOptions { jvmTarget = "1.8" }
52-
testOptions.unitTests.isIncludeAndroidResources = true
52+
testOptions {
53+
unitTests.isIncludeAndroidResources = true
54+
unitTests.isReturnDefaultValues = true
55+
}
5356
}
5457

5558
dependencies {
@@ -58,7 +61,7 @@ dependencies {
5861
implementation("com.google.firebase:firebase-components:18.0.0")
5962
implementation("com.google.firebase:firebase-annotations:16.2.0")
6063
implementation("com.google.firebase:firebase-appcheck-interop:17.1.0")
61-
implementation("com.google.ai.client.generativeai:common:0.7.1")
64+
implementation("com.google.ai.client.generativeai:common:0.9.0")
6265
implementation(libs.androidx.annotation)
6366
implementation("org.jetbrains.kotlinx:kotlinx-serialization-json:1.5.1")
6467
implementation("androidx.core:core-ktx:1.12.0")
@@ -71,8 +74,12 @@ dependencies {
7174
implementation("androidx.concurrent:concurrent-futures-ktx:1.2.0-alpha03")
7275
implementation("com.google.firebase:firebase-auth-interop:18.0.0")
7376

77+
val ktorVersion = "2.3.2"
7478
testImplementation("io.kotest:kotest-assertions-core:5.5.5")
7579
testImplementation("io.kotest:kotest-assertions-core-jvm:5.5.5")
80+
testImplementation("io.ktor:ktor-client-okhttp:$ktorVersion")
81+
testImplementation("io.ktor:ktor-client-mock:$ktorVersion")
82+
testImplementation("org.json:json:20240303")
7683
testImplementation(libs.androidx.test.junit)
7784
testImplementation(libs.androidx.test.runner)
7885
testImplementation(libs.junit)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,187 @@
1+
/*
2+
* Copyright 2024 Google LLC
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package com.google.firebase.vertexai
18+
19+
import com.google.firebase.vertexai.type.BlockReason
20+
import com.google.firebase.vertexai.type.FinishReason
21+
import com.google.firebase.vertexai.type.HarmCategory
22+
import com.google.firebase.vertexai.type.InvalidAPIKeyException
23+
import com.google.firebase.vertexai.type.PromptBlockedException
24+
import com.google.firebase.vertexai.type.ResponseStoppedException
25+
import com.google.firebase.vertexai.type.SerializationException
26+
import com.google.firebase.vertexai.type.ServerException
27+
import com.google.firebase.vertexai.type.TextPart
28+
import com.google.firebase.vertexai.util.goldenStreamingFile
29+
import io.kotest.assertions.throwables.shouldThrow
30+
import io.kotest.matchers.nulls.shouldNotBeNull
31+
import io.kotest.matchers.shouldBe
32+
import io.kotest.matchers.string.shouldContain
33+
import io.ktor.http.HttpStatusCode
34+
import kotlin.time.Duration.Companion.seconds
35+
import kotlinx.coroutines.flow.collect
36+
import kotlinx.coroutines.flow.first
37+
import kotlinx.coroutines.flow.toList
38+
import kotlinx.coroutines.withTimeout
39+
import org.junit.Test
40+
41+
internal class StreamingSnapshotTests {
42+
private val testTimeout = 5.seconds
43+
44+
@Test
45+
fun `short reply`() =
46+
goldenStreamingFile("success-basic-reply-short.txt") {
47+
val responses = model.generateContentStream("prompt")
48+
49+
withTimeout(testTimeout) {
50+
val responseList = responses.toList()
51+
responseList.isEmpty() shouldBe false
52+
responseList.first().candidates.first().finishReason shouldBe FinishReason.STOP
53+
responseList.first().candidates.first().content.parts.isEmpty() shouldBe false
54+
responseList.first().candidates.first().safetyRatings.isEmpty() shouldBe false
55+
}
56+
}
57+
58+
@Test
59+
fun `long reply`() =
60+
goldenStreamingFile("success-basic-reply-long.txt") {
61+
val responses = model.generateContentStream("prompt")
62+
63+
withTimeout(testTimeout) {
64+
val responseList = responses.toList()
65+
responseList.isEmpty() shouldBe false
66+
responseList.forEach {
67+
it.candidates.first().finishReason shouldBe FinishReason.STOP
68+
it.candidates.first().content.parts.isEmpty() shouldBe false
69+
it.candidates.first().safetyRatings.isEmpty() shouldBe false
70+
}
71+
}
72+
}
73+
74+
@Test
75+
fun `unknown enum`() =
76+
goldenStreamingFile("success-unknown-enum.txt") {
77+
val responses = model.generateContentStream("prompt")
78+
79+
withTimeout(testTimeout) {
80+
val responseList = responses.toList()
81+
82+
responseList.isEmpty() shouldBe false
83+
responseList.any {
84+
it.candidates.any { it.safetyRatings.any { it.category == HarmCategory.UNKNOWN } }
85+
} shouldBe true
86+
}
87+
}
88+
89+
@Test
90+
fun `quotes escaped`() =
91+
goldenStreamingFile("success-quotes-escaped.txt") {
92+
val responses = model.generateContentStream("prompt")
93+
94+
withTimeout(testTimeout) {
95+
val responseList = responses.toList()
96+
97+
responseList.isEmpty() shouldBe false
98+
val part = responseList.first().candidates.first().content.parts.first() as? TextPart
99+
part.shouldNotBeNull()
100+
part.text shouldContain "\""
101+
}
102+
}
103+
104+
@Test
105+
fun `prompt blocked for safety`() =
106+
goldenStreamingFile("failure-prompt-blocked-safety.txt") {
107+
val responses = model.generateContentStream("prompt")
108+
109+
withTimeout(testTimeout) {
110+
val exception = shouldThrow<PromptBlockedException> { responses.collect() }
111+
exception.response.promptFeedback?.blockReason shouldBe BlockReason.SAFETY
112+
}
113+
}
114+
115+
@Test
116+
fun `empty content`() =
117+
goldenStreamingFile("failure-empty-content.txt") {
118+
val responses = model.generateContentStream("prompt")
119+
120+
withTimeout(testTimeout) { shouldThrow<SerializationException> { responses.collect() } }
121+
}
122+
123+
@Test
124+
fun `http errors`() =
125+
goldenStreamingFile("failure-http-error.txt", HttpStatusCode.PreconditionFailed) {
126+
val responses = model.generateContentStream("prompt")
127+
128+
withTimeout(testTimeout) { shouldThrow<ServerException> { responses.collect() } }
129+
}
130+
131+
@Test
132+
fun `stopped for safety`() =
133+
goldenStreamingFile("failure-finish-reason-safety.txt") {
134+
val responses = model.generateContentStream("prompt")
135+
136+
withTimeout(testTimeout) {
137+
val exception = shouldThrow<ResponseStoppedException> { responses.collect() }
138+
exception.response.candidates.first().finishReason shouldBe FinishReason.SAFETY
139+
}
140+
}
141+
142+
@Test
143+
fun `citation parsed correctly`() =
144+
goldenStreamingFile("success-citations.txt") {
145+
val responses = model.generateContentStream("prompt")
146+
147+
withTimeout(testTimeout) {
148+
val responseList = responses.toList()
149+
responseList.any { it.candidates.any { it.citationMetadata.isNotEmpty() } } shouldBe true
150+
}
151+
}
152+
153+
@Test
154+
fun `stopped for recitation`() =
155+
goldenStreamingFile("failure-recitation-no-content.txt") {
156+
val responses = model.generateContentStream("prompt")
157+
158+
withTimeout(testTimeout) {
159+
val exception = shouldThrow<ResponseStoppedException> { responses.collect() }
160+
exception.response.candidates.first().finishReason shouldBe FinishReason.RECITATION
161+
}
162+
}
163+
164+
@Test
165+
fun `image rejected`() =
166+
goldenStreamingFile("failure-image-rejected.txt", HttpStatusCode.BadRequest) {
167+
val responses = model.generateContentStream("prompt")
168+
169+
withTimeout(testTimeout) { shouldThrow<ServerException> { responses.collect() } }
170+
}
171+
172+
@Test
173+
fun `unknown model`() =
174+
goldenStreamingFile("failure-unknown-model.txt", HttpStatusCode.NotFound) {
175+
val responses = model.generateContentStream("prompt")
176+
177+
withTimeout(testTimeout) { shouldThrow<ServerException> { responses.collect() } }
178+
}
179+
180+
@Test
181+
fun `invalid api key`() =
182+
goldenStreamingFile("failure-api-key.txt", HttpStatusCode.BadRequest) {
183+
val responses = model.generateContentStream("prompt")
184+
185+
withTimeout(testTimeout) { shouldThrow<InvalidAPIKeyException> { responses.collect() } }
186+
}
187+
}

0 commit comments

Comments
 (0)