From b49a9638d4d1ce0e0b5b5147adca9258f7605469 Mon Sep 17 00:00:00 2001 From: Alexandre Rodrigues <52402833+alexandre-kivra@users.noreply.github.com> Date: Fri, 17 Apr 2020 08:53:57 +0200 Subject: [PATCH] Support client credentials in body (#21) * fix: correct name is expires_in * feat: refactor and add credential_in_body option * fix: reinstate expire_time --- include/oauth2c.hrl | 21 ++-- src/oauth2c.erl | 192 +++++++++++++++--------------------- src/oauth2c_token_cache.erl | 17 ++-- test/oauth2c_SUITE.erl | 59 ++++++++--- 4 files changed, 142 insertions(+), 147 deletions(-) diff --git a/include/oauth2c.hrl b/include/oauth2c.hrl index f58baf4..3efdd2e 100644 --- a/include/oauth2c.hrl +++ b/include/oauth2c.hrl @@ -1,14 +1,13 @@ --record(client, { - grant_type = undefined :: binary() | undefined, - auth_url = undefined :: binary() | undefined, - access_token = undefined :: binary() | undefined, - token_type = undefined :: token_type() | undefined, - refresh_token = undefined :: binary() | undefined, - id = undefined :: binary() | undefined, - secret = undefined :: binary() | undefined, - scope = undefined :: binary() | undefined, - expiry_time = undefined :: integer() | undefined -}). +-record(client, {grant_type = undefined :: binary() | undefined, + auth_url = undefined :: binary() | undefined, + access_token = undefined :: binary() | undefined, + token_type = undefined :: token_type() | undefined, + refresh_token = undefined :: binary() | undefined, + id = undefined :: binary() | undefined, + secret = undefined :: binary() | undefined, + scope = undefined :: binary() | undefined, + expire_time = undefined :: integer() | undefined + }). -type method() :: head | get | diff --git a/src/oauth2c.erl b/src/oauth2c.erl index 21eeaec..4ffd38c 100644 --- a/src/oauth2c.erl +++ b/src/oauth2c.erl @@ -97,13 +97,12 @@ retrieve_access_token(Type, Url, ID, Secret, Scope) -> Scope :: binary() | undefined, Options :: options(). retrieve_access_token(Type, Url, ID, Secret, Scope, Options) -> - Client = #client{ - grant_type = Type - ,auth_url = Url - ,id = ID - ,secret = Secret - ,scope = Scope - }, + Client = #client{ grant_type = Type + , auth_url = Url + , id = ID + , secret = Secret + , scope = Scope + }, do_retrieve_access_token(Client, Options). -spec request(Method, Url, Client) -> Response::response() when @@ -185,111 +184,33 @@ ensure_client_has_access_token(Client0, Options) -> Client0 end. -do_retrieve_access_token(#client{grant_type = <<"password">>} = Client, Opts) -> - Payload0 = [ - {<<"grant_type">>, Client#client.grant_type} - ,{<<"username">>, Client#client.id} - ,{<<"password">>, Client#client.secret} - ], - Payload = case Client#client.scope of - undefined -> Payload0; - Scope -> [{<<"scope">>, Scope}|Payload0] - end, - Response = - restc:request(post, percent, Client#client.auth_url, [200], [], - Payload, Opts), - case Response of - {ok, _, Headers, Body} -> - AccessToken = proplists:get_value(<<"access_token">>, Body), - RefreshToken = proplists:get_value(<<"refresh_token">>, Body), - ExpiryTime = proplists:get_value(<<"expiry_time">>, Body), - Result = case RefreshToken of - undefined -> - #client{ - grant_type = Client#client.grant_type - ,auth_url = Client#client.auth_url - ,access_token = AccessToken - ,id = Client#client.id - ,secret = Client#client.secret - ,scope = Client#client.scope - ,expiry_time = ExpiryTime - }; - _ -> - #client{ - grant_type = Client#client.grant_type - ,auth_url = Client#client.auth_url - ,access_token = AccessToken - ,refresh_token = RefreshToken - ,scope = Client#client.scope - ,expiry_time = ExpiryTime - } - end, - {ok, Headers, Result}; - {error, _, _, Reason} -> - {error, Reason}; - {error, Reason} -> - {error, Reason} - end; -do_retrieve_access_token(#client{grant_type = <<"client_credentials">>, - id = Id, secret = Secret} = Client, Opts) -> - Payload0 = [{<<"grant_type">>, Client#client.grant_type}], - Payload = case Client#client.scope of - undefined -> - Payload0; - Scope -> - [{<<"scope">>, Scope}|Payload0] - end, - Auth = base64:encode(<>), - Header = [{<<"Authorization">>, <<"Basic ", Auth/binary>>}], +do_retrieve_access_token(Client, Opts) -> + #{headers := RequestHeaders, + body := RequestBody} = prepare_token_request(Client, Opts), case restc:request(post, percent, Client#client.auth_url, - [200], Header, Payload, Opts) of - {ok, _, Headers, Body} -> - AccessToken = proplists:get_value(<<"access_token">>, Body), - TokenType = proplists:get_value(<<"token_type">>, Body, ""), - ExpiryTime = proplists:get_value(<<"expiry_time">>, Body), - Result = #client{ - grant_type = Client#client.grant_type - ,auth_url = Client#client.auth_url - ,access_token = AccessToken - ,token_type = get_token_type(TokenType) - ,id = Client#client.id - ,secret = Client#client.secret - ,scope = Client#client.scope - ,expiry_time = ExpiryTime - }, - {ok, Headers, Result}; - {error, _, _, Reason} -> - {error, Reason}; - {error, Reason} -> - {error, Reason} - end; -do_retrieve_access_token(#client{grant_type = <<"azure_client_credentials">>, - id = Id, secret = Secret} = Client, Opts) -> - Payload0 = [{<<"grant_type">>, <<"client_credentials">>}, - {<<"client_id">>, Id}, - {<<"client_secret">>, Secret}], - Payload = case Client#client.scope of - undefined -> - Payload0; - Scope -> - [{<<"resource">>, Scope}|Payload0] - end, - case restc:request(post, percent, Client#client.auth_url, - [200], [], Payload, Opts) of + [200], RequestHeaders, RequestBody, Opts) + of {ok, _, Headers, Body} -> AccessToken = proplists:get_value(<<"access_token">>, Body), TokenType = proplists:get_value(<<"token_type">>, Body, ""), - ExpiryTime = proplists:get_value(<<"expiry_time">>, Body), - Result = #client{ - grant_type = Client#client.grant_type - ,auth_url = Client#client.auth_url - ,access_token = AccessToken - ,token_type = get_token_type(TokenType) - ,id = Client#client.id - ,secret = Client#client.secret - ,scope = Client#client.scope - ,expiry_time = ExpiryTime - }, + ExpireTime = + case proplists:get_value(<<"expires_in">>, Body) of + undefined -> undefined; + ExpiresIn -> erlang:system_time(second) + ExpiresIn + end, + RefreshToken = proplists:get_value(<<"refresh_token">>, + Body, + Client#client.refresh_token), + Result = #client{ grant_type = Client#client.grant_type + , auth_url = Client#client.auth_url + , access_token = AccessToken + , refresh_token = RefreshToken + , token_type = get_token_type(TokenType) + , id = Client#client.id + , secret = Client#client.secret + , scope = Client#client.scope + , expire_time = ExpireTime + }, {ok, Headers, Result}; {error, _, _, Reason} -> {error, Reason}; @@ -297,6 +218,55 @@ do_retrieve_access_token(#client{grant_type = <<"azure_client_credentials">>, {error, Reason} end. +prepare_token_request(Client, Opts) -> + BaseRequest = base_request(Client), + Request0 = add_client(BaseRequest, Client, Opts), + add_fields(Request0, Client). + +base_request(#client{grant_type = <<"azure_client_credentials">>}) -> + #{headers => [], body => [{<<"grant_type">>, <<"client_credentials">>}]}; +base_request(#client{grant_type = GrantType}) -> + #{headers => [], body => [{<<"grant_type">>, GrantType}]}. + +add_client(Request0, Client, Opts) -> + #client{id = Id, secret = Secret} = Client, + case + {Client#client.grant_type =:= <<"password">>, + Client#client.grant_type =:= <<"azure_client_credentials">> orelse + proplists:get_value(credentials_in_body, Opts, false)} + of + {false, false} -> + #{headers := Headers0} = Request0, + Auth = base64:encode(<>), + Headers = [{<<"Authorization">>, <<"Basic ", Auth/binary>>} + | Headers0], + Request0#{headers => Headers}; + {false, true} -> + #{body := Body} = Request0, + Request0#{body => [{<<"client_id">>, Id}, + {<<"client_secret">>, Secret} + | Body]}; + %% This clause is to still support password grant "as is" but + %% in the future this should be changed in order to support + %% client authentication in the password grant. Right now we + %% are assuming that if the grant is password then the client is public + %% which is not a fair assumption. + {true, _} -> + #{body := Body} = Request0, + Request0#{body => [{<<"username">>, Id}, + {<<"password">>, Secret} | Body]} + end. + +add_fields(Request, #client{scope=undefined}) -> + Request; +add_fields(Request, #client{grant_type = <<"azure_client_credentials">>, + scope = Scope}) -> + #{body := Body} = Request, + Request#{body => [{<<"resource">>, Scope} | Body]}; +add_fields(Request, #client{scope = Scope}) -> + #{body := Body} = Request, + Request#{body => [{<<"scope">>, Scope} | Body]}. + -spec get_token_type(binary()) -> token_type(). get_token_type(Type) -> get_str_token_type(string:to_lower(binary_to_list(Type))). @@ -324,12 +294,12 @@ add_auth_header(Headers, #client{access_token = AccessToken}) -> retrieve_access_token_fun(Client0, Options) -> fun() -> case do_retrieve_access_token(Client0, Options) of - {ok, _Headers, Client} -> {ok, Client, Client#client.expiry_time}; + {ok, _Headers, Client} -> {ok, Client, Client#client.expire_time}; {error, Reason} -> {error, Reason} end end. -get_access_token(#client{expiry_time = ExpiryTime} = Client0, Options) -> +get_access_token(#client{expire_time = ExpireTime} = Client0, Options) -> case {proplists:get_value(cache_token, Options, false), proplists:get_value(force_revalidate, Options, false)} of @@ -348,7 +318,7 @@ get_access_token(#client{expiry_time = ExpiryTime} = Client0, Options) -> {true, true} -> Key = hash_client(Client0), RevalidateFun = retrieve_access_token_fun(Client0, Options), - oauth2c_token_cache:set_and_get(Key, RevalidateFun, ExpiryTime) + oauth2c_token_cache:set_and_get(Key, RevalidateFun, ExpireTime) end. hash_client(#client{grant_type = Type, diff --git a/src/oauth2c_token_cache.erl b/src/oauth2c_token_cache.erl index 37e50ca..3f51a97 100644 --- a/src/oauth2c_token_cache.erl +++ b/src/oauth2c_token_cache.erl @@ -97,8 +97,8 @@ handle_call({set_and_get, Key, LazyValue, {reply, {ok, Result}, State}; {error, not_found} -> case LazyValue() of - {ok, Result, ExpiryTime0} -> - ExpiryTime = get_expiry_time(ExpiryTime0, DefaultTTL), + {ok, Result, ExpireTime} -> + ExpiryTime = get_expire_time(ExpireTime, DefaultTTL), ets:insert(?TOKEN_CACHE_ID, {Key, Result, ExpiryTime}), {reply, {ok, Result}, State}; {error, Reason} -> {reply, {error, Reason}, State} @@ -113,8 +113,7 @@ get_token(Key) -> get_token(Key, undefined). get_token(Key, ExpiryTimeLowerLimit) -> Now = erlang:system_time(second), - case ets:lookup(?TOKEN_CACHE_ID, Key) - of + case ets:lookup(?TOKEN_CACHE_ID, Key) of % Only return cache entry if % (1) It has not expired % (2) Its expiry time is greater than ExpiryTimeLowerLimit @@ -128,20 +127,20 @@ get_token(Key, ExpiryTimeLowerLimit) -> {error, not_found} end. -get_expiry_time(undefined, DefaultTTL) -> +get_expire_time(undefined, DefaultTTL) -> erlang:system_time(second) + DefaultTTL; -get_expiry_time(ExpiryTime, _DefaultTTL) -> - ExpiryTime. +get_expire_time(ExpireTime, _DefaultTTL) -> + ExpireTime. %%%_ * Tests ------------------------------------------------------- -ifdef(TEST). -include_lib("eunit/include/eunit.hrl"). -get_expiry_time_test_() -> +get_expires_in_test_() -> [fun() -> {T, Default} = Input, - Actual = get_expiry_time(T, Default), + Actual = get_expire_time(T, Default), ?assertEqual(Expected, Actual) end || {Input, Expected} <- [{{1, 100}, 1}] diff --git a/test/oauth2c_SUITE.erl b/test/oauth2c_SUITE.erl index 8b17246..cebfa0f 100644 --- a/test/oauth2c_SUITE.erl +++ b/test/oauth2c_SUITE.erl @@ -18,7 +18,9 @@ groups() -> []. -all() -> [ retrieve_access_token +all() -> [ client_credentials_in_body + , client_credentials_in_header + , retrieve_access_token , fetch_access_token_on_request , fetch_access_token_on_request , fetch_new_token_on_401 @@ -52,6 +54,34 @@ end_per_testcase(_TestCase, Config) -> oauth2c_token_cache:clear(), Config. +client_credentials_in_body(_Config) -> + oauth2c:retrieve_access_token(?CLIENT_CREDENTIALS_GRANT, + ?AUTH_URL, + <<"ID">>, + <<"SECRET">>, + undefined, + [credentials_in_body]), + ?assert(meck:called(restc, request, [post, + percent, + ?AUTH_URL, + '_', + [], %% empty headers + ['_', '_', '_'], %% grant_type + creds + '_'])). + +client_credentials_in_header(_Config) -> + oauth2c:retrieve_access_token(?CLIENT_CREDENTIALS_GRANT, + ?AUTH_URL, + <<"ID">>, + <<"SECRET">>), + ?assert(meck:called(restc, request, [post, + percent, + ?AUTH_URL, + '_', + ['_'], %% credentials + ['_'], %% grant_type + '_'])). + retrieve_access_token(_Config) -> Response = oauth2c:retrieve_access_token(?CLIENT_CREDENTIALS_GRANT, ?AUTH_URL, @@ -69,14 +99,12 @@ retrieve_cached_access_token(_Config) -> ])). retrieve_cached_expired_access_token(_Config) -> - Client0 = client(?AUTH_URL), - oauth2c:request(get, json, ?REQUEST_URL, [], [], [], [cache_token], Client0), + Client = client(?AUTH_URL), + oauth2c:request(get, json, ?REQUEST_URL, [], [], [], [cache_token], Client), % TTL is 1000ms for a cached entry, hence sleeping for 1050ms should % make the cached entry invalid. timer:sleep(1050), - % access_token is set to undefined to force oauth2c:request to look into cache - Client1 = Client0#client{access_token = undefined}, - oauth2c:request(get, json, ?REQUEST_URL, [], [], [], [cache_token], Client1), + oauth2c:request(get, json, ?REQUEST_URL, [], [], [], [cache_token], Client), ?assertEqual(2, meck:num_calls(restc, request, [ post, percent, ?AUTH_URL, '_', '_', '_', '_' @@ -146,11 +174,11 @@ retrieve_cached_token_on_401(_Config) -> ?assertMatch({{ok, 200, _, _}, _}, Response1), {_, Client1} = Response1, {_, Client2} = Response2, - ?assert(Client1#client.expiry_time < Client2#client.expiry_time). + ?assert(Client1#client.expire_time < Client2#client.expire_time). retrieve_cached_token_on_401_burst(_Config) -> Client = client(?AUTH_URL), - % First call to request will return a access token with expiry_time X, + % First call to request will return a access token with expires_in X, % and this token will be cached. {{ok, 200, _, _}, Client1} = oauth2c:request(get, json, ?REQUEST_URL, [], [], [], [cache_token], Client), @@ -181,7 +209,7 @@ retrieve_cached_token_on_401_burst(_Config) -> % processes, {{ok, 401, _, _}, Client2} = oauth2c:request(get, json, ?REQUEST_URL, [], [], [], [cache_token], Client1), - ?assert(Client1#client.expiry_time < Client2#client.expiry_time). + ?assert(Client1#client.expire_time < Client2#client.expire_time). fetch_access_token_and_do_request(_Config) -> {ok, _, Client} = oauth2c:retrieve_access_token(?CLIENT_CREDENTIALS_GRANT, @@ -223,7 +251,7 @@ mock_http_requests() -> meck:expect(restc, request, fun(post, percent, ?AUTH_URL, [200], _, _, _) -> Body = [{<<"access_token">>, ?VALID_TOKEN}, - {<<"expiry_time">>, erlang:system_time(second) + 1}, + {<<"expires_in">>, 1}, {<<"token_type">>, <<"bearer">>}], {ok, 200, [], Body}; (post, percent, ?INVALID_TOKEN_AUTH_URL, [200], _, _, _) -> @@ -246,10 +274,10 @@ mock_http_request_401() -> {[post, percent, ?AUTH_URL, [200], '_', '_', '_'], meck:seq([ {ok, 200, [], [{<<"access_token">>, <<"token1">>}, - {<<"expiry_time">>, erlang:system_time(second) + 1}, + {<<"expires_in">>, 1}, {<<"token_type">>, <<"bearer">>}]}, {ok, 200, [], [{<<"access_token">>, <<"token2">>}, - {<<"expiry_time">>, erlang:system_time(second) + 10}, + {<<"expires_in">>, 10}, {<<"token_type">>, <<"bearer">>}]} ]) }, @@ -258,7 +286,7 @@ mock_http_request_401() -> {ok, 200, [], [{<<"access_token">>, <<"invalid">>}, {<<"token_type">>, <<"bearer">>}]}, {ok, 401, [], [{<<"access_token">>, ?VALID_TOKEN}, - {<<"expiry_time">>, erlang:system_time(second) + 1}, + {<<"expires_in">>, 1}, {<<"token_type">>, <<"bearer">>}]} ]) } @@ -266,16 +294,15 @@ mock_http_request_401() -> ). mock_http_request_401_burst() -> - Now = erlang:system_time(second), meck:expect(restc, request, [ {[post, percent, ?AUTH_URL, [200], '_', '_', '_'], meck:seq([ {ok, 200, [], [{<<"access_token">>, ?VALID_TOKEN}, - {<<"expiry_time">>, Now + 10}, + {<<"expires_in">>, 10}, {<<"token_type">>, <<"bearer">>}]}, {ok, 200, [], [{<<"access_token">>, ?VALID_TOKEN}, - {<<"expiry_time">>, Now + 20}, + {<<"expires_in">>, 20}, {<<"token_type">>, <<"bearer">>}]} ]) },