From a5637002e4599e8b9e78db8e7be0cdb380942673 Mon Sep 17 00:00:00 2001 From: Nicolas Forstner <70640646+nlsfnr@users.noreply.github.com> Date: Fri, 3 Mar 2023 18:40:46 +0000 Subject: [PATCH] Unified error messages (#199) Unifies error messages across the project to make them more useful and consistent Improves the error message for invalid collection names --------- Co-authored-by: Hammad Bashir --- chromadb/__init__.py | 4 +- chromadb/api/local.py | 27 ++++++++------ chromadb/api/types.py | 53 ++++++++++++++------------- chromadb/db/clickhouse.py | 22 +++++------ chromadb/db/duckdb.py | 14 +++---- chromadb/test/test_api.py | 27 +++++--------- chromadb/utils/embedding_functions.py | 2 +- 7 files changed, 73 insertions(+), 76 deletions(-) diff --git a/chromadb/__init__.py b/chromadb/__init__.py index 2cd1c03a1bc..48eab1cc384 100644 --- a/chromadb/__init__.py +++ b/chromadb/__init__.py @@ -42,7 +42,7 @@ def require(key): return chromadb.db.duckdb.DuckDB(settings) else: - raise Exception(f"Unknown value '{setting} for chroma_db_impl") + raise ValueError(f"Expected chroma_db_impl to be one of clickhouse, duckdb, duckdb+parquet, got {setting}") def Client(settings=__settings): @@ -67,4 +67,4 @@ def require(key): return chromadb.api.local.LocalAPI(settings, get_db(settings)) else: - raise Exception(f"Unknown value '{setting} for chroma_api_impl") + raise ValueError(f"Expected chroma_api_impl to be one of rest, local, got {setting}") diff --git a/chromadb/api/local.py b/chromadb/api/local.py index 64e9d43cb88..9bb1ddead76 100644 --- a/chromadb/api/local.py +++ b/chromadb/api/local.py @@ -22,16 +22,22 @@ # mimics s3 bucket requirements for naming -def is_valid_index_name(index_name): +def check_index_name(index_name): + msg = ("Expected collection name that " + "(1) contains 3-63 characters, " + "(2) starts and ends with an alphanumeric character, " + "(3) otherwise contains only alphanumeric characters, underscores or hyphens (-), " + "(4) contains no two consecutive periods (..) and " + "(5) is not a valid IPv4 address, " + f"got {index_name}") if len(index_name) < 3 or len(index_name) > 63: - return False + raise ValueError(msg) if not re.match("^[a-z0-9][a-z0-9._-]*[a-z0-9]$", index_name): - return False + raise ValueError(msg) if ".." in index_name: - return False + raise ValueError(msg) if re.match("^[0-9]{1,3}.[0-9]{1,3}.[0-9]{1,3}.[0-9]{1,3}$", index_name): - return False - return True + raise ValueError(msg) class LocalAPI(API): @@ -51,8 +57,7 @@ def create_collection( embedding_function: Optional[Callable] = None, get_or_create: bool = False, ) -> Collection: - if not is_valid_index_name(name): - raise ValueError("Invalid index name: %s" % name) # NIT: tell the user why + check_index_name(name) res = self._db.create_collection(name, metadata, get_or_create) return Collection( @@ -74,7 +79,7 @@ def get_collection( ) -> Collection: res = self._db.get_collection(name) if len(res) == 0: - raise ValueError("Collection not found: %s" % name) + raise ValueError(f"Collection {name} does not exist") return Collection( client=self, name=name, embedding_function=embedding_function, metadata=res[0][2] ) @@ -94,10 +99,8 @@ def _modify( new_name: Optional[str] = None, new_metadata: Optional[Dict] = None, ): - # NIT: make sure we have a valid name like we do in create if new_name is not None: - if not is_valid_index_name(new_name): - raise ValueError("Invalid index name: %s" % new_name) + check_index_name(new_name) self._db.update_collection(current_name, new_name, new_metadata) diff --git a/chromadb/api/types.py b/chromadb/api/types.py index 1a5c0746665..441c6b5786f 100644 --- a/chromadb/api/types.py +++ b/chromadb/api/types.py @@ -80,29 +80,29 @@ def maybe_cast_one_to_many( def validate_ids(ids: IDs) -> IDs: """Validates ids to ensure it is a list of strings""" if not isinstance(ids, list): - raise ValueError("IDs must be a list") + raise ValueError(f"Expected IDs to be a list, got {ids}") for id in ids: if not isinstance(id, str): - raise ValueError(f"ID {id} must be a string") + raise ValueError(f"Expected ID to be a str, got {id}") return ids def validate_metadata(metadata: Metadata) -> Metadata: """Validates metadata to ensure it is a dictionary of strings to strings, ints, or floats""" if not isinstance(metadata, dict): - raise ValueError("Metadata must be a dictionary") + raise ValueError(f"Expected metadata to be a dict, got {metadata}") for key, value in metadata.items(): if not isinstance(key, str): - raise ValueError(f"Metadata key {key} must be a string") + raise ValueError(f"Expected metadata key to be a str, got {key}") if not isinstance(value, (str, int, float)): - raise ValueError(f"Metadata value {value} must be a string, int, or float") + raise ValueError(f"Expected metadata value to be a str, int, or float, got {value}") return metadata def validate_metadatas(metadatas: Metadatas) -> Metadatas: """Validates metadatas to ensure it is a list of dictionaries of strings to strings, ints, or floats""" if not isinstance(metadatas, list): - raise ValueError("Metadatas must be a list") + raise ValueError(f"Expected metadatas to be a list, got {metadatas}") for metadata in metadatas: validate_metadata(metadata) return metadatas @@ -114,22 +114,22 @@ def validate_where(where: Where) -> Where: or in the case of $and and $or, a list of where expressions """ if not isinstance(where, dict): - raise ValueError("Where must be a dictionary") + raise ValueError(f"Expected where to be a dict, got {where}") for key, value in where.items(): if not isinstance(key, str): - raise ValueError(f"Where key {key} must be a string") + raise ValueError(f"Expected where key to be a str, got {key}") if key != "$and" and key != "$or" and not isinstance(value, (str, int, float, dict)): raise ValueError( - f"Where value {value} must be a string, int, or float, or operator expression" + f"Expected where value to be a str, int, float, or operator expression, got {value}" ) if key == "$and" or key == "$or": if not isinstance(value, list): raise ValueError( - f"Where value {value} for $and or $or must be a list of where expressions" + f"Expected where value for $and or $or to be a list of where expressions, got {value}" ) if len(value) <= 1: raise ValueError( - f"Where value {value} for $and or $or must have at least two where expressions" + f"Expected where value for $and or $or to be a list with at least two where expressions, got {value}" ) for where_expression in value: validate_where(where_expression) @@ -138,7 +138,7 @@ def validate_where(where: Where) -> Where: # Ensure there is only one operator if len(value) != 1: raise ValueError( - f"Where operator expression {value} must have exactly one operator" + f"Expected operator expression to have exactly one operator, got {value}" ) for operator, operand in value.items(): @@ -146,17 +146,17 @@ def validate_where(where: Where) -> Where: if operator in ["$gt", "$gte", "$lt", "$lte"]: if not isinstance(operand, (int, float)): raise ValueError( - f"Where operand value {operand} must be an int or float for operator {operator}" + f"Expected operand value to be an int or a float for operator {operator}, got {operand}" ) if operator not in ["$gt", "$gte", "$lt", "$lte", "$ne", "$eq"]: raise ValueError( - f"Where operator must be one of $gt, $gte, $lt, $lte, $ne", "$eq" + f"Expected where operator to be one of $gt, $gte, $lt, $lte, $ne, $eq, got {operator}" ) if not isinstance(operand, (str, int, float)): raise ValueError( - f"Where operand value {operand} must be a string, int, or float" + f"Expected where operand value to be a str, int, or float, got {operand}" ) return where @@ -167,27 +167,27 @@ def validate_where_document(where_document: WhereDocument) -> WhereDocument: a list of where_document expressions """ if not isinstance(where_document, dict): - raise ValueError("Where document must be a dictionary") + raise ValueError(f"Expected where document to be a dictionary, got {where_document}") if len(where_document) != 1: - raise ValueError("Where document must have exactly one operator") + raise ValueError(f"Epected where document to have exactly one operator, got {where_document}") for operator, operand in where_document.items(): if operator not in ["$contains", "$and", "$or"]: - raise ValueError(f"Where document operator must be $contains, $and, or $or") + raise ValueError(f"Expected where document operator to be one of $contains, $and, $or, got {operator}") if operator == "$and" or operator == "$or": if not isinstance(operand, list): raise ValueError( - f"Where document value {operand} for $and or $or must be a list of where document expressions" + f"Expected document value for $and or $or to be a list of where document expressions, got {operand}" ) if len(operand) <= 1: raise ValueError( - f"Where document value {operand} for $and or $or must have at least two where document expressions" + f"Expected document value for $and or $or to be a list with at least two where document expressions, got {operand}" ) for where_document_expression in operand: validate_where_document(where_document_expression) # Value is a $contains operator elif not isinstance(operand, str): raise ValueError( - f"Where document operand value {operand} must be a string for operator $contains" + f"Expected where document operand value for operator $contains to be a str, got {operand}" ) return where_document @@ -197,10 +197,13 @@ def validate_include(include: Include, allow_distances: bool) -> Include: to control if distances is allowed""" if not isinstance(include, list): - raise ValueError("Include must be a list") + raise ValueError(f"Epected include to be a list, got {include}") for item in include: if not isinstance(item, str): - raise ValueError(f"Include item {item} must be a string") - if item not in ["embeddings", "documents", "metadatas"] + ["distances"] * allow_distances: - raise ValueError(f"Include item {item} value not within allowed values") + raise ValueError(f"Expected include item to be a str, got {item}") + allowed_values = ["embeddings", "documents", "metadatas"] + if allow_distances: + allowed_values.append("distances") + if item not in allowed_values: + raise ValueError(f"Expected include item to be one of {', '.join(allowed_values)}, got {item}") return include diff --git a/chromadb/db/clickhouse.py b/chromadb/db/clickhouse.py index 24e3c92c981..1f0c9f6b8f2 100644 --- a/chromadb/db/clickhouse.py +++ b/chromadb/db/clickhouse.py @@ -126,7 +126,7 @@ def create_collection( print(f"collection with name {name} already exists, returning existing collection") return dupe_check else: - raise Exception(f"collection with name {name} already exists") + raise ValueError(f"Collection with name {name} already exists") collection_uuid = uuid.uuid4() data_to_insert = [[collection_uuid, name, json.dumps(metadata)]] @@ -164,12 +164,12 @@ def update_collection( return self._get_conn().command( f""" - ALTER TABLE - collections + ALTER TABLE + collections UPDATE - metadata = '{json.dumps(new_metadata)}', + metadata = '{json.dumps(new_metadata)}', name = '{new_name}' - WHERE + WHERE name = '{current_name}' """ ) @@ -235,10 +235,10 @@ def _update( parameters[f"d{i}"] = documents[i] update_statement = f""" - UPDATE + UPDATE {",".join(update_fields)} WHERE - id = {{i{i}:String}} AND + id = {{i{i}:String}} AND collection_uuid = '{collection_uuid}'{"" if i == len(ids) - 1 else ","} """ updates.append(update_statement) @@ -257,7 +257,7 @@ def update( # Verify all IDs exist existing_items = self.get(collection_uuid=collection_uuid, ids=ids) if len(existing_items) != len(ids): - raise ValueError("Some of the supplied ids for update were not found") + raise ValueError(f"Could not find {len(ids) - len(existing_items)} items for update") # Update the db self._update(collection_uuid, ids, embeddings, metadatas, documents) @@ -315,7 +315,7 @@ def _format_where(self, where, result): return result.append(f" JSONExtractString(metadata,'{key}') = '{operand}'") return result.append(f" JSONExtractFloat(metadata,'{key}') = {operand}") else: - raise ValueError(f"Operator {operator} not supported") + raise ValueError(f"Expected one of $gt, $lt, $gte, $lte, $ne, $eq, got {operator}") elif type(value) == list: all_subresults = [] for subwhere in value: @@ -327,7 +327,7 @@ def _format_where(self, where, result): elif key == "$and": result.append(f"({' AND '.join(all_subresults)})") else: - raise ValueError(f"Operator {key} not supported with a list of where clauses") + raise ValueError(f"Expected one of $or, $and, got {key}") def _format_where_document(self, where_document, results): operator = list(where_document.keys())[0] @@ -344,7 +344,7 @@ def _format_where_document(self, where_document, results): if operator == "$and": results.append(f"({' AND '.join(all_subresults)})") else: - raise ValueError(f"Operator {operator} not supported") + raise ValueError(f"Epected one of $contains, $and, $or, got {operator}") def get( self, diff --git a/chromadb/db/duckdb.py b/chromadb/db/duckdb.py index 91a9474af92..20f4167c046 100644 --- a/chromadb/db/duckdb.py +++ b/chromadb/db/duckdb.py @@ -86,7 +86,7 @@ def create_collection( print(f"collection with name {name} already exists, returning existing collection") return dupe_check else: - raise Exception(f"collection with name {name} already exists") + raise ValueError(f"Collection with name {name} already exists") self._conn.execute( f"""INSERT INTO collections (uuid, name, metadata) VALUES (?, ?, ?)""", @@ -281,12 +281,12 @@ def _update( update_fields.append(f"document = ?") update_statement = f""" - UPDATE + UPDATE embeddings SET {", ".join(update_fields)} WHERE - id = ? AND + id = ? AND collection_uuid = '{collection_uuid}'; """ self._conn.executemany(update_statement, update_data) @@ -307,7 +307,7 @@ def _delete(self, where_str: Optional[str] = None): def get_by_ids(self, ids: List, columns: Optional[List] = None): # select from duckdb table where ids are in the list if not isinstance(ids, list): - raise Exception("ids must be a list") + raise TypeError(f"Expected ids to be a list, got {ids}") if not ids: # create an empty pandas dataframe @@ -351,7 +351,7 @@ def __del__(self): def persist(self): raise NotImplementedError( - "chroma_db_impl='duckdb+parquet' to get persistence functionality" + "Set chroma_db_impl='duckdb+parquet' to get persistence functionality" ) @@ -362,8 +362,8 @@ def __init__(self, settings): super().__init__(settings=settings) if settings.persist_directory == ".chroma": - raise Exception( - "You cannot use chroma's cache directory, please set a different directory" + raise ValueError( + "You cannot use chroma's cache directory .chroma/, please set a different directory" ) self._save_folder = settings.persist_directory diff --git a/chromadb/test/test_api.py b/chromadb/test/test_api.py index 8b4393cd558..1884c2ac3af 100644 --- a/chromadb/test/test_api.py +++ b/chromadb/test/test_api.py @@ -687,9 +687,8 @@ def test_metadata_validation_add(api_fixture, request): api.reset() collection = api.create_collection("test_metadata_validation") - with pytest.raises(ValueError) as e: + with pytest.raises(ValueError, match='metadata'): collection.add(**bad_metadata_records) - assert "Metadata" in str(e.value) @pytest.mark.parametrize("api_fixture", test_apis) @@ -699,9 +698,8 @@ def test_metadata_validation_update(api_fixture, request): api.reset() collection = api.create_collection("test_metadata_validation") collection.add(**metadata_records) - with pytest.raises(ValueError) as e: + with pytest.raises(ValueError, match='metadata'): collection.update(ids=["id1"], metadatas={"value": {"nested": "5"}}) - assert "Metadata" in str(e.value) @pytest.mark.parametrize("api_fixture", test_apis) @@ -710,9 +708,8 @@ def test_where_validation_get(api_fixture, request): api.reset() collection = api.create_collection("test_where_validation") - with pytest.raises(ValueError) as e: + with pytest.raises(ValueError, match='where'): collection.get(where={"value": {"nested": "5"}}) - assert "Where" in str(e.value) @pytest.mark.parametrize("api_fixture", test_apis) @@ -721,9 +718,8 @@ def test_where_validation_query(api_fixture, request): api.reset() collection = api.create_collection("test_where_validation") - with pytest.raises(ValueError) as e: + with pytest.raises(ValueError, match='where'): collection.query(query_embeddings=[0, 0, 0], where={"value": {"nested": "5"}}) - assert "Where" in str(e.value) operator_records = { @@ -912,17 +908,14 @@ def test_query_document_valid_operators(api_fixture, request): api.reset() collection = api.create_collection("test_where_valid_operators") collection.add(**operator_records) - with pytest.raises(ValueError) as e: + with pytest.raises(ValueError, match='where document'): collection.get(where_document={"$lt": {"$nested": 2}}) - assert "Where document" in str(e.value) - with pytest.raises(ValueError) as e: + with pytest.raises(ValueError, match='where document'): collection.query(query_embeddings=[0, 0, 0], where_document={"$contains": 2}) - assert "Where document" in str(e.value) - with pytest.raises(ValueError) as e: + with pytest.raises(ValueError, match='where document'): collection.get(where_document={"$contains": []}) - assert "Where document" in str(e.value) # Test invalid $and, $or with pytest.raises(ValueError) as e: @@ -1184,13 +1177,11 @@ def test_get_include(api_fixture, request): assert items["embeddings"] == None assert items["ids"][0] == "id1" - with pytest.raises(ValueError) as e: + with pytest.raises(ValueError, match='include'): items = collection.get(include=["metadatas", "undefined"]) - assert "Include" in str(e.value) - with pytest.raises(ValueError) as e: + with pytest.raises(ValueError, match='include'): items = collection.get(include=None) - assert "Include" in str(e.value) # make sure query results are returned in the right order diff --git a/chromadb/utils/embedding_functions.py b/chromadb/utils/embedding_functions.py index aea478ec63a..4ca43281b4b 100644 --- a/chromadb/utils/embedding_functions.py +++ b/chromadb/utils/embedding_functions.py @@ -7,7 +7,7 @@ def __init__(self, model_name: str = "all-MiniLM-L6-v2"): from sentence_transformers import SentenceTransformer except ImportError: raise ValueError( - "sentence_transformers is not installed. Please install it with `pip install sentence_transformers`" + "The sentence_transformers python package is not installed. Please install it with `pip install sentence_transformers`" ) self._model = SentenceTransformer(model_name)