7
7
8
8
import os
9
9
from copy import deepcopy
10
- from unittest import TestCase , mock , SkipTest
10
+ from unittest import SkipTest , TestCase , mock , skipIf
11
11
12
- from langchain . llms import Cohere
12
+ import langchain_core
13
13
from langchain .chains import LLMChain
14
+ from langchain .llms import Cohere
14
15
from langchain .prompts import PromptTemplate
15
16
from langchain .schema .runnable import RunnableParallel , RunnablePassthrough
16
17
17
- from ads .llm .serialize import load , dump
18
18
from ads .llm import (
19
19
GenerativeAI ,
20
20
GenerativeAIEmbeddings ,
21
21
ModelDeploymentTGI ,
22
22
ModelDeploymentVLLM ,
23
23
)
24
+ from ads .llm .serialize import dump , load
25
+
26
+
27
+ def version_tuple (version ):
28
+ return tuple (map (int , version .split ("." )))
24
29
25
30
26
31
class ChainSerializationTest (TestCase ):
@@ -142,6 +147,10 @@ def test_llm_chain_serialization_with_oci(self):
142
147
self .assertEqual (llm_chain .llm .model , "my_model" )
143
148
self .assertEqual (llm_chain .input_keys , ["subject" ])
144
149
150
+ @skipIf (
151
+ version_tuple (langchain_core .__version__ ) > (0 , 1 , 50 ),
152
+ "Serialization not supported in this langchain_core version" ,
153
+ )
145
154
def test_oci_gen_ai_serialization (self ):
146
155
"""Tests serialization of OCI Gen AI LLM."""
147
156
try :
@@ -157,6 +166,10 @@ def test_oci_gen_ai_serialization(self):
157
166
self .assertEqual (llm .compartment_id , self .COMPARTMENT_ID )
158
167
self .assertEqual (llm .client_kwargs , self .GEN_AI_KWARGS )
159
168
169
+ @skipIf (
170
+ version_tuple (langchain_core .__version__ ) > (0 , 1 , 50 ),
171
+ "Serialization not supported in this langchain_core version" ,
172
+ )
160
173
def test_gen_ai_embeddings_serialization (self ):
161
174
"""Tests serialization of OCI Gen AI embeddings."""
162
175
try :
@@ -201,10 +214,27 @@ def test_runnable_sequence_serialization(self):
201
214
element_3 = kwargs .get ("last" )
202
215
self .assertNotIn ("_type" , element_3 )
203
216
self .assertEqual (element_3 .get ("id" ), ["ads" , "llm" , "ModelDeploymentTGI" ])
204
- self .assertEqual (
205
- element_3 .get ("kwargs" ),
206
- {"endpoint" : "https://modeldeployment.customer-oci.com/ocid/predict" },
207
- )
217
+
218
+ if version_tuple (langchain_core .__version__ ) > (0 , 1 , 50 ):
219
+ self .assertEqual (
220
+ element_3 .get ("kwargs" ),
221
+ {
222
+ "max_tokens" : 256 ,
223
+ "temperature" : 0.2 ,
224
+ "p" : 0.75 ,
225
+ "endpoint" : "https://modeldeployment.customer-oci.com/ocid/predict" ,
226
+ "best_of" : 1 ,
227
+ "do_sample" : True ,
228
+ "watermark" : True ,
229
+ },
230
+ )
231
+ else :
232
+ self .assertEqual (
233
+ element_3 .get ("kwargs" ),
234
+ {
235
+ "endpoint" : "https://modeldeployment.customer-oci.com/ocid/predict" ,
236
+ },
237
+ )
208
238
209
239
chain = load (serialized )
210
240
self .assertEqual (len (chain .steps ), 3 )
0 commit comments