Skip to content

Commit ede4b5b

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat: RAG - Introduce configuration to corpus types, with DocumentCorpus and MemoryCorpus options.
PiperOrigin-RevId: 766977388
1 parent aa6eda4 commit ede4b5b

File tree

6 files changed

+275
-35
lines changed

6 files changed

+275
-35
lines changed

tests/unit/vertex_rag/test_rag_constants_preview.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
from vertexai.preview.rag import (
4343
ANN,
4444
Basic,
45+
DocumentCorpus,
4546
EmbeddingModelConfig,
4647
Enterprise,
4748
Filter,
@@ -52,8 +53,10 @@
5253
LayoutParserConfig,
5354
LlmParserConfig,
5455
LlmRanker,
56+
MemoryCorpus,
5557
Pinecone,
5658
RagCorpus,
59+
RagCorpusTypeConfig,
5760
RagEmbeddingModelConfig,
5861
RagEngineConfig,
5962
RagFile,
@@ -226,6 +229,7 @@
226229
)
227230
),
228231
)
232+
229233
TEST_EMBEDDING_MODEL_CONFIG = EmbeddingModelConfig(
230234
publisher_model="publishers/google/models/textembedding-gecko",
231235
)
@@ -816,12 +820,69 @@
816820
max_parsing_requests_per_min=100,
817821
)
818822

823+
TEST_GAPIC_LLM_PARSER = RagFileParsingConfig.LlmParser(
824+
model_name="gemini-1.5-pro-002",
825+
max_parsing_requests_per_min=500,
826+
global_max_parsing_requests_per_min=1000,
827+
custom_parsing_prompt="test-custom-parsing-prompt",
828+
)
829+
819830
TEST_LLM_PARSER_CONFIG = LlmParserConfig(
820831
model_name="gemini-1.5-pro-002",
821832
max_parsing_requests_per_min=500,
833+
global_max_parsing_requests_per_min=1000,
822834
custom_parsing_prompt="test-custom-parsing-prompt",
823835
)
824836

837+
TEST_RAG_MEMORY_CORPUS_CONFIG = MemoryCorpus(
838+
llm_parser=TEST_LLM_PARSER_CONFIG,
839+
)
840+
841+
TEST_RAG_MEMORY_CORPUS = RagCorpus(
842+
name=TEST_RAG_CORPUS_RESOURCE_NAME,
843+
display_name=TEST_CORPUS_DISPLAY_NAME,
844+
description=TEST_CORPUS_DISCRIPTION,
845+
corpus_type_config=RagCorpusTypeConfig(
846+
corpus_type_config=TEST_RAG_MEMORY_CORPUS_CONFIG
847+
),
848+
)
849+
850+
TEST_GAPIC_RAG_MEMORY_CORPUS = GapicRagCorpus(
851+
name=TEST_RAG_CORPUS_RESOURCE_NAME,
852+
display_name=TEST_CORPUS_DISPLAY_NAME,
853+
description=TEST_CORPUS_DISCRIPTION,
854+
corpus_type_config=GapicRagCorpus.CorpusTypeConfig(
855+
memory_corpus=GapicRagCorpus.CorpusTypeConfig.MemoryCorpus(
856+
llm_parser=RagFileParsingConfig.LlmParser(
857+
model_name="gemini-1.5-pro-002",
858+
max_parsing_requests_per_min=500,
859+
global_max_parsing_requests_per_min=1000,
860+
custom_parsing_prompt="test-custom-parsing-prompt",
861+
)
862+
)
863+
),
864+
)
865+
866+
TEST_RAG_DOCUMENT_CORPUS_CONFIG = DocumentCorpus()
867+
868+
TEST_RAG_DOCUMENT_CORPUS = RagCorpus(
869+
name=TEST_RAG_CORPUS_RESOURCE_NAME,
870+
display_name=TEST_CORPUS_DISPLAY_NAME,
871+
description=TEST_CORPUS_DISCRIPTION,
872+
corpus_type_config=RagCorpusTypeConfig(
873+
corpus_type_config=TEST_RAG_DOCUMENT_CORPUS_CONFIG
874+
),
875+
)
876+
877+
TEST_GAPIC_RAG_DOCUMENT_CORPUS = GapicRagCorpus(
878+
name=TEST_RAG_CORPUS_RESOURCE_NAME,
879+
display_name=TEST_CORPUS_DISPLAY_NAME,
880+
description=TEST_CORPUS_DISCRIPTION,
881+
corpus_type_config=GapicRagCorpus.CorpusTypeConfig(
882+
document_corpus=GapicRagCorpus.CorpusTypeConfig.DocumentCorpus()
883+
),
884+
)
885+
825886
TEST_IMPORT_FILES_CONFIG_SHARE_POINT_SOURCE_NO_FOLDERS = ImportRagFilesConfig(
826887
rag_file_transformation_config=TEST_RAG_FILE_TRANSFORMATION_CONFIG,
827888
share_point_sources=GapicSharePointSources(
@@ -885,6 +946,7 @@
885946
llm_parser=RagFileParsingConfig.LlmParser(
886947
model_name="gemini-1.5-pro-002",
887948
max_parsing_requests_per_min=500,
949+
global_max_parsing_requests_per_min=1000,
888950
custom_parsing_prompt="test-custom-parsing-prompt",
889951
)
890952
)

tests/unit/vertex_rag/test_rag_data_preview.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,36 @@ def create_rag_corpus_mock_vertex_ai_datastore_search_config():
313313
yield create_rag_corpus_mock_vertex_ai_datastore_search_config
314314

315315

316+
@pytest.fixture
317+
def create_rag_corpus_mock_memory_corpus():
318+
with mock.patch.object(
319+
VertexRagDataServiceClient,
320+
"create_rag_corpus",
321+
) as create_rag_corpus_mock_memory_corpus:
322+
create_rag_corpus_lro_mock = mock.Mock(ga_operation.Operation)
323+
create_rag_corpus_lro_mock.done.return_value = True
324+
create_rag_corpus_lro_mock.result.return_value = (
325+
test_rag_constants_preview.TEST_GAPIC_RAG_MEMORY_CORPUS
326+
)
327+
create_rag_corpus_mock_memory_corpus.return_value = create_rag_corpus_lro_mock
328+
yield create_rag_corpus_mock_memory_corpus
329+
330+
331+
@pytest.fixture
332+
def create_rag_corpus_mock_document_corpus():
333+
with mock.patch.object(
334+
VertexRagDataServiceClient,
335+
"create_rag_corpus",
336+
) as create_rag_corpus_mock_document_corpus:
337+
create_rag_corpus_lro_mock = mock.Mock(ga_operation.Operation)
338+
create_rag_corpus_lro_mock.done.return_value = True
339+
create_rag_corpus_lro_mock.result.return_value = (
340+
test_rag_constants_preview.TEST_GAPIC_RAG_DOCUMENT_CORPUS
341+
)
342+
create_rag_corpus_mock_document_corpus.return_value = create_rag_corpus_lro_mock
343+
yield create_rag_corpus_mock_document_corpus
344+
345+
316346
@pytest.fixture
317347
def update_rag_corpus_mock_vertex_ai_engine_search_config():
318348
with mock.patch.object(
@@ -591,6 +621,7 @@ def rag_corpus_eq(returned_corpus, expected_corpus):
591621
assert returned_corpus.vertex_ai_search_config.__eq__(
592622
expected_corpus.vertex_ai_search_config
593623
)
624+
assert returned_corpus.corpus_type_config.__eq__(expected_corpus.corpus_type_config)
594625

595626

596627
def rag_file_eq(returned_file, expected_file):
@@ -918,6 +949,28 @@ def test_create_corpus_failure(self):
918949
)
919950
e.match("Failed in RagCorpus creation due to")
920951

952+
@pytest.mark.usefixtures("create_rag_corpus_mock_memory_corpus")
953+
def test_create_memory_corpus_success(self):
954+
rag_corpus = rag.create_corpus(
955+
display_name=test_rag_constants_preview.TEST_CORPUS_DISPLAY_NAME,
956+
corpus_type_config=rag.RagCorpusTypeConfig(
957+
corpus_type_config=test_rag_constants_preview.TEST_RAG_MEMORY_CORPUS_CONFIG
958+
),
959+
)
960+
961+
rag_corpus_eq(rag_corpus, test_rag_constants_preview.TEST_RAG_MEMORY_CORPUS)
962+
963+
@pytest.mark.usefixtures("create_rag_corpus_mock_document_corpus")
964+
def test_create_document_corpus_success(self):
965+
rag_corpus = rag.create_corpus(
966+
display_name=test_rag_constants_preview.TEST_CORPUS_DISPLAY_NAME,
967+
corpus_type_config=rag.RagCorpusTypeConfig(
968+
corpus_type_config=test_rag_constants_preview.TEST_RAG_DOCUMENT_CORPUS_CONFIG
969+
),
970+
)
971+
972+
rag_corpus_eq(rag_corpus, test_rag_constants_preview.TEST_RAG_DOCUMENT_CORPUS)
973+
921974
@pytest.mark.usefixtures("update_rag_corpus_mock_weaviate")
922975
def test_update_corpus_weaviate_success(self):
923976
rag_corpus = rag.update_corpus(

vertexai/preview/rag/__init__.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
ANN,
4040
ChunkingConfig,
4141
Basic,
42+
DocumentCorpus,
4243
Enterprise,
4344
EmbeddingModelConfig,
4445
Filter,
@@ -49,8 +50,10 @@
4950
LayoutParserConfig,
5051
LlmParserConfig,
5152
LlmRanker,
53+
MemoryCorpus,
5254
Pinecone,
5355
RagCorpus,
56+
RagCorpusTypeConfig,
5457
RagEmbeddingModelConfig,
5558
RagEngineConfig,
5659
RagFile,
@@ -77,6 +80,7 @@
7780
"ANN",
7881
"Basic",
7982
"ChunkingConfig",
83+
"DocumentCorpus",
8084
"Enterprise",
8185
"EmbeddingModelConfig",
8286
"Filter",
@@ -87,14 +91,18 @@
8791
"LayoutParserConfig",
8892
"LlmParserConfig",
8993
"LlmRanker",
94+
"MemoryCorpus",
9095
"Pinecone",
9196
"RagEngineConfig",
9297
"RagCorpus",
98+
"RagCorpusTypeConfig",
99+
"RagEmbeddingModelConfig",
93100
"RagFile",
94101
"RagManagedDb",
95102
"RagManagedDbConfig",
96103
"RagResource",
97104
"RagRetrievalConfig",
105+
"RagVectorDbConfig",
98106
"Ranking",
99107
"RankService",
100108
"Retrieval",
@@ -105,12 +113,10 @@
105113
"TransformationConfig",
106114
"VertexAiSearchConfig",
107115
"VertexFeatureStore",
116+
"VertexPredictionEndpoint",
108117
"VertexRagStore",
109118
"VertexVectorSearch",
110119
"Weaviate",
111-
"RagEmbeddingModelConfig",
112-
"VertexPredictionEndpoint",
113-
"RagVectorDbConfig",
114120
"create_corpus",
115121
"delete_corpus",
116122
"delete_file",

vertexai/preview/rag/rag_data.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
LlmParserConfig,
5353
Pinecone,
5454
RagCorpus,
55+
RagCorpusTypeConfig,
5556
RagEngineConfig,
5657
RagFile,
5758
RagManagedDb,
@@ -69,6 +70,7 @@
6970
def create_corpus(
7071
display_name: Optional[str] = None,
7172
description: Optional[str] = None,
73+
corpus_type_config: Optional[RagCorpusTypeConfig] = None,
7274
embedding_model_config: Optional[EmbeddingModelConfig] = None,
7375
vector_db: Optional[
7476
Union[Weaviate, VertexFeatureStore, VertexVectorSearch, Pinecone, RagManagedDb]
@@ -96,6 +98,7 @@ def create_corpus(
9698
the RagCorpus. The name can be up to 128 characters long and can consist
9799
of any UTF-8 characters.
98100
description: The description of the RagCorpus.
101+
corpus_type_config: The corpus type config of the RagCorpus.
99102
embedding_model_config: The embedding model config.
100103
Note: Deprecated. Use backend_config instead.
101104
vector_db: The vector db config of the RagCorpus. If unspecified, the
@@ -119,6 +122,13 @@ def create_corpus(
119122
parent = initializer.global_config.common_location_path(project=None, location=None)
120123

121124
rag_corpus = GapicRagCorpus(display_name=display_name, description=description)
125+
126+
if corpus_type_config:
127+
_gapic_utils.set_corpus_type_config(
128+
corpus_type_config=corpus_type_config,
129+
rag_corpus=rag_corpus,
130+
)
131+
122132
if embedding_model_config:
123133
_gapic_utils.set_embedding_model_config(
124134
embedding_model_config=embedding_model_config,

vertexai/preview/rag/utils/_gapic_utils.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,13 +44,16 @@
4444
from google.cloud.aiplatform_v1beta1.types import EncryptionSpec
4545
from vertexai.preview.rag.utils.resources import (
4646
ANN,
47+
DocumentCorpus,
4748
EmbeddingModelConfig,
4849
JiraSource,
4950
KNN,
5051
LayoutParserConfig,
5152
LlmParserConfig,
53+
MemoryCorpus,
5254
Pinecone,
5355
RagCorpus,
56+
RagCorpusTypeConfig,
5457
RagEmbeddingModelConfig,
5558
RagEngineConfig,
5659
RagFile,
@@ -312,12 +315,35 @@ def convert_gapic_to_backend_config(
312315
return vector_config
313316

314317

318+
def convert_gapic_to_rag_corpus_type_config(
319+
gapic_rag_corpus_type_config: GapicRagCorpus.CorpusTypeConfig,
320+
) -> RagCorpusTypeConfig:
321+
"""Convert GapicRagCorpus.CorpusTypeConfig to RagCorpusTypeConfig."""
322+
if gapic_rag_corpus_type_config.document_corpus:
323+
return RagCorpusTypeConfig(corpus_type_config=DocumentCorpus())
324+
elif gapic_rag_corpus_type_config.memory_corpus:
325+
return RagCorpusTypeConfig(
326+
corpus_type_config=MemoryCorpus(
327+
llm_parser=LlmParserConfig(
328+
model_name=gapic_rag_corpus_type_config.memory_corpus.llm_parser.model_name,
329+
max_parsing_requests_per_min=gapic_rag_corpus_type_config.memory_corpus.llm_parser.max_parsing_requests_per_min,
330+
global_max_parsing_requests_per_min=gapic_rag_corpus_type_config.memory_corpus.llm_parser.global_max_parsing_requests_per_min,
331+
custom_parsing_prompt=gapic_rag_corpus_type_config.memory_corpus.llm_parser.custom_parsing_prompt,
332+
)
333+
)
334+
)
335+
return None
336+
337+
315338
def convert_gapic_to_rag_corpus(gapic_rag_corpus: GapicRagCorpus) -> RagCorpus:
316339
"""Convert GapicRagCorpus to RagCorpus."""
317340
rag_corpus = RagCorpus(
318341
name=gapic_rag_corpus.name,
319342
display_name=gapic_rag_corpus.display_name,
320343
description=gapic_rag_corpus.description,
344+
corpus_type_config=convert_gapic_to_rag_corpus_type_config(
345+
gapic_rag_corpus.corpus_type_config
346+
),
321347
embedding_model_config=convert_gapic_to_embedding_model_config(
322348
gapic_rag_corpus.rag_embedding_model_config
323349
),
@@ -553,6 +579,10 @@ def prepare_import_files_request(
553579
rag_file_parsing_config.llm_parser.max_parsing_requests_per_min = (
554580
llm_parser.max_parsing_requests_per_min
555581
)
582+
if llm_parser.global_max_parsing_requests_per_min is not None:
583+
rag_file_parsing_config.llm_parser.global_max_parsing_requests_per_min = (
584+
llm_parser.global_max_parsing_requests_per_min
585+
)
556586
if llm_parser.custom_parsing_prompt is not None:
557587
rag_file_parsing_config.llm_parser.custom_parsing_prompt = (
558588
llm_parser.custom_parsing_prompt
@@ -671,10 +701,51 @@ def get_file_name(
671701
)
672702

673703

704+
def set_corpus_type_config(
705+
corpus_type_config: RagCorpusTypeConfig,
706+
rag_corpus: GapicRagCorpus,
707+
) -> None:
708+
"""Set corpus type config in GapicRagCorpus."""
709+
if isinstance(corpus_type_config.corpus_type_config, DocumentCorpus):
710+
rag_corpus.corpus_type_config = GapicRagCorpus.CorpusTypeConfig(
711+
document_corpus=GapicRagCorpus.CorpusTypeConfig.DocumentCorpus()
712+
)
713+
elif isinstance(corpus_type_config.corpus_type_config, MemoryCorpus):
714+
memory_corpus = GapicRagCorpus.CorpusTypeConfig.MemoryCorpus()
715+
if corpus_type_config.corpus_type_config.llm_parser is not None:
716+
memory_corpus.llm_parser = RagFileParsingConfig.LlmParser(
717+
model_name=corpus_type_config.corpus_type_config.llm_parser.model_name
718+
)
719+
if (
720+
corpus_type_config.corpus_type_config.llm_parser.max_parsing_requests_per_min
721+
is not None
722+
):
723+
memory_corpus.llm_parser.max_parsing_requests_per_min = (
724+
corpus_type_config.corpus_type_config.llm_parser.max_parsing_requests_per_min
725+
)
726+
if (
727+
corpus_type_config.corpus_type_config.llm_parser.global_max_parsing_requests_per_min
728+
is not None
729+
):
730+
memory_corpus.llm_parser.global_max_parsing_requests_per_min = (
731+
corpus_type_config.corpus_type_config.llm_parser.global_max_parsing_requests_per_min
732+
)
733+
if (
734+
corpus_type_config.corpus_type_config.llm_parser.custom_parsing_prompt
735+
is not None
736+
):
737+
memory_corpus.llm_parser.custom_parsing_prompt = (
738+
corpus_type_config.corpus_type_config.llm_parser.custom_parsing_prompt
739+
)
740+
else:
741+
raise TypeError
742+
743+
674744
def set_embedding_model_config(
675745
embedding_model_config: EmbeddingModelConfig,
676746
rag_corpus: GapicRagCorpus,
677747
) -> None:
748+
"""Sets the embedding model config for the rag corpus."""
678749
if embedding_model_config.publisher_model and embedding_model_config.endpoint:
679750
raise ValueError("publisher_model and endpoint cannot be set at the same time.")
680751
if (

0 commit comments

Comments
 (0)