Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions db/init.sql
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,19 @@ CREATE TABLE ollama_embeddings (
metadata JSONB
);

CREATE TABLE bedrock_embeddings_1024 (
id BIGSERIAL PRIMARY KEY,
doc_id TEXT NOT NULL,
model TEXT NOT NULL,
embedding VECTOR(1024) NOT NULL,
metadata JSONB,
created_at TIMESTAMPTZ DEFAULT now()
);

CREATE INDEX bedrock_embeddings_1024_ivf_cos ON bedrock_embeddings_1024
USING ivfflat (embedding vector_cosine_ops)
WITH (lists = 100);

-- Index for efficient vector similarity search for OpenAI embeddings
CREATE INDEX openai_embeddings_idx ON openai_embeddings
USING ivfflat (embedding vector_cosine_ops)
Expand Down
23 changes: 23 additions & 0 deletions examples/bedrock/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Bedrock Example

Minimal example that uses **AWS Bedrock** through the `BedrockBackend` to:
1) generate text
2) create embeddings

## Prerequisites

- Go 1.22+
- AWS credentials configured (env/SharedConfig)
- Environment variables:
- `AWS_REGION` (e.g. `us-east-1`)
- `BEDROCK_TEXT_MODEL` (e.g. `anthropic.claude-3-haiku-20240307-v1:0`)
- `BEDROCK_EMBED_MODEL` (e.g. `amazon.titan-embed-text-v1`)

## Run

```bash
export AWS_REGION=us-east-1
export BEDROCK_TEXT_MODEL=anthropic.claude-3-haiku-20240307-v1:0
export BEDROCK_EMBED_MODEL=amazon.titan-embed-text-v1

go run ./examples/bedrock "Explain vector databases in 2 sentences."
63 changes: 63 additions & 0 deletions examples/bedrock/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
package main

import (
"context"
"fmt"
"log"
"os"
"time"

"github.com/stackloklabs/gorag/pkg/backend"
)

func getenv(key string) string {
v := os.Getenv(key)
if v == "" {
log.Fatalf("missing env %s", key)
}
return v
}

func main() {
region := getenv("AWS_REGION")
textModel := getenv("BEDROCK_TEXT_MODEL")
embedModel := getenv("BEDROCK_EMBED_MODEL")

prompt := "Explain Retrieval-Augmented Generation (RAG) in 3 sentences."
if len(os.Args) > 1 {
prompt = os.Args[1]
}

ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()

br, err := backend.NewBedrockBackend(ctx, region, textModel, embedModel)
if err != nil {
log.Fatalf("init bedrock backend: %v", err)
}

out, err := br.Generate(ctx, prompt, map[string]any{
"textGenerationConfig": map[string]any{
"maxTokenCount": 512,
"temperature": 0.2,
"topP": 0.9,
"stopSequences": []string{},
},
})

if err != nil {
log.Fatalf("generate: %v", err)
}
fmt.Println("=== Completion ===")
fmt.Println(out)

vecs, err := br.Embed(ctx, []string{
"RAG augments LLMs with external knowledge.",
"Vector databases enable efficient similarity search.",
}, nil)
if err != nil {
log.Fatalf("embed: %v", err)
}
fmt.Println("=== Embeddings ===")
fmt.Printf("count=%d dim=%d\n", len(vecs), len(vecs[0]))
}
15 changes: 15 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,21 @@ require (
)

require (
github.com/aws/aws-sdk-go-v2 v1.36.6 // indirect
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.11 // indirect
github.com/aws/aws-sdk-go-v2/config v1.29.18 // indirect
github.com/aws/aws-sdk-go-v2/credentials v1.17.71 // indirect
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.33 // indirect
github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.37 // indirect
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.37 // indirect
github.com/aws/aws-sdk-go-v2/internal/ini v1.8.3 // indirect
github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.31.1 // indirect
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.4 // indirect
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.18 // indirect
github.com/aws/aws-sdk-go-v2/service/sso v1.25.6 // indirect
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.30.4 // indirect
github.com/aws/aws-sdk-go-v2/service/sts v1.34.1 // indirect
github.com/aws/smithy-go v1.22.4 // indirect
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
github.com/jackc/chunkreader/v2 v2.0.1 // indirect
github.com/jackc/pgconn v1.14.3 // indirect
Expand Down
198 changes: 198 additions & 0 deletions pkg/backend/bedrock_backend.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
package backend

import (
"context"
"encoding/json"
"errors"
"strings"

"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
)

type bedrockInvoker interface {
InvokeModel(ctx context.Context, params *bedrockruntime.InvokeModelInput, optFns ...func(*bedrockruntime.Options)) (*bedrockruntime.InvokeModelOutput, error)
}

type BedrockBackend struct {
client bedrockInvoker
textModel string
embedModel string
}

func NewBedrockBackend(ctx context.Context, region, textModel, embedModel string) (*BedrockBackend, error) {
cfg, err := config.LoadDefaultConfig(ctx, config.WithRegion(region))
if err != nil {
return nil, err
}
return &BedrockBackend{
client: bedrockruntime.NewFromConfig(cfg),
textModel: textModel,
embedModel: embedModel,
}, nil
}

func newBedrockBackendWithClient(textModel, embedModel string, c bedrockInvoker) *BedrockBackend {
return &BedrockBackend{
client: c,
textModel: textModel,
embedModel: embedModel,
}
}

func (b *BedrockBackend) Generate(ctx context.Context, prompt string, params map[string]any) (string, error) {
if b.textModel == "" {
return "", errors.New("text model is not set")
}
body := map[string]any{
"inputText": prompt,
}
for k, v := range params {
body[k] = v
}
req, err := json.Marshal(body)
if err != nil {
return "", err
}
out, err := b.client.InvokeModel(ctx, &bedrockruntime.InvokeModelInput{
ModelId: awsString(b.textModel),
ContentType: awsString("application/json"),
Body: req,
})
if err != nil {
return "", err
}
return parseText(out.Body)
}

func (b *BedrockBackend) GenerateStream(ctx context.Context, prompt string, params map[string]any, onToken func(string) error) error {
txt, err := b.Generate(ctx, prompt, params)
if err != nil {
return err
}
for _, t := range strings.Split(txt, " ") {
if err := onToken(t + " "); err != nil {
return err
}
}
return nil
}

func (b *BedrockBackend) Embed(ctx context.Context, texts []string, params map[string]any) ([][]float32, error) {
if b.embedModel == "" {
return nil, errors.New("embed model is not set")
}
outVecs := make([][]float32, 0, len(texts))
for _, t := range texts {
body := map[string]any{
"inputText": t,
}
for k, v := range params {
body[k] = v
}
req, err := json.Marshal(body)
if err != nil {
return nil, err
}
out, err := b.client.InvokeModel(ctx, &bedrockruntime.InvokeModelInput{
ModelId: awsString(b.embedModel),
ContentType: awsString("application/json"),
Body: req,
})
if err != nil {
return nil, err
}
vec, err := parseEmbedding(out.Body)
if err != nil {
return nil, err
}
outVecs = append(outVecs, vec)
}
return outVecs, nil
}
func parseText(b []byte) (string, error) {
var m map[string]any
if err := json.Unmarshal(b, &m); err != nil {
return "", err
}

// Titan: {"results":[{"outputText":"..."}]}
if r, ok := m["results"].([]any); ok && len(r) > 0 {
for _, it := range r {
if mm, ok := it.(map[string]any); ok {
if s, ok := mm["outputText"].(string); ok {
return s, nil
}
if s, ok := mm["text"].(string); ok {
return s, nil
}

if msg, ok := mm["message"].(map[string]any); ok {
if content, ok := msg["content"].([]any); ok && len(content) > 0 {
if c0, ok := content[0].(map[string]any); ok {
if s, ok := c0["text"].(string); ok {
return s, nil
}
}
}
}
}
}
}

if s, ok := m["outputText"].(string); ok {
return s, nil
}
if s, ok := m["completion"].(string); ok {
return s, nil
}
if s, ok := m["generation"].(string); ok {
return s, nil
}
if arr, ok := m["content"].([]any); ok && len(arr) > 0 {
if mm, ok := arr[0].(map[string]any); ok {
if s, ok := mm["text"].(string); ok {
return s, nil
}
}
}

return "", errors.New("unexpected response schema")
}

func parseEmbedding(b []byte) ([]float32, error) {
var m map[string]any
if err := json.Unmarshal(b, &m); err != nil {
return nil, err
}
if v, ok := m["embedding"].([]any); ok {
return toFloat32Slice(v)
}
if r, ok := m["results"].([]any); ok && len(r) > 0 {
if mm, ok := r[0].(map[string]any); ok {
if v, ok := mm["embedding"].([]any); ok {
return toFloat32Slice(v)
}
}
}
return nil, errors.New("embedding not found")
}

func toFloat32Slice(v []any) ([]float32, error) {
res := make([]float32, len(v))
for i, x := range v {
switch n := x.(type) {
case float64:
res[i] = float32(n)
case float32:
res[i] = n
default:
return nil, errors.New("invalid embedding element type")
}
}
return res, nil
}

func awsString(s string) *string {
return &s
}
Loading