Skip to content

Add fallback support via on error hook #2

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Jul 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 22 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ interface CircuitBreakerResult {
state: CircuitState // Current circuit breaker state
failureCount: number // Current failure count
executionTimeMs: number // Execution time in milliseconds
fallbackResponseProvided?: boolean // Whether a fallback response was provided
}
```

Expand Down Expand Up @@ -179,7 +180,10 @@ interface ProxyRequestOptions {
res: Response,
body?: ReadableStream | null,
) => void | Promise<void>
onError?: (req: Request, error: Error) => void | Promise<void>
onError?: (
req: Request,
error: Error,
) => void | Promise<void> | Promise<Response>
beforeCircuitBreakerExecution?: (
req: Request,
opts: ProxyRequestOptions,
Expand Down Expand Up @@ -547,6 +551,23 @@ proxy(req, undefined, {
})
```

#### Returning Fallback Responses

You can return a fallback response from the `onError` hook by resolving the hook with a `Response` object. This allows you to customize the error response sent to the client.

```typescript
proxy(req, undefined, {
onError: async (req, error) => {
// Log error
console.error("Proxy error:", error)

// Return a fallback response
console.log("Returning fallback response for:", req.url)
return new Response("Fallback response", { status: 200 })
},
})
```

## Performance Tips

1. **URL Caching**: Keep `cacheURLs` enabled (default 100) for better performance
Expand Down
3 changes: 2 additions & 1 deletion package.json
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@
"clean": "rm -rf lib/",
"prepublishOnly": "bun run clean && bun run build",
"example:benchmark": "bun run examples/local-gateway-server.ts",
"deploy": "bun run prepublishOnly && bun publish"
"deploy": "bun run prepublishOnly && bun publish",
"actions": "DOCKER_HOST=$(docker context inspect --format '{{.Endpoints.docker.Host}}') act"
},
"repository": {
"type": "git",
Expand Down
10 changes: 8 additions & 2 deletions src/proxy.ts
Original file line number Diff line number Diff line change
Expand Up @@ -166,8 +166,9 @@ export class FetchProxy {
currentLogger.logRequestError(req, err, { requestId, executionTime })

// Execute error hooks
let fallbackResponse: Response | void = undefined
if (options.onError) {
await options.onError(req, err)
fallbackResponse = await options.onError(req, err)
}

// Execute circuit breaker completion hooks for failures
Expand All @@ -179,12 +180,17 @@ export class FetchProxy {
state: this.circuitBreaker.getState(),
failureCount: this.circuitBreaker.getFailures(),
executionTimeMs: executionTime,
fallbackResponseProvided: fallbackResponse instanceof Response,
},
options,
)

if (fallbackResponse instanceof Response) {
// If onError provided a fallback response, return it
return fallbackResponse
}
// Return appropriate error response
if (err.message.includes("Circuit breaker is OPEN")) {
else if (err.message.includes("Circuit breaker is OPEN")) {
return new Response("Service Unavailable", { status: 503 })
} else if (
err.message.includes("timeout") ||
Expand Down
6 changes: 5 additions & 1 deletion src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,10 @@ export type AfterCircuitBreakerHook = (
result: CircuitBreakerResult,
) => void | Promise<void>

export type ErrorHook = (req: Request, error: Error) => void | Promise<void>
export type ErrorHook = (
req: Request,
error: Error,
) => void | Promise<void> | Promise<Response>

// Circuit breaker result information
export interface CircuitBreakerResult {
Expand All @@ -92,6 +95,7 @@ export interface CircuitBreakerResult {
state: CircuitState
failureCount: number
executionTimeMs: number
fallbackResponseProvided?: boolean
}

export enum CircuitState {
Expand Down
6 changes: 1 addition & 5 deletions tests/dos-prevention.test.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,4 @@
import { afterAll, describe, expect, test, mock } from "bun:test"

afterAll(() => {
mock.restore()
})
import { describe, expect, test } from "bun:test"

describe("DoS and Resource Exhaustion Security Tests", () => {
describe("Request Parameter Validation", () => {
Expand Down
26 changes: 13 additions & 13 deletions tests/enhanced-hooks.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,17 @@ import {
beforeEach,
jest,
afterAll,
mock,
spyOn,
} from "bun:test"
import { FetchProxy } from "../src/proxy"
import { CircuitState } from "../src/types"
import type { ProxyRequestOptions, CircuitBreakerResult } from "../src/types"

// Mock fetch for testing
const mockFetch = jest.fn()
;(global as any).fetch = mockFetch
// Spy on fetch for testing
let fetchSpy: ReturnType<typeof spyOn>

afterAll(() => {
mock.restore()
fetchSpy?.mockRestore()
})

describe("Enhanced Hook Naming Conventions", () => {
Expand All @@ -39,8 +38,9 @@ describe("Enhanced Hook Naming Conventions", () => {
headers: new Headers({ "content-type": "application/json" }),
})

mockFetch.mockClear()
mockFetch.mockResolvedValue(mockResponse)
fetchSpy = spyOn(global, "fetch")
fetchSpy.mockClear()
fetchSpy.mockResolvedValue(mockResponse)
})

describe("beforeRequest Hook", () => {
Expand All @@ -56,7 +56,7 @@ describe("Enhanced Hook Naming Conventions", () => {

expect(beforeRequestHook).toHaveBeenCalledTimes(1)
expect(beforeRequestHook).toHaveBeenCalledWith(request, options)
expect(mockFetch).toHaveBeenCalledTimes(1)
expect(fetchSpy).toHaveBeenCalledTimes(1)
})

it("should handle async beforeRequest hooks", async () => {
Expand Down Expand Up @@ -139,7 +139,7 @@ describe("Enhanced Hook Naming Conventions", () => {
const request = new Request("https://example.com/test")
const error = new Error("Network error")

mockFetch.mockRejectedValueOnce(error)
fetchSpy.mockRejectedValueOnce(error)

const options: ProxyRequestOptions = {
afterCircuitBreakerExecution: afterCircuitBreakerHook,
Expand All @@ -166,7 +166,7 @@ describe("Enhanced Hook Naming Conventions", () => {
const request = new Request("https://example.com/test")

// Add some delay to the fetch
mockFetch.mockImplementationOnce(
fetchSpy.mockImplementationOnce(
() =>
new Promise((resolve) => setTimeout(() => resolve(mockResponse), 50)),
)
Expand Down Expand Up @@ -296,8 +296,8 @@ describe("Enhanced Hook Naming Conventions", () => {
await proxy.proxy(request, undefined, options)

// Verify the mock was called (we can't easily verify exact headers due to internal processing)
expect(mockFetch).toHaveBeenCalledTimes(1)
expect(mockFetch).toHaveBeenCalledWith(
expect(fetchSpy).toHaveBeenCalledTimes(1)
expect(fetchSpy).toHaveBeenCalledWith(
expect.any(String),
expect.objectContaining({
headers: expect.any(Headers),
Expand All @@ -315,7 +315,7 @@ describe("Enhanced Hook Naming Conventions", () => {
},
})

mockFetch.mockResolvedValueOnce(originalResponse)
fetchSpy.mockResolvedValueOnce(originalResponse)

const request = new Request("https://example.com/test")

Expand Down
6 changes: 1 addition & 5 deletions tests/header-injection.test.ts
Original file line number Diff line number Diff line change
@@ -1,14 +1,10 @@
/**
* Security tests for header injection vulnerabilities
*/
import { describe, expect, it, afterAll, mock } from "bun:test"
import { describe, expect, it } from "bun:test"

import { recordToHeaders } from "../src/utils"

afterAll(() => {
mock.restore()
})

describe("Header Injection Security Tests", () => {
describe("CRLF Header Injection", () => {
it("should reject header names with CRLF characters", () => {
Expand Down
26 changes: 21 additions & 5 deletions tests/http-method-validation.test.ts
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import { describe, it, expect, beforeEach, afterAll, mock } from "bun:test"
import { describe, it, expect, beforeEach, spyOn, afterEach } from "bun:test"
import { validateHttpMethod } from "../src/utils"
import { FetchProxy } from "../src/proxy"

afterAll(() => {
mock.restore()
})

describe("HTTP Method Validation Security Tests", () => {
let fetchSpy: ReturnType<typeof spyOn>

afterEach(() => {
fetchSpy?.mockRestore()
})
describe("Direct Method Validation", () => {
it("should reject CONNECT method", () => {
expect(() => {
Expand Down Expand Up @@ -75,6 +76,15 @@ describe("HTTP Method Validation Security Tests", () => {
base: "http://httpbin.org", // Use a real service for testing
circuitBreaker: { enabled: false },
})

// Mock fetch to return a successful response
fetchSpy = spyOn(global, "fetch").mockResolvedValue(
new Response("", {
status: 200,
statusText: "OK",
headers: new Headers({ "content-type": "text/plain" }),
}),
)
})

it("should reject CONNECT method in proxy (if runtime allows it)", async () => {
Expand Down Expand Up @@ -107,6 +117,9 @@ describe("HTTP Method Validation Security Tests", () => {
// The normalized request should work fine
const response = await proxy.proxy(request)
expect(response.status).toBe(200)

// Verify fetch was called
expect(fetchSpy).toHaveBeenCalledTimes(1)
})

it("should allow safe methods in proxy", async () => {
Expand All @@ -116,6 +129,9 @@ describe("HTTP Method Validation Security Tests", () => {

const response = await proxy.proxy(request)
expect(response.status).toBe(200)

// Verify fetch was called
expect(fetchSpy).toHaveBeenCalledTimes(1)
})

it("should validate methods when passed through request options", async () => {
Expand Down
15 changes: 7 additions & 8 deletions tests/index.test.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { describe, it, expect, beforeAll, afterAll, mock } from "bun:test"
import { describe, it, expect, beforeAll, afterAll, spyOn } from "bun:test"
import createFetchGate, { FetchProxy } from "../src/index"
import {
buildURL,
Expand Down Expand Up @@ -55,7 +55,7 @@ describe("fetch-gate", () => {

afterAll(() => {
server?.stop()
mock.restore()
// No need for explicit restore with spyOn as it's automatically cleaned up
})

describe("createFetchGate", () => {
Expand Down Expand Up @@ -308,10 +308,9 @@ describe("fetch-gate", () => {

describe("Circuit Breaker Edge Cases", () => {
it("should transition to HALF_OPEN state after reset timeout", async () => {
// Custom mock for Date.now()
const originalDateNow = Date.now
let now = originalDateNow()
global.Date.now = () => now
// Spy on Date.now()
let now = Date.now()
const dateNowSpy = spyOn(Date, "now").mockImplementation(() => now)

const circuitBreaker = new CircuitBreaker({
failureThreshold: 1,
Expand All @@ -330,8 +329,8 @@ describe("fetch-gate", () => {

expect(circuitBreaker.getState()).toBe(CircuitState.HALF_OPEN)

// Restore original Date.now()
global.Date.now = originalDateNow
// Restore Date.now() spy
dateNowSpy.mockRestore()
})

it("should reset failures after successful execution in HALF_OPEN state", async () => {
Expand Down
16 changes: 4 additions & 12 deletions tests/logging.test.ts
Original file line number Diff line number Diff line change
@@ -1,12 +1,4 @@
import {
describe,
expect,
it,
beforeEach,
spyOn,
afterAll,
mock,
} from "bun:test"
import { describe, expect, it, beforeEach, spyOn, afterAll } from "bun:test"
import { FetchProxy } from "../src/proxy"
import {
ProxyLogger,
Expand All @@ -15,11 +7,11 @@ import {
} from "../src/logger"
import { CircuitState } from "../src/types"

// Mock fetch for testing
const originalFetch = global.fetch
// Spy on fetch for testing
let fetchSpy: ReturnType<typeof spyOn>

afterAll(() => {
mock.restore()
fetchSpy?.mockRestore()
})

describe("Logging Integration", () => {
Expand Down
6 changes: 1 addition & 5 deletions tests/path-traversal.test.ts
Original file line number Diff line number Diff line change
@@ -1,10 +1,6 @@
import { describe, expect, test, mock, afterAll } from "bun:test"
import { describe, expect, test } from "bun:test"
import { normalizeSecurePath } from "../src/utils"

afterAll(() => {
mock.restore()
})

describe("Path Traversal Security", () => {
describe("normalizeSecurePath", () => {
test("should normalize simple valid paths", () => {
Expand Down
Loading