Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ jobs:
sudo apt update --fix-missing --yes
sudo apt upgrade --yes
sudo apt-get install --yes git
sudo apt-get clean
sudo apt-get clean

- uses: actions/checkout@v4

Expand All @@ -88,9 +88,9 @@ jobs:
python --version
pip show setuptools
rm -rf ~/.cache/pip

- name: Print available disk space before graphnet install
run: |
run: |
df -h
- name: Upgrade packages in virtual environment
shell: bash
Expand Down Expand Up @@ -128,7 +128,7 @@ jobs:
pip show torch-scatter
pip show jammy_flows
- name: Print available disk space after graphnet install
run: |
run: |
df -h
- name: Run unit tests and generate coverage report
shell: bash
Expand Down
185 changes: 103 additions & 82 deletions src/graphnet/models/easy_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""Suggested Model subclass that enables simple user syntax."""

from collections import OrderedDict
from typing import Any, Dict, List, Optional, Union, Type

import numpy as np
Expand All @@ -10,7 +9,7 @@
from torch import Tensor
from torch.nn import ModuleList
from torch.optim import Adam
from torch.utils.data import DataLoader, SequentialSampler
from torch.utils.data import DataLoader
from torch_geometric.data import Data
import pandas as pd
from pytorch_lightning.loggers import Logger as LightningLogger
Expand Down Expand Up @@ -288,14 +287,59 @@ def train(self, mode: bool = True) -> "Model":
task.train_eval()
return self

def predict_step(self, *args: Any, **kwargs: Any) -> List[Any]:
"""Perform prediction step.

Returns a list whose first entries are the per-task prediction
tensors and whose trailing entries are numpy arrays for any
attributes requested via `_predict_additional_attributes`. Pulling
attributes here avoids a second pass over the dataloader in
`predict_as_dataframe`.
"""
batch = kwargs.get("batch", args[0])
pred = list(self(batch))

attrs = getattr(self, "_predict_additional_attributes", None)
if not attrs:
return pred

# Pulse- vs event-level: pred[0] has one row per node when
# predictions are on the pulse level.
pulse_level_predictions = len(pred[0]) > len(batch)
n_pulses = (
batch.n_pulses.detach().cpu().numpy()
if pulse_level_predictions
else None
)
for attr in attrs:
value = batch[attr]
if isinstance(value, torch.Tensor):
value = value.detach().cpu().numpy()
else:
value = np.asarray(value)
if (
pulse_level_predictions
and n_pulses is not None
and len(value) < n_pulses.sum()
):
value = np.repeat(value, n_pulses)
pred.append(value)
return pred

def predict(
self,
dataloader: DataLoader,
gpus: Optional[Union[List[int], int]] = None,
distribution_strategy: Optional[str] = "auto",
additional_attributes: Optional[List[str]] = None,
**trainer_kwargs: Any,
) -> List[Tensor]:
"""Return predictions for `dataloader`."""
) -> List[Union[Tensor, np.ndarray]]:
"""Return predictions for `dataloader`.

If `additional_attributes` is provided, the returned list has the
per-task prediction tensors followed by one numpy array per
requested attribute, gathered from the same dataloader pass.
"""
self.inference()
self.train(mode=False)

Expand All @@ -310,14 +354,30 @@ def predict(
**trainer_kwargs,
)

predictions_list = inference_trainer.predict(self, dataloader)
# Stash on self so predict_step can see it; always clear.
self._predict_additional_attributes = additional_attributes or None
try:
predictions_list = inference_trainer.predict(self, dataloader)
finally:
self._predict_additional_attributes = None
assert len(predictions_list), "Got no predictions"

# The trailing entries in each batch's output are the gathered
# additional attributes (numpy arrays); everything before is a
# per-task prediction tensor.
nb_attrs = len(additional_attributes or [])
nb_outputs = len(predictions_list[0])
predictions: List[Tensor] = [
torch.cat([preds[ix] for preds in predictions_list], dim=0)
for ix in range(nb_outputs)
]
nb_task_outputs = nb_outputs - nb_attrs
predictions: List[Union[Tensor, np.ndarray]] = []
for ix in range(nb_outputs):
if ix < nb_task_outputs:
predictions.append(
torch.cat([p[ix] for p in predictions_list], dim=0)
)
else:
predictions.append(
np.concatenate([p[ix] for p in predictions_list], axis=0)
)
return predictions

def predict_as_dataframe(
Expand All @@ -333,106 +393,67 @@ def predict_as_dataframe(
"""Return predictions for `dataloader` as a DataFrame.

Include `additional_attributes` as additional columns in the output
DataFrame.
DataFrame. Attributes are gathered during the prediction pass, so
the dataloader is iterated only once and shuffling is safe.
"""
if prediction_columns is None:
prediction_columns = self.prediction_labels

if additional_attributes is None:
additional_attributes = []
assert isinstance(additional_attributes, list)

if (
not isinstance(dataloader.sampler, SequentialSampler)
and additional_attributes
):
print(dataloader.sampler)
raise UserWarning(
"DataLoader has a `sampler` that is not `SequentialSampler`, "
"indicating that shuffling is enabled. Using "
"`predict_as_dataframe` with `additional_attributes` assumes "
"that the sequence of batches in `dataloader` are "
"deterministic. Either call this method a `dataloader` which "
"doesn't resample batches; or do not request "
"`additional_attributes`."
)
self.info(f"Column names for predictions are: \n {prediction_columns}")
predictions_torch = self.predict(

outputs = self.predict(
dataloader=dataloader,
gpus=gpus,
distribution_strategy=distribution_strategy,
additional_attributes=additional_attributes,
**trainer_kwargs,
)
predictions = (
torch.cat(predictions_torch, dim=1).detach().cpu().numpy()
)

# `predict` returns task tensors first, then one np.ndarray per
# requested attribute — split on that boundary.
split = len(outputs) - len(additional_attributes)
pred_tensors = outputs[:split]
attr_arrays = outputs[split:]

predictions = torch.cat(pred_tensors, dim=1).detach().cpu().numpy()
assert len(prediction_columns) == predictions.shape[1], (
f"Number of provided column names ({len(prediction_columns)}) and "
f"number of output columns ({predictions.shape[1]}) don't match."
)

# Check if predictions are on event- or pulse-level
pulse_level_predictions = len(predictions) > len(dataloader.dataset)

# Get additional attributes
attributes: Dict[str, List[np.ndarray]] = OrderedDict(
[(attr, []) for attr in additional_attributes]
)
for batch in dataloader:
for attr in attributes:
attribute = batch[attr]
if isinstance(attribute, torch.Tensor):
attribute = attribute.detach().cpu().numpy()

# Check if node level predictions
# If true, additional attributes are repeated
# to make dimensions fit
if pulse_level_predictions:
if len(attribute) < np.sum(
batch.n_pulses.detach().cpu().numpy()
):
attribute = np.repeat(
attribute, batch.n_pulses.detach().cpu().numpy()
)
attributes[attr].extend(attribute)

# Confirm that attributes match length of predictions
skip_attributes = []
for attr in attributes.keys():
try:
assert len(attributes[attr]) == len(predictions)
except AssertionError:
columns: Dict[str, np.ndarray] = {
name: predictions[:, i]
for i, name in enumerate(prediction_columns)
}
for name, arr in zip(additional_attributes, attr_arrays):
if len(arr) != len(predictions):
self.warning_once(
"Could not automatically adjust length"
f" of additional attribute '{attr}' to match length of"
f" predictions.This error can be caused by heavy"
f" of additional attribute '{name}' to match length of"
" predictions. This error can be caused by heavy"
" disagreement between number of examples in the"
" dataset vs. actual events in the dataloader, e.g. "
" heavy filtering of events in `collate_fn` passed to"
" `dataloader`. This can also be caused by requesting"
" pulse-level attributes for `Task`s that produce"
" event-level predictions. Attribute skipped."
)
skip_attributes.append(attr)

# Remove bad attributes
for attr in skip_attributes:
attributes.pop(attr)
additional_attributes.remove(attr)

data = np.concatenate(
[predictions]
+ [
np.asarray(values)[:, np.newaxis]
for values in attributes.values()
],
axis=1,
)

results = pd.DataFrame(
data, columns=prediction_columns + additional_attributes
)
return results
continue
arr = np.asarray(arr)
if arr.ndim == 1:
columns[name] = arr
else:
# Multi-dim target (e.g. `direction` with x/y/z components):
# expand to one column per component so the DataFrame stays
# tabular.
flat = arr.reshape(len(arr), -1)
for i in range(flat.shape[1]):
columns[f"{name}_{i}"] = flat[:, i]

return pd.DataFrame(columns)

def _create_default_callbacks(
self,
Expand Down
Loading
Loading