diff --git a/auth_oidc/__manifest__.py b/auth_oidc/__manifest__.py
index 5e046a73f3..50c8524e47 100644
--- a/auth_oidc/__manifest__.py
+++ b/auth_oidc/__manifest__.py
@@ -16,6 +16,10 @@
"summary": "Allow users to login through OpenID Connect Provider",
"external_dependencies": {"python": ["python-jose"]},
"depends": ["auth_oauth"],
- "data": ["views/auth_oauth_provider.xml", "data/auth_oauth_data.xml"],
+ "data": [
+ "security/ir.model.access.csv",
+ "views/auth_oauth_provider.xml",
+ "data/auth_oauth_data.xml",
+ ],
"demo": ["demo/local_keycloak.xml"],
}
diff --git a/auth_oidc/demo/local_keycloak.xml b/auth_oidc/demo/local_keycloak.xml
index 919754db99..92588dc952 100644
--- a/auth_oidc/demo/local_keycloak.xml
+++ b/auth_oidc/demo/local_keycloak.xml
@@ -17,4 +17,9 @@
name="jwks_uri"
>http://localhost:8080/auth/realms/master/protocol/openid-connect/certs
+
+
+
+ token['name'] == 'test'
+
diff --git a/auth_oidc/models/auth_oauth_provider.py b/auth_oidc/models/auth_oauth_provider.py
index ac498a7cdb..15a5ff7f5e 100644
--- a/auth_oidc/models/auth_oauth_provider.py
+++ b/auth_oidc/models/auth_oauth_provider.py
@@ -2,12 +2,13 @@
# Copyright 2021 ACSONE SA/NV
# License: AGPL-3.0 or later (http://www.gnu.org/licenses/agpl)
+import collections
import logging
import secrets
import requests
-from odoo import fields, models, tools
+from odoo import api, exceptions, fields, models, tools
try:
from jose import jwt
@@ -46,6 +47,11 @@ class AuthOauthProvider(models.Model):
string="Token URL", help="Required for OpenID Connect authorization code flow."
)
jwks_uri = fields.Char(string="JWKS URL", help="Required for OpenID Connect.")
+ group_line_ids = fields.One2many(
+ "auth.oauth.provider.group_line",
+ "provider_id",
+ string="Group mappings",
+ )
@tools.ormcache("self.jwks_uri", "kid")
def _get_keys(self, kid):
@@ -104,3 +110,34 @@ def _decode_id_token(self, access_token, id_token, kid):
if error:
raise error
return {}
+
+
+class AuthOauthProviderGroupLine(models.Model):
+ _name = "auth.oauth.provider.group_line"
+
+ provider_id = fields.Many2one("auth.oauth.provider", required=True)
+ group_id = fields.Many2one("res.groups", required=True)
+ expression = fields.Char(required=True, help="Variables: user, token")
+
+ @api.constrains("expression")
+ def _check_expression(self):
+ for this in self:
+ try:
+ this._eval_expression(self.env.user, {})
+ except (AttributeError, KeyError, NameError) as e:
+ raise exceptions.ValidationError("\n".join(e.args))
+
+ def _eval_expression(self, user, token):
+ self.ensure_one()
+
+ class Defaultdict2(collections.defaultdict):
+ def __init__(self, *args, **kwargs):
+ super().__init__(Defaultdict2, *args, **kwargs)
+
+ return tools.safe_eval.safe_eval(
+ self.expression,
+ {
+ "user": user,
+ "token": Defaultdict2(token),
+ },
+ )
diff --git a/auth_oidc/models/res_users.py b/auth_oidc/models/res_users.py
index 1684480fa4..eb3bede25f 100644
--- a/auth_oidc/models/res_users.py
+++ b/auth_oidc/models/res_users.py
@@ -64,6 +64,12 @@ def auth_oauth(self, provider, params):
_logger.error("No id_token in response.")
raise AccessDenied()
validation = oauth_provider._parse_id_token(id_token, access_token)
+ if oauth_provider.data_endpoint:
+ data = requests.get(
+ oauth_provider.data_endpoint,
+ headers={"Authorization": "Bearer %s" % access_token},
+ ).json()
+ validation.update(data)
# required check
if "sub" in validation and "user_id" not in validation:
# set user_id for auth_oauth, user_id is not an OpenID Connect standard
@@ -80,3 +86,22 @@ def auth_oauth(self, provider, params):
raise AccessDenied()
# return user credentials
return (self.env.cr.dbname, login, access_token)
+
+ @api.model
+ def _auth_oauth_signin(self, provider, validation, params):
+ login = super()._auth_oauth_signin(provider, validation, params)
+ user = self.search([("login", "=", login)])
+ if user:
+ group_updates = []
+ for group_line in (
+ self.env["auth.oauth.provider"].browse(provider).group_line_ids
+ ):
+ if group_line._eval_expression(user, validation):
+ if group_line.group_id not in user.groups_id:
+ group_updates.append((4, group_line.group_id.id))
+ else:
+ if group_line.group_id in user.groups_id:
+ group_updates.append((3, group_line.group_id.id))
+ if group_updates:
+ user.write({"groups_id": group_updates})
+ return login
diff --git a/auth_oidc/security/ir.model.access.csv b/auth_oidc/security/ir.model.access.csv
new file mode 100644
index 0000000000..503e4c7529
--- /dev/null
+++ b/auth_oidc/security/ir.model.access.csv
@@ -0,0 +1,2 @@
+id,name,model_id:id,group_id:id,perm_read,perm_write,perm_create,perm_unlink
+access_auth_oauth_provider_group_line,auth_oauth_provider,model_auth_oauth_provider_group_line,base.group_system,1,1,1,1
diff --git a/auth_oidc/tests/test_auth_oidc_auth_code.py b/auth_oidc/tests/test_auth_oidc_auth_code.py
index a1a08b0a71..6e9054c884 100644
--- a/auth_oidc/tests/test_auth_oidc_auth_code.py
+++ b/auth_oidc/tests/test_auth_oidc_auth_code.py
@@ -308,3 +308,9 @@ def test_login_with_jwk_format(self):
)
self.assertEqual(token, "122/3")
self.assertEqual(login, user.login)
+
+ def test_group_expression(self):
+ """Test that group expressions evaluate correctly"""
+ group_line = self.env.ref("auth_oidc.local_keycloak").group_line_ids[:1]
+ group_line.expression = 'token["test"]["test"] == 1'
+ self.assertFalse(group_line._eval_expression(self.env.user, {}))
diff --git a/auth_oidc/views/auth_oauth_provider.xml b/auth_oidc/views/auth_oauth_provider.xml
index 90c931b417..dbdeadd8ef 100644
--- a/auth_oidc/views/auth_oauth_provider.xml
+++ b/auth_oidc/views/auth_oauth_provider.xml
@@ -19,6 +19,16 @@
+
+
+
+
+
+
+
+
+
+