Skip to content

Commit c84bc9f

Browse files
fixed small mistake in output_calculator.py
1 parent 7fa2457 commit c84bc9f

File tree

2 files changed

+10
-5
lines changed

2 files changed

+10
-5
lines changed

optibess_algorithm/output_calculator.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1115,11 +1115,10 @@ def __init__(self, tariff_table: np.ndarray[Any, np.dtype[np.float64]] | None =
11151115
super().__init__(*args, **kwargs)
11161116
self._fill_battery_from_grid = True
11171117
self.tariff_table = tariff_table
1118+
self._sell_prices = self._buy_prices = None
11181119
if sell_prices is not None:
11191120
self.sell_prices = sell_prices
11201121
self.buy_prices = buy_prices if buy_prices is not None else sell_prices
1121-
else:
1122-
self._sell_prices = self._buy_prices = None
11231122
self._model_session = onnxruntime.InferenceSession(os.path.join(os.path.dirname(__file__),
11241123
"schedule_model.onnx"),
11251124
providers=["CPUExecutionProvider"])
@@ -1246,6 +1245,8 @@ def _get_actions(self, year):
12461245
stride = self._df["pv_output"].values.strides[0]
12471246
split_hourly_power = ast(normalized_hourly_power, (YEAR_DAYS + day_add, DAY_LENGTH),
12481247
(DAY_LENGTH * stride, stride))
1248+
temp = normalized_hourly_power.reshape((YEAR_DAYS + day_add, DAY_LENGTH))
1249+
print(np.all(np.equal(split_hourly_power, temp)))
12491250
split_hourly_power = np.repeat(split_hourly_power, DAY_LENGTH, axis=0) + pos_encoding
12501251
split_prices = ast(sell_prices, (YEAR_DAYS + day_add, DAY_LENGTH), (DAY_LENGTH, sell_prices.strides[0]))
12511252
max_sell_prices = np.repeat(np.max(np.abs(split_prices), axis=1) + self.EPSILON, DAY_LENGTH)

tests/test_output_calculator.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import unittest
22
from unittest.mock import Mock
3+
4+
import onnxruntime
35
import pandas as pd
46
import numpy as np
57
import numpy.testing as nptesting
@@ -789,8 +791,10 @@ def test_run(self):
789791
output._prod_trans_loss = 0.024
790792
output._charge_loss = 0.035
791793
output._grid_bess_loss = 0.04
794+
output._model_session = onnxruntime.InferenceSession("output_calculator/test_model.onnx",
795+
providers=["CPUExecutionProvider"])
792796
output.run()
793-
expected_result = [-0., -0., -0., -0., -0., -0., -0., 7000., 7000.,
794-
7000., 7000., 7000., 7000., 1978.25, 5345.75, 2786.48, 7000., 6974.29,
795-
6974.29, 3581.76, -0., -0., -0., -0.]
797+
expected_result = [-0., -0., -0., -0., -0., -0., -0., 6442.93, 6650.14,
798+
6656.05, 7000., 7000., 6745.98, 6830.97, 5384.28, 4091.53, 7000., 6974.29,
799+
6370.69, -0., -0., -0., -0., -0.]
796800
nptesting.assert_array_almost_equal(output.output[0][:24], expected_result, 2)

0 commit comments

Comments
 (0)