|
44 | 44 | from google.cloud.aiplatform_v1beta1.types import EncryptionSpec
|
45 | 45 | from vertexai.preview.rag.utils.resources import (
|
46 | 46 | ANN,
|
| 47 | + DocumentCorpus, |
47 | 48 | EmbeddingModelConfig,
|
48 | 49 | JiraSource,
|
49 | 50 | KNN,
|
50 | 51 | LayoutParserConfig,
|
51 | 52 | LlmParserConfig,
|
| 53 | + MemoryCorpus, |
52 | 54 | Pinecone,
|
53 | 55 | RagCorpus,
|
| 56 | + RagCorpusTypeConfig, |
54 | 57 | RagEmbeddingModelConfig,
|
55 | 58 | RagEngineConfig,
|
56 | 59 | RagFile,
|
@@ -312,12 +315,35 @@ def convert_gapic_to_backend_config(
|
312 | 315 | return vector_config
|
313 | 316 |
|
314 | 317 |
|
| 318 | +def convert_gapic_to_rag_corpus_type_config( |
| 319 | + gapic_rag_corpus_type_config: GapicRagCorpus.CorpusTypeConfig, |
| 320 | +) -> RagCorpusTypeConfig: |
| 321 | + """Convert GapicRagCorpus.CorpusTypeConfig to RagCorpusTypeConfig.""" |
| 322 | + if gapic_rag_corpus_type_config.document_corpus: |
| 323 | + return RagCorpusTypeConfig(corpus_type_config=DocumentCorpus()) |
| 324 | + elif gapic_rag_corpus_type_config.memory_corpus: |
| 325 | + return RagCorpusTypeConfig( |
| 326 | + corpus_type_config=MemoryCorpus( |
| 327 | + llm_parser=LlmParserConfig( |
| 328 | + model_name=gapic_rag_corpus_type_config.memory_corpus.llm_parser.model_name, |
| 329 | + max_parsing_requests_per_min=gapic_rag_corpus_type_config.memory_corpus.llm_parser.max_parsing_requests_per_min, |
| 330 | + global_max_parsing_requests_per_min=gapic_rag_corpus_type_config.memory_corpus.llm_parser.global_max_parsing_requests_per_min, |
| 331 | + custom_parsing_prompt=gapic_rag_corpus_type_config.memory_corpus.llm_parser.custom_parsing_prompt, |
| 332 | + ) |
| 333 | + ) |
| 334 | + ) |
| 335 | + return None |
| 336 | + |
| 337 | + |
315 | 338 | def convert_gapic_to_rag_corpus(gapic_rag_corpus: GapicRagCorpus) -> RagCorpus:
|
316 | 339 | """Convert GapicRagCorpus to RagCorpus."""
|
317 | 340 | rag_corpus = RagCorpus(
|
318 | 341 | name=gapic_rag_corpus.name,
|
319 | 342 | display_name=gapic_rag_corpus.display_name,
|
320 | 343 | description=gapic_rag_corpus.description,
|
| 344 | + corpus_type_config=convert_gapic_to_rag_corpus_type_config( |
| 345 | + gapic_rag_corpus.corpus_type_config |
| 346 | + ), |
321 | 347 | embedding_model_config=convert_gapic_to_embedding_model_config(
|
322 | 348 | gapic_rag_corpus.rag_embedding_model_config
|
323 | 349 | ),
|
@@ -553,6 +579,10 @@ def prepare_import_files_request(
|
553 | 579 | rag_file_parsing_config.llm_parser.max_parsing_requests_per_min = (
|
554 | 580 | llm_parser.max_parsing_requests_per_min
|
555 | 581 | )
|
| 582 | + if llm_parser.global_max_parsing_requests_per_min is not None: |
| 583 | + rag_file_parsing_config.llm_parser.global_max_parsing_requests_per_min = ( |
| 584 | + llm_parser.global_max_parsing_requests_per_min |
| 585 | + ) |
556 | 586 | if llm_parser.custom_parsing_prompt is not None:
|
557 | 587 | rag_file_parsing_config.llm_parser.custom_parsing_prompt = (
|
558 | 588 | llm_parser.custom_parsing_prompt
|
@@ -671,10 +701,51 @@ def get_file_name(
|
671 | 701 | )
|
672 | 702 |
|
673 | 703 |
|
| 704 | +def set_corpus_type_config( |
| 705 | + corpus_type_config: RagCorpusTypeConfig, |
| 706 | + rag_corpus: GapicRagCorpus, |
| 707 | +) -> None: |
| 708 | + """Set corpus type config in GapicRagCorpus.""" |
| 709 | + if isinstance(corpus_type_config.corpus_type_config, DocumentCorpus): |
| 710 | + rag_corpus.corpus_type_config = GapicRagCorpus.CorpusTypeConfig( |
| 711 | + document_corpus=GapicRagCorpus.CorpusTypeConfig.DocumentCorpus() |
| 712 | + ) |
| 713 | + elif isinstance(corpus_type_config.corpus_type_config, MemoryCorpus): |
| 714 | + memory_corpus = GapicRagCorpus.CorpusTypeConfig.MemoryCorpus() |
| 715 | + if corpus_type_config.corpus_type_config.llm_parser is not None: |
| 716 | + memory_corpus.llm_parser = RagFileParsingConfig.LlmParser( |
| 717 | + model_name=corpus_type_config.corpus_type_config.llm_parser.model_name |
| 718 | + ) |
| 719 | + if ( |
| 720 | + corpus_type_config.corpus_type_config.llm_parser.max_parsing_requests_per_min |
| 721 | + is not None |
| 722 | + ): |
| 723 | + memory_corpus.llm_parser.max_parsing_requests_per_min = ( |
| 724 | + corpus_type_config.corpus_type_config.llm_parser.max_parsing_requests_per_min |
| 725 | + ) |
| 726 | + if ( |
| 727 | + corpus_type_config.corpus_type_config.llm_parser.global_max_parsing_requests_per_min |
| 728 | + is not None |
| 729 | + ): |
| 730 | + memory_corpus.llm_parser.global_max_parsing_requests_per_min = ( |
| 731 | + corpus_type_config.corpus_type_config.llm_parser.global_max_parsing_requests_per_min |
| 732 | + ) |
| 733 | + if ( |
| 734 | + corpus_type_config.corpus_type_config.llm_parser.custom_parsing_prompt |
| 735 | + is not None |
| 736 | + ): |
| 737 | + memory_corpus.llm_parser.custom_parsing_prompt = ( |
| 738 | + corpus_type_config.corpus_type_config.llm_parser.custom_parsing_prompt |
| 739 | + ) |
| 740 | + else: |
| 741 | + raise TypeError |
| 742 | + |
| 743 | + |
674 | 744 | def set_embedding_model_config(
|
675 | 745 | embedding_model_config: EmbeddingModelConfig,
|
676 | 746 | rag_corpus: GapicRagCorpus,
|
677 | 747 | ) -> None:
|
| 748 | + """Sets the embedding model config for the rag corpus.""" |
678 | 749 | if embedding_model_config.publisher_model and embedding_model_config.endpoint:
|
679 | 750 | raise ValueError("publisher_model and endpoint cannot be set at the same time.")
|
680 | 751 | if (
|
|
0 commit comments