Skip to content

Commit

Permalink
#41 refactoring
Browse files Browse the repository at this point in the history
Signed-off-by: Christian Munz <[email protected]>
  • Loading branch information
chris-1187 committed Nov 25, 2024
1 parent 9508683 commit c0c9780
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 32 deletions.
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
"openai>=1.13.3,<2.0.0",
"pydantic>=2.6.0,<3.0.0",
"statsmodels>=0.14.1,<0.15.0",
"pmdarima>=2.0.4",
]

PYSPARK_PACKAGES = [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -257,22 +257,8 @@ def _get_source_model(self, source_df) -> ARIMAResults:
.tail(self.rows_to_analyze)
)

if not self.arima_auto:
if self.arima_auto:
# Default case: False
model = ARIMA(
endog=input_data,
order=self.order,
seasonal_order=self.seasonal_order,
trend=self.trend,
enforce_stationarity=self.enforce_stationarity,
enforce_invertibility=self.enforce_invertibility,
concentrate_scale=self.concentrate_scale,
trend_offset=self.trend_offset,
missing=self.missing,
)
return model.fit()

else:
auto_model = auto_arima(
y=input_data,
seasonal=any(self.seasonal_order),
Expand All @@ -283,23 +269,28 @@ def _get_source_model(self, source_df) -> ARIMAResults:
max_order=None,
)

auto_source_order = auto_model.order
auto_source_seasonal_order = auto_model.seasonal_order

trend = "c" if auto_source_order[1] == 0 else "t"

model = ARIMA(
endog=input_data,
order=auto_source_order,
seasonal_order=auto_source_seasonal_order,
trend=trend,
enforce_stationarity=self.enforce_stationarity,
enforce_invertibility=self.enforce_invertibility,
concentrate_scale=self.concentrate_scale,
trend_offset=self.trend_offset,
missing=self.missing,
)
return model.fit()
order = auto_model.order
seasonal_order = auto_model.seasonal_order
trend = "c" if order[1] == 0 else "t"

else:
order = self.order
seasonal_order = self.seasonal_order
trend = self.trend

model = ARIMA(
endog=input_data,
order=order,
seasonal_order=seasonal_order,
trend=trend,
enforce_stationarity=self.enforce_stationarity,
enforce_invertibility=self.enforce_invertibility,
concentrate_scale=self.concentrate_scale,
trend_offset=self.trend_offset,
missing=self.missing,
)

return model.fit()

def _split_by_source(self) -> dict:
"""
Expand Down

0 comments on commit c0c9780

Please sign in to comment.