-
Notifications
You must be signed in to change notification settings - Fork 3.4k
Open
Labels
bugSomething isn't workingSomething isn't working
Description
Related to DeepLearningExamples/PyTorch/Forecasting/TFT/
(e.g. GNMT/PyTorch or FasterTransformer/All)
Describe the bug
When I run the "inference.py" the error happen because "unscaled_predictions" was numpy.ndarray. Therefore, we need to add the code to process the unscaled_predictions to tensor
To Reproduce
Steps to reproduce the behavior:
python inference.py \
--checkpoint /results/TFT_electricity_bs8x1024_lr1e-3/seed_1/checkpoint.pt \
--data /data/processed/electricity_bin/test.csv \
--tgt_scalers /data/processed/electricity_bin/tgt_scalers.bin \
--cat_encodings /data/processed/electricity_bin/cat_encodings.bin \
--visualize \
--save_predictions
Expected behavior
'numpy.ndarray' object has no attribute 'new_full'
Environment
- GPUs in the system: NVIDIA GeForce RTX 3090
- CUDA driver version 520.61.05
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working