@@ -305,12 +305,13 @@ class SHAPConfig(ExplainabilityConfig):
305
305
306
306
def __init__ (
307
307
self ,
308
- baseline ,
309
- num_samples ,
310
- agg_method ,
308
+ baseline = None ,
309
+ num_samples = None ,
310
+ agg_method = None ,
311
311
use_logit = False ,
312
312
save_local_shap_values = True ,
313
313
seed = None ,
314
+ num_clusters = None ,
314
315
):
315
316
"""Initializes config for SHAP.
316
317
@@ -320,34 +321,49 @@ def __init__(
320
321
be the same as the dataset format. Each row should contain only the feature
321
322
columns/values and omit the label column/values. If None a baseline will be
322
323
calculated automatically by using K-means or K-prototypes in the input dataset.
323
- num_samples (int): Number of samples to be used in the Kernel SHAP algorithm.
324
+ num_samples (None or int): Number of samples to be used in the Kernel SHAP algorithm.
324
325
This number determines the size of the generated synthetic dataset to compute the
325
- SHAP values.
326
- agg_method (str): Aggregation method for global SHAP values. Valid values are
326
+ SHAP values. If not provided then Clarify job will choose a proper value according
327
+ to the count of features.
328
+ agg_method (None or str): Aggregation method for global SHAP values. Valid values are
327
329
"mean_abs" (mean of absolute SHAP values for all instances),
328
330
"median" (median of SHAP values for all instances) and
329
331
"mean_sq" (mean of squared SHAP values for all instances).
332
+ If not provided then Clarify job uses method "mean_abs"
330
333
use_logit (bool): Indicator of whether the logit function is to be applied to the model
331
334
predictions. Default is False. If "use_logit" is true then the SHAP values will
332
335
have log-odds units.
333
336
save_local_shap_values (bool): Indicator of whether to save the local SHAP values
334
337
in the output location. Default is True.
335
338
seed (int): seed value to get deterministic SHAP values. Default is None.
339
+ num_clusters (None or int): If a baseline is not provided, Clarify automatically
340
+ computes a baseline dataset via a clustering algorithm (K-means/K-prototypes).
341
+ num_clusters is a parameter for this algorithm. num_clusters will be the resulting
342
+ size of the baseline dataset. If not provided, Clarify job will use a default value.
336
343
"""
337
- if agg_method not in ["mean_abs" , "median" , "mean_sq" ]:
344
+ if agg_method is not None and agg_method not in ["mean_abs" , "median" , "mean_sq" ]:
338
345
raise ValueError (
339
346
f"Invalid agg_method { agg_method } ." f" Please choose mean_abs, median, or mean_sq."
340
347
)
341
-
348
+ if num_clusters is not None and baseline is not None :
349
+ raise ValueError (
350
+ "Baseline and num_clusters cannot be provided together. "
351
+ "Please specify one of the two."
352
+ )
342
353
self .shap_config = {
343
- "baseline" : baseline ,
344
- "num_samples" : num_samples ,
345
- "agg_method" : agg_method ,
346
354
"use_logit" : use_logit ,
347
355
"save_local_shap_values" : save_local_shap_values ,
348
356
}
357
+ if baseline is not None :
358
+ self .shap_config ["baseline" ] = baseline
359
+ if num_samples is not None :
360
+ self .shap_config ["num_samples" ] = num_samples
361
+ if agg_method is not None :
362
+ self .shap_config ["agg_method" ] = agg_method
349
363
if seed is not None :
350
364
self .shap_config ["seed" ] = seed
365
+ if num_clusters is not None :
366
+ self .shap_config ["num_clusters" ] = num_clusters
351
367
352
368
def get_explainability_config (self ):
353
369
"""Returns config."""
0 commit comments