Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 14 additions & 9 deletions src/vectorcode/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,16 +182,11 @@ async def get_collection(
logger.debug(
f"Getting/Creating collection with the following metadata: {collection_meta}"
)
if not make_if_missing:
__COLLECTION_CACHE[full_path] = await client.get_collection(
try:
collection = await client.get_collection(
collection_name, embedding_function
)
else:
collection = await client.get_or_create_collection(
collection_name,
metadata=collection_meta,
embedding_function=embedding_function,
)
__COLLECTION_CACHE[full_path] = collection
if (
not collection.metadata.get("hostname") == socket.gethostname()
or collection.metadata.get("username")
Expand All @@ -208,7 +203,17 @@ async def get_collection(
raise IndexError(
"Failed to create the collection due to hash collision. Please file a bug report."
)
__COLLECTION_CACHE[full_path] = collection
except ValueError:
if make_if_missing:
collection = await client.create_collection(
collection_name,
metadata=collection_meta,
embedding_function=embedding_function,
)

__COLLECTION_CACHE[full_path] = collection
else:
raise
return __COLLECTION_CACHE[full_path]


Expand Down
45 changes: 35 additions & 10 deletions tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,16 @@ async def test_get_collection():
with patch("chromadb.AsyncHttpClient") as MockAsyncHttpClient:
mock_client = MagicMock(spec=AsyncClientAPI)
mock_collection = MagicMock()
mock_collection.metadata = {
"path": config.project_root,
"hostname": socket.gethostname(),
"created-by": "VectorCode",
"username": os.environ.get(
"USER", os.environ.get("USERNAME", "DEFAULT_USER")
),
"embedding_function": config.embedding_function,
"hnsw:M": 64,
}
mock_client.get_collection.return_value = mock_collection
MockAsyncHttpClient.return_value = mock_client

Expand All @@ -232,6 +242,18 @@ async def test_get_collection():
mock_client.get_collection.assert_called_once()
mock_client.get_or_create_collection.assert_not_called()

# Test retrieving a non-existing collection
with patch("chromadb.AsyncHttpClient") as MockAsyncHttpClient:
from vectorcode.common import __COLLECTION_CACHE

__COLLECTION_CACHE.clear()
mock_client = MagicMock(spec=AsyncClientAPI)
mock_client.get_collection.side_effect = ValueError
MockAsyncHttpClient.return_value = mock_client

with pytest.raises(ValueError):
collection = await get_collection(mock_client, config, False)

# Test creating a collection if it doesn't exist
with patch("chromadb.AsyncHttpClient") as MockAsyncHttpClient:
mock_client = MagicMock(spec=AsyncClientAPI)
Expand All @@ -252,7 +274,7 @@ async def test_get_collection():
"created-by": "VectorCode",
}

async def mock_get_or_create_collection(
async def mock_create_collection(
self,
name=None,
configuration=None,
Expand All @@ -263,7 +285,7 @@ async def mock_get_or_create_collection(
mock_collection.metadata.update(metadata or {})
return mock_collection

mock_client.get_or_create_collection.side_effect = mock_get_or_create_collection
mock_client.create_collection.side_effect = mock_create_collection
MockAsyncHttpClient.return_value = mock_client

collection = await get_collection(mock_client, config, make_if_missing=True)
Expand All @@ -273,16 +295,18 @@ async def mock_get_or_create_collection(
)
assert collection.metadata["created-by"] == "VectorCode"
assert collection.metadata["hnsw:M"] == 64
mock_client.get_or_create_collection.assert_called_once()
mock_client.create_collection.assert_called_once()
mock_client.get_collection.side_effect = None

# Test raising IndexError on hash collision.
with patch("chromadb.AsyncHttpClient") as MockAsyncHttpClient:
with (
patch("chromadb.AsyncHttpClient") as MockAsyncHttpClient,
patch("socket.gethostname", side_effect=(lambda: "dummy")),
):
mock_client = MagicMock(spec=AsyncClientAPI)
mock_client.get_or_create_collection.side_effect = IndexError(
"Hash collision occurred"
)

MockAsyncHttpClient.return_value = mock_client
mock_client.get_collection = AsyncMock(return_value=mock_collection)
from vectorcode.common import __COLLECTION_CACHE

__COLLECTION_CACHE.clear()
Expand Down Expand Up @@ -315,7 +339,8 @@ async def test_get_collection_hnsw():
"embedding_function": "SentenceTransformerEmbeddingFunction",
"path": "/test_project",
}
mock_client.get_or_create_collection.return_value = mock_collection
mock_client.create_collection.return_value = mock_collection
mock_client.get_collection.side_effect = ValueError
MockAsyncHttpClient.return_value = mock_client

# Clear the collection cache to force creation
Expand All @@ -332,9 +357,9 @@ async def test_get_collection_hnsw():
assert collection.metadata["created-by"] == "VectorCode"
assert collection.metadata["hnsw:ef_construction"] == 200
assert collection.metadata["hnsw:M"] == 32
mock_client.get_or_create_collection.assert_called_once()
mock_client.create_collection.assert_called_once()
assert (
mock_client.get_or_create_collection.call_args.kwargs["metadata"]
mock_client.create_collection.call_args.kwargs["metadata"]
== mock_collection.metadata
)

Expand Down
Loading