Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
186 changes: 186 additions & 0 deletions apidump_integration_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
package aibridge_test

import (
"bufio"
"bytes"
"context"
"io"
"net"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"strings"
"testing"
"time"

"cdr.dev/slog/v3"
"cdr.dev/slog/v3/sloggers/slogtest"
"github.com/coder/aibridge"
"github.com/coder/aibridge/config"
aibcontext "github.com/coder/aibridge/context"
"github.com/coder/aibridge/fixtures"
"github.com/coder/aibridge/intercept/apidump"
"github.com/coder/aibridge/internal/testutil"
"github.com/coder/aibridge/mcp"
"github.com/coder/aibridge/provider"
"github.com/stretchr/testify/require"
"golang.org/x/tools/txtar"
)

func openaiCfgWithAPIDump(url, key, dumpDir string) config.OpenAI {
return config.OpenAI{
BaseURL: url,
Key: key,
APIDumpDir: dumpDir,
}
}

func anthropicCfgWithAPIDump(url, key, dumpDir string) config.Anthropic {
return config.Anthropic{
BaseURL: url,
Key: key,
APIDumpDir: dumpDir,
}
}

func TestAPIDump(t *testing.T) {
t.Parallel()

cases := []struct {
name string
fixture []byte
providerName string
providersFunc func(addr, dumpDir string) []aibridge.Provider
createRequestFunc createRequestFunc
}{
{
name: config.ProviderAnthropic,
fixture: fixtures.AntSimple,
providerName: config.ProviderAnthropic,
providersFunc: func(addr, dumpDir string) []aibridge.Provider {
return []aibridge.Provider{provider.NewAnthropic(anthropicCfgWithAPIDump(addr, apiKey, dumpDir), nil)}
},
createRequestFunc: createAnthropicMessagesReq,
},
{
name: config.ProviderOpenAI,
fixture: fixtures.OaiChatSimple,
providerName: config.ProviderOpenAI,
providersFunc: func(addr, dumpDir string) []aibridge.Provider {
return []aibridge.Provider{provider.NewOpenAI(openaiCfgWithAPIDump(addr, apiKey, dumpDir))}
},
createRequestFunc: createOpenAIChatCompletionsReq,
},
{
name: config.ProviderOpenAI,
fixture: fixtures.OaiResponsesBlockingSimple,
providerName: config.ProviderOpenAI,
providersFunc: func(addr, dumpDir string) []aibridge.Provider {
return []aibridge.Provider{provider.NewOpenAI(openaiCfgWithAPIDump(addr, apiKey, dumpDir))}
},
createRequestFunc: createOpenAIResponsesReq,
},
}

for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()

logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug)

ctx, cancel := context.WithTimeout(t.Context(), time.Second*30)
t.Cleanup(cancel)

arc := txtar.Parse(tc.fixture)
files := filesMap(arc)
require.Contains(t, files, fixtureRequest)
require.Contains(t, files, fixtureNonStreamingResponse)

reqBody := files[fixtureRequest]

// Setup mock upstream server.
srv := newMockServer(ctx, t, files, nil)
t.Cleanup(srv.Close)

// Create temp dir for API dumps.
dumpDir := t.TempDir()

recorderClient := &testutil.MockRecorder{}
b, err := aibridge.NewRequestBridge(t.Context(), tc.providersFunc(srv.URL, dumpDir), recorderClient, mcp.NewServerProxyManager(nil, testTracer), logger, nil, testTracer)
require.NoError(t, err)

mockSrv := httptest.NewUnstartedServer(b)
t.Cleanup(mockSrv.Close)
mockSrv.Config.BaseContext = func(_ net.Listener) context.Context {
return aibcontext.AsActor(ctx, userID, nil)
}
mockSrv.Start()

req := tc.createRequestFunc(t, mockSrv.URL, reqBody)
resp, err := http.DefaultClient.Do(req)
require.NoError(t, err)
require.Equal(t, http.StatusOK, resp.StatusCode)
defer resp.Body.Close()
_, _ = io.ReadAll(resp.Body)

// Verify dump files were created.
interceptions := recorderClient.RecordedInterceptions()
require.Len(t, interceptions, 1)
interceptionID := interceptions[0].ID

// Find dump files for this interception by walking the dump directory.
var reqDumpFile, respDumpFile string
err = filepath.Walk(dumpDir, func(path string, info os.FileInfo, err error) error {
if err != nil {
return err
}
if info.IsDir() {
return nil
}
// Files are named: {timestamp}-{interceptionID}.{req|resp}.txt
if strings.Contains(path, interceptionID) {
if strings.HasSuffix(path, apidump.SuffixRequest) {
reqDumpFile = path
} else if strings.HasSuffix(path, apidump.SuffixResponse) {
respDumpFile = path
}
}
return nil
})
require.NoError(t, err)
require.NotEmpty(t, reqDumpFile, "request dump file should exist")
require.NotEmpty(t, respDumpFile, "response dump file should exist")

// Verify request dump contains expected HTTP request format.
reqDumpData, err := os.ReadFile(reqDumpFile)
require.NoError(t, err)

// Parse the dumped HTTP request.
dumpReq, err := http.ReadRequest(bufio.NewReader(bytes.NewReader(reqDumpData)))
require.NoError(t, err)
dumpBody, err := io.ReadAll(dumpReq.Body)
require.NoError(t, err)

// Compare requests semantically (key order may differ).
require.JSONEq(t, string(dumpBody), string(reqBody), "request body JSON should match semantically")

// Verify response dump contains expected HTTP response format.
respDumpData, err := os.ReadFile(respDumpFile)
require.NoError(t, err)

// Parse the dumped HTTP response.
dumpResp, err := http.ReadResponse(bufio.NewReader(bytes.NewReader(respDumpData)), nil)
require.NoError(t, err)
require.Equal(t, http.StatusOK, dumpResp.StatusCode)
dumpRespBody, err := io.ReadAll(dumpResp.Body)
require.NoError(t, err)

// Compare responses semantically (key order may differ).
expectedRespBody := files[fixtureNonStreamingResponse]
require.JSONEq(t, string(expectedRespBody), string(dumpRespBody), "response body JSON should match semantically")

recorderClient.VerifyAllInterceptionsEnded(t)
})
}
}
2 changes: 2 additions & 0 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ func DefaultCircuitBreaker() CircuitBreaker {
type Anthropic struct {
BaseURL string
Key string
APIDumpDir string
CircuitBreaker *CircuitBreaker
}

Expand All @@ -53,5 +54,6 @@ type AWSBedrock struct {
type OpenAI struct {
BaseURL string
Key string
APIDumpDir string
CircuitBreaker *CircuitBreaker
}
3 changes: 2 additions & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,15 @@ go 1.24.6
// Misc libs.
require (
cdr.dev/slog/v3 v3.0.0-rc1
github.com/coder/quartz v0.3.0
github.com/google/uuid v1.6.0
github.com/hashicorp/go-multierror v1.1.1
github.com/mark3labs/mcp-go v0.38.0
github.com/prometheus/client_golang v1.23.2
github.com/sony/gobreaker/v2 v2.3.0
github.com/stretchr/testify v1.11.1
github.com/tidwall/gjson v1.18.0
github.com/tidwall/pretty v1.2.1
github.com/tidwall/sjson v1.2.5
go.uber.org/goleak v1.3.0
go.uber.org/mock v0.6.0
Expand Down Expand Up @@ -73,7 +75,6 @@ require (
github.com/rivo/uniseg v0.4.4 // indirect
github.com/spf13/cast v1.7.1 // indirect
github.com/tidwall/match v1.2.0 // indirect
github.com/tidwall/pretty v1.2.1 // indirect
github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect
github.com/yosida95/uritemplate/v3 v3.0.2 // indirect
go.opentelemetry.io/auto/sdk v1.1.0 // indirect
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UF
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/charmbracelet/lipgloss v0.7.1 h1:17WMwi7N1b1rVWOjMT+rCh7sQkvDU75B2hbZpc5Kc1E=
github.com/charmbracelet/lipgloss v0.7.1/go.mod h1:yG0k3giv8Qj8edTCbbg6AlQ5e8KNWpFujkNawKNhE2c=
github.com/coder/quartz v0.3.0 h1:bUoSEJ77NBfKtUqv6CPSC0AS8dsjqAqqAv7bN02m1mg=
github.com/coder/quartz v0.3.0/go.mod h1:BgE7DOj/8NfvRgvKw0jPLDQH/2Lya2kxcTaNJ8X0rZk=
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM=
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8=
Expand Down
Loading