From 5aba3c9aa4ea9b3f388df125f9c66495b43c5c9e Mon Sep 17 00:00:00 2001 From: "Alessandro (Ale) Segala" <43508+ItalyPaleAle@users.noreply.github.com> Date: Wed, 31 May 2023 15:43:10 -0700 Subject: [PATCH] Migrate HTTP server to net/http (#6248) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Convert pprof server to net/http Signed-off-by: ItalyPaleAle <43508+ItalyPaleAle@users.noreply.github.com> * ๐Ÿงน Signed-off-by: ItalyPaleAle <43508+ItalyPaleAle@users.noreply.github.com> * Converted some middlewares to net/http Signed-off-by: ItalyPaleAle <43508+ItalyPaleAle@users.noreply.github.com> * Fixed tests Signed-off-by: ItalyPaleAle <43508+ItalyPaleAle@users.noreply.github.com> * Migrated metrics middleware Signed-off-by: ItalyPaleAle <43508+ItalyPaleAle@users.noreply.github.com> * Working on tracing middleware Signed-off-by: ItalyPaleAle <43508+ItalyPaleAle@users.noreply.github.com> * Completed converting middlewares Signed-off-by: ItalyPaleAle <43508+ItalyPaleAle@users.noreply.github.com> * ๐Ÿ’„ Signed-off-by: ItalyPaleAle <43508+ItalyPaleAle@users.noreply.github.com> * Changed the web server Signed-off-by: ItalyPaleAle <43508+ItalyPaleAle@users.noreply.github.com> * Removed streaming from HTTP API since it is not working at this time See #6246 Signed-off-by: ItalyPaleAle <43508+ItalyPaleAle@users.noreply.github.com> * Fixed useMaxBodySize middleware Signed-off-by: ItalyPaleAle <43508+ItalyPaleAle@users.noreply.github.com> * Fixed tracing headers Signed-off-by: ItalyPaleAle <43508+ItalyPaleAle@users.noreply.github.com> * Should fix tests Signed-off-by: ItalyPaleAle <43508+ItalyPaleAle@users.noreply.github.com> * Should actually fix tests Signed-off-by: ItalyPaleAle <43508+ItalyPaleAle@users.noreply.github.com> * Should fix limiting request body Signed-off-by: ItalyPaleAle <43508+ItalyPaleAle@users.noreply.github.com> * Add unit tests for limitreadcloser Signed-off-by: ItalyPaleAle <43508+ItalyPaleAle@users.noreply.github.com> * Document changes in responsewriter Signed-off-by: ItalyPaleAle <43508+ItalyPaleAle@users.noreply.github.com> * status codes use net/http Signed-off-by: ItalyPaleAle <43508+ItalyPaleAle@users.noreply.github.com> * Method constants use net/http too Signed-off-by: ItalyPaleAle <43508+ItalyPaleAle@users.noreply.github.com> * Handle ErrServerClosed better Signed-off-by: ItalyPaleAle <43508+ItalyPaleAle@users.noreply.github.com> * Tweak SDK pipeline Signed-off-by: ItalyPaleAle <43508+ItalyPaleAle@users.noreply.github.com> * Added InvokeMethodRequest.WithHTTPHeaders Signed-off-by: ItalyPaleAle <43508+ItalyPaleAle@users.noreply.github.com> * Fix propagation of trace context between net/http and fasthttp Signed-off-by: ItalyPaleAle <43508+ItalyPaleAle@users.noreply.github.com> * Fix handling of error responses from app in CallLocalStream Signed-off-by: ItalyPaleAle <43508+ItalyPaleAle@users.noreply.github.com> * ๐Ÿ’„ Signed-off-by: ItalyPaleAle <43508+ItalyPaleAle@users.noreply.github.com> --------- Signed-off-by: ItalyPaleAle <43508+ItalyPaleAle@users.noreply.github.com> Co-authored-by: Loong Dai --- .github/workflows/dapr-test-sdk.yml | 6 - go.mod | 4 +- go.sum | 4 +- pkg/diagnostics/http_monitoring.go | 34 +- pkg/diagnostics/http_monitoring_test.go | 80 ++-- pkg/diagnostics/http_tracing.go | 154 ++++--- pkg/diagnostics/http_tracing_test.go | 178 ++++---- pkg/diagnostics/utils/trace_utils.go | 42 +- pkg/diagnostics/utils/trace_utils_test.go | 18 +- pkg/grpc/api_daprinternal.go | 2 + pkg/grpc/api_test.go | 36 +- pkg/http/api.go | 392 ++++++++---------- pkg/http/api_test.go | 37 +- pkg/http/config.go | 4 +- pkg/http/responses.go | 47 --- pkg/http/responses_test.go | 10 - pkg/http/server.go | 145 ++++--- pkg/http/server_test.go | 78 ++-- pkg/http/universalapi.go | 6 +- pkg/messaging/direct_messaging.go | 2 +- pkg/messaging/v1/invoke_method_request.go | 7 + pkg/messaging/v1/invoke_method_response.go | 3 + pkg/messaging/v1/util.go | 32 +- pkg/messaging/v1/util_test.go | 19 +- pkg/runtime/runtime.go | 4 +- tests/apps/service_invocation/app.go | 11 +- .../service_invocation_test.go | 22 +- utils/fasthttpadaptor/adaptor.go | 20 +- utils/nethttpadaptor/nethttpadaptor.go | 23 +- utils/responsewriter/README.md | 29 ++ utils/responsewriter/response_writer.go | 179 ++++++++ utils/responsewriter/response_writer_test.go | 215 ++++++++++ utils/streams/limitreadcloser.go | 42 +- utils/streams/limitreadcloser_test.go | 124 ++++++ 34 files changed, 1250 insertions(+), 759 deletions(-) create mode 100644 utils/responsewriter/README.md create mode 100644 utils/responsewriter/response_writer.go create mode 100644 utils/responsewriter/response_writer_test.go create mode 100644 utils/streams/limitreadcloser_test.go diff --git a/.github/workflows/dapr-test-sdk.yml b/.github/workflows/dapr-test-sdk.yml index de471313d97..fe865ff5802 100644 --- a/.github/workflows/dapr-test-sdk.yml +++ b/.github/workflows/dapr-test-sdk.yml @@ -79,8 +79,6 @@ jobs: with: header: ${{ github.run_id }}-python number: ${{ env.PR_NUMBER }} - hide: true - hide_classify: OUTDATED GITHUB_TOKEN: ${{ secrets.DAPR_BOT_TOKEN }} message: | # Dapr SDK Python test @@ -210,8 +208,6 @@ jobs: with: header: ${{ github.run_id }}-java number: ${{ env.PR_NUMBER }} - hide: true - hide_classify: OUTDATED GITHUB_TOKEN: ${{ secrets.DAPR_BOT_TOKEN }} message: | # Dapr SDK Java test @@ -383,8 +379,6 @@ jobs: with: header: ${{ github.run_id }}-js number: ${{ env.PR_NUMBER }} - hide: true - hide_classify: OUTDATED GITHUB_TOKEN: ${{ secrets.DAPR_BOT_TOKEN }} message: | # Dapr SDK JS test diff --git a/go.mod b/go.mod index 4cf993c4cdb..29f454d526a 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,6 @@ go 1.20 require ( contrib.go.opencensus.io/exporter/prometheus v0.4.2 - github.com/AdhityaRamadhanus/fasthttpcors v0.0.0-20170121111917-d4c07198763a github.com/PaesslerAG/jsonpath v0.1.1 github.com/PuerkitoBio/purell v1.2.0 github.com/argoproj/argo-rollouts v1.4.1 @@ -13,6 +12,7 @@ require ( github.com/dapr/kit v0.0.5 github.com/evanphx/json-patch/v5 v5.6.0 github.com/fasthttp/router v1.4.18 + github.com/go-chi/cors v1.2.1 github.com/go-logr/logr v1.2.4 github.com/golang/mock v1.6.0 github.com/golang/protobuf v1.5.3 @@ -24,7 +24,6 @@ require ( github.com/grpc-ecosystem/go-grpc-middleware v1.4.0 github.com/hashicorp/go-hclog v1.5.0 github.com/hashicorp/go-msgpack/v2 v2.1.0 - github.com/hashicorp/go-multierror v1.1.1 github.com/hashicorp/golang-lru/v2 v2.0.2 github.com/hashicorp/raft v1.4.0 github.com/hashicorp/raft-boltdb v0.0.0-20230125174641-2a8082862702 @@ -241,6 +240,7 @@ require ( github.com/hashicorp/go-cleanhttp v0.5.2 // indirect github.com/hashicorp/go-immutable-radix v1.3.1 // indirect github.com/hashicorp/go-msgpack v0.5.5 // indirect + github.com/hashicorp/go-multierror v1.1.1 // indirect github.com/hashicorp/go-rootcerts v1.0.2 // indirect github.com/hashicorp/go-uuid v1.0.3 // indirect github.com/hashicorp/golang-lru v0.5.4 // indirect diff --git a/go.sum b/go.sum index b9c80963467..2349e7f65fe 100644 --- a/go.sum +++ b/go.sum @@ -66,8 +66,6 @@ github.com/99designs/go-keychain v0.0.0-20191008050251-8e49817e8af4 h1:/vQbFIOMb github.com/99designs/go-keychain v0.0.0-20191008050251-8e49817e8af4/go.mod h1:hN7oaIRCjzsZ2dE+yG5k+rsdt3qcwykqK6HVGcKwsw4= github.com/99designs/keyring v1.2.1 h1:tYLp1ULvO7i3fI5vE21ReQuj99QFSs7lGm0xWyJo87o= github.com/99designs/keyring v1.2.1/go.mod h1:fc+wB5KTk9wQ9sDx0kFXB3A0MaeGHM9AwRStKOQ5vOA= -github.com/AdhityaRamadhanus/fasthttpcors v0.0.0-20170121111917-d4c07198763a h1:XVdatQFSP2YhJGjqLLIfW8QBk4loz/SCe/PxkXDiW+s= -github.com/AdhityaRamadhanus/fasthttpcors v0.0.0-20170121111917-d4c07198763a/go.mod h1:C0A1KeiVHs+trY6gUTPhhGammbrZ30ZfXRW/nuT7HLw= github.com/AthenZ/athenz v1.10.39 h1:mtwHTF/v62ewY2Z5KWhuZgVXftBej1/Tn80zx4DcawY= github.com/AthenZ/athenz v1.10.39/go.mod h1:3Tg8HLsiQZp81BJY58JBeU2BR6B/H4/0MQGfCwhHNEA= github.com/Azure/azure-sdk-for-go v68.0.0+incompatible h1:fcYLmCpyNYRnvJbPerq7U0hS+6+I79yEDJBqVNcqUzU= @@ -529,6 +527,8 @@ github.com/gin-gonic/gin v1.7.7/go.mod h1:axIBovoeJpVj8S3BwE0uPMTeReE4+AfFtqpqaZ github.com/go-asn1-ber/asn1-ber v1.3.1/go.mod h1:hEBeB/ic+5LoWskz+yKT7vGhhPYkProFKoKdwZRWMe0= github.com/go-chi/chi v4.0.2+incompatible/go.mod h1:eB3wogJHnLi3x/kFX2A+IbTBlXxmMeXJVKy9tTv1XzQ= github.com/go-chi/chi/v5 v5.0.7/go.mod h1:DslCQbL2OYiznFReuXYUmQ2hGd1aDpCnlMNITLSKoi8= +github.com/go-chi/cors v1.2.1 h1:xEC8UT3Rlp2QuWNEr4Fs/c2EAGVKBwy/1vHx3bppil4= +github.com/go-chi/cors v1.2.1/go.mod h1:sSbTewc+6wYHBBCW7ytsFSn836hqM7JxpglAy2Vzc58= github.com/go-co-op/gocron v1.9.0/go.mod h1:DbJm9kdgr1sEvWpHCA7dFFs/PGHPMil9/97EXCRPr4k= github.com/go-errors/errors v1.0.1/go.mod h1:f4zRHt4oKfwPJE5k8C9vpYG+aDHdBFUsgrm6/TyX73Q= github.com/go-errors/errors v1.4.2 h1:J6MZopCL4uSllY1OfXM374weqZFFItUbrImctkmUxIA= diff --git a/pkg/diagnostics/http_monitoring.go b/pkg/diagnostics/http_monitoring.go index 967d32ca73f..5415f5b4cca 100644 --- a/pkg/diagnostics/http_monitoring.go +++ b/pkg/diagnostics/http_monitoring.go @@ -15,16 +15,17 @@ package diagnostics import ( "context" + "net/http" "strconv" "strings" "time" - "github.com/valyala/fasthttp" "go.opencensus.io/stats" "go.opencensus.io/stats/view" "go.opencensus.io/tag" diagUtils "github.com/dapr/dapr/pkg/diagnostics/utils" + "github.com/dapr/dapr/utils/responsewriter" ) // To track the metrics for fasthttp using opencensus, this implementation is inspired by @@ -209,27 +210,32 @@ func (h *httpMetrics) Init(appID string) error { ) } -// FastHTTPMiddleware is the middleware to track http server-side requests. -func (h *httpMetrics) FastHTTPMiddleware(next fasthttp.RequestHandler) fasthttp.RequestHandler { - return func(ctx *fasthttp.RequestCtx) { - reqContentSize := ctx.Request.Header.ContentLength() - if reqContentSize < 0 { - reqContentSize = 0 +// HTTPMiddleware is the middleware to track HTTP server-side requests. +func (h *httpMetrics) HTTPMiddleware(next http.HandlerFunc) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + var reqContentSize int64 + if cl := r.Header.Get("content-length"); cl != "" { + reqContentSize, _ = strconv.ParseInt(cl, 10, 64) + if reqContentSize < 0 { + reqContentSize = 0 + } } - method := string(ctx.Method()) - path := h.convertPathToMetricLabel(string(ctx.Path())) + path := h.convertPathToMetricLabel(r.URL.Path) - h.ServerRequestReceived(ctx, method, path, int64(reqContentSize)) + h.ServerRequestReceived(r.Context(), r.Method, path, reqContentSize) + + // Wrap the writer in a ResponseWriter so we can collect stats such as status code and size + w = responsewriter.EnsureResponseWriter(w) start := time.Now() - next(ctx) + next(w, r) - status := strconv.Itoa(ctx.Response.StatusCode()) elapsed := float64(time.Since(start) / time.Millisecond) - respSize := int64(len(ctx.Response.Body())) - h.ServerRequestCompleted(ctx, method, path, status, respSize, elapsed) + status := strconv.Itoa(w.(responsewriter.ResponseWriter).Status()) + respSize := int64(w.(responsewriter.ResponseWriter).Size()) + h.ServerRequestCompleted(r.Context(), r.Method, path, status, respSize, elapsed) } } diff --git a/pkg/diagnostics/http_monitoring_test.go b/pkg/diagnostics/http_monitoring_test.go index a9b660a4841..11cfe2c56a0 100644 --- a/pkg/diagnostics/http_monitoring_test.go +++ b/pkg/diagnostics/http_monitoring_test.go @@ -1,34 +1,34 @@ package diagnostics import ( - "net" + "net/http" + "net/http/httptest" + "strconv" + "strings" "testing" "time" "github.com/stretchr/testify/assert" - "github.com/valyala/fasthttp" "go.opencensus.io/stats/view" ) -func TestFastHTTPMiddleware(t *testing.T) { +func TestHTTPMiddleware(t *testing.T) { requestBody := "fake_requestDaprBody" responseBody := "fake_responseDaprBody" - testRequestCtx := fakeFastHTTPRequestCtx(requestBody) - - fakeHandler := func(ctx *fasthttp.RequestCtx) { - time.Sleep(100 * time.Millisecond) - ctx.Response.SetBodyRaw([]byte(responseBody)) - } + testRequest := fakeHTTPRequest(requestBody) // create test httpMetrics testHTTP := newHTTPMetrics() testHTTP.Init("fakeID") - handler := testHTTP.FastHTTPMiddleware(fakeHandler) + handler := testHTTP.HTTPMiddleware(func(w http.ResponseWriter, r *http.Request) { + time.Sleep(100 * time.Millisecond) + w.Write([]byte(responseBody)) + }) // act - handler(testRequestCtx) + handler(httptest.NewRecorder(), testRequest) // assert rows, err := view.RetrieveData("http/server/request_count") @@ -46,12 +46,12 @@ func TestFastHTTPMiddleware(t *testing.T) { assert.Equal(t, 1, len(rows)) assert.Equal(t, "app_id", rows[0].Tags[0].Key.Name()) assert.Equal(t, "fakeID", rows[0].Tags[0].Value) - assert.True(t, (rows[0].Data).(*view.DistributionData).Min == float64(len([]byte(requestBody)))) + assert.Equal(t, float64(len(requestBody)), (rows[0].Data).(*view.DistributionData).Min) rows, err = view.RetrieveData("http/server/response_bytes") assert.NoError(t, err) assert.Equal(t, 1, len(rows)) - assert.True(t, (rows[0].Data).(*view.DistributionData).Min == float64(len([]byte(responseBody)))) + assert.Equal(t, float64(len(responseBody)), (rows[0].Data).(*view.DistributionData).Min) rows, err = view.RetrieveData("http/server/latency") assert.NoError(t, err) @@ -59,16 +59,11 @@ func TestFastHTTPMiddleware(t *testing.T) { assert.True(t, (rows[0].Data).(*view.DistributionData).Min >= 100.0) } -func TestFastHTTPMiddlewareWhenMetricsDisabled(t *testing.T) { +func TestHTTPMiddlewareWhenMetricsDisabled(t *testing.T) { requestBody := "fake_requestDaprBody" responseBody := "fake_responseDaprBody" - testRequestCtx := fakeFastHTTPRequestCtx(requestBody) - - fakeHandler := func(ctx *fasthttp.RequestCtx) { - time.Sleep(100 * time.Millisecond) - ctx.Response.SetBodyRaw([]byte(responseBody)) - } + testRequest := fakeHTTPRequest(requestBody) // create test httpMetrics testHTTP := newHTTPMetrics() @@ -79,10 +74,13 @@ func TestFastHTTPMiddlewareWhenMetricsDisabled(t *testing.T) { views := []*view.View{v} view.Unregister(views...) - handler := testHTTP.FastHTTPMiddleware(fakeHandler) + handler := testHTTP.HTTPMiddleware(func(w http.ResponseWriter, r *http.Request) { + time.Sleep(100 * time.Millisecond) + w.Write([]byte(responseBody)) + }) // act - handler(testRequestCtx) + handler(httptest.NewRecorder(), testRequest) // assert rows, err := view.RetrieveData("http/server/request_count") @@ -121,34 +119,16 @@ func TestConvertPathToMethodName(t *testing.T) { } } -func fakeFastHTTPRequestCtx(expectedBody string) *fasthttp.RequestCtx { - expectedMethod := fasthttp.MethodPost - expectedRequestURI := "/invoke/method/testmethod" - expectedTransferEncoding := "encoding" - expectedHost := "dapr.io" - expectedRemoteAddr := "1.2.3.4:6789" - expectedHeader := map[string]string{ - "Correlation-ID": "e6f4bb20-96c0-426a-9e3d-991ba16a3ebb", - "XXX-Remote-Addr": "192.168.0.100", +func fakeHTTPRequest(body string) *http.Request { + req, err := http.NewRequest(http.MethodPost, "http://dapr.io/invoke/method/testmethod", strings.NewReader(body)) + if err != nil { + panic(err) } + req.Header.Set("Correlation-ID", "e6f4bb20-96c0-426a-9e3d-991ba16a3ebb") + req.Header.Set("XXX-Remote-Addr", "192.168.0.100") + req.Header.Set("Transfer-Encoding", "encoding") + // This is normally set automatically when the request is sent to a server, but in this case we are not using a real server + req.Header.Set("Content-Length", strconv.FormatInt(req.ContentLength, 10)) - var ctx fasthttp.RequestCtx - var req fasthttp.Request - - req.Header.SetMethod(expectedMethod) - req.SetRequestURI(expectedRequestURI) - req.Header.SetHost(expectedHost) - req.Header.Add(fasthttp.HeaderTransferEncoding, expectedTransferEncoding) - req.Header.SetContentLength(len([]byte(expectedBody))) - req.BodyWriter().Write([]byte(expectedBody)) //nolint:errcheck - - for k, v := range expectedHeader { - req.Header.Set(k, v) - } - - remoteAddr, _ := net.ResolveTCPAddr("tcp", expectedRemoteAddr) - - ctx.Init(&req, remoteAddr, nil) - - return &ctx + return req } diff --git a/pkg/diagnostics/http_tracing.go b/pkg/diagnostics/http_tracing.go index b01c0fc0d7f..7037e31e02a 100644 --- a/pkg/diagnostics/http_tracing.go +++ b/pkg/diagnostics/http_tracing.go @@ -15,17 +15,16 @@ package diagnostics import ( "net/http" - "net/textproto" "strconv" "strings" "github.com/valyala/fasthttp" - otelcodes "go.opentelemetry.io/otel/codes" "go.opentelemetry.io/otel/trace" "github.com/dapr/dapr/pkg/config" diagUtils "github.com/dapr/dapr/pkg/diagnostics/utils" + "github.com/dapr/dapr/utils/responsewriter" ) // We have leveraged the code from opencensus-go plugin to adhere the w3c trace context. @@ -39,67 +38,76 @@ const ( ) // HTTPTraceMiddleware sets the trace context or starts the trace client span based on request. -func HTTPTraceMiddleware(next fasthttp.RequestHandler, appID string, spec config.TracingSpec) fasthttp.RequestHandler { - return func(ctx *fasthttp.RequestCtx) { - path := string(ctx.Request.URI().Path()) +func HTTPTraceMiddleware(next http.Handler, appID string, spec config.TracingSpec) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + path := r.URL.Path if isHealthzRequest(path) { - next(ctx) + next.ServeHTTP(w, r) return } - ctx, span := startTracingClientSpanFromHTTPContext(ctx, path, spec) + span := startTracingClientSpanFromHTTPRequest(r, path, spec) - next(ctx) + // Wrap the writer in a ResponseWriter so we can collect stats such as status code and size + rw := responsewriter.EnsureResponseWriter(w) - // Add span attributes only if it is sampled, which reduced the perf impact. - if span.SpanContext().IsSampled() { - AddAttributesToSpan(span, userDefinedHTTPHeaders(ctx)) - spanAttr := spanAttributesMapFromHTTPContext(ctx) - AddAttributesToSpan(span, spanAttr) + // Before the response is written, we need to add the tracing headers + rw.Before(func(rw responsewriter.ResponseWriter) { + // Add span attributes only if it is sampled, which reduced the perf impact. + if span.SpanContext().IsSampled() { + AddAttributesToSpan(span, userDefinedHTTPHeaders(r)) + spanAttr := spanAttributesMapFromHTTPContext(rw, r) + AddAttributesToSpan(span, spanAttr) - // Correct the span name based on API. - if sname, ok := spanAttr[daprAPISpanNameInternal]; ok { - span.SetName(sname) + // Correct the span name based on API. + if sname, ok := spanAttr[daprAPISpanNameInternal]; ok { + span.SetName(sname) + } } - } - // Check if response has traceparent header and add if absent - if ctx.Response.Header.Peek(TraceparentHeader) == nil { - span = diagUtils.SpanFromContext(ctx) - // Using Header.Set here because we want to overwrite any header that may exist - SpanContextToHTTPHeaders(span.SpanContext(), ctx.Response.Header.Set) - } + // Check if response has traceparent header and add if absent + if rw.Header().Get(TraceparentHeader) == "" { + span = diagUtils.SpanFromContext(r.Context()) + // Using Header.Set here because we know the traceparent header isn't set + SpanContextToHTTPHeaders(span.SpanContext(), rw.Header().Set) + } - UpdateSpanStatusFromHTTPStatus(span, ctx.Response.StatusCode()) - span.End() - } + UpdateSpanStatusFromHTTPStatus(span, rw.Status()) + span.End() + }) + + next.ServeHTTP(rw, r) + }) } // userDefinedHTTPHeaders returns dapr- prefixed header from incoming metadata. // Users can add dapr- prefixed headers that they want to see in span attributes. -func userDefinedHTTPHeaders(reqCtx *fasthttp.RequestCtx) map[string]string { - m := map[string]string{} +func userDefinedHTTPHeaders(r *http.Request) map[string]string { + // Allocate this with enough memory for a pessimistic case + m := make(map[string]string, len(r.Header)) - reqCtx.Request.Header.VisitAll(func(key []byte, value []byte) { - if len(key) < (len(daprHeaderPrefix) + 1) { - return + for key, vSlice := range r.Header { + if len(vSlice) < 1 || len(key) < (len(daprHeaderBinSuffix)+1) { + continue } - ks := strings.ToLower(string(key)) - if ks[0:len(daprHeaderPrefix)] == daprHeaderPrefix { - m[ks] = string(value) + + key = strings.ToLower(key) + if strings.HasPrefix(key, daprHeaderPrefix) { + // Get the last value for each key + m[key] = vSlice[len(vSlice)-1] } - }) + } return m } -func startTracingClientSpanFromHTTPContext(ctx *fasthttp.RequestCtx, spanName string, spec config.TracingSpec) (*fasthttp.RequestCtx, trace.Span) { - sc, _ := SpanContextFromRequest(&ctx.Request) - netCtx := trace.ContextWithRemoteSpanContext(ctx, sc) +func startTracingClientSpanFromHTTPRequest(r *http.Request, spanName string, spec config.TracingSpec) trace.Span { + sc := SpanContextFromRequest(r) + ctx := trace.ContextWithRemoteSpanContext(r.Context(), sc) kindOption := trace.WithSpanKind(trace.SpanKindClient) - _, span := tracer.Start(netCtx, spanName, kindOption) - diagUtils.SpanToFastHTTPContext(ctx, span) - return ctx, span + _, span := tracer.Start(ctx, spanName, kindOption) + diagUtils.AddSpanToRequest(r, span) + return span } func StartProducerSpanChildFromParent(ctx *fasthttp.RequestCtx, parentSpan trace.Span) trace.Span { @@ -111,17 +119,17 @@ func StartProducerSpanChildFromParent(ctx *fasthttp.RequestCtx, parentSpan trace } // SpanContextFromRequest extracts a span context from incoming requests. -func SpanContextFromRequest(req *fasthttp.Request) (sc trace.SpanContext, ok bool) { - h, ok := getRequestHeader(req, TraceparentHeader) - if !ok { - return trace.SpanContext{}, false +func SpanContextFromRequest(r *http.Request) (sc trace.SpanContext) { + h := r.Header.Get(TraceparentHeader) + if h == "" { + return trace.SpanContext{} } - sc, ok = SpanContextFromW3CString(h) + sc, ok := SpanContextFromW3CString(h) if ok { - ts := tracestateFromRequest(req) + ts := tracestateFromRequest(r) sc = sc.WithTraceState(*ts) } - return sc, ok + return sc } func isHealthzRequest(name string) bool { @@ -152,17 +160,8 @@ func traceStatusFromHTTPCode(httpCode int) (otelcodes.Code, string) { return code, "" } -func getRequestHeader(req *fasthttp.Request, name string) (string, bool) { - s := string(req.Header.Peek(textproto.CanonicalMIMEHeaderKey(name))) - if s == "" { - return "", false - } - - return s, true -} - -func tracestateFromRequest(req *fasthttp.Request) *trace.TraceState { - h, _ := getRequestHeader(req, TracestateHeader) +func tracestateFromRequest(r *http.Request) *trace.TraceState { + h := r.Header.Get(TracestateHeader) return TraceStateFromW3CString(h) } @@ -183,11 +182,6 @@ func tracestateToHeader(sc trace.SpanContext, setHeader func(string, string)) { } } -func getContextValue(ctx *fasthttp.RequestCtx, key string) string { - v, _ := ctx.UserValue(key).(string) - return v -} - func getAPIComponent(apiPath string) (string, string) { // Dapr API reference : https://docs.dapr.io/reference/api/ // example : apiPath /v1.0/state/statestore @@ -205,11 +199,11 @@ func getAPIComponent(apiPath string) (string, string) { return tokens[1], tokens[2] } -func spanAttributesMapFromHTTPContext(ctx *fasthttp.RequestCtx) map[string]string { +func spanAttributesMapFromHTTPContext(rw responsewriter.ResponseWriter, r *http.Request) map[string]string { // Span Attribute reference https://github.com/open-telemetry/opentelemetry-specification/tree/master/specification/trace/semantic_conventions - path := string(ctx.Request.URI().Path()) - method := string(ctx.Request.Header.Method()) - statusCode := ctx.Response.StatusCode() + path := r.URL.Path + method := r.Method + statusCode := rw.Status() m := map[string]string{} _, componentType := getAPIComponent(path) @@ -218,29 +212,29 @@ func spanAttributesMapFromHTTPContext(ctx *fasthttp.RequestCtx) map[string]strin switch componentType { case "state": dbType = stateBuildingBlockType - m[dbNameSpanAttributeKey] = getContextValue(ctx, "storeName") + m[dbNameSpanAttributeKey] = rw.UserValueString("storeName") case "secrets": dbType = secretBuildingBlockType - m[dbNameSpanAttributeKey] = getContextValue(ctx, "secretStoreName") + m[dbNameSpanAttributeKey] = rw.UserValueString("secretStoreName") case "bindings": dbType = bindingBuildingBlockType - m[dbNameSpanAttributeKey] = getContextValue(ctx, "name") + m[dbNameSpanAttributeKey] = rw.UserValueString("name") case "invoke": m[gRPCServiceSpanAttributeKey] = daprGRPCServiceInvocationService - targetID := getContextValue(ctx, "id") + targetID := rw.UserValueString("id") m[netPeerNameSpanAttributeKey] = targetID - m[daprAPISpanNameInternal] = "CallLocal/" + targetID + "/" + getContextValue(ctx, "method") + m[daprAPISpanNameInternal] = "CallLocal/" + targetID + "/" + rw.UserValueString("method") case "publish": m[messagingSystemSpanAttributeKey] = pubsubBuildingBlockType - m[messagingDestinationSpanAttributeKey] = getContextValue(ctx, "topic") + m[messagingDestinationSpanAttributeKey] = rw.UserValueString("topic") m[messagingDestinationKindSpanAttributeKey] = messagingDestinationTopicKind case "actors": - dbType = populateActorParams(ctx, m) + dbType = populateActorParams(rw, r, m) } // Populate the rest of database attributes. @@ -258,14 +252,14 @@ func spanAttributesMapFromHTTPContext(ctx *fasthttp.RequestCtx) map[string]strin return m } -func populateActorParams(ctx *fasthttp.RequestCtx, m map[string]string) string { - actorType := getContextValue(ctx, "actorType") - actorID := getContextValue(ctx, "actorId") +func populateActorParams(rw responsewriter.ResponseWriter, r *http.Request, m map[string]string) string { + actorType := rw.UserValueString("actorType") + actorID := rw.UserValueString("actorId") if actorType == "" || actorID == "" { return "" } - path := string(ctx.Request.URI().Path()) + path := r.URL.Path // Split up to 7 delimiters in '/v1.0/actors/{actorType}/{actorId}/method/{method}' // to get component api type and value tokens := strings.SplitN(path, "/", 7) @@ -280,7 +274,7 @@ func populateActorParams(ctx *fasthttp.RequestCtx, m map[string]string) string { case "method": m[gRPCServiceSpanAttributeKey] = daprGRPCServiceInvocationService m[netPeerNameSpanAttributeKey] = m[daprAPIActorTypeID] - m[daprAPISpanNameInternal] = "CallActor/" + actorType + "/" + getContextValue(ctx, "method") + m[daprAPISpanNameInternal] = "CallActor/" + actorType + "/" + rw.UserValueString("method") case "state": dbType = stateBuildingBlockType diff --git a/pkg/diagnostics/http_tracing_test.go b/pkg/diagnostics/http_tracing_test.go index 3d87dc0e1a1..f5caa3efc1c 100644 --- a/pkg/diagnostics/http_tracing_test.go +++ b/pkg/diagnostics/http_tracing_test.go @@ -16,14 +16,16 @@ package diagnostics import ( "context" "fmt" - "net" - "net/textproto" + "net/http" + "net/http/httptest" + "net/url" + "strconv" "strings" "testing" "time" "github.com/stretchr/testify/assert" - "github.com/valyala/fasthttp" + "github.com/stretchr/testify/require" "go.opentelemetry.io/otel" otelcodes "go.opentelemetry.io/otel/codes" @@ -32,6 +34,7 @@ import ( "github.com/dapr/dapr/pkg/config" diagUtils "github.com/dapr/dapr/pkg/diagnostics/utils" + "github.com/dapr/dapr/utils/responsewriter" ) func TestSpanContextFromRequest(t *testing.T) { @@ -83,10 +86,12 @@ func TestSpanContextFromRequest(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - req := &fasthttp.Request{} + req := &http.Request{ + Header: make(http.Header), + } req.Header.Add("traceparent", tt.header) - gotSc, _ := SpanContextFromRequest(req) + gotSc := SpanContextFromRequest(req) wantSc := trace.NewSpanContext(tt.wantSc) assert.Equalf(t, wantSc, gotSc, "SpanContextFromRequest gotSc = %v, want %v", gotSc, wantSc) }) @@ -94,12 +99,14 @@ func TestSpanContextFromRequest(t *testing.T) { } func TestUserDefinedHTTPHeaders(t *testing.T) { - reqCtx := &fasthttp.RequestCtx{} - reqCtx.Request.Header.Add("dapr-userdefined-1", "value1") - reqCtx.Request.Header.Add("dapr-userdefined-2", "value2") - reqCtx.Request.Header.Add("no-attr", "value3") + req := &http.Request{ + Header: make(http.Header), + } + req.Header.Add("dapr-userdefined-1", "value1") + req.Header.Add("dapr-userdefined-2", "value2") + req.Header.Add("no-attr", "value3") - m := userDefinedHTTPHeaders(reqCtx) + m := userDefinedHTTPHeaders(req) assert.Equal(t, 2, len(m)) assert.Equal(t, "value1", m["dapr-userdefined-1"]) @@ -120,22 +127,22 @@ func TestSpanContextToHTTPHeaders(t *testing.T) { } for _, tt := range tests { t.Run("SpanContextToHTTPHeaders", func(t *testing.T) { - req := &fasthttp.Request{} + req, _ := http.NewRequest(http.MethodGet, "http://test.local/path", nil) wantSc := trace.NewSpanContext(tt.sc) SpanContextToHTTPHeaders(wantSc, req.Header.Set) - got, _ := SpanContextFromRequest(req) + got := SpanContextFromRequest(req) assert.Equalf(t, got, wantSc, "SpanContextToHTTPHeaders() got = %v, want %v", got, wantSc) }) } t.Run("empty span context", func(t *testing.T) { - req := &fasthttp.Request{} + req, _ := http.NewRequest(http.MethodGet, "http://test.local/path", nil) sc := trace.SpanContext{} SpanContextToHTTPHeaders(sc, req.Header.Set) - assert.Nil(t, req.Header.Peek(TraceparentHeader)) + assert.Empty(t, req.Header.Get(TraceparentHeader)) }) } @@ -245,25 +252,25 @@ func TestGetSpanAttributesMapFromHTTPContext(t *testing.T) { for _, tt := range tests { t.Run(tt.path, func(t *testing.T) { + var err error req := getTestHTTPRequest() - resp := &fasthttp.Response{} - resp.SetStatusCode(200) - req.SetRequestURI(tt.path) - reqCtx := &fasthttp.RequestCtx{} - req.CopyTo(&reqCtx.Request) - - reqCtx.SetUserValue("storeName", "statestore") - reqCtx.SetUserValue("secretStoreName", "keyvault") - reqCtx.SetUserValue("topic", "topicA") - reqCtx.SetUserValue("name", "kafka") - reqCtx.SetUserValue("id", "fakeApp") - reqCtx.SetUserValue("method", "add") - reqCtx.SetUserValue("actorType", "demo_actor") - reqCtx.SetUserValue("actorId", "1") - - got := spanAttributesMapFromHTTPContext(reqCtx) + resp := responsewriter.EnsureResponseWriter(httptest.NewRecorder()) + resp.WriteHeader(http.StatusOK) + req.URL, err = url.Parse("http://test.local" + tt.path) + require.NoError(t, err) + + resp.SetUserValue("storeName", "statestore") + resp.SetUserValue("secretStoreName", "keyvault") + resp.SetUserValue("topic", "topicA") + resp.SetUserValue("name", "kafka") + resp.SetUserValue("id", "fakeApp") + resp.SetUserValue("method", "add") + resp.SetUserValue("actorType", "demo_actor") + resp.SetUserValue("actorId", "1") + + got := spanAttributesMapFromHTTPContext(responsewriter.EnsureResponseWriter(resp), req) for k, v := range tt.out { - assert.Equal(t, v, got[k]) + assert.Equalf(t, v, got[k], "key: %v", k) } }) } @@ -283,11 +290,11 @@ func TestSpanContextToResponse(t *testing.T) { } for _, tt := range tests { t.Run("SpanContextToResponse", func(t *testing.T) { - resp := &fasthttp.Response{} + resp := httptest.NewRecorder() wantSc := trace.NewSpanContext(tt.scConfig) - SpanContextToHTTPHeaders(wantSc, resp.Header.Set) + SpanContextToHTTPHeaders(wantSc, resp.Header().Set) - h := string(resp.Header.Peek(textproto.CanonicalMIMEHeaderKey("traceparent"))) + h := resp.Header().Get("traceparent") got, _ := SpanContextFromW3CString(h) assert.Equalf(t, got, wantSc, "SpanContextToResponse() got = %v, want %v", got, wantSc) @@ -295,13 +302,11 @@ func TestSpanContextToResponse(t *testing.T) { } } -func getTestHTTPRequest() *fasthttp.Request { - req := &fasthttp.Request{} - req.SetRequestURI("/v1.0/state/statestore/key") +func getTestHTTPRequest() *http.Request { + req, _ := http.NewRequest(http.MethodGet, "http://test.local/v1.0/state/statestore/key", nil) req.Header.Set("dapr-testheaderkey", "dapr-testheadervalue") req.Header.Set("x-testheaderkey1", "dapr-testheadervalue") req.Header.Set("daprd-testheaderkey2", "dapr-testheadervalue") - req.Header.SetMethod(fasthttp.MethodGet) var ( tid = trace.TraceID{1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 4, 8, 16, 32, 64, 128} @@ -322,10 +327,10 @@ func TestHTTPTraceMiddleware(t *testing.T) { requestBody := "fake_requestDaprBody" responseBody := "fake_responseDaprBody" - fakeHandler := func(ctx *fasthttp.RequestCtx) { + fakeHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { time.Sleep(100 * time.Millisecond) - ctx.Response.SetBodyRaw([]byte(responseBody)) - } + w.Write([]byte(responseBody)) + }) rate := config.TracingSpec{SamplingRate: "1"} handler := HTTPTraceMiddleware(fakeHandler, "fakeAppID", rate) @@ -339,15 +344,15 @@ func TestHTTPTraceMiddleware(t *testing.T) { otel.SetTracerProvider(tp) t.Run("traceparent is given in request and sampling is enabled", func(t *testing.T) { - testRequestCtx := newTraceFastHTTPRequestCtx( + r := newTraceRequest( requestBody, "/v1.0/state/statestore", map[string]string{ "traceparent": "00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-01", }, - map[string]string{}, ) - handler(testRequestCtx) - span := diagUtils.SpanFromContext(testRequestCtx) + w := httptest.NewRecorder() + handler.ServeHTTP(w, r) + span := diagUtils.SpanFromContext(r.Context()) sc := span.SpanContext() traceID := sc.TraceID() spanID := sc.SpanID() @@ -356,15 +361,15 @@ func TestHTTPTraceMiddleware(t *testing.T) { }) t.Run("traceparent is not given in request", func(t *testing.T) { - testRequestCtx := newTraceFastHTTPRequestCtx( + r := newTraceRequest( requestBody, "/v1.0/state/statestore", map[string]string{ "dapr-userdefined": "value", }, - map[string]string{}, ) - handler(testRequestCtx) - span := diagUtils.SpanFromContext(testRequestCtx) + w := httptest.NewRecorder() + handler.ServeHTTP(w, r) + span := diagUtils.SpanFromContext(r.Context()) sc := span.SpanContext() traceID := sc.TraceID() spanID := sc.SpanID() @@ -373,53 +378,51 @@ func TestHTTPTraceMiddleware(t *testing.T) { }) t.Run("traceparent not given in response", func(t *testing.T) { - testRequestCtx := newTraceFastHTTPRequestCtx( + r := newTraceRequest( requestBody, "/v1.0/state/statestore", map[string]string{ "dapr-userdefined": "value", }, - map[string]string{}, ) - handler(testRequestCtx) - span := diagUtils.SpanFromContext(testRequestCtx) + w := httptest.NewRecorder() + handler.ServeHTTP(w, r) + span := diagUtils.SpanFromContext(r.Context()) sc := span.SpanContext() - assert.Equal(t, testRequestCtx.Response.Header.Peek(TraceparentHeader), []byte(SpanContextToW3CString(sc))) + assert.Equal(t, w.Header().Get(TraceparentHeader), SpanContextToW3CString(sc)) }) t.Run("traceparent given in response", func(t *testing.T) { - testRequestCtx := newTraceFastHTTPRequestCtx( + r := newTraceRequest( requestBody, "/v1.0/state/statestore", map[string]string{ "dapr-userdefined": "value", }, - map[string]string{ - TraceparentHeader: "00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-01", - TracestateHeader: "xyz=t61pCWkhMzZ", - }, ) - handler(testRequestCtx) - span := diagUtils.SpanFromContext(testRequestCtx) + w := httptest.NewRecorder() + w.Header().Set(TraceparentHeader, "00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-01") + w.Header().Set(TracestateHeader, "xyz=t61pCWkhMzZ") + handler.ServeHTTP(w, r) + span := diagUtils.SpanFromContext(r.Context()) sc := span.SpanContext() - assert.NotEqual(t, testRequestCtx.Response.Header.Peek(TraceparentHeader), []byte(SpanContextToW3CString(sc))) + assert.NotEqual(t, w.Header().Get(TraceparentHeader), SpanContextToW3CString(sc)) }) t.Run("path is /v1.0/invoke/*", func(t *testing.T) { - testRequestCtx := newTraceFastHTTPRequestCtx( + r := newTraceRequest( requestBody, "/v1.0/invoke/callee/method/method1", map[string]string{}, - map[string]string{}, ) - testRequestCtx.SetUserValue("id", "callee") - testRequestCtx.SetUserValue("method", "method1") - // act - handler(testRequestCtx) + w := responsewriter.EnsureResponseWriter(httptest.NewRecorder()) + w.SetUserValue("id", "callee") + w.SetUserValue("method", "method1") - // assert - span := diagUtils.SpanFromContext(testRequestCtx) + handler.ServeHTTP(w, r) + + span := diagUtils.SpanFromContext(r.Context()) sc := span.SpanContext() spanString := fmt.Sprintf("%v", span) - assert.True(t, strings.Contains(spanString, "CallLocal/callee/method1")) + assert.Truef(t, strings.Contains(spanString, "CallLocal/callee/method1"), "spanString is %s", spanString) traceID := sc.TraceID() spanID := sc.SpanID() assert.NotEmpty(t, fmt.Sprintf("%x", traceID[:])) @@ -458,33 +461,12 @@ func TestTraceStatusFromHTTPCode(t *testing.T) { } } -func newTraceFastHTTPRequestCtx(expectedBody, expectedRequestURI string, expectedRequestHeader map[string]string, expectedResponseHeader map[string]string) *fasthttp.RequestCtx { - expectedMethod := fasthttp.MethodPost - expectedTransferEncoding := "encoding" - expectedHost := "dapr.io" - expectedRemoteAddr := "1.2.3.4:6789" - - var ctx fasthttp.RequestCtx - var req fasthttp.Request - - req.Header.SetMethod(expectedMethod) - req.SetRequestURI(expectedRequestURI) - req.Header.SetHost(expectedHost) - req.Header.Add(fasthttp.HeaderTransferEncoding, expectedTransferEncoding) - req.Header.SetContentLength(len([]byte(expectedBody))) - req.BodyWriter().Write([]byte(expectedBody)) //nolint:errcheck - - for k, v := range expectedRequestHeader { +func newTraceRequest(body, requestPath string, requestHeader map[string]string) *http.Request { + req, _ := http.NewRequest(http.MethodPost, "http://dapr.io"+requestPath, strings.NewReader(body)) + req.Header.Set("Transfer-Encoding", "encoding") + req.Header.Set("Content-Length", strconv.Itoa(len(body))) + for k, v := range requestHeader { req.Header.Set(k, v) } - - remoteAddr, _ := net.ResolveTCPAddr("tcp", expectedRemoteAddr) - - ctx.Init(&req, remoteAddr, nil) - - for k, v := range expectedResponseHeader { - ctx.Response.Header.Set(k, v) - } - - return &ctx + return req } diff --git a/pkg/diagnostics/utils/trace_utils.go b/pkg/diagnostics/utils/trace_utils.go index a4d3c44eee4..9bb1bbf55f0 100644 --- a/pkg/diagnostics/utils/trace_utils.go +++ b/pkg/diagnostics/utils/trace_utils.go @@ -15,6 +15,7 @@ package utils import ( "context" + "net/http" "strconv" "github.com/valyala/fasthttp" @@ -24,11 +25,12 @@ import ( "github.com/dapr/kit/logger" ) +type daprContextKey string + const ( defaultSamplingRate = 1e-4 - // daprFastHTTPContextKey is the context value of span in fasthttp.RequestCtx. - daprFastHTTPContextKey = "daprSpanContextKey" + spanContextKey daprContextKey = "span" ) var emptySpanContext trace.SpanContext @@ -93,27 +95,37 @@ func IsTracingEnabled(rate string) bool { return GetTraceSamplingRate(rate) != 0 } -// SpanFromContext returns the SpanContext stored in a context, or nil or trace.nooSpan{} if there isn't one. - TODO +// SpanFromContext returns the Span stored in a context, or nil or trace.noopSpan{} if there isn't one. func SpanFromContext(ctx context.Context) trace.Span { + // TODO: Remove fasthttp compatibility when no HTTP API using contexts depend on fasthttp + var val any if reqCtx, ok := ctx.(*fasthttp.RequestCtx); ok { - val := reqCtx.UserValue(daprFastHTTPContextKey) - if val != nil { - return val.(trace.Span) - } + val = reqCtx.UserValue(spanContextKey) } else { - val := ctx.Value(daprFastHTTPContextKey) - if val != nil { - return val.(trace.Span) + val = ctx.Value(spanContextKey) + } + + if val != nil { + span, ok := val.(trace.Span) + if ok { + return span } } - span := trace.SpanFromContext(ctx) - return span + // Return the default span, which can be a noop + return trace.SpanFromContext(ctx) +} + +// AddSpanToFasthttpContext adds the span to the fasthttp request context. +// TODO: Remove fasthttp compatibility when no HTTP API using contexts depend on fasthttp. +func AddSpanToFasthttpContext(ctx *fasthttp.RequestCtx, span trace.Span) { + ctx.SetUserValue(spanContextKey, span) } -// SpanToFastHTTPContext sets span into fasthttp.RequestCtx. -func SpanToFastHTTPContext(ctx *fasthttp.RequestCtx, span trace.Span) { - ctx.SetUserValue(daprFastHTTPContextKey, span) +// AddSpanToRequest sets span into a request context. +func AddSpanToRequest(r *http.Request, span trace.Span) { + ctx := context.WithValue(r.Context(), spanContextKey, span) + *r = *(r.WithContext(ctx)) } // BinaryFromSpanContext returns the binary format representation of a SpanContext. diff --git a/pkg/diagnostics/utils/trace_utils_test.go b/pkg/diagnostics/utils/trace_utils_test.go index e6835b08783..ff82f9e778d 100644 --- a/pkg/diagnostics/utils/trace_utils_test.go +++ b/pkg/diagnostics/utils/trace_utils_test.go @@ -15,29 +15,29 @@ package utils import ( "context" + "net/http" "reflect" "testing" "github.com/stretchr/testify/assert" - "github.com/valyala/fasthttp" sdktrace "go.opentelemetry.io/otel/sdk/trace" "go.opentelemetry.io/otel/trace" ) func TestSpanFromContext(t *testing.T) { - t.Run("fasthttp.RequestCtx, not nil span", func(t *testing.T) { - ctx := &fasthttp.RequestCtx{} + t.Run("not nil span", func(t *testing.T) { + r, _ := http.NewRequest(http.MethodGet, "http://test.local/method", nil) var sp trace.Span - SpanToFastHTTPContext(ctx, sp) + AddSpanToRequest(r, sp) - assert.NotNil(t, SpanFromContext(ctx)) + assert.NotNil(t, SpanFromContext(r.Context())) }) - t.Run("fasthttp.RequestCtx, nil span", func(t *testing.T) { - ctx := &fasthttp.RequestCtx{} - SpanToFastHTTPContext(ctx, nil) - sp := SpanFromContext(ctx) + t.Run("nil span", func(t *testing.T) { + r, _ := http.NewRequest(http.MethodGet, "http://test.local/method", nil) + AddSpanToRequest(r, nil) + sp := SpanFromContext(r.Context()) expectedType := "trace.noopSpan" gotType := reflect.TypeOf(sp).String() assert.Equal(t, expectedType, gotType) diff --git a/pkg/grpc/api_daprinternal.go b/pkg/grpc/api_daprinternal.go index 200204122b0..26d20ed6a6f 100644 --- a/pkg/grpc/api_daprinternal.go +++ b/pkg/grpc/api_daprinternal.go @@ -170,9 +170,11 @@ func (a *api) CallLocalStream(stream internalv1pb.ServiceInvocation_CallLocalStr // Submit the request to the app res, err := a.appChannel.InvokeMethod(ctx, req) if err != nil { + statusCode = int32(codes.Internal) return status.Errorf(codes.Internal, messages.ErrChannelInvoke, err) } defer res.Close() + statusCode = res.Status().Code // Respond to the caller buf := invokev1.BufPool.Get().(*[]byte) diff --git a/pkg/grpc/api_test.go b/pkg/grpc/api_test.go index 2a7deeba51c..24aed6b20af 100644 --- a/pkg/grpc/api_test.go +++ b/pkg/grpc/api_test.go @@ -1779,7 +1779,7 @@ func TestUnSubscribeConfiguration(t *testing.T) { } resp, err := client.SubscribeConfigurationAlpha1(context.Background(), req) - assert.Nil(t, err, "Error should be nil") + assert.NoError(t, err, "Error should be nil") retry := 3 count := 0 var subscribeID string @@ -1790,8 +1790,8 @@ func TestUnSubscribeConfiguration(t *testing.T) { count++ time.Sleep(time.Millisecond * 10) rsp, recvErr := resp.Recv() - assert.NotNil(t, rsp) - assert.NoError(t, recvErr) + require.NoError(t, recvErr) + require.NotNil(t, rsp) if rsp.Items != nil { assert.Equal(t, tt.expectedResponse, rsp.Items) } else { @@ -1799,12 +1799,12 @@ func TestUnSubscribeConfiguration(t *testing.T) { } subscribeID = rsp.Id } - assert.Nil(t, err, "Error should be nil") + assert.NoError(t, err, "Error should be nil") _, err = client.UnsubscribeConfigurationAlpha1(context.Background(), &runtimev1pb.UnsubscribeConfigurationRequest{ StoreName: tt.storeName, Id: subscribeID, }) - assert.Nil(t, err, "Error should be nil") + assert.NoError(t, err, "Error should be nil") count = 0 for { if err != nil && err.Error() == "EOF" { @@ -1827,7 +1827,7 @@ func TestUnSubscribeConfiguration(t *testing.T) { } resp, err := client.SubscribeConfiguration(context.Background(), req) - assert.Nil(t, err, "Error should be nil") + assert.NoError(t, err, "Error should be nil") retry := 3 count := 0 var subscribeID string @@ -1847,12 +1847,12 @@ func TestUnSubscribeConfiguration(t *testing.T) { } subscribeID = rsp.Id } - assert.Nil(t, err, "Error should be nil") + assert.NoError(t, err, "Error should be nil") _, err = client.UnsubscribeConfiguration(context.Background(), &runtimev1pb.UnsubscribeConfigurationRequest{ StoreName: tt.storeName, Id: subscribeID, }) - assert.Nil(t, err, "Error should be nil") + assert.NoError(t, err, "Error should be nil") count = 0 for { if err != nil && err.Error() == "EOF" { @@ -2370,7 +2370,7 @@ func TestPublishTopic(t *testing.T) { PubsubName: "pubsub", Topic: "topic", }) - assert.Nil(t, err) + assert.NoError(t, err) }) t.Run("no err: publish event request with topic, pubsub and ce metadata override", func(t *testing.T) { @@ -2383,7 +2383,7 @@ func TestPublishTopic(t *testing.T) { "cloudevent.pubsub": "overridepubsub", // noop -- if this modified the envelope the test would fail }, }) - assert.Nil(t, err) + assert.NoError(t, err) }) t.Run("err: publish event request with error-topic and pubsub", func(t *testing.T) { @@ -2473,7 +2473,7 @@ func TestPublishTopic(t *testing.T) { PubsubName: "pubsub", Topic: "topic", }) - assert.Nil(t, err) + assert.NoError(t, err) }) t.Run("err: bulk publish event request with error-topic and pubsub", func(t *testing.T) { @@ -2561,7 +2561,7 @@ func TestBulkPublish(t *testing.T) { Topic: "topic", Entries: sampleEntries, }) - assert.Nil(t, err) + assert.NoError(t, err) assert.Empty(t, res.FailedEntries) }) @@ -2576,7 +2576,7 @@ func TestBulkPublish(t *testing.T) { "cloudevent.pubsub": "overridepubsub", // noop -- if this modified the envelope the test would fail }, }) - assert.Nil(t, err) + assert.NoError(t, err) assert.Empty(t, res.FailedEntries) }) @@ -2588,7 +2588,7 @@ func TestBulkPublish(t *testing.T) { }) t.Log(res) // Full failure from component, so expecting no error - assert.Nil(t, err) + assert.NoError(t, err) assert.NotNil(t, res) assert.Equal(t, 4, len(res.FailedEntries)) }) @@ -2600,7 +2600,7 @@ func TestBulkPublish(t *testing.T) { Entries: sampleEntries, }) // Partial failure, so expecting no error - assert.Nil(t, err) + assert.NoError(t, err) assert.NotNil(t, res) assert.Equal(t, 2, len(res.FailedEntries)) }) @@ -2623,13 +2623,13 @@ func TestInvokeBinding(t *testing.T) { client := runtimev1pb.NewDaprClient(clientConn) _, err := client.InvokeBinding(context.Background(), &runtimev1pb.InvokeBindingRequest{}) - assert.Nil(t, err) + assert.NoError(t, err) _, err = client.InvokeBinding(context.Background(), &runtimev1pb.InvokeBindingRequest{Name: "error-binding"}) assert.Equal(t, codes.Internal, status.Code(err)) ctx := grpcMetadata.AppendToOutgoingContext(context.Background(), "traceparent", "Test") resp, err := client.InvokeBinding(ctx, &runtimev1pb.InvokeBindingRequest{Metadata: map[string]string{"userMetadata": "val1"}}) - assert.Nil(t, err) + assert.NoError(t, err) assert.NotNil(t, resp) assert.Contains(t, resp.Metadata, "traceparent") assert.Equal(t, resp.Metadata["traceparent"], "Test") @@ -4089,7 +4089,7 @@ func TestTryLock(t *testing.T) { ExpiryInSeconds: 1, } resp, err := api.TryLockAlpha1(context.Background(), req) - assert.Nil(t, err) + assert.NoError(t, err) assert.Equal(t, true, resp.Success) }) } diff --git a/pkg/http/api.go b/pkg/http/api.go index 63db766f579..8cd170a6779 100644 --- a/pkg/http/api.go +++ b/pkg/http/api.go @@ -85,7 +85,6 @@ type api struct { outboundReadyStatus bool tracingSpec config.TracingSpec maxRequestBodySize int64 // In bytes - isStreamingEnabled bool } const ( @@ -131,7 +130,6 @@ type APIOpts struct { Shutdown func() GetComponentsCapabilitiesFn func() map[string][]string MaxRequestBodySize int64 // In bytes - IsStreamingEnabled bool } // NewAPI returns a new API. @@ -145,7 +143,6 @@ func NewAPI(opts APIOpts) API { sendToOutputBindingFn: opts.SendToOutputBindingFn, tracingSpec: opts.TracingSpec, maxRequestBodySize: opts.MaxRequestBodySize, - isStreamingEnabled: opts.IsStreamingEnabled, universal: &universalapi.UniversalAPI{ AppID: opts.AppID, Logger: log, @@ -207,43 +204,43 @@ func (a *api) MarkStatusAsOutboundReady() { func (a *api) constructWorkflowEndpoints() []Endpoint { return []Endpoint{ { - Methods: []string{fasthttp.MethodGet}, + Methods: []string{nethttp.MethodGet}, Route: "workflows/{workflowComponent}/{instanceID}", Version: apiVersionV1alpha1, Handler: a.onGetWorkflowHandler(), }, { - Methods: []string{fasthttp.MethodPost}, + Methods: []string{nethttp.MethodPost}, Route: "workflows/{workflowComponent}/{instanceID}/raiseEvent/{eventName}", Version: apiVersionV1alpha1, Handler: a.onRaiseEventWorkflowHandler(), }, { - Methods: []string{fasthttp.MethodPost}, + Methods: []string{nethttp.MethodPost}, Route: "workflows/{workflowComponent}/{workflowName}/start", Version: apiVersionV1alpha1, Handler: a.onStartWorkflowHandler(), }, { - Methods: []string{fasthttp.MethodPost}, + Methods: []string{nethttp.MethodPost}, Route: "workflows/{workflowComponent}/{instanceID}/pause", Version: apiVersionV1alpha1, Handler: a.onPauseWorkflowHandler(), }, { - Methods: []string{fasthttp.MethodPost}, + Methods: []string{nethttp.MethodPost}, Route: "workflows/{workflowComponent}/{instanceID}/resume", Version: apiVersionV1alpha1, Handler: a.onResumeWorkflowHandler(), }, { - Methods: []string{fasthttp.MethodPost}, + Methods: []string{nethttp.MethodPost}, Route: "workflows/{workflowComponent}/{instanceID}/terminate", Version: apiVersionV1alpha1, Handler: a.onTerminateWorkflowHandler(), }, { - Methods: []string{fasthttp.MethodPost}, + Methods: []string{nethttp.MethodPost}, Route: "workflows/{workflowComponent}/{instanceID}/purge", Version: apiVersionV1alpha1, Handler: a.onPurgeWorkflowHandler(), @@ -254,37 +251,37 @@ func (a *api) constructWorkflowEndpoints() []Endpoint { func (a *api) constructStateEndpoints() []Endpoint { return []Endpoint{ { - Methods: []string{fasthttp.MethodGet}, + Methods: []string{nethttp.MethodGet}, Route: "state/{storeName}/{key}", Version: apiVersionV1, Handler: a.onGetState, }, { - Methods: []string{fasthttp.MethodPost, fasthttp.MethodPut}, + Methods: []string{nethttp.MethodPost, nethttp.MethodPut}, Route: "state/{storeName}", Version: apiVersionV1, Handler: a.onPostState, }, { - Methods: []string{fasthttp.MethodDelete}, + Methods: []string{nethttp.MethodDelete}, Route: "state/{storeName}/{key}", Version: apiVersionV1, Handler: a.onDeleteState, }, { - Methods: []string{fasthttp.MethodPost, fasthttp.MethodPut}, + Methods: []string{nethttp.MethodPost, nethttp.MethodPut}, Route: "state/{storeName}/bulk", Version: apiVersionV1, Handler: a.onBulkGetState, }, { - Methods: []string{fasthttp.MethodPost, fasthttp.MethodPut}, + Methods: []string{nethttp.MethodPost, nethttp.MethodPut}, Route: "state/{storeName}/transaction", Version: apiVersionV1, Handler: a.onPostStateTransaction, }, { - Methods: []string{fasthttp.MethodPost, fasthttp.MethodPut}, + Methods: []string{nethttp.MethodPost, nethttp.MethodPut}, Route: "state/{storeName}/query", Version: apiVersionV1alpha1, Handler: a.onQueryStateHandler(), @@ -295,13 +292,13 @@ func (a *api) constructStateEndpoints() []Endpoint { func (a *api) constructPubSubEndpoints() []Endpoint { return []Endpoint{ { - Methods: []string{fasthttp.MethodPost, fasthttp.MethodPut}, + Methods: []string{nethttp.MethodPost, nethttp.MethodPut}, Route: "publish/{pubsubname}/{topic:*}", Version: apiVersionV1, Handler: a.onPublish, }, { - Methods: []string{fasthttp.MethodPost, fasthttp.MethodPut}, + Methods: []string{nethttp.MethodPost, nethttp.MethodPut}, Route: "publish/bulk/{pubsubname}/{topic:*}", Version: apiVersionV1alpha1, Handler: a.onBulkPublish, @@ -312,7 +309,7 @@ func (a *api) constructPubSubEndpoints() []Endpoint { func (a *api) constructBindingsEndpoints() []Endpoint { return []Endpoint{ { - Methods: []string{fasthttp.MethodPost, fasthttp.MethodPut}, + Methods: []string{nethttp.MethodPost, nethttp.MethodPut}, Route: "bindings/{name}", Version: apiVersionV1, Handler: a.onOutputBindingMessage, @@ -336,55 +333,55 @@ func (a *api) constructDirectMessagingEndpoints() []Endpoint { func (a *api) constructActorEndpoints() []Endpoint { return []Endpoint{ { - Methods: []string{fasthttp.MethodPost, fasthttp.MethodPut}, + Methods: []string{nethttp.MethodPost, nethttp.MethodPut}, Route: "actors/{actorType}/{actorId}/state", Version: apiVersionV1, Handler: a.onActorStateTransaction, }, { - Methods: []string{fasthttp.MethodGet, fasthttp.MethodPost, fasthttp.MethodDelete, fasthttp.MethodPut}, + Methods: []string{nethttp.MethodGet, nethttp.MethodPost, nethttp.MethodDelete, nethttp.MethodPut}, Route: "actors/{actorType}/{actorId}/method/{method}", Version: apiVersionV1, Handler: a.onDirectActorMessage, }, { - Methods: []string{fasthttp.MethodGet}, + Methods: []string{nethttp.MethodGet}, Route: "actors/{actorType}/{actorId}/state/{key}", Version: apiVersionV1, Handler: a.onGetActorState, }, { - Methods: []string{fasthttp.MethodPost, fasthttp.MethodPut}, + Methods: []string{nethttp.MethodPost, nethttp.MethodPut}, Route: "actors/{actorType}/{actorId}/reminders/{name}", Version: apiVersionV1, Handler: a.onCreateActorReminder, }, { - Methods: []string{fasthttp.MethodPost, fasthttp.MethodPut}, + Methods: []string{nethttp.MethodPost, nethttp.MethodPut}, Route: "actors/{actorType}/{actorId}/timers/{name}", Version: apiVersionV1, Handler: a.onCreateActorTimer, }, { - Methods: []string{fasthttp.MethodDelete}, + Methods: []string{nethttp.MethodDelete}, Route: "actors/{actorType}/{actorId}/reminders/{name}", Version: apiVersionV1, Handler: a.onDeleteActorReminder, }, { - Methods: []string{fasthttp.MethodDelete}, + Methods: []string{nethttp.MethodDelete}, Route: "actors/{actorType}/{actorId}/timers/{name}", Version: apiVersionV1, Handler: a.onDeleteActorTimer, }, { - Methods: []string{fasthttp.MethodGet}, + Methods: []string{nethttp.MethodGet}, Route: "actors/{actorType}/{actorId}/reminders/{name}", Version: apiVersionV1, Handler: a.onGetActorReminder, }, { - Methods: []string{fasthttp.MethodPatch}, + Methods: []string{nethttp.MethodPatch}, Route: "actors/{actorType}/{actorId}/reminders/{name}", Version: apiVersionV1, Handler: a.onRenameActorReminder, @@ -395,7 +392,7 @@ func (a *api) constructActorEndpoints() []Endpoint { func (a *api) constructHealthzEndpoints() []Endpoint { return []Endpoint{ { - Methods: []string{fasthttp.MethodGet}, + Methods: []string{nethttp.MethodGet}, Route: "healthz", Version: apiVersionV1, Handler: a.onGetHealthz, @@ -403,7 +400,7 @@ func (a *api) constructHealthzEndpoints() []Endpoint { IsHealthCheck: true, }, { - Methods: []string{fasthttp.MethodGet}, + Methods: []string{nethttp.MethodGet}, Route: "healthz/outbound", Version: apiVersionV1, Handler: a.onGetOutboundHealthz, @@ -416,37 +413,37 @@ func (a *api) constructHealthzEndpoints() []Endpoint { func (a *api) constructConfigurationEndpoints() []Endpoint { return []Endpoint{ { - Methods: []string{fasthttp.MethodGet}, + Methods: []string{nethttp.MethodGet}, Route: "configuration/{storeName}", Version: apiVersionV1alpha1, Handler: a.onGetConfiguration, }, { - Methods: []string{fasthttp.MethodGet}, + Methods: []string{nethttp.MethodGet}, Route: "configuration/{storeName}", Version: apiVersionV1, Handler: a.onGetConfiguration, }, { - Methods: []string{fasthttp.MethodGet}, + Methods: []string{nethttp.MethodGet}, Route: "configuration/{storeName}/subscribe", Version: apiVersionV1alpha1, Handler: a.onSubscribeConfiguration, }, { - Methods: []string{fasthttp.MethodGet}, + Methods: []string{nethttp.MethodGet}, Route: "configuration/{storeName}/subscribe", Version: apiVersionV1, Handler: a.onSubscribeConfiguration, }, { - Methods: []string{fasthttp.MethodGet}, + Methods: []string{nethttp.MethodGet}, Route: "configuration/{storeName}/{configurationSubscribeID}/unsubscribe", Version: apiVersionV1alpha1, Handler: a.onUnsubscribeConfiguration, }, { - Methods: []string{fasthttp.MethodGet}, + Methods: []string{nethttp.MethodGet}, Route: "configuration/{storeName}/{configurationSubscribeID}/unsubscribe", Version: apiVersionV1, Handler: a.onUnsubscribeConfiguration, @@ -462,7 +459,7 @@ func (a *api) onOutputBindingMessage(reqCtx *fasthttp.RequestCtx) { err := json.Unmarshal(body, &req) if err != nil { msg := NewErrorResponse("ERR_MALFORMED_REQUEST", fmt.Sprintf(messages.ErrMalformedRequest, err)) - respond(reqCtx, withError(fasthttp.StatusBadRequest, msg)) + respond(reqCtx, withError(nethttp.StatusBadRequest, msg)) log.Debug(msg) return } @@ -470,7 +467,7 @@ func (a *api) onOutputBindingMessage(reqCtx *fasthttp.RequestCtx) { b, err := json.Marshal(req.Data) if err != nil { msg := NewErrorResponse("ERR_MALFORMED_REQUEST_DATA", fmt.Sprintf(messages.ErrMalformedRequestData, err)) - respond(reqCtx, withError(fasthttp.StatusInternalServerError, msg)) + respond(reqCtx, withError(nethttp.StatusInternalServerError, msg)) log.Debug(msg) return } @@ -502,7 +499,7 @@ func (a *api) onOutputBindingMessage(reqCtx *fasthttp.RequestCtx) { if err != nil { msg := NewErrorResponse("ERR_INVOKE_OUTPUT_BINDING", fmt.Sprintf(messages.ErrInvokeOutputBinding, name, err)) - respond(reqCtx, withError(fasthttp.StatusInternalServerError, msg)) + respond(reqCtx, withError(nethttp.StatusInternalServerError, msg)) log.Debug(msg) return } @@ -510,7 +507,7 @@ func (a *api) onOutputBindingMessage(reqCtx *fasthttp.RequestCtx) { if resp == nil { respond(reqCtx, withEmpty()) } else { - respond(reqCtx, withMetadata(resp.Metadata), withJSON(fasthttp.StatusOK, resp.Data)) + respond(reqCtx, withMetadata(resp.Metadata), withJSON(nethttp.StatusOK, resp.Data)) } } @@ -525,7 +522,7 @@ func (a *api) onBulkGetState(reqCtx *fasthttp.RequestCtx) { err = json.Unmarshal(reqCtx.PostBody(), &req) if err != nil { msg := NewErrorResponse("ERR_MALFORMED_REQUEST", fmt.Sprintf(messages.ErrMalformedRequest, err)) - respond(reqCtx, withError(fasthttp.StatusBadRequest, msg)) + respond(reqCtx, withError(nethttp.StatusBadRequest, msg)) log.Debug(msg) return } @@ -543,7 +540,7 @@ func (a *api) onBulkGetState(reqCtx *fasthttp.RequestCtx) { bulkResp := make([]BulkGetResponse, len(req.Keys)) if len(req.Keys) == 0 { b, _ := json.Marshal(bulkResp) - respond(reqCtx, withJSON(fasthttp.StatusOK, b)) + respond(reqCtx, withJSON(nethttp.StatusOK, b)) return } @@ -553,7 +550,7 @@ func (a *api) onBulkGetState(reqCtx *fasthttp.RequestCtx) { key, err = stateLoader.GetModifiedStateKey(k, storeName, a.universal.AppID) if err != nil { msg := NewErrorResponse("ERR_MALFORMED_REQUEST", fmt.Sprintf(messages.ErrMalformedRequest, err)) - respond(reqCtx, withError(fasthttp.StatusBadRequest, msg)) + respond(reqCtx, withError(nethttp.StatusBadRequest, msg)) log.Debug(err) return } @@ -579,7 +576,7 @@ func (a *api) onBulkGetState(reqCtx *fasthttp.RequestCtx) { if err != nil { msg := NewErrorResponse("ERR_STATE_BULK_GET", err.Error()) - respond(reqCtx, withError(fasthttp.StatusInternalServerError, msg)) + respond(reqCtx, withError(nethttp.StatusInternalServerError, msg)) log.Debug(msg) return } @@ -616,7 +613,7 @@ func (a *api) onBulkGetState(reqCtx *fasthttp.RequestCtx) { } b, _ := json.Marshal(bulkResp) - respond(reqCtx, withJSON(fasthttp.StatusOK, b)) + respond(reqCtx, withJSON(nethttp.StatusOK, b)) } func (a *api) getStateStoreWithRequestValidation(reqCtx *fasthttp.RequestCtx) (state.Store, string, error) { @@ -669,7 +666,7 @@ func (a *api) onStartWorkflowHandler() fasthttp.RequestHandler { in.Input = reqCtx.PostBody() return in, nil }, - SuccessStatusCode: fasthttp.StatusAccepted, + SuccessStatusCode: nethttp.StatusAccepted, }) } @@ -696,7 +693,7 @@ func (a *api) onTerminateWorkflowHandler() fasthttp.RequestHandler { in.InstanceId = reqCtx.UserValue(instanceID).(string) return in, nil }, - SuccessStatusCode: fasthttp.StatusAccepted, + SuccessStatusCode: nethttp.StatusAccepted, }) } @@ -717,7 +714,7 @@ func (a *api) onRaiseEventWorkflowHandler() fasthttp.RequestHandler { in.EventData = reqCtx.PostBody() return in, nil }, - SuccessStatusCode: fasthttp.StatusAccepted, + SuccessStatusCode: nethttp.StatusAccepted, }) } @@ -731,7 +728,7 @@ func (a *api) onPauseWorkflowHandler() fasthttp.RequestHandler { in.InstanceId = reqCtx.UserValue(instanceID).(string) return in, nil }, - SuccessStatusCode: fasthttp.StatusAccepted, + SuccessStatusCode: nethttp.StatusAccepted, }) } @@ -745,7 +742,7 @@ func (a *api) onResumeWorkflowHandler() fasthttp.RequestHandler { in.InstanceId = reqCtx.UserValue(instanceID).(string) return in, nil }, - SuccessStatusCode: fasthttp.StatusAccepted, + SuccessStatusCode: nethttp.StatusAccepted, }) } @@ -758,7 +755,7 @@ func (a *api) onPurgeWorkflowHandler() fasthttp.RequestHandler { in.InstanceId = reqCtx.UserValue(instanceID).(string) return in, nil }, - SuccessStatusCode: fasthttp.StatusAccepted, + SuccessStatusCode: nethttp.StatusAccepted, }) } @@ -776,7 +773,7 @@ func (a *api) onGetState(reqCtx *fasthttp.RequestCtx) { k, err := stateLoader.GetModifiedStateKey(key, storeName, a.universal.AppID) if err != nil { msg := NewErrorResponse("ERR_MALFORMED_REQUEST", fmt.Sprintf(messages.ErrMalformedRequest, err)) - respond(reqCtx, withError(fasthttp.StatusBadRequest, msg)) + respond(reqCtx, withError(nethttp.StatusBadRequest, msg)) log.Debug(err) return } @@ -801,7 +798,7 @@ func (a *api) onGetState(reqCtx *fasthttp.RequestCtx) { if err != nil { msg := NewErrorResponse("ERR_STATE_GET", fmt.Sprintf(messages.ErrStateGet, key, storeName, err.Error())) - respond(reqCtx, withError(fasthttp.StatusInternalServerError, msg)) + respond(reqCtx, withError(nethttp.StatusInternalServerError, msg)) log.Debug(msg) return } @@ -815,7 +812,7 @@ func (a *api) onGetState(reqCtx *fasthttp.RequestCtx) { val, err := encryption.TryDecryptValue(storeName, resp.Data) if err != nil { msg := NewErrorResponse("ERR_STATE_GET", fmt.Sprintf(messages.ErrStateGet, key, storeName, err.Error())) - respond(reqCtx, withError(fasthttp.StatusInternalServerError, msg)) + respond(reqCtx, withError(nethttp.StatusInternalServerError, msg)) log.Debug(msg) return } @@ -823,13 +820,17 @@ func (a *api) onGetState(reqCtx *fasthttp.RequestCtx) { resp.Data = val } - respond(reqCtx, withJSON(fasthttp.StatusOK, resp.Data), withEtag(resp.ETag), withMetadata(resp.Metadata)) + if resp.ETag != nil { + reqCtx.Response.Header.Add(etagHeader, *resp.ETag) + } + + respond(reqCtx, withJSON(nethttp.StatusOK, resp.Data), withMetadata(resp.Metadata)) } func (a *api) getConfigurationStoreWithRequestValidation(reqCtx *fasthttp.RequestCtx) (configuration.Store, string, error) { if a.universal.CompStore.ConfigurationsLen() == 0 { msg := NewErrorResponse("ERR_CONFIGURATION_STORE_NOT_CONFIGURED", messages.ErrConfigurationStoresNotConfigured) - respond(reqCtx, withError(fasthttp.StatusInternalServerError, msg)) + respond(reqCtx, withError(nethttp.StatusInternalServerError, msg)) log.Debug(msg) return nil, "", errors.New(msg.Message) } @@ -839,7 +840,7 @@ func (a *api) getConfigurationStoreWithRequestValidation(reqCtx *fasthttp.Reques conf, ok := a.universal.CompStore.GetConfiguration(storeName) if !ok { msg := NewErrorResponse("ERR_CONFIGURATION_STORE_NOT_FOUND", fmt.Sprintf(messages.ErrConfigurationStoreNotFound, storeName)) - respond(reqCtx, withError(fasthttp.StatusBadRequest, msg)) + respond(reqCtx, withError(nethttp.StatusBadRequest, msg)) log.Debug(msg) return nil, "", errors.New(msg.Message) } @@ -913,7 +914,7 @@ func (a *api) onSubscribeConfiguration(reqCtx *fasthttp.RequestCtx) { } if a.appChannel == nil { msg := NewErrorResponse("ERR_APP_CHANNEL_NIL", "app channel is not initialized. cannot subscribe to configuration updates") - respond(reqCtx, withError(fasthttp.StatusInternalServerError, msg)) + respond(reqCtx, withError(nethttp.StatusInternalServerError, msg)) log.Debug(msg) return } @@ -956,14 +957,14 @@ func (a *api) onSubscribeConfiguration(reqCtx *fasthttp.RequestCtx) { if err != nil { msg := NewErrorResponse("ERR_CONFIGURATION_SUBSCRIBE", fmt.Sprintf(messages.ErrConfigurationSubscribe, keys, storeName, err.Error())) - respond(reqCtx, withError(fasthttp.StatusInternalServerError, msg)) + respond(reqCtx, withError(nethttp.StatusInternalServerError, msg)) log.Debug(msg) return } respBytes, _ := json.Marshal(&subscribeConfigurationResponse{ ID: subscribeID, }) - respond(reqCtx, withJSON(fasthttp.StatusOK, respBytes)) + respond(reqCtx, withJSON(nethttp.StatusOK, respBytes)) } func (a *api) onUnsubscribeConfiguration(reqCtx *fasthttp.RequestCtx) { @@ -993,14 +994,14 @@ func (a *api) onUnsubscribeConfiguration(reqCtx *fasthttp.RequestCtx) { Ok: false, Message: msg.Message, }) - respond(reqCtx, withJSON(fasthttp.StatusInternalServerError, errRespBytes)) + respond(reqCtx, withJSON(nethttp.StatusInternalServerError, errRespBytes)) log.Debug(msg) return } respBytes, _ := json.Marshal(&UnsubscribeConfigurationResponse{ Ok: true, }) - respond(reqCtx, withJSON(fasthttp.StatusOK, respBytes)) + respond(reqCtx, withJSON(nethttp.StatusOK, respBytes)) } func (a *api) onGetConfiguration(reqCtx *fasthttp.RequestCtx) { @@ -1035,7 +1036,7 @@ func (a *api) onGetConfiguration(reqCtx *fasthttp.RequestCtx) { if err != nil { msg := NewErrorResponse("ERR_CONFIGURATION_GET", fmt.Sprintf(messages.ErrConfigurationGet, keys, storeName, err.Error())) - respond(reqCtx, withError(fasthttp.StatusInternalServerError, msg)) + respond(reqCtx, withError(nethttp.StatusInternalServerError, msg)) log.Debug(msg) return } @@ -1047,7 +1048,7 @@ func (a *api) onGetConfiguration(reqCtx *fasthttp.RequestCtx) { respBytes, _ := json.Marshal(getResponse.Items) - respond(reqCtx, withJSON(fasthttp.StatusOK, respBytes)) + respond(reqCtx, withJSON(nethttp.StatusOK, respBytes)) } func extractEtag(reqCtx *fasthttp.RequestCtx) (hasEtag bool, etag string) { @@ -1078,7 +1079,7 @@ func (a *api) onDeleteState(reqCtx *fasthttp.RequestCtx) { k, err := stateLoader.GetModifiedStateKey(key, storeName, a.universal.AppID) if err != nil { msg := NewErrorResponse("ERR_MALFORMED_REQUEST", err.Error()) - respond(reqCtx, withError(fasthttp.StatusBadRequest, msg)) + respond(reqCtx, withError(nethttp.StatusBadRequest, msg)) log.Debug(err) return } @@ -1129,7 +1130,7 @@ func (a *api) onPostState(reqCtx *fasthttp.RequestCtx) { err = json.Unmarshal(reqCtx.PostBody(), &reqs) if err != nil { msg := NewErrorResponse("ERR_MALFORMED_REQUEST", err.Error()) - respond(reqCtx, withError(fasthttp.StatusBadRequest, msg)) + respond(reqCtx, withError(nethttp.StatusBadRequest, msg)) log.Debug(msg) return } @@ -1153,7 +1154,7 @@ func (a *api) onPostState(reqCtx *fasthttp.RequestCtx) { reqs[i].Key, err = stateLoader.GetModifiedStateKey(r.Key, storeName, a.universal.AppID) if err != nil { msg := NewErrorResponse("ERR_MALFORMED_REQUEST", err.Error()) - respond(reqCtx, withError(fasthttp.StatusBadRequest, msg)) + respond(reqCtx, withError(nethttp.StatusBadRequest, msg)) log.Debug(err) return } @@ -1209,7 +1210,7 @@ func (a *api) stateErrorResponse(err error, errorCode string) (int, string, Erro } message = err.Error() - return fasthttp.StatusInternalServerError, message, r + return nethttp.StatusInternalServerError, message, r } // etagError checks if the error from the state store is an etag error and returns a bool for indication, @@ -1219,9 +1220,9 @@ func (a *api) etagError(err error) (bool, int, string) { if errors.As(err, &etagErr) { switch etagErr.Kind() { case state.ETagMismatch: - return true, fasthttp.StatusConflict, etagErr.Error() + return true, nethttp.StatusConflict, etagErr.Error() case state.ETagInvalid: - return true, fasthttp.StatusBadRequest, etagErr.Error() + return true, nethttp.StatusBadRequest, etagErr.Error() } } return false, -1, "" @@ -1248,102 +1249,76 @@ func (a *api) isHTTPEndpoint(appID string) bool { // getBaseURL takes an app id and checks if the app id is an HTTP endpoint CRD. // It returns the baseURL if found. func (a *api) getBaseURL(targetAppID string) string { - if endpoint, ok := a.universal.CompStore.GetHTTPEndpoint(targetAppID); ok && endpoint.Name == targetAppID { + endpoint, ok := a.universal.CompStore.GetHTTPEndpoint(targetAppID) + if ok && endpoint.Name == targetAppID { return endpoint.Spec.BaseURL } return "" } func (a *api) onDirectMessage(reqCtx *fasthttp.RequestCtx) { - // Need a context specific to this request. See: https://github.com/valyala/fasthttp/issues/1350 - // Because this can respond with `withStream()`, we can't defer a call to cancel() here - ctx, cancel := context.WithCancel(reqCtx) - targetID := a.findTargetID(reqCtx) if targetID == "" { msg := NewErrorResponse("ERR_DIRECT_INVOKE", messages.ErrDirectInvokeNoAppID) - respond(reqCtx, withError(fasthttp.StatusNotFound, msg)) - cancel() + respond(reqCtx, withError(nethttp.StatusNotFound, msg)) return } verb := strings.ToUpper(string(reqCtx.Method())) if a.directMessaging == nil { msg := NewErrorResponse("ERR_DIRECT_INVOKE", messages.ErrDirectInvokeNotReady) - respond(reqCtx, withError(fasthttp.StatusInternalServerError, msg)) - cancel() + respond(reqCtx, withError(nethttp.StatusInternalServerError, msg)) return } - var req *invokev1.InvokeMethodRequest - var policyDef *resiliency.PolicyDefinition - var invokeMethodName string + var ( + policyDef *resiliency.PolicyDefinition + invokeMethodName string + ) switch { - // overwritten URL, so targetID = baseURL case strings.HasPrefix(targetID, "http://") || strings.HasPrefix(targetID, "https://"): - baseURL := targetID + // overwritten URL, so targetID = baseURL invokeMethodNameWithPrefix := reqCtx.UserValue(methodParam).(string) - prefix := "v1.0/invoke/" + baseURL + "/" + methodParam + prefix := "v1.0/invoke/" + targetID + "/" + methodParam if len(invokeMethodNameWithPrefix) <= len(prefix) { msg := NewErrorResponse("ERR_DIRECT_INVOKE", messages.ErrDirectInvokeMethod) - respond(reqCtx, withError(fasthttp.StatusNotFound, msg)) - cancel() + respond(reqCtx, withError(nethttp.StatusNotFound, msg)) return } invokeMethodName = invokeMethodNameWithPrefix[len(prefix):] policyDef = a.resiliency.EndpointPolicy(targetID, targetID+"/"+invokeMethodNameWithPrefix) - req = invokev1.NewInvokeMethodRequest(invokeMethodName). - WithHTTPExtension(verb, reqCtx.QueryArgs().String()). - WithRawDataBytes(reqCtx.Request.Body()). - WithContentType(string(reqCtx.Request.Header.ContentType())). - // Save headers to internal metadata - WithFastHTTPHeaders(&reqCtx.Request.Header) - if policyDef != nil { - req.WithReplay(policyDef.HasRetries()) - } - defer req.Close() - // http endpoint CRD resource is detected being used for service invocation + case a.isHTTPEndpoint(targetID): + // http endpoint CRD resource is detected being used for service invocation baseURL := a.getBaseURL(targetID) policyDef = a.resiliency.EndpointPolicy(targetID, targetID+":"+baseURL) invokeMethodName = reqCtx.UserValue(methodParam).(string) if invokeMethodName == "" { msg := NewErrorResponse("ERR_DIRECT_INVOKE", messages.ErrDirectInvokeMethod) - respond(reqCtx, withError(fasthttp.StatusNotFound, msg)) - cancel() + respond(reqCtx, withError(nethttp.StatusNotFound, msg)) return } - req = invokev1.NewInvokeMethodRequest(invokeMethodName). - WithHTTPExtension(verb, reqCtx.QueryArgs().String()). - WithRawDataBytes(reqCtx.Request.Body()). - WithContentType(string(reqCtx.Request.Header.ContentType())). - // Save headers to internal metadata - WithFastHTTPHeaders(&reqCtx.Request.Header) - if policyDef != nil { - req.WithReplay(policyDef.HasRetries()) - } - defer req.Close() - // regular service to service invocation default: + // regular service to service invocation invokeMethodName = reqCtx.UserValue(methodParam).(string) policyDef = a.resiliency.EndpointPolicy(targetID, targetID+":"+invokeMethodName) + } - req = invokev1.NewInvokeMethodRequest(invokeMethodName). - WithHTTPExtension(verb, reqCtx.QueryArgs().String()). - WithRawDataBytes(reqCtx.Request.Body()). - WithContentType(string(reqCtx.Request.Header.ContentType())). - // Save headers to internal metadata - WithFastHTTPHeaders(&reqCtx.Request.Header) - if policyDef != nil { - req.WithReplay(policyDef.HasRetries()) - } - defer req.Close() + req := invokev1.NewInvokeMethodRequest(invokeMethodName). + WithHTTPExtension(verb, reqCtx.QueryArgs().String()). + WithRawDataBytes(reqCtx.Request.Body()). + WithContentType(string(reqCtx.Request.Header.ContentType())). + // Save headers to internal metadata + WithFastHTTPHeaders(&reqCtx.Request.Header) + if policyDef != nil { + req.WithReplay(policyDef.HasRetries()) } + defer req.Close() policyRunner := resiliency.NewRunnerWithOptions( - ctx, policyDef, + reqCtx, policyDef, resiliency.RunnerOpts[*invokev1.InvokeMethodResponse]{ Disposer: resiliency.DisposerCloser[*invokev1.InvokeMethodResponse], }, @@ -1355,7 +1330,7 @@ func (a *api) onDirectMessage(reqCtx *fasthttp.RequestCtx) { // Allowlist policies that are applied on the callee side can return a Permission Denied error. // For everything else, treat it as a gRPC transport error invokeErr := invokeError{ - statusCode: fasthttp.StatusInternalServerError, + statusCode: nethttp.StatusInternalServerError, msg: NewErrorResponse("ERR_DIRECT_INVOKE", fmt.Sprintf(messages.ErrDirectInvoke, targetID, rErr)), } @@ -1369,7 +1344,7 @@ func (a *api) onDirectMessage(reqCtx *fasthttp.RequestCtx) { resStatus := rResp.Status() if !rResp.IsHTTPResponse() { statusCode := int32(invokev1.HTTPStatusFromCode(codes.Code(resStatus.Code))) - if statusCode != fasthttp.StatusOK { + if statusCode != nethttp.StatusOK { // Close the response to replace the body _ = rResp.Close() var body []byte @@ -1378,7 +1353,7 @@ func (a *api) onDirectMessage(reqCtx *fasthttp.RequestCtx) { resStatus.Code = statusCode if rErr != nil { return rResp, invokeError{ - statusCode: fasthttp.StatusInternalServerError, + statusCode: nethttp.StatusInternalServerError, msg: NewErrorResponse("ERR_MALFORMED_RESPONSE", rErr.Error()), } } @@ -1395,8 +1370,7 @@ func (a *api) onDirectMessage(reqCtx *fasthttp.RequestCtx) { // Special case for timeouts/circuit breakers since they won't go through the rest of the logic. if errors.Is(err, context.DeadlineExceeded) || breaker.IsErrorPermanent(err) { - respond(reqCtx, withError(fasthttp.StatusInternalServerError, NewErrorResponse("ERR_DIRECT_INVOKE", err.Error()))) - cancel() + respond(reqCtx, withError(nethttp.StatusInternalServerError, NewErrorResponse("ERR_DIRECT_INVOKE", err.Error()))) return } @@ -1410,7 +1384,6 @@ func (a *api) onDirectMessage(reqCtx *fasthttp.RequestCtx) { invokeErr := invokeError{} if errors.As(err, &invokeErr) { respond(reqCtx, withError(invokeErr.statusCode, invokeErr.msg)) - cancel() if resp != nil { _ = resp.Close() } @@ -1418,33 +1391,21 @@ func (a *api) onDirectMessage(reqCtx *fasthttp.RequestCtx) { } if resp == nil { - respond(reqCtx, withError(fasthttp.StatusInternalServerError, NewErrorResponse("ERR_DIRECT_INVOKE", fmt.Sprintf(messages.ErrDirectInvoke, targetID, "response object is nil")))) - cancel() + respond(reqCtx, withError(nethttp.StatusInternalServerError, NewErrorResponse("ERR_DIRECT_INVOKE", fmt.Sprintf(messages.ErrDirectInvoke, targetID, "response object is nil")))) return } + defer resp.Close() statusCode := int(resp.Status().Code) - // TODO @ItalyPaleAle: Make this the only path once streaming is finalized - if a.isStreamingEnabled { - // This will also close the response stream automatically; no need to invoke resp.Close() - // Likewise, it calls "cancel" to cancel the context at the end - respond(reqCtx, withStream(statusCode, resp.RawData(), cancel)) - return - } - - defer resp.Close() - body, err := resp.RawDataFull() if err != nil { - respond(reqCtx, withError(fasthttp.StatusInternalServerError, NewErrorResponse("ERR_DIRECT_INVOKE", fmt.Sprintf(messages.ErrDirectInvoke, targetID, err)))) - cancel() + respond(reqCtx, withError(nethttp.StatusInternalServerError, NewErrorResponse("ERR_DIRECT_INVOKE", fmt.Sprintf(messages.ErrDirectInvoke, targetID, err)))) return } reqCtx.Response.Header.SetContentType(resp.ContentType()) respond(reqCtx, with(statusCode, body)) - cancel() } // findTargetID tries to find ID of the target service from the following four places: @@ -1481,8 +1442,7 @@ func (a *api) findTargetID(reqCtx *fasthttp.RequestCtx) string { // parts[3]: http: // parts[4]: api.github.com // parts[5]: method - targetURL := parts[3] + "//" + parts[4] - return targetURL + return parts[3] + "//" + parts[4] } return "" @@ -1491,7 +1451,7 @@ func (a *api) findTargetID(reqCtx *fasthttp.RequestCtx) string { func (a *api) onCreateActorReminder(reqCtx *fasthttp.RequestCtx) { if a.universal.Actors == nil { msg := NewErrorResponse("ERR_ACTOR_RUNTIME_NOT_FOUND", messages.ErrActorRuntimeNotFound) - respond(reqCtx, withError(fasthttp.StatusInternalServerError, msg)) + respond(reqCtx, withError(nethttp.StatusInternalServerError, msg)) return } @@ -1503,7 +1463,7 @@ func (a *api) onCreateActorReminder(reqCtx *fasthttp.RequestCtx) { err := json.Unmarshal(reqCtx.PostBody(), &req) if err != nil { msg := NewErrorResponse("ERR_MALFORMED_REQUEST", fmt.Sprintf(messages.ErrMalformedRequest, err)) - respond(reqCtx, withError(fasthttp.StatusBadRequest, msg)) + respond(reqCtx, withError(nethttp.StatusBadRequest, msg)) log.Debug(msg) return } @@ -1515,7 +1475,7 @@ func (a *api) onCreateActorReminder(reqCtx *fasthttp.RequestCtx) { err = a.universal.Actors.CreateReminder(reqCtx, &req) if err != nil { msg := NewErrorResponse("ERR_ACTOR_REMINDER_CREATE", fmt.Sprintf(messages.ErrActorReminderCreate, err)) - respond(reqCtx, withError(fasthttp.StatusInternalServerError, msg)) + respond(reqCtx, withError(nethttp.StatusInternalServerError, msg)) log.Debug(msg) } else { respond(reqCtx, withEmpty()) @@ -1525,7 +1485,7 @@ func (a *api) onCreateActorReminder(reqCtx *fasthttp.RequestCtx) { func (a *api) onRenameActorReminder(reqCtx *fasthttp.RequestCtx) { if a.universal.Actors == nil { msg := NewErrorResponse("ERR_ACTOR_RUNTIME_NOT_FOUND", messages.ErrActorRuntimeNotFound) - respond(reqCtx, withError(fasthttp.StatusInternalServerError, msg)) + respond(reqCtx, withError(nethttp.StatusInternalServerError, msg)) return } @@ -1537,7 +1497,7 @@ func (a *api) onRenameActorReminder(reqCtx *fasthttp.RequestCtx) { err := json.Unmarshal(reqCtx.PostBody(), &req) if err != nil { msg := NewErrorResponse("ERR_MALFORMED_REQUEST", fmt.Sprintf(messages.ErrMalformedRequest, err)) - respond(reqCtx, withError(fasthttp.StatusBadRequest, msg)) + respond(reqCtx, withError(nethttp.StatusBadRequest, msg)) log.Debug(msg) return } @@ -1549,7 +1509,7 @@ func (a *api) onRenameActorReminder(reqCtx *fasthttp.RequestCtx) { err = a.universal.Actors.RenameReminder(reqCtx, &req) if err != nil { msg := NewErrorResponse("ERR_ACTOR_REMINDER_RENAME", fmt.Sprintf(messages.ErrActorReminderRename, err)) - respond(reqCtx, withError(fasthttp.StatusInternalServerError, msg)) + respond(reqCtx, withError(nethttp.StatusInternalServerError, msg)) log.Debug(msg) } else { respond(reqCtx, withEmpty()) @@ -1559,7 +1519,7 @@ func (a *api) onRenameActorReminder(reqCtx *fasthttp.RequestCtx) { func (a *api) onCreateActorTimer(reqCtx *fasthttp.RequestCtx) { if a.universal.Actors == nil { msg := NewErrorResponse("ERR_ACTOR_RUNTIME_NOT_FOUND", messages.ErrActorRuntimeNotFound) - respond(reqCtx, withError(fasthttp.StatusInternalServerError, msg)) + respond(reqCtx, withError(nethttp.StatusInternalServerError, msg)) log.Debug(msg) return } @@ -1572,7 +1532,7 @@ func (a *api) onCreateActorTimer(reqCtx *fasthttp.RequestCtx) { err := json.Unmarshal(reqCtx.PostBody(), &req) if err != nil { msg := NewErrorResponse("ERR_MALFORMED_REQUEST", fmt.Sprintf(messages.ErrMalformedRequest, err)) - respond(reqCtx, withError(fasthttp.StatusBadRequest, msg)) + respond(reqCtx, withError(nethttp.StatusBadRequest, msg)) log.Debug(msg) return } @@ -1584,7 +1544,7 @@ func (a *api) onCreateActorTimer(reqCtx *fasthttp.RequestCtx) { err = a.universal.Actors.CreateTimer(reqCtx, &req) if err != nil { msg := NewErrorResponse("ERR_ACTOR_TIMER_CREATE", fmt.Sprintf(messages.ErrActorTimerCreate, err)) - respond(reqCtx, withError(fasthttp.StatusInternalServerError, msg)) + respond(reqCtx, withError(nethttp.StatusInternalServerError, msg)) log.Debug(msg) } else { respond(reqCtx, withEmpty()) @@ -1594,7 +1554,7 @@ func (a *api) onCreateActorTimer(reqCtx *fasthttp.RequestCtx) { func (a *api) onDeleteActorReminder(reqCtx *fasthttp.RequestCtx) { if a.universal.Actors == nil { msg := NewErrorResponse("ERR_ACTOR_RUNTIME_NOT_FOUND", messages.ErrActorRuntimeNotFound) - respond(reqCtx, withError(fasthttp.StatusInternalServerError, msg)) + respond(reqCtx, withError(nethttp.StatusInternalServerError, msg)) log.Debug(msg) return } @@ -1612,7 +1572,7 @@ func (a *api) onDeleteActorReminder(reqCtx *fasthttp.RequestCtx) { err := a.universal.Actors.DeleteReminder(reqCtx, &req) if err != nil { msg := NewErrorResponse("ERR_ACTOR_REMINDER_DELETE", fmt.Sprintf(messages.ErrActorReminderDelete, err)) - respond(reqCtx, withError(fasthttp.StatusInternalServerError, msg)) + respond(reqCtx, withError(nethttp.StatusInternalServerError, msg)) log.Debug(msg) } else { respond(reqCtx, withEmpty()) @@ -1622,7 +1582,7 @@ func (a *api) onDeleteActorReminder(reqCtx *fasthttp.RequestCtx) { func (a *api) onActorStateTransaction(reqCtx *fasthttp.RequestCtx) { if a.universal.Actors == nil { msg := NewErrorResponse("ERR_ACTOR_RUNTIME_NOT_FOUND", messages.ErrActorRuntimeNotFound) - respond(reqCtx, withError(fasthttp.StatusInternalServerError, msg)) + respond(reqCtx, withError(nethttp.StatusInternalServerError, msg)) log.Debug(msg) return } @@ -1635,7 +1595,7 @@ func (a *api) onActorStateTransaction(reqCtx *fasthttp.RequestCtx) { err := json.Unmarshal(body, &ops) if err != nil { msg := NewErrorResponse("ERR_MALFORMED_REQUEST", err.Error()) - respond(reqCtx, withError(fasthttp.StatusBadRequest, msg)) + respond(reqCtx, withError(nethttp.StatusBadRequest, msg)) log.Debug(msg) return } @@ -1647,7 +1607,7 @@ func (a *api) onActorStateTransaction(reqCtx *fasthttp.RequestCtx) { if !hosted { msg := NewErrorResponse("ERR_ACTOR_INSTANCE_MISSING", messages.ErrActorInstanceMissing) - respond(reqCtx, withError(fasthttp.StatusBadRequest, msg)) + respond(reqCtx, withError(nethttp.StatusBadRequest, msg)) log.Debug(msg) return } @@ -1661,7 +1621,7 @@ func (a *api) onActorStateTransaction(reqCtx *fasthttp.RequestCtx) { err = a.universal.Actors.TransactionalStateOperation(reqCtx, &req) if err != nil { msg := NewErrorResponse("ERR_ACTOR_STATE_TRANSACTION_SAVE", fmt.Sprintf(messages.ErrActorStateTransactionSave, err)) - respond(reqCtx, withError(fasthttp.StatusInternalServerError, msg)) + respond(reqCtx, withError(nethttp.StatusInternalServerError, msg)) log.Debug(msg) } else { respond(reqCtx, withEmpty()) @@ -1671,7 +1631,7 @@ func (a *api) onActorStateTransaction(reqCtx *fasthttp.RequestCtx) { func (a *api) onGetActorReminder(reqCtx *fasthttp.RequestCtx) { if a.universal.Actors == nil { msg := NewErrorResponse("ERR_ACTOR_RUNTIME_NOT_FOUND", messages.ErrActorRuntimeNotFound) - respond(reqCtx, withError(fasthttp.StatusInternalServerError, msg)) + respond(reqCtx, withError(nethttp.StatusInternalServerError, msg)) log.Debug(msg) return } @@ -1687,25 +1647,25 @@ func (a *api) onGetActorReminder(reqCtx *fasthttp.RequestCtx) { }) if err != nil { msg := NewErrorResponse("ERR_ACTOR_REMINDER_GET", fmt.Sprintf(messages.ErrActorReminderGet, err)) - respond(reqCtx, withError(fasthttp.StatusInternalServerError, msg)) + respond(reqCtx, withError(nethttp.StatusInternalServerError, msg)) log.Debug(msg) return } b, err := json.Marshal(resp) if err != nil { msg := NewErrorResponse("ERR_ACTOR_REMINDER_GET", fmt.Sprintf(messages.ErrActorReminderGet, err)) - respond(reqCtx, withError(fasthttp.StatusInternalServerError, msg)) + respond(reqCtx, withError(nethttp.StatusInternalServerError, msg)) log.Debug(msg) return } - respond(reqCtx, withJSON(fasthttp.StatusOK, b)) + respond(reqCtx, withJSON(nethttp.StatusOK, b)) } func (a *api) onDeleteActorTimer(reqCtx *fasthttp.RequestCtx) { if a.universal.Actors == nil { msg := NewErrorResponse("ERR_ACTOR_RUNTIME_NOT_FOUND", messages.ErrActorRuntimeNotFound) - respond(reqCtx, withError(fasthttp.StatusInternalServerError, msg)) + respond(reqCtx, withError(nethttp.StatusInternalServerError, msg)) log.Debug(msg) return } @@ -1722,7 +1682,7 @@ func (a *api) onDeleteActorTimer(reqCtx *fasthttp.RequestCtx) { err := a.universal.Actors.DeleteTimer(reqCtx, &req) if err != nil { msg := NewErrorResponse("ERR_ACTOR_TIMER_DELETE", fmt.Sprintf(messages.ErrActorTimerDelete, err)) - respond(reqCtx, withError(fasthttp.StatusInternalServerError, msg)) + respond(reqCtx, withError(nethttp.StatusInternalServerError, msg)) log.Debug(msg) } else { respond(reqCtx, withEmpty()) @@ -1732,7 +1692,7 @@ func (a *api) onDeleteActorTimer(reqCtx *fasthttp.RequestCtx) { func (a *api) onDirectActorMessage(reqCtx *fasthttp.RequestCtx) { if a.universal.Actors == nil { msg := NewErrorResponse("ERR_ACTOR_RUNTIME_NOT_FOUND", messages.ErrActorRuntimeNotFound) - respond(reqCtx, withError(fasthttp.StatusInternalServerError, msg)) + respond(reqCtx, withError(nethttp.StatusInternalServerError, msg)) log.Debug(msg) return } @@ -1773,14 +1733,14 @@ func (a *api) onDirectActorMessage(reqCtx *fasthttp.RequestCtx) { }) if err != nil && !errors.Is(err, actors.ErrDaprResponseHeader) { msg := NewErrorResponse("ERR_ACTOR_INVOKE_METHOD", fmt.Sprintf(messages.ErrActorInvoke, err)) - respond(reqCtx, withError(fasthttp.StatusInternalServerError, msg)) + respond(reqCtx, withError(nethttp.StatusInternalServerError, msg)) log.Debug(msg) return } if resp == nil { msg := NewErrorResponse("ERR_ACTOR_INVOKE_METHOD", fmt.Sprintf(messages.ErrActorInvoke, "failed to cast response")) - respond(reqCtx, withError(fasthttp.StatusInternalServerError, msg)) + respond(reqCtx, withError(nethttp.StatusInternalServerError, msg)) log.Debug(msg) return } @@ -1791,7 +1751,7 @@ func (a *api) onDirectActorMessage(reqCtx *fasthttp.RequestCtx) { body, err := resp.RawDataFull() if err != nil { msg := NewErrorResponse("ERR_ACTOR_INVOKE_METHOD", fmt.Sprintf(messages.ErrActorInvoke, err)) - respond(reqCtx, withError(fasthttp.StatusInternalServerError, msg)) + respond(reqCtx, withError(nethttp.StatusInternalServerError, msg)) log.Debug(msg) return } @@ -1808,7 +1768,7 @@ func (a *api) onDirectActorMessage(reqCtx *fasthttp.RequestCtx) { func (a *api) onGetActorState(reqCtx *fasthttp.RequestCtx) { if a.universal.Actors == nil { msg := NewErrorResponse("ERR_ACTOR_RUNTIME_NOT_FOUND", messages.ErrActorRuntimeNotFound) - respond(reqCtx, withError(fasthttp.StatusInternalServerError, msg)) + respond(reqCtx, withError(nethttp.StatusInternalServerError, msg)) log.Debug(msg) return } @@ -1824,7 +1784,7 @@ func (a *api) onGetActorState(reqCtx *fasthttp.RequestCtx) { if !hosted { msg := NewErrorResponse("ERR_ACTOR_INSTANCE_MISSING", messages.ErrActorInstanceMissing) - respond(reqCtx, withError(fasthttp.StatusBadRequest, msg)) + respond(reqCtx, withError(nethttp.StatusBadRequest, msg)) log.Debug(msg) return } @@ -1838,14 +1798,14 @@ func (a *api) onGetActorState(reqCtx *fasthttp.RequestCtx) { resp, err := a.universal.Actors.GetState(reqCtx, &req) if err != nil { msg := NewErrorResponse("ERR_ACTOR_STATE_GET", fmt.Sprintf(messages.ErrActorStateGet, err)) - respond(reqCtx, withError(fasthttp.StatusInternalServerError, msg)) + respond(reqCtx, withError(nethttp.StatusInternalServerError, msg)) log.Debug(msg) } else { if resp == nil || len(resp.Data) == 0 { respond(reqCtx, withEmpty()) return } - respond(reqCtx, withJSON(fasthttp.StatusOK, resp.Data)) + respond(reqCtx, withJSON(nethttp.StatusOK, resp.Data)) } } @@ -1891,7 +1851,7 @@ func (a *api) onPublish(reqCtx *fasthttp.RequestCtx) { if err != nil { msg := NewErrorResponse("ERR_PUBSUB_CLOUD_EVENTS_SER", fmt.Sprintf(messages.ErrPubsubCloudEventCreation, err.Error())) - respond(reqCtx, withError(fasthttp.StatusInternalServerError, msg)) + respond(reqCtx, withError(nethttp.StatusInternalServerError, msg)) log.Debug(msg) return } @@ -1904,7 +1864,7 @@ func (a *api) onPublish(reqCtx *fasthttp.RequestCtx) { if err != nil { msg := NewErrorResponse("ERR_PUBSUB_CLOUD_EVENTS_SER", fmt.Sprintf(messages.ErrPubsubCloudEventsSer, topic, pubsubName, err.Error())) - respond(reqCtx, withError(fasthttp.StatusInternalServerError, msg)) + respond(reqCtx, withError(nethttp.StatusInternalServerError, msg)) log.Debug(msg) return } @@ -1924,18 +1884,18 @@ func (a *api) onPublish(reqCtx *fasthttp.RequestCtx) { diag.DefaultComponentMonitoring.PubsubEgressEvent(context.Background(), pubsubName, topic, err == nil, elapsed) if err != nil { - status := fasthttp.StatusInternalServerError + status := nethttp.StatusInternalServerError msg := NewErrorResponse("ERR_PUBSUB_PUBLISH_MESSAGE", fmt.Sprintf(messages.ErrPubsubPublishMessage, topic, pubsubName, err.Error())) if errors.As(err, &runtimePubsub.NotAllowedError{}) { msg = NewErrorResponse("ERR_PUBSUB_FORBIDDEN", err.Error()) - status = fasthttp.StatusForbidden + status = nethttp.StatusForbidden } if errors.As(err, &runtimePubsub.NotFoundError{}) { msg = NewErrorResponse("ERR_PUBSUB_NOT_FOUND", err.Error()) - status = fasthttp.StatusBadRequest + status = nethttp.StatusBadRequest } respond(reqCtx, withError(status, msg)) @@ -1946,7 +1906,7 @@ func (a *api) onPublish(reqCtx *fasthttp.RequestCtx) { } type bulkPublishMessageEntry struct { - EntryId string `json:"entryId,omitempty"` //nolint:stylecheck + EntryID string `json:"entryId,omitempty"` Event interface{} `json:"event"` ContentType string `json:"contentType"` Metadata map[string]string `json:"metadata,omitempty"` @@ -1981,23 +1941,22 @@ func (a *api) onBulkPublish(reqCtx *fasthttp.RequestCtx) { if err != nil { msg := NewErrorResponse("ERR_PUBSUB_EVENTS_SER", fmt.Sprintf(messages.ErrPubsubUnmarshal, topic, pubsubName, err.Error())) - respond(reqCtx, withError(fasthttp.StatusBadRequest, msg)) + respond(reqCtx, withError(nethttp.StatusBadRequest, msg)) log.Debug(msg) return } entries := make([]pubsub.BulkMessageEntry, len(incomingEntries)) - entryIdSet := map[string]struct{}{} //nolint:stylecheck + entryIDSet := map[string]struct{}{} for i, entry := range incomingEntries { var dBytes []byte - - dBytes, cErr := ConvertEventToBytes(entry.Event, entry.ContentType) - if cErr != nil { + dBytes, err = ConvertEventToBytes(entry.Event, entry.ContentType) + if err != nil { msg := NewErrorResponse("ERR_PUBSUB_EVENTS_SER", - fmt.Sprintf(messages.ErrPubsubMarshal, topic, pubsubName, cErr.Error())) - respond(reqCtx, withError(fasthttp.StatusBadRequest, msg)) + fmt.Sprintf(messages.ErrPubsubMarshal, topic, pubsubName, err.Error())) + respond(reqCtx, withError(nethttp.StatusBadRequest, msg)) log.Debug(msg) return } @@ -2010,16 +1969,16 @@ func (a *api) onBulkPublish(reqCtx *fasthttp.RequestCtx) { // override request level metadata. entries[i].Metadata = utils.PopulateMetadataForBulkPublishEntry(metadata, entry.Metadata) } - if _, ok := entryIdSet[entry.EntryId]; ok || entry.EntryId == "" { + if _, ok := entryIDSet[entry.EntryID]; ok || entry.EntryID == "" { msg := NewErrorResponse("ERR_PUBSUB_EVENTS_SER", fmt.Sprintf(messages.ErrPubsubMarshal, topic, pubsubName, "error: entryId is duplicated or not present for entry")) - respond(reqCtx, withError(fasthttp.StatusBadRequest, msg)) + respond(reqCtx, withError(nethttp.StatusBadRequest, msg)) log.Debug(msg) return } - entryIdSet[entry.EntryId] = struct{}{} - entries[i].EntryId = entry.EntryId + entryIDSet[entry.EntryID] = struct{}{} + entries[i].EntryId = entry.EntryID } spanMap := map[int]trace.Span{} @@ -2039,7 +1998,8 @@ func (a *api) onBulkPublish(reqCtx *fasthttp.RequestCtx) { corID := diag.SpanContextToW3CString(childSpan.SpanContext()) spanMap[i] = childSpan - envelope, envelopeErr := runtimePubsub.NewCloudEvent(&runtimePubsub.CloudEvent{ + var envelope map[string]interface{} + envelope, err = runtimePubsub.NewCloudEvent(&runtimePubsub.CloudEvent{ Source: a.universal.AppID, Topic: topic, DataContentType: entries[i].ContentType, @@ -2048,10 +2008,10 @@ func (a *api) onBulkPublish(reqCtx *fasthttp.RequestCtx) { TraceState: traceState, Pubsub: pubsubName, }, entries[i].Metadata) - if envelopeErr != nil { + if err != nil { msg := NewErrorResponse("ERR_PUBSUB_CLOUD_EVENTS_SER", - fmt.Sprintf(messages.ErrPubsubCloudEventCreation, envelopeErr.Error())) - respond(reqCtx, withError(fasthttp.StatusInternalServerError, msg), closeChildSpans) + fmt.Sprintf(messages.ErrPubsubCloudEventCreation, err.Error())) + respond(reqCtx, withError(nethttp.StatusInternalServerError, msg), closeChildSpans) log.Debug(msg) return @@ -2063,7 +2023,7 @@ func (a *api) onBulkPublish(reqCtx *fasthttp.RequestCtx) { if err != nil { msg := NewErrorResponse("ERR_PUBSUB_CLOUD_EVENTS_SER", fmt.Sprintf(messages.ErrPubsubCloudEventsSer, topic, pubsubName, err.Error())) - respond(reqCtx, withError(fasthttp.StatusInternalServerError, msg), closeChildSpans) + respond(reqCtx, withError(nethttp.StatusInternalServerError, msg), closeChildSpans) log.Debug(msg) return @@ -2101,12 +2061,12 @@ func (a *api) onBulkPublish(reqCtx *fasthttp.RequestCtx) { } bulkRes.FailedEntries = append(bulkRes.FailedEntries, resEntry) } - status := fasthttp.StatusInternalServerError + status := nethttp.StatusInternalServerError bulkRes.ErrorCode = "ERR_PUBSUB_PUBLISH_MESSAGE" if errors.As(err, &runtimePubsub.NotAllowedError{}) { msg := NewErrorResponse("ERR_PUBSUB_FORBIDDEN", err.Error()) - status = fasthttp.StatusForbidden + status = nethttp.StatusForbidden respond(reqCtx, withError(status, msg), closeChildSpans) log.Debug(msg) @@ -2115,7 +2075,7 @@ func (a *api) onBulkPublish(reqCtx *fasthttp.RequestCtx) { if errors.As(err, &runtimePubsub.NotFoundError{}) { msg := NewErrorResponse("ERR_PUBSUB_NOT_FOUND", err.Error()) - status = fasthttp.StatusBadRequest + status = nethttp.StatusBadRequest respond(reqCtx, withError(status, msg), closeChildSpans) log.Debug(msg) @@ -2138,30 +2098,30 @@ func (a *api) validateAndGetPubsubAndTopic(reqCtx *fasthttp.RequestCtx) (pubsub. if a.pubsubAdapter == nil { msg := NewErrorResponse("ERR_PUBSUB_NOT_CONFIGURED", messages.ErrPubsubNotConfigured) - return nil, "", "", fasthttp.StatusBadRequest, &msg + return nil, "", "", nethttp.StatusBadRequest, &msg } pubsubName := reqCtx.UserValue(pubsubnameparam).(string) if pubsubName == "" { msg := NewErrorResponse("ERR_PUBSUB_EMPTY", messages.ErrPubsubEmpty) - return nil, "", "", fasthttp.StatusNotFound, &msg + return nil, "", "", nethttp.StatusNotFound, &msg } thepubsub := a.pubsubAdapter.GetPubSub(pubsubName) if thepubsub == nil { msg := NewErrorResponse("ERR_PUBSUB_NOT_FOUND", fmt.Sprintf(messages.ErrPubsubNotFound, pubsubName)) - return nil, "", "", fasthttp.StatusNotFound, &msg + return nil, "", "", nethttp.StatusNotFound, &msg } topic := reqCtx.UserValue(topicParam).(string) if topic == "" { msg := NewErrorResponse("ERR_TOPIC_EMPTY", fmt.Sprintf(messages.ErrTopicEmpty, pubsubName)) - return nil, "", "", fasthttp.StatusNotFound, &msg + return nil, "", "", nethttp.StatusNotFound, &msg } - return thepubsub, pubsubName, topic, fasthttp.StatusOK, nil + return thepubsub, pubsubName, topic, nethttp.StatusOK, nil } // GetStatusCodeFromMetadata extracts the http status code from the metadata if it exists. @@ -2174,13 +2134,13 @@ func GetStatusCodeFromMetadata(metadata map[string]string) int { } } - return fasthttp.StatusOK + return nethttp.StatusOK } func (a *api) onGetHealthz(reqCtx *fasthttp.RequestCtx) { if !a.readyStatus { msg := NewErrorResponse("ERR_HEALTH_NOT_READY", messages.ErrHealthNotReady) - respond(reqCtx, withError(fasthttp.StatusInternalServerError, msg)) + respond(reqCtx, withError(nethttp.StatusInternalServerError, msg)) log.Debug(msg) } else { respond(reqCtx, withEmpty()) @@ -2190,7 +2150,7 @@ func (a *api) onGetHealthz(reqCtx *fasthttp.RequestCtx) { func (a *api) onGetOutboundHealthz(reqCtx *fasthttp.RequestCtx) { if !a.outboundReadyStatus { msg := NewErrorResponse("ERR_OUTBOUND_HEALTH_NOT_READY", messages.ErrOutboundHealthNotReady) - respond(reqCtx, withError(fasthttp.StatusInternalServerError, msg)) + respond(reqCtx, withError(nethttp.StatusInternalServerError, msg)) log.Debug(msg) } else { respond(reqCtx, withEmpty()) @@ -2240,7 +2200,7 @@ func (a *api) onPostStateTransaction(reqCtx *fasthttp.RequestCtx) { transactionalStore, ok := store.(state.TransactionalStore) if !ok { msg := NewErrorResponse("ERR_STATE_STORE_NOT_SUPPORTED", fmt.Sprintf(messages.ErrStateStoreNotSupported, storeName)) - respond(reqCtx, withError(fasthttp.StatusInternalServerError, msg)) + respond(reqCtx, withError(nethttp.StatusInternalServerError, msg)) log.Debug(msg) return } @@ -2249,7 +2209,7 @@ func (a *api) onPostStateTransaction(reqCtx *fasthttp.RequestCtx) { var req stateTransactionRequestBody if err := json.Unmarshal(body, &req); err != nil { msg := NewErrorResponse("ERR_MALFORMED_REQUEST", fmt.Sprintf(messages.ErrMalformedRequest, err.Error())) - respond(reqCtx, withError(fasthttp.StatusBadRequest, msg)) + respond(reqCtx, withError(nethttp.StatusBadRequest, msg)) log.Debug(msg) return } @@ -2277,14 +2237,14 @@ func (a *api) onPostStateTransaction(reqCtx *fasthttp.RequestCtx) { if err != nil { msg := NewErrorResponse("ERR_MALFORMED_REQUEST", fmt.Sprintf(messages.ErrMalformedRequest, err.Error())) - respond(reqCtx, withError(fasthttp.StatusBadRequest, msg)) + respond(reqCtx, withError(nethttp.StatusBadRequest, msg)) log.Debug(msg) return } upsertReq.Key, err = stateLoader.GetModifiedStateKey(upsertReq.Key, storeName, a.universal.AppID) if err != nil { msg := NewErrorResponse("ERR_MALFORMED_REQUEST", err.Error()) - respond(reqCtx, withError(fasthttp.StatusBadRequest, msg)) + respond(reqCtx, withError(nethttp.StatusBadRequest, msg)) log.Debug(err) return } @@ -2295,14 +2255,14 @@ func (a *api) onPostStateTransaction(reqCtx *fasthttp.RequestCtx) { if err != nil { msg := NewErrorResponse("ERR_MALFORMED_REQUEST", fmt.Sprintf(messages.ErrMalformedRequest, err.Error())) - respond(reqCtx, withError(fasthttp.StatusBadRequest, msg)) + respond(reqCtx, withError(nethttp.StatusBadRequest, msg)) log.Debug(msg) return } delReq.Key, err = stateLoader.GetModifiedStateKey(delReq.Key, storeName, a.universal.AppID) if err != nil { msg := NewErrorResponse("ERR_MALFORMED_REQUEST", err.Error()) - respond(reqCtx, withError(fasthttp.StatusBadRequest, msg)) + respond(reqCtx, withError(nethttp.StatusBadRequest, msg)) log.Debug(msg) return } @@ -2311,7 +2271,7 @@ func (a *api) onPostStateTransaction(reqCtx *fasthttp.RequestCtx) { msg := NewErrorResponse( "ERR_NOT_SUPPORTED_STATE_OPERATION", fmt.Sprintf(messages.ErrNotSupportedStateOperation, o.Operation)) - respond(reqCtx, withError(fasthttp.StatusBadRequest, msg)) + respond(reqCtx, withError(nethttp.StatusBadRequest, msg)) log.Debug(msg) return } @@ -2327,7 +2287,7 @@ func (a *api) onPostStateTransaction(reqCtx *fasthttp.RequestCtx) { msg := NewErrorResponse( "ERR_SAVE_STATE", fmt.Sprintf(messages.ErrStateSave, storeName, err.Error())) - respond(reqCtx, withError(fasthttp.StatusBadRequest, msg)) + respond(reqCtx, withError(nethttp.StatusBadRequest, msg)) log.Debug(msg) return } @@ -2355,7 +2315,7 @@ func (a *api) onPostStateTransaction(reqCtx *fasthttp.RequestCtx) { if err != nil { msg := NewErrorResponse("ERR_STATE_TRANSACTION", fmt.Sprintf(messages.ErrStateTransaction, err.Error())) - respond(reqCtx, withError(fasthttp.StatusInternalServerError, msg)) + respond(reqCtx, withError(nethttp.StatusInternalServerError, msg)) log.Debug(msg) } else { respond(reqCtx, withEmpty()) diff --git a/pkg/http/api_test.go b/pkg/http/api_test.go index 46313c2e52f..50608cf487b 100644 --- a/pkg/http/api_test.go +++ b/pkg/http/api_test.go @@ -349,7 +349,7 @@ func TestBulkPubSubEndpoints(t *testing.T) { bulkRequest := []bulkPublishMessageEntry{ { - EntryId: "1", + EntryID: "1", Event: map[string]string{ "key": "first", "value": "first value", @@ -357,7 +357,7 @@ func TestBulkPubSubEndpoints(t *testing.T) { ContentType: "application/json", }, { - EntryId: "2", + EntryID: "2", Event: map[string]string{ "key": "second", "value": "second value", @@ -369,7 +369,7 @@ func TestBulkPubSubEndpoints(t *testing.T) { }, }, { - EntryId: "3", + EntryID: "3", Event: map[string]string{ "key": "third", "value": "third value", @@ -466,7 +466,7 @@ func TestBulkPubSubEndpoints(t *testing.T) { errBulkRequest := []bulkPublishMessageEntry{} for _, entry := range bulkRequest { // Fail entries 2 and 3 - if entry.EntryId == "2" || entry.EntryId == "3" { + if entry.EntryID == "2" || entry.EntryID == "3" { if entry.Metadata == nil { entry.Metadata = map[string]string{} } @@ -555,7 +555,7 @@ func TestBulkPubSubEndpoints(t *testing.T) { t.Run("Bulk Publish with duplicate entryId - 400", func(t *testing.T) { reqWithoutEntryId := []bulkPublishMessageEntry{ //nolint:stylecheck { - EntryId: "1", + EntryID: "1", Event: map[string]string{ "key": "first", "value": "first value", @@ -563,7 +563,7 @@ func TestBulkPubSubEndpoints(t *testing.T) { ContentType: "application/json", }, { - EntryId: "1", + EntryID: "1", Event: map[string]string{ "key": "second", "value": "second value", @@ -604,7 +604,7 @@ func TestBulkPubSubEndpoints(t *testing.T) { t.Run("Bulk Publish invalid cloudevent - 400", func(t *testing.T) { reqInvalidCE := []bulkPublishMessageEntry{ { - EntryId: "1", + EntryID: "1", Event: map[string]string{ "key": "first", "value": "first value", @@ -612,7 +612,7 @@ func TestBulkPubSubEndpoints(t *testing.T) { ContentType: "application/json", }, { - EntryId: "2", + EntryID: "2", Event: "this is not a cloudevent!", ContentType: "application/cloudevents+json", Metadata: map[string]string{ @@ -637,7 +637,7 @@ func TestBulkPubSubEndpoints(t *testing.T) { t.Run("Bulk Publish dataContentType mismatch - 400", func(t *testing.T) { dCTMismatch := []bulkPublishMessageEntry{ { - EntryId: "1", + EntryID: "1", Event: map[string]string{ "key": "first", "value": "first value", @@ -645,7 +645,7 @@ func TestBulkPubSubEndpoints(t *testing.T) { ContentType: "application/json", }, { - EntryId: "2", + EntryID: "2", Event: map[string]string{ "key": "second", "value": "second value", @@ -4135,10 +4135,12 @@ func (f *fakeHTTPServer) StartServer(endpoints []Endpoint) { } func (f *fakeHTTPServer) StartServerWithTracing(spec config.TracingSpec, endpoints []Endpoint) { - router := f.getRouter(endpoints) + router := nethttpadaptor.NewNetHTTPHandlerFunc(f.getRouter(endpoints).Handler) f.ln = fasthttputil.NewInmemoryListener() go func() { - if err := fasthttp.Serve(f.ln, diag.HTTPTraceMiddleware(router.Handler, "fakeAppID", spec)); err != nil { + h := fasthttpadaptor.NewFastHTTPHandler(diag.HTTPTraceMiddleware(router, "fakeAppID", spec)) + err := fasthttp.Serve(f.ln, h) + if err != nil { panic(fmt.Errorf("failed to set tracing span context: %v", err)) } }() @@ -4155,8 +4157,10 @@ func (f *fakeHTTPServer) StartServerWithTracing(spec config.TracingSpec, endpoin func (f *fakeHTTPServer) StartServerWithAPIToken(endpoints []Endpoint) { router := f.getRouter(endpoints) f.ln = fasthttputil.NewInmemoryListener() + h := nethttpadaptor.NewNetHTTPHandlerFunc(router.Handler) go func() { - if err := fasthttp.Serve(f.ln, useAPIAuthentication(router.Handler)); err != nil { + err := gohttp.Serve(f.ln, useAPIAuthentication(h)) //nolint:gosec + if err != nil { panic(fmt.Errorf("failed to serve: %v", err)) } }() @@ -4175,11 +4179,12 @@ func (f *fakeHTTPServer) StartServerWithTracingAndPipeline(spec config.TracingSp f.ln = fasthttputil.NewInmemoryListener() go func() { handler := fasthttpadaptor.NewFastHTTPHandler( - pipeline.Apply( + diag.HTTPTraceMiddleware(pipeline.Apply( nethttpadaptor.NewNetHTTPHandlerFunc(router.Handler), - ), + ), "fakeAppID", spec), ) - if err := fasthttp.Serve(f.ln, diag.HTTPTraceMiddleware(handler, "fakeAppID", spec)); err != nil { + err := fasthttp.Serve(f.ln, handler) + if err != nil { panic(fmt.Errorf("failed to serve tracing span context: %v", err)) } }() diff --git a/pkg/http/config.go b/pkg/http/config.go index 8500b740c48..9207f824c63 100644 --- a/pkg/http/config.go +++ b/pkg/http/config.go @@ -23,9 +23,9 @@ type ServerConfig struct { ProfilePort int AllowedOrigins string EnableProfiling bool - MaxRequestBodySize int + MaxRequestBodySizeMB int UnixDomainSocket string - ReadBufferSize int + ReadBufferSizeKB int EnableAPILogging bool APILoggingObfuscateURLs bool APILogHealthChecks bool diff --git a/pkg/http/responses.go b/pkg/http/responses.go index 3dce8e4143e..65ed0128b17 100644 --- a/pkg/http/responses.go +++ b/pkg/http/responses.go @@ -15,8 +15,6 @@ package http import ( "encoding/json" - "io" - "net" "github.com/valyala/fasthttp" ) @@ -65,15 +63,6 @@ type QueryItem struct { type option = func(ctx *fasthttp.RequestCtx) -// withEtag sets etag header. -func withEtag(etag *string) option { - return func(ctx *fasthttp.RequestCtx) { - if etag != nil { - ctx.Response.Header.Add(etagHeader, *etag) - } - } -} - // withMetadata sets metadata headers. func withMetadata(metadata map[string]string) option { return func(ctx *fasthttp.RequestCtx) { @@ -118,42 +107,6 @@ func with(code int, obj []byte) option { } } -// withStream is like "with" but accepts a stream -// The stream is closed at the end if it implements the Close() method -func withStream(code int, r io.Reader, onDone func()) option { - return func(ctx *fasthttp.RequestCtx) { - if len(ctx.Response.Header.ContentType()) == 0 { - ctx.Response.Header.SetContentType(jsonContentTypeHeader) - } - ctx.Response.SetStatusCode(code) - - // This is a bit hacky (there's literally "hijack" in the name), but it seems to be the only way we can actually send data to the client in a streamed way - // (believe me, I've spent over a day on this and I'm not exaggerating) - ctx.HijackSetNoResponse(true) - ctx.Hijack(func(c net.Conn) { - // Write the headers - c.Write(ctx.Response.Header.Header()) - - // Send the data as a stream - _, err := io.Copy(c, r) - if err != nil { - log.Warn("Error while copying response into connection: ", err) - } - - // Close the stream if it implements io.Closer - if rc, ok := r.(io.Closer); ok { - _ = rc.Close() - } - - // Call the "onDone" method (usually a context.Cancel function) - // Note: "c" (net.Conn) is closed automatically, no need to close that - if onDone != nil { - onDone() - } - }) - } -} - func respond(ctx *fasthttp.RequestCtx, options ...option) { for _, option := range options { option(ctx) diff --git a/pkg/http/responses_test.go b/pkg/http/responses_test.go index 356900894c8..202274c2cfd 100644 --- a/pkg/http/responses_test.go +++ b/pkg/http/responses_test.go @@ -18,8 +18,6 @@ import ( "github.com/stretchr/testify/assert" "github.com/valyala/fasthttp" - - "github.com/dapr/kit/ptr" ) func TestHeaders(t *testing.T) { @@ -38,14 +36,6 @@ func TestHeaders(t *testing.T) { assert.Equal(t, "application/json", string(ctx.Response.Header.ContentType())) }) - t.Run("Respond with ETag JSON", func(t *testing.T) { - ctx := &fasthttp.RequestCtx{Request: fasthttp.Request{}} - etagValue := "etagValue" - respond(ctx, withJSON(200, nil), withEtag(ptr.Of(etagValue))) - - assert.Equal(t, etagValue, string(ctx.Response.Header.Peek(etagHeader))) - }) - t.Run("Respond with metadata and JSON", func(t *testing.T) { ctx := &fasthttp.RequestCtx{Request: fasthttp.Request{}} respond(ctx, withJSON(200, nil), withMetadata(map[string]string{"key": "value"})) diff --git a/pkg/http/server.go b/pkg/http/server.go index 2a9513dd070..738cd01aeb7 100644 --- a/pkg/http/server.go +++ b/pkg/http/server.go @@ -14,6 +14,7 @@ limitations under the License. package http import ( + "context" "errors" "fmt" "io" @@ -23,12 +24,16 @@ import ( "regexp" "strconv" "strings" + "time" + + // Import pprof that automatically registers itself in the default server mux. + // Putting "nolint:gosec" here because the linter points out this is automatically exposed on the default server mux, but we only use that in the profiling server. + //nolint:gosec + _ "net/http/pprof" - cors "github.com/AdhityaRamadhanus/fasthttpcors" routing "github.com/fasthttp/router" - "github.com/hashicorp/go-multierror" + "github.com/go-chi/cors" "github.com/valyala/fasthttp" - "github.com/valyala/fasthttp/pprofhandler" "github.com/dapr/dapr/pkg/config" corsDapr "github.com/dapr/dapr/pkg/cors" @@ -37,8 +42,8 @@ import ( httpMiddleware "github.com/dapr/dapr/pkg/middleware/http" auth "github.com/dapr/dapr/pkg/runtime/security" authConsts "github.com/dapr/dapr/pkg/runtime/security/consts" - "github.com/dapr/dapr/utils/fasthttpadaptor" "github.com/dapr/dapr/utils/nethttpadaptor" + "github.com/dapr/dapr/utils/streams" "github.com/dapr/kit/logger" ) @@ -60,7 +65,7 @@ type server struct { pipeline httpMiddleware.Pipeline api API apiSpec config.APISpec - servers []*fasthttp.Server + servers []*http.Server profilingListeners []net.Listener } @@ -90,11 +95,14 @@ func NewServer(opts NewServerOpts) Server { // StartNonBlocking starts a new server in a goroutine. func (s *server) StartNonBlocking() error { handler := s.useRouter() - handler = s.useComponents(handler) - handler = s.useCors(handler) - handler = useAPIAuthentication(handler) - handler = s.useMetrics(handler) - handler = s.useTracing(handler) + + // These middlewares use net/http handlers + netHTTPHandler := s.useComponents(nethttpadaptor.NewNetHTTPHandlerFunc(handler)) + netHTTPHandler = s.useCors(netHTTPHandler) + netHTTPHandler = useAPIAuthentication(netHTTPHandler) + netHTTPHandler = s.useMetrics(netHTTPHandler) + netHTTPHandler = s.useTracing(netHTTPHandler) + netHTTPHandler = s.useMaxBodySize(netHTTPHandler) var listeners []net.Listener var profilingListeners []net.Listener @@ -123,18 +131,17 @@ func (s *server) StartNonBlocking() error { } for _, listener := range listeners { - // customServer is created in a loop because each instance + // srv is created in a loop because each instance // has a handle on the underlying listener. - customServer := &fasthttp.Server{ - Handler: handler, - MaxRequestBodySize: s.config.MaxRequestBodySize * 1024 * 1024, - ReadBufferSize: s.config.ReadBufferSize * 1024, - NoDefaultServerHeader: true, + srv := &http.Server{ + Handler: netHTTPHandler, + ReadHeaderTimeout: 10 * time.Second, + MaxHeaderBytes: s.config.ReadBufferSizeKB << 10, // To bytes } - s.servers = append(s.servers, customServer) + s.servers = append(s.servers, srv) go func(l net.Listener) { - if err := customServer.Serve(l); err != nil { + if err := srv.Serve(l); err != http.ErrServerClosed { log.Fatal(err) } }(listener) @@ -142,18 +149,21 @@ func (s *server) StartNonBlocking() error { if s.config.PublicPort != nil { publicHandler := s.usePublicRouter() - publicHandler = s.useMetrics(publicHandler) - publicHandler = s.useTracing(publicHandler) - healthServer := &fasthttp.Server{ - Handler: publicHandler, - MaxRequestBodySize: s.config.MaxRequestBodySize * 1024 * 1024, - NoDefaultServerHeader: true, + // Convert to net/http + netHTTPPublicHandler := s.useMetrics(nethttpadaptor.NewNetHTTPHandlerFunc(publicHandler)) + netHTTPPublicHandler = s.useTracing(netHTTPPublicHandler) + + healthServer := &http.Server{ + Addr: fmt.Sprintf(":%d", *s.config.PublicPort), + Handler: netHTTPPublicHandler, + ReadHeaderTimeout: 10 * time.Second, + MaxHeaderBytes: s.config.ReadBufferSizeKB << 10, // To bytes } s.servers = append(s.servers, healthServer) go func() { - if err := healthServer.ListenAndServe(fmt.Sprintf(":%d", *s.config.PublicPort)); err != nil { + if err := healthServer.ListenAndServe(); err != http.ErrServerClosed { log.Fatal(err) } }() @@ -179,15 +189,16 @@ func (s *server) StartNonBlocking() error { for _, listener := range profilingListeners { // profServer is created in a loop because each instance // has a handle on the underlying listener. - profServer := &fasthttp.Server{ - Handler: pprofhandler.PprofHandler, - MaxRequestBodySize: s.config.MaxRequestBodySize * 1024 * 1024, - NoDefaultServerHeader: true, + profServer := &http.Server{ + // pprof is automatically registered in the DefaultServerMux + Handler: http.DefaultServeMux, + ReadHeaderTimeout: 10 * time.Second, + MaxHeaderBytes: s.config.ReadBufferSizeKB << 10, // To bytes } s.servers = append(s.servers, profServer) go func(l net.Listener) { - if err := profServer.Serve(l); err != nil { + if err := profServer.Serve(l); err != http.ErrServerClosed { log.Fatal(err) } }(listener) @@ -198,19 +209,24 @@ func (s *server) StartNonBlocking() error { } func (s *server) Close() error { - var merr error + var err error for _, ln := range s.servers { // This calls `Close()` on the underlying listener. - if err := ln.Shutdown(); err != nil { - merr = multierror.Append(merr, err) + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + shutdownErr := ln.Shutdown(ctx) + // Error will be ErrServerClosed if everything went well + if errors.Is(shutdownErr, http.ErrServerClosed) { + shutdownErr = nil } + err = errors.Join(err, shutdownErr) + cancel() } - return merr + return err } -func (s *server) useTracing(next fasthttp.RequestHandler) fasthttp.RequestHandler { +func (s *server) useTracing(next http.Handler) http.Handler { if diagUtils.IsTracingEnabled(s.tracingSpec.SamplingRate) { log.Infof("enabled tracing http middleware") return diag.HTTPTraceMiddleware(next, s.config.AppID, s.tracingSpec) @@ -218,16 +234,27 @@ func (s *server) useTracing(next fasthttp.RequestHandler) fasthttp.RequestHandle return next } -func (s *server) useMetrics(next fasthttp.RequestHandler) fasthttp.RequestHandler { +func (s *server) useMetrics(next http.Handler) http.Handler { if s.metricSpec.Enabled { log.Infof("enabled metrics http middleware") - return diag.DefaultHTTPMonitoring.FastHTTPMiddleware(next) + return diag.DefaultHTTPMonitoring.HTTPMiddleware(next.ServeHTTP) } return next } +func (s *server) useMaxBodySize(next http.Handler) http.Handler { + if s.config.MaxRequestBodySizeMB > 0 { + maxSize := int64(s.config.MaxRequestBodySizeMB) << 20 // To bytes + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + r.Body = streams.LimitReadCloser(r.Body, maxSize) + next.ServeHTTP(w, r) + }) + } + return next +} + func (s *server) apiLoggingInfo(route string, next fasthttp.RequestHandler) fasthttp.RequestHandler { return func(ctx *fasthttp.RequestCtx) { fields := make(map[string]any, 2) @@ -259,47 +286,39 @@ func (s *server) usePublicRouter() fasthttp.RequestHandler { return router.Handler } -func (s *server) useComponents(next fasthttp.RequestHandler) fasthttp.RequestHandler { - return fasthttpadaptor.NewFastHTTPHandler( - s.pipeline.Apply( - nethttpadaptor.NewNetHTTPHandlerFunc(next), - ), - ) +func (s *server) useComponents(next http.Handler) http.Handler { + return s.pipeline.Apply(next) } -func (s *server) useCors(next fasthttp.RequestHandler) fasthttp.RequestHandler { +func (s *server) useCors(next http.Handler) http.Handler { + // TODO: Technically, if "AllowedOrigins" is "*", all origins should be allowed + // This behavior is not quite correct as in this case we are disallowing all origins if s.config.AllowedOrigins == corsDapr.DefaultAllowedOrigins { return next } - log.Infof("enabled cors http middleware") - origins := strings.Split(s.config.AllowedOrigins, ",") - corsHandler := s.getCorsHandler(origins) - return corsHandler.CorsMiddleware(next) + log.Infof("Enabled cors http middleware") + return cors.New(cors.Options{ + AllowedOrigins: strings.Split(s.config.AllowedOrigins, ","), + Debug: false, + }).Handler(next) } -func useAPIAuthentication(next fasthttp.RequestHandler) fasthttp.RequestHandler { +func useAPIAuthentication(next http.Handler) http.Handler { token := auth.GetAPIToken() if token == "" { return next } - log.Info("enabled token authentication on http server") + log.Info("Enabled token authentication on HTTP server") - return func(ctx *fasthttp.RequestCtx) { - v := ctx.Request.Header.Peek(authConsts.APITokenHeader) - if auth.ExcludedRoute(string(ctx.Request.URI().FullURI())) || string(v) == token { - ctx.Request.Header.Del(authConsts.APITokenHeader) - next(ctx) + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + v := r.Header.Get(authConsts.APITokenHeader) + if auth.ExcludedRoute(r.URL.String()) || v == token { + r.Header.Del(authConsts.APITokenHeader) + next.ServeHTTP(w, r) } else { - ctx.Error("invalid api token", http.StatusUnauthorized) + http.Error(w, "invalid api token", http.StatusUnauthorized) } - } -} - -func (s *server) getCorsHandler(allowedOrigins []string) *cors.CorsHandler { - return cors.NewCorsHandler(cors.Options{ - AllowedOrigins: allowedOrigins, - Debug: false, }) } diff --git a/pkg/http/server_test.go b/pkg/http/server_test.go index ded6f66e1c4..267c3418033 100644 --- a/pkg/http/server_test.go +++ b/pkg/http/server_test.go @@ -18,6 +18,8 @@ import ( "bytes" "encoding/json" "fmt" + "net/http" + "net/http/httptest" "runtime" "testing" "time" @@ -35,17 +37,6 @@ import ( "github.com/dapr/kit/logger" ) -type mockHost struct { - hasCORS bool -} - -func (m *mockHost) mockHandler() fasthttp.RequestHandler { - return func(ctx *fasthttp.RequestCtx) { - b := ctx.Response.Header.Peek("Access-Control-Allow-Origin") - m.hasCORS = len(b) > 0 - } -} - func newServer() server { return server{ config: ServerConfig{}, @@ -53,33 +44,42 @@ func newServer() server { } func TestCorsHandler(t *testing.T) { + hf := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }) + t.Run("with default cors, middleware not enabled", func(t *testing.T) { srv := newServer() srv.config.AllowedOrigins = cors.DefaultAllowedOrigins - mh := mockHost{} - h := srv.useCors(mh.mockHandler()) - r := &fasthttp.RequestCtx{ - Request: fasthttp.Request{}, + h := srv.useCors(hf) + w := httptest.NewRecorder() + r := &http.Request{ + Method: http.MethodOptions, + Header: http.Header{ + "Origin": []string{"*"}, + }, } - r.Request.Header.Set("Origin", "*") - h(r) + h.ServeHTTP(w, r) - assert.False(t, mh.hasCORS) + assert.Empty(t, w.Header().Get("Access-Control-Allow-Origin")) }) t.Run("with custom cors, middleware enabled", func(t *testing.T) { srv := newServer() srv.config.AllowedOrigins = "http://test.com" - mh := mockHost{} - h := srv.useCors(mh.mockHandler()) - r := &fasthttp.RequestCtx{ - Request: fasthttp.Request{}, + h := srv.useCors(hf) + w := httptest.NewRecorder() + r := &http.Request{ + Method: http.MethodOptions, + Header: http.Header{ + "Origin": []string{"http://test.com"}, + }, } - r.Request.Header.Set("Origin", "http://test.com") - h(r) - assert.True(t, mh.hasCORS) + h.ServeHTTP(w, r) + + assert.NotEmpty(t, w.Header().Get("Access-Control-Allow-Origin")) }) } @@ -411,13 +411,13 @@ func TestClose(t *testing.T) { port, err := freeport.GetFreePort() require.NoError(t, err) serverConfig := ServerConfig{ - AppID: "test", - HostAddress: "127.0.0.1", - Port: port, - APIListenAddresses: []string{"127.0.0.1"}, - MaxRequestBodySize: 4, - ReadBufferSize: 4, - EnableAPILogging: true, + AppID: "test", + HostAddress: "127.0.0.1", + Port: port, + APIListenAddresses: []string{"127.0.0.1"}, + MaxRequestBodySizeMB: 4, + ReadBufferSizeKB: 4, + EnableAPILogging: true, } a := &api{} server := NewServer(NewServerOpts{ @@ -437,13 +437,13 @@ func TestClose(t *testing.T) { port, err := freeport.GetFreePort() require.NoError(t, err) serverConfig := ServerConfig{ - AppID: "test", - HostAddress: "127.0.0.1", - Port: port, - APIListenAddresses: []string{"127.0.0.1"}, - MaxRequestBodySize: 4, - ReadBufferSize: 4, - EnableAPILogging: false, + AppID: "test", + HostAddress: "127.0.0.1", + Port: port, + APIListenAddresses: []string{"127.0.0.1"}, + MaxRequestBodySizeMB: 4, + ReadBufferSizeKB: 4, + EnableAPILogging: false, } a := &api{} server := NewServer(NewServerOpts{ diff --git a/pkg/http/universalapi.go b/pkg/http/universalapi.go index b3d00ecf1f9..efcba39633d 100644 --- a/pkg/http/universalapi.go +++ b/pkg/http/universalapi.go @@ -64,6 +64,8 @@ func UniversalFastHTTPHandler[T proto.Message, U proto.Message]( } return func(reqCtx *fasthttp.RequestCtx) { + var err error + // Need to use some reflection magic to allocate a value for the pointer of the generic type T in := reflect.New(rt).Interface().(T) @@ -72,7 +74,7 @@ func UniversalFastHTTPHandler[T proto.Message, U proto.Message]( // Read the response body and decode it as JSON using protojson body := reqCtx.PostBody() if len(body) > 0 { - err := pjsonDec.Unmarshal(body, in) + err = pjsonDec.Unmarshal(body, in) if err != nil { msg := NewErrorResponse("ERR_MALFORMED_REQUEST", err.Error()) respond(reqCtx, withError(fasthttp.StatusBadRequest, msg)) @@ -82,8 +84,6 @@ func UniversalFastHTTPHandler[T proto.Message, U proto.Message]( } } - var err error - // If we have an inModifier function, invoke it now if opts.InModifier != nil { in, err = opts.InModifier(reqCtx, in) diff --git a/pkg/messaging/direct_messaging.go b/pkg/messaging/direct_messaging.go index 45c396d2c58..baecd8bb403 100644 --- a/pkg/messaging/direct_messaging.go +++ b/pkg/messaging/direct_messaging.go @@ -440,7 +440,7 @@ func (d *directMessaging) invokeRemoteStream(ctx context.Context, clientV1 inter } return nil, err } - if chunk.Response == nil || chunk.Response.Status == nil || chunk.Response.Headers == nil { + if chunk.Response == nil || chunk.Response.Status == nil { return nil, errors.New("response does not contain the required fields in the leading chunk") } pr, pw := io.Pipe() diff --git a/pkg/messaging/v1/invoke_method_request.go b/pkg/messaging/v1/invoke_method_request.go index 9af34d84c98..0a57172944c 100644 --- a/pkg/messaging/v1/invoke_method_request.go +++ b/pkg/messaging/v1/invoke_method_request.go @@ -18,6 +18,7 @@ import ( "encoding/json" "errors" "io" + "net/http" "strings" "github.com/valyala/fasthttp" @@ -86,6 +87,12 @@ func (imr *InvokeMethodRequest) WithMetadata(md map[string][]string) *InvokeMeth return imr } +// WithHTTPHeaders sets HTTP request headers. +func (imr *InvokeMethodRequest) WithHTTPHeaders(header http.Header) *InvokeMethodRequest { + imr.r.Metadata = httpHeadersToInternalMetadata(header) + return imr +} + // WithFastHTTPHeaders sets fasthttp request headers. func (imr *InvokeMethodRequest) WithFastHTTPHeaders(header *fasthttp.RequestHeader) *InvokeMethodRequest { imr.r.Metadata = fasthttpHeadersToInternalMetadata(header) diff --git a/pkg/messaging/v1/invoke_method_response.go b/pkg/messaging/v1/invoke_method_response.go index 2aeadb1e282..f5dc1c3c537 100644 --- a/pkg/messaging/v1/invoke_method_response.go +++ b/pkg/messaging/v1/invoke_method_response.go @@ -51,6 +51,9 @@ func InternalInvokeResponse(pb *internalv1pb.InternalInvokeResponse) (*InvokeMet if pb.Message == nil { pb.Message = &commonv1pb.InvokeResponse{Data: nil} } + if pb.Headers == nil { + pb.Headers = map[string]*internalv1pb.ListStringValue{} + } return rsp, nil } diff --git a/pkg/messaging/v1/util.go b/pkg/messaging/v1/util.go index 89bc4585778..2057d11bb6a 100644 --- a/pkg/messaging/v1/util.go +++ b/pkg/messaging/v1/util.go @@ -118,13 +118,29 @@ func metadataToInternalMetadata(md map[string][]string) DaprInternalMetadata { return internalMD } +// httpHeadersToInternalMetadata converts http headers to Dapr internal metadata map. +func httpHeadersToInternalMetadata(header http.Header) DaprInternalMetadata { + internalMD := make(DaprInternalMetadata, len(header)) + for key, val := range header { + // Note: HTTP headers can never be binary (only gRPC supports binary headers) + if internalMD[key] == nil || len(internalMD[key].Values) == 0 { + internalMD[key] = &internalv1pb.ListStringValue{ + Values: val, + } + } else { + internalMD[key].Values = append(internalMD[key].Values, val...) + } + } + return internalMD +} + // Covers *fasthttp.RequestHeader and *fasthttp.ResponseHeader type fasthttpHeaders interface { Len() int VisitAll(f func(key []byte, value []byte)) } -// fasthttpHeadersToInternalMetadata converts fasthtt headers to Dapr internal metadata map. +// fasthttpHeadersToInternalMetadata converts fasthttp headers to Dapr internal metadata map. func fasthttpHeadersToInternalMetadata(header fasthttpHeaders) DaprInternalMetadata { internalMD := make(DaprInternalMetadata, header.Len()) header.VisitAll(func(key []byte, value []byte) { @@ -141,20 +157,6 @@ func fasthttpHeadersToInternalMetadata(header fasthttpHeaders) DaprInternalMetad return internalMD } -// Converts a fasthttp.RequestHeader to a map. -func fasthttpHeadersToMap(header fasthttpHeaders) map[string][]string { - md := map[string][]string{} - header.VisitAll(func(key []byte, value []byte) { - keyStr := string(key) - if len(md[keyStr]) == 0 { - md[keyStr] = []string{string(value)} - } else { - md[keyStr] = append(md[keyStr], string(value)) - } - }) - return md -} - // isPermanentHTTPHeader checks whether hdr belongs to the list of // permanent request headers maintained by IANA. // http://www.iana.org/assignments/message-headers/message-headers.xml diff --git a/pkg/messaging/v1/util_test.go b/pkg/messaging/v1/util_test.go index 5a582b93641..9183163bfcc 100644 --- a/pkg/messaging/v1/util_test.go +++ b/pkg/messaging/v1/util_test.go @@ -17,6 +17,7 @@ import ( "context" "encoding/base64" "fmt" + "net/http" "sort" "strings" "testing" @@ -389,17 +390,19 @@ func TestFasthttpHeadersToInternalMetadata(t *testing.T) { assert.Equal(t, []string{"test2", "test3"}, imd["Bar"].Values) } -func TestFasthttpHeadersToMap(t *testing.T) { - header := &fasthttp.RequestHeader{} +func TestHttpHeadersToInternalMetadata(t *testing.T) { + header := http.Header{} header.Add("foo", "test") header.Add("bar", "test2") header.Add("bar", "test3") - md := fasthttpHeadersToMap(header) + imd := httpHeadersToInternalMetadata(header) - require.NotEmpty(t, md) - require.NotEmpty(t, md["Foo"]) - assert.Equal(t, []string{"test"}, md["Foo"]) - require.NotEmpty(t, md["Bar"]) - assert.Equal(t, []string{"test2", "test3"}, md["Bar"]) + require.NotEmpty(t, imd) + require.NotEmpty(t, imd["Foo"]) + require.NotEmpty(t, imd["Foo"].Values) + assert.Equal(t, []string{"test"}, imd["Foo"].Values) + require.NotEmpty(t, imd["Bar"]) + require.NotEmpty(t, imd["Bar"].Values) + assert.Equal(t, []string{"test2", "test3"}, imd["Bar"].Values) } diff --git a/pkg/runtime/runtime.go b/pkg/runtime/runtime.go index fd1afb61c33..b0a4cd5fe2c 100644 --- a/pkg/runtime/runtime.go +++ b/pkg/runtime/runtime.go @@ -1550,9 +1550,9 @@ func (a *DaprRuntime) startHTTPServer(port int, publicPort *int, profilePort int ProfilePort: profilePort, AllowedOrigins: allowedOrigins, EnableProfiling: a.runtimeConfig.EnableProfiling, - MaxRequestBodySize: a.runtimeConfig.MaxRequestBodySize, + MaxRequestBodySizeMB: a.runtimeConfig.MaxRequestBodySize, UnixDomainSocket: a.runtimeConfig.UnixDomainSocket, - ReadBufferSize: a.runtimeConfig.ReadBufferSize, + ReadBufferSizeKB: a.runtimeConfig.ReadBufferSize, EnableAPILogging: a.runtimeConfig.EnableAPILogging, APILoggingObfuscateURLs: a.globalConfig.Spec.LoggingSpec.APILogging.ObfuscateURLs, APILogHealthChecks: !a.globalConfig.Spec.LoggingSpec.APILogging.OmitHealthChecks, diff --git a/tests/apps/service_invocation/app.go b/tests/apps/service_invocation/app.go index d4b7bc82404..f20728b4d4a 100644 --- a/tests/apps/service_invocation/app.go +++ b/tests/apps/service_invocation/app.go @@ -196,7 +196,12 @@ func withBodyHandler(w http.ResponseWriter, r *http.Request) { onBadRequest(w, err) return } - fmt.Printf("withBodyHandler body: %s\n", string(body)) + + if len(body) > 100 { + fmt.Printf("withBodyHandler body (first 100 bytes): %s\n", string(body[:100])) + } else { + fmt.Printf("withBodyHandler body: %s\n", string(body)) + } var s string err = json.Unmarshal(body, &s) if err != nil { @@ -1363,7 +1368,7 @@ func largeDataErrorServiceCall(w http.ResponseWriter, r *http.Request, isHTTP bo name: "4MB", }, { - size: 1024*1024*3 - 1, + size: 1024*1024*3 + 10, name: "4MB+", }, { @@ -1379,7 +1384,7 @@ func largeDataErrorServiceCall(w http.ResponseWriter, r *http.Request, isHTTP bo body := make([]byte, test.size) jsonBody, _ := json.Marshal(body) - fmt.Printf("largeDataErrorServiceCall - Request size: %d\n", len(jsonBody)) + fmt.Printf("largeDataErrorServiceCall %s - Request size: %d\n", test.name, len(jsonBody)) if isHTTP { resp, err := httpClient.Post(sanitizeHTTPURL(url), jsonContentType, bytes.NewReader(jsonBody)) diff --git a/tests/e2e/service_invocation/service_invocation_test.go b/tests/e2e/service_invocation/service_invocation_test.go index 2f4526a0eb0..cae4dda1369 100644 --- a/tests/e2e/service_invocation/service_invocation_test.go +++ b/tests/e2e/service_invocation/service_invocation_test.go @@ -854,8 +854,6 @@ func TestHeaders(t *testing.T) { require.NoError(t, err) - _ = assert.NotEmpty(t, requestHeaders["dapr-host"]) && - assert.True(t, strings.HasPrefix(requestHeaders["dapr-host"][0], "localhost:")) _ = assert.NotEmpty(t, requestHeaders["content-type"]) && assert.Equal(t, "application/grpc", requestHeaders["content-type"][0]) _ = assert.NotEmpty(t, requestHeaders[":authority"]) && @@ -1491,20 +1489,20 @@ func TestNegativeCases(t *testing.T) { var testResults negativeTestResult json.Unmarshal(resp, &testResults) - require.Nil(t, err) + require.NoError(t, err) require.True(t, testResults.MainCallSuccessful) require.Len(t, testResults.Results, 4) for _, result := range testResults.Results { switch result.TestCase { case "1MB": - require.True(t, result.CallSuccessful) + assert.True(t, result.CallSuccessful) case "4MB": - require.True(t, result.CallSuccessful) + assert.True(t, result.CallSuccessful) case "4MB+": - require.False(t, result.CallSuccessful) + assert.False(t, result.CallSuccessful) case "8MB": - require.False(t, result.CallSuccessful) + assert.False(t, result.CallSuccessful) } } }) @@ -1522,20 +1520,20 @@ func TestNegativeCases(t *testing.T) { var testResults negativeTestResult json.Unmarshal(resp, &testResults) - require.Nil(t, err) + require.NoError(t, err) require.True(t, testResults.MainCallSuccessful) require.Len(t, testResults.Results, 4) for _, result := range testResults.Results { switch result.TestCase { case "1MB": - require.True(t, result.CallSuccessful) + assert.True(t, result.CallSuccessful) case "4MB": - require.True(t, result.CallSuccessful) + assert.True(t, result.CallSuccessful) case "4MB+": - require.False(t, result.CallSuccessful) + assert.False(t, result.CallSuccessful) case "8MB": - require.False(t, result.CallSuccessful) + assert.False(t, result.CallSuccessful) } } }) diff --git a/utils/fasthttpadaptor/adaptor.go b/utils/fasthttpadaptor/adaptor.go index 08b698c0261..d26b5bb9d36 100644 --- a/utils/fasthttpadaptor/adaptor.go +++ b/utils/fasthttpadaptor/adaptor.go @@ -7,6 +7,8 @@ import ( "net/http" "github.com/valyala/fasthttp" + + "github.com/dapr/dapr/utils/responsewriter" ) // NewFastHTTPHandlerFunc wraps net/http handler func to fasthttp @@ -54,10 +56,10 @@ func NewFastHTTPHandler(h http.Handler) fasthttp.RequestHandler { return } - w := NetHTTPResponseWriter{w: ctx.Response.BodyWriter()} - h.ServeHTTP(&w, r.WithContext(ctx)) + w := responsewriter.EnsureResponseWriter(&NetHTTPResponseWriter{w: ctx.Response.BodyWriter()}) + h.ServeHTTP(w, r.WithContext(ctx)) - ctx.SetStatusCode(w.StatusCode()) + ctx.SetStatusCode(w.Status()) haveContentType := false for k, vv := range w.Header() { if k == fasthttp.HeaderContentType { @@ -80,7 +82,7 @@ func NewFastHTTPHandler(h http.Handler) fasthttp.RequestHandler { ctx.Response.Header.Set(fasthttp.HeaderContentType, http.DetectContentType(b[:l])) } - for k, v := range w.userValues { + for k, v := range w.AllUserValues() { ctx.SetUserValue(k, v) } } @@ -90,7 +92,6 @@ type NetHTTPResponseWriter struct { statusCode int h http.Header w io.Writer - userValues map[any]any } func (w *NetHTTPResponseWriter) StatusCode() int { @@ -114,12 +115,3 @@ func (w *NetHTTPResponseWriter) WriteHeader(statusCode int) { func (w *NetHTTPResponseWriter) Write(p []byte) (int, error) { return w.w.Write(p) } - -func (w *NetHTTPResponseWriter) SetUserValue(key any, value any) { - if w.userValues == nil { - w.userValues = map[any]any{} - } - w.userValues[key] = value -} - -func (w *NetHTTPResponseWriter) Flush() {} diff --git a/utils/nethttpadaptor/nethttpadaptor.go b/utils/nethttpadaptor/nethttpadaptor.go index 9f81a539e4e..17735cc9a94 100644 --- a/utils/nethttpadaptor/nethttpadaptor.go +++ b/utils/nethttpadaptor/nethttpadaptor.go @@ -14,6 +14,7 @@ limitations under the License. package nethttpadaptor import ( + "fmt" "io" "net" "net/http" @@ -21,7 +22,7 @@ import ( "github.com/valyala/fasthttp" - "github.com/dapr/dapr/utils/fasthttpadaptor" + diagUtils "github.com/dapr/dapr/pkg/diagnostics/utils" "github.com/dapr/kit/logger" ) @@ -41,7 +42,9 @@ func NewNetHTTPHandlerFunc(h fasthttp.RequestHandler) http.HandlerFunc { if r.Body != nil { reqBody, err := io.ReadAll(r.Body) if err != nil { - log.Errorf("error reading request body, %+v", err) + msg := fmt.Sprintf("error reading request body: %v", err) + log.Errorf(msg) + http.Error(w, msg, http.StatusBadRequest) return } c.Request.SetBody(reqBody) @@ -69,19 +72,25 @@ func NewNetHTTPHandlerFunc(h fasthttp.RequestHandler) http.HandlerFunc { } } - ctx := r.Context() - reqCtx, ok := ctx.(*fasthttp.RequestCtx) - if ok { + // Ensure user values are propagated if the context is a fasthttp.RequestCtx already + if reqCtx, ok := r.Context().(*fasthttp.RequestCtx); ok { reqCtx.VisitUserValuesAll(func(k any, v any) { c.SetUserValue(k, v) }) } + // Propagate the context + span := diagUtils.SpanFromContext(r.Context()) + if span != nil { + diagUtils.AddSpanToFasthttpContext(&c, span) + } + + // Invoke the handler h(&c) - if faw, ok := w.(*fasthttpadaptor.NetHTTPResponseWriter); ok { + if uvw, ok := w.(interface{ SetUserValue(key any, value any) }); ok { c.VisitUserValuesAll(func(k any, v any) { - faw.SetUserValue(k, v) + uvw.SetUserValue(k, v) }) } diff --git a/utils/responsewriter/README.md b/utils/responsewriter/README.md new file mode 100644 index 00000000000..bbb686b5726 --- /dev/null +++ b/utils/responsewriter/README.md @@ -0,0 +1,29 @@ +# responsewriter + +This package contains code forked from [`github.com/urfave/negroni`](https://github.com/urfave/negroni). It includes extensive changes, including the removal of features that depend on the rest of the framework, and additions of things we need to support integrations with the rest of Dapr, such as support for user values. + +Source commit: [b935227](https://github.com/urfave/negroni/tree/b935227d493b8a257f6e0b3c8d98ae576c90cd4a/) + +## License + +> The MIT License (MIT) +> +> Copyright (c) 2014 Jeremy Saenz +> +> Permission is hereby granted, free of charge, to any person obtaining a copy +> of this software and associated documentation files (the "Software"), to deal +> in the Software without restriction, including without limitation the rights +> to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +> copies of the Software, and to permit persons to whom the Software is +> furnished to do so, subject to the following conditions: +> +> The above copyright notice and this permission notice shall be included in all +> copies or substantial portions of the Software. +> +> THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +> IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +> FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +> AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +> LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +> OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +> SOFTWARE. \ No newline at end of file diff --git a/utils/responsewriter/response_writer.go b/utils/responsewriter/response_writer.go new file mode 100644 index 00000000000..19df31bd618 --- /dev/null +++ b/utils/responsewriter/response_writer.go @@ -0,0 +1,179 @@ +/* +Copyright 2022 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package responsewriter + +import ( + "io" + "net/http" +) + +// ResponseWriter is a wrapper around http.ResponseWriter that provides extra information about +// the response. It is recommended that middleware handlers use this construct to wrap a responsewriter +// if the functionality calls for it. +type ResponseWriter interface { + http.ResponseWriter + + // Status returns the status code of the response or 0 if the response has + // not been written + Status() int + // Written returns whether or not the ResponseWriter has been written. + Written() bool + // Size returns the size of the response body. + Size() int + // Before allows for a function to be called before the ResponseWriter has been written to. This is + // useful for setting headers or any other operations that must happen before a response has been written. + Before(func(ResponseWriter)) + // UserValue retrieves values from the object. + UserValue(key any) any + // UserValueString retrieves the user value and casts it as string. + // If the value is not a string, returns an empty string. + UserValueString(key any) string + // AllUserValues retrieves all user values. + AllUserValues() map[any]any + // SetUserValue sets arbitrary values in the object. + SetUserValue(key any, value any) +} + +type beforeFunc func(ResponseWriter) + +// NewResponseWriter creates a ResponseWriter that wraps a http.ResponseWriter +func NewResponseWriter(rw http.ResponseWriter) ResponseWriter { + return &responseWriter{ + ResponseWriter: rw, + } +} + +// EnsureResponseWriter creates a ResponseWriter that wraps a http.ResponseWriter, unless it's already a ResponseWriter. +func EnsureResponseWriter(rw http.ResponseWriter) ResponseWriter { + rwObj, ok := rw.(ResponseWriter) + if ok { + return rwObj + } + + return NewResponseWriter(rw) +} + +type responseWriter struct { + http.ResponseWriter + pendingStatus int + status int + size int + beforeFuncs []beforeFunc + callingBefores bool + userValues map[any]any +} + +func (rw *responseWriter) WriteHeader(s int) { + if rw.Written() { + return + } + + rw.pendingStatus = s + rw.callBefore() + + // Any of the rw.beforeFuncs may have written a header, + // so check again to see if any work is necessary. + if rw.Written() { + return + } + + rw.status = s + rw.ResponseWriter.WriteHeader(s) +} + +func (rw *responseWriter) Write(b []byte) (int, error) { + if !rw.Written() { + // The status will be StatusOK if WriteHeader has not been called yet + rw.WriteHeader(http.StatusOK) + } + size, err := rw.ResponseWriter.Write(b) + rw.size += size + return size, err +} + +// ReadFrom exposes underlying http.ResponseWriter to io.Copy and if it implements +// io.ReaderFrom, it can take advantage of optimizations such as sendfile, io.Copy +// with sync.Pool's buffer which is in http.(*response).ReadFrom and so on. +func (rw *responseWriter) ReadFrom(r io.Reader) (n int64, err error) { + if !rw.Written() { + // The status will be StatusOK if WriteHeader has not been called yet + rw.WriteHeader(http.StatusOK) + } + n, err = io.Copy(rw.ResponseWriter, r) + rw.size += int(n) + return +} + +// Satisfy http.ResponseController support (Go 1.20+) +func (rw *responseWriter) Unwrap() http.ResponseWriter { + return rw.ResponseWriter +} + +func (rw *responseWriter) Status() int { + if rw.Written() { + return rw.status + } + + return rw.pendingStatus +} + +func (rw *responseWriter) Size() int { + return rw.size +} + +func (rw *responseWriter) Written() bool { + return rw.status != 0 +} + +func (rw *responseWriter) Before(before func(ResponseWriter)) { + rw.beforeFuncs = append(rw.beforeFuncs, before) +} + +func (rw *responseWriter) callBefore() { + // Don't recursively call before() functions, to avoid infinite looping if + // one of them calls rw.WriteHeader again. + if rw.callingBefores { + return + } + + rw.callingBefores = true + defer func() { rw.callingBefores = false }() + + for i := len(rw.beforeFuncs) - 1; i >= 0; i-- { + rw.beforeFuncs[i](rw) + } +} + +func (rw *responseWriter) SetUserValue(key any, value any) { + if rw.userValues == nil { + rw.userValues = map[any]any{} + } + rw.userValues[key] = value +} + +func (rw *responseWriter) UserValue(key any) any { + if rw.userValues == nil { + return nil + } + return rw.userValues[key] +} + +func (rw *responseWriter) UserValueString(key any) string { + v, _ := rw.UserValue(key).(string) + return v +} + +func (rw *responseWriter) AllUserValues() map[any]any { + return rw.userValues +} diff --git a/utils/responsewriter/response_writer_test.go b/utils/responsewriter/response_writer_test.go new file mode 100644 index 00000000000..a74710ce49d --- /dev/null +++ b/utils/responsewriter/response_writer_test.go @@ -0,0 +1,215 @@ +/* +Copyright 2022 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package responsewriter + +import ( + "io" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestResponseWriterBeforeWrite(t *testing.T) { + rec := httptest.NewRecorder() + rw := NewResponseWriter(rec) + + require.Equal(t, rw.Status(), 0) + require.Equal(t, rw.Written(), false) +} + +func TestResponseWriterBeforeFuncHasAccessToStatus(t *testing.T) { + var status int + + rec := httptest.NewRecorder() + rw := NewResponseWriter(rec) + + rw.Before(func(w ResponseWriter) { + status = w.Status() + }) + rw.WriteHeader(http.StatusCreated) + + require.Equal(t, status, http.StatusCreated) +} + +func TestResponseWriterBeforeFuncCanChangeStatus(t *testing.T) { + rec := httptest.NewRecorder() + rw := NewResponseWriter(rec) + + // Always respond with 200. + rw.Before(func(w ResponseWriter) { + w.WriteHeader(http.StatusOK) + }) + + rw.WriteHeader(http.StatusBadRequest) + require.Equal(t, rec.Code, http.StatusOK) +} + +func TestResponseWriterBeforeFuncChangesStatusMultipleTimes(t *testing.T) { + rec := httptest.NewRecorder() + rw := NewResponseWriter(rec) + + rw.Before(func(w ResponseWriter) { + w.WriteHeader(http.StatusInternalServerError) + }) + rw.Before(func(w ResponseWriter) { + w.WriteHeader(http.StatusNotFound) + }) + + rw.WriteHeader(http.StatusOK) + require.Equal(t, rec.Code, http.StatusNotFound) +} + +func TestResponseWriterWritingString(t *testing.T) { + rec := httptest.NewRecorder() + rw := NewResponseWriter(rec) + + rw.Write([]byte("Hello world")) + + require.Equal(t, rec.Code, rw.Status()) + require.Equal(t, rec.Body.String(), "Hello world") + require.Equal(t, rw.Status(), http.StatusOK) + require.Equal(t, rw.Size(), 11) + require.Equal(t, rw.Written(), true) +} + +func TestResponseWriterWritingStrings(t *testing.T) { + rec := httptest.NewRecorder() + rw := NewResponseWriter(rec) + + rw.Write([]byte("Hello world")) + rw.Write([]byte("foo bar bat baz")) + + require.Equal(t, rec.Code, rw.Status()) + require.Equal(t, rec.Body.String(), "Hello worldfoo bar bat baz") + require.Equal(t, rw.Status(), http.StatusOK) + require.Equal(t, rw.Size(), 26) +} + +func TestResponseWriterWritingHeader(t *testing.T) { + rec := httptest.NewRecorder() + rw := NewResponseWriter(rec) + + rw.WriteHeader(http.StatusNotFound) + + require.Equal(t, rec.Code, rw.Status()) + require.Equal(t, rec.Body.String(), "") + require.Equal(t, rw.Status(), http.StatusNotFound) + require.Equal(t, rw.Size(), 0) +} + +func TestResponseWriterWritingHeaderTwice(t *testing.T) { + rec := httptest.NewRecorder() + rw := NewResponseWriter(rec) + + rw.WriteHeader(http.StatusNotFound) + rw.WriteHeader(http.StatusInternalServerError) + + require.Equal(t, rec.Code, rw.Status()) + require.Equal(t, rec.Body.String(), "") + require.Equal(t, rw.Status(), http.StatusNotFound) + require.Equal(t, rw.Size(), 0) +} + +func TestResponseWriterBefore(t *testing.T) { + rec := httptest.NewRecorder() + rw := NewResponseWriter(rec) + result := "" + + rw.Before(func(ResponseWriter) { + result += "foo" + }) + rw.Before(func(ResponseWriter) { + result += "bar" + }) + + rw.WriteHeader(http.StatusNotFound) + + require.Equal(t, rec.Code, rw.Status()) + require.Equal(t, rec.Body.String(), "") + require.Equal(t, rw.Status(), http.StatusNotFound) + require.Equal(t, rw.Size(), 0) + require.Equal(t, result, "barfoo") +} + +func TestResponseWriterUnwrap(t *testing.T) { + rec := httptest.NewRecorder() + rw := NewResponseWriter(rec) + switch v := rw.(type) { + case interface{ Unwrap() http.ResponseWriter }: + require.Equal(t, v.Unwrap(), rec) + default: + t.Error("Does not implement Unwrap()") + } +} + +// mockReader only implements io.Reader without other methods like WriterTo +type mockReader struct { + readStr string + eof bool +} + +func (r *mockReader) Read(p []byte) (n int, err error) { + if r.eof { + return 0, io.EOF + } + copy(p, r.readStr) + r.eof = true + return len(r.readStr), nil +} + +func TestResponseWriterWithoutReadFrom(t *testing.T) { + writeString := "Hello world" + + rec := httptest.NewRecorder() + rw := NewResponseWriter(rec) + + n, err := io.Copy(rw, &mockReader{readStr: writeString}) + require.Equal(t, err, nil) + require.Equal(t, rw.Status(), http.StatusOK) + require.Equal(t, rw.Written(), true) + require.Equal(t, rw.Size(), len(writeString)) + require.Equal(t, int(n), len(writeString)) + require.Equal(t, rec.Body.String(), writeString) +} + +type mockResponseWriterWithReadFrom struct { + *httptest.ResponseRecorder + writtenStr string +} + +func (rw *mockResponseWriterWithReadFrom) ReadFrom(r io.Reader) (n int64, err error) { + bytes, err := io.ReadAll(r) + if err != nil { + return 0, err + } + rw.writtenStr = string(bytes) + rw.ResponseRecorder.Write(bytes) + return int64(len(bytes)), nil +} + +func TestResponseWriterWithReadFrom(t *testing.T) { + writeString := "Hello world" + mrw := &mockResponseWriterWithReadFrom{ResponseRecorder: httptest.NewRecorder()} + rw := NewResponseWriter(mrw) + n, err := io.Copy(rw, &mockReader{readStr: writeString}) + require.Equal(t, err, nil) + require.Equal(t, rw.Status(), http.StatusOK) + require.Equal(t, rw.Written(), true) + require.Equal(t, rw.Size(), len(writeString)) + require.Equal(t, int(n), len(writeString)) + require.Equal(t, mrw.Body.String(), writeString) + require.Equal(t, mrw.writtenStr, writeString) +} diff --git a/utils/streams/limitreadcloser.go b/utils/streams/limitreadcloser.go index d9bb0032434..a03573ee8d3 100644 --- a/utils/streams/limitreadcloser.go +++ b/utils/streams/limitreadcloser.go @@ -1,5 +1,5 @@ /* -Copyright 2021 The Dapr Authors +Copyright 2023 The Dapr Authors Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at @@ -14,6 +14,7 @@ limitations under the License. package streams import ( + "errors" "io" ) @@ -23,7 +24,10 @@ Copyright 2009 The Go Authors. All rights reserved. License: BSD (https://github.com/golang/go/blob/go1.18.3/LICENSE) */ -// LimitReadCloser returns a ReadCloser that reads from r but stops with EOF after n bytes. +// ErrStreamTooLarge is returned by LimitReadCloser when the stream is too large. +var ErrStreamTooLarge = errors.New("stream too large") + +// LimitReadCloser returns a ReadCloser that reads from r but stops with ErrStreamTooLarge after n bytes. func LimitReadCloser(r io.ReadCloser, n int64) io.ReadCloser { return &limitReadCloser{ R: r, @@ -32,22 +36,46 @@ func LimitReadCloser(r io.ReadCloser, n int64) io.ReadCloser { } type limitReadCloser struct { - R io.ReadCloser - N int64 + R io.ReadCloser + N int64 + closed bool } func (l *limitReadCloser) Read(p []byte) (n int, err error) { - if l.N <= 0 || l.R == nil { + if l.N < 0 || l.R == nil { + return 0, ErrStreamTooLarge + } + if len(p) == 0 { + return 0, nil + } + if l.closed { return 0, io.EOF } - if int64(len(p)) > l.N { - p = p[0:l.N] + if int64(len(p)) > (l.N + 1) { + p = p[0:(l.N + 1)] } n, err = l.R.Read(p) l.N -= int64(n) + if l.N < 0 { + // Special case if we just read the "l.N+1" byte + if l.N == -1 { + n-- + } + if err == nil { + err = ErrStreamTooLarge + } + if !l.closed { + l.closed = true + l.R.Close() + } + } return } func (l *limitReadCloser) Close() error { + if l.closed { + return nil + } + l.closed = true return l.R.Close() } diff --git a/utils/streams/limitreadcloser_test.go b/utils/streams/limitreadcloser_test.go new file mode 100644 index 00000000000..92aec110c6a --- /dev/null +++ b/utils/streams/limitreadcloser_test.go @@ -0,0 +1,124 @@ +/* +Copyright 2023 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package streams + +import ( + "io" + "strings" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestLimitReadCloser(t *testing.T) { + t.Run("stream shorter than limit", func(t *testing.T) { + s := LimitReadCloser(io.NopCloser(strings.NewReader("e ho guardato dentro un'emozione")), 1000) + read, err := io.ReadAll(s) + require.NoError(t, err) + require.Equal(t, "e ho guardato dentro un'emozione", string(read)) + + // Reading again should return io.EOF + n, err := s.Read(read) + require.ErrorIs(t, err, io.EOF) + require.Equal(t, 0, n) + }) + + t.Run("stream has same length as limit", func(t *testing.T) { + s := LimitReadCloser(io.NopCloser(strings.NewReader("e ci ho visto dentro tanto amore")), 32) + read, err := io.ReadAll(s) + require.NoError(t, err) + require.Equal(t, "e ci ho visto dentro tanto amore", string(read)) + + // Reading again should return io.EOF + n, err := s.Read(read) + require.ErrorIs(t, err, io.EOF) + require.Equal(t, 0, n) + }) + + t.Run("stream longer than limit", func(t *testing.T) { + s := LimitReadCloser(io.NopCloser(strings.NewReader("che ho capito perche' non si comanda al cuore")), 21) + read, err := io.ReadAll(s) + require.Error(t, err) + require.ErrorIs(t, err, ErrStreamTooLarge) + require.Equal(t, "che ho capito perche'", string(read)) + + // Reading again should return ErrStreamTooLarge again + n, err := s.Read(read) + require.ErrorIs(t, err, ErrStreamTooLarge) + require.Equal(t, 0, n) + }) + + t.Run("stream longer than limit, read with byte slice", func(t *testing.T) { + s := LimitReadCloser(io.NopCloser(strings.NewReader("e va bene cosi'")), 4) + + read := make([]byte, 100) + n, err := s.Read(read) + require.Error(t, err) + require.ErrorIs(t, err, ErrStreamTooLarge) + require.Equal(t, "e va", string(read[0:n])) + + // Reading again should return ErrStreamTooLarge again + n, err = s.Read(read) + require.ErrorIs(t, err, ErrStreamTooLarge) + require.Equal(t, 0, n) + }) + + t.Run("read in two segments", func(t *testing.T) { + s := LimitReadCloser(io.NopCloser(strings.NewReader("senza parole")), 9) + + read := make([]byte, 5) + + n, err := s.Read(read) + require.NoError(t, err) + require.Equal(t, "senza", string(read[0:n])) + + n, err = s.Read(read) + require.Error(t, err) + require.ErrorIs(t, err, ErrStreamTooLarge) + require.Equal(t, " par", string(read[0:n])) + + // Reading again should return ErrStreamTooLarge again + n, err = s.Read(read) + require.ErrorIs(t, err, ErrStreamTooLarge) + require.Equal(t, 0, n) + }) + + t.Run("close early", func(t *testing.T) { + s := LimitReadCloser(io.NopCloser(strings.NewReader("senza parole")), 10) + + // Read 5 bytes then close + read := make([]byte, 5) + n, err := s.Read(read) + require.NoError(t, err) + require.Equal(t, "senza", string(read[0:n])) + + // Reading should now return io.EOF + err = s.Close() + require.NoError(t, err) + + n, err = s.Read(read) + require.Error(t, err) + require.ErrorIs(t, err, io.EOF) + require.Equal(t, 0, n) + }) + + t.Run("stream is nil", func(t *testing.T) { + s := LimitReadCloser(nil, 10) + + // Reading should return ErrStreamTooLarge again + n, err := s.Read(make([]byte, 10)) + require.ErrorIs(t, err, ErrStreamTooLarge) + require.Equal(t, 0, n) + }) +}