|  | 
|  | 1 | +// Copyright 2024 The gVisor Authors. | 
|  | 2 | +// | 
|  | 3 | +// Licensed under the Apache License, Version 2.0 (the "License"); | 
|  | 4 | +// you may not use this file except in compliance with the License. | 
|  | 5 | +// You may obtain a copy of the License at | 
|  | 6 | +// | 
|  | 7 | +//     http://www.apache.org/licenses/LICENSE-2.0 | 
|  | 8 | +// | 
|  | 9 | +// Unless required by applicable law or agreed to in writing, software | 
|  | 10 | +// distributed under the License is distributed on an "AS IS" BASIS, | 
|  | 11 | +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | 
|  | 12 | +// See the License for the specific language governing permissions and | 
|  | 13 | +// limitations under the License. | 
|  | 14 | + | 
|  | 15 | +// A simple `curl`-like HTTP client that prints metrics after the request. | 
|  | 16 | +// All of its output is structured to be unambiguous even if stdout/stderr | 
|  | 17 | +// is combined, as is the case for Kubernetes logs. | 
|  | 18 | +// Useful for communicating with SGLang. | 
|  | 19 | +package main | 
|  | 20 | + | 
|  | 21 | +import ( | 
|  | 22 | +	"bufio" | 
|  | 23 | +	"bytes" | 
|  | 24 | +	"encoding/base64" | 
|  | 25 | +	"encoding/json" | 
|  | 26 | +	"flag" | 
|  | 27 | +	"fmt" | 
|  | 28 | +	"net/http" | 
|  | 29 | +	"os" | 
|  | 30 | +	"sort" | 
|  | 31 | +	"strings" | 
|  | 32 | +	"time" | 
|  | 33 | +) | 
|  | 34 | + | 
|  | 35 | +// Flags. | 
|  | 36 | +var ( | 
|  | 37 | +	url            = flag.String("url", "", "HTTP request URL.") | 
|  | 38 | +	method         = flag.String("method", "GET", "HTTP request method (GET or POST).") | 
|  | 39 | +	postDataBase64 = flag.String("post_base64", "", "HTTP request POST data in base64 format; ignored for GET requests.") | 
|  | 40 | +	timeout        = flag.Duration("timeout", 0, "HTTP request timeout; 0 for no timeout.") | 
|  | 41 | +) | 
|  | 42 | + | 
|  | 43 | +// bufSize is the size of buffers used for HTTP requests and responses. | 
|  | 44 | +const bufSize = 1024 * 1024 // 1MiB | 
|  | 45 | + | 
|  | 46 | +// fatalf crashes the program with a given error message. | 
|  | 47 | +func fatalf(format string, values ...any) { | 
|  | 48 | +	fmt.Fprintf(os.Stderr, "FATAL: "+format+"\n", values...) | 
|  | 49 | +	os.Exit(1) | 
|  | 50 | +} | 
|  | 51 | + | 
|  | 52 | +// Metrics contains the request metrics to export to JSON. | 
|  | 53 | +// This is parsed by the sglang library at `test/gpu/sglang/sglang.go`. | 
|  | 54 | +type Metrics struct { | 
|  | 55 | +	// ProgramStarted is the time when the program started. | 
|  | 56 | +	ProgramStarted time.Time `json:"program_started"` | 
|  | 57 | +	// RequestSent is the time when the HTTP request was sent. | 
|  | 58 | +	RequestSent time.Time `json:"request_sent"` | 
|  | 59 | +	// ResponseReceived is the time when the HTTP response headers were received. | 
|  | 60 | +	ResponseReceived time.Time `json:"response_received"` | 
|  | 61 | +	// FirstByteRead is the time when the first HTTP response body byte was read. | 
|  | 62 | +	FirstByteRead time.Time `json:"first_byte_read"` | 
|  | 63 | +	// LastByteRead is the time when the last HTTP response body byte was read. | 
|  | 64 | +	LastByteRead time.Time `json:"last_byte_read"` | 
|  | 65 | +} | 
|  | 66 | + | 
|  | 67 | +func main() { | 
|  | 68 | +	var metrics Metrics | 
|  | 69 | +	metrics.ProgramStarted = time.Now() | 
|  | 70 | +	flag.Parse() | 
|  | 71 | +	if *url == "" { | 
|  | 72 | +		fatalf("--url is required") | 
|  | 73 | +	} | 
|  | 74 | +	client := http.Client{ | 
|  | 75 | +		Transport: &http.Transport{ | 
|  | 76 | +			MaxIdleConns:    1, | 
|  | 77 | +			IdleConnTimeout: *timeout, | 
|  | 78 | +			ReadBufferSize:  bufSize, | 
|  | 79 | +			WriteBufferSize: bufSize, | 
|  | 80 | +		}, | 
|  | 81 | +		Timeout: *timeout, | 
|  | 82 | +	} | 
|  | 83 | +	var request *http.Request | 
|  | 84 | +	var err error | 
|  | 85 | +	switch *method { | 
|  | 86 | +	case "GET": | 
|  | 87 | +		request, err = http.NewRequest("GET", *url, nil) | 
|  | 88 | +	case "POST": | 
|  | 89 | +		postData, postDataErr := base64.StdEncoding.DecodeString(*postDataBase64) | 
|  | 90 | +		if postDataErr != nil { | 
|  | 91 | +			fatalf("cannot decode POST data: %v", postDataErr) | 
|  | 92 | +		} | 
|  | 93 | +		request, err = http.NewRequest("POST", *url, bytes.NewBuffer(postData)) | 
|  | 94 | +	default: | 
|  | 95 | +		err = fmt.Errorf("unknown method %q", *method) | 
|  | 96 | +	} | 
|  | 97 | +	if err != nil { | 
|  | 98 | +		fatalf("cannot create request: %v", err) | 
|  | 99 | +	} | 
|  | 100 | +	orderedReqHeaders := make([]string, 0, len(request.Header)) | 
|  | 101 | +	for k := range request.Header { | 
|  | 102 | +		orderedReqHeaders = append(orderedReqHeaders, k) | 
|  | 103 | +	} | 
|  | 104 | +	sort.Strings(orderedReqHeaders) | 
|  | 105 | +	for _, k := range orderedReqHeaders { | 
|  | 106 | +		for _, v := range request.Header[k] { | 
|  | 107 | +			fmt.Fprintf(os.Stderr, "REQHEADER: %s: %s\n", k, v) | 
|  | 108 | +		} | 
|  | 109 | +	} | 
|  | 110 | +	metrics.RequestSent = time.Now() | 
|  | 111 | +	resp, err := client.Do(request) | 
|  | 112 | +	metrics.ResponseReceived = time.Now() | 
|  | 113 | +	if err != nil { | 
|  | 114 | +		fatalf("cannot make request: %v", err) | 
|  | 115 | +	} | 
|  | 116 | +	gotFirstByte := false | 
|  | 117 | +	scanner := bufio.NewScanner(resp.Body) | 
|  | 118 | +	for scanner.Scan() { | 
|  | 119 | +		if !gotFirstByte { | 
|  | 120 | +			metrics.FirstByteRead = time.Now() | 
|  | 121 | +			gotFirstByte = true | 
|  | 122 | +		} | 
|  | 123 | +		if scanner.Text() == "" { | 
|  | 124 | +			continue | 
|  | 125 | +		} | 
|  | 126 | +		fmt.Printf("BODY: %q\n", strings.TrimPrefix(scanner.Text(), "data: ")) | 
|  | 127 | +	} | 
|  | 128 | +	// Check for any errors that may have occurred during scanning | 
|  | 129 | +	if err := scanner.Err(); err != nil { | 
|  | 130 | +		fatalf("error reading response body: %v", err) | 
|  | 131 | +	} | 
|  | 132 | +	metrics.LastByteRead = time.Now() | 
|  | 133 | +	if err := resp.Body.Close(); err != nil { | 
|  | 134 | +		fatalf("cannot close response body: %v", err) | 
|  | 135 | +	} | 
|  | 136 | +	orderedRespHeaders := make([]string, 0, len(resp.Header)) | 
|  | 137 | +	for k := range resp.Header { | 
|  | 138 | +		orderedRespHeaders = append(orderedRespHeaders, k) | 
|  | 139 | +	} | 
|  | 140 | +	sort.Strings(orderedRespHeaders) | 
|  | 141 | +	for _, k := range orderedRespHeaders { | 
|  | 142 | +		for _, v := range resp.Header[k] { | 
|  | 143 | +			fmt.Fprintf(os.Stderr, "RESPHEADER: %s: %s\n", k, v) | 
|  | 144 | +		} | 
|  | 145 | +	} | 
|  | 146 | +	metricsBytes, err := json.Marshal(&metrics) | 
|  | 147 | +	if err != nil { | 
|  | 148 | +		fatalf("cannot marshal metrics: %v", err) | 
|  | 149 | +	} | 
|  | 150 | +	fmt.Fprintf(os.Stderr, "STATS: %s\n", string(metricsBytes)) | 
|  | 151 | +} | 
0 commit comments