Skip to content

Commit

Permalink
token-cache-PLT-19 (#19)
Browse files Browse the repository at this point in the history
Added option to cache access tokens
  • Loading branch information
jakobsvenning authored Apr 6, 2020
1 parent fb6cc6c commit c89d6ec
Show file tree
Hide file tree
Showing 8 changed files with 624 additions and 69 deletions.
45 changes: 45 additions & 0 deletions include/oauth2c.hrl
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
-record(client, {
grant_type = undefined :: binary() | undefined,
auth_url = undefined :: binary() | undefined,
access_token = undefined :: binary() | undefined,
token_type = undefined :: token_type() | undefined,
refresh_token = undefined :: binary() | undefined,
id = undefined :: binary() | undefined,
secret = undefined :: binary() | undefined,
scope = undefined :: binary() | undefined,
expiry_time = undefined :: integer() | undefined
}).

-type method() :: head |
get |
put |
patch |
post |
trace |
options |
delete.
-type url() :: binary().
%% <<"password">> or <<"client_credentials">>
-type at_type() :: binary().
-type headers() :: [header()].
-type header() :: {binary(), binary()}.
-type status_codes() :: [status_code()].
-type status_code() :: integer().
-type reason() :: term().
-type content_type() :: json | xml | percent.
-type property() :: atom() | tuple().
-type proplist() :: [property()].
-type options() :: proplist().
-type body() :: proplist().
-type restc_response() :: { ok
, Status::status_code()
, Headers::headers()
, Body::body()} |
{ error
, Status::status_code()
, Headers::headers()
, Body::body()} |
{ error, Reason::reason()}.
-type response() :: {restc_response(), #client{}}.
-type token_type() :: bearer | unsupported.
-type client() :: #client{}.
1 change: 1 addition & 0 deletions src/oauth2c.app.src
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
{description, "Erlang OAuth2 Client"},
{vsn, git},
{registered, []},
{mod, {oauth2c_app, []}},
{applications, [
kernel,
stdlib,
Expand Down
136 changes: 74 additions & 62 deletions src/oauth2c.erl
Original file line number Diff line number Diff line change
Expand Up @@ -41,51 +41,9 @@
-export([request/8]).

-define(DEFAULT_ENCODING, json).
-define(TOKEN_CACHE_SERVER, oauth2c_token_cache).

-record(client, {
grant_type = undefined :: binary() | undefined,
auth_url = undefined :: binary() | undefined,
access_token = undefined :: binary() | undefined,
token_type = undefined :: token_type() | undefined,
refresh_token = undefined :: binary() | undefined,
id = undefined :: binary() | undefined,
secret = undefined :: binary() | undefined,
scope = undefined :: binary() | undefined
}).

-type method() :: head |
get |
put |
patch |
post |
trace |
options |
delete.
-type url() :: binary().
%% <<"password">> or <<"client_credentials">>
-type at_type() :: binary().
-type headers() :: [header()].
-type header() :: {binary(), binary()}.
-type status_codes() :: [status_code()].
-type status_code() :: integer().
-type reason() :: term().
-type content_type() :: json | xml | percent.
-type property() :: atom() | tuple().
-type proplist() :: [property()].
-type options() :: proplist().
-type body() :: proplist().
-type restc_response() :: { ok
, Status::status_code()
, Headers::headers()
, Body::body()} |
{ error
, Status::status_code()
, Headers::headers()
, Body::body()} |
{ error, Reason::reason()}.
-type response() :: {restc_response(), #client{}}.
-type token_type() :: bearer | unsupported.
-type client() :: #client{}.
-include("oauth2c.hrl").

%%% API ========================================================================

Expand Down Expand Up @@ -137,7 +95,7 @@ retrieve_access_token(Type, Url, ID, Secret, Scope) ->
ID :: binary(),
Secret :: binary(),
Scope :: binary() | undefined,
Options :: list().
Options :: options().
retrieve_access_token(Type, Url, ID, Secret, Scope, Options) ->
Client = #client{
grant_type = Type
Expand Down Expand Up @@ -205,15 +163,28 @@ request(Method, Type, Url, Expect, Headers, Body, Client) ->
Body :: body(),
Options :: options(),
Client :: client().
request(Method, Type, Url, Expect, Headers, Body, Options, Client) ->
case do_request(Method,Type,Url,Expect,Headers,Body,Options,Client) of

request(Method, Type, Url, Expect, Headers, Body, Options, Client0) ->
Client1 = ensure_client_has_access_token(Client0, Options),
case do_request(Method,Type,Url,Expect,Headers,Body,Options,Client1) of
{{_, 401, _, _}, Client2} ->
{ok, _RetrHeaders, Client3} =
do_retrieve_access_token(Client2, Options),
do_request(Method,Type,Url,Expect,Headers,Body,Options,Client3);
{ok, Client3} = get_access_token(Client2#client{access_token = undefined},
[force_revalidate | Options]),
do_request(Method, Type, Url, Expect, Headers, Body, Options, Client3);
Result -> Result
end.

%%% INTERNAL ===================================================================

ensure_client_has_access_token(Client0, Options) ->
case Client0 of
#client{access_token = undefined} ->
{ok, Client} = get_access_token(Client0, Options),
Client;
_ ->
Client0
end.

do_retrieve_access_token(#client{grant_type = <<"password">>} = Client, Opts) ->
Payload0 = [
{<<"grant_type">>, Client#client.grant_type}
Expand All @@ -231,6 +202,7 @@ do_retrieve_access_token(#client{grant_type = <<"password">>} = Client, Opts) ->
{ok, _, Headers, Body} ->
AccessToken = proplists:get_value(<<"access_token">>, Body),
RefreshToken = proplists:get_value(<<"refresh_token">>, Body),
ExpiryTime = proplists:get_value(<<"expiry_time">>, Body),
Result = case RefreshToken of
undefined ->
#client{
Expand All @@ -240,6 +212,7 @@ do_retrieve_access_token(#client{grant_type = <<"password">>} = Client, Opts) ->
,id = Client#client.id
,secret = Client#client.secret
,scope = Client#client.scope
,expiry_time = ExpiryTime
};
_ ->
#client{
Expand All @@ -248,6 +221,7 @@ do_retrieve_access_token(#client{grant_type = <<"password">>} = Client, Opts) ->
,access_token = AccessToken
,refresh_token = RefreshToken
,scope = Client#client.scope
,expiry_time = ExpiryTime
}
end,
{ok, Headers, Result};
Expand All @@ -272,6 +246,7 @@ do_retrieve_access_token(#client{grant_type = <<"client_credentials">>,
{ok, _, Headers, Body} ->
AccessToken = proplists:get_value(<<"access_token">>, Body),
TokenType = proplists:get_value(<<"token_type">>, Body, ""),
ExpiryTime = proplists:get_value(<<"expiry_time">>, Body),
Result = #client{
grant_type = Client#client.grant_type
,auth_url = Client#client.auth_url
Expand All @@ -280,6 +255,7 @@ do_retrieve_access_token(#client{grant_type = <<"client_credentials">>,
,id = Client#client.id
,secret = Client#client.secret
,scope = Client#client.scope
,expiry_time = ExpiryTime
},
{ok, Headers, Result};
{error, _, _, Reason} ->
Expand All @@ -303,6 +279,7 @@ do_retrieve_access_token(#client{grant_type = <<"azure_client_credentials">>,
{ok, _, Headers, Body} ->
AccessToken = proplists:get_value(<<"access_token">>, Body),
TokenType = proplists:get_value(<<"token_type">>, Body, ""),
ExpiryTime = proplists:get_value(<<"expiry_time">>, Body),
Result = #client{
grant_type = Client#client.grant_type
,auth_url = Client#client.auth_url
Expand All @@ -311,6 +288,7 @@ do_retrieve_access_token(#client{grant_type = <<"azure_client_credentials">>,
,id = Client#client.id
,secret = Client#client.secret
,scope = Client#client.scope
,expiry_time = ExpiryTime
},
{ok, Headers, Result};
{error, _, _, Reason} ->
Expand All @@ -327,19 +305,9 @@ get_token_type(Type) ->
get_str_token_type("bearer") -> bearer;
get_str_token_type(_Else) -> unsupported.

do_request(Method, Type, Url, Expect, Headers, Body, Options, Client0) ->
{Headers2, Client} = add_auth_header(Headers, Client0, Options),
{restc:request(Method, Type, Url, Expect, Headers2, Body, Options), Client}.

add_auth_header(Headers0,
#client{access_token = undefined} = Client0,
Options) ->
{ok, _RetrHeaders, Client} = do_retrieve_access_token(Client0, Options),
Headers = add_auth_header(Headers0, Client),
{Headers, Client};
add_auth_header(Headers0, Client, _) ->
do_request(Method, Type, Url, Expect, Headers0, Body, Options, Client) ->
Headers = add_auth_header(Headers0, Client),
{Headers, Client}.
{restc:request(Method, Type, Url, Expect, Headers, Body, Options), Client}.

add_auth_header(Headers, #client{grant_type = <<"azure_client_credentials">>,
access_token = AccessToken}) ->
Expand All @@ -353,6 +321,50 @@ add_auth_header(Headers, #client{access_token = AccessToken}) ->
AH = {<<"Authorization">>, <<"token ", AccessToken/binary>>},
[AH | proplists:delete(<<"Authorization">>, Headers)].

retrieve_access_token_fun(Client0, Options) ->
fun() ->
case do_retrieve_access_token(Client0, Options) of
{ok, _Headers, Client} -> {ok, Client, Client#client.expiry_time};
{error, Reason} -> {error, Reason}
end
end.

get_access_token(#client{expiry_time = ExpiryTime} = Client0, Options) ->
case {proplists:get_value(cache_token, Options, false),
proplists:get_value(force_revalidate, Options, false)}
of
{false, _} ->
{ok, _Headers, Client} = do_retrieve_access_token(Client0, Options),
{ok, Client};
{true, false} ->
Key = hash_client(Client0),
case oauth2c_token_cache:get(Key) of
{error, not_found} ->
RevalidateFun = retrieve_access_token_fun(Client0, Options),
oauth2c_token_cache:set_and_get(Key, RevalidateFun);
{ok, Client} ->
{ok, Client}
end;
{true, true} ->
Key = hash_client(Client0),
RevalidateFun = retrieve_access_token_fun(Client0, Options),
oauth2c_token_cache:set_and_get(Key, RevalidateFun, ExpiryTime)
end.

hash_client(#client{grant_type = Type,
auth_url = AuthUrl,
id = ID,
secret = Secret,
scope = Scope}) ->
erlang:phash2({Type, AuthUrl, ID, Secret, Scope}).

%%%_ * Tests -------------------------------------------------------

-ifdef(TEST).
-include_lib("eunit/include/eunit.hrl").

-endif.

%%%_* Emacs ============================================================
%%% Local Variables:
%%% allout-layout: t
Expand Down
16 changes: 16 additions & 0 deletions src/oauth2c_app.erl
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
%%%-------------------------------------------------------------------
%% @doc oauth2c application callback
%% @end
%%%-------------------------------------------------------------------

-module(oauth2c_app).

-behaviour(application).

-export([start/2, stop/1]).

start(_StartType, _StartArgs) ->
oauth2c_sup:start_link().

stop(_State) ->
ok.
31 changes: 31 additions & 0 deletions src/oauth2c_sup.erl
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
%%%-------------------------------------------------------------------
%% @doc Top level supervisor of oauth2c.
%% @end
%%%-------------------------------------------------------------------

%%%_* Module declaration ===============================================
-module(oauth2c_sup).
-behaviour(supervisor).

%%%_* Exports ==========================================================
-export([start_link/0]).
-export([init/1]).

%%%_* Code =============================================================
%%%_ * API -------------------------------------------------------------

start_link() ->
supervisor:start_link({local, ?MODULE}, ?MODULE, []).

init([]) ->
Strategy = #{strategy => one_for_one,
intensity => 2,
period => 60},
ChildSpecs = [#{id => oauth2c_token_cache,
start => {oauth2c_token_cache, start_link, []},
restart => permanent,
shutdown => 5000,
type => worker,
modules => [oauth2c_token_cache]
}],
{ok, {Strategy, ChildSpecs}}.
Loading

0 comments on commit c89d6ec

Please sign in to comment.