Skip to content

Commit bf54de8

Browse files
fix type warning
1 parent 04d272d commit bf54de8

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

Diff for: nequip/data/AtomicData.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -157,9 +157,12 @@ def _process_dict(kwargs, ignore_fields=[]):
157157
# ^ this tensor is a scalar; we need to give it
158158
# a data dimension to play nice with irreps
159159
kwargs[k] = v
160+
elif isinstance(v, torch.Tensor):
161+
# This is a tensor, so we just don't do anything except avoid the warning in the `else`
162+
pass
160163
else:
161164
warnings.warn(
162-
f"Value for field {k} was of unsupported type {type(k)} (value was {v})"
165+
f"Value for field {k} was of unsupported type {type(v)} (value was {v})"
163166
)
164167

165168
if AtomicDataDict.BATCH_KEY in kwargs:

0 commit comments

Comments
 (0)