Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions dd-java-agent/agent-aiguard/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ dependencies {
implementation libs.okhttp

api project(':dd-trace-api')
api project(':utils:version-utils')
implementation project(':internal-api')
implementation project(':communication')

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
package com.datadog.aiguard;

import static datadog.communication.ddagent.TracerVersion.TRACER_VERSION;
import static datadog.trace.api.telemetry.WafMetricCollector.AIGuardTruncationType.CONTENT;
import static datadog.trace.api.telemetry.WafMetricCollector.AIGuardTruncationType.MESSAGES;
import static datadog.trace.util.Strings.isBlank;
import static java.util.Collections.singletonMap;

Expand All @@ -21,6 +24,7 @@
import datadog.trace.api.aiguard.AIGuard.ToolCall.Function;
import datadog.trace.api.aiguard.Evaluator;
import datadog.trace.api.aiguard.noop.NoOpEvaluator;
import datadog.trace.api.telemetry.WafMetricCollector;
import datadog.trace.bootstrap.instrumentation.api.AgentScope;
import datadog.trace.bootstrap.instrumentation.api.AgentSpan;
import datadog.trace.bootstrap.instrumentation.api.AgentTracer;
Expand Down Expand Up @@ -79,7 +83,18 @@ public static void install() {
if (isBlank(endpoint)) {
endpoint = String.format("https://app.%s/api/v2/ai-guard", config.getSite());
}
final Map<String, String> headers = mapOf("DD-API-KEY", apiKey, "DD-APPLICATION-KEY", appKey);
final Map<String, String> headers =
mapOf(
"DD-API-KEY",
apiKey,
"DD-APPLICATION-KEY",
appKey,
"DD-AI-GUARD-VERSION",
TRACER_VERSION,
"DD-AI-GUARD-SOURCE",
"SDK",
"DD-AI-GUARD-LANGUAGE",
"jvm");
final HttpUrl url = HttpUrl.get(endpoint).newBuilder().addPathSegment("evaluate").build();
final int timeout = config.getAiGuardTimeout();
final OkHttpClient client = buildClient(url, timeout);
Expand Down Expand Up @@ -113,12 +128,17 @@ static void uninstall() {
private static List<Message> messagesForMetaStruct(List<Message> messages) {
final Config config = Config.get();
final int size = Math.min(messages.size(), config.getAiGuardMaxMessagesLength());
if (size < messages.size()) {
WafMetricCollector.get().aiGuardTruncated(MESSAGES);
}
final List<Message> result = new ArrayList<>(size);
final int maxContent = config.getAiGuardMaxContentSize();
boolean contentTruncated = false;
for (int i = 0; i < size; i++) {
Message source = messages.get(i);
final String content = source.getContent();
if (content != null && content.length() > maxContent) {
contentTruncated = true;
source =
new Message(
source.getRole(),
Expand All @@ -128,6 +148,9 @@ private static List<Message> messagesForMetaStruct(List<Message> messages) {
}
result.add(source);
}
if (contentTruncated) {
WafMetricCollector.get().aiGuardTruncated(CONTENT);
}
return result;
}

Expand Down Expand Up @@ -203,20 +226,27 @@ public Evaluation evaluate(final List<Message> messages, final Options options)
final String reason = (String) result.get("reason");
span.setTag(ACTION_TAG, action);
span.setTag(REASON_TAG, reason);
final boolean blockingEnabled =
isBlockingEnabled(options, result.get("is_blocking_enabled"));
if (blockingEnabled && action != Action.ALLOW) {
final boolean shouldBlock =
isBlockingEnabled(options, result.get("is_blocking_enabled")) && action != Action.ALLOW;
WafMetricCollector.get().aiGuardRequest(action, shouldBlock);
if (shouldBlock) {
span.setTag(BLOCKED_TAG, true);
throw new AIGuardAbortError(action, reason);
}
return new Evaluation(action, reason);
}
} catch (AIGuardAbortError | AIGuardClientError e) {
} catch (AIGuardAbortError e) {
span.addThrowable(e);
throw e;
} catch (AIGuardClientError e) {
WafMetricCollector.get().aiGuardError();
span.addThrowable(e);
throw e;
} catch (final Exception e) {
WafMetricCollector.get().aiGuardError();
final AIGuardClientError error =
new AIGuardClientError("AI Guard service returned unexpected response", e);
new AIGuardClientError(
"AI Guard service returned unexpected response: " + e.getMessage(), e);
span.addThrowable(error);
throw error;
} finally {
Expand Down Expand Up @@ -248,11 +278,14 @@ private static OkHttpClient buildClient(final HttpUrl url, final long timeout) {
return OkHttpUtils.buildHttpClient(url, timeout).newBuilder().build();
}

private static Map<String, String> mapOf(
final String key1, final String prop1, final String key2, final String prop2) {
final Map<String, String> map = new HashMap<>(2);
map.put(key1, prop1);
map.put(key2, prop2);
private static Map<String, String> mapOf(final String... props) {
if (props.length % 2 != 0) {
throw new IllegalArgumentException("Props must be even");
}
final Map<String, String> map = new HashMap<>(props.length << 1);
for (int i = 0; i < props.length; ) {
map.put(props[i++], props[i++]);
}
return map;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@ import com.fasterxml.jackson.annotation.JsonInclude
import com.fasterxml.jackson.databind.ObjectMapper
import com.fasterxml.jackson.databind.PropertyNamingStrategies
import com.squareup.moshi.Moshi
import datadog.common.version.VersionInfo
import datadog.trace.api.Config
import datadog.trace.api.aiguard.AIGuard
import datadog.trace.api.telemetry.WafMetricCollector
import datadog.trace.bootstrap.instrumentation.api.AgentSpan
import datadog.trace.bootstrap.instrumentation.api.AgentTracer
import datadog.trace.test.util.DDSpecification
Expand Down Expand Up @@ -35,7 +37,11 @@ class AIGuardInternalTests extends DDSpecification {
protected static final URL = HttpUrl.parse('https://app.datadoghq.com/api/v2/ai-guard/evaluate')

@Shared
protected static final HEADERS = ['DD-API-KEY': 'api', 'DD-APPLICATION-KEY': 'app']
protected static final HEADERS = ['DD-API-KEY': 'api',
'DD-APPLICATION-KEY': 'app',
'DD-AI-GUARD-VERSION': VersionInfo.VERSION,
'DD-AI-GUARD-SOURCE': 'SDK',
'DD-AI-GUARD-LANGUAGE': 'jvm']

@Shared
protected static final ORIGINAL_TRACER = AgentTracer.get()
Expand Down Expand Up @@ -79,6 +85,11 @@ class AIGuardInternalTests extends DDSpecification {
buildSpan(_ as String, _ as String) >> builder
}
AgentTracer.forceRegister(tracer)

WafMetricCollector.get().tap {
prepareMetrics()
drain()
}
}

void cleanup() {
Expand Down Expand Up @@ -193,6 +204,7 @@ class AIGuardInternalTests extends DDSpecification {
eval.action == suite.action
eval.reason == suite.reason
}
assertTelemetry('ai_guard.requests', "action:$suite.action", "block:$throwAbortError", 'error:false')

where:
suite << TestSuite.build()
Expand Down Expand Up @@ -222,6 +234,7 @@ class AIGuardInternalTests extends DDSpecification {
final exception = thrown(AIGuard.AIGuardClientError)
exception.errors == errors
1 * span.addThrowable(_ as AIGuard.AIGuardClientError)
assertTelemetry('ai_guard.requests', 'error:true')
}

void 'test evaluate with invalid JSON'() {
Expand All @@ -246,6 +259,7 @@ class AIGuardInternalTests extends DDSpecification {
then:
thrown(AIGuard.AIGuardClientError)
1 * span.addThrowable(_ as AIGuard.AIGuardClientError)
assertTelemetry('ai_guard.requests', 'error:true')
}

void 'test evaluate with missing action'() {
Expand All @@ -270,6 +284,7 @@ class AIGuardInternalTests extends DDSpecification {
then:
thrown(AIGuard.AIGuardClientError)
1 * span.addThrowable(_ as AIGuard.AIGuardClientError)
assertTelemetry('ai_guard.requests', 'error:true')
}

void 'test evaluate with non JSON response'() {
Expand All @@ -294,6 +309,7 @@ class AIGuardInternalTests extends DDSpecification {
then:
thrown(AIGuard.AIGuardClientError)
1 * span.addThrowable(_ as AIGuard.AIGuardClientError)
assertTelemetry('ai_guard.requests', 'error:true')
}

void 'test evaluate with empty response'() {
Expand All @@ -318,6 +334,7 @@ class AIGuardInternalTests extends DDSpecification {
then:
thrown(AIGuard.AIGuardClientError)
1 * span.addThrowable(_ as AIGuard.AIGuardClientError)
assertTelemetry('ai_guard.requests', 'error:true')
}

void 'test message length truncation'() {
Expand Down Expand Up @@ -349,6 +366,7 @@ class AIGuardInternalTests extends DDSpecification {
assert received.size() == maxMessages
assert received.size() < messages.size()
}
assertTelemetry('ai_guard.truncated', 'type:messages')
}

void 'test message content truncation'() {
Expand Down Expand Up @@ -380,6 +398,7 @@ class AIGuardInternalTests extends DDSpecification {
assert it.content.length() < message.content.length()
}
}
assertTelemetry('ai_guard.truncated', 'type:content')
}

void 'test no messages'() {
Expand Down Expand Up @@ -425,6 +444,21 @@ class AIGuardInternalTests extends DDSpecification {
0 * span.setTag(AIGuardInternal.TOOL_TAG, _)
}

private static assertTelemetry(final String metric, final String...tags) {
final metrics = WafMetricCollector.get().with {
prepareMetrics()
drain()
}
final filtered = metrics.findAll {
it.namespace == 'appsec'
&& it.metricName == metric
&& it.tags == tags.toList()
}
assert filtered.size() == 1 : metrics
assert filtered*.value.sum() == 1
return true
}

private static assertRequest(final Request request, final List<AIGuard.Message> messages) {
assert request.url() == URL
assert request.method() == 'POST'
Expand Down Expand Up @@ -452,12 +486,12 @@ class AIGuardInternalTests extends DDSpecification {

private static Response mockResponse(final Request request, final int status, final Object body) {
return new Response.Builder()
.protocol(Protocol.HTTP_1_1)
.message('ok')
.request(request)
.code(status)
.body(body == null ? null : ResponseBody.create(MediaType.parse('application/json'), MOSHI.adapter(Object).toJson(body)))
.build()
.protocol(Protocol.HTTP_1_1)
.message('ok')
.request(request)
.code(status)
.body(body == null ? null : ResponseBody.create(MediaType.parse('application/json'), MOSHI.adapter(Object).toJson(body)))
.build()
}

private static class TestSuite {
Expand Down Expand Up @@ -495,13 +529,13 @@ class AIGuardInternalTests extends DDSpecification {
@Override
String toString() {
return "TestSuite{" +
"description='" + description + '\'' +
", action=" + action +
", reason='" + reason + '\'' +
", blocking=" + blocking +
", target='" + target + '\'' +
", messages=" + messages +
'}'
"description='" + description + '\'' +
", action=" + action +
", reason='" + reason + '\'' +
", blocking=" + blocking +
", target='" + target + '\'' +
", messages=" + messages +
'}'
}
}
}
Loading