Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changeset/funny-turkeys-sit.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"@effect/rpc": minor
---

Allow RPC Client Middleware to shortcircuit with a failure
35 changes: 21 additions & 14 deletions packages/platform-node/test/fixtures/rpc-schemas.ts
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,18 @@ class StreamUsers extends Schema.TaggedRequest<StreamUsers>()("StreamUsers", {

class CurrentUser extends Context.Tag("CurrentUser")<CurrentUser, User>() {}

class Unauthorized extends Schema.TaggedError<Unauthorized>("Unauthorized")("Unauthorized", {}) {}

class AuthMiddleware extends RpcMiddleware.Tag<AuthMiddleware>()("AuthMiddleware", {
provides: CurrentUser,
failure: Unauthorized,
requiredForClient: true
}) {}
export class Unauthorized extends Schema.TaggedError<Unauthorized>("Unauthorized")("Unauthorized", {}) {}
export class InvalidClientCredentials
extends Schema.TaggedError<InvalidClientCredentials>("InvalidClientCredentials")("InvalidClientCredentials", {})
{}

class AuthMiddleware
extends RpcMiddleware.Tag<AuthMiddleware, { clientError: InvalidClientCredentials }>()("AuthMiddleware", {
provides: CurrentUser,
failure: Unauthorized,
requiredForClient: true
})
{}

class TimingMiddleware extends RpcMiddleware.Tag<TimingMiddleware>()("TimingMiddleware", {
wrap: true
Expand Down Expand Up @@ -78,9 +83,9 @@ export const UserRpcs = RpcGroup.make(
const AuthLive = Layer.succeed(
AuthMiddleware,
AuthMiddleware.of((options) =>
Effect.succeed(
new User({ id: options.headers.userid ?? "1", name: options.headers.name ?? "Fallback name" })
)
options.headers.userid && options.headers.userid === "-2" ?
new Unauthorized({ failedOn: "Server" }) :
Effect.succeed(new User({ id: options.headers.userid ?? "1", name: options.headers.name ?? "Fallback name" }))
)
)

Expand Down Expand Up @@ -153,10 +158,12 @@ export const RpcLive = RpcServer.layer(UserRpcs).pipe(
)

const AuthClient = RpcMiddleware.layerClient(AuthMiddleware, ({ request }) =>
Effect.succeed({
...request,
headers: Headers.set(request.headers, "name", "Logged in user")
}))
request.headers.userid === "-1" ?
new InvalidClientCredentials() :
Effect.succeed({
...request,
headers: Headers.set(request.headers, "name", "Logged in user")
}))

export class UsersClient extends Context.Tag("UsersClient")<
UsersClient,
Expand Down
16 changes: 15 additions & 1 deletion packages/platform-node/test/rpc-e2e.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import { RpcClient, RpcServer } from "@effect/rpc"
import { assert, describe, it } from "@effect/vitest"
import type { Layer } from "effect"
import { Cause, Effect, Fiber, Option, Stream } from "effect"
import { User, UsersClient } from "./fixtures/rpc-schemas.js"
import { InvalidClientCredentials, Unauthorized, User, UsersClient } from "./fixtures/rpc-schemas.js"

export const e2eSuite = <E>(
name: string,
Expand All @@ -18,6 +18,20 @@ export const e2eSuite = <E>(
assert.deepStrictEqual(user, new User({ id: "1", name: "Logged in user" }))
}).pipe(Effect.provide(layer)))

it.effect("should short circuit on client middleware failure", () =>
Effect.gen(function*() {
const client = yield* UsersClient
const failure = yield* client.GetUser({ id: "1" }, { headers: { userid: "-1" } }).pipe(Effect.flip)
assert.instanceOf(failure, InvalidClientCredentials)
}).pipe(Effect.provide(layer)))

it.effect("should fail on server middleware failure", () =>
Effect.gen(function*() {
const client = yield* UsersClient
const failure = yield* client.GetUser({ id: "1" }, { headers: { userid: "-2" } }).pipe(Effect.flip)
assert.instanceOf(failure, Unauthorized)
}).pipe(Effect.provide(layer)))

it.effect("nested method", () =>
Effect.gen(function*() {
const client = yield* UsersClient
Expand Down
47 changes: 30 additions & 17 deletions packages/rpc/src/RpcClient.ts
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,10 @@ export declare namespace RpcClient {
infer _Error,
infer _Middleware
> ? [_Success] extends [RpcSchema.Stream<infer _A, infer _E>] ? AsMailbox extends true ? Effect.Effect<
Mailbox.ReadonlyMailbox<_A["Type"], _E["Type"] | _Error["Type"] | E | _Middleware["failure"]["Type"]>,
Mailbox.ReadonlyMailbox<
_A["Type"],
_E["Type"] | _Error["Type"] | E | _Middleware["failure"]["Type"] | _Middleware["~ClientError"]
>,
never,
| Scope.Scope
| _Payload["Context"]
Expand All @@ -128,12 +131,13 @@ export declare namespace RpcClient {
>
: Stream.Stream<
_A["Type"],
_E["Type"] | _Error["Type"] | E | _Middleware["failure"]["Type"],
_E["Type"] | _Error["Type"] | E | _Middleware["failure"]["Type"] | _Middleware["~ClientError"],
_Payload["Context"] | _Success["Context"] | _Error["Context"] | _Middleware["failure"]["Context"]
>
: Effect.Effect<
Discard extends true ? void : _Success["Type"],
Discard extends true ? E : _Error["Type"] | E | _Middleware["failure"]["Type"],
Discard extends true ? E | _Middleware["~ClientError"]
: _Error["Type"] | E | _Middleware["failure"]["Type"] | _Middleware["~ClientError"],
_Payload["Context"] | _Success["Context"] | _Error["Context"] | _Middleware["failure"]["Context"]
> :
never
Expand Down Expand Up @@ -168,7 +172,10 @@ export declare namespace RpcClient {
infer _Error,
infer _Middleware
> ? [_Success] extends [RpcSchema.Stream<infer _A, infer _E>] ? AsMailbox extends true ? Effect.Effect<
Mailbox.ReadonlyMailbox<_A["Type"], _E["Type"] | _Error["Type"] | E | _Middleware["failure"]["Type"]>,
Mailbox.ReadonlyMailbox<
_A["Type"],
_E["Type"] | _Error["Type"] | E | _Middleware["failure"]["Type"] | _Middleware["~ClientError"]
>,
never,
| Scope.Scope
| _Payload["Context"]
Expand All @@ -178,12 +185,13 @@ export declare namespace RpcClient {
>
: Stream.Stream<
_A["Type"],
_E["Type"] | _Error["Type"] | E | _Middleware["failure"]["Type"],
_E["Type"] | _Error["Type"] | E | _Middleware["failure"]["Type"] | _Middleware["~ClientError"],
_Payload["Context"] | _Success["Context"] | _Error["Context"] | _Middleware["failure"]["Context"]
>
: Effect.Effect<
Discard extends true ? void : _Success["Type"],
Discard extends true ? E : _Error["Type"] | E | _Middleware["failure"]["Type"],
Discard extends true ? E | _Middleware["~ClientError"]
: _Error["Type"] | E | _Middleware["failure"]["Type"] | _Middleware["~ClientError"],
_Payload["Context"] | _Success["Context"] | _Error["Context"] | _Middleware["failure"]["Context"]
> :
never
Expand Down Expand Up @@ -329,7 +337,7 @@ export const makeNoSerialization: <Rpcs extends Rpc.Any, E, const Flatten extend

const onEffectRequest = (
rpc: Rpc.AnyWithProps,
middleware: (request: Request<Rpcs>) => Effect.Effect<Request<Rpcs>>,
middleware: (request: Request<Rpcs>) => Effect.Effect<Request<Rpcs>, any>,
span: Span | undefined,
payload: any,
headers: Headers.Headers,
Expand All @@ -352,12 +360,15 @@ export const makeNoSerialization: <Rpcs extends Rpc.Any, E, const Flatten extend
headers: Headers.merge(parentFiber.getFiberRef(currentHeaders), headers)
})
if (discard) {
return Effect.flatMap(send, (message) =>
options.onFromClient({
message,
context,
discard
}))
return Effect.matchEffect(send, {
onFailure: () => Effect.void,
onSuccess: (message) =>
options.onFromClient({
message,
context,
discard
})
})
}
const runtime = Runtime.make({
context: parentFiber.currentContext,
Expand Down Expand Up @@ -414,7 +425,7 @@ export const makeNoSerialization: <Rpcs extends Rpc.Any, E, const Flatten extend

const onStreamRequest = Effect.fnUntraced(function*(
rpc: Rpc.AnyWithProps,
middleware: (request: Request<Rpcs>) => Effect.Effect<Request<Rpcs>>,
middleware: (request: Request<Rpcs>) => Effect.Effect<Request<Rpcs>, any>,
payload: any,
headers: Headers.Headers,
streamBufferSize: number,
Expand Down Expand Up @@ -483,8 +494,10 @@ export const makeNoSerialization: <Rpcs extends Rpc.Any, E, const Flatten extend
return mailbox
})

const getRpcClientMiddleware = (rpc: Rpc.AnyWithProps): (request: Request<Rpcs>) => Effect.Effect<Request<Rpcs>> => {
const middlewares: Array<RpcMiddleware.RpcMiddlewareClient> = []
const getRpcClientMiddleware = (
rpc: Rpc.AnyWithProps
): (request: Request<Rpcs>) => Effect.Effect<Request<Rpcs>, any> => {
const middlewares: Array<RpcMiddleware.RpcMiddlewareClient<never, any>> = []
for (const tag of rpc.middlewares.values()) {
const middleware = context.unsafeMap.get(`${tag.key}/Client`)
if (!middleware) continue
Expand All @@ -501,7 +514,7 @@ export const makeNoSerialization: <Rpcs extends Rpc.Any, E, const Flatten extend
middlewares[i]({
rpc,
request
}) as Effect.Effect<Request<Rpcs>>,
}) as Effect.Effect<Request<Rpcs>, any>,
step(nextRequest) {
request = nextRequest
i++
Expand Down
43 changes: 35 additions & 8 deletions packages/rpc/src/RpcMiddleware.ts
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,11 @@ export interface SuccessValue {
* @since 1.0.0
* @category models
*/
export interface RpcMiddlewareClient<R = never> {
export interface RpcMiddlewareClient<R = never, CE = never> {
(options: {
readonly rpc: Rpc.AnyWithProps
readonly request: Request<Rpc.Any>
}): Effect.Effect<Request<Rpc.Any>, never, R>
}): Effect.Effect<Request<Rpc.Any>, CE, R>
}

/**
Expand Down Expand Up @@ -98,12 +98,14 @@ export interface Any {
export interface TagClass<
Self,
Name extends string,
Options
Options,
ClientError
> extends
TagClass.Base<
Self,
Name,
Options,
ClientError,
TagClass.Wrap<Options> extends true ? RpcMiddlewareWrap<
TagClass.Provides<Options>,
TagClass.Failure<Options>
Expand Down Expand Up @@ -188,15 +190,22 @@ export declare namespace TagClass {
* @since 1.0.0
* @category models
*/
export interface Base<Self, Name extends string, Options, Service> extends Context.Tag<Self, Service> {
new(_: never): Context.TagClassShape<Name, Service>
export type ClientError<Options> = Options extends { readonly "~ClientError": any } ? Options["~ClientError"] : never

/**
* @since 1.0.0
* @category models
*/
export interface Base<Self, Name extends string, Options, ClientError, Service> extends Context.Tag<Self, Service> {
new(_: never): Context.TagClassShape<Name, Service> & { readonly "~ClientError": ClientError }
readonly [TypeId]: TypeId
readonly optional: Optional<Options>
readonly failure: FailureSchema<Options>
readonly provides: Options extends { readonly provides: Context.Tag<any, any> } ? Options["provides"]
: undefined
readonly requiredForClient: RequiredForClient<Options>
readonly wrap: Wrap<Options>
readonly "~ClientError": ClientError
}
}

Expand All @@ -211,6 +220,7 @@ export interface TagClassAny extends Context.Tag<any, any> {
readonly failure: Schema.Schema.All
readonly requiredForClient: boolean
readonly wrap: boolean
readonly "~ClientError": any
}

/**
Expand All @@ -224,13 +234,28 @@ export interface TagClassAnyWithProps extends Context.Tag<any, RpcMiddleware<any
readonly failure: Schema.Schema.All
readonly requiredForClient: boolean
readonly wrap: boolean
readonly "~ClientError": any
}

/**
* @since 4.0.0
* @category models
*/
export interface AnyId {
readonly [TypeId]: TypeId
readonly "~ClientError"?: any
}

/**
* @since 1.0.0
* @category tags
*/
export const Tag = <Self>(): <
export const Tag = <
Self,
const Config extends {
clientError?: any
} = { clientError: never }
>(): <
const Name extends string,
const Options extends {
readonly wrap?: boolean
Expand All @@ -242,7 +267,7 @@ export const Tag = <Self>(): <
>(
id: Name,
options?: Options | undefined
) => TagClass<Self, Name, Options> =>
) => TagClass<Self, Name, Options, "clientError" extends keyof Config ? Config["clientError"] : never> =>
(
id: string,
options?: {
Expand Down Expand Up @@ -285,7 +310,9 @@ export const Tag = <Self>(): <
*/
export const layerClient = <Id, S, R, EX = never, RX = never>(
tag: Context.Tag<Id, S>,
service: RpcMiddlewareClient<R> | Effect.Effect<RpcMiddlewareClient<R>, EX, RX>
service:
| RpcMiddlewareClient<R, TagClass.ClientError<Id>>
| Effect.Effect<RpcMiddlewareClient<R, TagClass.ClientError<Id>>, EX, RX>
): Layer.Layer<ForClient<Id>, EX, R | Exclude<RX, Scope>> =>
Layer.scopedContext(Effect.gen(function*() {
const context = (yield* Effect.context<R | Scope>()).pipe(
Expand Down