Skip to content

Commit 1ddbf59

Browse files
authored
Merge pull request #25 from bat-bs/reporting
Add cost collector and caching
2 parents 4bc8c66 + f08d678 commit 1ddbf59

12 files changed

+364
-73
lines changed

api/apiserver.go

+3-2
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,12 @@ func ApiInit(mux *http.ServeMux, a *auth.Auth, db *db.Database) {
1818
}
1919

2020
api := ApiHandler{db, a, timeZone}
21+
graph := NewGraphHandler(&api)
2122
mux.HandleFunc("/api2/user/widget", api.GetUserWidget)
2223
mux.HandleFunc("/api2/user/logout", api.LogoutUser)
2324
mux.HandleFunc("/api2/admin/table/get", api.GetAdminTable)
24-
mux.HandleFunc("/api2/admin/table/graph/get/", api.GetAdminTableGraph)
25-
mux.HandleFunc("/api2/table/graph/get/", api.GetTableGraph)
25+
mux.HandleFunc("/api2/admin/table/graph/get/", graph.GetAdminTableGraph)
26+
mux.HandleFunc("/api2/table/graph/get/", graph.GetTableGraph)
2627
mux.HandleFunc("/api2/table/get", api.GetTable)
2728
mux.HandleFunc("/api2/table/entry/save", api.CreateEntry)
2829
mux.HandleFunc("/api2/table/entry/delete/", api.DeleteEntry)

api/graph.go

+72-19
Original file line numberDiff line numberDiff line change
@@ -15,21 +15,41 @@ import (
1515
"github.com/go-echarts/go-echarts/v2/opts"
1616
)
1717

18+
type GraphHandler struct {
19+
a *ApiHandler
20+
cache []Cache
21+
}
22+
23+
// if basecount, filter and DB Rows missmatch, Cache will be updated
24+
type Cache struct {
25+
ID string
26+
BaseCount int //
27+
Filter string
28+
Data []db.RequestSummary
29+
}
30+
31+
func NewGraphHandler(a *ApiHandler) *GraphHandler {
32+
var cache []Cache
33+
return &GraphHandler{a, cache}
34+
}
35+
36+
// Generate Token Graphs for API Key and Admin overview and allow to filter by timeframes
37+
1838
type Graph struct {
1939
key string
2040
kind string // can be "user" or "apiKey"
2141
w http.ResponseWriter
2242
r *http.Request
2343
}
2444

25-
func (a *ApiHandler) GetTableGraph(w http.ResponseWriter, r *http.Request) {
26-
ok := a.auth.ValidateSessionToken(w, r)
45+
func (g *GraphHandler) GetTableGraph(w http.ResponseWriter, r *http.Request) {
46+
ok := g.a.auth.ValidateSessionToken(w, r)
2747
if !ok {
2848
http.Error(w, "Not Authorized", http.StatusForbidden)
2949
return
3050
}
3151
key := strings.TrimPrefix(r.URL.Path, "/api2/table/graph/get/")
32-
a.RenderGraph(&Graph{
52+
g.RenderGraph(&Graph{
3353
key: key,
3454
kind: "apiKey",
3555
w: w,
@@ -38,16 +58,16 @@ func (a *ApiHandler) GetTableGraph(w http.ResponseWriter, r *http.Request) {
3858

3959
}
4060

41-
func (a *ApiHandler) GetAdminTableGraph(w http.ResponseWriter, r *http.Request) {
61+
func (g *GraphHandler) GetAdminTableGraph(w http.ResponseWriter, r *http.Request) {
4262

43-
ok, err := a.auth.ValidateAdminSession(w, r)
63+
ok, err := g.a.auth.ValidateAdminSession(w, r)
4464
if err != nil || !ok {
4565
http.Error(w, "Not Authorized", http.StatusForbidden)
4666
return
4767
}
4868

4969
key := strings.TrimPrefix(r.URL.Path, "/api2/admin/table/graph/get/")
50-
a.RenderGraph(&Graph{
70+
g.RenderGraph(&Graph{
5171
key: key,
5272
kind: "user",
5373
w: w,
@@ -56,10 +76,10 @@ func (a *ApiHandler) GetAdminTableGraph(w http.ResponseWriter, r *http.Request)
5676

5777
}
5878

59-
func (a *ApiHandler) RenderGraph(g *Graph) {
79+
func (g *GraphHandler) RenderGraph(gr *Graph) {
6080

6181
selectedfilter := "24h" // default Value
62-
header := g.r.Header.Get("HX-Current-URL")
82+
header := gr.r.Header.Get("HX-Current-URL")
6383
currentURL, err := url.Parse(header)
6484
if err != nil {
6585
log.Println("Error Parsing HX-Current-URL for Graph Table")
@@ -90,13 +110,6 @@ func (a *ApiHandler) RenderGraph(g *Graph) {
90110
filter = "This Year"
91111
}
92112

93-
data, err := a.db.LookupApiKeyUserStats(g.key, g.kind, filter)
94-
if err != nil {
95-
log.Println(err)
96-
http.Error(g.w, "Could not get Data from DB for User "+string(g.key), 500)
97-
return
98-
}
99-
100113
// create a new line instance
101114
line := charts.NewLine()
102115
// set some global options like Title/Legend/ToolTip or anything else
@@ -108,9 +121,10 @@ func (a *ApiHandler) RenderGraph(g *Graph) {
108121
charts.WithYAxisOpts(opts.YAxis{SplitNumber: 2}),
109122
// charts.WithVisualMapOpts(opts.VisualMap{Show: opts.Bool(false)})
110123
)
124+
data := g.GetTableGraphData(gr, filter)
111125

112126
// Put data into instance
113-
td, err := a.GetAdminTableGraphData(g.w, data, filter)
127+
td, err := g.SetTableGraphData(gr.w, data, filter)
114128
if err != nil {
115129
return
116130
}
@@ -141,7 +155,7 @@ func (a *ApiHandler) RenderGraph(g *Graph) {
141155
Filter: filter,
142156
}
143157
// var buf bytes.Buffer
144-
if err := t.Execute(g.w, snippetData); err != nil {
158+
if err := t.Execute(gr.w, snippetData); err != nil {
145159
log.Println("Error Templating Chart", err)
146160
return
147161
}
@@ -157,7 +171,46 @@ type TableData struct {
157171
totalCount int
158172
}
159173

160-
func (a *ApiHandler) GetAdminTableGraphData(w http.ResponseWriter, d []db.RequestSummary, filter string) (*TableData, error) {
174+
func (g *GraphHandler) GetTableGraphData(gr *Graph, filter string) []db.RequestSummary {
175+
for _, row := range g.cache {
176+
if row.ID == gr.key && row.Filter == filter {
177+
count, err := g.a.db.LookupApiKeyUserStatsRows(gr.key, gr.kind)
178+
if err == nil && count == row.BaseCount {
179+
return row.Data
180+
}
181+
if err != nil {
182+
log.Println(err)
183+
}
184+
row.Data = g.LookupTableGraphData(gr, filter)
185+
row.BaseCount = count
186+
}
187+
}
188+
rowCount, err := g.a.db.LookupApiKeyUserStatsRows(gr.key, gr.kind)
189+
if err != nil {
190+
log.Println("could not count rows for caching: ", err)
191+
}
192+
193+
row := Cache{
194+
Data: g.LookupTableGraphData(gr, filter),
195+
BaseCount: rowCount,
196+
Filter: filter,
197+
ID: gr.key,
198+
}
199+
g.cache = append(g.cache, row)
200+
return row.Data
201+
}
202+
203+
func (g *GraphHandler) LookupTableGraphData(gr *Graph, filter string) []db.RequestSummary {
204+
data, err := g.a.db.LookupApiKeyUserStats(gr.key, gr.kind, filter)
205+
if err != nil && data != nil {
206+
log.Println(err)
207+
http.Error(gr.w, "Could not get Data from DB for User "+string(gr.key), 500)
208+
return nil
209+
}
210+
return data
211+
}
212+
213+
func (g *GraphHandler) SetTableGraphData(w http.ResponseWriter, d []db.RequestSummary, filter string) (*TableData, error) {
161214

162215
td := &TableData{
163216
data: make([]opts.LineData, 0),
@@ -188,7 +241,7 @@ func (a *ApiHandler) GetAdminTableGraphData(w http.ResponseWriter, d []db.Reques
188241
totalTokens = item.TokenCountComplete + item.TokenCountPrompt
189242
td.data = append(td.data, opts.LineData{Value: totalTokens})
190243
td.totalCount = td.totalCount + totalTokens
191-
loc, err := time.LoadLocation(a.timeZone)
244+
loc, err := time.LoadLocation(g.a.timeZone)
192245
if err != nil {
193246
log.Println("Error Displaying Timezone, maybe the TIMEZONE env is wrongly set")
194247
td.timeAxis = append(td.timeAxis, item.RequestTime.Format(format))

apiproxy/response.go

+2
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ import (
1111
"strings"
1212
)
1313

14+
// Process Response of Azure and Write Key Parameters to DB
15+
1416
type ResponseConf struct {
1517
db *db.Database
1618
}

cmd/main.go

+7
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,13 @@ import (
66
api "openai-api-proxy/api"
77
proxy "openai-api-proxy/apiproxy"
88
auth "openai-api-proxy/auth"
9+
costs "openai-api-proxy/costs"
910
db "openai-api-proxy/db"
1011
web "openai-api-proxy/webui"
1112
"os"
1213
"os/signal"
1314
"syscall"
15+
"time"
1416

1517
"github.com/joho/godotenv"
1618
)
@@ -23,6 +25,11 @@ func main() {
2325
}
2426
db := db.DatabaseInit()
2527
defer db.Close()
28+
29+
ticker := time.NewTicker(24 * time.Hour)
30+
31+
go costs.GetAllCosts(db, ticker.C)
32+
2633
osExit(db)
2734
defer log.Println("Closing DB Clients :)")
2835

costs/azureCostCollector.go

+117
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
package costs
2+
3+
import (
4+
"bytes"
5+
"encoding/json"
6+
"fmt"
7+
"io"
8+
"log"
9+
"net/http"
10+
db "openai-api-proxy/db"
11+
"regexp"
12+
"time"
13+
)
14+
15+
// Collect Cost of Azure DB and Write them to DB
16+
17+
type BaseJson struct {
18+
Costs []Costs `json:"Items"`
19+
}
20+
type Costs struct {
21+
ModelName string
22+
RetailPrice float32 `json:"retailPrice"`
23+
SKUName string `json:"skuName"`
24+
UnitOfMeasure string `json:"unitOfMeasure"`
25+
TokenType string
26+
Currency string
27+
IsRegional bool
28+
}
29+
30+
const MoneyUnit = 10000000
31+
32+
func GetAllCosts(d *db.Database, g <-chan time.Time) {
33+
for ; true; <-g {
34+
log.Println("Azure: Collecting Prices")
35+
totalCosts := []*Costs{}
36+
tokenTypes := []string{"Input", "Output"}
37+
models := d.LookupModels()
38+
cleanModel := regexp.MustCompile(`-[0-9]{4}-[0-9]{2}-[0-9]{2}`)
39+
if len(models) == 0 {
40+
log.Println("No Models in DB yet, no costs can be calculated")
41+
return
42+
}
43+
44+
for _, model := range models {
45+
modelName := cleanModel.ReplaceAllString(model, "")
46+
for _, tokenType := range tokenTypes {
47+
c := GetCosts(modelName, tokenType)
48+
c.TokenType = tokenType
49+
c.ModelName = modelName
50+
totalCosts = append(totalCosts, c)
51+
}
52+
}
53+
dbcosts := []*db.Costs{}
54+
for _, costs := range totalCosts {
55+
dbc := &db.Costs{
56+
ModelName: costs.ModelName,
57+
RetailPrice: int64(costs.RetailPrice * float32(MoneyUnit)),
58+
TokenType: costs.TokenType,
59+
UnitOfMeasure: costs.UnitOfMeasure,
60+
IsRegional: costs.IsRegional,
61+
BackendName: "azure",
62+
}
63+
dbcosts = append(dbcosts, dbc)
64+
}
65+
err := d.WriteCosts(dbcosts)
66+
if err != nil {
67+
log.Println(err)
68+
}
69+
}
70+
}
71+
72+
func GetCosts(modelName string, tokenType string) *Costs {
73+
currency := "EUR"
74+
regionName := "swedencentral"
75+
productName := "Azure OpenAI"
76+
regional := true
77+
errorContext := fmt.Sprintf("Check ENV - regional: %s, regionName: %s, currency: %s", regional, regionName, currency)
78+
79+
var region string
80+
81+
if regional {
82+
region = "regional"
83+
} else {
84+
region = "global"
85+
}
86+
87+
rq, err := http.NewRequest("GET", "https://prices.azure.com/api/retail/prices", nil)
88+
if err != nil {
89+
log.Println("Could not get latest Prices from API", errorContext)
90+
}
91+
92+
filter := fmt.Sprintf("productName eq '%s' and armRegionName eq '%s' and skuName eq '%s-%s-%s'", productName, regionName, modelName, tokenType, region)
93+
q := rq.URL.Query()
94+
q.Add("currencyCode", currency)
95+
q.Add("$filter", filter)
96+
rq.URL.RawQuery = q.Encode()
97+
98+
client := &http.Client{}
99+
resp, err := client.Do(rq)
100+
if err != nil {
101+
log.Println("Error Requesting Data from Azure API.", errorContext)
102+
}
103+
defer resp.Body.Close()
104+
body, err := io.ReadAll(resp.Body)
105+
if err != nil {
106+
log.Println("Error Parsing Body")
107+
log.Println(err)
108+
}
109+
var base *BaseJson
110+
111+
resp.Body = io.NopCloser(bytes.NewReader(body))
112+
json.Unmarshal(body, &base)
113+
c := &base.Costs[0]
114+
c.IsRegional = regional
115+
c.Currency = currency
116+
return c
117+
}

0 commit comments

Comments
 (0)