1
- import os
1
+ from __future__ import annotations as _annotations
2
+
2
3
import logging
4
+ import os
3
5
from pathlib import Path
4
- from typing import TYPE_CHECKING , Any , Dict , Mapping , Optional , Tuple
6
+ from typing import TYPE_CHECKING , Any
5
7
6
- from botocore .exceptions import ClientError
7
- from botocore .client import Config
8
8
import boto3
9
-
10
- from pydantic import BaseSettings
11
- from pydantic .typing import StrPath , get_origin , is_union
12
- from pydantic .utils import deep_update
13
- from pydantic .fields import ModelField
9
+ from botocore .client import Config
10
+ from botocore .exceptions import ClientError
11
+ from pydantic import BaseModel
12
+ from pydantic ._internal ._utils import lenient_issubclass
13
+ from pydantic .fields import FieldInfo
14
+ from pydantic_settings import BaseSettings
15
+ from pydantic_settings .sources import (
16
+ EnvSettingsSource ,
17
+ )
14
18
15
19
if TYPE_CHECKING :
16
20
from mypy_boto3_ssm .client import SSMClient
@@ -23,16 +27,28 @@ class SettingsError(ValueError):
23
27
pass
24
28
25
29
26
- class AwsSsmSettingsSource :
27
- __slots__ = ("ssm_prefix" , "env_nested_delimiter" )
28
-
30
+ class AwsSsmSettingsSource (EnvSettingsSource ):
29
31
def __init__ (
30
32
self ,
31
- ssm_prefix : Optional [StrPath ],
32
- env_nested_delimiter : Optional [str ] = None ,
33
+ settings_cls : type [BaseSettings ],
34
+ case_sensitive : bool = None ,
35
+ ssm_prefix : str = None ,
33
36
):
34
- self .ssm_prefix : Optional [StrPath ] = ssm_prefix
35
- self .env_nested_delimiter : Optional [str ] = env_nested_delimiter
37
+ # Ideally would retrieve ssm_prefix from self.config
38
+ # but need the superclass to be initialized for that
39
+ ssm_prefix_ = (
40
+ ssm_prefix
41
+ if ssm_prefix is not None
42
+ else settings_cls .model_config .get ("ssm_prefix" , "/" )
43
+ )
44
+ super ().__init__ (
45
+ settings_cls ,
46
+ case_sensitive = case_sensitive ,
47
+ env_prefix = ssm_prefix_ ,
48
+ env_nested_delimiter = "/" , # SSM only accepts / as a delimiter
49
+ )
50
+ self .ssm_prefix = ssm_prefix_
51
+ assert self .ssm_prefix == self .env_prefix
36
52
37
53
@property
38
54
def client (self ) -> "SSMClient" :
@@ -43,124 +59,103 @@ def client_config(self) -> Config:
43
59
timeout = float (os .environ .get ("SSM_TIMEOUT" , 0.5 ))
44
60
return Config (connect_timeout = timeout , read_timeout = timeout )
45
61
46
- def load_from_ssm (self , secrets_path : Path , case_sensitive : bool ):
47
-
48
- if not secrets_path .is_absolute ():
62
+ def _load_env_vars (
63
+ self ,
64
+ ):
65
+ """
66
+ Access env_prefix instead of ssm_prefix
67
+ """
68
+ if not Path (self .env_prefix ).is_absolute ():
49
69
raise ValueError ("SSM prefix must be absolute path" )
50
70
51
- logger .debug (f"Building SSM settings with prefix of { secrets_path = } " )
71
+ logger .debug (f"Building SSM settings with prefix of { self . env_prefix = } " )
52
72
53
73
output = {}
54
74
try :
55
75
paginator = self .client .get_paginator ("get_parameters_by_path" )
56
76
response_iterator = paginator .paginate (
57
- Path = str ( secrets_path ) , WithDecryption = True
77
+ Path = self . env_prefix , WithDecryption = True , Recursive = True
58
78
)
59
79
60
80
for page in response_iterator :
61
81
for parameter in page ["Parameters" ]:
62
- key = Path (parameter ["Name" ]).relative_to (secrets_path ).as_posix ()
63
- output [key if case_sensitive else key .lower ()] = parameter ["Value" ]
82
+ key = (
83
+ Path (parameter ["Name" ]).relative_to (self .env_prefix ).as_posix ()
84
+ )
85
+ output [
86
+ self .env_prefix + key
87
+ if self .case_sensitive
88
+ else self .env_prefix .lower () + key .lower ()
89
+ ] = parameter ["Value" ]
64
90
65
91
except ClientError :
66
- logger .exception ("Failed to get parameters from %s" , secrets_path )
92
+ logger .exception ("Failed to get parameters from %s" , self . env_prefix )
67
93
68
94
return output
69
95
70
- def __call__ (self , settings : BaseSettings ) -> Dict [str , Any ]:
71
- """
72
- Returns SSM values for all settings.
73
- """
74
- d : Dict [str , Optional [Any ]] = {}
75
-
76
- if self .ssm_prefix is None :
77
- return d
78
-
79
- ssm_values = self .load_from_ssm (
80
- secrets_path = Path (self .ssm_prefix ),
81
- case_sensitive = settings .__config__ .case_sensitive ,
82
- )
96
+ def __repr__ (self ) -> str :
97
+ return f"AwsSsmSettingsSource(ssm_prefix={ self .env_prefix !r} )"
83
98
84
- # The following was lifted from https://github.com/samuelcolvin/pydantic/blob/a21f0763ee877f0c86f254a5d60f70b1002faa68/pydantic/env_settings.py#L165-L237 # noqa
85
- for field in settings .__fields__ .values ():
86
- env_val : Optional [str ] = None
87
- for env_name in field .field_info .extra ["env_names" ]:
88
- env_val = ssm_values .get (env_name )
89
- if env_val is not None :
90
- break
91
-
92
- is_complex , allow_json_failure = self ._field_is_complex (field )
93
- if is_complex :
94
- if env_val is None :
95
- # field is complex but no value found so far, try explode_env_vars
96
- env_val_built = self ._explode_ssm_values (field , ssm_values )
97
- if env_val_built :
98
- d [field .alias ] = env_val_built
99
- else :
100
- # field is complex and there's a value, decode that as JSON, then
101
- # add explode_env_vars
102
- try :
103
- env_val = settings .__config__ .json_loads (env_val )
104
- except ValueError as e :
105
- if not allow_json_failure :
106
- raise SettingsError (
107
- f'error parsing JSON for "{ env_name } "'
108
- ) from e
109
-
110
- if isinstance (env_val , dict ):
111
- d [field .alias ] = deep_update (
112
- env_val , self ._explode_ssm_values (field , ssm_values )
113
- )
114
- else :
115
- d [field .alias ] = env_val
116
- elif env_val is not None :
117
- # simplest case, field is not complex, we only need to add the
118
- # value if it was found
119
- d [field .alias ] = env_val
120
-
121
- return d
122
-
123
- def _field_is_complex (self , field : ModelField ) -> Tuple [bool , bool ]:
124
- """
125
- Find out if a field is complex, and if so whether JSON errors should be ignored
99
+ def get_field_value (
100
+ self , field : FieldInfo , field_name : str
101
+ ) -> tuple [Any , str , bool ]:
126
102
"""
127
- if field .is_complex ():
128
- allow_json_failure = False
129
- elif (
130
- is_union (get_origin (field .type_ ))
131
- and field .sub_fields
132
- and any (f .is_complex () for f in field .sub_fields )
133
- ):
134
- allow_json_failure = True
135
- else :
136
- return False , False
103
+ Gets the value for field from environment variables and a flag to
104
+ determine whether value is complex.
137
105
138
- return True , allow_json_failure
106
+ Args:
107
+ field: The field.
108
+ field_name: The field name.
139
109
140
- def _explode_ssm_values (
141
- self , field : ModelField , env_vars : Mapping [ str , Optional [ str ]]
142
- ) -> Dict [ str , Any ]:
110
+ Returns:
111
+ A tuple contains the key, value if the file exists otherwise `None`, and
112
+ a flag to determine whether value is complex.
143
113
"""
144
- Process env_vars and extract the values of keys containing
145
- env_nested_delimiter into nested dictionaries.
146
114
147
- This is applied to a single field, hence filtering by env_var prefix.
148
- """
149
- prefixes = [
150
- f"{ env_name } { self .env_nested_delimiter } "
151
- for env_name in field .field_info .extra ["env_names" ]
152
- ]
153
- result : Dict [str , Any ] = {}
154
- for env_name , env_val in env_vars .items ():
155
- if not any (env_name .startswith (prefix ) for prefix in prefixes ):
156
- continue
157
- _ , * keys , last_key = env_name .split (self .env_nested_delimiter )
158
- env_var = result
159
- for key in keys :
160
- env_var = env_var .setdefault (key , {})
161
- env_var [last_key ] = env_val
162
-
163
- return result
115
+ # env_name = /asdf/foo
116
+ # env_vars = {foo:xyz}
117
+ env_val : str | None = None
118
+ for field_key , env_name , value_is_complex in self ._extract_field_info (
119
+ field , field_name
120
+ ):
121
+ env_val = self .env_vars .get (env_name )
122
+ if env_val is not None :
123
+ break
124
+
125
+ return env_val , field_key , value_is_complex
126
+
127
+ def __call__ (self ) -> dict [str , Any ]:
128
+ data : dict [str , Any ] = {}
129
+
130
+ for field_name , field in self .settings_cls .model_fields .items ():
131
+ try :
132
+ field_value , field_key , value_is_complex = self .get_field_value (
133
+ field , field_name
134
+ )
135
+ except Exception as e :
136
+ raise SettingsError (
137
+ f'error getting value for field "{ field_name } " from source "{ self .__class__ .__name__ } "' # noqa
138
+ ) from e
139
+
140
+ try :
141
+ field_value = self .prepare_field_value (
142
+ field_name , field , field_value , value_is_complex
143
+ )
144
+ except ValueError as e :
145
+ raise SettingsError (
146
+ f'error parsing value for field "{ field_name } " from source "{ self .__class__ .__name__ } "' # noqa
147
+ ) from e
148
+
149
+ if field_value is not None :
150
+ if (
151
+ not self .case_sensitive
152
+ and lenient_issubclass (field .annotation , BaseModel )
153
+ and isinstance (field_value , dict )
154
+ ):
155
+ data [field_key ] = self ._replace_field_names_case_insensitively (
156
+ field , field_value
157
+ )
158
+ else :
159
+ data [field_key ] = field_value
164
160
165
- def __repr__ (self ) -> str :
166
- return f"AwsSsmSettingsSource(ssm_prefix={ self .ssm_prefix !r} )"
161
+ return data
0 commit comments