Skip to content

Commit

Permalink
INTPYTHON-412 Add linting of test files (#30)
Browse files Browse the repository at this point in the history
  • Loading branch information
blink1073 authored Dec 10, 2024
1 parent 8f2cafe commit fcbc22f
Show file tree
Hide file tree
Showing 11 changed files with 56 additions and 32 deletions.
8 changes: 4 additions & 4 deletions .github/workflows/_lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ jobs:
${{ env.WORKDIR }}/.mypy_cache_test
key: mypy-test-${{ runner.os }}-${{ runner.arch }}-py${{ matrix.python-version }}-${{ inputs.working-directory }}-${{ hashFiles(format('{0}/poetry.lock', inputs.working-directory)) }}

# - name: Analysing the code with our lint
# working-directory: ${{ inputs.working-directory }}
# run: |
# make lint_tests
- name: Analysing the code with our lint
working-directory: ${{ inputs.working-directory }}
run: |
make lint_tests
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def __hash__(self) -> int:


class AnyDict(dict):
def __init__(self, *args, **kwargs) -> None:
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)

def __eq__(self, other: object) -> bool:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from __future__ import annotations

import os
from contextlib import asynccontextmanager
from typing import AsyncIterator, Optional
from typing import Any, AsyncGenerator, Generator, Optional
from uuid import UUID

import pytest
Expand All @@ -11,7 +13,6 @@
from langgraph.checkpoint.base import BaseCheckpointSaver
from langgraph.checkpoint.mongodb import MongoDBSaver
from langgraph.checkpoint.mongodb.aio import AsyncMongoDBSaver
from langgraph.store.base import BaseStore
from langgraph.store.memory import InMemoryStore

pytest.register_assert_rewrite("tests.memory_assert")
Expand All @@ -25,7 +26,7 @@


@pytest.fixture
def anyio_backend():
def anyio_backend() -> str:
return "asyncio"


Expand All @@ -41,14 +42,14 @@ def deterministic_uuids(mocker: MockerFixture) -> MockerFixture:


@pytest.fixture(scope="function")
def checkpointer_memory():
def checkpointer_memory() -> Generator[BaseCheckpointSaver, None]:
from .memory_assert import MemorySaverAssertImmutable

yield MemorySaverAssertImmutable()


@pytest.fixture
def checkpointer_mongodb():
def checkpointer_mongodb() -> Generator[BaseCheckpointSaver, None]:
"""Fresh checkpointer without any memories."""
with MongoDBSaver.from_conn_string(
os.environ.get("MONGODB_URI", "mongodb://localhost:27017"),
Expand All @@ -60,7 +61,7 @@ def checkpointer_mongodb():


@asynccontextmanager
async def _checkpointer_mongodb_aio():
async def _checkpointer_mongodb_aio() -> AsyncGenerator[AsyncMongoDBSaver, None]:
async with AsyncMongoDBSaver.from_conn_string(
os.environ.get("MONGODB_URI", "mongodb://localhost:27017"),
os.environ.get("DATABASE_NAME", "langchain_checkpoints_db"),
Expand All @@ -73,7 +74,7 @@ async def _checkpointer_mongodb_aio():
@asynccontextmanager
async def awith_checkpointer(
checkpointer_name: Optional[str],
) -> AsyncIterator[BaseCheckpointSaver]:
) -> Any:
if checkpointer_name is None:
yield None
elif checkpointer_name == "memory":
Expand All @@ -88,12 +89,12 @@ async def awith_checkpointer(


@pytest.fixture(scope="function")
def store_in_memory():
def store_in_memory() -> Generator[InMemoryStore]:
yield InMemoryStore()


@asynccontextmanager
async def awith_store(store_name: Optional[str]) -> AsyncIterator[BaseStore]:
async def awith_store(store_name: Optional[str]) -> Any:
if store_name is None:
yield None
elif store_name == "in_memory":
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import re
from typing import Any, Iterator, List, Optional, cast

Expand All @@ -8,11 +10,11 @@


class FakeChatModel(GenericFakeChatModel):
messages: list[BaseMessage]
messages: list[BaseMessage] # type:ignore[assignment]

i: int = 0

def bind_tools(self, functions: list):
def bind_tools(self, tools: list) -> FakeChatModel: # type:ignore[override]
return self

def _generate(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,11 @@ def __init__(

def put(
self,
config: dict,
config: RunnableConfig,
checkpoint: Checkpoint,
metadata: CheckpointMetadata,
new_versions: ChannelVersions,
) -> None:
) -> RunnableConfig:
if self.put_sleep:
import time

Expand All @@ -53,12 +53,12 @@ def put(
if saved := super().get(config):
assert (
self.serde.loads_typed(
self.storage_for_copies[thread_id][checkpoint_ns][saved["id"]]
self.storage_for_copies[thread_id][checkpoint_ns][saved["id"]] # type:ignore[arg-type]
)
== saved
)
self.storage_for_copies[thread_id][checkpoint_ns][checkpoint["id"]] = (
self.serde.dumps_typed(copy_checkpoint(checkpoint))
self.serde.dumps_typed(copy_checkpoint(checkpoint)) # type:ignore[assignment]
)
# call super to write checkpoint
return super().put(config, checkpoint, metadata, new_versions)
Expand All @@ -78,7 +78,7 @@ def put(
checkpoint: Checkpoint,
metadata: CheckpointMetadata,
new_versions: ChannelVersions,
) -> None:
) -> RunnableConfig:
"""The implementation of put() merges config["configurable"] (a run's
configurable fields) with the metadata field. The state of the
checkpoint metadata can be asserted to confirm that the run's
Expand Down Expand Up @@ -115,7 +115,12 @@ async def aput(
new_versions: ChannelVersions,
) -> RunnableConfig:
return await asyncio.get_running_loop().run_in_executor(
None, self.put, config, checkpoint, metadata, new_versions
None,
self.put,
config,
checkpoint,
metadata,
new_versions,
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,8 @@
_AnyIdToolMessage,
)

# mypy: ignore-errors


# define these objects to avoid importing langchain_core.agents
# and therefore avoid relying on core Pydantic version
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,8 @@
_AnyIdToolMessage,
)

# mypy: ignore-errors

pytestmark = pytest.mark.anyio


Expand Down
2 changes: 1 addition & 1 deletion libs/mongodb/langchain_mongodb/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,6 @@ def _wait_until(
return retval

if time.time() - start > timeout:
raise TimeoutError("Didn't ever %s" % success_description)
raise TimeoutError(f"Didn't ever {success_description}")

time.sleep(interval)
11 changes: 9 additions & 2 deletions libs/mongodb/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,15 @@ asyncio_mode = "auto"
[tool.mypy]
disallow_untyped_defs = "True"

[tool.ruff.lint]
select = ["E", "F", "I"]
[tool.ruff]
lint.select = [
"E", # pycodestyle
"F", # Pyflakes
"UP", # pyupgrade
"B", # flake8-bugbear
"I", # isort
]
lint.ignore = ["E501", "B008", "UP007", "UP006"]

[tool.coverage.run]
omit = ["tests/*"]
17 changes: 11 additions & 6 deletions libs/mongodb/tests/integration_tests/test_retrievers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import os
import os # noqa: I001
from typing import Generator, List
from time import sleep, time

Expand All @@ -11,7 +11,7 @@
from langchain_mongodb import MongoDBAtlasVectorSearch
from langchain_mongodb.index import (
create_fulltext_search_index,
create_vector_search_index
create_vector_search_index,
)
from langchain_mongodb.retrievers import (
MongoDBAtlasFullTextSearchRetriever,
Expand Down Expand Up @@ -162,14 +162,19 @@ def test_fulltext_retriever(
)

# Wait for the search index to complete.
search_content = dict(index=SEARCH_INDEX_NAME, wildcard=dict(query="*", path=PAGE_CONTENT_FIELD, allowAnalyzedField=True))
search_content = dict(
index=SEARCH_INDEX_NAME,
wildcard=dict(query="*", path=PAGE_CONTENT_FIELD, allowAnalyzedField=True),
)
n_docs = collection.count_documents({})
t0 = time()
while True:
if (time() - t0) > TIMEOUT:
raise TimeoutError(f'Search index {SEARCH_INDEX_NAME} did not complete in {TIMEOUT}')
results = collection.aggregate([{ "$search": search_content }])
if len(list(results)) == n_docs:
raise TimeoutError(
f"Search index {SEARCH_INDEX_NAME} did not complete in {TIMEOUT}"
)
cursor = collection.aggregate([{"$search": search_content}])
if len(list(cursor)) == n_docs:
break
sleep(INTERVAL)

Expand Down
2 changes: 1 addition & 1 deletion libs/mongodb/tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def embed_documents(self, texts: List[str]) -> List[List[float]]:
for text in texts:
if text not in self.known_texts:
self.known_texts.append(text)
vector = [float(1.0)] * (self.dimensionality - 1) + [
vector = [1.0] * (self.dimensionality - 1) + [
float(self.known_texts.index(text))
]
out_vectors.append(vector)
Expand Down

0 comments on commit fcbc22f

Please sign in to comment.