-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.go
227 lines (193 loc) · 5.87 KB
/
main.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
package main
import (
"compress/gzip"
"embed"
"encoding/json"
"flag"
"fmt"
"io/fs"
"log"
"net"
"net/http"
"net/http/httputil"
"net/url"
"os"
"path"
"strconv"
"strings"
"time"
)
// Set cache age
var cacheAge = time.Hour * 2
//go:embed static/*
var staticContent embed.FS
//go:embed mimetype.json
var mimeTypeContent []byte
type mimeType map[string]string
//go:embed revproxies.json
var revProxiesContent []byte
type revProxyHost struct {
Scheme string `json:"scheme"`
Host string `json:"host"`
Path string `json:"path"`
ReqHeaders map[string]string `json:"reqHeaders"`
ResHeaders map[string]string `json:"resHeaders"`
}
func main() {
checkFlags()
// Reads the mimetype.json file and converts it to a MimeType type
var mimeTypes mimeType
err := json.Unmarshal(mimeTypeContent, &mimeTypes)
if err != nil {
log.Fatal(err)
}
var revProxyes []revProxyHost
err = json.Unmarshal(revProxiesContent, &revProxyes)
if err != nil {
log.Fatal(err)
}
port := os.Getenv("PORT")
if port == "" {
port = "8080"
}
// Reverse proxy handlers
for _, proxy := range revProxyes {
proxy := proxy
log.Println(proxy)
// Forward requests as a reverse proxy
apiProxy := httputil.NewSingleHostReverseProxy(&url.URL{
Scheme: proxy.Scheme,
Host: proxy.Host,
})
pp := strings.TrimSuffix(proxy.Path, "/")
// Overwrite Director function and edit request headers
apiProxy.Director = func(req *http.Request) {
req.Host = proxy.Host
req.URL.Scheme = proxy.Scheme
req.URL.Host = proxy.Host
req.URL.Path = strings.TrimPrefix(req.URL.Path, pp)
if req.URL.Path == "" {
req.URL.Path = "/"
}
// Add request headers
for name, value := range proxy.ReqHeaders {
req.Header.Set(name, value)
}
}
// Overwrite ModifyResponse function to edit response headers
apiProxy.ModifyResponse = func(res *http.Response) error {
// Add response headers
for name, value := range proxy.ResHeaders {
res.Header.Set(name, value)
}
if 300 <= res.StatusCode && res.StatusCode < 400 {
loc := res.Header.Get("Location")
parsedURL, err := url.Parse(loc)
if err != nil {
log.Printf("500: %s %s %s", res.Request.RemoteAddr, res.Request.Method, res.Request.RequestURI)
return err
}
// Check if the host is an IP loopback address or "localhost"
host := parsedURL.Hostname()
if host == "localhost" || isLoopbackIP(host) {
// Add the port number if one is not already specified
if parsedURL.Port() != port {
parsedURL.Host = host + ":" + port
}
if !strings.HasPrefix(parsedURL.Path, proxy.Path) {
parsedURL.Path = path.Join(proxy.Path, parsedURL.Path)
}
if strings.HasSuffix(loc, "/") && !strings.HasSuffix(parsedURL.Path, "/") {
parsedURL.Path += "/"
}
}
res.Header.Set("Location", parsedURL.String())
}
log.Printf("%d: %s %s %s", res.StatusCode, res.Request.RemoteAddr, res.Request.Method, res.Request.RequestURI)
return nil
}
http.Handle(proxy.Path, apiProxy)
}
// Root handler
http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
// Get the path to the requested file
embeddedPath := path.Clean("static/" + r.URL.Path[1:])
// If the path is a directory, add "/index.html" at the end
var isDir bool
if info, err := fs.Stat(staticContent, embeddedPath); err == nil && info.IsDir() {
isDir = true
embeddedPath = path.Join(embeddedPath, "index.html")
}
// If the file exists, return its contents as a response
content, err := staticContent.ReadFile(embeddedPath)
if err != nil {
// If the file does not exist, return a 404 error
http.NotFound(w, r)
log.Printf("404: %s %s %s", r.RemoteAddr, r.Method, r.RequestURI)
return
}
// Redirect if it is a directory but the path does not end with /.
if isDir && !strings.HasSuffix(r.URL.Path, "/") {
redirectUrl := r.URL
redirectUrl.Path += "/"
http.Redirect(w, r, redirectUrl.String(), http.StatusMovedPermanently)
log.Printf("301: %s %s %s", r.RemoteAddr, r.Method, r.RequestURI)
return
}
// Set the Content-Type associated with the extension
ext := path.Ext(embeddedPath)
contentType, ok := mimeTypes[ext]
if ok {
w.Header().Set("Content-Type", contentType)
}
// Set up cache control
if cacheAge == 0 {
w.Header().Set("Cache-Control", "no-cache")
} else {
w.Header().Set("Cache-Control", "public, max-age="+strconv.Itoa(int(cacheAge.Seconds())))
}
// Add Access-Control-Allow-Origin to HTTP header to allow all CORS
// Get Origin from the request header
origin := r.Header.Get("Origin")
// If the request header contains Origin, use its value
if origin != "" {
w.Header().Set("Access-Control-Allow-Origin", origin)
} else {
// If Origin is not included, allow all origins
w.Header().Set("Access-Control-Allow-Origin", "*")
}
// Enable gzip compression
if strings.Contains(r.Header.Get("Accept-Encoding"), "gzip") {
w.Header().Set("Content-Encoding", "gzip")
gz := gzip.NewWriter(w)
defer gz.Close()
gz.Write(content)
} else {
fmt.Fprint(w, string(content))
}
log.Printf("200: %s %s %s", r.RemoteAddr, r.Method, r.RequestURI)
})
// Start the server
addr := fmt.Sprintf(":%s", port)
log.Printf("Starting server on %s", addr)
log.Fatal(http.ListenAndServe(addr, nil))
}
// Returns true if the given IP address is an IPv4 or IPv6 loopback address
func isLoopbackIP(ip string) bool {
parsedIP := net.ParseIP(ip)
return parsedIP != nil && (parsedIP.IsLoopback() || parsedIP.Equal(net.IPv4(127, 0, 0, 1)) || parsedIP.Equal(net.IPv6loopback))
}
func checkFlags() {
var versionFlag bool
flag.BoolVar(&versionFlag, "v", false, "Print the version")
flag.BoolVar(&versionFlag, "version", false, "Print the version")
flag.Parse()
if versionFlag {
printVersion()
os.Exit(0)
}
}
func printVersion() {
fmt.Println("Version:", Version)
fmt.Println("Revision:", Revision)
}