Skip to content

XAI Benchmarks #93

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Dec 6, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 70 additions & 24 deletions kgcnn/data/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
import tensorflow as tf
import pandas as pd
import os
import typing as t
from typing import Union, List, Callable
from collections.abc import MutableMapping

from kgcnn.data.utils import save_pickle_file, load_pickle_file, ragged_tensor_from_nested_numpy
from kgcnn.graph.base import GraphDict

Expand Down Expand Up @@ -172,6 +174,7 @@ def length(self, value: int):
raise ValueError("Can not set length. Please use 'empty()' to initialize an empty list.")

def _to_tensor(self, item, make_copy=True):
# TODO: Document this
if not make_copy:
self.logger.warning("At the moment always a copy is made for tensor().")
props = self.obtain_property(item["name"]) # Will be list.
Expand Down Expand Up @@ -592,8 +595,12 @@ def set_methods(self, method_list: List[dict]) -> None:
else:
self.error("Class does not have method '%s'." % method)

def get_split_indices(self, name: str = "kfold", return_as_train_test: bool = True,
shuffle: bool = True, seed: int = None):
def get_split_indices(self,
name: str = "kfold",
return_as_train_test: bool = True,
shuffle: bool = True,
seed: int = None
):
"""Gather split ids from split graph property and return k-fold splits.

Args:
Expand All @@ -611,7 +618,6 @@ def check_and_extend_splits(to_split):
if to_split - len(split_indices) + 1 > 0:
for _ in range(to_split - len(split_indices) + 1):
split_indices.append([])

graphs = self.obtain_property(name)
for i, s in enumerate(graphs):
if s is None:
Expand Down Expand Up @@ -640,10 +646,28 @@ def check_and_extend_splits(to_split):

return train_test

def get_train_test_indices(self, train: str = "train", test: str = "test", valid: str = None,
split_index: Union[int, list] = 1, shuffle: bool = False, seed: int = None):
"""Get train and test indices from graph list. The 'train' and 'test' properties must be set on the graph.
They can also be a list of split assignment if more than one train-test split is required.
def get_train_test_indices(self,
train: str = "train",
test: str = "test",
valid: t.Optional[str] = None,
split_index: Union[int, list] = 1,
shuffle: bool = False,
seed: int = None
) -> t.List[t.Union[t.Tuple[int, int], t.Tuple[int, int, int]]]:
"""
Get train and test indices from graph list.
The 'train' and 'test' properties must be set on the graph, and optionally an additional property
for the validation split may be present. All of these properties may have either of the following
values:
- The property is a boolean integer value indicating whether the corresponding element of the
dataset belongs to that part of the split (train / test)
- The property is a list containing integer split indices, where each split index present within
that list implies that the corresponding dataset element is part of that particular split.
In this case the ``split_index`` parameter may also be a list of split indices that specifies
for which of these split indices the train test index split is to be returned by this method.

The return value of this method is a list with the same length as the ``split_index`` parameter,
which by default will be 1.

Args:
train (str): Name of graph property that has train split assignment. Defaults to 'train'.
Expand All @@ -658,27 +682,49 @@ def get_train_test_indices(self, train: str = "train", test: str = "test", valid
"""
out_indices = []
if not isinstance(split_index, (list, tuple)):
split_index = [split_index]
split_index_list: t.List[int] = [split_index]
else:
split_index_list: t.List[int] = split_index

for split_index in split_index_list:

# This list will itself contain numpy arrays which are filled with graph indices of the dataset
# each element of this list will correspond to one property name (train, test...)
graph_index_split_list: t.List[np.ndarray] = []

for j in split_index:
split_list = []
for s in [train, test, valid]:
if s is None:
# out_indices.append(None)
for property_name in [train, test, valid]:

# It may be that we only seek train and test indices and not validation indices, in that
# case the property name for validation is None, in which case we want to skip
if property_name is None:
continue
s_list = []
split_prop = self.obtain_property(s)
for i, x in enumerate(split_prop):
if x is not None:
if j in x:
s_list.append(i)
s_list = np.array(s_list)
split_list.append(s_list)

# This list will contain all the indices of the dataset elements (graphs) which are
# associated with the current iteration's split index for the current iteration's
# property name (train, test...)
graph_index_list: t.List[int] = []

# "obtain_property" returns a list which contains only the property values corresponding to
# the given property name for each graph inside the dataset in the same order.
# In this case, this is supposed to be a split list, which is a list that contains integer
# indices, each representing one particular dataset split. The split list of each graph
# only contains those split indices to which that graph is associated.
split_prop: t.List[t.List[int]] = self.obtain_property(property_name)
for index, split_list in enumerate(split_prop):
if split_list is not None:
if split_index in split_list:
graph_index_list.append(index)

graph_index_array = np.array(graph_index_list)
graph_index_split_list.append(graph_index_array)

if shuffle:
np.random.seed(seed)
for x in split_list:
np.random.shuffle(x)
out_indices.append(split_list)
for graph_index_array in graph_index_split_list:
np.random.shuffle(graph_index_array)

out_indices.append(graph_index_split_list)

return out_indices


Expand Down
7 changes: 7 additions & 0 deletions kgcnn/data/datasets/VgdMockDataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from kgcnn.data.visual_graph_dataset import VisualGraphDataset


class VgdMockDataset(VisualGraphDataset):

def __init__(self):
super(VgdMockDataset, self).__init__(name='mock')
171 changes: 171 additions & 0 deletions kgcnn/data/visual_graph_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
"""
Module for handling Visual Graph Datasets (VGD).
"""
import os
import typing as t
from functools import cache, lru_cache

import numpy as np

from visual_graph_datasets.config import Config
from visual_graph_datasets.util import get_dataset_path, ensure_folder
from visual_graph_datasets.web import PROVIDER_CLASS_MAP, AbstractFileShare
from visual_graph_datasets.data import load_visual_graph_dataset
from visual_graph_datasets.visualization.importances import create_importances_pdf

from kgcnn.data.base import MemoryGraphDataset
from kgcnn.graph.base import GraphDict


class VisualGraphDataset(MemoryGraphDataset):

def __init__(self,
name: str):
super(VisualGraphDataset, self).__init__(
dataset_name=name,
file_directory=None,
file_name=None,
data_directory=None
)

self.vgd_config = Config()
self.vgd_config.load()

self.index_data_map: t.Dict[int, dict] = {}

def ensure(self) -> None:
"""
This method ensures that the raw dataset exists on the disk. After this method has been called it
can be certain that the dataset folder exists on the disk and that the folder path is known.

This is either because an existing dataset folder has been found or because the dataset was
downloaded.

Returns:
None
"""
# First of all we try to load the dataset, as it might already exist on the system.
try:
# This function will try to locate a dataset with the given name inside the system's global
# default folder where all the visual graph datasets are stored. If it does not find a
# corresponding dataset folder there, an exception is raised.
self.data_directory = get_dataset_path(self.dataset_name)
return
except (FileNotFoundError, IndexError) as e:
self.logger.info(f'the visual graph dataset "{self.dataset_name}" was not found on the disk. '
f'The following exception was raised during lookup:')
self.logger.info(str(e))

# At this point we know that the folder does not already exist which means we need to download the
# dataset.

# For this we will first check if a dataset with the given name is even available at the remote
# file share provider.
ensure_folder(self.vgd_config.get_datasets_path())
file_share_provider: str = self.vgd_config.get_provider()
file_share_class: type = PROVIDER_CLASS_MAP[file_share_provider]
file_share: AbstractFileShare = file_share_class(config=self.vgd_config, logger=self.logger)
file_share.check_dataset(self.dataset_name)

# If the dataset is available, then we can download it and then finally load the path
file_share.download_dataset(self.dataset_name, self.vgd_config.get_datasets_path())
self.data_directory = get_dataset_path(self.dataset_name)
self.logger.info(f'visual graph dataset found @ {self.data_directory}')

def read_in_memory(self) -> None:
"""
Actually loads the dataset from the file representations into the working memory as GraphDicts
within the internal MemoryGraphList.

Returns:
None
"""
name_data_map, self.index_data_map = load_visual_graph_dataset(
self.data_directory,
logger=self.logger,
metadata_contains_index=True
)
dataset_length = len(self.index_data_map)

self.empty(dataset_length)
self.logger.info(f'initialized empty list with {len(self.index_data_map)} elements')

for index, data in sorted(self.index_data_map.items(), key=lambda t: t[0]):
g = data['metadata']['graph']

# In the visual_graph_dataset framework, the train and test split indications are part of the
# metadata and not the graph itself. We need to set them as graph properties with these
# specific names however, such that later on the existing base method "get_train_test_indices"
# of the dataset class can be used.
g['train'] = data['metadata']['train_split']
g['test'] = data['metadata']['test_split']

# Otherwise the basic structure of the dict from the visual graph dataset should be compatible
# such that it can be directly used as a GraphDict
graph_dict = GraphDict(g)
self[index] = graph_dict

self.logger.info(f'loaded dataset as MemoryGraphList')

def visualize_importances(self,
output_path: str,
gt_importances_suffix: str,
node_importances_list: t.List[np.ndarray],
edge_importances_list: t.List[np.ndarray],
indices: t.List[int],
title: str = 'Model',
) -> None:
data_list = [self.index_data_map[index] for index in indices]

suffix = gt_importances_suffix
create_importances_pdf(
output_path=output_path,
graph_list=[data['metadata']['graph'] for data in data_list],
image_path_list=[data['image_path'] for data in data_list],
node_positions_list=[data['metadata']['graph']['image_node_positions'] for data in data_list],
importances_map={
'Ground Truth': (
[data['metadata']['graph'][f'node_importances_{suffix}'] for data in data_list],
[data['metadata']['graph'][f'edge_importances_{suffix}'] for data in data_list],
),
title: (
node_importances_list,
edge_importances_list
)
}
)

# This method loops through all the elements of the dataset which makes it quite computationally
# expensive, which is why the value will be cached to be more efficient if it is called multiple
# time redundantly
@cache
def has_importances(self, suffix: t.Union[str, int]) -> bool:
suffix = str(suffix)

node_condition = all([f'node_importances_{suffix}' in data['metadata']['graph']
for data in self.index_data_map.values()])

edge_condition = all([f'edge_importances_{suffix}' in data['metadata']['graph']
for data in self.index_data_map.values()])

return node_condition and edge_condition

def get_importances(self,
suffix: str,
indices: t.Optional[t.List[int]] = None
) -> t.Tuple[t.List[np.ndarray], t.List[np.ndarray]]:
if indices is None:
indices = list(self.index_data_map.keys())

node_importances_list: t.List[np.ndarray] = []
edge_importances_list: t.List[np.ndarray] = []
for index in indices:
data = self.index_data_map[index]
g = data['metadata']['graph']
node_importances_list.append(g[f'node_importances_{suffix}'])
edge_importances_list.append(g[f'edge_importances_{suffix}'])

return node_importances_list, edge_importances_list

def __hash__(self):
return hash(self.dataset_name)
5 changes: 4 additions & 1 deletion kgcnn/literature/MEGAN.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,8 @@ def __init__(self,
use_bias=use_bias
)
self.node_importance_layers.append(lay)

self.lay_sparsity = ExplanationSparsityRegularization(factor=self.sparsity_factor)

self.lay_sparsity = ExplanationSparsityRegularization(factor=self.sparsity_factor)

Expand Down Expand Up @@ -410,8 +412,10 @@ def train_step(self, data):
else:
out_pred = shifted_sigmoid(
outs,

# shift=self.importance_multiplier,
# multiplier=(self.importance_multiplier / 5)

shift=self.importance_multiplier,
multiplier=1,
)
Expand Down Expand Up @@ -459,4 +463,3 @@ def make_model(inputs: t.Optional[list] = None,
outputs = megan(layer_inputs)
model = ks.models.Model(inputs=layer_inputs, outputs=outputs)

return
23 changes: 23 additions & 0 deletions kgcnn/utils/cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
"""
Module containing utility methods for the command line interface (CLI).
"""
import click

CHECK_MARK = '✓'


# == CLICK SPECIFIC UTILS ==

def echo_info(content: str, verbose: bool = True):
if verbose:
click.secho(f'... {content}')


def echo_success(content: str, verbose: bool = True):
if verbose:
click.secho(f'[{CHECK_MARK}] {content}', fg='green')


def echo_error(content: str, verbose: bool = True):
if verbose:
click.secho(f'[!] {content}', fg='red')
Empty file added kgcnn/xai/__init__.py
Empty file.
22 changes: 22 additions & 0 deletions kgcnn/xai/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import typing as t

import tensorflow as tf


class AbstractExplanationMixin:

def explain(self, x):
raise NotImplementedError()


class ImportanceExplanationMixin:

def explain(self, x, **kwargs):
return self.explain_importances(x, **kwargs)

# Returns a tuple of ragged tensors (node_importances, edge_importances)
def explain_importances(self,
x: t.Sequence[tf.Tensor],
**kwargs
) -> t.Tuple[tf.RaggedTensor, tf.RaggedTensor]:
raise NotImplementedError
Loading