@@ -323,15 +323,15 @@ def __init__(
323
323
"""
324
324
steps = []
325
325
if "entry_point" in kwargs :
326
- entry_point = kwargs [ "entry_point" ]
327
- source_dir = kwargs .get ("source_dir" )
328
- dependencies = kwargs .get ("dependencies" )
326
+ entry_point = kwargs . get ( "entry_point" , None )
327
+ source_dir = kwargs .get ("source_dir" , None )
328
+ dependencies = kwargs .get ("dependencies" , None )
329
329
repack_model_step = _RepackModelStep (
330
330
name = f"{ name } RepackModel" ,
331
331
depends_on = depends_on ,
332
332
retry_policies = repack_model_step_retry_policies ,
333
333
sagemaker_session = estimator .sagemaker_session ,
334
- role = estimator .sagemaker_session ,
334
+ role = estimator .role ,
335
335
model_data = model_data ,
336
336
entry_point = entry_point ,
337
337
source_dir = source_dir ,
@@ -357,7 +357,11 @@ def predict_wrapper(endpoint, session):
357
357
vpc_config = None ,
358
358
sagemaker_session = estimator .sagemaker_session ,
359
359
role = estimator .role ,
360
- ** kwargs ,
360
+ env = kwargs .get ("env" , None ),
361
+ name = kwargs .get ("name" , None ),
362
+ enable_network_isolation = kwargs .get ("enable_network_isolation" , None ),
363
+ model_kms_key = kwargs .get ("model_kms_key" , None ),
364
+ image_config = kwargs .get ("image_config" , None ),
361
365
)
362
366
model_step = CreateModelStep (
363
367
name = f"{ name } CreateModelStep" ,
0 commit comments