Skip to content

Commit 3f5ff91

Browse files
authored
ODSC-58449: ads LangChain plugin update (#877)
1 parent 30534f7 commit 3f5ff91

File tree

3 files changed

+44
-8
lines changed

3 files changed

+44
-8
lines changed

ads/llm/langchain/plugins/base.py

+6
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from langchain.llms.base import LLM
99
from langchain.pydantic_v1 import BaseModel, Field, root_validator
1010

11+
from ads import logger
1112
from ads.common.auth import default_signer
1213
from ads.config import COMPARTMENT_OCID
1314

@@ -95,6 +96,11 @@ def validate_environment( # pylint: disable=no-self-argument
9596
"""Validate that python package exists in environment."""
9697
# Initialize client only if user does not pass in client.
9798
# Users may choose to initialize the OCI client by themselves and pass it into this model.
99+
logger.warning(
100+
f"The ads langchain plugin {cls.__name__} will be deprecated soon. "
101+
"Please refer to https://python.langchain.com/v0.2/docs/integrations/providers/oci/ "
102+
"for the latest support."
103+
)
98104
if not values.get("client"):
99105
auth = values.get("auth", {})
100106
client_kwargs = auth.get("client_kwargs") or {}

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ pii = [
175175
"spacy==3.6.1",
176176
"report-creator==1.0.9",
177177
]
178-
llm = ["langchain-community<0.0.32", "langchain>=0.1.10,<0.1.14", "evaluate>=0.4.0", "langchain-core<0.1.51"]
178+
llm = ["langchain-community<0.0.32", "langchain>=0.1.10,<0.1.14", "evaluate>=0.4.0"]
179179
aqua = ["jupyter_server"]
180180

181181
# To reduce backtracking (decrese deps install time) during test/dev env setup reducing number of versions pip is

tests/unitary/with_extras/langchain/test_serialization.py

+37-7
Original file line numberDiff line numberDiff line change
@@ -7,20 +7,25 @@
77

88
import os
99
from copy import deepcopy
10-
from unittest import TestCase, mock, SkipTest
10+
from unittest import SkipTest, TestCase, mock, skipIf
1111

12-
from langchain.llms import Cohere
12+
import langchain_core
1313
from langchain.chains import LLMChain
14+
from langchain.llms import Cohere
1415
from langchain.prompts import PromptTemplate
1516
from langchain.schema.runnable import RunnableParallel, RunnablePassthrough
1617

17-
from ads.llm.serialize import load, dump
1818
from ads.llm import (
1919
GenerativeAI,
2020
GenerativeAIEmbeddings,
2121
ModelDeploymentTGI,
2222
ModelDeploymentVLLM,
2323
)
24+
from ads.llm.serialize import dump, load
25+
26+
27+
def version_tuple(version):
28+
return tuple(map(int, version.split(".")))
2429

2530

2631
class ChainSerializationTest(TestCase):
@@ -142,6 +147,10 @@ def test_llm_chain_serialization_with_oci(self):
142147
self.assertEqual(llm_chain.llm.model, "my_model")
143148
self.assertEqual(llm_chain.input_keys, ["subject"])
144149

150+
@skipIf(
151+
version_tuple(langchain_core.__version__) > (0, 1, 50),
152+
"Serialization not supported in this langchain_core version",
153+
)
145154
def test_oci_gen_ai_serialization(self):
146155
"""Tests serialization of OCI Gen AI LLM."""
147156
try:
@@ -157,6 +166,10 @@ def test_oci_gen_ai_serialization(self):
157166
self.assertEqual(llm.compartment_id, self.COMPARTMENT_ID)
158167
self.assertEqual(llm.client_kwargs, self.GEN_AI_KWARGS)
159168

169+
@skipIf(
170+
version_tuple(langchain_core.__version__) > (0, 1, 50),
171+
"Serialization not supported in this langchain_core version",
172+
)
160173
def test_gen_ai_embeddings_serialization(self):
161174
"""Tests serialization of OCI Gen AI embeddings."""
162175
try:
@@ -201,10 +214,27 @@ def test_runnable_sequence_serialization(self):
201214
element_3 = kwargs.get("last")
202215
self.assertNotIn("_type", element_3)
203216
self.assertEqual(element_3.get("id"), ["ads", "llm", "ModelDeploymentTGI"])
204-
self.assertEqual(
205-
element_3.get("kwargs"),
206-
{"endpoint": "https://modeldeployment.customer-oci.com/ocid/predict"},
207-
)
217+
218+
if version_tuple(langchain_core.__version__) > (0, 1, 50):
219+
self.assertEqual(
220+
element_3.get("kwargs"),
221+
{
222+
"max_tokens": 256,
223+
"temperature": 0.2,
224+
"p": 0.75,
225+
"endpoint": "https://modeldeployment.customer-oci.com/ocid/predict",
226+
"best_of": 1,
227+
"do_sample": True,
228+
"watermark": True,
229+
},
230+
)
231+
else:
232+
self.assertEqual(
233+
element_3.get("kwargs"),
234+
{
235+
"endpoint": "https://modeldeployment.customer-oci.com/ocid/predict",
236+
},
237+
)
208238

209239
chain = load(serialized)
210240
self.assertEqual(len(chain.steps), 3)

0 commit comments

Comments
 (0)