-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Add RAG Example using FAISS and Harmony Prompts #207
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
Ujjwal-Bajpayee
wants to merge
1
commit into
openai:main
Choose a base branch
from
Ujjwal-Bajpayee:main
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+368
−0
Open
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,66 @@ | ||
| # Minimal RAG + gpt-oss Example (FAISS Retrieval) | ||
|
|
||
| This example demonstrates a simple, production-style Retrieval-Augmented Generation (RAG) pipeline using FAISS, sentence-transformers, and gpt-oss (or any OpenAI-compatible endpoint). | ||
|
|
||
| **No project configs or core files are changed. All code and dependencies are local to `examples/`.** | ||
|
|
||
| ## Setup | ||
|
|
||
| 1. Install requirements (in a virtualenv): | ||
|
|
||
| ```sh | ||
| pip install -r examples/requirements-rag.txt | ||
| ``` | ||
|
|
||
| 2. Set environment variables: | ||
|
|
||
| - `OPENAI_API_KEY` (your key) | ||
| - `OPENAI_BASE_URL` (e.g., `http://localhost:8000/v1` for vLLM/gpt-oss) | ||
| - `GPT_OSS_MODEL` (model name, e.g., `gpt-oss-20b`) | ||
|
|
||
| ## Usage | ||
|
|
||
| ```sh | ||
| python examples/rag_gpt_oss.py --query "What is vector search?" --top_k 4 | ||
| ``` | ||
|
|
||
| Optional flags: | ||
| - `--rebuild-index` (force reindex) | ||
| - `--no-stream` (disable streaming) | ||
| - `--chunk-size` (default 800) | ||
| - `--chunk-overlap` (default 120) | ||
|
|
||
| ## What it does | ||
|
|
||
| - Loads docs from `examples/data/*.{txt,md,pdf}` (PDFs require `pymupdf`) | ||
| - Builds or loads a FAISS index in `examples/data/.faiss/` | ||
| - Retrieves top-k chunks with metadata (source file, char span) | ||
| - Constructs a Harmony prompt (system guides behavior, user includes question and retrieved context, sources cited) | ||
| - Calls an OpenAI-compatible chat endpoint using the official `openai` Python SDK | ||
| - Streams output (unless `--no-stream`) | ||
| - Prints answer and compact citations list ([source:filename#chunk]) | ||
| - Saves a JSONL transcript to `examples/data/runs/{timestamp}.jsonl` | ||
|
|
||
| ## Example Output | ||
|
|
||
| ``` | ||
| Answer: Vector search is a method ... | ||
|
|
||
| Sources: | ||
| [1] intro_vector_search.md | ||
| [2] embeddings_and_faiss.md | ||
| ``` | ||
|
|
||
| ## Pointing to a Local vLLM Server | ||
|
|
||
| Set `OPENAI_BASE_URL` to your vLLM/gpt-oss endpoint, e.g.: | ||
|
|
||
| ``` | ||
| export OPENAI_BASE_URL=http://localhost:8000/v1 | ||
| ``` | ||
|
|
||
| ## Notes | ||
|
|
||
| - This is a minimal, example-only script. It does not alter project configs or CI. | ||
| - If required packages (faiss, pymupdf) are missing, install hints are printed and the script exits cleanly. | ||
| - All code is self-contained under `examples/`, with no changes to core project files. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,7 @@ | ||
| # Embeddings and FAISS | ||
|
|
||
| Embeddings are vector representations of text. FAISS is a fast library for similarity search and clustering of dense vectors. To use FAISS: | ||
|
|
||
| 1. Generate embeddings for your text chunks using a model like sentence-transformers/all-MiniLM-L6-v2. | ||
| 2. Build a FAISS index from these vectors. | ||
| 3. Retrieve top-k similar chunks for a query using cosine similarity. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,3 @@ | ||
| # Introduction to Vector Search | ||
|
|
||
| Vector search is a technique that enables searching for information based on the semantic meaning of text, rather than exact keyword matches. It works by converting text into high-dimensional vectors (embeddings) and finding the most similar vectors using distance metrics like cosine similarity. This approach powers modern retrieval-augmented generation (RAG) systems and semantic search engines. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,6 @@ | ||
| Chunking and overlap are crucial for effective retrieval: | ||
|
|
||
| - Use chunk sizes that balance context (e.g., 800 characters) and retrieval granularity. | ||
| - Overlap chunks (e.g., 120 characters) to avoid missing relevant information at boundaries. | ||
| - Clean and normalize text before indexing. | ||
| - Always cite sources for transparency. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,253 @@ | ||
| #!/usr/bin/env python | ||
| """ | ||
| Minimal RAG + gpt-oss example using FAISS retrieval. | ||
| See docs/examples/rag_gpt_oss.md for details. | ||
| """ | ||
| import os | ||
| import sys | ||
| import time | ||
| import json | ||
| import argparse | ||
| import glob | ||
| import hashlib | ||
| import datetime | ||
| from pathlib import Path | ||
| from typing import List, Dict, Optional | ||
|
|
||
| # --- Dependency checks and fallbacks --- | ||
| try: | ||
| import faiss | ||
| except ImportError: | ||
| print("[ERROR] Missing dependency: faiss-cpu. Install with: pip install faiss-cpu>=1.8", file=sys.stderr) | ||
| sys.exit(2) | ||
| try: | ||
| from sentence_transformers import SentenceTransformer | ||
| except ImportError: | ||
| print("[ERROR] Missing dependency: sentence-transformers. Install with: pip install sentence-transformers>=2.6", file=sys.stderr) | ||
| sys.exit(2) | ||
| try: | ||
| import tiktoken | ||
| def count_tokens(text): | ||
| enc = tiktoken.get_encoding("cl100k_base") | ||
| return len(enc.encode(text)) | ||
| except ImportError: | ||
| def count_tokens(text): | ||
| return len(text.encode("utf-8")) // 4 # crude fallback | ||
| try: | ||
| import fitz # pymupdf | ||
| def extract_pdf_text(path): | ||
| doc = fitz.open(path) | ||
| return "\n".join(page.get_text() for page in doc) | ||
| except ImportError: | ||
| def extract_pdf_text(path): | ||
| print("[ERROR] pymupdf not installed. Install with: pip install pymupdf>=1.24", file=sys.stderr) | ||
| sys.exit(2) | ||
| try: | ||
| from openai import OpenAI | ||
| except ImportError: | ||
| print("[ERROR] Missing dependency: openai. Install with: pip install openai>=1.40", file=sys.stderr) | ||
| sys.exit(2) | ||
|
|
||
| # --- Harmony helpers --- | ||
| from examples.utils.harmony_helpers import build_harmony_messages, validate_harmony_response | ||
|
|
||
| # --- Chunker --- | ||
| def recursive_chunk(text, chunk_size=800, chunk_overlap=120): | ||
| """Chunk text recursively by tokens/bytes.""" | ||
| chunks = [] | ||
| start = 0 | ||
| text_len = len(text) | ||
| while start < text_len: | ||
| end = min(start + chunk_size, text_len) | ||
| chunk = text[start:end] | ||
| chunks.append((start, end, chunk)) | ||
| if end == text_len: | ||
| break | ||
| start += chunk_size - chunk_overlap | ||
| return chunks | ||
|
|
||
| # --- Doc loader --- | ||
| def load_docs(data_dir: str) -> List[Dict]: | ||
| docs = [] | ||
| for path in glob.glob(os.path.join(data_dir, '*.*')): | ||
| ext = os.path.splitext(path)[1].lower() | ||
| if ext in {'.md', '.txt'}: | ||
| with open(path, encoding='utf-8') as f: | ||
| text = f.read() | ||
| elif ext == '.pdf': | ||
| text = extract_pdf_text(path) | ||
| else: | ||
| continue | ||
| docs.append({'path': path, 'text': text}) | ||
| return docs | ||
|
|
||
| # --- Indexing --- | ||
| def build_or_load_faiss(docs: List[Dict], faiss_dir: str, chunk_size: int, chunk_overlap: int, model_name: str) -> (faiss.IndexFlatIP, List[Dict]): | ||
| os.makedirs(faiss_dir, exist_ok=True) | ||
| meta_path = os.path.join(faiss_dir, 'meta.json') | ||
| index_path = os.path.join(faiss_dir, 'index.bin') | ||
| chunks_path = os.path.join(faiss_dir, 'chunks.jsonl') | ||
| # Check if index exists and is up-to-date | ||
| doc_hash = hashlib.sha1() | ||
| for doc in docs: | ||
| stat = os.stat(doc['path']) | ||
| doc_hash.update(f"{doc['path']}:{stat.st_mtime}".encode()) | ||
| hash_hex = doc_hash.hexdigest() | ||
| if os.path.exists(meta_path): | ||
| with open(meta_path) as f: | ||
| meta = json.load(f) | ||
| if meta.get('hash') == hash_hex and os.path.exists(index_path) and os.path.exists(chunks_path): | ||
| index = faiss.read_index(index_path) | ||
| with open(chunks_path) as f: | ||
| chunks = [json.loads(line) for line in f] | ||
| return index, chunks | ||
| # Rebuild index | ||
| model = SentenceTransformer(model_name) | ||
| all_chunks = [] | ||
| vectors = [] | ||
| for doc in docs: | ||
| for i, (start, end, chunk) in enumerate(recursive_chunk(doc['text'], chunk_size, chunk_overlap)): | ||
| chunk_id = f"{os.path.basename(doc['path'])}#{i}" | ||
| all_chunks.append({ | ||
| 'id': chunk_id, | ||
| 'text': chunk, | ||
| 'source': os.path.basename(doc['path']), | ||
| 'span': [start, end], | ||
| 'path': doc['path'] | ||
| }) | ||
| vectors.append(chunk) | ||
| if not all_chunks: | ||
| print("[ERROR] No chunks found for indexing.", file=sys.stderr) | ||
| sys.exit(2) | ||
| embeds = model.encode(vectors, normalize_embeddings=True, show_progress_bar=True) | ||
| dim = embeds.shape[1] | ||
| index = faiss.IndexFlatIP(dim) | ||
| index.add(embeds) | ||
| faiss.write_index(index, index_path) | ||
| with open(chunks_path, 'w', encoding='utf-8') as f: | ||
| for chunk in all_chunks: | ||
| f.write(json.dumps(chunk, ensure_ascii=False) + '\n') | ||
| with open(meta_path, 'w') as f: | ||
| json.dump({'hash': hash_hex, 'dim': dim, 'model': model_name}, f) | ||
| return index, all_chunks | ||
|
|
||
| # --- Retrieval --- | ||
| def retrieve(query: str, index, chunks: List[Dict], model_name: str, top_k: int) -> List[Dict]: | ||
| model = SentenceTransformer(model_name) | ||
| qvec = model.encode([query], normalize_embeddings=True) | ||
| D, I = index.search(qvec, top_k) | ||
| results = [] | ||
| for rank, idx in enumerate(I[0]): | ||
| if idx < 0 or idx >= len(chunks): | ||
| continue | ||
| chunk = chunks[idx].copy() | ||
| chunk['score'] = float(D[0][rank]) | ||
| chunk['rank'] = rank + 1 | ||
| results.append(chunk) | ||
| return results | ||
|
|
||
| # --- Main CLI --- | ||
| def main(): | ||
| parser = argparse.ArgumentParser(description="Minimal RAG + gpt-oss example (FAISS retrieval)") | ||
| parser.add_argument('--query', required=True, help='User query') | ||
| parser.add_argument('--top_k', type=int, default=4, help='Top-k chunks to retrieve') | ||
| parser.add_argument('--rebuild-index', action='store_true', help='Force rebuild FAISS index') | ||
| parser.add_argument('--no-stream', action='store_true', help='Disable streaming output') | ||
| parser.add_argument('--chunk-size', type=int, default=800, help='Chunk size (chars)') | ||
| parser.add_argument('--chunk-overlap', type=int, default=120, help='Chunk overlap (chars)') | ||
| args = parser.parse_args() | ||
|
|
||
| # Env vars | ||
| api_key = os.getenv('OPENAI_API_KEY') | ||
| base_url = os.getenv('OPENAI_BASE_URL') | ||
| model = os.getenv('GPT_OSS_MODEL') | ||
| if not (api_key and base_url and model): | ||
| print("[ERROR] Set OPENAI_API_KEY, OPENAI_BASE_URL, and GPT_OSS_MODEL.", file=sys.stderr) | ||
| sys.exit(2) | ||
|
|
||
| data_dir = os.path.join(os.path.dirname(__file__), 'data') | ||
| faiss_dir = os.path.join(data_dir, '.faiss') | ||
| runs_dir = os.path.join(data_dir, 'runs') | ||
| os.makedirs(runs_dir, exist_ok=True) | ||
|
|
||
| docs = load_docs(data_dir) | ||
| if not docs: | ||
| print("[ERROR] No documents found in examples/data/", file=sys.stderr) | ||
| sys.exit(2) | ||
|
|
||
| # Index | ||
| if args.rebuild_index: | ||
| for f in Path(faiss_dir).glob('*'): | ||
| f.unlink() | ||
| index, all_chunks = build_or_load_faiss(docs, faiss_dir, args.chunk_size, args.chunk_overlap, 'sentence-transformers/all-MiniLM-L6-v2') | ||
| if index.ntotal == 0 or not all_chunks: | ||
| print("[ERROR] FAISS index is empty.", file=sys.stderr) | ||
| sys.exit(2) | ||
|
|
||
| # Retrieval | ||
| retrieved = retrieve(args.query, index, all_chunks, 'sentence-transformers/all-MiniLM-L6-v2', args.top_k) | ||
| if not retrieved: | ||
| print("[ERROR] No relevant chunks retrieved.", file=sys.stderr) | ||
| sys.exit(2) | ||
|
|
||
| # Prompt | ||
| system_prompt = "You are a helpful assistant. Use ONLY the provided CONTEXT. Cite sources as [1], [2], ... Map them to filenames at the end under 'Sources'." | ||
| messages = build_harmony_messages(system_prompt, args.query, retrieved) | ||
|
|
||
| # OpenAI-compatible call | ||
| client = OpenAI(base_url=base_url, api_key=api_key) | ||
| start_time = time.time() | ||
| response_text = "" | ||
| try: | ||
| stream = not args.no_stream | ||
| completion = client.chat.completions.create( | ||
| model=model, | ||
| messages=messages, | ||
| stream=stream, | ||
| temperature=0.2, | ||
| max_tokens=512 | ||
| ) | ||
| if stream: | ||
| print("\nAnswer:", end=" ", flush=True) | ||
| for chunk in completion: | ||
| delta = getattr(chunk.choices[0].delta, 'content', None) | ||
| if delta: | ||
| print(delta, end="", flush=True) | ||
| response_text += delta | ||
| print() | ||
| else: | ||
| response_text = completion.choices[0].message.content | ||
| print("\nAnswer:", response_text) | ||
| except Exception as e: | ||
| print(f"[ERROR] Model call failed: {e}", file=sys.stderr) | ||
| sys.exit(2) | ||
| latency_ms = int((time.time() - start_time) * 1000) | ||
|
|
||
| # Validate response | ||
| if not validate_harmony_response(response_text): | ||
| print("[ERROR] Model returned empty or invalid response.", file=sys.stderr) | ||
| sys.exit(2) | ||
|
|
||
| # Citations | ||
| print("\nSources:") | ||
| for i, chunk in enumerate(retrieved, 1): | ||
| print(f"[{i}] {chunk['source']}") | ||
|
|
||
| # Save transcript | ||
| ts = datetime.datetime.now().strftime('%Y%m%d_%H%M%S') | ||
| run_path = os.path.join(runs_dir, f'{ts}.jsonl') | ||
| with open(run_path, 'w', encoding='utf-8') as f: | ||
| log = { | ||
| 'query': args.query, | ||
| 'retrieved_ids': [c['id'] for c in retrieved], | ||
| 'prompt': messages, | ||
| 'model': model, | ||
| 'latency_ms': latency_ms, | ||
| 'answer': response_text | ||
| } | ||
| f.write(json.dumps(log, ensure_ascii=False) + '\n') | ||
| # Simple inline test | ||
| assert os.path.exists(run_path) and os.path.getsize(run_path) > 0, "Transcript not saved!" | ||
|
|
||
| if __name__ == '__main__': | ||
| main() | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,5 @@ | ||
| faiss-cpu>=1.8 | ||
| sentence-transformers>=2.6 | ||
| pymupdf>=1.24 | ||
| tiktoken>=0.7 | ||
| openai>=1.40 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,28 @@ | ||
| import re | ||
|
|
||
| def build_harmony_messages(system_prompt: str, user_query: str, retrieved_chunks: list[dict]) -> list[dict]: | ||
| """ | ||
| Build Harmony-style messages for OpenAI-compatible chat completion. | ||
| Each chunk is cited as [n] in CONTEXT and mapped to its source. | ||
| """ | ||
| context_lines = [] | ||
| for i, chunk in enumerate(retrieved_chunks, 1): | ||
| context_lines.append(f"[{i}] {chunk['text']}") | ||
| context = "\n".join(context_lines) | ||
| user_content = f"QUESTION: {user_query}\nCONTEXT:\n{context}" | ||
| messages = [ | ||
| {"role": "system", "content": system_prompt}, | ||
| {"role": "user", "content": user_content}, | ||
| ] | ||
| return messages | ||
|
|
||
| def validate_harmony_response(text: str) -> bool: | ||
| """ | ||
| Minimal checks: non-empty, not a tool-call JSON. | ||
| """ | ||
| if not text or not text.strip(): | ||
| return False | ||
| # Disallow tool-call JSON (e.g., starts with '{' and contains "tool_call") | ||
| if text.strip().startswith('{') and 'tool_call' in text: | ||
| return False | ||
| return True |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The retrieval function forwards the user-provided
top_kdirectly toindex.searchwithout clamping it to the number of indexed chunks or ensuring it is positive. When the corpus is small (e.g., only three chunks) and the CLI is invoked with--top_k 100or--top_k 0,faiss.IndexFlatIP.searchraises aFaiss assertion 'k <= index.ntotal' failed(or similar) before any error handling runs, terminating the program instead of emitting the friendly error messages used elsewhere. Validatingtop_kagainstindex.ntotaland requiring it to be > 0 would avoid the crash.Useful? React with 👍 / 👎.