Skip to content

Commit

Permalink
Support client credentials in body (#21)
Browse files Browse the repository at this point in the history
* fix: correct name is expires_in

* feat: refactor and add credential_in_body option

* fix: reinstate expire_time
  • Loading branch information
alexandre-kivra authored Apr 17, 2020
1 parent c89d6ec commit b49a963
Show file tree
Hide file tree
Showing 4 changed files with 142 additions and 147 deletions.
21 changes: 10 additions & 11 deletions include/oauth2c.hrl
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
-record(client, {
grant_type = undefined :: binary() | undefined,
auth_url = undefined :: binary() | undefined,
access_token = undefined :: binary() | undefined,
token_type = undefined :: token_type() | undefined,
refresh_token = undefined :: binary() | undefined,
id = undefined :: binary() | undefined,
secret = undefined :: binary() | undefined,
scope = undefined :: binary() | undefined,
expiry_time = undefined :: integer() | undefined
}).
-record(client, {grant_type = undefined :: binary() | undefined,
auth_url = undefined :: binary() | undefined,
access_token = undefined :: binary() | undefined,
token_type = undefined :: token_type() | undefined,
refresh_token = undefined :: binary() | undefined,
id = undefined :: binary() | undefined,
secret = undefined :: binary() | undefined,
scope = undefined :: binary() | undefined,
expire_time = undefined :: integer() | undefined
}).

-type method() :: head |
get |
Expand Down
192 changes: 81 additions & 111 deletions src/oauth2c.erl
Original file line number Diff line number Diff line change
Expand Up @@ -97,13 +97,12 @@ retrieve_access_token(Type, Url, ID, Secret, Scope) ->
Scope :: binary() | undefined,
Options :: options().
retrieve_access_token(Type, Url, ID, Secret, Scope, Options) ->
Client = #client{
grant_type = Type
,auth_url = Url
,id = ID
,secret = Secret
,scope = Scope
},
Client = #client{ grant_type = Type
, auth_url = Url
, id = ID
, secret = Secret
, scope = Scope
},
do_retrieve_access_token(Client, Options).

-spec request(Method, Url, Client) -> Response::response() when
Expand Down Expand Up @@ -185,118 +184,89 @@ ensure_client_has_access_token(Client0, Options) ->
Client0
end.

do_retrieve_access_token(#client{grant_type = <<"password">>} = Client, Opts) ->
Payload0 = [
{<<"grant_type">>, Client#client.grant_type}
,{<<"username">>, Client#client.id}
,{<<"password">>, Client#client.secret}
],
Payload = case Client#client.scope of
undefined -> Payload0;
Scope -> [{<<"scope">>, Scope}|Payload0]
end,
Response =
restc:request(post, percent, Client#client.auth_url, [200], [],
Payload, Opts),
case Response of
{ok, _, Headers, Body} ->
AccessToken = proplists:get_value(<<"access_token">>, Body),
RefreshToken = proplists:get_value(<<"refresh_token">>, Body),
ExpiryTime = proplists:get_value(<<"expiry_time">>, Body),
Result = case RefreshToken of
undefined ->
#client{
grant_type = Client#client.grant_type
,auth_url = Client#client.auth_url
,access_token = AccessToken
,id = Client#client.id
,secret = Client#client.secret
,scope = Client#client.scope
,expiry_time = ExpiryTime
};
_ ->
#client{
grant_type = Client#client.grant_type
,auth_url = Client#client.auth_url
,access_token = AccessToken
,refresh_token = RefreshToken
,scope = Client#client.scope
,expiry_time = ExpiryTime
}
end,
{ok, Headers, Result};
{error, _, _, Reason} ->
{error, Reason};
{error, Reason} ->
{error, Reason}
end;
do_retrieve_access_token(#client{grant_type = <<"client_credentials">>,
id = Id, secret = Secret} = Client, Opts) ->
Payload0 = [{<<"grant_type">>, Client#client.grant_type}],
Payload = case Client#client.scope of
undefined ->
Payload0;
Scope ->
[{<<"scope">>, Scope}|Payload0]
end,
Auth = base64:encode(<<Id/binary, ":", Secret/binary>>),
Header = [{<<"Authorization">>, <<"Basic ", Auth/binary>>}],
do_retrieve_access_token(Client, Opts) ->
#{headers := RequestHeaders,
body := RequestBody} = prepare_token_request(Client, Opts),
case restc:request(post, percent, Client#client.auth_url,
[200], Header, Payload, Opts) of
{ok, _, Headers, Body} ->
AccessToken = proplists:get_value(<<"access_token">>, Body),
TokenType = proplists:get_value(<<"token_type">>, Body, ""),
ExpiryTime = proplists:get_value(<<"expiry_time">>, Body),
Result = #client{
grant_type = Client#client.grant_type
,auth_url = Client#client.auth_url
,access_token = AccessToken
,token_type = get_token_type(TokenType)
,id = Client#client.id
,secret = Client#client.secret
,scope = Client#client.scope
,expiry_time = ExpiryTime
},
{ok, Headers, Result};
{error, _, _, Reason} ->
{error, Reason};
{error, Reason} ->
{error, Reason}
end;
do_retrieve_access_token(#client{grant_type = <<"azure_client_credentials">>,
id = Id, secret = Secret} = Client, Opts) ->
Payload0 = [{<<"grant_type">>, <<"client_credentials">>},
{<<"client_id">>, Id},
{<<"client_secret">>, Secret}],
Payload = case Client#client.scope of
undefined ->
Payload0;
Scope ->
[{<<"resource">>, Scope}|Payload0]
end,
case restc:request(post, percent, Client#client.auth_url,
[200], [], Payload, Opts) of
[200], RequestHeaders, RequestBody, Opts)
of
{ok, _, Headers, Body} ->
AccessToken = proplists:get_value(<<"access_token">>, Body),
TokenType = proplists:get_value(<<"token_type">>, Body, ""),
ExpiryTime = proplists:get_value(<<"expiry_time">>, Body),
Result = #client{
grant_type = Client#client.grant_type
,auth_url = Client#client.auth_url
,access_token = AccessToken
,token_type = get_token_type(TokenType)
,id = Client#client.id
,secret = Client#client.secret
,scope = Client#client.scope
,expiry_time = ExpiryTime
},
ExpireTime =
case proplists:get_value(<<"expires_in">>, Body) of
undefined -> undefined;
ExpiresIn -> erlang:system_time(second) + ExpiresIn
end,
RefreshToken = proplists:get_value(<<"refresh_token">>,
Body,
Client#client.refresh_token),
Result = #client{ grant_type = Client#client.grant_type
, auth_url = Client#client.auth_url
, access_token = AccessToken
, refresh_token = RefreshToken
, token_type = get_token_type(TokenType)
, id = Client#client.id
, secret = Client#client.secret
, scope = Client#client.scope
, expire_time = ExpireTime
},
{ok, Headers, Result};
{error, _, _, Reason} ->
{error, Reason};
{error, Reason} ->
{error, Reason}
end.

prepare_token_request(Client, Opts) ->
BaseRequest = base_request(Client),
Request0 = add_client(BaseRequest, Client, Opts),
add_fields(Request0, Client).

base_request(#client{grant_type = <<"azure_client_credentials">>}) ->
#{headers => [], body => [{<<"grant_type">>, <<"client_credentials">>}]};
base_request(#client{grant_type = GrantType}) ->
#{headers => [], body => [{<<"grant_type">>, GrantType}]}.

add_client(Request0, Client, Opts) ->
#client{id = Id, secret = Secret} = Client,
case
{Client#client.grant_type =:= <<"password">>,
Client#client.grant_type =:= <<"azure_client_credentials">> orelse
proplists:get_value(credentials_in_body, Opts, false)}
of
{false, false} ->
#{headers := Headers0} = Request0,
Auth = base64:encode(<<Id/binary, ":", Secret/binary>>),
Headers = [{<<"Authorization">>, <<"Basic ", Auth/binary>>}
| Headers0],
Request0#{headers => Headers};
{false, true} ->
#{body := Body} = Request0,
Request0#{body => [{<<"client_id">>, Id},
{<<"client_secret">>, Secret}
| Body]};
%% This clause is to still support password grant "as is" but
%% in the future this should be changed in order to support
%% client authentication in the password grant. Right now we
%% are assuming that if the grant is password then the client is public
%% which is not a fair assumption.
{true, _} ->
#{body := Body} = Request0,
Request0#{body => [{<<"username">>, Id},
{<<"password">>, Secret} | Body]}
end.

add_fields(Request, #client{scope=undefined}) ->
Request;
add_fields(Request, #client{grant_type = <<"azure_client_credentials">>,
scope = Scope}) ->
#{body := Body} = Request,
Request#{body => [{<<"resource">>, Scope} | Body]};
add_fields(Request, #client{scope = Scope}) ->
#{body := Body} = Request,
Request#{body => [{<<"scope">>, Scope} | Body]}.

-spec get_token_type(binary()) -> token_type().
get_token_type(Type) ->
get_str_token_type(string:to_lower(binary_to_list(Type))).
Expand Down Expand Up @@ -324,12 +294,12 @@ add_auth_header(Headers, #client{access_token = AccessToken}) ->
retrieve_access_token_fun(Client0, Options) ->
fun() ->
case do_retrieve_access_token(Client0, Options) of
{ok, _Headers, Client} -> {ok, Client, Client#client.expiry_time};
{ok, _Headers, Client} -> {ok, Client, Client#client.expire_time};
{error, Reason} -> {error, Reason}
end
end.

get_access_token(#client{expiry_time = ExpiryTime} = Client0, Options) ->
get_access_token(#client{expire_time = ExpireTime} = Client0, Options) ->
case {proplists:get_value(cache_token, Options, false),
proplists:get_value(force_revalidate, Options, false)}
of
Expand All @@ -348,7 +318,7 @@ get_access_token(#client{expiry_time = ExpiryTime} = Client0, Options) ->
{true, true} ->
Key = hash_client(Client0),
RevalidateFun = retrieve_access_token_fun(Client0, Options),
oauth2c_token_cache:set_and_get(Key, RevalidateFun, ExpiryTime)
oauth2c_token_cache:set_and_get(Key, RevalidateFun, ExpireTime)
end.

hash_client(#client{grant_type = Type,
Expand Down
17 changes: 8 additions & 9 deletions src/oauth2c_token_cache.erl
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,8 @@ handle_call({set_and_get, Key, LazyValue,
{reply, {ok, Result}, State};
{error, not_found} ->
case LazyValue() of
{ok, Result, ExpiryTime0} ->
ExpiryTime = get_expiry_time(ExpiryTime0, DefaultTTL),
{ok, Result, ExpireTime} ->
ExpiryTime = get_expire_time(ExpireTime, DefaultTTL),
ets:insert(?TOKEN_CACHE_ID, {Key, Result, ExpiryTime}),
{reply, {ok, Result}, State};
{error, Reason} -> {reply, {error, Reason}, State}
Expand All @@ -113,8 +113,7 @@ get_token(Key) ->
get_token(Key, undefined).
get_token(Key, ExpiryTimeLowerLimit) ->
Now = erlang:system_time(second),
case ets:lookup(?TOKEN_CACHE_ID, Key)
of
case ets:lookup(?TOKEN_CACHE_ID, Key) of
% Only return cache entry if
% (1) It has not expired
% (2) Its expiry time is greater than ExpiryTimeLowerLimit
Expand All @@ -128,20 +127,20 @@ get_token(Key, ExpiryTimeLowerLimit) ->
{error, not_found}
end.

get_expiry_time(undefined, DefaultTTL) ->
get_expire_time(undefined, DefaultTTL) ->
erlang:system_time(second) + DefaultTTL;
get_expiry_time(ExpiryTime, _DefaultTTL) ->
ExpiryTime.
get_expire_time(ExpireTime, _DefaultTTL) ->
ExpireTime.

%%%_ * Tests -------------------------------------------------------

-ifdef(TEST).
-include_lib("eunit/include/eunit.hrl").

get_expiry_time_test_() ->
get_expires_in_test_() ->
[fun() ->
{T, Default} = Input,
Actual = get_expiry_time(T, Default),
Actual = get_expire_time(T, Default),
?assertEqual(Expected, Actual)
end
|| {Input, Expected} <- [{{1, 100}, 1}]
Expand Down
Loading

0 comments on commit b49a963

Please sign in to comment.