Skip to content

Commit 108fee7

Browse files
author
Daniel Casanueva
authored
Add Host API combinator (#1800)
1 parent 527e99e commit 108fee7

File tree

7 files changed

+67
-6
lines changed

7 files changed

+67
-6
lines changed

changelog.d/pr-1800

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
synopsis: Add Host API combinator
2+
packages: servant servant-client-core servant-client servant-server
3+
prs: #1800
4+
description: {
5+
Adding a Host combinator allows servant users to select APIs according
6+
to the Host header provided by clients.
7+
}

servant-client-core/src/Servant/Client/Core/HasClient.hs

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ import Servant.API
6666
ReflectMethod (..),
6767
StreamBody',
6868
Verb,
69-
getResponse, AuthProtect, BasicAuth, BasicAuthData, Capture', CaptureAll, DeepQuery, Description, Fragment, FramingRender (..), FramingUnrender (..), Header', Headers (..), HttpVersion, MimeRender (mimeRender), NoContent (NoContent), QueryFlag, QueryParam', QueryParams, QueryString, Raw, RawM, RemoteHost, ReqBody', SBoolI, Stream, Summary, ToHttpApiData, ToSourceIO (..), Vault, WithNamedContext, WithResource, WithStatus (..), contentType, getHeadersHList, toEncodedUrlPiece, NamedRoutes)
69+
getResponse, AuthProtect, BasicAuth, BasicAuthData, Capture', CaptureAll, DeepQuery, Description, Fragment, FramingRender (..), FramingUnrender (..), Header', Headers (..), HttpVersion, MimeRender (mimeRender), NoContent (NoContent), QueryFlag, QueryParam', QueryParams, QueryString, Raw, RawM, RemoteHost, ReqBody', SBoolI, Stream, Summary, ToHttpApiData, ToSourceIO (..), Vault, WithNamedContext, WithResource, WithStatus (..), contentType, getHeadersHList, toEncodedUrlPiece, NamedRoutes, Host)
7070
import Servant.API.Generic
7171
(GenericMode(..), ToServant, ToServantApi
7272
, GenericServant, toServant, fromServant)
@@ -494,6 +494,15 @@ instance (KnownSymbol sym, ToHttpApiData a, HasClient m api, SBoolI (FoldRequire
494494
hoistClientMonad pm _ f cl = \arg ->
495495
hoistClientMonad pm (Proxy :: Proxy api) f (cl arg)
496496

497+
instance (KnownSymbol sym, HasClient m api) => HasClient m (Host sym :> api) where
498+
type Client m (Host sym :> api) = Client m api
499+
500+
clientWithRoute pm Proxy req =
501+
clientWithRoute pm (Proxy :: Proxy api) $
502+
addHeader "Host" (symbolVal (Proxy :: Proxy sym)) req
503+
504+
hoistClientMonad pm _ = hoistClientMonad pm (Proxy :: Proxy api)
505+
497506
-- | Using a 'HttpVersion' combinator in your API doesn't affect the client
498507
-- functions.
499508
instance HasClient m api

servant-client/test/Servant/ClientTestUtils.hs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ import Servant.API
6868
JSON, MimeRender (mimeRender), MimeUnrender (mimeUnrender),
6969
NoContent (NoContent), PlainText, Post, QueryFlag, QueryParam,
7070
QueryParams, QueryString, Raw, ReqBody, StdMethod (GET), ToHttpApiData (..),
71-
UVerb, Union, Verb, WithStatus (WithStatus), NamedRoutes, addHeader)
71+
UVerb, Union, Verb, WithStatus (WithStatus), NamedRoutes, addHeader, Host)
7272
import Servant.API.Generic ((:-))
7373
import Servant.API.QueryString (FromDeepQuery(..), ToDeepQuery(..))
7474
import Servant.Client
@@ -221,6 +221,7 @@ type Api =
221221
:<|> NamedRoutes RecordRoutes
222222
:<|> "multiple-choices-int" :> MultipleChoicesInt
223223
:<|> "captureVerbatim" :> Capture "someString" Verbatim :> Get '[PlainText] Text
224+
:<|> "host-test" :> Host "servant.example" :> Get '[JSON] Bool
224225

225226
api :: Proxy Api
226227
api = Proxy
@@ -256,6 +257,7 @@ uverbGetCreated :: ClientM (Union '[WithStatus 201 Person])
256257
recordRoutes :: RecordRoutes (AsClientT ClientM)
257258
multiChoicesInt :: Int -> ClientM MultipleChoicesIntResult
258259
captureVerbatim :: Verbatim -> ClientM Text
260+
getHost :: ClientM Bool
259261

260262
getRoot
261263
:<|> getGet
@@ -285,7 +287,8 @@ getRoot
285287
:<|> uverbGetCreated
286288
:<|> recordRoutes
287289
:<|> multiChoicesInt
288-
:<|> captureVerbatim = client api
290+
:<|> captureVerbatim
291+
:<|> getHost = client api
289292

290293
server :: Application
291294
server = serve api (
@@ -349,6 +352,7 @@ server = serve api (
349352
)
350353

351354
:<|> pure . decodeUtf8 . unVerbatim
355+
:<|> pure True
352356
)
353357

354358
-- * api for testing failures

servant-server/src/Servant/Server/Internal.hs

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ module Servant.Server.Internal
1616
) where
1717

1818
import Control.Monad
19-
(join, when)
19+
(join, when, unless)
2020
import Control.Monad.Trans
2121
(liftIO, lift)
2222
import Control.Monad.Trans.Resource
@@ -48,13 +48,13 @@ import Network.Socket
4848
(SockAddr)
4949
import Network.Wai
5050
(Application, Request, Response, ResponseReceived, httpVersion, isSecure, lazyRequestBody,
51-
queryString, remoteHost, getRequestBodyChunk, requestHeaders,
51+
queryString, remoteHost, getRequestBodyChunk, requestHeaders, requestHeaderHost,
5252
requestMethod, responseLBS, responseStream, vault)
5353
import Servant.API
5454
((:<|>) (..), (:>), Accept (..), BasicAuth, Capture',
5555
CaptureAll, DeepQuery, Description, EmptyAPI, Fragment,
5656
FramingRender (..), FramingUnrender (..), FromSourceIO (..),
57-
Header', If, IsSecure (..), NoContentVerb, QueryFlag,
57+
Host, Header', If, IsSecure (..), NoContentVerb, QueryFlag,
5858
QueryParam', QueryParams, QueryString, Raw, RawM, ReflectMethod (reflectMethod),
5959
RemoteHost, ReqBody', SBool (..), SBoolI (..), SourceIO,
6060
Stream, StreamBody', Summary, ToSourceIO (..), Vault, Verb,
@@ -461,6 +461,30 @@ instance
461461
<> headerName
462462
<> " failed: " <> e
463463

464+
instance
465+
( KnownSymbol sym
466+
, HasServer api context
467+
, HasContextEntry (MkContextWithErrorFormatter context) ErrorFormatters
468+
) => HasServer (Host sym :> api) context where
469+
type ServerT (Host sym :> api) m = ServerT api m
470+
471+
hoistServerWithContext _ = hoistServerWithContext (Proxy :: Proxy api)
472+
473+
route _ context (Delayed {..}) = route (Proxy :: Proxy api) context $
474+
let formatError =
475+
headerParseErrorFormatter $ getContextEntry $ mkContextWithErrorFormatter context
476+
rep = typeRep (Proxy :: Proxy Host)
477+
targetHost = symbolVal (Proxy :: Proxy sym)
478+
hostCheck :: DelayedIO ()
479+
hostCheck = withRequest $ \req ->
480+
case requestHeaderHost req of
481+
Just hostBytes ->
482+
let host = BC8.unpack hostBytes
483+
in unless (host == targetHost) $
484+
delayedFail $ formatError rep req $ "Invalid host: " ++ host
485+
_ -> delayedFail $ formatError rep req "Host header missing"
486+
in Delayed { headersD = headersD <* hostCheck, .. }
487+
464488
-- | If you use @'QueryParam' "author" Text@ in one of the endpoints for your API,
465489
-- this automatically requires your server-side handler to be a function
466490
-- that takes an argument of type @'Maybe' 'Text'@.

servant/servant.cabal

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ library
8989
Servant.API.Fragment
9090
Servant.API.Generic
9191
Servant.API.Header
92+
Servant.API.Host
9293
Servant.API.HttpVersion
9394
Servant.API.IsSecure
9495
Servant.API.Modifiers

servant/src/Servant/API.hs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ module Servant.API (
1414
module Servant.API.Capture,
1515
-- | Capturing parts of the url path as parsed values: @'Capture'@ and @'CaptureAll'@
1616
module Servant.API.Header,
17+
-- | Matching the @Host@ header.
18+
module Servant.API.Host,
1719
-- | Retrieving specific headers from the request
1820
module Servant.API.HttpVersion,
1921
-- | Retrieving the HTTP version of the request
@@ -110,6 +112,7 @@ import Servant.API.Generic
110112
ToServant, ToServantApi, fromServant, genericApi, toServant)
111113
import Servant.API.Header
112114
(Header, Header')
115+
import Servant.API.Host (Host)
113116
import Servant.API.HttpVersion
114117
(HttpVersion (..))
115118
import Servant.API.IsSecure

servant/src/Servant/API/Host.hs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
module Servant.API.Host (Host) where
2+
3+
import Data.Typeable (Typeable)
4+
import GHC.TypeLits (Symbol)
5+
6+
-- | Match against the given host.
7+
--
8+
-- This allows you to define APIs over multiple domains. For example:
9+
--
10+
-- > type API = Host "api1.example" :> API1
11+
-- > :<|> Host "api2.example" :> API2
12+
--
13+
data Host (sym :: Symbol) deriving Typeable

0 commit comments

Comments
 (0)