Skip to content

Commit 32b7613

Browse files
authored
Merge pull request #22 from bat-bs/reporting
chore: fix to many open DB Connections
2 parents be9a125 + af650d1 commit 32b7613

File tree

8 files changed

+66
-30
lines changed

8 files changed

+66
-30
lines changed

.env.example

+3-1
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,6 @@ API_VERSION=2024-02-15-preview
66
RESSOURCE_NAME=gpt-35-turbo
77
BASE_URL=openai.azure.com
88

9-
DEFAULT_BACKEND=azure
9+
DEFAULT_BACKEND=azure
10+
11+
DATABASE_PATH=postgresql://user:password@server:5432/db?sslmode=disable

api/apiserver.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,13 @@ import (
1111
"github.com/Masterminds/sprig/v3"
1212
)
1313

14-
func ApiInit(mux *http.ServeMux, a *auth.Auth) {
14+
func ApiInit(mux *http.ServeMux, a *auth.Auth, db *db.Database) {
1515
timeZone, ok := os.LookupEnv("TIMEZONE")
1616
if !ok {
1717
timeZone = "Europe/Berlin"
1818
}
1919

20-
api := ApiHandler{db.NewDB(), a, timeZone}
20+
api := ApiHandler{db, a, timeZone}
2121
mux.HandleFunc("/api2/user/widget", api.GetUserWidget)
2222
mux.HandleFunc("/api2/user/logout", api.LogoutUser)
2323
mux.HandleFunc("/api2/admin/table/get", api.GetAdminTable)

apiproxy/proxy.go

+13-5
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
"net/http"
1010
"net/http/httputil"
1111
"net/url"
12+
db "openai-api-proxy/db"
1213
"os"
1314
"strings"
1415
)
@@ -42,7 +43,7 @@ var (
4243
backendProxy = make(map[string]*httputil.ReverseProxy)
4344
)
4445

45-
func Init(mux *http.ServeMux) {
46+
func Init(mux *http.ServeMux, db *db.Database) {
4647
// Setup Azure Vars and Connection String
4748
azconf := &AzureConfig{
4849
DeploymentName: os.Getenv("DEPLOYMENT_NAME"),
@@ -51,14 +52,21 @@ func Init(mux *http.ServeMux) {
5152
BaseUrl: os.Getenv("BASE_URL"),
5253
}
5354
defaultBackend = os.Getenv("DEFAULT_BACKEND")
54-
55-
h := &baseHandle{azconf}
55+
rc := &ResponseConf{
56+
db: db,
57+
}
58+
h := &baseHandle{
59+
db: db,
60+
az: azconf,
61+
rc: rc}
5662
mux.Handle("/api/", h)
5763

5864
}
5965

6066
type baseHandle struct {
67+
db *db.Database
6168
az *AzureConfig
69+
rc *ResponseConf
6270
}
6371

6472
func (h *baseHandle) ServeHTTP(w http.ResponseWriter, r *http.Request) {
@@ -122,7 +130,7 @@ func (h *baseHandle) ServeHTTP(w http.ResponseWriter, r *http.Request) {
122130

123131
// Since OpenAI and Azure API are not really compatible, we need 2 different handler functions
124132
func (h *baseHandle) HandleAzure(w http.ResponseWriter, r *http.Request, backend string) {
125-
azureToken := ValidateToken(w, r)
133+
azureToken := h.ValidateToken(w, r)
126134
if azureToken == "" {
127135
http.Error(w, "Error Processing Request", http.StatusUnauthorized)
128136
return
@@ -150,7 +158,7 @@ func (h *baseHandle) HandleAzure(w http.ResponseWriter, r *http.Request, backend
150158
log.Printf("Proxying request to Azure backend: %s", actualURL.String())
151159
r.Body.Close()
152160

153-
proxy.ModifyResponse = NewResponse
161+
proxy.ModifyResponse = h.rc.NewResponse
154162

155163
proxy.ServeHTTP(w, r)
156164

apiproxy/response.go

+10-7
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,13 @@ import (
1111
"strings"
1212
)
1313

14+
type ResponseConf struct {
15+
db *db.Database
16+
}
17+
1418
type Response struct {
1519
rs *http.Response
16-
db *db.Database
20+
rc *ResponseConf
1721
apiKeyID string
1822
content Content
1923
}
@@ -27,18 +31,18 @@ type Content struct {
2731
} `json:"usage"`
2832
}
2933

30-
func NewResponse(in *http.Response) error {
34+
func (rc *ResponseConf) NewResponse(in *http.Response) error {
3135

3236
r := Response{
3337
rs: in,
34-
db: db.NewDB(),
38+
rc: rc,
3539
}
3640

3741
err := r.ReadValues()
3842
if err != nil {
3943
return err
4044
}
41-
go r.ProcessValues()
45+
r.ProcessValues()
4246
return nil
4347
}
4448

@@ -47,7 +51,7 @@ func (r *Response) GetApiKeyUUID() string {
4751

4852
apiKey := strings.TrimPrefix(header, "Bearer ")
4953

50-
hashes, err := r.db.LookupApiKeys("*")
54+
hashes, err := r.rc.db.LookupApiKeys("*")
5155
if err != nil || len(hashes) == 0 {
5256
log.Println("Error while requesting API Keys from DB", err)
5357
}
@@ -69,7 +73,6 @@ func (r *Response) ReadValues() error {
6973
}
7074

7175
func (r *Response) ProcessValues() {
72-
defer r.db.Close()
7376
c := r.content
7477
if c.Object != "chat.completion" {
7578
log.Printf("Untested API Endpoint '%s' is used, check the Request in the DB: %s", c.Object, c.ID)
@@ -86,7 +89,7 @@ func (r *Response) ProcessValues() {
8689
TokenCountComplete: c.Usage.CompletionTokens,
8790
Model: c.Model,
8891
}
89-
if err := r.db.WriteRequest(&rq); err != nil {
92+
if err := r.rc.db.WriteRequest(&rq); err != nil {
9093
log.Println("error processing Response: Response could not be written to db DB")
9194
log.Println(err)
9295
return

apiproxy/validateToken.go

+2-4
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ func CompareToken(hashes []db.ApiKey, apiKey string) (string, error) {
2525
return "", err
2626
}
2727

28-
func ValidateToken(w http.ResponseWriter, r *http.Request) string {
28+
func (h *baseHandle) ValidateToken(w http.ResponseWriter, r *http.Request) string {
2929
header := r.Header.Get(authHeader)
3030

3131
apiKey := strings.TrimPrefix(header, "Bearer ")
@@ -36,9 +36,7 @@ func ValidateToken(w http.ResponseWriter, r *http.Request) string {
3636
return ""
3737
}
3838

39-
db := db.NewDB()
40-
41-
hashes, err := db.LookupApiKeys("*")
39+
hashes, err := h.db.LookupApiKeys("*")
4240
if err != nil || len(hashes) == 0 {
4341
log.Println("Error while requesting API Keys from DB", err)
4442
http.Error(w, "401 - Token Invalid", http.StatusUnauthorized)

auth/auth.go

+6-4
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ import (
1717
"golang.org/x/oauth2"
1818
)
1919

20-
func Init(mux *http.ServeMux) (a *Auth) {
20+
func Init(mux *http.ServeMux, db *db.Database) (a *Auth) {
2121
ctx := context.Background()
2222
issuer, ok := os.LookupEnv("ISSUER")
2323
if !ok {
@@ -63,8 +63,10 @@ func Init(mux *http.ServeMux) (a *Auth) {
6363
}
6464

6565
var verifier = provider.Verifier(&oidc.Config{ClientID: clientId})
66+
6667
a = &Auth{
6768
oauth2Config: oauth2Config,
69+
db: db,
6870
ctx: context.Background(),
6971
verifier: verifier,
7072
provider: provider,
@@ -89,6 +91,7 @@ type Claims struct {
8991

9092
type Auth struct {
9193
oauth2Config *oauth2.Config
94+
db *db.Database
9295
ctx context.Context
9396
verifier *oidc.IDTokenVerifier
9497
provider *oidc.Provider
@@ -205,9 +208,8 @@ func (a *Auth) ValidateAdminSession(w http.ResponseWriter, r *http.Request) (boo
205208
log.Println("Claims not found")
206209
return false, err
207210
}
208-
db := db.NewDB()
209-
defer db.Close()
210-
user, err := db.GetUser(claims.Sub)
211+
212+
user, err := a.db.GetUser(claims.Sub)
211213
if err != nil {
212214
return false, err
213215
}

cmd/main.go

+28-5
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@ import (
88
auth "openai-api-proxy/auth"
99
db "openai-api-proxy/db"
1010
web "openai-api-proxy/webui"
11+
"os"
12+
"os/signal"
13+
"syscall"
1114

1215
"github.com/joho/godotenv"
1316
)
@@ -18,15 +21,35 @@ func main() {
1821
if err != nil {
1922
log.Println("Warning: not able to loading Env File", err)
2023
}
21-
db.DatabaseInit()
24+
db := db.DatabaseInit()
25+
defer db.Close()
26+
osExit(db)
27+
defer log.Println("Closing DB Clients :)")
28+
2229
mux := http.NewServeMux()
23-
a := auth.Init(mux)
30+
a := auth.Init(mux, db)
2431
//
25-
proxy.Init(mux) // Start AI Proxy
26-
go web.Init(mux, a) // Start Web UI
27-
go api.ApiInit(mux, a) // Start Backend API
32+
proxy.Init(mux, db) // Start AI Proxy
33+
go web.Init(mux, a) // Start Web UI
34+
go api.ApiInit(mux, a, db) // Start Backend API
2835

2936
log.Printf("Serving on http://localhost:%d", 8082)
3037
log.Fatal(http.ListenAndServe(":8082", mux))
3138

3239
}
40+
41+
// Close DB on Program Exit
42+
func osExit(db *db.Database) {
43+
sigc := make(chan os.Signal, 1)
44+
signal.Notify(sigc,
45+
syscall.SIGHUP,
46+
syscall.SIGINT,
47+
syscall.SIGTERM,
48+
syscall.SIGQUIT)
49+
go func() {
50+
s := <-sigc
51+
log.Printf("Exit: %s Closing DB Clients :)", s)
52+
db.Close()
53+
os.Exit(1)
54+
}()
55+
}

db/database.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -24,18 +24,18 @@ type ApiKey struct {
2424
TokenCountComplete *int
2525
}
2626

27-
func DatabaseInit() {
27+
func DatabaseInit() *Database {
2828
createTable, err := os.ReadFile("db/schema.sql")
2929
if err != nil {
3030
log.Fatal("cannot load schema file: ", err)
3131
}
3232

3333
d := NewDB()
34-
defer d.Close()
3534
d.Migrate()
3635
if _, err := d.db.Exec(string(createTable)); err != nil {
3736
log.Fatal(err)
3837
}
38+
return d
3939

4040
}
4141

0 commit comments

Comments
 (0)