diff --git a/.gitignore b/.gitignore index 4cb4003c26e..f74a8d23353 100644 --- a/.gitignore +++ b/.gitignore @@ -15,6 +15,7 @@ github.com/ google test_report* +coverage.txt # Go Workspaces (introduced in Go 1.18+) go.work diff --git a/pkg/actors/actors.go b/pkg/actors/actors.go index 9348420105b..34ef166a8b5 100644 --- a/pkg/actors/actors.go +++ b/pkg/actors/actors.go @@ -329,7 +329,10 @@ type lookupActorRes struct { } func (a *actorsRuntime) Call(ctx context.Context, req *invokev1.InvokeMethodRequest) (*invokev1.InvokeMethodResponse, error) { - a.placement.WaitUntilPlacementTableIsReady() + err := a.placement.WaitUntilPlacementTableIsReady(ctx) + if err != nil { + return nil, fmt.Errorf("failed to wait for placement table readiness: %w", err) + } actor := req.Actor() // Retry here to allow placement table dissemination/rebalancing to happen. @@ -1031,13 +1034,14 @@ func (a *actorsRuntime) executeReminder(reminder *Reminder) error { } func (a *actorsRuntime) reminderRequiresUpdate(req *CreateReminderRequest, reminder *Reminder) bool { - if reminder.ActorID == req.ActorID && reminder.ActorType == req.ActorType && reminder.Name == req.Name && - (!reflect.DeepEqual(reminder.Data, req.Data) || reminder.DueTime != req.DueTime || reminder.Period != req.Period || - len(req.TTL) != 0 || (len(reminder.ExpirationTime) != 0 && len(req.TTL) == 0)) { - return true - } - - return false + return reminder.ActorID == req.ActorID && + reminder.ActorType == req.ActorType && + reminder.Name == req.Name && + (!reflect.DeepEqual(reminder.Data, req.Data) || + reminder.DueTime != req.DueTime || + reminder.Period != req.Period || + len(req.TTL) != 0 || + (len(reminder.ExpirationTime) != 0 && len(req.TTL) == 0)) } func (a *actorsRuntime) getReminder(reminderName string, actorType string, actorID string) (*Reminder, bool) { diff --git a/pkg/actors/actors_mock.go b/pkg/actors/actors_mock.go index bf9d5e5a533..38ef198e2f3 100644 --- a/pkg/actors/actors_mock.go +++ b/pkg/actors/actors_mock.go @@ -20,6 +20,7 @@ package actors import ( "context" + "errors" mock "github.com/stretchr/testify/mock" @@ -238,10 +239,19 @@ type FailingActors struct { } func (f *FailingActors) Call(ctx context.Context, req *v1.InvokeMethodRequest) (*v1.InvokeMethodResponse, error) { - if err := f.Failure.PerformFailure(req.Actor().ActorId); err != nil { + proto := req.Proto() + if proto == nil || proto.Actor == nil { + return nil, errors.New("proto.Actor is nil") + } + if err := f.Failure.PerformFailure(proto.Actor.ActorId); err != nil { return nil, err } - resp := v1.NewInvokeMethodResponse(200, "Success", nil) + var data []byte + if proto.Message != nil && proto.Message.Data != nil { + data = proto.Message.Data.Value + } + resp := v1.NewInvokeMethodResponse(200, "Success", nil). + WithRawData(data, "") return resp, nil } diff --git a/pkg/actors/internal/placement.go b/pkg/actors/internal/placement.go index 6d7980d9703..588880066cb 100644 --- a/pkg/actors/internal/placement.go +++ b/pkg/actors/internal/placement.go @@ -14,6 +14,7 @@ limitations under the License. package internal import ( + "context" "net" "sync" "sync/atomic" @@ -224,9 +225,15 @@ func (p *ActorPlacement) Stop() { } // WaitUntilPlacementTableIsReady waits until placement table is until table lock is unlocked. -func (p *ActorPlacement) WaitUntilPlacementTableIsReady() { - if p.tableIsBlocked.Load() { - <-p.unblockSignal +func (p *ActorPlacement) WaitUntilPlacementTableIsReady(ctx context.Context) error { + if !p.tableIsBlocked.Load() { + return nil + } + select { + case <-p.unblockSignal: + return nil + case <-ctx.Done(): + return ctx.Err() } } diff --git a/pkg/actors/internal/placement_test.go b/pkg/actors/internal/placement_test.go index 68516ee6112..3cba8a56d1c 100644 --- a/pkg/actors/internal/placement_test.go +++ b/pkg/actors/internal/placement_test.go @@ -14,6 +14,7 @@ limitations under the License. package internal import ( + "context" "fmt" "io" "net" @@ -24,6 +25,7 @@ import ( "github.com/phayes/freeport" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" @@ -201,24 +203,69 @@ func TestWaitUntilPlacementTableIsReady(t *testing.T) { []string{"actorOne", "actorTwo"}, appHealthFunc, tableUpdateFunc) - testPlacement.onPlacementOrder(&placementv1pb.PlacementOrder{Operation: "lock"}) + t.Run("already unlocked", func(t *testing.T) { + require.False(t, testPlacement.tableIsBlocked.Load()) - asserted := atomic.Bool{} - asserted.Store(false) - go func() { - testPlacement.WaitUntilPlacementTableIsReady() - asserted.Store(true) - }() + err := testPlacement.WaitUntilPlacementTableIsReady(context.Background()) + assert.NoError(t, err) + }) + + t.Run("wait until ready", func(t *testing.T) { + testPlacement.onPlacementOrder(&placementv1pb.PlacementOrder{Operation: "lock"}) + + testSuccessCh := make(chan struct{}) + go func() { + err := testPlacement.WaitUntilPlacementTableIsReady(context.Background()) + if assert.NoError(t, err) { + testSuccessCh <- struct{}{} + } + }() + + time.Sleep(50 * time.Millisecond) + require.True(t, testPlacement.tableIsBlocked.Load()) + + // unlock + testPlacement.onPlacementOrder(&placementv1pb.PlacementOrder{Operation: "unlock"}) + + // ensure that it is unlocked + select { + case <-time.After(500 * time.Millisecond): + t.Fatal("placement table not unlocked in 500ms") + case <-testSuccessCh: + // all good + } - time.Sleep(50 * time.Millisecond) - assert.False(t, asserted.Load()) + assert.False(t, testPlacement.tableIsBlocked.Load()) + }) - // unlock - testPlacement.onPlacementOrder(&placementv1pb.PlacementOrder{Operation: "unlock"}) + t.Run("abort on context canceled", func(t *testing.T) { + testPlacement.onPlacementOrder(&placementv1pb.PlacementOrder{Operation: "lock"}) + + testSuccessCh := make(chan struct{}) + ctx, cancel := context.WithCancel(context.Background()) + go func() { + err := testPlacement.WaitUntilPlacementTableIsReady(ctx) + if assert.ErrorIs(t, err, context.Canceled) { + testSuccessCh <- struct{}{} + } + }() + + time.Sleep(50 * time.Millisecond) + require.True(t, testPlacement.tableIsBlocked.Load()) + + // cancel context + cancel() + + // ensure that it is still locked + select { + case <-time.After(500 * time.Millisecond): + t.Fatal("did not return in 500ms") + case <-testSuccessCh: + // all good + } - // ensure that it is unlocked - time.Sleep(50 * time.Millisecond) - assert.True(t, asserted.Load()) + assert.True(t, testPlacement.tableIsBlocked.Load()) + }) } func TestLookupActor(t *testing.T) { diff --git a/pkg/channel/grpc/grpc_channel.go b/pkg/channel/grpc/grpc_channel.go index df267de0cf9..6606f012533 100644 --- a/pkg/channel/grpc/grpc_channel.go +++ b/pkg/channel/grpc/grpc_channel.go @@ -114,14 +114,12 @@ func (g *Channel) invokeMethodV1(ctx context.Context, req *invokev1.InvokeMethod var header, trailer grpcMetadata.MD - var opts []grpc.CallOption - opts = append( - opts, + opts := []grpc.CallOption{ grpc.Header(&header), grpc.Trailer(&trailer), - grpc.MaxCallSendMsgSize(g.maxRequestBodySizeMB<<20), - grpc.MaxCallRecvMsgSize(g.maxRequestBodySizeMB<<20), - ) + grpc.MaxCallSendMsgSize(g.maxRequestBodySizeMB << 20), + grpc.MaxCallRecvMsgSize(g.maxRequestBodySizeMB << 20), + } resp, err := g.appCallbackClient.OnInvoke(ctx, req.Message(), opts...) @@ -139,9 +137,11 @@ func (g *Channel) invokeMethodV1(ctx context.Context, req *invokev1.InvokeMethod rsp = invokev1.NewInvokeMethodResponse(int32(codes.OK), "", nil) } - rsp.WithHeaders(header).WithTrailers(trailer) + rsp.WithHeaders(header). + WithTrailers(trailer). + WithMessage(resp) - return rsp.WithMessage(resp), nil + return rsp, nil } // HealthProbe performs a health probe. diff --git a/pkg/channel/http/http_channel_test.go b/pkg/channel/http/http_channel_test.go index 53e5cee38f4..6a224782a82 100644 --- a/pkg/channel/http/http_channel_test.go +++ b/pkg/channel/http/http_channel_test.go @@ -21,11 +21,11 @@ import ( "net/http/httptest" "strconv" "sync" + "sync/atomic" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "go.uber.org/atomic" "github.com/dapr/dapr/pkg/config" invokev1 "github.com/dapr/dapr/pkg/messaging/v1" @@ -41,13 +41,13 @@ type testConcurrencyHandler struct { } func (t *testConcurrencyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { - cur := t.currentCalls.Inc() + cur := t.currentCalls.Add(1) if cur > t.maxCalls { t.testFailed = true } - t.currentCalls.Dec() + t.currentCalls.Add(-1) io.WriteString(w, r.URL.RawQuery) } @@ -365,7 +365,7 @@ func TestInvokeMethodMaxConcurrency(t *testing.T) { t.Run("single concurrency", func(t *testing.T) { handler := testConcurrencyHandler{ maxCalls: 1, - currentCalls: atomic.NewInt32(0), + currentCalls: &atomic.Int32{}, } server := httptest.NewServer(&handler) c := Channel{ @@ -398,7 +398,7 @@ func TestInvokeMethodMaxConcurrency(t *testing.T) { t.Run("10 concurrent calls", func(t *testing.T) { handler := testConcurrencyHandler{ maxCalls: 10, - currentCalls: atomic.NewInt32(0), + currentCalls: &atomic.Int32{}, } server := httptest.NewServer(&handler) c := Channel{ @@ -431,7 +431,7 @@ func TestInvokeMethodMaxConcurrency(t *testing.T) { t.Run("introduce failures", func(t *testing.T) { handler := testConcurrencyHandler{ maxCalls: 5, - currentCalls: atomic.NewInt32(0), + currentCalls: &atomic.Int32{}, } server := httptest.NewServer(&handler) c := Channel{ diff --git a/pkg/components/bindings/input_pluggable_test.go b/pkg/components/bindings/input_pluggable_test.go index f258fc14a6b..64c35d92fbc 100644 --- a/pkg/components/bindings/input_pluggable_test.go +++ b/pkg/components/bindings/input_pluggable_test.go @@ -21,25 +21,20 @@ import ( "os" "runtime" "sync" + "sync/atomic" "testing" + guuid "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/grpc" + "github.com/dapr/components-contrib/bindings" contribMetadata "github.com/dapr/components-contrib/metadata" - "github.com/dapr/dapr/pkg/components/pluggable" proto "github.com/dapr/dapr/pkg/proto/components/v1" testingGrpc "github.com/dapr/dapr/pkg/testing/grpc" - - guuid "github.com/google/uuid" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - - "go.uber.org/atomic" - "github.com/dapr/kit/logger" - - "google.golang.org/grpc" ) type inputBindingServer struct { diff --git a/pkg/grpc/api.go b/pkg/grpc/api.go index 654795aac62..51789d156ca 100644 --- a/pkg/grpc/api.go +++ b/pkg/grpc/api.go @@ -38,7 +38,6 @@ import ( "github.com/dapr/components-contrib/secretstores" "github.com/dapr/components-contrib/state" "github.com/dapr/components-contrib/workflows" - "github.com/dapr/dapr/pkg/acl" "github.com/dapr/dapr/pkg/actors" componentsV1alpha "github.com/dapr/dapr/pkg/apis/components/v1alpha1" "github.com/dapr/dapr/pkg/channel" @@ -333,83 +332,6 @@ func NewAPI(opts APIOpts) API { } } -// CallLocal is used for internal dapr to dapr calls. It is invoked by another Dapr instance with a request to the local app. -func (a *api) CallLocal(ctx context.Context, in *internalv1pb.InternalInvokeRequest) (*internalv1pb.InternalInvokeResponse, error) { - if a.appChannel == nil { - return nil, status.Error(codes.Internal, messages.ErrChannelNotFound) - } - - req, err := invokev1.InternalInvokeRequest(in) - if err != nil { - return nil, status.Errorf(codes.InvalidArgument, messages.ErrInternalInvokeRequest, err.Error()) - } - - if a.accessControlList != nil { - // An access control policy has been specified for the app. Apply the policies. - operation := req.Message().Method - var httpVerb commonv1pb.HTTPExtension_Verb //nolint:nosnakecase - // Get the http verb in case the application protocol is http - if a.appProtocol == config.HTTPProtocol && req.Metadata() != nil && len(req.Metadata()) > 0 { - httpExt := req.Message().GetHttpExtension() - if httpExt != nil { - httpVerb = httpExt.GetVerb() - } - } - callAllowed, errMsg := acl.ApplyAccessControlPolicies(ctx, operation, httpVerb, a.appProtocol, a.accessControlList) - - if !callAllowed { - return nil, status.Errorf(codes.PermissionDenied, errMsg) - } - } - - var callerAppID string - callerIDHeader, ok := req.Metadata()[invokev1.CallerIDHeader] - if ok && len(callerIDHeader.Values) > 0 { - callerAppID = callerIDHeader.Values[0] - } else { - callerAppID = "unknown" - } - - diag.DefaultMonitoring.ServiceInvocationRequestReceived(callerAppID, req.Message().Method) - - var statusCode int32 - defer func() { - diag.DefaultMonitoring.ServiceInvocationResponseSent(callerAppID, req.Message().Method, statusCode) - }() - - // stausCode will be read by the deferred method above - resp, err := a.appChannel.InvokeMethod(ctx, req) - if err != nil { - statusCode = int32(codes.Internal) - return nil, status.Errorf(codes.Internal, messages.ErrChannelInvoke, err) - } else { - statusCode = resp.Status().Code - } - - return resp.Proto(), nil -} - -// CallActor invokes a virtual actor. -func (a *api) CallActor(ctx context.Context, in *internalv1pb.InternalInvokeRequest) (*internalv1pb.InternalInvokeResponse, error) { - req, err := invokev1.InternalInvokeRequest(in) - if err != nil { - return nil, status.Errorf(codes.InvalidArgument, messages.ErrInternalInvokeRequest, err.Error()) - } - - // We don't do resiliency here as it is handled in the API layer. See InvokeActor(). - resp, err := a.actor.Call(ctx, req) - if err != nil { - // We have to remove the error to keep the body, so callers must re-inspect for the header in the actual response. - if errors.Is(err, actors.ErrDaprResponseHeader) { - return resp.Proto(), nil - } - - err = status.Errorf(codes.Internal, messages.ErrActorInvoke, err) - return nil, err - } - return resp.Proto(), nil -} - // validateAndGetPubsbuAndTopic validates the request parameters and returns the pubsub interface, pubsub name, topic name, rawPayload metadata if set // or an error. func (a *api) validateAndGetPubsubAndTopic(pubsubName, topic string, reqMeta map[string]string) (pubsub.PubSub, string, string, bool, error) { @@ -2030,8 +1952,6 @@ func (a *api) SubscribeConfigurationAlpha1(request *runtimev1pb.SubscribeConfigu subscribeKeys := make([]string, 0) // TODO(@halspang) provide a switch to use just resiliency or this. - newCtx, cancel := context.WithCancel(context.Background()) - defer cancel() if len(request.Keys) > 0 { subscribeKeys = append(subscribeKeys, request.Keys...) @@ -2050,7 +1970,7 @@ func (a *api) SubscribeConfigurationAlpha1(request *runtimev1pb.SubscribeConfigu // TODO(@laurence) deal with failed subscription and retires start := time.Now() - policyRunner := resiliency.NewRunner[string](newCtx, + policyRunner := resiliency.NewRunner[string](configurationServer.Context(), a.resiliency.ComponentOutboundPolicy(request.StoreName, resiliency.Configuration), ) subscribeID, err := policyRunner(func(ctx context.Context) (string, error) { @@ -2130,3 +2050,15 @@ func (a *api) UnsubscribeConfigurationAlpha1(ctx context.Context, request *runti Ok: true, }, nil } + +func (a *api) Close() error { + a.configurationSubscribeLock.Lock() + defer a.configurationSubscribeLock.Unlock() + + for k, stop := range a.configurationSubscribe { + close(stop) + delete(a.configurationSubscribe, k) + } + + return nil +} diff --git a/pkg/grpc/api_actor_test.go b/pkg/grpc/api_actor_test.go index df922b00b5d..204fd2250e4 100644 --- a/pkg/grpc/api_actor_test.go +++ b/pkg/grpc/api_actor_test.go @@ -66,7 +66,7 @@ func TestRegisterActorReminder(t *testing.T) { defer clientConn.Close() client := runtimev1pb.NewDaprClient(clientConn) - _, err := client.RegisterActorReminder(context.TODO(), &runtimev1pb.RegisterActorReminderRequest{}) + _, err := client.RegisterActorReminder(context.Background(), &runtimev1pb.RegisterActorReminderRequest{}) assert.Equal(t, codes.Internal, status.Code(err)) }) } @@ -83,7 +83,7 @@ func TestUnregisterActorTimer(t *testing.T) { defer clientConn.Close() client := runtimev1pb.NewDaprClient(clientConn) - _, err := client.UnregisterActorTimer(context.TODO(), &runtimev1pb.UnregisterActorTimerRequest{}) + _, err := client.UnregisterActorTimer(context.Background(), &runtimev1pb.UnregisterActorTimerRequest{}) assert.Equal(t, codes.Internal, status.Code(err)) }) } @@ -100,7 +100,7 @@ func TestRegisterActorTimer(t *testing.T) { defer clientConn.Close() client := runtimev1pb.NewDaprClient(clientConn) - _, err := client.RegisterActorTimer(context.TODO(), &runtimev1pb.RegisterActorTimerRequest{}) + _, err := client.RegisterActorTimer(context.Background(), &runtimev1pb.RegisterActorTimerRequest{}) assert.Equal(t, codes.Internal, status.Code(err)) }) } @@ -117,7 +117,7 @@ func TestGetActorState(t *testing.T) { defer clientConn.Close() client := runtimev1pb.NewDaprClient(clientConn) - _, err := client.GetActorState(context.TODO(), &runtimev1pb.GetActorStateRequest{}) + _, err := client.GetActorState(context.Background(), &runtimev1pb.GetActorStateRequest{}) assert.Equal(t, codes.Internal, status.Code(err)) }) @@ -150,7 +150,7 @@ func TestGetActorState(t *testing.T) { client := runtimev1pb.NewDaprClient(clientConn) // act - res, err := client.GetActorState(context.TODO(), &runtimev1pb.GetActorStateRequest{ + res, err := client.GetActorState(context.Background(), &runtimev1pb.GetActorStateRequest{ ActorId: "fakeActorID", ActorType: "fakeActorType", Key: "key1", @@ -176,7 +176,7 @@ func TestExecuteActorStateTransaction(t *testing.T) { defer clientConn.Close() client := runtimev1pb.NewDaprClient(clientConn) - _, err := client.ExecuteActorStateTransaction(context.TODO(), &runtimev1pb.ExecuteActorStateTransactionRequest{}) + _, err := client.ExecuteActorStateTransaction(context.Background(), &runtimev1pb.ExecuteActorStateTransactionRequest{}) assert.Equal(t, codes.Internal, status.Code(err)) }) @@ -220,7 +220,8 @@ func TestExecuteActorStateTransaction(t *testing.T) { client := runtimev1pb.NewDaprClient(clientConn) // act - res, err := client.ExecuteActorStateTransaction(context.TODO(), + res, err := client.ExecuteActorStateTransaction( + context.Background(), &runtimev1pb.ExecuteActorStateTransactionRequest{ ActorId: "fakeActorID", ActorType: "fakeActorType", @@ -256,7 +257,7 @@ func TestUnregisterActorReminder(t *testing.T) { defer clientConn.Close() client := runtimev1pb.NewDaprClient(clientConn) - _, err := client.UnregisterActorReminder(context.TODO(), &runtimev1pb.UnregisterActorReminderRequest{}) + _, err := client.UnregisterActorReminder(context.Background(), &runtimev1pb.UnregisterActorReminderRequest{}) assert.Equal(t, codes.Internal, status.Code(err)) }) } @@ -273,7 +274,7 @@ func TestInvokeActor(t *testing.T) { defer clientConn.Close() client := runtimev1pb.NewDaprClient(clientConn) - _, err := client.InvokeActor(context.TODO(), &runtimev1pb.InvokeActorRequest{}) + _, err := client.InvokeActor(context.Background(), &runtimev1pb.InvokeActorRequest{}) assert.Equal(t, codes.Internal, status.Code(err)) }) } diff --git a/pkg/grpc/api_daprinternal.go b/pkg/grpc/api_daprinternal.go new file mode 100644 index 00000000000..e68f62216eb --- /dev/null +++ b/pkg/grpc/api_daprinternal.go @@ -0,0 +1,118 @@ +/* +Copyright 2021 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package grpc + +import ( + "context" + "errors" + + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + + "github.com/dapr/dapr/pkg/acl" + "github.com/dapr/dapr/pkg/actors" + "github.com/dapr/dapr/pkg/config" + diag "github.com/dapr/dapr/pkg/diagnostics" + "github.com/dapr/dapr/pkg/messages" + invokev1 "github.com/dapr/dapr/pkg/messaging/v1" + commonv1pb "github.com/dapr/dapr/pkg/proto/common/v1" + internalv1pb "github.com/dapr/dapr/pkg/proto/internals/v1" +) + +// CallLocal is used for internal dapr to dapr calls. It is invoked by another Dapr instance with a request to the local app. +func (a *api) CallLocal(ctx context.Context, in *internalv1pb.InternalInvokeRequest) (*internalv1pb.InternalInvokeResponse, error) { + if a.appChannel == nil { + return nil, status.Error(codes.Internal, messages.ErrChannelNotFound) + } + + req, err := invokev1.InternalInvokeRequest(in) + if err != nil { + return nil, status.Errorf(codes.InvalidArgument, messages.ErrInternalInvokeRequest, err.Error()) + } + + err = a.callLocalValidateACL(ctx, req) + if err != nil { + return nil, err + } + + var callerAppID string + callerIDHeader, ok := req.Metadata()[invokev1.CallerIDHeader] + if ok && len(callerIDHeader.Values) > 0 { + callerAppID = callerIDHeader.Values[0] + } else { + callerAppID = "unknown" + } + + diag.DefaultMonitoring.ServiceInvocationRequestReceived(callerAppID, req.Message().Method) + + var statusCode int32 + defer func() { + diag.DefaultMonitoring.ServiceInvocationResponseSent(callerAppID, req.Message().Method, statusCode) + }() + + // stausCode will be read by the deferred method above + resp, err := a.appChannel.InvokeMethod(ctx, req) + if err != nil { + statusCode = int32(codes.Internal) + return nil, status.Errorf(codes.Internal, messages.ErrChannelInvoke, err) + } else { + statusCode = resp.Status().Code + } + + return resp.Proto(), nil +} + +// CallActor invokes a virtual actor. +func (a *api) CallActor(ctx context.Context, in *internalv1pb.InternalInvokeRequest) (*internalv1pb.InternalInvokeResponse, error) { + req, err := invokev1.InternalInvokeRequest(in) + if err != nil { + return nil, status.Errorf(codes.InvalidArgument, messages.ErrInternalInvokeRequest, err.Error()) + } + + // We don't do resiliency here as it is handled in the API layer. See InvokeActor(). + resp, err := a.actor.Call(ctx, req) + if err != nil { + // We have to remove the error to keep the body, so callers must re-inspect for the header in the actual response. + if errors.Is(err, actors.ErrDaprResponseHeader) { + return resp.Proto(), nil + } + + err = status.Errorf(codes.Internal, messages.ErrActorInvoke, err) + return nil, err + } + return resp.Proto(), nil +} + +// Used by CallLocal and CallLocalStream to check the request against the access control list +func (a *api) callLocalValidateACL(ctx context.Context, req *invokev1.InvokeMethodRequest) error { + if a.accessControlList != nil { + // An access control policy has been specified for the app. Apply the policies. + operation := req.Message().Method + var httpVerb commonv1pb.HTTPExtension_Verb //nolint:nosnakecase + // Get the HTTP verb in case the application protocol is "http" + if a.appProtocol == config.HTTPProtocol && req.Metadata() != nil && len(req.Metadata()) > 0 { + httpExt := req.Message().GetHttpExtension() + if httpExt != nil { + httpVerb = httpExt.GetVerb() + } + } + callAllowed, errMsg := acl.ApplyAccessControlPolicies(ctx, operation, httpVerb, a.appProtocol, a.accessControlList) + + if !callAllowed { + return status.Errorf(codes.PermissionDenied, errMsg) + } + } + + return nil +} diff --git a/pkg/grpc/api_test.go b/pkg/grpc/api_test.go index 1ee61c9a36d..247ccf4170b 100644 --- a/pkg/grpc/api_test.go +++ b/pkg/grpc/api_test.go @@ -32,6 +32,7 @@ import ( "github.com/phayes/freeport" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" "go.opentelemetry.io/otel/trace" epb "google.golang.org/genproto/googleapis/rpc/errdetails" "google.golang.org/grpc" @@ -148,13 +149,13 @@ type mockGRPCAPI struct{} func (m *mockGRPCAPI) CallLocal(ctx context.Context, in *internalv1pb.InternalInvokeRequest) (*internalv1pb.InternalInvokeResponse, error) { resp := invokev1.NewInvokeMethodResponse(0, "", nil) - resp.WithRawData(ExtractSpanContext(ctx), "text/plains") + resp.WithRawData(ExtractSpanContext(ctx), "text/plain") return resp.Proto(), nil } func (m *mockGRPCAPI) CallActor(ctx context.Context, in *internalv1pb.InternalInvokeRequest) (*internalv1pb.InternalInvokeResponse, error) { resp := invokev1.NewInvokeMethodResponse(0, "", nil) - resp.WithRawData(ExtractSpanContext(ctx), "text/plains") + resp.WithRawData(ExtractSpanContext(ctx), "text/plain") return resp.Proto(), nil } @@ -312,7 +313,10 @@ func startDaprAPIServer(port int, testAPIServer *api, token string) *grpc.Server } func createTestClient(port int) *grpc.ClientConn { - conn, err := grpc.Dial(fmt.Sprintf("localhost:%d", port), grpc.WithTransportCredentials(insecure.NewCredentials())) + conn, err := grpc.Dial( + fmt.Sprintf("localhost:%d", port), + grpc.WithTransportCredentials(insecure.NewCredentials()), + ) if err != nil { panic(err) } @@ -401,7 +405,10 @@ func TestCallLocal(t *testing.T) { port, _ := freeport.GetFreePort() mockAppChannel := new(channelt.MockAppChannel) - mockAppChannel.On("InvokeMethod", mock.AnythingOfType("*context.valueCtx"), mock.AnythingOfType("*v1.InvokeMethodRequest")).Return(nil, status.Error(codes.Unknown, "unknown error")) + mockAppChannel.On("InvokeMethod", + mock.MatchedBy(matchContextInterface), + mock.AnythingOfType("*v1.InvokeMethodRequest"), + ).Return(nil, status.Error(codes.Unknown, "unknown error")) fakeAPI := &api{ id: "fakeAPI", appChannel: mockAppChannel, @@ -446,7 +453,7 @@ func TestAPIToken(t *testing.T) { // Set up direct messaging mock mockDirectMessaging.Calls = nil // reset call count mockDirectMessaging.On("Invoke", - mock.AnythingOfType("*context.valueCtx"), + mock.MatchedBy(matchContextInterface), "fakeAppID", mock.AnythingOfType("*v1.InvokeMethodRequest")).Return(fakeResp, nil).Once() @@ -494,7 +501,7 @@ func TestAPIToken(t *testing.T) { // Set up direct messaging mock mockDirectMessaging.Calls = nil // reset call count mockDirectMessaging.On("Invoke", - mock.AnythingOfType("*context.valueCtx"), + mock.MatchedBy(matchContextInterface), "fakeAppID", mock.AnythingOfType("*v1.InvokeMethodRequest")).Return(fakeResp, nil).Once() @@ -536,7 +543,7 @@ func TestAPIToken(t *testing.T) { // Set up direct messaging mock mockDirectMessaging.Calls = nil // reset call count mockDirectMessaging.On("Invoke", - mock.AnythingOfType("*context.valueCtx"), + mock.MatchedBy(matchContextInterface), "fakeAppID", mock.AnythingOfType("*v1.InvokeMethodRequest")).Return(fakeResp, nil).Once() @@ -628,7 +635,7 @@ func TestInvokeServiceFromHTTPResponse(t *testing.T) { // Set up direct messaging mock mockDirectMessaging.Calls = nil // reset call count mockDirectMessaging.On("Invoke", - mock.AnythingOfType("*context.valueCtx"), + mock.MatchedBy(matchContextInterface), "fakeAppID", mock.AnythingOfType("*v1.InvokeMethodRequest")).Return(fakeResp, nil).Once() @@ -698,7 +705,7 @@ func TestInvokeServiceFromGRPCResponse(t *testing.T) { // Set up direct messaging mock mockDirectMessaging.Calls = nil // reset call count mockDirectMessaging.On("Invoke", - mock.AnythingOfType("*context.valueCtx"), + mock.MatchedBy(matchContextInterface), "fakeAppID", mock.AnythingOfType("*v1.InvokeMethodRequest")).Return(fakeResp, nil).Once() @@ -1145,7 +1152,7 @@ func TestGetState(t *testing.T) { func TestGetConfiguration(t *testing.T) { fakeConfigurationStore := &daprt.MockConfigurationStore{} fakeConfigurationStore.On("Get", - mock.AnythingOfType("*context.valueCtx"), + mock.MatchedBy(matchContextInterface), mock.MatchedBy(func(req *configuration.GetRequest) bool { return req.Keys[0] == goodKey })).Return( @@ -1157,7 +1164,7 @@ func TestGetConfiguration(t *testing.T) { }, }, nil) fakeConfigurationStore.On("Get", - mock.AnythingOfType("*context.valueCtx"), + mock.MatchedBy(matchContextInterface), mock.MatchedBy(func(req *configuration.GetRequest) bool { return req.Keys[0] == "good-key1" && req.Keys[1] == goodKey2 && req.Keys[2] == "good-key3" })).Return( @@ -1175,7 +1182,7 @@ func TestGetConfiguration(t *testing.T) { }, }, nil) fakeConfigurationStore.On("Get", - mock.AnythingOfType("*context.valueCtx"), + mock.MatchedBy(matchContextInterface), mock.MatchedBy(func(req *configuration.GetRequest) bool { return req.Keys[0] == "error-key" })).Return( @@ -1278,7 +1285,7 @@ func TestSubscribeConfiguration(t *testing.T) { fakeConfigurationStore := &daprt.MockConfigurationStore{} var tempReq *configuration.SubscribeRequest fakeConfigurationStore.On("Subscribe", - mock.AnythingOfType("*context.cancelCtx"), + mock.MatchedBy(matchContextInterface), mock.MatchedBy(func(req *configuration.SubscribeRequest) bool { tempReq = req return len(tempReq.Keys) == 1 && tempReq.Keys[0] == goodKey @@ -1296,7 +1303,7 @@ func TestSubscribeConfiguration(t *testing.T) { return true })).Return("id", nil) fakeConfigurationStore.On("Subscribe", - mock.AnythingOfType("*context.cancelCtx"), + mock.MatchedBy(matchContextInterface), mock.MatchedBy(func(req *configuration.SubscribeRequest) bool { tempReq = req return len(req.Keys) == 2 && req.Keys[0] == goodKey && req.Keys[1] == goodKey2 @@ -1317,7 +1324,7 @@ func TestSubscribeConfiguration(t *testing.T) { return true })).Return("id", nil) fakeConfigurationStore.On("Subscribe", - mock.AnythingOfType("*context.cancelCtx"), + mock.MatchedBy(matchContextInterface), mock.MatchedBy(func(req *configuration.SubscribeRequest) bool { return req.Keys[0] == "error-key" }), @@ -1431,12 +1438,12 @@ func TestUnSubscribeConfiguration(t *testing.T) { defer close(stop) var tempReq *configuration.SubscribeRequest fakeConfigurationStore.On("Unsubscribe", - mock.AnythingOfType("*context.valueCtx"), + mock.MatchedBy(matchContextInterface), mock.MatchedBy(func(req *configuration.UnsubscribeRequest) bool { return true })).Return(nil) fakeConfigurationStore.On("Subscribe", - mock.AnythingOfType("*context.cancelCtx"), + mock.MatchedBy(matchContextInterface), mock.MatchedBy(func(req *configuration.SubscribeRequest) bool { tempReq = req return len(req.Keys) == 1 && req.Keys[0] == goodKey @@ -1468,7 +1475,7 @@ func TestUnSubscribeConfiguration(t *testing.T) { return true })).Return(mockSubscribeID, nil) fakeConfigurationStore.On("Subscribe", - mock.AnythingOfType("*context.cancelCtx"), + mock.MatchedBy(matchContextInterface), mock.MatchedBy(func(req *configuration.SubscribeRequest) bool { tempReq = req return len(req.Keys) == 2 && req.Keys[0] == goodKey && req.Keys[1] == goodKey2 @@ -1606,7 +1613,7 @@ func TestUnSubscribeConfiguration(t *testing.T) { func TestUnsubscribeConfigurationErrScenario(t *testing.T) { fakeConfigurationStore := &daprt.MockConfigurationStore{} fakeConfigurationStore.On("Unsubscribe", - mock.AnythingOfType("*context.valueCtx"), + mock.MatchedBy(matchContextInterface), mock.MatchedBy(func(req *configuration.UnsubscribeRequest) bool { return req.ID == mockSubscribeID })).Return(nil) @@ -3248,16 +3255,20 @@ func TestServiceInvocationWithResiliency(t *testing.T) { client := runtimev1pb.NewDaprClient(clientConn) t.Run("Test invoke direct message retries with resiliency", func(t *testing.T) { - _, err := client.InvokeService(context.Background(), &runtimev1pb.InvokeServiceRequest{ + val := []byte("failingKey") + res, err := client.InvokeService(context.Background(), &runtimev1pb.InvokeServiceRequest{ Id: "failingApp", Message: &commonv1pb.InvokeRequest{ Method: "test", - Data: &anypb.Any{Value: []byte("failingKey")}, + Data: &anypb.Any{Value: val}, }, }) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, 2, failingDirectMessaging.Failure.CallCount("failingKey")) + require.NotNil(t, res) + require.NotNil(t, res.Data) + assert.Equal(t, val, res.Data.Value) }) t.Run("Test invoke direct message fails with timeout", func(t *testing.T) { @@ -3623,3 +3634,8 @@ func TestUnlock(t *testing.T) { assert.Equal(t, runtimev1pb.UnlockResponse_SUCCESS, resp.Status) //nolint:nosnakecase }) } + +func matchContextInterface(v any) bool { + _, ok := v.(context.Context) + return ok +} diff --git a/pkg/grpc/metadata/metadata.go b/pkg/grpc/metadata/metadata.go index 85eab01cc0c..d27586d974f 100644 --- a/pkg/grpc/metadata/metadata.go +++ b/pkg/grpc/metadata/metadata.go @@ -37,7 +37,7 @@ func FromIncomingContext(ctx context.Context) (MD, bool) { } // SetMetadataInContextUnary sets the metadata in the context for an unary gRPC invocation. -func SetMetadataInContextUnary(ctx context.Context, req any, info *grpcGo.UnaryServerInfo, handler grpcGo.UnaryHandler) (any, error) { +func SetMetadataInContextUnary(ctx context.Context, req any, _ *grpcGo.UnaryServerInfo, handler grpcGo.UnaryHandler) (any, error) { // Because metadata.FromIncomingContext re-allocates the entire map every time to ensure the keys are lowercased, we can do it once and re-use that after meta, ok := grpcMetadata.FromIncomingContext(ctx) if ok && len(meta) > 0 { @@ -47,7 +47,7 @@ func SetMetadataInContextUnary(ctx context.Context, req any, info *grpcGo.UnaryS } // SetMetadataInTapHandle sets the metadata in the context for a streaming gRPC invocation. -func SetMetadataInTapHandle(ctx context.Context, info *tap.Info) (context.Context, error) { +func SetMetadataInTapHandle(ctx context.Context, _ *tap.Info) (context.Context, error) { // Because metadata.FromIncomingContext re-allocates the entire map every time to ensure the keys are lowercased, we can do it once and re-use that after meta, ok := grpcMetadata.FromIncomingContext(ctx) if ok && len(meta) > 0 { diff --git a/pkg/grpc/server.go b/pkg/grpc/server.go index 27dc22b9336..84b76f529c3 100644 --- a/pkg/grpc/server.go +++ b/pkg/grpc/server.go @@ -173,6 +173,12 @@ func (s *server) Close() error { server.GracefulStop() } + if s.api != nil { + if closer, ok := s.api.(io.Closer); ok { + closer.Close() + } + } + return nil } diff --git a/pkg/http/api_test.go b/pkg/http/api_test.go index f37548836fe..d938f3852b4 100644 --- a/pkg/http/api_test.go +++ b/pkg/http/api_test.go @@ -946,13 +946,6 @@ func TestV1OutputBindingsEndpointsWithTracer(t *testing.T) { } func TestV1DirectMessagingEndpoints(t *testing.T) { - headerMetadata := map[string][]string{ - "Accept-Encoding": {"gzip"}, - "Content-Length": {"8"}, - "Content-Type": {"application/json"}, - "Host": {"localhost"}, - "User-Agent": {"Go-http-client/1.1"}, - } fakeDirectMessageResponse := invokev1.NewInvokeMethodResponse(200, "OK", nil) fakeDirectMessageResponse.WithRawData([]byte("fakeDirectMessageResponse"), "application/json") @@ -969,21 +962,18 @@ func TestV1DirectMessagingEndpoints(t *testing.T) { apiPath := "v1.0/invoke/fakeAppID/method/fakeMethod" fakeData := []byte("fakeData") - fakeReq := invokev1.NewInvokeMethodRequest("fakeMethod") - fakeReq.WithHTTPExtension(gohttp.MethodPost, "") - fakeReq.WithRawData(fakeData, "application/json") - fakeReq.WithMetadata(headerMetadata) - mockDirectMessaging.Calls = nil // reset call count - mockDirectMessaging.On("Invoke", + mockDirectMessaging.On( + "Invoke", mock.MatchedBy(func(a context.Context) bool { return true }), mock.MatchedBy(func(b string) bool { return b == "fakeAppID" }), mock.MatchedBy(func(c *invokev1.InvokeMethodRequest) bool { return true - })).Return(fakeDirectMessageResponse, nil).Once() + }), + ).Return(fakeDirectMessageResponse, nil).Once() // act resp := fakeServer.DoRequest("POST", apiPath, fakeData, nil) @@ -1055,11 +1045,6 @@ func TestV1DirectMessagingEndpoints(t *testing.T) { apiPath := "v1.0/invoke/fakeAppID/method/fakeMethod" fakeData := []byte("fakeData") - fakeReq := invokev1.NewInvokeMethodRequest("fakeMethod") - fakeReq.WithHTTPExtension(gohttp.MethodPost, "") - fakeReq.WithRawData(fakeData, "application/json") - fakeReq.WithMetadata(headerMetadata) - mockDirectMessaging.Calls = nil // reset call count mockDirectMessaging.On( @@ -1089,11 +1074,6 @@ func TestV1DirectMessagingEndpoints(t *testing.T) { apiPath := "v1.0/invoke/fakeAppID/method/fakeMethod" fakeData := []byte("fakeData") - fakeReq := invokev1.NewInvokeMethodRequest("fakeMethod") - fakeReq.WithHTTPExtension(gohttp.MethodPost, "") - fakeReq.WithRawData(fakeData, "application/json") - fakeReq.WithMetadata(headerMetadata) - mockDirectMessaging.Calls = nil // reset call count mockDirectMessaging.On( @@ -1115,21 +1095,18 @@ func TestV1DirectMessagingEndpoints(t *testing.T) { apiPath := "v1.0/invoke/fakeAppID/method/fakeMethod?param1=val1¶m2=val2" fakeData := []byte("fakeData") - fakeReq := invokev1.NewInvokeMethodRequest("fakeMethod") - fakeReq.WithHTTPExtension(gohttp.MethodPost, "param1=val1¶m2=val2") - fakeReq.WithRawData(fakeData, "application/json") - fakeReq.WithMetadata(headerMetadata) - mockDirectMessaging.Calls = nil // reset call count - mockDirectMessaging.On("Invoke", + mockDirectMessaging.On( + "Invoke", mock.MatchedBy(func(a context.Context) bool { return true }), mock.MatchedBy(func(b string) bool { return b == "fakeAppID" }), mock.MatchedBy(func(c *invokev1.InvokeMethodRequest) bool { return true - })).Return(fakeDirectMessageResponse, nil).Once() + }), + ).Return(fakeDirectMessageResponse, nil).Once() // act resp := fakeServer.DoRequest("POST", apiPath, fakeData, nil) @@ -1143,20 +1120,18 @@ func TestV1DirectMessagingEndpoints(t *testing.T) { t.Run("Invoke direct messaging - HEAD - 200 OK", func(t *testing.T) { apiPath := "v1.0/invoke/fakeAppID/method/fakeMethod?param1=val1¶m2=val2" - fakeReq := invokev1.NewInvokeMethodRequest("fakeMethod") - fakeReq.WithHTTPExtension(gohttp.MethodHead, "") - fakeReq.WithMetadata(headerMetadata) - mockDirectMessaging.Calls = nil // reset call count - mockDirectMessaging.On("Invoke", + mockDirectMessaging.On( + "Invoke", mock.MatchedBy(func(a context.Context) bool { return true }), mock.MatchedBy(func(b string) bool { return b == "fakeAppID" }), mock.MatchedBy(func(c *invokev1.InvokeMethodRequest) bool { return true - })).Return(fakeDirectMessageResponse, nil).Once() + }), + ).Return(fakeDirectMessageResponse, nil).Once() // act resp := fakeServer.DoRequest("HEAD", apiPath, nil, nil) @@ -1170,10 +1145,6 @@ func TestV1DirectMessagingEndpoints(t *testing.T) { t.Run("Invoke direct messaging route '/' - 200 OK", func(t *testing.T) { apiPath := "v1.0/invoke/fakeAppID/method/" - fakeReq := invokev1.NewInvokeMethodRequest("/") - fakeReq.WithHTTPExtension(gohttp.MethodGet, "") - fakeReq.WithMetadata(headerMetadata) - mockDirectMessaging.Calls = nil // reset call count mockDirectMessaging.On("Invoke", @@ -1198,11 +1169,6 @@ func TestV1DirectMessagingEndpoints(t *testing.T) { apiPath := "v1.0/invoke/fakeAppID/method/fakeMethod?param1=val1¶m2=val2" fakeData := []byte("fakeData") - fakeReq := invokev1.NewInvokeMethodRequest("fakeMethod") - fakeReq.WithHTTPExtension(gohttp.MethodPost, "param1=val1¶m2=val2") - fakeReq.WithRawData(fakeData, "application/json") - fakeReq.WithMetadata(headerMetadata) - mockDirectMessaging.Calls = nil // reset call count mockDirectMessaging.On("Invoke", @@ -1227,11 +1193,6 @@ func TestV1DirectMessagingEndpoints(t *testing.T) { apiPath := "v1.0/invoke/fakeAppID/method/fakeMethod?param1=val1¶m2=val2" fakeData := []byte("fakeData") - fakeReq := invokev1.NewInvokeMethodRequest("fakeMethod") - fakeReq.WithHTTPExtension(gohttp.MethodPost, "param1=val1¶m2=val2") - fakeReq.WithRawData(fakeData, "application/json") - fakeReq.WithMetadata(headerMetadata) - mockDirectMessaging.Calls = nil // reset call count mockDirectMessaging.On("Invoke", @@ -1256,14 +1217,6 @@ func TestV1DirectMessagingEndpoints(t *testing.T) { } func TestV1DirectMessagingEndpointsWithTracer(t *testing.T) { - headerMetadata := map[string][]string{ - "Accept-Encoding": {"gzip"}, - "Content-Length": {"8"}, - "Content-Type": {"application/json"}, - "Host": {"localhost"}, - "User-Agent": {"Go-http-client/1.1"}, - "X-Correlation-Id": {"fake-correlation-id"}, - } fakeDirectMessageResponse := invokev1.NewInvokeMethodResponse(200, "OK", nil) fakeDirectMessageResponse.WithRawData([]byte("fakeDirectMessageResponse"), "application/json") @@ -1288,20 +1241,17 @@ func TestV1DirectMessagingEndpointsWithTracer(t *testing.T) { apiPath := "v1.0/invoke/fakeAppID/method/fakeMethod" fakeData := []byte("fakeData") - fakeReq := invokev1.NewInvokeMethodRequest("fakeMethod") - fakeReq.WithHTTPExtension(gohttp.MethodPost, "") - fakeReq.WithRawData(fakeData, "application/json") - fakeReq.WithMetadata(headerMetadata) - mockDirectMessaging.Calls = nil // reset call count - mockDirectMessaging.On("Invoke", + mockDirectMessaging.On( + "Invoke", mock.MatchedBy(func(a context.Context) bool { return true }), mock.MatchedBy(func(b string) bool { return b == "fakeAppID" }), mock.MatchedBy(func(c *invokev1.InvokeMethodRequest) bool { return true - })).Return(fakeDirectMessageResponse, nil).Once() + }), + ).Return(fakeDirectMessageResponse, nil).Once() // act resp := fakeServer.DoRequest("POST", apiPath, fakeData, nil) @@ -1317,14 +1267,16 @@ func TestV1DirectMessagingEndpointsWithTracer(t *testing.T) { fakeData := []byte("fakeData") mockDirectMessaging.Calls = nil // reset call count - mockDirectMessaging.On("Invoke", + mockDirectMessaging.On( + "Invoke", mock.MatchedBy(func(a context.Context) bool { return true }), mock.MatchedBy(func(b string) bool { return b == "fakeAppID" }), mock.MatchedBy(func(c *invokev1.InvokeMethodRequest) bool { return true - })).Return(fakeDirectMessageResponse, nil).Once() + }), + ).Return(fakeDirectMessageResponse, nil).Once() // act resp := fakeServer.DoRequest("POST", apiPath, fakeData, nil, "dapr-app-id", "fakeAppID") @@ -1339,20 +1291,17 @@ func TestV1DirectMessagingEndpointsWithTracer(t *testing.T) { apiPath := "v1.0/invoke/fakeAppID/method/fakeMethod?param1=val1¶m2=val2" fakeData := []byte("fakeData") - fakeReq := invokev1.NewInvokeMethodRequest("fakeMethod") - fakeReq.WithHTTPExtension(gohttp.MethodPost, "param1=val1¶m2=val2") - fakeReq.WithRawData(fakeData, "application/json") - fakeReq.WithMetadata(headerMetadata) - mockDirectMessaging.Calls = nil // reset call count - mockDirectMessaging.On("Invoke", + mockDirectMessaging.On( + "Invoke", mock.MatchedBy(func(a context.Context) bool { return true }), mock.MatchedBy(func(b string) bool { return b == "fakeAppID" }), mock.MatchedBy(func(c *invokev1.InvokeMethodRequest) bool { return true - })).Return(fakeDirectMessageResponse, nil).Once() + }), + ).Return(fakeDirectMessageResponse, nil).Once() // act resp := fakeServer.DoRequest("POST", apiPath, fakeData, nil) @@ -1391,10 +1340,6 @@ func TestV1DirectMessagingEndpointsWithResiliency(t *testing.T) { apiPath := "v1.0/invoke/failingApp/method/fakeMethod" fakeData := []byte("failingKey") - fakeReq := invokev1.NewInvokeMethodRequest("fakeMethod") - fakeReq.WithHTTPExtension(gohttp.MethodPost, "") - fakeReq.WithRawData(fakeData, "application/json") - // act resp := fakeServer.DoRequest("POST", apiPath, fakeData, nil) @@ -1406,10 +1351,6 @@ func TestV1DirectMessagingEndpointsWithResiliency(t *testing.T) { apiPath := "v1.0/invoke/failingApp/method/fakeMethod" fakeData := []byte("timeoutKey") - fakeReq := invokev1.NewInvokeMethodRequest("fakeMethod") - fakeReq.WithHTTPExtension(gohttp.MethodPost, "") - fakeReq.WithRawData(fakeData, "application/json") - // act start := time.Now() resp := fakeServer.DoRequest("POST", apiPath, fakeData, nil) @@ -1424,10 +1365,6 @@ func TestV1DirectMessagingEndpointsWithResiliency(t *testing.T) { apiPath := "v1.0/invoke/failingApp/method/fakeMethod" fakeData := []byte("extraFailingKey") - fakeReq := invokev1.NewInvokeMethodRequest("fakeMethod") - fakeReq.WithHTTPExtension(gohttp.MethodPost, "") - fakeReq.WithRawData(fakeData, "application/json") - // act resp := fakeServer.DoRequest("POST", apiPath, fakeData, nil) @@ -1439,10 +1376,6 @@ func TestV1DirectMessagingEndpointsWithResiliency(t *testing.T) { apiPath := "v1.0/invoke/circuitBreakerApp/method/fakeMethod" fakeData := []byte("circuitBreakerKey") - fakeReq := invokev1.NewInvokeMethodRequest("fakeMethod") - fakeReq.WithHTTPExtension(gohttp.MethodPost, "") - fakeReq.WithRawData(fakeData, "application/json") - // Circuit Breaker trips on the 5th failure, stopping retries. resp := fakeServer.DoRequest("POST", apiPath, fakeData, nil) assert.Equal(t, 500, resp.StatusCode) @@ -2107,23 +2040,21 @@ func TestV1ActorEndpoints(t *testing.T) { t.Run("Direct Message - Forwards downstream status", func(t *testing.T) { apiPath := "v1.0/actors/fakeActorType/fakeActorID/method/method1" - headerMetadata := map[string][]string{ - "Accept-Encoding": {"gzip"}, - "Content-Length": {"8"}, - "Content-Type": {"application/json"}, - "Host": {"localhost"}, - "User-Agent": {"Go-http-client/1.1"}, - } mockActors := new(actors.MockActors) - invokeRequest := invokev1.NewInvokeMethodRequest("method1") - invokeRequest.WithActor("fakeActorType", "fakeActorID") fakeData := []byte("fakeData") - invokeRequest.WithHTTPExtension(gohttp.MethodPost, "") - invokeRequest.WithRawData(fakeData, "application/json") - invokeRequest.WithMetadata(headerMetadata) response := invokev1.NewInvokeMethodResponse(206, "OK", nil) - mockActors.On("Call", invokeRequest).Return(response, nil) + mockActors.On("Call", mock.MatchedBy(func(imr *invokev1.InvokeMethodRequest) bool { + m := imr.Proto() + if m.Actor == nil || m.Actor.ActorType != "fakeActorType" || m.Actor.ActorId != "fakeActorID" { + return false + } + + if m.Message == nil || m.Message.Data == nil || len(m.Message.Data.Value) == 0 || !bytes.Equal(m.Message.Data.Value, fakeData) { + return false + } + return true + })).Return(response, nil) testAPI.actor = mockActors @@ -2180,10 +2111,12 @@ func TestV1ActorEndpoints(t *testing.T) { t.Run("Direct Message - retries with resiliency", func(t *testing.T) { testAPI.actor = failingActors + msg := []byte("M'illumino d'immenso.") apiPath := fmt.Sprintf("v1.0/actors/failingActorType/%s/method/method1", "failingId") - resp := fakeServer.DoRequest("POST", apiPath, nil, nil) + resp := fakeServer.DoRequest("POST", apiPath, msg, nil) assert.Equal(t, 200, resp.StatusCode) + assert.Equal(t, msg, resp.RawBody) assert.Equal(t, 2, failingActors.Failure.CallCount("failingId")) }) @@ -2452,14 +2385,6 @@ func TestAPIToken(t *testing.T) { t.Setenv("DAPR_API_TOKEN", token) - fakeHeaderMetadata := map[string][]string{ - "Accept-Encoding": {"gzip"}, - "Content-Length": {"8"}, - "Content-Type": {"application/json"}, - "Host": {"localhost"}, - "User-Agent": {"Go-http-client/1.1"}, - } - fakeDirectMessageResponse := invokev1.NewInvokeMethodResponse(200, "OK", nil) fakeDirectMessageResponse.WithRawData([]byte("fakeDirectMessageResponse"), "application/json") @@ -2477,20 +2402,17 @@ func TestAPIToken(t *testing.T) { apiPath := "v1.0/invoke/fakeDaprID/method/fakeMethod" fakeData := []byte("fakeData") - fakeReq := invokev1.NewInvokeMethodRequest("fakeMethod") - fakeReq.WithHTTPExtension(gohttp.MethodPost, "") - fakeReq.WithRawData(fakeData, "application/json") - fakeReq.WithMetadata(fakeHeaderMetadata) - mockDirectMessaging.Calls = nil // reset call count - mockDirectMessaging.On("Invoke", + mockDirectMessaging.On( + "Invoke", mock.MatchedBy(func(a context.Context) bool { return true }), mock.MatchedBy(func(b string) bool { return b == "fakeDaprID" }), mock.MatchedBy(func(c *invokev1.InvokeMethodRequest) bool { return true - })).Return(fakeDirectMessageResponse, nil).Once() + }), + ).Return(fakeDirectMessageResponse, nil).Once() // act resp := fakeServer.DoRequestWithAPIToken("POST", apiPath, token, fakeData) @@ -2505,20 +2427,17 @@ func TestAPIToken(t *testing.T) { apiPath := "v1.0/invoke/fakeDaprID/method/fakeMethod" fakeData := []byte("fakeData") - fakeReq := invokev1.NewInvokeMethodRequest("fakeMethod") - fakeReq.WithHTTPExtension(gohttp.MethodPost, "") - fakeReq.WithRawData(fakeData, "application/json") - fakeReq.WithMetadata(fakeHeaderMetadata) - mockDirectMessaging.Calls = nil // reset call count - mockDirectMessaging.On("Invoke", + mockDirectMessaging.On( + "Invoke", mock.MatchedBy(func(a context.Context) bool { return true }), mock.MatchedBy(func(b string) bool { return b == "fakeDaprID" }), mock.MatchedBy(func(c *invokev1.InvokeMethodRequest) bool { return true - })).Return(fakeDirectMessageResponse, nil).Once() + }), + ).Return(fakeDirectMessageResponse, nil).Once() // act resp := fakeServer.DoRequestWithAPIToken("POST", apiPath, "", fakeData) @@ -2533,20 +2452,17 @@ func TestAPIToken(t *testing.T) { apiPath := "v1.0/invoke/fakeDaprID/method/fakeMethod" fakeData := []byte("fakeData") - fakeReq := invokev1.NewInvokeMethodRequest("fakeMethod") - fakeReq.WithHTTPExtension(gohttp.MethodPost, "") - fakeReq.WithRawData(fakeData, "application/json") - fakeReq.WithMetadata(fakeHeaderMetadata) - mockDirectMessaging.Calls = nil // reset call count - mockDirectMessaging.On("Invoke", + mockDirectMessaging.On( + "Invoke", mock.MatchedBy(func(a context.Context) bool { return true }), mock.MatchedBy(func(b string) bool { return b == "fakeDaprID" }), mock.MatchedBy(func(c *invokev1.InvokeMethodRequest) bool { return true - })).Return(fakeDirectMessageResponse, nil).Once() + }), + ).Return(fakeDirectMessageResponse, nil).Once() // act resp := fakeServer.DoRequestWithAPIToken("POST", apiPath, "4567", fakeData) @@ -2561,20 +2477,17 @@ func TestAPIToken(t *testing.T) { apiPath := "v1.0/invoke/fakeDaprID/method/fakeMethod" fakeData := []byte("fakeData") - fakeReq := invokev1.NewInvokeMethodRequest("fakeMethod") - fakeReq.WithHTTPExtension(gohttp.MethodPost, "") - fakeReq.WithRawData(fakeData, "application/json") - fakeReq.WithMetadata(fakeHeaderMetadata) - mockDirectMessaging.Calls = nil // reset call count - mockDirectMessaging.On("Invoke", + mockDirectMessaging.On( + "Invoke", mock.MatchedBy(func(a context.Context) bool { return true }), mock.MatchedBy(func(b string) bool { return b == "fakeDaprID" }), mock.MatchedBy(func(c *invokev1.InvokeMethodRequest) bool { return true - })).Return(fakeDirectMessageResponse, nil).Once() + }), + ).Return(fakeDirectMessageResponse, nil).Once() // act resp := fakeServer.DoRequest("POST", apiPath, fakeData, nil) @@ -2587,15 +2500,6 @@ func TestAPIToken(t *testing.T) { } func TestEmptyPipelineWithTracer(t *testing.T) { - fakeHeaderMetadata := map[string][]string{ - "Accept-Encoding": {"gzip"}, - "Content-Length": {"8"}, - "Content-Type": {"application/json"}, - "Host": {"localhost"}, - "User-Agent": {"Go-http-client/1.1"}, - "X-Correlation-Id": {"fake-correlation-id"}, - } - fakeDirectMessageResponse := invokev1.NewInvokeMethodResponse(200, "OK", nil) fakeDirectMessageResponse.WithRawData([]byte("fakeDirectMessageResponse"), "application/json") @@ -2620,11 +2524,6 @@ func TestEmptyPipelineWithTracer(t *testing.T) { apiPath := "v1.0/invoke/fakeDaprID/method/fakeMethod" fakeData := []byte("fakeData") - fakeReq := invokev1.NewInvokeMethodRequest("fakeMethod") - fakeReq.WithHTTPExtension(gohttp.MethodPost, "") - fakeReq.WithRawData(fakeData, "application/json") - fakeReq.WithMetadata(fakeHeaderMetadata) - mockDirectMessaging.Calls = nil // reset call count mockDirectMessaging.On("Invoke", mock.MatchedBy(func(a context.Context) bool { @@ -2999,15 +2898,6 @@ func buildHTTPPineline(spec config.PipelineSpec) httpMiddleware.Pipeline { } func TestSinglePipelineWithTracer(t *testing.T) { - fakeHeaderMetadata := map[string][]string{ - "Accept-Encoding": {"gzip"}, - "Content-Length": {"8"}, - "Content-Type": {"application/json"}, - "Host": {"localhost"}, - "User-Agent": {"Go-http-client/1.1"}, - "X-Correlation-Id": {"fake-correlation-id"}, - } - fakeDirectMessageResponse := invokev1.NewInvokeMethodResponse(200, "OK", nil) fakeDirectMessageResponse.WithRawData([]byte("fakeDirectMessageResponse"), "application/json") @@ -3041,11 +2931,6 @@ func TestSinglePipelineWithTracer(t *testing.T) { apiPath := "v1.0/invoke/fakeAppID/method/fakeMethod" fakeData := []byte("fakeData") - fakeReq := invokev1.NewInvokeMethodRequest("fakeMethod") - fakeReq.WithHTTPExtension(gohttp.MethodPost, "") - fakeReq.WithRawData([]byte("FAKEDATA"), "application/json") - fakeReq.WithMetadata(fakeHeaderMetadata) - mockDirectMessaging.Calls = nil // reset call count mockDirectMessaging.On("Invoke", mock.MatchedBy(func(a context.Context) bool { @@ -3066,15 +2951,6 @@ func TestSinglePipelineWithTracer(t *testing.T) { } func TestSinglePipelineWithNoTracing(t *testing.T) { - fakeHeaderMetadata := map[string][]string{ - "Accept-Encoding": {"gzip"}, - "Content-Length": {"8"}, - "Content-Type": {"application/json"}, - "Host": {"localhost"}, - "User-Agent": {"Go-http-client/1.1"}, - "X-Correlation-Id": {"fake-correlation-id"}, - } - fakeDirectMessageResponse := invokev1.NewInvokeMethodResponse(200, "OK", nil) fakeDirectMessageResponse.WithRawData([]byte("fakeDirectMessageResponse"), "application/json") @@ -3108,20 +2984,17 @@ func TestSinglePipelineWithNoTracing(t *testing.T) { apiPath := "v1.0/invoke/fakeAppID/method/fakeMethod" fakeData := []byte("fakeData") - fakeReq := invokev1.NewInvokeMethodRequest("fakeMethod") - fakeReq.WithHTTPExtension(gohttp.MethodPost, "") - fakeReq.WithRawData([]byte("FAKEDATA"), "application/json") - fakeReq.WithMetadata(fakeHeaderMetadata) - mockDirectMessaging.Calls = nil // reset call count - mockDirectMessaging.On("Invoke", + mockDirectMessaging.On( + "Invoke", mock.MatchedBy(func(a context.Context) bool { return true }), mock.MatchedBy(func(b string) bool { return b == "fakeAppID" }), mock.MatchedBy(func(c *invokev1.InvokeMethodRequest) bool { return true - })).Return(fakeDirectMessageResponse, nil).Once() + }), + ).Return(fakeDirectMessageResponse, nil).Once() // act resp := fakeServer.DoRequest("POST", apiPath, fakeData, nil) diff --git a/pkg/http/server.go b/pkg/http/server.go index 2c91c588bbf..1758a3690b7 100644 --- a/pkg/http/server.go +++ b/pkg/http/server.go @@ -21,6 +21,7 @@ import ( "net/http" "net/url" "regexp" + "strconv" "strings" cors "github.com/AdhityaRamadhanus/fasthttpcors" @@ -106,9 +107,10 @@ func (s *server) StartNonBlocking() error { listeners = append(listeners, l) } else { for _, apiListenAddress := range s.config.APIListenAddresses { - l, err := net.Listen("tcp", fmt.Sprintf("%s:%v", apiListenAddress, s.config.Port)) + addr := net.JoinHostPort(apiListenAddress, strconv.Itoa(s.config.Port)) + l, err := net.Listen("tcp", addr) if err != nil { - log.Warnf("Failed to listen on %v:%v with error: %v", apiListenAddress, s.config.Port, err) + log.Warnf("Failed to listen on %s with error: %v", addr, err) } else { listeners = append(listeners, l) } @@ -155,10 +157,11 @@ func (s *server) StartNonBlocking() error { if s.config.EnableProfiling { for _, apiListenAddress := range s.config.APIListenAddresses { - log.Infof("starting profiling server on %v:%v", apiListenAddress, s.config.ProfilePort) - pl, err := net.Listen("tcp", fmt.Sprintf("%s:%v", apiListenAddress, s.config.ProfilePort)) + addr := net.JoinHostPort(apiListenAddress, strconv.Itoa(s.config.ProfilePort)) + log.Infof("starting profiling server on %s", addr) + pl, err := net.Listen("tcp", addr) if err != nil { - log.Warnf("Failed to listen on %v:%v with error: %v", apiListenAddress, s.config.ProfilePort, err) + log.Warnf("Failed to listen on %s with error: %v", addr, err) } else { profilingListeners = append(profilingListeners, pl) } diff --git a/pkg/messaging/direct_messaging.go b/pkg/messaging/direct_messaging.go index d2feafc1068..a04853fa810 100644 --- a/pkg/messaging/direct_messaging.go +++ b/pkg/messaging/direct_messaging.go @@ -152,9 +152,8 @@ func (d *directMessaging) requestAppIDAndNamespace(targetAppID string) (string, } } -// invokeWithRetry will call a remote endpoint for the specified number of retries and will only retry in the case of transient failures -// TODO: check why https://github.com/grpc-ecosystem/go-grpc-middleware/blob/master/retry/examples_test.go doesn't recover the connection when target -// Server shuts down. +// invokeWithRetry will call a remote endpoint for the specified number of retries and will only retry in the case of transient failures. +// TODO: check why https://github.com/grpc-ecosystem/go-grpc-middleware/blob/master/retry/examples_test.go doesn't recover the connection when target server shuts down. func (d *directMessaging) invokeWithRetry( ctx context.Context, numRetries int, @@ -244,12 +243,10 @@ func (d *directMessaging) invokeRemote(ctx context.Context, appID, appNamespace, clientV1 := internalv1pb.NewServiceInvocationClient(conn) - var opts []grpc.CallOption - opts = append( - opts, - grpc.MaxCallRecvMsgSize(d.maxRequestBodySizeMB<<20), - grpc.MaxCallSendMsgSize(d.maxRequestBodySizeMB<<20), - ) + opts := []grpc.CallOption{ + grpc.MaxCallRecvMsgSize(d.maxRequestBodySizeMB << 20), + grpc.MaxCallSendMsgSize(d.maxRequestBodySizeMB << 20), + } start := time.Now() diag.DefaultMonitoring.ServiceInvocationRequestSent(appID, req.Message().Method) diff --git a/pkg/messaging/v1/invoke_method_response.go b/pkg/messaging/v1/invoke_method_response.go index 165fada9f67..5d323ebe395 100644 --- a/pkg/messaging/v1/invoke_method_response.go +++ b/pkg/messaging/v1/invoke_method_response.go @@ -115,7 +115,7 @@ func (imr *InvokeMethodResponse) IsHTTPResponse() bool { return imr.r.Status.Code >= 100 } -// Proto clones the internal InvokeMethodResponse pb object. +// Proto returns the internal InvokeMethodResponse Proto object. func (imr *InvokeMethodResponse) Proto() *internalv1pb.InternalInvokeResponse { return imr.r } diff --git a/pkg/operator/handlers/dapr_handler.go b/pkg/operator/handlers/dapr_handler.go index d9ec8ac2a0b..39a7a827462 100644 --- a/pkg/operator/handlers/dapr_handler.go +++ b/pkg/operator/handlers/dapr_handler.go @@ -16,11 +16,10 @@ import ( "sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/controller-runtime/pkg/controller" - "github.com/dapr/kit/logger" - "github.com/dapr/dapr/pkg/operator/monitoring" "github.com/dapr/dapr/pkg/validation" "github.com/dapr/dapr/utils" + "github.com/dapr/kit/logger" ) const ( diff --git a/pkg/resiliency/resiliency.go b/pkg/resiliency/resiliency.go index 9fa5a712e12..3b903bc59f4 100644 --- a/pkg/resiliency/resiliency.go +++ b/pkg/resiliency/resiliency.go @@ -23,20 +23,16 @@ import ( "sync" "time" - "github.com/dapr/dapr/utils" - "github.com/ghodss/yaml" grpcRetry "github.com/grpc-ecosystem/go-grpc-middleware/retry" lru "github.com/hashicorp/golang-lru" - - diag "github.com/dapr/dapr/pkg/diagnostics" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" resiliencyV1alpha "github.com/dapr/dapr/pkg/apis/resiliency/v1alpha1" + diag "github.com/dapr/dapr/pkg/diagnostics" operatorv1pb "github.com/dapr/dapr/pkg/proto/operator/v1" - - metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" - "github.com/dapr/dapr/pkg/resiliency/breaker" + "github.com/dapr/dapr/utils" "github.com/dapr/kit/config" "github.com/dapr/kit/logger" "github.com/dapr/kit/retry" @@ -66,6 +62,8 @@ const ( Pubsub ComponentType = "Pubsub" Secretstore ComponentType = "Secretstore" Statestore ComponentType = "Statestore" + Inbound ComponentDirection = "Inbound" + Outbound ComponentDirection = "Outbound" ) // ActorCircuitBreakerScope indicates the scope of the circuit breaker for an actor. @@ -179,12 +177,13 @@ type ( getPolicyLevels() []string getPolicyTypeName() PolicyTypeName } - EndpointPolicy struct{} - ActorPolicy struct{} - ComponentType string - ComponentPolicy struct { + EndpointPolicy struct{} + ActorPolicy struct{} + ComponentType string + ComponentDirection string + ComponentPolicy struct { componentType ComponentType - componentDirection string + componentDirection ComponentDirection } ) @@ -785,7 +784,7 @@ func (r *Resiliency) ComponentInboundPolicy(name string, componentType Component diag.DefaultResiliencyMonitoring.PolicyExecuted(r.name, r.namespace, diag.CircuitBreakerPolicy) } } else { - if defaultPolicies, ok := r.getDefaultPolicy(&ComponentPolicy{componentType: componentType, componentDirection: "Inbound"}); ok { + if defaultPolicies, ok := r.getDefaultPolicy(&ComponentPolicy{componentType: componentType, componentDirection: Inbound}); ok { r.log.Debugf("Found Default Policy for Component: %s: %+v", name, defaultPolicies) if defaultPolicies.Timeout != "" { policyDef.t = r.timeouts[defaultPolicies.Timeout] @@ -827,18 +826,21 @@ func (r *Resiliency) GetPolicy(target string, policyType PolicyType) *PolicyDesc componentPolicy, exists = r.components[target] if exists { policy, _ := policyType.(*ComponentPolicy) - if policy.componentDirection == "Inbound" { + switch policy.componentDirection { + case Inbound: policyName = PolicyNames{ Retry: componentPolicy.Inbound.Retry, CircuitBreaker: componentPolicy.Inbound.CircuitBreaker, Timeout: componentPolicy.Inbound.Timeout, } - } else { + case Outbound: policyName = PolicyNames{ Retry: componentPolicy.Outbound.Retry, CircuitBreaker: componentPolicy.Outbound.CircuitBreaker, Timeout: componentPolicy.Outbound.Timeout, } + default: + panic(fmt.Errorf("invalid component policy direction: '%s'", policy.componentDirection)) } } case Actor: @@ -1052,9 +1054,9 @@ func (*ComponentPolicy) getPolicyTypeName() PolicyTypeName { } var ComponentInboundPolicy = ComponentPolicy{ - componentDirection: "Inbound", + componentDirection: Inbound, } var ComponentOutboundPolicy = ComponentPolicy{ - componentDirection: "Outbound", + componentDirection: Outbound, } diff --git a/pkg/runtime/runtime_test.go b/pkg/runtime/runtime_test.go index d373edac2eb..e435879c455 100644 --- a/pkg/runtime/runtime_test.go +++ b/pkg/runtime/runtime_test.go @@ -20,6 +20,7 @@ import ( "encoding/base64" "encoding/hex" "encoding/json" + "errors" "fmt" "io" "net" @@ -36,25 +37,11 @@ import ( "testing" "time" - "github.com/golang/mock/gomock" - - "github.com/dapr/components-contrib/lock" - "github.com/dapr/components-contrib/middleware" - - "github.com/dapr/dapr/pkg/components" - bindingsLoader "github.com/dapr/dapr/pkg/components/bindings" - configurationLoader "github.com/dapr/dapr/pkg/components/configuration" - lockLoader "github.com/dapr/dapr/pkg/components/lock" - httpMiddlewareLoader "github.com/dapr/dapr/pkg/components/middleware/http" - pubsubLoader "github.com/dapr/dapr/pkg/components/pubsub" - stateLoader "github.com/dapr/dapr/pkg/components/state" - httpMiddleware "github.com/dapr/dapr/pkg/middleware/http" - "github.com/ghodss/yaml" + "github.com/golang/mock/gomock" "github.com/google/uuid" "github.com/hashicorp/go-multierror" "github.com/phayes/freeport" - "github.com/pkg/errors" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" @@ -71,21 +58,26 @@ import ( "github.com/dapr/components-contrib/bindings" "github.com/dapr/components-contrib/contenttype" + "github.com/dapr/components-contrib/lock" mdata "github.com/dapr/components-contrib/metadata" + "github.com/dapr/components-contrib/middleware" "github.com/dapr/components-contrib/nameresolution" "github.com/dapr/components-contrib/pubsub" "github.com/dapr/components-contrib/secretstores" "github.com/dapr/components-contrib/state" - - "github.com/dapr/kit/logger" - "github.com/dapr/kit/ptr" - componentsV1alpha1 "github.com/dapr/dapr/pkg/apis/components/v1alpha1" "github.com/dapr/dapr/pkg/apis/resiliency/v1alpha1" subscriptionsapi "github.com/dapr/dapr/pkg/apis/subscriptions/v1alpha1" channelt "github.com/dapr/dapr/pkg/channel/testing" + "github.com/dapr/dapr/pkg/components" + bindingsLoader "github.com/dapr/dapr/pkg/components/bindings" + configurationLoader "github.com/dapr/dapr/pkg/components/configuration" + lockLoader "github.com/dapr/dapr/pkg/components/lock" + httpMiddlewareLoader "github.com/dapr/dapr/pkg/components/middleware/http" nrLoader "github.com/dapr/dapr/pkg/components/nameresolution" + pubsubLoader "github.com/dapr/dapr/pkg/components/pubsub" secretstoresLoader "github.com/dapr/dapr/pkg/components/secretstores" + stateLoader "github.com/dapr/dapr/pkg/components/state" "github.com/dapr/dapr/pkg/config" "github.com/dapr/dapr/pkg/cors" diagUtils "github.com/dapr/dapr/pkg/diagnostics/utils" @@ -93,6 +85,7 @@ import ( "github.com/dapr/dapr/pkg/expr" pb "github.com/dapr/dapr/pkg/grpc/proxy/testservice" invokev1 "github.com/dapr/dapr/pkg/messaging/v1" + httpMiddleware "github.com/dapr/dapr/pkg/middleware/http" "github.com/dapr/dapr/pkg/modes" operatorv1pb "github.com/dapr/dapr/pkg/proto/operator/v1" runtimev1pb "github.com/dapr/dapr/pkg/proto/runtime/v1" @@ -102,6 +95,8 @@ import ( "github.com/dapr/dapr/pkg/scopes" sentryConsts "github.com/dapr/dapr/pkg/sentry/consts" daprt "github.com/dapr/dapr/pkg/testing" + "github.com/dapr/kit/logger" + "github.com/dapr/kit/ptr" ) const ( @@ -1379,17 +1374,18 @@ func TestInitPubSub(t *testing.T) { rt.appChannel = mockAppChannel // User App subscribes 2 topics via http app channel - fakeReq := invokev1.NewInvokeMethodRequest("dapr/subscribe") - fakeReq.WithHTTPExtension(http.MethodGet, "") - fakeReq.WithRawData(nil, "application/json") + fakeReq := invokev1.NewInvokeMethodRequest("dapr/subscribe"). + WithHTTPExtension(http.MethodGet, ""). + WithRawData(nil, "application/json") - fakeResp := invokev1.NewInvokeMethodResponse(200, "OK", nil) subs := getSubscriptionsJSONString( []string{"topic0", "topic1"}, // first pubsub - []string{"topic0"}) // second pubsub - fakeResp.WithRawData([]byte(subs), "application/json") + []string{"topic0"}, // second pubsub + ) + fakeResp := invokev1.NewInvokeMethodResponse(200, "OK", nil). + WithRawData([]byte(subs), "application/json") - mockAppChannel.On("InvokeMethod", mock.AnythingOfType("*context.emptyCtx"), fakeReq).Return(fakeResp, nil) + mockAppChannel.On("InvokeMethod", mock.MatchedBy(matchContextInterface), fakeReq).Return(fakeResp, nil) // act for _, comp := range pubsubComponents { @@ -1415,15 +1411,15 @@ func TestInitPubSub(t *testing.T) { rt.appChannel = mockAppChannel // User App subscribes to a topic via http app channel - fakeReq := invokev1.NewInvokeMethodRequest("dapr/subscribe") - fakeReq.WithHTTPExtension(http.MethodGet, "") - fakeReq.WithRawData(nil, "application/json") + fakeReq := invokev1.NewInvokeMethodRequest("dapr/subscribe"). + WithHTTPExtension(http.MethodGet, ""). + WithRawData(nil, "application/json") - fakeResp := invokev1.NewInvokeMethodResponse(200, "OK", nil) sub := getSubscriptionCustom("topic0", "customroute/topic0") - fakeResp.WithRawData([]byte(sub), "application/json") + fakeResp := invokev1.NewInvokeMethodResponse(200, "OK", nil). + WithRawData([]byte(sub), "application/json") - mockAppChannel.On("InvokeMethod", mock.AnythingOfType("*context.emptyCtx"), fakeReq).Return(fakeResp, nil) + mockAppChannel.On("InvokeMethod", mock.MatchedBy(matchContextInterface), fakeReq).Return(fakeResp, nil) // act for _, comp := range pubsubComponents { @@ -1446,12 +1442,12 @@ func TestInitPubSub(t *testing.T) { mockAppChannel := new(channelt.MockAppChannel) rt.appChannel = mockAppChannel - fakeReq := invokev1.NewInvokeMethodRequest("dapr/subscribe") - fakeReq.WithHTTPExtension(http.MethodGet, "") - fakeReq.WithRawData(nil, "application/json") + fakeReq := invokev1.NewInvokeMethodRequest("dapr/subscribe"). + WithHTTPExtension(http.MethodGet, ""). + WithRawData(nil, "application/json") fakeResp := invokev1.NewInvokeMethodResponse(404, "Not Found", nil) - mockAppChannel.On("InvokeMethod", mock.AnythingOfType("*context.emptyCtx"), fakeReq).Return(fakeResp, nil) + mockAppChannel.On("InvokeMethod", mock.MatchedBy(matchContextInterface), fakeReq).Return(fakeResp, nil) // act for _, comp := range pubsubComponents { @@ -1577,7 +1573,7 @@ func TestInitPubSub(t *testing.T) { fakeResp := invokev1.NewInvokeMethodResponse(200, "OK", nil) subs := getSubscriptionsJSONString([]string{"topic0"}, []string{"topic1"}) fakeResp.WithRawData([]byte(subs), "application/json") - mockAppChannel.On("InvokeMethod", mock.AnythingOfType("*context.emptyCtx"), fakeReq).Return(fakeResp, nil) + mockAppChannel.On("InvokeMethod", mock.MatchedBy(matchContextInterface), fakeReq).Return(fakeResp, nil) // act for _, comp := range pubsubComponents { @@ -1609,7 +1605,7 @@ func TestInitPubSub(t *testing.T) { fakeResp := invokev1.NewInvokeMethodResponse(200, "OK", nil) subs := getSubscriptionsJSONString([]string{"topic0", "topic1"}, []string{"topic0"}) fakeResp.WithRawData([]byte(subs), "application/json") - mockAppChannel.On("InvokeMethod", mock.AnythingOfType("*context.emptyCtx"), fakeReq).Return(fakeResp, nil) + mockAppChannel.On("InvokeMethod", mock.MatchedBy(matchContextInterface), fakeReq).Return(fakeResp, nil) // act for _, comp := range pubsubComponents { @@ -1641,7 +1637,7 @@ func TestInitPubSub(t *testing.T) { fakeResp := invokev1.NewInvokeMethodResponse(200, "OK", nil) subs := getSubscriptionsJSONString([]string{"topic3"}, []string{"topic5"}) fakeResp.WithRawData([]byte(subs), "application/json") - mockAppChannel.On("InvokeMethod", mock.AnythingOfType("*context.emptyCtx"), fakeReq).Return(fakeResp, nil) + mockAppChannel.On("InvokeMethod", mock.MatchedBy(matchContextInterface), fakeReq).Return(fakeResp, nil) // act for _, comp := range pubsubComponents { @@ -1673,7 +1669,7 @@ func TestInitPubSub(t *testing.T) { // topic0 is allowed, topic3 and topic5 are not subs := getSubscriptionsJSONString([]string{"topic0", "topic3"}, []string{"topic0", "topic5"}) fakeResp.WithRawData([]byte(subs), "application/json") - mockAppChannel.On("InvokeMethod", mock.AnythingOfType("*context.emptyCtx"), fakeReq).Return(fakeResp, nil) + mockAppChannel.On("InvokeMethod", mock.MatchedBy(matchContextInterface), fakeReq).Return(fakeResp, nil) // act for _, comp := range pubsubComponents { @@ -1816,7 +1812,7 @@ func TestInitPubSub(t *testing.T) { fakeResp := invokev1.NewInvokeMethodResponse(200, "OK", nil) subs := getSubscriptionsJSONString([]string{"topic0"}, []string{"topic1"}) fakeResp.WithRawData([]byte(subs), "application/json") - mockAppChannel.On("InvokeMethod", mock.AnythingOfType("*context.emptyCtx"), fakeReq).Return(fakeResp, nil) + mockAppChannel.On("InvokeMethod", mock.MatchedBy(matchContextInterface), fakeReq).Return(fakeResp, nil) // act for _, comp := range pubsubComponents { @@ -1858,7 +1854,7 @@ func TestInitPubSub(t *testing.T) { fakeResp := invokev1.NewInvokeMethodResponse(200, "OK", nil) subs := getSubscriptionsJSONString([]string{"topic0"}, []string{"topic0"}) fakeResp.WithRawData([]byte(subs), "application/json") - mockAppChannel.On("InvokeMethod", mock.AnythingOfType("*context.emptyCtx"), fakeReq).Return(fakeResp, nil) + mockAppChannel.On("InvokeMethod", mock.MatchedBy(matchContextInterface), fakeReq).Return(fakeResp, nil) // act for _, comp := range pubsubComponents { @@ -2768,7 +2764,7 @@ func TestErrorPublishedNonCloudEventGRPC(t *testing.T) { response, ok := reply.(*runtimev1pb.TopicEventResponse) if !ok { - return errors.Errorf("unexpected reply type: %s", reflect.TypeOf(reply)) + return fmt.Errorf("unexpected reply type: %s", reflect.TypeOf(reply)) } response.Status = tc.Status @@ -2827,7 +2823,7 @@ func TestOnNewPublishedMessage(t *testing.T) { // User App subscribes 1 topics via http app channel fakeResp := invokev1.NewInvokeMethodResponse(200, "OK", nil) - mockAppChannel.On("InvokeMethod", mock.AnythingOfType("*context.valueCtx"), fakeReq).Return(fakeResp, nil) + mockAppChannel.On("InvokeMethod", mock.MatchedBy(matchContextInterface), fakeReq).Return(fakeResp, nil) // act err := rt.publishMessageHTTP(context.Background(), testPubSubMessage) @@ -2864,7 +2860,7 @@ func TestOnNewPublishedMessage(t *testing.T) { fakeReqNoTraceID.WithHTTPExtension(http.MethodPost, "") fakeReqNoTraceID.WithRawData(message.data, contenttype.CloudEventContentType) fakeReqNoTraceID.WithCustomHTTPMetadata(testPubSubMessage.metadata) - mockAppChannel.On("InvokeMethod", mock.AnythingOfType("*context.emptyCtx"), fakeReqNoTraceID).Return(fakeResp, nil) + mockAppChannel.On("InvokeMethod", mock.MatchedBy(matchContextInterface), fakeReqNoTraceID).Return(fakeResp, nil) // act err = rt.publishMessageHTTP(context.Background(), message) @@ -2882,7 +2878,7 @@ func TestOnNewPublishedMessage(t *testing.T) { fakeResp := invokev1.NewInvokeMethodResponse(200, "OK", nil) fakeResp.WithRawData([]byte("OK"), "application/json") - mockAppChannel.On("InvokeMethod", mock.AnythingOfType("*context.valueCtx"), fakeReq).Return(fakeResp, nil) + mockAppChannel.On("InvokeMethod", mock.MatchedBy(matchContextInterface), fakeReq).Return(fakeResp, nil) // act err := rt.publishMessageHTTP(context.Background(), testPubSubMessage) @@ -2900,7 +2896,7 @@ func TestOnNewPublishedMessage(t *testing.T) { fakeResp := invokev1.NewInvokeMethodResponse(200, "OK", nil) fakeResp.WithRawData([]byte("{ \"status\": \"SUCCESS\"}"), "application/json") - mockAppChannel.On("InvokeMethod", mock.AnythingOfType("*context.valueCtx"), fakeReq).Return(fakeResp, nil) + mockAppChannel.On("InvokeMethod", mock.MatchedBy(matchContextInterface), fakeReq).Return(fakeResp, nil) // act err := rt.publishMessageHTTP(context.Background(), testPubSubMessage) @@ -2918,7 +2914,7 @@ func TestOnNewPublishedMessage(t *testing.T) { fakeResp := invokev1.NewInvokeMethodResponse(200, "OK", nil) fakeResp.WithRawData([]byte("{ \"status\": \"RETRY\"}"), "application/json") - mockAppChannel.On("InvokeMethod", mock.AnythingOfType("*context.valueCtx"), fakeReq).Return(fakeResp, nil) + mockAppChannel.On("InvokeMethod", mock.MatchedBy(matchContextInterface), fakeReq).Return(fakeResp, nil) // act err := rt.publishMessageHTTP(context.Background(), testPubSubMessage) @@ -2926,7 +2922,7 @@ func TestOnNewPublishedMessage(t *testing.T) { // assert var cloudEvent map[string]interface{} json.Unmarshal(testPubSubMessage.data, &cloudEvent) - expectedClientError := errors.Errorf("RETRY status returned from app while processing pub/sub event %v", cloudEvent["id"].(string)) + expectedClientError := fmt.Errorf("RETRY status returned from app while processing pub/sub event %v", cloudEvent["id"].(string)) assert.Equal(t, expectedClientError.Error(), err.Error()) mockAppChannel.AssertNumberOfCalls(t, "InvokeMethod", 1) }) @@ -2939,7 +2935,7 @@ func TestOnNewPublishedMessage(t *testing.T) { fakeResp := invokev1.NewInvokeMethodResponse(200, "OK", nil) fakeResp.WithRawData([]byte("{ \"status\": \"DROP\"}"), "application/json") - mockAppChannel.On("InvokeMethod", mock.AnythingOfType("*context.valueCtx"), fakeReq).Return(fakeResp, nil) + mockAppChannel.On("InvokeMethod", mock.MatchedBy(matchContextInterface), fakeReq).Return(fakeResp, nil) // act err := rt.publishMessageHTTP(context.Background(), testPubSubMessage) @@ -2957,7 +2953,7 @@ func TestOnNewPublishedMessage(t *testing.T) { fakeResp := invokev1.NewInvokeMethodResponse(200, "OK", nil) fakeResp.WithRawData([]byte("{ \"status\": \"not_valid\"}"), "application/json") - mockAppChannel.On("InvokeMethod", mock.AnythingOfType("*context.valueCtx"), fakeReq).Return(fakeResp, nil) + mockAppChannel.On("InvokeMethod", mock.MatchedBy(matchContextInterface), fakeReq).Return(fakeResp, nil) // act err := rt.publishMessageHTTP(context.Background(), testPubSubMessage) @@ -2975,7 +2971,7 @@ func TestOnNewPublishedMessage(t *testing.T) { fakeResp := invokev1.NewInvokeMethodResponse(200, "OK", nil) fakeResp.WithRawData([]byte("{ \"message\": \"empty status\"}"), "application/json") - mockAppChannel.On("InvokeMethod", mock.AnythingOfType("*context.valueCtx"), fakeReq).Return(fakeResp, nil) + mockAppChannel.On("InvokeMethod", mock.MatchedBy(matchContextInterface), fakeReq).Return(fakeResp, nil) // act err := rt.publishMessageHTTP(context.Background(), testPubSubMessage) @@ -2993,7 +2989,7 @@ func TestOnNewPublishedMessage(t *testing.T) { fakeResp := invokev1.NewInvokeMethodResponse(200, "OK", nil) fakeResp.WithRawData([]byte("{ \"message\": \"success\"}"), "application/json") - mockAppChannel.On("InvokeMethod", mock.AnythingOfType("*context.valueCtx"), fakeReq).Return(fakeResp, nil) + mockAppChannel.On("InvokeMethod", mock.MatchedBy(matchContextInterface), fakeReq).Return(fakeResp, nil) // act err := rt.publishMessageHTTP(context.Background(), testPubSubMessage) @@ -3008,13 +3004,13 @@ func TestOnNewPublishedMessage(t *testing.T) { rt.appChannel = mockAppChannel invokeError := errors.New("error invoking method") - mockAppChannel.On("InvokeMethod", mock.AnythingOfType("*context.valueCtx"), fakeReq).Return(nil, invokeError) + mockAppChannel.On("InvokeMethod", mock.MatchedBy(matchContextInterface), fakeReq).Return(nil, invokeError) // act err := rt.publishMessageHTTP(context.Background(), testPubSubMessage) // assert - expectedError := errors.Wrap(invokeError, "error from app channel while sending pub/sub event to app") + expectedError := fmt.Errorf("error from app channel while sending pub/sub event to app: %w", invokeError) assert.Equal(t, expectedError.Error(), err.Error(), "expected errors to match") mockAppChannel.AssertNumberOfCalls(t, "InvokeMethod", 1) }) @@ -3027,7 +3023,7 @@ func TestOnNewPublishedMessage(t *testing.T) { fakeResp := invokev1.NewInvokeMethodResponse(404, "Not Found", nil) fakeResp.WithRawData([]byte(clientError.Error()), "application/json") - mockAppChannel.On("InvokeMethod", mock.AnythingOfType("*context.valueCtx"), fakeReq).Return(fakeResp, nil) + mockAppChannel.On("InvokeMethod", mock.MatchedBy(matchContextInterface), fakeReq).Return(fakeResp, nil) // act err := rt.publishMessageHTTP(context.Background(), testPubSubMessage) @@ -3045,7 +3041,7 @@ func TestOnNewPublishedMessage(t *testing.T) { fakeResp := invokev1.NewInvokeMethodResponse(500, "Internal Error", nil) fakeResp.WithRawData([]byte(clientError.Error()), "application/json") - mockAppChannel.On("InvokeMethod", mock.AnythingOfType("*context.valueCtx"), fakeReq).Return(fakeResp, nil) + mockAppChannel.On("InvokeMethod", mock.MatchedBy(matchContextInterface), fakeReq).Return(fakeResp, nil) // act err := rt.publishMessageHTTP(context.Background(), testPubSubMessage) @@ -3053,7 +3049,7 @@ func TestOnNewPublishedMessage(t *testing.T) { // assert var cloudEvent map[string]interface{} json.Unmarshal(testPubSubMessage.data, &cloudEvent) - expectedClientError := errors.Errorf("retriable error returned from app while processing pub/sub event %v, topic: %v, body: Internal Error. status code returned: 500", cloudEvent["id"].(string), cloudEvent["topic"]) + expectedClientError := fmt.Errorf("retriable error returned from app while processing pub/sub event %v, topic: %v, body: Internal Error. status code returned: 500", cloudEvent["id"].(string), cloudEvent["topic"]) assert.Equal(t, expectedClientError.Error(), err.Error()) mockAppChannel.AssertNumberOfCalls(t, "InvokeMethod", 1) }) @@ -3806,9 +3802,13 @@ func TestPubSubDeadLetter(t *testing.T) { mockAppChannel := new(channelt.MockAppChannel) rt.appChannel = mockAppChannel - mockAppChannel.On("InvokeMethod", mock.MatchedBy(matchContextInterface), req).Return(fakeResp, nil) + mockAppChannel. + On("InvokeMethod", mock.MatchedBy(matchContextInterface), req). + Return(fakeResp, nil) // Mock send message to app returns error. - mockAppChannel.On("InvokeMethod", mock.MatchedBy(matchContextInterface), mock.Anything).Return(nil, errors.New("failed to send")) + mockAppChannel. + On("InvokeMethod", mock.MatchedBy(matchContextInterface), mock.Anything). + Return(nil, errors.New("failed to send")) require.NoError(t, rt.initPubSub(pubsubComponent)) rt.startSubscriptions() @@ -3850,9 +3850,13 @@ func TestPubSubDeadLetter(t *testing.T) { mockAppChannel := new(channelt.MockAppChannel) rt.appChannel = mockAppChannel - mockAppChannel.On("InvokeMethod", mock.AnythingOfType("*context.emptyCtx"), req).Return(fakeResp, nil) + mockAppChannel. + On("InvokeMethod", mock.MatchedBy(matchContextInterface), req). + Return(fakeResp, nil) // Mock send message to app returns error. - mockAppChannel.On("InvokeMethod", mock.AnythingOfType("*context.timerCtx"), mock.Anything).Return(nil, errors.New("failed to send")) + mockAppChannel. + On("InvokeMethod", mock.MatchedBy(matchContextInterface), mock.Anything). + Return(nil, errors.New("failed to send")) require.NoError(t, rt.initPubSub(pubsubComponent)) rt.startSubscriptions() @@ -4187,8 +4191,8 @@ func TestReadInputBindings(t *testing.T) { fakeResp := invokev1.NewInvokeMethodResponse(200, "OK", nil) fakeResp.WithRawData([]byte("OK"), "application/json") - mockAppChannel.On("InvokeMethod", mock.AnythingOfType("*context.emptyCtx"), fakeBindingReq).Return(fakeBindingResp, nil) - mockAppChannel.On("InvokeMethod", mock.AnythingOfType("*context.valueCtx"), fakeReq).Return(fakeResp, nil) + mockAppChannel.On("InvokeMethod", mock.MatchedBy(matchContextInterface), fakeBindingReq).Return(fakeBindingResp, nil) + mockAppChannel.On("InvokeMethod", mock.MatchedBy(matchContextInterface), fakeReq).Return(fakeResp, nil) rt.appChannel = mockAppChannel @@ -4225,8 +4229,8 @@ func TestReadInputBindings(t *testing.T) { fakeResp := invokev1.NewInvokeMethodResponse(500, "Internal Error", nil) fakeResp.WithRawData([]byte("Internal Error"), "application/json") - mockAppChannel.On("InvokeMethod", mock.AnythingOfType("*context.emptyCtx"), fakeBindingReq).Return(fakeBindingResp, nil) - mockAppChannel.On("InvokeMethod", mock.AnythingOfType("*context.valueCtx"), fakeReq).Return(fakeResp, nil) + mockAppChannel.On("InvokeMethod", mock.MatchedBy(matchContextInterface), fakeBindingReq).Return(fakeBindingResp, nil) + mockAppChannel.On("InvokeMethod", mock.MatchedBy(matchContextInterface), fakeReq).Return(fakeResp, nil) rt.appChannel = mockAppChannel rt.inputBindingRoutes[testInputBindingName] = testInputBindingName @@ -4262,8 +4266,8 @@ func TestReadInputBindings(t *testing.T) { fakeResp := invokev1.NewInvokeMethodResponse(200, "OK", nil) fakeResp.WithRawData([]byte("OK"), "application/json") - mockAppChannel.On("InvokeMethod", mock.AnythingOfType("*context.emptyCtx"), fakeBindingReq).Return(fakeBindingResp, nil) - mockAppChannel.On("InvokeMethod", mock.AnythingOfType("*context.valueCtx"), fakeReq).Return(fakeResp, nil) + mockAppChannel.On("InvokeMethod", mock.MatchedBy(matchContextInterface), fakeBindingReq).Return(fakeBindingResp, nil) + mockAppChannel.On("InvokeMethod", mock.MatchedBy(matchContextInterface), fakeReq).Return(fakeResp, nil) rt.appChannel = mockAppChannel rt.inputBindingRoutes[testInputBindingName] = testInputBindingName diff --git a/pkg/testing/directmessaging_mock.go b/pkg/testing/directmessaging_mock.go index de848788f40..599d83caf90 100644 --- a/pkg/testing/directmessaging_mock.go +++ b/pkg/testing/directmessaging_mock.go @@ -1,4 +1,15 @@ -// Code generated by mockery v1.0.0. DO NOT EDIT. +/* +Copyright 2022 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ package testing @@ -22,10 +33,8 @@ func (_m *MockDirectMessaging) Invoke(ctx context.Context, targetAppID string, r var r0 *v1.InvokeMethodResponse if rf, ok := ret.Get(0).(func(context.Context, string, *v1.InvokeMethodRequest) *v1.InvokeMethodResponse); ok { r0 = rf(ctx, targetAppID, req) - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).(*v1.InvokeMethodResponse) - } + } else if ret.Get(0) != nil { + r0 = ret.Get(0).(*v1.InvokeMethodResponse) } var r1 error @@ -51,5 +60,7 @@ func (f *FailingDirectMessaging) Invoke(ctx context.Context, targetAppID string, if err != nil { return &v1.InvokeMethodResponse{}, err } - return v1.NewInvokeMethodResponse(200, "OK", nil), nil + res := v1.NewInvokeMethodResponse(200, "OK", nil). + WithRawData(req.Message().Data.Value, "") + return res, nil } diff --git a/tests/apps/middleware/app.go b/tests/apps/middleware/app.go index 9aad9f70724..c9b3566610e 100644 --- a/tests/apps/middleware/app.go +++ b/tests/apps/middleware/app.go @@ -68,7 +68,7 @@ func testLogCall(w http.ResponseWriter, r *http.Request) { input := "hello" url := fmt.Sprintf("http://localhost:%d/v1.0/invoke/%s/method/logCall", daprPort, service) - resp, err := http.Post(url, "application/json", bytes.NewBuffer([]byte(input))) //nolint:gosec + resp, err := http.Post(url, "application/json", bytes.NewReader([]byte(input))) //nolint:gosec if err != nil { log.Printf("Could not call service") w.WriteHeader(http.StatusInternalServerError) diff --git a/tests/apps/resiliencyapp/app.go b/tests/apps/resiliencyapp/app.go index 046dc6e1456..026a8604bdf 100644 --- a/tests/apps/resiliencyapp/app.go +++ b/tests/apps/resiliencyapp/app.go @@ -20,6 +20,8 @@ import ( "io" "log" "net/http" + "os" + "strconv" "time" commonv1pb "github.com/dapr/dapr/pkg/proto/common/v1" @@ -34,8 +36,14 @@ import ( "google.golang.org/protobuf/types/known/anypb" ) -const ( - appPort = 3000 +var ( + appPort = 3000 + daprHttpPort = 3500 + daprGrpcPort = 50001 + + daprClient runtimev1pb.DaprClient + callTracking map[string][]CallRecord + httpClient = utils.NewHTTPClient() ) type FailureMessage struct { @@ -56,12 +64,20 @@ type PubsubResponse struct { Message string `json:"message,omitempty"` } -var ( - daprClient runtimev1pb.DaprClient - callTracking map[string][]CallRecord -) - -var httpClient = utils.NewHTTPClient() +func init() { + p := os.Getenv("DAPR_HTTP_PORT") + if p != "" && p != "0" { + daprHttpPort, _ = strconv.Atoi(p) + } + p = os.Getenv("DAPR_GRPC_PORT") + if p != "" && p != "0" { + daprGrpcPort, _ = strconv.Atoi(p) + } + p = os.Getenv("PORT") + if p != "" && p != "0" { + appPort, _ = strconv.Atoi(p) + } +} // Endpoint handling. func indexHandler(w http.ResponseWriter, r *http.Request) { @@ -125,10 +141,11 @@ func resiliencyBindingHandler(w http.ResponseWriter, r *http.Request) { return } + body, _ := io.ReadAll(r.Body) var message FailureMessage - json.NewDecoder(r.Body).Decode(&message) + json.Unmarshal(body, &message) - log.Printf("Binding received message %+v\n", message) + log.Printf("Binding received message %s\n", string(body)) callCount := 0 if records, ok := callTracking[message.ID]; ok { @@ -178,7 +195,7 @@ func resiliencyPubsubHandler(w http.ResponseWriter, r *http.Request) { rawDataBytes, _ := json.Marshal(rawData) var message FailureMessage json.Unmarshal(rawDataBytes, &message) - log.Printf("Pubsub received message %+v\n", message) + log.Printf("Pubsub received message %s\n", string(rawDataBytes)) callCount := 0 if records, ok := callTracking[message.ID]; ok { @@ -205,10 +222,11 @@ func resiliencyPubsubHandler(w http.ResponseWriter, r *http.Request) { } func resiliencyServiceInvocationHandler(w http.ResponseWriter, r *http.Request) { + body, _ := io.ReadAll(r.Body) var message FailureMessage - json.NewDecoder(r.Body).Decode(&message) + json.Unmarshal(body, &message) - log.Printf("Http invocation received message %+v\n", message) + log.Printf("HTTP invocation received message %s\n", string(body)) callCount := 0 if records, ok := callTracking[message.ID]; ok { @@ -235,10 +253,11 @@ func resiliencyServiceInvocationHandler(w http.ResponseWriter, r *http.Request) } func resiliencyActorMethodHandler(w http.ResponseWriter, r *http.Request) { + body, _ := io.ReadAll(r.Body) var message FailureMessage - json.NewDecoder(r.Body).Decode(&message) + json.Unmarshal(body, &message) - log.Printf("Actor received message %+v\n", message) + log.Printf("Actor received message %s\n", string(body)) callCount := 0 if records, ok := callTracking[message.ID]; ok { @@ -262,30 +281,6 @@ func resiliencyActorMethodHandler(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) } -// App startup/endpoint setup. -func initGRPCClient() { - url := fmt.Sprintf("localhost:%d", 50001) - log.Printf("Connecting to dapr using url %s", url) - var grpcConn *grpc.ClientConn - for retries := 10; retries > 0; retries-- { - var err error - grpcConn, err = grpc.Dial(url, grpc.WithTransportCredentials(insecure.NewCredentials())) - if err == nil { - break - } - - if retries == 0 { - log.Printf("Could not connect to dapr: %v", err) - log.Panic(err) - } - - log.Printf("Could not connect to dapr: %v, retrying...", err) - time.Sleep(5 * time.Second) - } - - daprClient = runtimev1pb.NewDaprClient(grpcConn) -} - func appRouter() *mux.Router { router := mux.NewRouter().StrictSlash(true) @@ -322,7 +317,7 @@ func appRouter() *mux.Router { func main() { callTracking = map[string][]CallRecord{} - initGRPCClient() + daprClient = utils.GetGRPCClient(daprGrpcPort) log.Printf("Resiliency App - listening on http://localhost:%d", appPort) utils.StartServer(appPort, appRouter, true, false) @@ -434,7 +429,7 @@ func TestInvokeService(w http.ResponseWriter, r *http.Request) { if targetMethod == "" { targetMethod = "resiliencyInvocation" } - url := fmt.Sprintf("http://localhost:3500/v1.0/invoke/%s/method/%s", targetApp, targetMethod) + url := fmt.Sprintf("http://localhost:%d/v1.0/invoke/%s/method/%s", daprHttpPort, targetApp, targetMethod) req, _ := http.NewRequest("POST", url, r.Body) defer r.Body.Close() @@ -496,7 +491,11 @@ func TestInvokeService(w http.ResponseWriter, r *http.Request) { log.Printf("Proxying message: %+v", message) b, _ := json.Marshal(message) - conn, err := grpc.Dial("localhost:50001", grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithBlock()) + conn, err := grpc.Dial( + fmt.Sprintf("localhost:%d", daprGrpcPort), + grpc.WithTransportCredentials(insecure.NewCredentials()), + grpc.WithBlock(), + ) if err != nil { log.Fatalf("did not connect: %v", err) } @@ -521,7 +520,7 @@ func TestInvokeActorMethod(w http.ResponseWriter, r *http.Request) { if protocol == "http" { httpClient.Timeout = time.Minute - url := "http://localhost:3500/v1.0/actors/resiliencyActor/1/method/resiliencyMethod" + url := fmt.Sprintf("http://localhost:%d/v1.0/actors/resiliencyActor/1/method/resiliencyMethod", daprHttpPort) req, _ := http.NewRequest("PUT", url, r.Body) defer r.Body.Close() diff --git a/tests/e2e/service_invocation/service_invocation_test.go b/tests/e2e/service_invocation/service_invocation_test.go index f78528f9937..beee60e0632 100644 --- a/tests/e2e/service_invocation/service_invocation_test.go +++ b/tests/e2e/service_invocation/service_invocation_test.go @@ -462,21 +462,31 @@ func TestHeaders(t *testing.T) { t.Logf("unmarshalling..%s\n", string(resp)) err = json.Unmarshal(resp, &appResp) - actualHeaders := map[string]string{} - json.Unmarshal([]byte(appResp.Message), &actualHeaders) + actualHeaders := struct { + Request string `json:"request"` + Response string `json:"response"` + Trailers string `json:"trailers"` + }{} + err = json.Unmarshal([]byte(appResp.Message), &actualHeaders) + require.NoError(t, err, "failed to unmarshal response: %s", appResp.Message) requestHeaders := map[string][]string{} responseHeaders := map[string][]string{} trailerHeaders := map[string][]string{} - json.Unmarshal([]byte(actualHeaders["request"]), &requestHeaders) - json.Unmarshal([]byte(actualHeaders["response"]), &responseHeaders) - json.Unmarshal([]byte(actualHeaders["trailers"]), &trailerHeaders) + json.Unmarshal([]byte(actualHeaders.Request), &requestHeaders) + json.Unmarshal([]byte(actualHeaders.Response), &responseHeaders) + json.Unmarshal([]byte(actualHeaders.Trailers), &trailerHeaders) require.NoError(t, err) - assert.Equal(t, "application/grpc", requestHeaders["content-type"][0]) - assert.Equal(t, "127.0.0.1:3000", requestHeaders[":authority"][0]) - assert.Equal(t, "DaprValue1", requestHeaders["daprtest-request-1"][0]) - assert.Equal(t, "DaprValue2", requestHeaders["daprtest-request-2"][0]) - assert.NotNil(t, requestHeaders["user-agent"][0]) + _ = assert.NotEmpty(t, requestHeaders["content-type"]) && + assert.Equal(t, "application/grpc", requestHeaders["content-type"][0]) + _ = assert.NotEmpty(t, requestHeaders[":authority"]) && + assert.Equal(t, "127.0.0.1:3000", requestHeaders[":authority"][0]) + _ = assert.NotEmpty(t, requestHeaders["daprtest-request-1"]) && + assert.Equal(t, "DaprValue1", requestHeaders["daprtest-request-1"][0]) + _ = assert.NotEmpty(t, requestHeaders["daprtest-request-2"]) && + assert.Equal(t, "DaprValue2", requestHeaders["daprtest-request-2"][0]) + _ = assert.NotEmpty(t, requestHeaders["user-agent"]) && + assert.NotNil(t, requestHeaders["user-agent"][0]) grpcTraceBinRq := requestHeaders["grpc-trace-bin"] if assert.NotNil(t, grpcTraceBinRq, "grpc-trace-bin is missing from the request") { if assert.Equal(t, 1, len(grpcTraceBinRq), "grpc-trace-bin is missing from the request") { @@ -489,16 +499,24 @@ func TestHeaders(t *testing.T) { assert.NotEqual(t, "", traceParentRq[0], "traceparent is missing from the request") } } - assert.Equal(t, hostIP, requestHeaders["x-forwarded-for"][0]) - assert.Equal(t, hostname, requestHeaders["x-forwarded-host"][0]) - assert.Equal(t, expectedForwarded, requestHeaders["forwarded"][0]) - - assert.Equal(t, "serviceinvocation-caller", requestHeaders[invokev1.CallerIDHeader][0]) - assert.Equal(t, "grpcapp", requestHeaders[invokev1.CalleeIDHeader][0]) - - assert.Equal(t, "application/grpc", responseHeaders["content-type"][0]) - assert.Equal(t, "DaprTest-Response-Value-1", responseHeaders["daprtest-response-1"][0]) - assert.Equal(t, "DaprTest-Response-Value-2", responseHeaders["daprtest-response-2"][0]) + _ = assert.NotEmpty(t, requestHeaders["x-forwarded-for"]) && + assert.Equal(t, hostIP, requestHeaders["x-forwarded-for"][0]) + _ = assert.NotEmpty(t, requestHeaders["x-forwarded-host"]) && + assert.Equal(t, hostname, requestHeaders["x-forwarded-host"][0]) + _ = assert.NotEmpty(t, requestHeaders["forwarded"]) && + assert.Equal(t, expectedForwarded, requestHeaders["forwarded"][0]) + + _ = assert.NotEmpty(t, requestHeaders[invokev1.CallerIDHeader]) && + assert.Equal(t, "serviceinvocation-caller", requestHeaders[invokev1.CallerIDHeader][0]) + _ = assert.NotEmpty(t, requestHeaders[invokev1.CalleeIDHeader]) && + assert.Equal(t, "grpcapp", requestHeaders[invokev1.CalleeIDHeader][0]) + + _ = assert.NotEmpty(t, responseHeaders["content-type"]) && + assert.Equal(t, "application/grpc", responseHeaders["content-type"][0]) + _ = assert.NotEmpty(t, responseHeaders["daprtest-response-1"]) && + assert.Equal(t, "DaprTest-Response-Value-1", responseHeaders["daprtest-response-1"][0]) + _ = assert.NotEmpty(t, responseHeaders["daprtest-response-2"]) && + assert.Equal(t, "DaprTest-Response-Value-2", responseHeaders["daprtest-response-2"][0]) grpcTraceBinRs := responseHeaders["grpc-trace-bin"] if assert.NotNil(t, grpcTraceBinRs, "grpc-trace-bin is missing from the response") { if assert.Equal(t, 1, len(grpcTraceBinRs), "grpc-trace-bin is missing from the response") { @@ -512,8 +530,10 @@ func TestHeaders(t *testing.T) { } } - assert.Equal(t, "DaprTest-Trailer-Value-1", trailerHeaders["daprtest-trailer-1"][0]) - assert.Equal(t, "DaprTest-Trailer-Value-2", trailerHeaders["daprtest-trailer-2"][0]) + _ = assert.NotEmpty(t, trailerHeaders["daprtest-trailer-1"]) && + assert.Equal(t, "DaprTest-Trailer-Value-1", trailerHeaders["daprtest-trailer-1"][0]) + _ = assert.NotEmpty(t, trailerHeaders["daprtest-trailer-2"]) && + assert.Equal(t, "DaprTest-Trailer-Value-2", trailerHeaders["daprtest-trailer-2"][0]) }) t.Run("grpc-to-http", func(t *testing.T) { @@ -532,33 +552,52 @@ func TestHeaders(t *testing.T) { t.Logf("unmarshalling..%s\n", string(resp)) err = json.Unmarshal(resp, &appResp) - actualHeaders := map[string]string{} - json.Unmarshal([]byte(appResp.Message), &actualHeaders) + actualHeaders := struct { + Request string `json:"request"` + Response string `json:"response"` + }{} + err = json.Unmarshal([]byte(appResp.Message), &actualHeaders) + require.NoError(t, err, "failed to unmarshal response: %s", appResp.Message) requestHeaders := map[string][]string{} responseHeaders := map[string][]string{} - json.Unmarshal([]byte(actualHeaders["request"]), &requestHeaders) - json.Unmarshal([]byte(actualHeaders["response"]), &responseHeaders) + json.Unmarshal([]byte(actualHeaders.Request), &requestHeaders) + json.Unmarshal([]byte(actualHeaders.Response), &responseHeaders) require.NoError(t, err) - assert.NotNil(t, requestHeaders["Content-Length"][0]) - assert.Equal(t, "text/plain; utf-8", requestHeaders["Content-Type"][0]) - assert.Equal(t, "localhost:50001", requestHeaders["Dapr-Authority"][0]) - assert.Equal(t, "DaprValue1", requestHeaders["Daprtest-Request-1"][0]) - assert.Equal(t, "DaprValue2", requestHeaders["Daprtest-Request-2"][0]) - assert.NotNil(t, requestHeaders["Traceparent"][0]) - assert.NotNil(t, requestHeaders["User-Agent"][0]) - assert.Equal(t, hostIP, requestHeaders["X-Forwarded-For"][0]) - assert.Equal(t, hostname, requestHeaders["X-Forwarded-Host"][0]) - assert.Equal(t, expectedForwarded, requestHeaders["Forwarded"][0]) - assert.Equal(t, "serviceinvocation-caller", requestHeaders["Dapr-Caller-App-Id"][0]) - assert.Equal(t, "serviceinvocation-callee-0", requestHeaders["Dapr-Callee-App-Id"][0]) - - assert.NotNil(t, responseHeaders["dapr-content-length"][0]) - assert.Equal(t, "application/grpc", responseHeaders["content-type"][0]) - assert.True(t, strings.HasPrefix(responseHeaders["dapr-content-type"][0], "application/json")) - assert.NotNil(t, responseHeaders["dapr-date"][0]) - assert.Equal(t, "DaprTest-Response-Value-1", responseHeaders["daprtest-response-1"][0]) - assert.Equal(t, "DaprTest-Response-Value-2", responseHeaders["daprtest-response-2"][0]) + _ = assert.NotEmpty(t, requestHeaders["Content-Type"]) && + assert.Equal(t, "text/plain; utf-8", requestHeaders["Content-Type"][0]) + _ = assert.NotEmpty(t, requestHeaders["Dapr-Authority"]) && + assert.Equal(t, "localhost:50001", requestHeaders["Dapr-Authority"][0]) + _ = assert.NotEmpty(t, requestHeaders["Daprtest-Request-1"]) && + assert.Equal(t, "DaprValue1", requestHeaders["Daprtest-Request-1"][0]) + _ = assert.NotEmpty(t, requestHeaders["Daprtest-Request-2"]) && + assert.Equal(t, "DaprValue2", requestHeaders["Daprtest-Request-2"][0]) + _ = assert.NotEmpty(t, requestHeaders["Traceparent"]) && + assert.NotNil(t, requestHeaders["Traceparent"][0]) + _ = assert.NotEmpty(t, requestHeaders["User-Agent"]) && + assert.NotNil(t, requestHeaders["User-Agent"][0]) + _ = assert.NotEmpty(t, requestHeaders["X-Forwarded-For"]) && + assert.Equal(t, hostIP, requestHeaders["X-Forwarded-For"][0]) + _ = assert.NotEmpty(t, requestHeaders["X-Forwarded-Host"]) && + assert.Equal(t, hostname, requestHeaders["X-Forwarded-Host"][0]) + _ = assert.NotEmpty(t, requestHeaders["Forwarded"]) && + assert.Equal(t, expectedForwarded, requestHeaders["Forwarded"][0]) + + _ = assert.NotEmpty(t, requestHeaders["Dapr-Caller-App-Id"]) && + assert.Equal(t, "serviceinvocation-caller", requestHeaders["Dapr-Caller-App-Id"][0]) + _ = assert.NotEmpty(t, requestHeaders["Dapr-Callee-App-Id"]) && + assert.Equal(t, "serviceinvocation-callee-0", requestHeaders["Dapr-Callee-App-Id"][0]) + + _ = assert.NotEmpty(t, responseHeaders["content-type"]) && + assert.Equal(t, "application/grpc", responseHeaders["content-type"][0]) + _ = assert.NotEmpty(t, responseHeaders["dapr-content-type"]) && + assert.True(t, strings.HasPrefix(responseHeaders["dapr-content-type"][0], "application/json")) + _ = assert.NotEmpty(t, responseHeaders["dapr-date"]) && + assert.NotNil(t, responseHeaders["dapr-date"][0]) + _ = assert.NotEmpty(t, responseHeaders["daprtest-response-1"]) && + assert.Equal(t, "DaprTest-Response-Value-1", responseHeaders["daprtest-response-1"][0]) + _ = assert.NotEmpty(t, responseHeaders["daprtest-response-2"]) && + assert.Equal(t, "DaprTest-Response-Value-2", responseHeaders["daprtest-response-2"][0]) grpcTraceBinRs := responseHeaders["grpc-trace-bin"] if assert.NotNil(t, grpcTraceBinRs, "grpc-trace-bin is missing from the response") { @@ -584,23 +623,31 @@ func TestHeaders(t *testing.T) { t.Logf("unmarshalling..%s\n", string(resp)) err = json.Unmarshal(resp, &appResp) - actualHeaders := map[string]string{} - json.Unmarshal([]byte(appResp.Message), &actualHeaders) + actualHeaders := struct { + Request string `json:"request"` + Response string `json:"response"` + }{} + err = json.Unmarshal([]byte(appResp.Message), &actualHeaders) + require.NoError(t, err, "failed to unmarshal response: %s", appResp.Message) requestHeaders := map[string][]string{} responseHeaders := map[string][]string{} - json.Unmarshal([]byte(actualHeaders["request"]), &requestHeaders) - json.Unmarshal([]byte(actualHeaders["response"]), &responseHeaders) + json.Unmarshal([]byte(actualHeaders.Request), &requestHeaders) + json.Unmarshal([]byte(actualHeaders.Response), &responseHeaders) require.NoError(t, err) - assert.Nil(t, requestHeaders["connection"]) - assert.Nil(t, requestHeaders["content-length"]) - assert.True(t, strings.HasPrefix(requestHeaders["dapr-host"][0], "localhost:")) - assert.Equal(t, "application/grpc", requestHeaders["content-type"][0]) - assert.True(t, strings.HasPrefix(requestHeaders[":authority"][0], "127.0.0.1:")) - assert.Equal(t, "DaprValue1", requestHeaders["daprtest-request-1"][0]) - assert.Equal(t, "DaprValue2", requestHeaders["daprtest-request-2"][0]) - assert.NotNil(t, requestHeaders["user-agent"][0]) + _ = assert.NotEmpty(t, requestHeaders["dapr-host"]) && + assert.True(t, strings.HasPrefix(requestHeaders["dapr-host"][0], "localhost:")) + _ = assert.NotEmpty(t, requestHeaders["content-type"]) && + assert.Equal(t, "application/grpc", requestHeaders["content-type"][0]) + _ = assert.NotEmpty(t, requestHeaders[":authority"]) && + assert.True(t, strings.HasPrefix(requestHeaders[":authority"][0], "127.0.0.1:")) + _ = assert.NotEmpty(t, requestHeaders["daprtest-request-1"]) && + assert.Equal(t, "DaprValue1", requestHeaders["daprtest-request-1"][0]) + _ = assert.NotEmpty(t, requestHeaders["daprtest-request-1"]) && + assert.Equal(t, "DaprValue2", requestHeaders["daprtest-request-2"][0]) + _ = assert.NotEmpty(t, requestHeaders["user-agent"]) && + assert.NotNil(t, requestHeaders["user-agent"][0]) grpcTraceBinRq := requestHeaders["grpc-trace-bin"] if assert.NotNil(t, grpcTraceBinRq, "grpc-trace-bin is missing from the request") { if assert.Equal(t, 1, len(grpcTraceBinRq), "grpc-trace-bin is missing from the request") { @@ -613,19 +660,26 @@ func TestHeaders(t *testing.T) { assert.NotEqual(t, "", traceParentRq[0], "traceparent is missing from the request") } } - assert.Equal(t, hostIP, requestHeaders["x-forwarded-for"][0]) - assert.Equal(t, hostname, requestHeaders["x-forwarded-host"][0]) - assert.Equal(t, expectedForwarded, requestHeaders["forwarded"][0]) + _ = assert.NotEmpty(t, requestHeaders["x-forwarded-for"]) && + assert.Equal(t, hostIP, requestHeaders["x-forwarded-for"][0]) + _ = assert.NotEmpty(t, requestHeaders["x-forwarded-host"]) && + assert.Equal(t, hostname, requestHeaders["x-forwarded-host"][0]) + _ = assert.NotEmpty(t, requestHeaders["forwarded"]) && + assert.Equal(t, expectedForwarded, requestHeaders["forwarded"][0]) assert.Equal(t, "serviceinvocation-caller", requestHeaders[invokev1.CallerIDHeader][0]) assert.Equal(t, "grpcapp", requestHeaders[invokev1.CalleeIDHeader][0]) - assert.NotNil(t, responseHeaders["Content-Length"][0]) - assert.True(t, strings.HasPrefix(responseHeaders["Content-Type"][0], "application/json")) - assert.NotNil(t, responseHeaders["Date"][0]) - assert.Equal(t, "DaprTest-Response-Value-1", responseHeaders["Daprtest-Response-1"][0]) - assert.Equal(t, "DaprTest-Response-Value-2", responseHeaders["Daprtest-Response-2"][0]) - assert.NotNil(t, responseHeaders["Traceparent"][0]) + _ = assert.NotEmpty(t, responseHeaders["Content-Type"]) && + assert.True(t, strings.HasPrefix(responseHeaders["Content-Type"][0], "application/json")) + _ = assert.NotEmpty(t, responseHeaders["Date"]) && + assert.NotNil(t, responseHeaders["Date"][0]) + _ = assert.NotEmpty(t, responseHeaders["Daprtest-Response-1"]) && + assert.Equal(t, "DaprTest-Response-Value-1", responseHeaders["Daprtest-Response-1"][0]) + _ = assert.NotEmpty(t, responseHeaders["Daprtest-Response-2"]) && + assert.Equal(t, "DaprTest-Response-Value-2", responseHeaders["Daprtest-Response-2"][0]) + _ = assert.NotEmpty(t, responseHeaders["Traceparent"]) && + assert.NotNil(t, responseHeaders["Traceparent"][0]) }) /* Tracing specific tests */ @@ -674,14 +728,19 @@ func TestHeaders(t *testing.T) { t.Logf("unmarshalling..%s\n", string(resp)) err = json.Unmarshal(resp, &appResp) - actualHeaders := map[string]string{} - json.Unmarshal([]byte(appResp.Message), &actualHeaders) + actualHeaders := struct { + Request string `json:"request"` + Response string `json:"response"` + Trailers string `json:"trailers"` + }{} + err = json.Unmarshal([]byte(appResp.Message), &actualHeaders) + require.NoError(t, err, "failed to unmarshal response: %s", appResp.Message) requestHeaders := map[string][]string{} responseHeaders := map[string][]string{} trailerHeaders := map[string][]string{} - json.Unmarshal([]byte(actualHeaders["request"]), &requestHeaders) - json.Unmarshal([]byte(actualHeaders["response"]), &responseHeaders) - json.Unmarshal([]byte(actualHeaders["trailers"]), &trailerHeaders) + json.Unmarshal([]byte(actualHeaders.Request), &requestHeaders) + json.Unmarshal([]byte(actualHeaders.Response), &responseHeaders) + json.Unmarshal([]byte(actualHeaders.Trailers), &trailerHeaders) require.NoError(t, err) @@ -737,12 +796,16 @@ func TestHeaders(t *testing.T) { t.Logf("unmarshalling..%s\n", string(resp)) err = json.Unmarshal(resp, &appResp) - actualHeaders := map[string]string{} - json.Unmarshal([]byte(appResp.Message), &actualHeaders) + actualHeaders := struct { + Request string `json:"request"` + Response string `json:"response"` + }{} + err = json.Unmarshal([]byte(appResp.Message), &actualHeaders) + require.NoError(t, err, "failed to unmarshal response: %s", appResp.Message) requestHeaders := map[string][]string{} responseHeaders := map[string][]string{} - json.Unmarshal([]byte(actualHeaders["request"]), &requestHeaders) - json.Unmarshal([]byte(actualHeaders["response"]), &responseHeaders) + json.Unmarshal([]byte(actualHeaders.Request), &requestHeaders) + json.Unmarshal([]byte(actualHeaders.Response), &responseHeaders) require.NoError(t, err) @@ -771,17 +834,23 @@ func TestHeaders(t *testing.T) { t.Logf("unmarshalling..%s\n", string(resp)) err = json.Unmarshal(resp, &appResp) - actualHeaders := map[string]string{} - json.Unmarshal([]byte(appResp.Message), &actualHeaders) + actualHeaders := struct { + Request string `json:"request"` + Response string `json:"response"` + }{} + err = json.Unmarshal([]byte(appResp.Message), &actualHeaders) + require.NoError(t, err, "failed to unmarshal response: %s", appResp.Message) requestHeaders := map[string][]string{} responseHeaders := map[string][]string{} - json.Unmarshal([]byte(actualHeaders["request"]), &requestHeaders) - json.Unmarshal([]byte(actualHeaders["response"]), &responseHeaders) + json.Unmarshal([]byte(actualHeaders.Request), &requestHeaders) + json.Unmarshal([]byte(actualHeaders.Response), &responseHeaders) require.NoError(t, err) - assert.NotNil(t, requestHeaders["Traceparent"][0]) - assert.Equal(t, expectedTraceID, requestHeaders["Daprtest-Traceid"][0]) + _ = assert.NotEmpty(t, requestHeaders["Traceparent"]) && + assert.NotNil(t, requestHeaders["Traceparent"][0]) + _ = assert.NotEmpty(t, requestHeaders["Daprtest-Traceid"]) && + assert.Equal(t, expectedTraceID, requestHeaders["Daprtest-Traceid"][0]) grpcTraceBinRs := responseHeaders["grpc-trace-bin"] if assert.NotNil(t, grpcTraceBinRs, "grpc-trace-bin is missing from the response") { @@ -818,17 +887,23 @@ func verifyHTTPToHTTPTracing(t *testing.T, url string, expectedTraceID string) { t.Logf("unmarshalling..%s\n", string(resp)) err = json.Unmarshal(resp, &appResp) - actualHeaders := map[string]string{} - json.Unmarshal([]byte(appResp.Message), &actualHeaders) + actualHeaders := struct { + Request string `json:"request"` + Response string `json:"response"` + }{} + err = json.Unmarshal([]byte(appResp.Message), &actualHeaders) + require.NoError(t, err, "failed to unmarshal response: %s", appResp.Message) requestHeaders := map[string][]string{} responseHeaders := map[string][]string{} - json.Unmarshal([]byte(actualHeaders["request"]), &requestHeaders) - json.Unmarshal([]byte(actualHeaders["response"]), &responseHeaders) + json.Unmarshal([]byte(actualHeaders.Request), &requestHeaders) + json.Unmarshal([]byte(actualHeaders.Response), &responseHeaders) require.NoError(t, err) - assert.NotNil(t, requestHeaders["Traceparent"][0]) - assert.Equal(t, expectedTraceID, requestHeaders["Daprtest-Traceid"][0]) + _ = assert.NotEmpty(t, requestHeaders["Traceparent"]) && + assert.NotNil(t, requestHeaders["Traceparent"][0]) + _ = assert.NotEmpty(t, requestHeaders["Daprtest-Traceid"]) && + assert.Equal(t, expectedTraceID, requestHeaders["Daprtest-Traceid"][0]) traceParentRs := responseHeaders["Traceparent"] if assert.NotNil(t, traceParentRs, "Traceparent is missing from the response") { @@ -853,32 +928,43 @@ func verifyHTTPToHTTP(t *testing.T, hostIP string, hostname string, url string, t.Logf("unmarshalling..%s\n", string(resp)) err = json.Unmarshal(resp, &appResp) - actualHeaders := map[string]string{} - json.Unmarshal([]byte(appResp.Message), &actualHeaders) + actualHeaders := struct { + Request string `json:"request"` + Response string `json:"response"` + }{} + err = json.Unmarshal([]byte(appResp.Message), &actualHeaders) + require.NoError(t, err, "failed to unmarshal response: %s", appResp.Message) requestHeaders := map[string][]string{} responseHeaders := map[string][]string{} - json.Unmarshal([]byte(actualHeaders["request"]), &requestHeaders) - json.Unmarshal([]byte(actualHeaders["response"]), &responseHeaders) + json.Unmarshal([]byte(actualHeaders.Request), &requestHeaders) + json.Unmarshal([]byte(actualHeaders.Response), &responseHeaders) require.NoError(t, err) - assert.NotNil(t, requestHeaders["Accept-Encoding"][0]) - assert.NotNil(t, requestHeaders["Content-Length"][0]) - assert.True(t, strings.HasPrefix(requestHeaders["Content-Type"][0], "application/json")) - assert.Equal(t, "DaprValue1", requestHeaders["Daprtest-Request-1"][0]) - assert.Equal(t, "DaprValue2", requestHeaders["Daprtest-Request-2"][0]) - assert.NotNil(t, requestHeaders["Traceparent"][0]) - assert.NotNil(t, requestHeaders["User-Agent"][0]) - assert.Equal(t, hostIP, requestHeaders["X-Forwarded-For"][0]) - assert.Equal(t, hostname, requestHeaders["X-Forwarded-Host"][0]) - assert.Equal(t, expectedForwarded, requestHeaders["Forwarded"][0]) - assert.Equal(t, "serviceinvocation-caller", requestHeaders["Dapr-Caller-App-Id"][0]) - assert.Equal(t, "serviceinvocation-callee-0", requestHeaders["Dapr-Callee-App-Id"][0]) - - assert.NotNil(t, responseHeaders["Content-Length"][0]) - assert.True(t, strings.HasPrefix(responseHeaders["Content-Type"][0], "application/json")) - assert.Equal(t, "DaprTest-Response-Value-1", responseHeaders["Daprtest-Response-1"][0]) - assert.Equal(t, "DaprTest-Response-Value-2", responseHeaders["Daprtest-Response-2"][0]) - assert.NotNil(t, responseHeaders["Traceparent"][0]) + _ = assert.NotEmpty(t, requestHeaders["Content-Type"]) && + assert.True(t, strings.HasPrefix(requestHeaders["Content-Type"][0], "application/json")) + _ = assert.NotEmpty(t, requestHeaders["Daprtest-Request-1"]) && + assert.Equal(t, "DaprValue1", requestHeaders["Daprtest-Request-1"][0]) + _ = assert.NotEmpty(t, requestHeaders["Daprtest-Request-2"]) && + assert.Equal(t, "DaprValue2", requestHeaders["Daprtest-Request-2"][0]) + _ = assert.NotEmpty(t, requestHeaders["Traceparent"]) && + assert.NotNil(t, requestHeaders["Traceparent"][0]) + _ = assert.NotEmpty(t, requestHeaders["User-Agent"]) && + assert.NotNil(t, requestHeaders["User-Agent"][0]) + _ = assert.NotEmpty(t, requestHeaders["X-Forwarded-For"]) && + assert.Equal(t, hostIP, requestHeaders["X-Forwarded-For"][0]) + _ = assert.NotEmpty(t, requestHeaders["X-Forwarded-Host"]) && + assert.Equal(t, hostname, requestHeaders["X-Forwarded-Host"][0]) + _ = assert.NotEmpty(t, requestHeaders["Forwarded"]) && + assert.Equal(t, expectedForwarded, requestHeaders["Forwarded"][0]) + + _ = assert.NotEmpty(t, responseHeaders["Content-Type"]) && + assert.True(t, strings.HasPrefix(responseHeaders["Content-Type"][0], "application/json")) + _ = assert.NotEmpty(t, responseHeaders["Daprtest-Response-1"]) && + assert.Equal(t, "DaprTest-Response-Value-1", responseHeaders["Daprtest-Response-1"][0]) + _ = assert.NotEmpty(t, responseHeaders["Daprtest-Response-2"]) && + assert.Equal(t, "DaprTest-Response-Value-2", responseHeaders["Daprtest-Response-2"][0]) + _ = assert.NotEmpty(t, responseHeaders["Traceparent"]) && + assert.NotNil(t, responseHeaders["Traceparent"][0]) } func TestUppercaseMiddlewareServiceInvocation(t *testing.T) { diff --git a/utils/host.go b/utils/host.go index ab41ce3782b..dcebf1ed8a5 100644 --- a/utils/host.go +++ b/utils/host.go @@ -1,3 +1,16 @@ +/* +Copyright 2022 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + package utils import ( diff --git a/utils/host_test.go b/utils/host_test.go index 442aec2cd0a..7bd4312de91 100644 --- a/utils/host_test.go +++ b/utils/host_test.go @@ -1,3 +1,16 @@ +/* +Copyright 2022 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + package utils import ( diff --git a/utils/resolvconf.go b/utils/resolvconf.go index add2c056f8a..25fb7bb5056 100644 --- a/utils/resolvconf.go +++ b/utils/resolvconf.go @@ -1,3 +1,16 @@ +/* +Copyright 2022 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + package utils import ( diff --git a/utils/resolvconf_test.go b/utils/resolvconf_test.go index 41e03ba8c95..b4775999dd5 100644 --- a/utils/resolvconf_test.go +++ b/utils/resolvconf_test.go @@ -1,3 +1,16 @@ +/* +Copyright 2022 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + package utils import (