-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathtest_env.py
42 lines (35 loc) · 1.36 KB
/
test_env.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import os
from cleanlab_codex import Project
from cleanlab_tlm import TLM
from dotenv import load_dotenv
from llama_index.embeddings.google_genai import GoogleGenAIEmbedding # type: ignore
load_dotenv()
QUESTION = "How many syllables are in the phrase 'AI User Conference'?"
def main() -> None:
# check that we can query TLM
tlm_api_key = os.getenv("CLEANLAB_TLM_API_KEY")
if not tlm_api_key:
raise ValueError("CLEANLAB_TLM_API_KEY is not set")
tlm = TLM(api_key=tlm_api_key)
tlm_response = tlm.prompt(QUESTION)
assert "trustworthiness_score" in tlm_response
# check that we can query Codex
codex_access_key = os.getenv("CLEANLAB_CODEX_ACCESS_KEY")
if not codex_access_key:
raise ValueError("CLEANLAB_CODEX_ACCESS_KEY is not set")
project = Project.from_access_key(access_key=codex_access_key)
project.query(QUESTION) # Just verify we can query without error
# check that we can query Gemini
gemini_api_key = os.getenv("GOOGLE_API_KEY")
if not gemini_api_key:
raise ValueError("GOOGLE_API_KEY is not set")
embed_model = GoogleGenAIEmbedding(
model_name="text-embedding-004",
embed_batch_size=100,
api_key=gemini_api_key,
)
embedding = embed_model.get_text_embedding(QUESTION)
assert len(embedding) > 0
print("all ok")
if __name__ == "__main__":
main()