{
 "nbformat": 4,
 "nbformat_minor": 5,
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "name": "python",
   "version": "3.11.0"
  }
 },
 "cells": [
  {
   "cell_type": "markdown",
   "id": "cell-0",
   "metadata": {},
   "source": [
    "# B4 - RAG Essentials: Chunking, Embeddings, and Hybrid Retrieval\n",
    "\n",
    "Companion notebook for article **B4** in *Building with Claude - A Practitioner's Guide to the Anthropic API*.\n",
    "\n",
    "**Attribution.** Concepts adapted from Anthropic's \"Building with the Claude API\" course (Coursera) and public API documentation at [docs.anthropic.com](https://docs.anthropic.com). All code below is original work (c) 2026 DataMy. Not affiliated with Anthropic.\n",
    "\n",
    "---\n",
    "\n",
    "## What you'll build in this notebook\n",
    "\n",
    "A working RAG pipeline from scratch -- no vector database required -- over a three-document corpus:\n",
    "\n",
    "1. **Chunking** -- compare fixed-size and section-boundary strategies on the runbook corpus.\n",
    "2. **Embeddings** -- encode all chunks with VoyageAI `voyage-3`; implement cosine similarity vector search.\n",
    "3. **BM25 keyword search** -- build an in-memory BM25 index; show where it outperforms vector search.\n",
    "4. **Hybrid retrieval with RRF** -- combine both retrievers via Reciprocal Rank Fusion.\n",
    "5. **Full RAG loop** -- retrieve relevant chunks, inject into context, ask Claude a grounded question.\n",
    "\n",
    "**Prerequisites:**\n",
    "- `pip install -r ../requirements.txt`\n",
    "- A `.env` file with `ANTHROPIC_API_KEY` and `VOYAGE_API_KEY` set\n",
    "- Datasets built by `python ../scripts/generate_data.py` (creates the three corpus documents)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cell-1",
   "metadata": {},
   "source": [
    "## Section 1 - Setup\n",
    "\n",
    "Same import pattern as B1 onward: `ClaudeClient` from `llm_client.py`, data from `../data/`.\n",
    "\n",
    "This notebook adds two new dependencies:\n",
    "- `voyageai` -- embedding API client. Reads `VOYAGE_API_KEY` from the environment.\n",
    "- `rank_bm25` -- pure-Python BM25 implementation. No API key required.\n",
    "\n",
    "The three corpus documents are loaded as plain strings. Each is chunked, embedded, and indexed.\n",
    "Retrieval happens at query time against the in-memory structures."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cell-2",
   "metadata": {},
   "outputs": [],
   "source": [
    "import math\n",
    "import re\n",
    "from pathlib import Path\n",
    "\n",
    "import voyageai\n",
    "from dotenv import load_dotenv\n",
    "from rank_bm25 import BM25Okapi\n",
    "\n",
    "from llm_client import ClaudeClient\n",
    "\n",
    "load_dotenv(\"../.env\")\n",
    "\n",
    "DATA_DIR = Path(\"..\") / \"data\"\n",
    "\n",
    "CORPUS_PATHS = {\n",
    "    \"warehouse_runbook\": DATA_DIR / \"runbook_warehouse_cost.md\",\n",
    "    \"quality_runbook\":   DATA_DIR / \"runbook_data_quality.md\",\n",
    "    \"qbr_q3_2025\":       DATA_DIR / \"qbr_q3_2025.md\",\n",
    "}\n",
    "\n",
    "for name, path in CORPUS_PATHS.items():\n",
    "    assert path.exists(), f\"Missing: {path}. Run python ../scripts/generate_data.py\"\n",
    "\n",
    "CORPUS = {name: path.read_text() for name, path in CORPUS_PATHS.items()}\n",
    "\n",
    "vc = voyageai.Client()   # reads VOYAGE_API_KEY from env\n",
    "cc = ClaudeClient()\n",
    "\n",
    "print(\"Corpus loaded:\")\n",
    "for name, text in CORPUS.items():\n",
    "    print(f\"  {name:25s}  {len(text):>7,} chars  (~{len(text)//4:>5,} tokens)\")\n",
    "print(f\"\\nClaudeClient default model : {cc.default_model}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cell-3",
   "metadata": {},
   "source": [
    "## Section 2 - Chunking strategies\n",
    "\n",
    "Two strategies are implemented and compared:\n",
    "\n",
    "**Fixed-size** -- split on word count with a sliding overlap window. Simple and universal, but splits\n",
    "arbitrarily across semantic boundaries.\n",
    "\n",
    "**Section-boundary** -- split on Markdown `##` headers, further splitting long sections at paragraph\n",
    "boundaries. Each chunk carries its section title, making it self-describing without surrounding context.\n",
    "\n",
    "For the runbook corpus (structured Markdown with clear headers), section-boundary chunking produces\n",
    "better retrieval. The output below shows both chunk counts and a sample from each strategy."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cell-4",
   "metadata": {},
   "outputs": [],
   "source": [
    "def chunk_fixed(text: str, chunk_size: int = 400, overlap: int = 50) -> list[str]:\n    \"\"\"Fixed-size word-count chunks with sliding overlap.\"\"\"\n    if not (0 <= overlap < chunk_size):\n        raise ValueError(\n            f\"overlap must satisfy 0 <= overlap < chunk_size, \"\n            f\"got overlap={overlap}, chunk_size={chunk_size}\"\n        )\n    words = text.split()\n    chunks, start = [], 0\n    while start < len(words):\n        end = min(start + chunk_size, len(words))\n        chunks.append(\" \".join(words[start:end]))\n        start += chunk_size - overlap\n    return chunks\n\n\ndef chunk_by_section(text: str, max_words: int = 600) -> list[str]:\n    \"\"\"Split on ## Markdown headers; further split long sections at paragraph boundaries.\"\"\"\n    raw_sections = re.split(r\"\\n(?=## )\", text)\n    chunks = []\n    for section in raw_sections:\n        words = section.split()\n        if len(words) <= max_words:\n            chunks.append(section)\n        else:\n            paras = section.split(\"\\n\\n\")\n            current, current_words = [], 0\n            for para in paras:\n                pw = len(para.split())\n                if current_words + pw > max_words and current:\n                    chunks.append(\"\\n\\n\".join(current))\n                    current, current_words = [], 0\n                current.append(para)\n                current_words += pw\n            if current:\n                chunks.append(\"\\n\\n\".join(current))\n    return [c for c in chunks if c.strip()]\n\n\n# Build the corpus chunk list using section-boundary strategy (used for the rest of the notebook)\nall_chunks: list[dict] = []\nfor doc_name, doc_text in CORPUS.items():\n    for chunk_text in chunk_by_section(doc_text):\n        all_chunks.append({\"source\": doc_name, \"text\": chunk_text})\n\n# Compare chunk counts between strategies\nprint(\"=== Chunking comparison ===\")\nprint(f\"{'Document':25s}  {'Fixed-size':>12}  {'Section-boundary':>16}\")\nprint(\"-\" * 58)\nfor doc_name, doc_text in CORPUS.items():\n    n_fixed   = len(chunk_fixed(doc_text))\n    n_section = len(chunk_by_section(doc_text))\n    print(f\"{doc_name:25s}  {n_fixed:>12}  {n_section:>16}\")\n\nprint(f\"\\nTotal chunks in corpus (section-boundary): {len(all_chunks)}\")\n\n# Sample: first chunk from each document\nprint(\"\\n=== Sample chunks (section-boundary, first chunk per document) ===\")\nseen = set()\nfor chunk in all_chunks:\n    if chunk[\"source\"] not in seen:\n        seen.add(chunk[\"source\"])\n        preview = chunk[\"text\"][:200].replace(\"\\n\", \" \")\n        print(f\"\\n[{chunk['source']}]\\n{preview} ...\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cell-5",
   "metadata": {},
   "source": [
    "## Section 3 - Embeddings with VoyageAI\n",
    "\n",
    "Embed all chunks using `voyage-3`. The API accepts a batch of strings and returns a list of float\n",
    "vectors, one per string. Key parameter: `input_type`.\n",
    "\n",
    "- Use `\"document\"` when embedding chunks during indexing.\n",
    "- Use `\"query\"` when embedding the user's question at retrieval time.\n",
    "\n",
    "VoyageAI uses asymmetric representations for the two types; using the wrong one silently degrades\n",
    "retrieval quality.\n",
    "\n",
    "After embedding, cosine similarity search is a brute-force scan over the chunk list -- fast enough\n",
    "for corpora of a few hundred chunks, and the right place to start before adding a vector database."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cell-6",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Embed all chunks (one API call, batched)\n",
    "print(f\"Embedding {len(all_chunks)} chunks with voyage-3 ...\")\n",
    "embed_result = vc.embed(\n",
    "    [c[\"text\"] for c in all_chunks],\n",
    "    model=\"voyage-3\",\n",
    "    input_type=\"document\",\n",
    ")\n",
    "for i, emb in enumerate(embed_result.embeddings):\n",
    "    all_chunks[i][\"embedding\"] = emb\n",
    "\n",
    "print(f\"Done. Vector dimension: {len(embed_result.embeddings[0])}\")\n",
    "print(f\"Total tokens used for indexing: {embed_result.total_tokens:,}\")\n",
    "\n",
    "\n",
    "# Cosine similarity helpers\n",
    "def cosine_sim(a: list[float], b: list[float]) -> float:\n",
    "    dot   = sum(x * y for x, y in zip(a, b))\n",
    "    norm_a = math.sqrt(sum(x * x for x in a))\n",
    "    norm_b = math.sqrt(sum(x * x for x in b))\n",
    "    return dot / (norm_a * norm_b + 1e-10)\n",
    "\n",
    "\n",
    "def embed_query(query: str) -> list[float]:\n",
    "    return vc.embed([query], model=\"voyage-3\", input_type=\"query\").embeddings[0]\n",
    "\n",
    "\n",
    "def vector_search(query: str, k: int = 5) -> list[tuple[dict, float]]:\n",
    "    q_emb  = embed_query(query)\n",
    "    scored = [(cosine_sim(q_emb, c[\"embedding\"]), i) for i, c in enumerate(all_chunks)]\n",
    "    scored.sort(reverse=True)\n",
    "    return [(all_chunks[i], score) for score, i in scored[:k]]\n",
    "\n",
    "\n",
    "# Demo: semantic search on a paraphrased question\n",
    "q = \"How do I figure out which warehouse is driving up my Snowflake bill?\"\n",
    "print(f\"\\n=== Vector search: '{q}' ===\")\n",
    "for rank, (chunk, score) in enumerate(vector_search(q, k=3), 1):\n",
    "    preview = chunk[\"text\"][:120].replace(\"\\n\", \" \")\n",
    "    print(f\"  {rank}. [{chunk['source']}] sim={score:.4f}  {preview} ...\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cell-7",
   "metadata": {},
   "source": [
    "## Section 4 - BM25 keyword search\n",
    "\n",
    "BM25 is a classical term-frequency scoring algorithm. It excels where embeddings struggle: exact-match\n",
    "keywords, named entities, dates, error codes, and product names.\n",
    "\n",
    "The query below asks about a specific incident by date. Watch which retriever finds the right chunk."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cell-8",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Build BM25 index over the same corpus\n",
    "tokenized_corpus = [c[\"text\"].lower().split() for c in all_chunks]\n",
    "bm25 = BM25Okapi(tokenized_corpus)\n",
    "print(f\"BM25 index built over {len(tokenized_corpus)} chunks.\")\n",
    "\n",
    "\n",
    "def bm25_search(query: str, k: int = 5) -> list[tuple[dict, float]]:\n",
    "    tokens = query.lower().split()\n",
    "    scores = bm25.get_scores(tokens)\n",
    "    top_idx = sorted(range(len(scores)), key=lambda i: scores[i], reverse=True)[:k]\n",
    "    return [(all_chunks[i], scores[i]) for i in top_idx]\n",
    "\n",
    "\n",
    "# Query A: named-entity / date lookup -- BM25 should win here\n",
    "q_entity = \"What happened in the 2025-04-03 incident?\"\n",
    "print(f\"\\n=== BM25 search: '{q_entity}' ===\")\n",
    "for rank, (chunk, score) in enumerate(bm25_search(q_entity, k=3), 1):\n",
    "    preview = chunk[\"text\"][:120].replace(\"\\n\", \" \")\n",
    "    print(f\"  {rank}. [{chunk['source']}] score={score:.4f}  {preview} ...\")\n",
    "\n",
    "print(f\"\\n=== Vector search: '{q_entity}' ===\")\n",
    "for rank, (chunk, score) in enumerate(vector_search(q_entity, k=3), 1):\n",
    "    preview = chunk[\"text\"][:120].replace(\"\\n\", \" \")\n",
    "    print(f\"  {rank}. [{chunk['source']}] sim={score:.4f}  {preview} ...\")\n",
    "\n",
    "print()\n",
    "print(\"Observation: BM25 finds the exact incident chunk by date;\")\n",
    "print(\"vector search may return semantically adjacent but less precise results.\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cell-9",
   "metadata": {},
   "source": [
    "## Section 5 - Hybrid retrieval with Reciprocal Rank Fusion\n",
    "\n",
    "Neither retriever alone is sufficient:\n",
    "- Vector search finds semantic matches but misses exact keyword hits.\n",
    "- BM25 finds keyword hits but misses paraphrased or synonym-heavy queries.\n",
    "\n",
    "Reciprocal Rank Fusion (RRF) combines the two by converting scores to ranks before merging.\n",
    "Each result receives a score of `1 / (60 + rank)` from each retriever; scores are summed.\n",
    "The constant 60 prevents the top rank from dominating and is robust across retriever types.\n",
    "\n",
    "The result: a ranked list that benefits from both semantic and lexical signals simultaneously."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cell-10",
   "metadata": {},
   "outputs": [],
   "source": [
    "def rrf_score(rank: int, k: int = 60) -> float:\n    return 1.0 / (k + rank)\n\n\ndef hybrid_search(\n    query: str,\n    k: int = 5,\n    fetch: int = 20,\n) -> list[tuple[dict, float]]:\n    \"\"\"Combine vector and BM25 results via Reciprocal Rank Fusion.\"\"\"\n    vec_results  = vector_search(query, k=fetch)\n    bm25_results = bm25_search(query,   k=fetch)\n\n    # Precomputed position map: avoids O(n) list scan and handles duplicate chunk text\n    _idx = {id(c): i for i, c in enumerate(all_chunks)}\n    vec_ranks  = {_idx[id(c)]: r + 1 for r, (c, _) in enumerate(vec_results)}\n    bm25_ranks = {_idx[id(c)]: r + 1 for r, (c, _) in enumerate(bm25_results)}\n\n    candidates = set(vec_ranks) | set(bm25_ranks)\n    rrf: dict[int, float] = {}\n    for idx in candidates:\n        rrf[idx] = (\n            (rrf_score(vec_ranks[idx])  if idx in vec_ranks  else 0.0) +\n            (rrf_score(bm25_ranks[idx]) if idx in bm25_ranks else 0.0)\n        )\n\n    top = sorted(rrf, key=lambda i: rrf[i], reverse=True)[:k]\n    return [(all_chunks[i], rrf[i]) for i in top]\n\n\n# Compare all three retrievers on two contrasting query types\nqueries = [\n    (\"semantic\",  \"How do I reduce Snowflake spend without impacting dashboard performance?\"),\n    (\"entity\",    \"What was the credit impact of the 2025-05-18 incident?\"),\n]\n\nfor qtype, q in queries:\n    vec_top  = vector_search(q, k=1)[0]\n    bm25_top = bm25_search(q, k=1)[0]\n    hyb_top  = hybrid_search(q, k=1)[0]\n\n    print(f\"Query [{qtype}]: {q}\")\n    print(f\"  Vector  -> [{vec_top[0]['source']}]  {vec_top[0]['text'][:90].replace(chr(10),' ')} ...\")\n    print(f\"  BM25    -> [{bm25_top[0]['source']}]  {bm25_top[0]['text'][:90].replace(chr(10),' ')} ...\")\n    print(f\"  Hybrid  -> [{hyb_top[0]['source']}]  {hyb_top[0]['text'][:90].replace(chr(10),' ')} ...\")\n    print()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cell-11",
   "metadata": {},
   "source": [
    "## Section 6 - Full RAG loop: retrieve, inject, generate\n",
    "\n",
    "The complete pipeline in one function. Hybrid search retrieves the top-k chunks, assembles them\n",
    "into a context block with source attribution, and passes the context + question to Claude.\n",
    "\n",
    "Three design choices reflected in the implementation:\n",
    "\n",
    "1. **Source labels in the context.** Each chunk is prefixed with its document name so Claude can\n",
    "   cite the source -- and so you can audit whether the right document was retrieved.\n",
    "2. **Explicit grounding instruction.** The system prompt says to use ONLY the provided context and\n",
    "   to acknowledge when the answer is not there. Without this, the model fills gaps from training.\n",
    "3. **Temperature 0.** Retrieval-grounded answers should be deterministic."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cell-12",
   "metadata": {},
   "outputs": [],
   "source": [
    "RAG_SYSTEM = (\n",
    "    \"You are a data platform assistant for Acme SaaS Co. \"\n",
    "    \"Answer questions using ONLY the context provided below. \"\n",
    "    \"Cite the source document name whenever you reference a specific fact. \"\n",
    "    \"If the answer is not present in the context, say so explicitly -- \"\n",
    "    \"do not draw on outside knowledge.\"\n",
    ")\n",
    "\n",
    "\n",
    "def rag_answer(question: str, k: int = 4) -> tuple[str, list]:\n",
    "    retrieved = hybrid_search(question, k=k)\n",
    "\n",
    "    context = \"\\n\\n---\\n\\n\".join(\n",
    "        f\"[Source: {c['source']}]\\n{c['text']}\"\n",
    "        for c, _ in retrieved\n",
    "    )\n",
    "\n",
    "    resp = cc.client.messages.create(\n",
    "        model=cc.default_model,\n",
    "        max_tokens=800,\n",
    "        temperature=0,\n",
    "        system=RAG_SYSTEM,\n",
    "        messages=[{\n",
    "            \"role\": \"user\",\n",
    "            \"content\": f\"Context:\\n\\n{context}\\n\\nQuestion: {question}\",\n",
    "        }],\n",
    "    )\n",
    "    return resp.content[0].text, retrieved\n",
    "\n",
    "\n",
    "# Demo queries that span all three corpus documents\n",
    "demo_questions = [\n",
    "    \"What caused the largest credit incident in 2025 and how was it resolved?\",\n",
    "    \"How do duplicate primary keys get detected in the data quality runbook?\",\n",
    "    \"What were the Q4 priorities described in the Q3 QBR report?\",\n",
    "]\n",
    "\n",
    "for question in demo_questions:\n",
    "    print(f\"Q: {question}\")\n",
    "    answer, sources = rag_answer(question)\n",
    "    print(f\"A: {answer[:350].rstrip()} ...\")\n",
    "    print(\"Retrieved from:\")\n",
    "    for chunk, score in sources:\n",
    "        print(f\"  [{chunk['source']}] rrf={score:.4f}  \"\n",
    "              f\"{chunk['text'][:60].replace(chr(10), ' ')} ...\")\n",
    "    print()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cell-13",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Out-of-scope guard: ask a question not covered by the corpus\n",
    "out_of_scope = \"What is Acme's policy on remote work?\"\n",
    "answer, sources = rag_answer(out_of_scope)\n",
    "print(f\"Q: {out_of_scope}\")\n",
    "print(f\"A: {answer}\")\n",
    "print()\n",
    "print(\"Expected: the model should say the answer is not in the context.\")\n",
    "print(\"If it speculates, the grounding instruction needs to be tightened.\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cell-14",
   "metadata": {},
   "outputs": [],
   "source": [
    "cc.print_summary()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cell-15",
   "metadata": {},
   "source": [
    "## Section 7 - Practitioner Lab\n",
    "\n",
    "Open-ended extension. No reference solution.\n",
    "\n",
    "**Goal:** add document-level metadata filtering to the hybrid search.\n",
    "\n",
    "**Problem:** the current `hybrid_search` searches across all three documents regardless of which\n",
    "one the question is about. Add an optional `source_filter` parameter that restricts retrieval to\n",
    "a specified document (or list of documents) before scoring.\n",
    "\n",
    "```python\n",
    "# Target signature:\n",
    "def hybrid_search(\n",
    "    query: str,\n",
    "    k: int = 5,\n",
    "    fetch: int = 20,\n",
    "    source_filter: str | list[str] | None = None,\n",
    ") -> list[tuple[dict, float]]:\n",
    "    ...\n",
    "```\n",
    "\n",
    "**Constraints:**\n",
    "1. When `source_filter` is None, behaviour is unchanged.\n",
    "2. When a filter is applied, only chunks whose `\"source\"` key matches are candidates -- the BM25\n",
    "   and vector retrievers should both be restricted before RRF merging.\n",
    "3. The BM25 index needs to be rebuilt (or a sub-index maintained) for the filtered subset.\n",
    "   Think about whether you rebuild on every call or pre-build per-document indexes.\n",
    "\n",
    "**Test case:**\n",
    "```python\n",
    "results = hybrid_search(\n",
    "    \"What dbt test catches duplicate rows?\",\n",
    "    source_filter=\"quality_runbook\",\n",
    ")\n",
    "# Every returned chunk should have source == \"quality_runbook\"\n",
    "assert all(c[\"source\"] == \"quality_runbook\" for c, _ in results)\n",
    "```\n",
    "\n",
    "**Stretch:** extend `rag_answer` to auto-detect which document to filter on by doing a cheap\n",
    "BM25 search first to identify the most relevant source, then re-running full hybrid search\n",
    "restricted to that source.\n",
    "\n",
    "Why this matters: in production RAG systems, naive cross-document retrieval frequently returns\n",
    "plausible but wrong chunks when two documents share vocabulary. Metadata filtering is the simplest\n",
    "precision improvement -- often more impactful than reranking or better chunking.\n",
    "\n",
    "---\n",
    "\n",
    "*Companion article: B4 - RAG Essentials: Chunking, Embeddings, and Hybrid Retrieval.*\n",
    "*Next notebook: B5_rag_advanced.ipynb*"
   ]
  }
 ]
}