Skip to content

Add getter functions for TLM defaults #59

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
May 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]

## [1.1.2] - 2025-05-01

- Add getter functions for `_TLM_DEFAULT_MODEL`, `_DEFAULT_TLM_QUALITY_PRESET`, `_TLM_DEFAULT_CONTEXT_LIMIT`, `_TLM_MAX_TOKEN_RANGE`.
- Add unit tests for the getter functions.

## [1.1.1] - 2025-04-23

### Changed
Expand Down Expand Up @@ -141,7 +146,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- Release of the Cleanlab TLM Python client.

[Unreleased]: https://github.com/cleanlab/cleanlab-tlm/compare/v1.1.1...HEAD
[Unreleased]: https://github.com/cleanlab/cleanlab-tlm/compare/v1.1.2...HEAD
[1.1.1]: https://github.com/cleanlab/cleanlab-tlm/compare/v1.1.1...v1.1.2
[1.1.1]: https://github.com/cleanlab/cleanlab-tlm/compare/v1.1.0...v1.1.1
[1.1.0]: https://github.com/cleanlab/cleanlab-tlm/compare/v1.0.23...v1.1.0
[1.0.23]: https://github.com/cleanlab/cleanlab-tlm/compare/v1.0.22...v1.0.23
Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ extra-dependencies = [
"pytest",
"pytest-asyncio",
"python-dotenv",
"tiktoken",
]
[tool.hatch.envs.types.scripts]
check = "mypy --strict --install-types --non-interactive {args:src/cleanlab_tlm tests}"
Expand All @@ -57,6 +58,7 @@ allow-direct-references = true
extra-dependencies = [
"python-dotenv",
"pytest-asyncio",
"tiktoken",
]

[tool.hatch.envs.hatch-test.env-vars]
Expand Down
2 changes: 1 addition & 1 deletion src/cleanlab_tlm/__about__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
# SPDX-License-Identifier: MIT
__version__ = "1.1.1"
__version__ = "1.1.2"
2 changes: 2 additions & 0 deletions src/cleanlab_tlm/internal/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
_VALID_TLM_QUALITY_PRESETS: list[str] = ["best", "high", "medium", "low", "base"]
_VALID_TLM_QUALITY_PRESETS_RAG: list[str] = ["medium", "low", "base"]
_DEFAULT_TLM_QUALITY_PRESET: TLMQualityPreset = "medium"
_DEFAULT_TLM_MAX_TOKENS: int = 512
_VALID_TLM_MODELS: list[str] = [
"gpt-3.5-turbo-16k",
"gpt-4",
Expand Down Expand Up @@ -32,6 +33,7 @@
"nova-pro",
]
_TLM_DEFAULT_MODEL: str = "gpt-4o-mini"
_TLM_DEFAULT_CONTEXT_LIMIT: int = 70000
_VALID_TLM_TASKS: set[str] = {task.value for task in Task}
TLM_TASK_SUPPORTING_CONSTRAIN_OUTPUTS: set[Task] = {
Task.DEFAULT,
Expand Down
46 changes: 46 additions & 0 deletions src/cleanlab_tlm/utils/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from cleanlab_tlm.internal.constants import (
_DEFAULT_TLM_MAX_TOKENS,
_DEFAULT_TLM_QUALITY_PRESET,
_TLM_DEFAULT_CONTEXT_LIMIT,
_TLM_DEFAULT_MODEL,
)


def get_default_model() -> str:
"""
Get the default model name for TLM.

Returns:
str: The default model name for TLM.
"""
return _TLM_DEFAULT_MODEL


def get_default_quality_preset() -> str:
"""
Get the default quality preset for TLM.

Returns:
str: The default quality preset for TLM.
"""
return _DEFAULT_TLM_QUALITY_PRESET


def get_default_context_limit() -> int:
"""
Get the default context limit for TLM.

Returns:
int: The default context limit for TLM.
"""
return _TLM_DEFAULT_CONTEXT_LIMIT


def get_default_max_tokens() -> int:
"""
Get the default maximum output tokens allowed.

Returns:
int: The default maximum output tokens.
"""
return _DEFAULT_TLM_MAX_TOKENS
2 changes: 2 additions & 0 deletions tests/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
MAX_COMBINED_LENGTH_TOKENS: int = 70_000

CHARACTERS_PER_TOKEN: int = 4
# 4 character (3 character + 1 space) = 1 token
WORD_THAT_EQUALS_ONE_TOKEN = "orb " # noqa: S105

# Property tests for TLM
excluded_tlm_models: list[str] = [
Expand Down
58 changes: 58 additions & 0 deletions tests/test_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import pytest
import tiktoken

from cleanlab_tlm.errors import TlmBadRequestError
from cleanlab_tlm.tlm import TLM
from cleanlab_tlm.utils.config import (
get_default_context_limit,
get_default_max_tokens,
get_default_model,
get_default_quality_preset,
)
from tests.constants import WORD_THAT_EQUALS_ONE_TOKEN

tlm_with_default_setting = TLM()


def test_get_default_model(tlm: TLM) -> None:
assert tlm.get_model_name() == get_default_model()


def test_get_default_quality_preset(tlm: TLM) -> None:
assert get_default_quality_preset() == tlm._quality_preset


def test_prompt_too_long_exception_single_prompt(tlm: TLM) -> None:
"""Tests that bad request error is raised when prompt is too long when calling tlm.prompt with a single prompt."""
with pytest.raises(TlmBadRequestError) as exc_info:
tlm.prompt(WORD_THAT_EQUALS_ONE_TOKEN * (get_default_context_limit() + 1))

assert exc_info.value.message.startswith("Prompt length exceeds")
assert exc_info.value.retryable is False


def test_prompt_within_context_limit_returns_response(tlm: TLM) -> None:
"""Tests that no error is raised when prompt length is within limit."""
response = tlm.prompt(WORD_THAT_EQUALS_ONE_TOKEN * (get_default_context_limit() - 1000))

assert isinstance(response, dict)
assert "response" in response
assert isinstance(response["response"], str)


def test_response_within_max_tokens() -> None:
"""Tests that response is within max tokens limit."""
tlm_base = TLM(quality_preset="base")
prompt = "write a 100 page book about computer science. make sure it is extremely long and comprehensive."

result = tlm_base.prompt(prompt)
assert isinstance(result, dict)
response = result["response"]
assert isinstance(response, str)

try:
enc = tiktoken.encoding_for_model(get_default_model())
except KeyError:
enc = tiktoken.encoding_for_model("gpt-4o")
tokens_in_response = len(enc.encode(response))
assert tokens_in_response <= get_default_max_tokens()
26 changes: 13 additions & 13 deletions tests/test_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,13 @@
from cleanlab_tlm.utils.rag import Eval, TrustworthyRAG
from tests.conftest import make_text_unique
from tests.constants import (
CHARACTERS_PER_TOKEN,
MAX_COMBINED_LENGTH_TOKENS,
MAX_PROMPT_LENGTH_TOKENS,
MAX_RESPONSE_LENGTH_TOKENS,
TEST_PROMPT,
TEST_PROMPT_BATCH,
TEST_RESPONSE,
WORD_THAT_EQUALS_ONE_TOKEN,
)
from tests.test_get_trustworthiness_score import is_tlm_score_response_with_error
from tests.test_prompt import is_tlm_response_with_error
Expand Down Expand Up @@ -208,7 +208,7 @@ def test_prompt_too_long_exception_single_prompt(tlm: TLM) -> None:
"""Tests that bad request error is raised when prompt is too long when calling tlm.prompt with a single prompt."""
with pytest.raises(TlmBadRequestError) as exc_info:
tlm.prompt(
"a" * (MAX_PROMPT_LENGTH_TOKENS + 1) * CHARACTERS_PER_TOKEN,
WORD_THAT_EQUALS_ONE_TOKEN * (MAX_PROMPT_LENGTH_TOKENS + 1),
)

assert exc_info.value.message.startswith("Prompt length exceeds")
Expand All @@ -221,7 +221,7 @@ def test_prompt_too_long_exception_prompt(tlm: TLM, num_prompts: int) -> None:
# create batch of prompts with one prompt that is too long
prompts = [test_prompt] * num_prompts
prompt_too_long_index = np.random.randint(0, num_prompts)
prompts[prompt_too_long_index] = "a" * (MAX_PROMPT_LENGTH_TOKENS + 1) * CHARACTERS_PER_TOKEN
prompts[prompt_too_long_index] = WORD_THAT_EQUALS_ONE_TOKEN * (MAX_PROMPT_LENGTH_TOKENS + 1)

tlm_responses = cast(list[TLMResponse], tlm.prompt(prompts))

Expand All @@ -232,8 +232,8 @@ def test_response_too_long_exception_single_score(tlm: TLM) -> None:
"""Tests that bad request error is raised when response is too long when calling tlm.get_trustworthiness_score with a single prompt."""
with pytest.raises(TlmBadRequestError) as exc_info:
tlm.get_trustworthiness_score(
"a",
"a" * (MAX_RESPONSE_LENGTH_TOKENS + 1) * CHARACTERS_PER_TOKEN,
WORD_THAT_EQUALS_ONE_TOKEN,
WORD_THAT_EQUALS_ONE_TOKEN * (MAX_RESPONSE_LENGTH_TOKENS + 1),
)

assert exc_info.value.message.startswith("Response length exceeds")
Expand All @@ -247,7 +247,7 @@ def test_response_too_long_exception_score(tlm: TLM, num_prompts: int) -> None:
prompts = [test_prompt] * num_prompts
responses = [TEST_RESPONSE] * num_prompts
response_too_long_index = np.random.randint(0, num_prompts)
responses[response_too_long_index] = "a" * (MAX_RESPONSE_LENGTH_TOKENS + 1) * CHARACTERS_PER_TOKEN
responses[response_too_long_index] = WORD_THAT_EQUALS_ONE_TOKEN * (MAX_RESPONSE_LENGTH_TOKENS + 1)

tlm_responses = cast(list[TLMScore], tlm.get_trustworthiness_score(prompts, responses))

Expand All @@ -258,8 +258,8 @@ def test_prompt_too_long_exception_single_score(tlm: TLM) -> None:
"""Tests that bad request error is raised when prompt is too long when calling tlm.get_trustworthiness_score with a single prompt."""
with pytest.raises(TlmBadRequestError) as exc_info:
tlm.get_trustworthiness_score(
"a" * (MAX_PROMPT_LENGTH_TOKENS + 1) * CHARACTERS_PER_TOKEN,
"a",
WORD_THAT_EQUALS_ONE_TOKEN * (MAX_PROMPT_LENGTH_TOKENS + 1),
WORD_THAT_EQUALS_ONE_TOKEN,
)

assert exc_info.value.message.startswith("Prompt length exceeds")
Expand All @@ -273,7 +273,7 @@ def test_prompt_too_long_exception_score(tlm: TLM, num_prompts: int) -> None:
prompts = [test_prompt] * num_prompts
responses = [TEST_RESPONSE] * num_prompts
prompt_too_long_index = np.random.randint(0, num_prompts)
prompts[prompt_too_long_index] = "a" * (MAX_PROMPT_LENGTH_TOKENS + 1) * CHARACTERS_PER_TOKEN
prompts[prompt_too_long_index] = WORD_THAT_EQUALS_ONE_TOKEN * (MAX_PROMPT_LENGTH_TOKENS + 1)

tlm_responses = cast(list[TLMScore], tlm.get_trustworthiness_score(prompts, responses))

Expand All @@ -286,8 +286,8 @@ def test_combined_too_long_exception_single_score(tlm: TLM) -> None:

with pytest.raises(TlmBadRequestError) as exc_info:
tlm.get_trustworthiness_score(
"a" * max_prompt_length * CHARACTERS_PER_TOKEN,
"a" * MAX_RESPONSE_LENGTH_TOKENS * CHARACTERS_PER_TOKEN,
WORD_THAT_EQUALS_ONE_TOKEN * max_prompt_length,
WORD_THAT_EQUALS_ONE_TOKEN * MAX_RESPONSE_LENGTH_TOKENS,
)

assert exc_info.value.message.startswith("Prompt and response combined length exceeds")
Expand All @@ -306,8 +306,8 @@ def test_prompt_and_response_combined_too_long_exception_batch_score(tlm: TLM, n
combined_too_long_index = np.random.randint(0, num_prompts)

max_prompt_length = MAX_COMBINED_LENGTH_TOKENS - MAX_RESPONSE_LENGTH_TOKENS + 1
prompts[combined_too_long_index] = "a" * max_prompt_length * CHARACTERS_PER_TOKEN
responses[combined_too_long_index] = "a" * MAX_RESPONSE_LENGTH_TOKENS * CHARACTERS_PER_TOKEN
prompts[combined_too_long_index] = WORD_THAT_EQUALS_ONE_TOKEN * max_prompt_length
responses[combined_too_long_index] = WORD_THAT_EQUALS_ONE_TOKEN * MAX_RESPONSE_LENGTH_TOKENS

tlm_responses = cast(list[TLMScore], tlm.get_trustworthiness_score(prompts, responses))

Expand Down
Loading