Skip to content

Commit dbeecd4

Browse files
authored
Add basic Vertex Java compilation tests (#6810)
This is a starting point for compilation testing, broadly using most symbols Vertex exposes to Java users, lightly validating usability and structure. As a note, there are some builder patterns in Vertex that don't work as expected from Java and were omitted, as fixing that would be a breaking change.
1 parent 5ff9d95 commit dbeecd4

File tree

2 files changed

+240
-0
lines changed

2 files changed

+240
-0
lines changed

firebase-vertexai/firebase-vertexai.gradle.kts

+1
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ android {
6464
}
6565
}
6666
lint { targetSdk = targetSdkVersion }
67+
sourceSets { getByName("test").java.srcDirs("src/testUtil") }
6768
}
6869

6970
// Enable Kotlin "Explicit API Mode". This causes the Kotlin compiler to fail if any
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,239 @@
1+
/*
2+
* Copyright 2025 Google LLC
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package java.com.google.firebase.vertexai;
18+
19+
import android.graphics.Bitmap;
20+
import com.google.common.util.concurrent.ListenableFuture;
21+
import com.google.firebase.concurrent.FirebaseExecutors;
22+
import com.google.firebase.vertexai.FirebaseVertexAI;
23+
import com.google.firebase.vertexai.GenerativeModel;
24+
import com.google.firebase.vertexai.java.ChatFutures;
25+
import com.google.firebase.vertexai.java.GenerativeModelFutures;
26+
import com.google.firebase.vertexai.type.BlockReason;
27+
import com.google.firebase.vertexai.type.Candidate;
28+
import com.google.firebase.vertexai.type.Citation;
29+
import com.google.firebase.vertexai.type.CitationMetadata;
30+
import com.google.firebase.vertexai.type.Content;
31+
import com.google.firebase.vertexai.type.ContentModality;
32+
import com.google.firebase.vertexai.type.CountTokensResponse;
33+
import com.google.firebase.vertexai.type.FileDataPart;
34+
import com.google.firebase.vertexai.type.FinishReason;
35+
import com.google.firebase.vertexai.type.FunctionCallPart;
36+
import com.google.firebase.vertexai.type.GenerateContentResponse;
37+
import com.google.firebase.vertexai.type.HarmCategory;
38+
import com.google.firebase.vertexai.type.HarmProbability;
39+
import com.google.firebase.vertexai.type.HarmSeverity;
40+
import com.google.firebase.vertexai.type.ImagePart;
41+
import com.google.firebase.vertexai.type.InlineDataPart;
42+
import com.google.firebase.vertexai.type.ModalityTokenCount;
43+
import com.google.firebase.vertexai.type.Part;
44+
import com.google.firebase.vertexai.type.PromptFeedback;
45+
import com.google.firebase.vertexai.type.SafetyRating;
46+
import com.google.firebase.vertexai.type.TextPart;
47+
import com.google.firebase.vertexai.type.UsageMetadata;
48+
import java.util.Calendar;
49+
import java.util.List;
50+
import java.util.Map;
51+
import java.util.concurrent.Executor;
52+
import kotlinx.serialization.json.JsonElement;
53+
import kotlinx.serialization.json.JsonNull;
54+
import org.junit.Assert;
55+
import org.reactivestreams.Publisher;
56+
import org.reactivestreams.Subscriber;
57+
import org.reactivestreams.Subscription;
58+
59+
/**
60+
* Tests in this file exist to be compiled, not invoked
61+
*/
62+
public class JavaCompileTests {
63+
64+
public void initializeJava() throws Exception {
65+
FirebaseVertexAI vertex = FirebaseVertexAI.getInstance();
66+
GenerativeModel model = vertex.generativeModel("fake-model-name");
67+
GenerativeModelFutures futures = GenerativeModelFutures.from(model);
68+
testFutures(futures);
69+
}
70+
71+
private void testFutures(GenerativeModelFutures futures) throws Exception {
72+
Content content =
73+
new Content.Builder()
74+
.addText("Fake prompt")
75+
.addFileData("fakeuri", "image/png")
76+
.addInlineData(new byte[] {}, "text/json")
77+
.addImage(Bitmap.createBitmap(0, 0, Bitmap.Config.HARDWARE))
78+
.addPart(new FunctionCallPart("fakeFunction", Map.of("fakeArg", JsonNull.INSTANCE)))
79+
.build();
80+
// TODO b/406558430 Content.Builder.setParts and Content.Builder.setRole return void
81+
Executor executor = FirebaseExecutors.directExecutor();
82+
ListenableFuture<CountTokensResponse> countResponse = futures.countTokens(content);
83+
validateCountTokensResponse(countResponse.get());
84+
ListenableFuture<GenerateContentResponse> generateResponse = futures.generateContent(content);
85+
validateGenerateContentResponse(generateResponse.get());
86+
ChatFutures chat = futures.startChat();
87+
ListenableFuture<GenerateContentResponse> future = chat.sendMessage(content);
88+
future.addListener(
89+
() -> {
90+
try {
91+
validateGenerateContentResponse(future.get());
92+
} catch (Exception e) {
93+
// Ignore
94+
}
95+
},
96+
executor);
97+
Publisher<GenerateContentResponse> responsePublisher = futures.generateContentStream(content);
98+
responsePublisher.subscribe(
99+
new Subscriber<GenerateContentResponse>() {
100+
private boolean complete = false;
101+
102+
@Override
103+
public void onSubscribe(Subscription s) {
104+
s.request(Long.MAX_VALUE);
105+
}
106+
107+
@Override
108+
public void onNext(GenerateContentResponse response) {
109+
Assert.assertFalse(complete);
110+
validateGenerateContentResponse(response);
111+
}
112+
113+
@Override
114+
public void onError(Throwable t) {
115+
// Ignore
116+
}
117+
118+
@Override
119+
public void onComplete() {
120+
complete = true;
121+
}
122+
});
123+
}
124+
125+
public void validateCountTokensResponse(CountTokensResponse response) {
126+
int tokens = response.getTotalTokens();
127+
Integer billable = response.getTotalBillableCharacters();
128+
Assert.assertEquals(tokens, response.component1());
129+
Assert.assertEquals(billable, response.component2());
130+
Assert.assertEquals(response.getPromptTokensDetails(), response.component3());
131+
for (ModalityTokenCount count : response.getPromptTokensDetails()) {
132+
ContentModality modality = count.getModality();
133+
int tokenCount = count.getTokenCount();
134+
}
135+
}
136+
137+
public void validateGenerateContentResponse(GenerateContentResponse response) {
138+
List<Candidate> candidates = response.getCandidates();
139+
if (candidates.size() == 1
140+
&& candidates.get(0).getContent().getParts().stream()
141+
.anyMatch(p -> p instanceof TextPart && !((TextPart) p).getText().isEmpty())) {
142+
String text = response.getText();
143+
Assert.assertNotNull(text);
144+
Assert.assertFalse(text.isBlank());
145+
}
146+
validateCandidates(candidates);
147+
validateFunctionCalls(response.getFunctionCalls());
148+
validatePromptFeedback(response.getPromptFeedback());
149+
validateUsageMetadata(response.getUsageMetadata());
150+
}
151+
152+
public void validateCandidates(List<Candidate> candidates) {
153+
for (Candidate candidate : candidates) {
154+
validateCitationMetadata(candidate.getCitationMetadata());
155+
FinishReason reason = candidate.getFinishReason();
156+
validateSafetyRatings(candidate.getSafetyRatings());
157+
validateCitationMetadata(candidate.getCitationMetadata());
158+
validateContent(candidate.getContent());
159+
}
160+
}
161+
162+
public void validateContent(Content content) {
163+
String role = content.getRole();
164+
for (Part part : content.getParts()) {
165+
if (part instanceof TextPart) {
166+
String text = ((TextPart) part).getText();
167+
} else if (part instanceof ImagePart) {
168+
Bitmap bitmap = ((ImagePart) part).getImage();
169+
} else if (part instanceof InlineDataPart) {
170+
String mime = ((InlineDataPart) part).getMimeType();
171+
byte[] data = ((InlineDataPart) part).getInlineData();
172+
} else if (part instanceof FileDataPart) {
173+
String mime = ((FileDataPart) part).getMimeType();
174+
String uri = ((FileDataPart) part).getUri();
175+
}
176+
}
177+
}
178+
179+
public void validateCitationMetadata(CitationMetadata metadata) {
180+
if (metadata != null) {
181+
for (Citation citation : metadata.getCitations()) {
182+
String uri = citation.getUri();
183+
String license = citation.getLicense();
184+
Calendar calendar = citation.getPublicationDate();
185+
int startIndex = citation.getStartIndex();
186+
int endIndex = citation.getEndIndex();
187+
Assert.assertTrue(startIndex <= endIndex);
188+
}
189+
}
190+
}
191+
192+
public void validateFunctionCalls(List<FunctionCallPart> parts) {
193+
if (parts != null) {
194+
for (FunctionCallPart part : parts) {
195+
String functionName = part.getName();
196+
Map<String, JsonElement> args = part.getArgs();
197+
Assert.assertFalse(functionName.isBlank());
198+
}
199+
}
200+
}
201+
202+
public void validatePromptFeedback(PromptFeedback feedback) {
203+
if (feedback != null) {
204+
String message = feedback.getBlockReasonMessage();
205+
BlockReason reason = feedback.getBlockReason();
206+
validateSafetyRatings(feedback.getSafetyRatings());
207+
}
208+
}
209+
210+
public void validateSafetyRatings(List<SafetyRating> ratings) {
211+
for (SafetyRating rating : ratings) {
212+
Boolean blocked = rating.getBlocked();
213+
HarmCategory category = rating.getCategory();
214+
HarmProbability probability = rating.getProbability();
215+
float score = rating.getProbabilityScore();
216+
HarmSeverity severity = rating.getSeverity();
217+
Float severityScore = rating.getSeverityScore();
218+
if (severity != null) {
219+
Assert.assertNotNull(severityScore);
220+
}
221+
}
222+
}
223+
224+
public void validateUsageMetadata(UsageMetadata metadata) {
225+
if (metadata != null) {
226+
int totalTokens = metadata.getTotalTokenCount();
227+
int promptTokenCount = metadata.getPromptTokenCount();
228+
for (ModalityTokenCount count : metadata.getPromptTokensDetails()) {
229+
ContentModality modality = count.getModality();
230+
int tokenCount = count.getTokenCount();
231+
}
232+
Integer candidatesTokenCount = metadata.getCandidatesTokenCount();
233+
for (ModalityTokenCount count : metadata.getCandidatesTokensDetails()) {
234+
ContentModality modality = count.getModality();
235+
int tokenCount = count.getTokenCount();
236+
}
237+
}
238+
}
239+
}

0 commit comments

Comments
 (0)