|
19 | 19 | from . import _transformers as t
|
20 | 20 | from . import types
|
21 | 21 | from .models import AsyncModels, Models
|
22 |
| -from .types import Content, GenerateContentConfigOrDict, GenerateContentResponse, Part, PartUnionDict |
| 22 | +from .types import Content, ContentOrDict, GenerateContentConfigOrDict, GenerateContentResponse, Part, PartUnionDict |
23 | 23 |
|
24 | 24 |
|
25 | 25 | if sys.version_info >= (3, 10):
|
@@ -116,14 +116,21 @@ def __init__(
|
116 | 116 | *,
|
117 | 117 | model: str,
|
118 | 118 | config: Optional[GenerateContentConfigOrDict] = None,
|
119 |
| - history: list[Content], |
| 119 | + history: list[ContentOrDict], |
120 | 120 | ):
|
121 | 121 | self._model = model
|
122 | 122 | self._config = config
|
123 |
| - self._comprehensive_history = history |
| 123 | + content_models = [] |
| 124 | + for content in history: |
| 125 | + if not isinstance(content, Content): |
| 126 | + content_model = Content.model_validate(content) |
| 127 | + else: |
| 128 | + content_model = content |
| 129 | + content_models.append(content_model) |
| 130 | + self._comprehensive_history = content_models |
124 | 131 | """Comprehensive history is the full history of the chat, including turns of the invalid contents from the model and their associated inputs.
|
125 | 132 | """
|
126 |
| - self._curated_history = _extract_curated_history(history) |
| 133 | + self._curated_history = _extract_curated_history(content_models) |
127 | 134 | """Curated history is the set of valid turns that will be used in the subsequent send requests.
|
128 | 135 | """
|
129 | 136 |
|
@@ -210,7 +217,7 @@ def __init__(
|
210 | 217 | modules: Models,
|
211 | 218 | model: str,
|
212 | 219 | config: Optional[GenerateContentConfigOrDict] = None,
|
213 |
| - history: list[Content], |
| 220 | + history: list[ContentOrDict], |
214 | 221 | ):
|
215 | 222 | self._modules = modules
|
216 | 223 | super().__init__(
|
@@ -344,7 +351,7 @@ def create(
|
344 | 351 | *,
|
345 | 352 | model: str,
|
346 | 353 | config: Optional[GenerateContentConfigOrDict] = None,
|
347 |
| - history: Optional[list[Content]] = None, |
| 354 | + history: Optional[list[ContentOrDict]] = None, |
348 | 355 | ) -> Chat:
|
349 | 356 | """Creates a new chat session.
|
350 | 357 |
|
@@ -373,7 +380,7 @@ def __init__(
|
373 | 380 | modules: AsyncModels,
|
374 | 381 | model: str,
|
375 | 382 | config: Optional[GenerateContentConfigOrDict] = None,
|
376 |
| - history: list[Content], |
| 383 | + history: list[ContentOrDict], |
377 | 384 | ):
|
378 | 385 | self._modules = modules
|
379 | 386 | super().__init__(
|
@@ -501,7 +508,7 @@ def create(
|
501 | 508 | *,
|
502 | 509 | model: str,
|
503 | 510 | config: Optional[GenerateContentConfigOrDict] = None,
|
504 |
| - history: Optional[list[Content]] = None, |
| 511 | + history: Optional[list[ContentOrDict]] = None, |
505 | 512 | ) -> AsyncChat:
|
506 | 513 | """Creates a new chat session.
|
507 | 514 |
|
|
0 commit comments