Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions docker-compose.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
services:
neon-proxy:
build: .
environment:
- ALLOW_ADDR_REGEX=.*
- LOG_TRAFFIC=true
- TLS_SKIP_VERIFY=true
ports:
- '5433:80'
72 changes: 69 additions & 3 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package main

import (
"bytes"
"crypto/tls"
"fmt"
"io"
"log"
Expand Down Expand Up @@ -42,6 +43,9 @@ type Config struct {
UseHostHeader bool `env:"USE_HOST_HEADER" envDefault:"false"`
LogTraffic bool `env:"LOG_TRAFFIC" envDefault:"false"`
LogConnInfo bool `env:"LOG_CONN_INFO" envDefault:"true"`
UseTLS bool `env:"USE_TLS" envDefault:"true"`
TLSSkipVerify bool `env:"TLS_SKIP_VERIFY" envDefault:"false"`
TLSServerName string `env:"TLS_SERVER_NAME" envDefault:""`
}

var upgrader = websocket.Upgrader{
Expand Down Expand Up @@ -142,9 +146,71 @@ func (h *ProxyHandler) HandleWS(conn *websocket.Conn, addr string) error {
activeConnections.Inc()
defer activeConnections.Dec()

socket, err := net.Dial("tcp", addr)
if err != nil {
return err
var socket net.Conn
var err error

if h.cfg.UseTLS {
// First establish a plain TCP connection
socket, err = net.Dial("tcp", addr)
if err != nil {
return fmt.Errorf("failed to establish TCP connection: %w", err)
}

// Send PostgreSQL SSL request
sslRequest := []byte{0x00, 0x00, 0x00, 0x08, 0x04, 0xd2, 0x16, 0x2f}
_, err = socket.Write(sslRequest)
if err != nil {
socket.Close()
return fmt.Errorf("failed to send SSL request: %w", err)
}

// Read SSL response (1 byte)
response := make([]byte, 1)
_, err = socket.Read(response)
if err != nil {
socket.Close()
return fmt.Errorf("failed to read SSL response: %w", err)
}

if response[0] == 'S' {
// Server supports SSL, upgrade the connection
serverName := h.cfg.TLSServerName
if serverName == "" {
// Extract hostname from address if TLS_SERVER_NAME is not set
host, _, err := net.SplitHostPort(addr)
if err != nil {
// If SplitHostPort fails, use the full address as hostname
serverName = addr
} else {
serverName = host
}
}

tlsConfig := &tls.Config{
ServerName: serverName,
InsecureSkipVerify: h.cfg.TLSSkipVerify,
}
tlsConn := tls.Client(socket, tlsConfig)
err = tlsConn.Handshake()
if err != nil {
socket.Close()
return fmt.Errorf("failed to complete TLS handshake: %w", err)
}
socket = tlsConn
} else if response[0] == 'N' {
// Server doesn't support SSL
if h.cfg.LogConnInfo {
log.Printf("PostgreSQL server doesn't support SSL, continuing with plain connection")
}
} else {
socket.Close()
return fmt.Errorf("unexpected SSL response from server: %c", response[0])
}
} else {
socket, err = net.Dial("tcp", addr)
if err != nil {
return fmt.Errorf("failed to establish TCP connection: %w", err)
}
}
defer socket.Close()

Expand Down