Skip to content

Commit

Permalink
feat(llm): support graph_rag_recall api (apache#79)
Browse files Browse the repository at this point in the history
  • Loading branch information
imbajin authored Sep 18, 2024
1 parent c89eb31 commit c519ec0
Show file tree
Hide file tree
Showing 16 changed files with 209 additions and 184 deletions.
13 changes: 5 additions & 8 deletions hugegraph-llm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ Refer to [docker-link](https://hub.docker.com/r/hugegraph/hugegraph) & [deploy-d
python3 -m hugegraph_llm.demo.gremlin_generate_web_demo
```

7. After starting the web demo, the config file `.env` will be automatically generated. You can modify its content in the web page. Or modify the file directly and restart the web application.
7. After starting the web demo, the config file `.env` will be automatically generated. You can modify its content on the web page. Or modify the file directly and restart the web application.

(Optional)To regenerate the config file, you can use `config.generate` with `-u` or `--update`.
```bash
Expand Down Expand Up @@ -130,22 +130,19 @@ The methods of the `KgBuilder` class can be chained together to perform a sequen

Run example like `python3 ./hugegraph_llm/examples/graph_rag_test.py`

The `GraphRAG` class is used to integrate HugeGraph with large language models to provide retrieval-augmented generation capabilities.
The `RAGPipeline` class is used to integrate HugeGraph with large language models to provide retrieval-augmented generation capabilities.
Here is a brief usage guide:

1. **Extract Keyword:**: Extract keywords and expand synonyms.

```python
graph_rag.extract_keyword(text="Tell me about Al Pacino.").print_result()
graph_rag.extract_keywords(text="Tell me about Al Pacino.").print_result()
```

2. **Query Graph for Rag**: Retrieve the corresponding keywords and their multi-degree associated relationships from HugeGraph.

```python
graph_rag.query_graph_for_rag(
max_deep=2,
max_items=30
).print_result()
graph_rag.query_graphdb(max_deep=2, max_items=30).print_result()
```
3. **Synthesize Answer**: Summarize the results and organize the language to answer the question.

Expand Down
15 changes: 13 additions & 2 deletions hugegraph-llm/src/hugegraph_llm/api/models/rag_requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,31 @@
# specific language governing permissions and limitations
# under the License.

from typing import Optional
from typing import Optional, Literal

from pydantic import BaseModel


class RAGRequest(BaseModel):
query: str
raw_llm: Optional[bool] = True
raw_llm: Optional[bool] = False
vector_only: Optional[bool] = False
graph_only: Optional[bool] = False
graph_vector: Optional[bool] = False
graph_ratio: float = 0.5
rerank_method: Literal["bleu", "reranker"] = "bleu"
near_neighbor_first: bool = False
custom_related_information: str = None
answer_prompt: Optional[str] = None


class GraphRAGRequest(BaseModel):
query: str
rerank_method: Literal["bleu", "reranker"] = "bleu"
near_neighbor_first: bool = False
custom_related_information: str = None


class GraphConfigRequest(BaseModel):
ip: str = "127.0.0.1"
port: str = "8080"
Expand Down
52 changes: 49 additions & 3 deletions hugegraph-llm/src/hugegraph_llm/api/rag_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,22 +14,42 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import json
from typing import Literal

from fastapi import status, APIRouter
from fastapi import status, APIRouter, HTTPException

from hugegraph_llm.api.exceptions.rag_exceptions import generate_response
from hugegraph_llm.api.models.rag_requests import (
RAGRequest,
GraphConfigRequest,
LLMConfigRequest,
RerankerConfigRequest,
RerankerConfigRequest, GraphRAGRequest,
)
from hugegraph_llm.api.models.rag_response import RAGResponse
from hugegraph_llm.config import settings
from hugegraph_llm.utils.log import log


def graph_rag_recall(
text: str,
rerank_method: Literal["bleu", "reranker"],
near_neighbor_first: bool,
custom_related_information: str
) -> dict:
from hugegraph_llm.operators.graph_rag_task import RAGPipeline
rag = RAGPipeline()
rag.extract_keywords().keywords_to_vid().query_graphdb().merge_dedup_rerank(
rerank_method=rerank_method,
near_neighbor_first=near_neighbor_first,
custom_related_information=custom_related_information,
)
context = rag.run(verbose=True, query=text, graph_search=True)
return context


def rag_http_api(
router: APIRouter, rag_answer_func, apply_graph_conf, apply_llm_conf, apply_embedding_conf, apply_reranker_conf
router: APIRouter, rag_answer_func, apply_graph_conf, apply_llm_conf, apply_embedding_conf, apply_reranker_conf
):
@router.post("/rag", status_code=status.HTTP_200_OK)
def rag_answer_api(req: RAGRequest):
Expand All @@ -41,6 +61,32 @@ def rag_answer_api(req: RAGRequest):
if getattr(req, key)
}

@router.post("/rag/graph", status_code=status.HTTP_200_OK)
def graph_rag_recall_api(req: GraphRAGRequest):
try:
result = graph_rag_recall(
text=req.query,
rerank_method=req.rerank_method,
near_neighbor_first=req.near_neighbor_first,
custom_related_information=req.custom_related_information
)
# TODO/FIXME: handle QianFanClient error (not dict..critical)
# log.critical(f"## {type(result)}, {json.dumps(result)}")
if isinstance(result, dict):
log.critical(f"##1. {type(result)}")
return {"graph_recall": result}
else:
log.critical(f"##2. {type(result)}")
return {"graph_recall": json.dumps(result)}

except TypeError as e:
log.error(f"TypeError in graph_rag_recall_api: {e}")
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
except Exception as e:
log.error(f"Unexpected error occurred: {e}")
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="An unexpected error occurred.")


@router.post("/config/graph", status_code=status.HTTP_201_CREATED)
def graph_config_api(req: GraphConfigRequest):
# Accept status code
Expand Down
53 changes: 24 additions & 29 deletions hugegraph-llm/src/hugegraph_llm/demo/rag_web_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,50 +68,51 @@ def rag_answer(
custom_related_information: str,
answer_prompt: str,
) -> Tuple:

if prompt.default_question != text or prompt.custom_rerank_info != custom_related_information or prompt.answer_prompt != answer_prompt:
"""
Generate an answer using the RAG (Retrieval-Augmented Generation) pipeline.
1. Initialize the RAGPipeline.
2. Select vector search or graph search based on parameters.
3. Merge, deduplicate, and rerank the results.
4. Synthesize the final answer.
5. Run the pipeline and return the results.
"""
should_update_prompt = prompt.default_question != text or prompt.answer_prompt != answer_prompt
if should_update_prompt or prompt.custom_rerank_info != custom_related_information:
prompt.custom_rerank_info = custom_related_information
prompt.default_question = text
prompt.answer_prompt = answer_prompt
prompt.update_yaml_file()

vector_search = vector_only_answer or graph_vector_answer
graph_search = graph_only_answer or graph_vector_answer

if raw_answer is False and not vector_search and not graph_search:
gr.Warning("Please select at least one generate mode.")
return "", "", "", ""
searcher = RAGPipeline()

rag = RAGPipeline()
if vector_search:
searcher.query_vector_index_for_rag()
rag.query_vector_index()
if graph_search:
searcher.extract_keyword().match_keyword_to_id().query_graph_for_rag()
rag.extract_keywords().keywords_to_vid().query_graphdb()
# TODO: add more user-defined search strategies
searcher.merge_dedup_rerank(
graph_ratio, rerank_method, near_neighbor_first, custom_related_information
).synthesize_answer(
raw_answer=raw_answer,
vector_only_answer=vector_only_answer,
graph_only_answer=graph_only_answer,
graph_vector_answer=graph_vector_answer,
answer_prompt=answer_prompt,
)
rag.merge_dedup_rerank(graph_ratio, rerank_method, near_neighbor_first, custom_related_information)
rag.synthesize_answer(raw_answer, vector_only_answer, graph_only_answer, graph_vector_answer, answer_prompt)

try:
context = searcher.run(verbose=True, query=text, vector_search=vector_search, graph_search=graph_search)
context = rag.run(verbose=True, query=text, vector_search=vector_search, graph_search=graph_search)
if context.get("switch_to_bleu"):
gr.Warning("Online reranker fails, automatically switches to local bleu method.")
gr.Warning("Online reranker fails, automatically switches to local bleu rerank.")
return (
context.get("raw_answer", ""),
context.get("vector_only_answer", ""),
context.get("graph_only_answer", ""),
context.get("graph_vector_answer", ""),
)
except ValueError as e:
log.error(e)
log.critical(e)
raise gr.Error(str(e))
except Exception as e:
log.error(e)
log.critical(e)
raise gr.Error(f"An unexpected error occurred: {str(e)}")


Expand Down Expand Up @@ -665,19 +666,13 @@ def several_rag_answer(
parser.add_argument("--port", type=int, default=8001, help="port")
args = parser.parse_args()
app = FastAPI()
app_auth = APIRouter(dependencies=[Depends(authenticate)])
api_auth = APIRouter(dependencies=[Depends(authenticate)])

hugegraph_llm = init_rag_ui()
rag_http_api(
app_auth,
rag_answer,
apply_graph_config,
apply_llm_config,
apply_embedding_config,
apply_reranker_config,
)
rag_http_api(api_auth, rag_answer, apply_graph_config, apply_llm_config, apply_embedding_config,
apply_reranker_config)

app.include_router(app_auth)
app.include_router(api_auth)
auth_enabled = os.getenv("ENABLE_LOGIN", "False").lower() == "true"
log.info("(Status) Authentication is %s now.", "enabled" if auth_enabled else "disabled")
# TODO: support multi-user login when need
Expand Down
4 changes: 2 additions & 2 deletions hugegraph-llm/src/hugegraph_llm/models/llms/init_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@


from hugegraph_llm.models.llms.ollama import OllamaClient
from hugegraph_llm.models.llms.openai import OpenAIChat
from hugegraph_llm.models.llms.openai import OpenAIClient
from hugegraph_llm.models.llms.qianfan import QianfanClient
from hugegraph_llm.config import settings

Expand All @@ -34,7 +34,7 @@ def get_llm(self):
secret_key=settings.qianfan_secret_key
)
if self.llm_type == "openai":
return OpenAIChat(
return OpenAIClient(
api_key=settings.openai_api_key,
api_base=settings.openai_api_base,
model_name=settings.openai_language_model,
Expand Down
4 changes: 2 additions & 2 deletions hugegraph-llm/src/hugegraph_llm/models/llms/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@
from hugegraph_llm.utils.log import log


class OpenAIChat(BaseLLM):
"""Wrapper around OpenAI Chat large language models."""
class OpenAIClient(BaseLLM):
"""Wrapper for OpenAI Client."""

def __init__(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@ def run(self, context: Dict[str, Any]) -> Dict[str, Any]:

verbose = context.get("verbose") or False
if verbose:
print(f"\033[92mKEYWORDS: {context['keywords']}\033[0m")
from hugegraph_llm.utils.log import log
log.info(f"KEYWORDS: {context['keywords']}")

return context

Expand Down
37 changes: 11 additions & 26 deletions hugegraph-llm/src/hugegraph_llm/operators/graph_rag_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,27 +50,18 @@ def __init__(self, llm: Optional[BaseLLM] = None, embedding: Optional[BaseEmbedd
self._embedding = embedding or Embeddings().get_embedding()
self._operators: List[Any] = []

def extract_word(
self,
text: Optional[str] = None,
language: str = "english",
):
def extract_word(self, text: Optional[str] = None, language: str = "english"):
"""
Add a word extraction operator to the pipeline.
:param text: Text to extract words from.
:param language: Language of the text.
:return: Self-instance for chaining.
"""
self._operators.append(
WordExtract(
text=text,
language=language,
)
)
self._operators.append(WordExtract(text=text, language=language))
return self

def extract_keyword(
def extract_keywords(
self,
text: Optional[str] = None,
max_keywords: int = 5,
Expand Down Expand Up @@ -99,16 +90,17 @@ def extract_keyword(
)
return self

def match_keyword_to_id(
def keywords_to_vid(
self,
by: Literal["query", "keywords"] = "keywords",
topk_per_keyword: int = 1,
topk_per_query: int = 10,
):
"""
Add a semantic ID query operator to the pipeline.
:param by: Match by query or keywords.
:param topk_per_keyword: Top K results per keyword.
:param topk_per_query: Top K results per query.
:return: Self-instance for chaining.
"""
self._operators.append(
Expand All @@ -121,7 +113,7 @@ def match_keyword_to_id(
)
return self

def query_graph_for_rag(
def query_graphdb(
self,
max_deep: int = 2,
max_items: int = 30,
Expand All @@ -136,26 +128,19 @@ def query_graph_for_rag(
:return: Self-instance for chaining.
"""
self._operators.append(
GraphRAGQuery(
max_deep=max_deep,
max_items=max_items,
prop_to_match=prop_to_match,
)
GraphRAGQuery(max_deep=max_deep, max_items=max_items, prop_to_match=prop_to_match)
)
return self

def query_vector_index_for_rag(self, max_items: int = 3):
def query_vector_index(self, max_items: int = 3):
"""
Add a vector index query operator to the pipeline.
:param max_items: Maximum number of items to retrieve.
:return: Self-instance for chaining.
"""
self._operators.append(
VectorIndexQuery(
embedding=self._embedding,
topk=max_items,
)
VectorIndexQuery(embedding=self._embedding, topk=max_items, )
)
return self

Expand Down Expand Up @@ -230,7 +215,7 @@ def run(self, **kwargs) -> Dict[str, Any]:
:return: Final context after all operators have been executed.
"""
if len(self._operators) == 0:
self.extract_keyword().query_graph_for_rag().synthesize_answer()
self.extract_keywords().query_graphdb().synthesize_answer()

context = kwargs
context["llm"] = self._llm
Expand Down
Loading

0 comments on commit c519ec0

Please sign in to comment.