Skip to content

Commit e70cfb3

Browse files
author
Brandon Lefore
committed
Version 0.78.0
1 parent 7e9be52 commit e70cfb3

File tree

225 files changed

+3046
-82
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

225 files changed

+3046
-82
lines changed

abacusai/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,4 @@
44
from .streaming_client import StreamingClient
55

66

7-
__version__ = "0.77.8"
7+
__version__ = "0.78.0"

abacusai/api_class/model.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ class PersonalizationTrainingConfig(TrainingConfig):
5555
compute_session_metrics (bool): Evaluate models based on how well they are able to predict the next session of interactions.
5656
max_user_history_len_percentile (int): Filter out users with history length above this percentile.
5757
downsample_item_popularity_percentile (float): Downsample items more popular than this percentile.
58-
58+
allow_duplicate_action_types (List[str]): event types which will not be deduplicated.
5959
"""
6060
# top-level params
6161
objective: enums.PersonalizationObjective = dataclasses.field(default=None)
@@ -67,6 +67,7 @@ class PersonalizationTrainingConfig(TrainingConfig):
6767
target_action_types: List[str] = dataclasses.field(default=None)
6868
target_action_weights: Dict[str, float] = dataclasses.field(default=None)
6969
session_event_types: List[str] = dataclasses.field(default=None)
70+
allow_duplicate_action_types: List[str] = dataclasses.field(default=None)
7071

7172
# data split
7273
test_split: int = dataclasses.field(default=None)
@@ -270,6 +271,7 @@ class ForecastingTrainingConfig(TrainingConfig):
270271
use_log_transforms (bool): Apply logarithmic transformations to input data.
271272
smooth_history (float): Smooth (low pass filter) the timeseries.
272273
local_scale_target (bool): Using per training/prediction window target scaling.
274+
use_clipping (bool): Apply clipping to input data to stabilize the training.
273275
timeseries_weight_column (str): If set, we use the values in this column from timeseries data to assign time dependent item weights during training and evaluation.
274276
item_attributes_weight_column (str): If set, we use the values in this column from item attributes data to assign weights to items during training and evaluation.
275277
use_timeseries_weights_in_objective (bool): If True, we include weights from column set as "TIMESERIES WEIGHT COLUMN" in objective functions.
@@ -334,6 +336,7 @@ class ForecastingTrainingConfig(TrainingConfig):
334336
use_log_transforms: bool = dataclasses.field(default=None)
335337
smooth_history: float = dataclasses.field(default=None)
336338
local_scale_target: bool = dataclasses.field(default=None)
339+
use_clipping: bool = dataclasses.field(default=None)
337340
# Item weights
338341
timeseries_weight_column: str = dataclasses.field(default=None)
339342
item_attributes_weight_column: str = dataclasses.field(default=None)

abacusai/batch_prediction_version.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,10 +44,11 @@ class BatchPredictionVersion(AbstractApiClass):
4444
outputFeatureGroupId (str): The BP output feature group id if applicable
4545
outputFeatureGroupVersion (str): The BP output feature group version if applicable
4646
outputFeatureGroupTableName (str): The BP output feature group name if applicable
47+
batchPredictionWarnings (str): Relevant warnings if any issues are found
4748
batchInputs (PredictionInput): Inputs to the batch prediction
4849
"""
4950

50-
def __init__(self, client, batchPredictionVersion=None, batchPredictionId=None, status=None, driftMonitorStatus=None, deploymentId=None, modelId=None, modelVersion=None, predictionsStartedAt=None, predictionsCompletedAt=None, databaseOutputError=None, totalPredictions=None, failedPredictions=None, databaseConnectorId=None, databaseOutputConfiguration=None, explanations=None, fileConnectorOutputLocation=None, fileOutputFormat=None, connectorType=None, legacyInputLocation=None, error=None, driftMonitorError=None, monitorWarnings=None, csvInputPrefix=None, csvPredictionPrefix=None, csvExplanationsPrefix=None, databaseOutputTotalWrites=None, databaseOutputFailedWrites=None, outputIncludesMetadata=None, resultInputColumns=None, modelMonitorVersion=None, algoName=None, algorithm=None, outputFeatureGroupId=None, outputFeatureGroupVersion=None, outputFeatureGroupTableName=None, batchInputs={}, globalPredictionArgs={}):
51+
def __init__(self, client, batchPredictionVersion=None, batchPredictionId=None, status=None, driftMonitorStatus=None, deploymentId=None, modelId=None, modelVersion=None, predictionsStartedAt=None, predictionsCompletedAt=None, databaseOutputError=None, totalPredictions=None, failedPredictions=None, databaseConnectorId=None, databaseOutputConfiguration=None, explanations=None, fileConnectorOutputLocation=None, fileOutputFormat=None, connectorType=None, legacyInputLocation=None, error=None, driftMonitorError=None, monitorWarnings=None, csvInputPrefix=None, csvPredictionPrefix=None, csvExplanationsPrefix=None, databaseOutputTotalWrites=None, databaseOutputFailedWrites=None, outputIncludesMetadata=None, resultInputColumns=None, modelMonitorVersion=None, algoName=None, algorithm=None, outputFeatureGroupId=None, outputFeatureGroupVersion=None, outputFeatureGroupTableName=None, batchPredictionWarnings=None, batchInputs={}, globalPredictionArgs={}):
5152
super().__init__(client, batchPredictionVersion)
5253
self.batch_prediction_version = batchPredictionVersion
5354
self.batch_prediction_id = batchPredictionId
@@ -84,13 +85,14 @@ def __init__(self, client, batchPredictionVersion=None, batchPredictionId=None,
8485
self.output_feature_group_id = outputFeatureGroupId
8586
self.output_feature_group_version = outputFeatureGroupVersion
8687
self.output_feature_group_table_name = outputFeatureGroupTableName
88+
self.batch_prediction_warnings = batchPredictionWarnings
8789
self.batch_inputs = client._build_class(PredictionInput, batchInputs)
8890
self.global_prediction_args = client._build_class(
8991
BatchPredictionArgs, globalPredictionArgs)
9092

9193
def __repr__(self):
92-
repr_dict = {f'batch_prediction_version': repr(self.batch_prediction_version), f'batch_prediction_id': repr(self.batch_prediction_id), f'status': repr(self.status), f'drift_monitor_status': repr(self.drift_monitor_status), f'deployment_id': repr(self.deployment_id), f'model_id': repr(self.model_id), f'model_version': repr(self.model_version), f'predictions_started_at': repr(self.predictions_started_at), f'predictions_completed_at': repr(self.predictions_completed_at), f'database_output_error': repr(self.database_output_error), f'total_predictions': repr(self.total_predictions), f'failed_predictions': repr(self.failed_predictions), f'database_connector_id': repr(self.database_connector_id), f'database_output_configuration': repr(self.database_output_configuration), f'explanations': repr(self.explanations), f'file_connector_output_location': repr(self.file_connector_output_location), f'file_output_format': repr(self.file_output_format), f'connector_type': repr(self.connector_type), f'legacy_input_location': repr(
93-
self.legacy_input_location), f'error': repr(self.error), f'drift_monitor_error': repr(self.drift_monitor_error), f'monitor_warnings': repr(self.monitor_warnings), f'csv_input_prefix': repr(self.csv_input_prefix), f'csv_prediction_prefix': repr(self.csv_prediction_prefix), f'csv_explanations_prefix': repr(self.csv_explanations_prefix), f'database_output_total_writes': repr(self.database_output_total_writes), f'database_output_failed_writes': repr(self.database_output_failed_writes), f'output_includes_metadata': repr(self.output_includes_metadata), f'result_input_columns': repr(self.result_input_columns), f'model_monitor_version': repr(self.model_monitor_version), f'algo_name': repr(self.algo_name), f'algorithm': repr(self.algorithm), f'output_feature_group_id': repr(self.output_feature_group_id), f'output_feature_group_version': repr(self.output_feature_group_version), f'output_feature_group_table_name': repr(self.output_feature_group_table_name), f'batch_inputs': repr(self.batch_inputs), f'global_prediction_args': repr(self.global_prediction_args)}
94+
repr_dict = {f'batch_prediction_version': repr(self.batch_prediction_version), f'batch_prediction_id': repr(self.batch_prediction_id), f'status': repr(self.status), f'drift_monitor_status': repr(self.drift_monitor_status), f'deployment_id': repr(self.deployment_id), f'model_id': repr(self.model_id), f'model_version': repr(self.model_version), f'predictions_started_at': repr(self.predictions_started_at), f'predictions_completed_at': repr(self.predictions_completed_at), f'database_output_error': repr(self.database_output_error), f'total_predictions': repr(self.total_predictions), f'failed_predictions': repr(self.failed_predictions), f'database_connector_id': repr(self.database_connector_id), f'database_output_configuration': repr(self.database_output_configuration), f'explanations': repr(self.explanations), f'file_connector_output_location': repr(self.file_connector_output_location), f'file_output_format': repr(self.file_output_format), f'connector_type': repr(self.connector_type), f'legacy_input_location': repr(self.legacy_input_location), f'error': repr(
95+
self.error), f'drift_monitor_error': repr(self.drift_monitor_error), f'monitor_warnings': repr(self.monitor_warnings), f'csv_input_prefix': repr(self.csv_input_prefix), f'csv_prediction_prefix': repr(self.csv_prediction_prefix), f'csv_explanations_prefix': repr(self.csv_explanations_prefix), f'database_output_total_writes': repr(self.database_output_total_writes), f'database_output_failed_writes': repr(self.database_output_failed_writes), f'output_includes_metadata': repr(self.output_includes_metadata), f'result_input_columns': repr(self.result_input_columns), f'model_monitor_version': repr(self.model_monitor_version), f'algo_name': repr(self.algo_name), f'algorithm': repr(self.algorithm), f'output_feature_group_id': repr(self.output_feature_group_id), f'output_feature_group_version': repr(self.output_feature_group_version), f'output_feature_group_table_name': repr(self.output_feature_group_table_name), f'batch_prediction_warnings': repr(self.batch_prediction_warnings), f'batch_inputs': repr(self.batch_inputs), f'global_prediction_args': repr(self.global_prediction_args)}
9496
class_name = "BatchPredictionVersion"
9597
repr_str = ',\n '.join([f'{key}={value}' for key, value in repr_dict.items(
9698
) if getattr(self, key, None) is not None])
@@ -103,8 +105,8 @@ def to_dict(self):
103105
Returns:
104106
dict: The dict value representation of the class parameters
105107
"""
106-
resp = {'batch_prediction_version': self.batch_prediction_version, 'batch_prediction_id': self.batch_prediction_id, 'status': self.status, 'drift_monitor_status': self.drift_monitor_status, 'deployment_id': self.deployment_id, 'model_id': self.model_id, 'model_version': self.model_version, 'predictions_started_at': self.predictions_started_at, 'predictions_completed_at': self.predictions_completed_at, 'database_output_error': self.database_output_error, 'total_predictions': self.total_predictions, 'failed_predictions': self.failed_predictions, 'database_connector_id': self.database_connector_id, 'database_output_configuration': self.database_output_configuration, 'explanations': self.explanations, 'file_connector_output_location': self.file_connector_output_location, 'file_output_format': self.file_output_format, 'connector_type': self.connector_type, 'legacy_input_location': self.legacy_input_location, 'error': self.error,
107-
'drift_monitor_error': self.drift_monitor_error, 'monitor_warnings': self.monitor_warnings, 'csv_input_prefix': self.csv_input_prefix, 'csv_prediction_prefix': self.csv_prediction_prefix, 'csv_explanations_prefix': self.csv_explanations_prefix, 'database_output_total_writes': self.database_output_total_writes, 'database_output_failed_writes': self.database_output_failed_writes, 'output_includes_metadata': self.output_includes_metadata, 'result_input_columns': self.result_input_columns, 'model_monitor_version': self.model_monitor_version, 'algo_name': self.algo_name, 'algorithm': self.algorithm, 'output_feature_group_id': self.output_feature_group_id, 'output_feature_group_version': self.output_feature_group_version, 'output_feature_group_table_name': self.output_feature_group_table_name, 'batch_inputs': self._get_attribute_as_dict(self.batch_inputs), 'global_prediction_args': self._get_attribute_as_dict(self.global_prediction_args)}
108+
resp = {'batch_prediction_version': self.batch_prediction_version, 'batch_prediction_id': self.batch_prediction_id, 'status': self.status, 'drift_monitor_status': self.drift_monitor_status, 'deployment_id': self.deployment_id, 'model_id': self.model_id, 'model_version': self.model_version, 'predictions_started_at': self.predictions_started_at, 'predictions_completed_at': self.predictions_completed_at, 'database_output_error': self.database_output_error, 'total_predictions': self.total_predictions, 'failed_predictions': self.failed_predictions, 'database_connector_id': self.database_connector_id, 'database_output_configuration': self.database_output_configuration, 'explanations': self.explanations, 'file_connector_output_location': self.file_connector_output_location, 'file_output_format': self.file_output_format, 'connector_type': self.connector_type, 'legacy_input_location': self.legacy_input_location, 'error': self.error, 'drift_monitor_error': self.drift_monitor_error,
109+
'monitor_warnings': self.monitor_warnings, 'csv_input_prefix': self.csv_input_prefix, 'csv_prediction_prefix': self.csv_prediction_prefix, 'csv_explanations_prefix': self.csv_explanations_prefix, 'database_output_total_writes': self.database_output_total_writes, 'database_output_failed_writes': self.database_output_failed_writes, 'output_includes_metadata': self.output_includes_metadata, 'result_input_columns': self.result_input_columns, 'model_monitor_version': self.model_monitor_version, 'algo_name': self.algo_name, 'algorithm': self.algorithm, 'output_feature_group_id': self.output_feature_group_id, 'output_feature_group_version': self.output_feature_group_version, 'output_feature_group_table_name': self.output_feature_group_table_name, 'batch_prediction_warnings': self.batch_prediction_warnings, 'batch_inputs': self._get_attribute_as_dict(self.batch_inputs), 'global_prediction_args': self._get_attribute_as_dict(self.global_prediction_args)}
108110
return {key: value for key, value in resp.items() if value is not None}
109111

110112
def download_batch_prediction_result_chunk(self, offset: int = 0, chunk_size: int = 10485760):
@@ -148,6 +150,18 @@ def describe(self):
148150
"""
149151
return self.client.describe_batch_prediction_version(self.batch_prediction_version)
150152

153+
def get_logs(self):
154+
"""
155+
Retrieves the batch prediction logs.
156+
157+
Args:
158+
batch_prediction_version (str): The unique version ID of the batch prediction version.
159+
160+
Returns:
161+
BatchPredictionVersionLogs: The logs for the specified batch prediction version.
162+
"""
163+
return self.client.get_batch_prediction_version_logs(self.batch_prediction_version)
164+
151165
def download_result_to_file(self, file):
152166
"""
153167
Downloads the batch prediction version in a local file.
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
from .return_class import AbstractApiClass
2+
3+
4+
class BatchPredictionVersionLogs(AbstractApiClass):
5+
"""
6+
Logs from batch prediction version.
7+
8+
Args:
9+
client (ApiClient): An authenticated API Client instance
10+
logs (list[str]): List of logs from batch prediction version.
11+
warnings (list[str]): List of warnings from batch prediction version.
12+
"""
13+
14+
def __init__(self, client, logs=None, warnings=None):
15+
super().__init__(client, None)
16+
self.logs = logs
17+
self.warnings = warnings
18+
19+
def __repr__(self):
20+
repr_dict = {f'logs': repr(
21+
self.logs), f'warnings': repr(self.warnings)}
22+
class_name = "BatchPredictionVersionLogs"
23+
repr_str = ',\n '.join([f'{key}={value}' for key, value in repr_dict.items(
24+
) if getattr(self, key, None) is not None])
25+
return f"{class_name}({repr_str})"
26+
27+
def to_dict(self):
28+
"""
29+
Get a dict representation of the parameters in this class
30+
31+
Returns:
32+
dict: The dict value representation of the class parameters
33+
"""
34+
resp = {'logs': self.logs, 'warnings': self.warnings}
35+
return {key: value for key, value in resp.items() if value is not None}

abacusai/chat_message.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,19 +12,21 @@ class ChatMessage(AbstractApiClass):
1212
timestamp (str): The timestamp at which the message was sent
1313
isUseful (bool): Whether this message was marked as useful or not
1414
feedback (str): The feedback provided for the message
15+
docId (str): Id of the uploaded document if the message has
1516
"""
1617

17-
def __init__(self, client, role=None, text=None, timestamp=None, isUseful=None, feedback=None):
18+
def __init__(self, client, role=None, text=None, timestamp=None, isUseful=None, feedback=None, docId=None):
1819
super().__init__(client, None)
1920
self.role = role
2021
self.text = text
2122
self.timestamp = timestamp
2223
self.is_useful = isUseful
2324
self.feedback = feedback
25+
self.doc_id = docId
2426

2527
def __repr__(self):
2628
repr_dict = {f'role': repr(self.role), f'text': repr(self.text), f'timestamp': repr(
27-
self.timestamp), f'is_useful': repr(self.is_useful), f'feedback': repr(self.feedback)}
29+
self.timestamp), f'is_useful': repr(self.is_useful), f'feedback': repr(self.feedback), f'doc_id': repr(self.doc_id)}
2830
class_name = "ChatMessage"
2931
repr_str = ',\n '.join([f'{key}={value}' for key, value in repr_dict.items(
3032
) if getattr(self, key, None) is not None])
@@ -38,5 +40,5 @@ def to_dict(self):
3840
dict: The dict value representation of the class parameters
3941
"""
4042
resp = {'role': self.role, 'text': self.text, 'timestamp': self.timestamp,
41-
'is_useful': self.is_useful, 'feedback': self.feedback}
43+
'is_useful': self.is_useful, 'feedback': self.feedback, 'doc_id': self.doc_id}
4244
return {key: value for key, value in resp.items() if value is not None}

0 commit comments

Comments
 (0)