"""
ICAC RAG Engine — local PDF indexing + semantic search
Session 4: PyMuPDF + sentence-transformers + ChromaDB

Pipeline:
  1. PDF → text extraction (PyMuPDF/fitz)
  2. Text → chunks (500 chars, 100 overlap)
  3. Chunks → embeddings (sentence-transformers, all-MiniLM-L6-v2)
  4. Embeddings → ChromaDB (local, persistent)
  5. Query → cosine similarity → top-5 chunks
"""

import hashlib
import logging
import re
from pathlib import Path
from typing import Optional

log = logging.getLogger("icac.rag")

CHUNK_SIZE = 500
CHUNK_OVERLAP = 100
EMBED_MODEL = "all-MiniLM-L6-v2"
COLLECTION_NAME = "icac_documents"


class RAGEngine:
    """Local RAG engine using ChromaDB + sentence-transformers."""

    def __init__(self, docs_dir: str, chroma_dir: str):
        """Initialize ChromaDB + embedding model."""
        self.docs_dir = Path(docs_dir)
        self.chroma_dir = Path(chroma_dir)
        self.docs_dir.mkdir(parents=True, exist_ok=True)
        self.chroma_dir.mkdir(parents=True, exist_ok=True)

        self._embedder = None
        self._chroma_client = None
        self._collection = None
        self._initialized = False

        self._lazy_init()

    def _lazy_init(self):
        """Lazy-load heavy dependencies."""
        if self._initialized:
            return

        try:
            import chromadb
            from chromadb.config import Settings

            self._chroma_client = chromadb.PersistentClient(
                path=str(self.chroma_dir),
                settings=Settings(anonymized_telemetry=False)
            )
            self._collection = self._chroma_client.get_or_create_collection(
                name=COLLECTION_NAME,
                metadata={"hnsw:space": "cosine"}
            )
            log.info("ChromaDB ready — collection '%s' (%d docs)",
                     COLLECTION_NAME, self._collection.count())
        except ImportError:
            log.warning("chromadb not installed — RAG disabled")
            return
        except Exception as e:
            log.error("ChromaDB init failed: %s", e)
            return

        try:
            from sentence_transformers import SentenceTransformer
            self._embedder = SentenceTransformer(EMBED_MODEL)
            log.info("Embedding model loaded: %s", EMBED_MODEL)
        except ImportError:
            log.warning("sentence-transformers not installed — RAG disabled")
            return
        except Exception as e:
            log.error("Embedding model failed: %s", e)
            return

        self._initialized = True
        log.info("RAGEngine fully initialized — docs: %s, chroma: %s",
                 self.docs_dir, self.chroma_dir)

    def _extract_text_from_pdf(self, pdf_path: str) -> str:
        """Extract text from PDF using PyMuPDF (fitz)."""
        try:
            import fitz  # PyMuPDF
            doc = fitz.open(pdf_path)
            text_parts = []
            for page_num in range(len(doc)):
                page = doc[page_num]
                text = page.get_text()
                if text.strip():
                    text_parts.append(f"[PAGE {page_num + 1}]\n{text}")
            doc.close()
            full_text = "\n\n".join(text_parts)
            log.info("Extracted %d chars from %d pages: %s",
                     len(full_text), len(text_parts), Path(pdf_path).name)
            return full_text
        except ImportError:
            log.error("PyMuPDF (fitz) not installed — cannot extract PDF")
            return ""
        except Exception as e:
            log.error("PDF extraction failed for %s: %s", pdf_path, e)
            return ""

    def _chunk_text(self, text: str, source: str) -> list:
        """Split text into overlapping chunks."""
        # Clean up text
        text = re.sub(r'\n{3,}', '\n\n', text)
        text = re.sub(r' {2,}', ' ', text)

        chunks = []
        pos = 0
        chunk_id = 0

        while pos < len(text):
            end = pos + CHUNK_SIZE

            # Try to break at sentence boundary
            if end < len(text):
                # Look for sentence end near the boundary
                for boundary in ['. ', '.\n', '\n\n', '\n', ' ']:
                    bp = text.rfind(boundary, pos + CHUNK_SIZE // 2, end + 50)
                    if bp > pos:
                        end = bp + len(boundary)
                        break

            chunk_text = text[pos:end].strip()
            if len(chunk_text) > 30:  # skip very short chunks
                chunk_hash = hashlib.md5(f"{source}:{chunk_id}".encode()).hexdigest()[:12]
                chunks.append({
                    "id": f"{source}_{chunk_hash}",
                    "text": chunk_text,
                    "source": source,
                    "chunk_idx": chunk_id,
                })
                chunk_id += 1

            pos = end - CHUNK_OVERLAP
            if pos <= 0 and end >= len(text):
                break

        log.info("Chunked %d chars → %d chunks (source: %s)",
                 len(text), len(chunks), source)
        return chunks

    @staticmethod
    def _detect_doc_type(filename: str) -> str:
        """Detect document type from filename."""
        fn = filename.lower()
        if any(k in fn for k in ["compte", "gestion", "cg_", "comptes"]):
            return "compte_gestion"
        elif any(k in fn for k in ["pv", "proces", "conseil", "deliber", "séance"]):
            return "pv_conseil"
        elif any(k in fn for k in ["budget", "primitif", "bp_", "bp2"]):
            return "budget_primitif"
        elif any(k in fn for k in ["bulletin", "magazine", "info_municipale"]):
            return "bulletin_municipal"
        elif any(k in fn for k in ["marche", "marché", "appel", "lot_"]):
            return "marche_public"
        return "autre"

    @staticmethod
    def _detect_annee(filename: str) -> str:
        """Extract year from filename."""
        import re as _re
        m = _re.search(r'20[12]\d', filename)
        return m.group(0) if m else ""

    def index_document(self, pdf_path: str) -> int:
        """Extract, chunk, vectorize a PDF → return nb chunks indexed."""
        if not self._initialized:
            self._lazy_init()
            if not self._initialized:
                log.warning("RAG not initialized — skipping index")
                return 0

        pdf_name = Path(pdf_path).stem
        filename = Path(pdf_path).name

        # Check if already indexed
        existing = self._collection.get(where={"source": pdf_name})
        if existing and existing.get("ids"):
            log.info("Document already indexed: %s (%d chunks)",
                     pdf_name, len(existing["ids"]))
            return len(existing["ids"])

        # Detect doc_type and year
        doc_type = self._detect_doc_type(filename)
        annee = self._detect_annee(filename)
        log.info("Indexing %s → doc_type=%s, annee=%s", filename, doc_type, annee)

        # Extract text
        text = self._extract_text_from_pdf(pdf_path)
        if not text:
            return 0

        # Chunk
        chunks = self._chunk_text(text, pdf_name)
        if not chunks:
            return 0

        # Embed and store with enriched metadata
        texts = [c["text"] for c in chunks]
        ids = [c["id"] for c in chunks]
        metadatas = [{
            "source": c["source"],
            "chunk_idx": c["chunk_idx"],
            "filename": filename,
            "doc_type": doc_type,
            "annee": annee,
            "page": c.get("page", 1),
        } for c in chunks]

        embeddings = self._embedder.encode(texts, show_progress_bar=False).tolist()

        self._collection.add(
            ids=ids,
            embeddings=embeddings,
            documents=texts,
            metadatas=metadatas,
        )

        log.info("Indexed %s: %d chunks stored in ChromaDB (type=%s)", pdf_name, len(chunks), doc_type)
        return len(chunks)

    def search(self, query: str, doc_type: str = None,
               top_k: int = 5, min_score: float = 0.35) -> list:
        """
        Semantic search in ChromaDB.
        Only returns chunks with score >= min_score.
        Cosine score: 1.0 = identical, 0.0 = unrelated.
        0.35 = reasonable threshold for administrative text.
        """
        if not self._initialized:
            self._lazy_init()
            if not self._initialized:
                return []

        if self._collection.count() == 0:
            return []

        try:
            where = {"doc_type": doc_type} if doc_type else None
            query_embedding = self._embedder.encode([query]).tolist()

            results = self._collection.query(
                query_embeddings=query_embedding,
                n_results=min(top_k * 2, self._collection.count()),
                where=where,
                include=["documents", "metadatas", "distances"]
            )

            chunks = []
            if results and results.get("documents"):
                for i, doc in enumerate(results["documents"][0]):
                    meta = results["metadatas"][0][i] if results.get("metadatas") else {}
                    dist = results["distances"][0][i] if results.get("distances") else 0
                    score = round(1 - dist, 3)

                    # Filter non-relevant chunks
                    if score < min_score:
                        log.info(
                            "RAG ignored (score=%.3f < %.2f): %s p.%s",
                            score, min_score,
                            meta.get("filename", "?"),
                            meta.get("page", "?")
                        )
                        continue

                    chunks.append({
                        "text": doc,
                        "source": meta.get("source", "?"),
                        "filename": meta.get("filename", ""),
                        "doc_type": meta.get("doc_type", ""),
                        "annee": meta.get("annee", ""),
                        "page": meta.get("page", 1),
                        "chunk_idx": meta.get("chunk_idx", 0),
                        "score": score,
                    })

            # Keep only top_k after filtering
            chunks = sorted(chunks, key=lambda x: x["score"], reverse=True)[:top_k]

            log.info(
                "RAG search '%s' (doc_type=%s) → %d chunks relevant (threshold=%.2f)",
                query[:50], doc_type or "all", len(chunks), min_score
            )
            return chunks

        except Exception as e:
            log.error("Search RAG: %s", e)
            return []

    def get_context(self, query: str, doc_type: str = None) -> str:
        """Return formatted context for LLM from RAG search."""
        chunks = self.search(query, doc_type=doc_type, top_k=5, min_score=0.35)
        if not chunks:
            log.info("RAG: no relevant chunk for '%s' (doc_type=%s)", query[:60], doc_type or "all")
            return ""

        parts = ["DOCUMENTS LOCAUX (RAG):"]
        for i, c in enumerate(chunks, 1):
            src = f"{c['filename']} p.{c['page']} {c['annee']} [score={c['score']:.2f}]"
            parts.append(f"\n--- Extrait {i} ({src}) ---")
            parts.append(c["text"])

        return "\n".join(parts)

    def get_stats(self) -> dict:
        """Return RAG engine stats."""
        if not self._initialized:
            return {"initialized": False, "total_chunks": 0, "docs_on_disk": 0}

        docs_on_disk = len(list(self.docs_dir.glob("*.pdf")))
        return {
            "initialized": True,
            "total_chunks": self._collection.count() if self._collection else 0,
            "docs_on_disk": docs_on_disk,
            "embed_model": EMBED_MODEL,
            "chunk_size": CHUNK_SIZE,
        }

    def index_all_documents(self) -> dict:
        """Index all PDFs in the docs directory."""
        total_chunks = 0
        indexed = 0
        errors = []

        for pdf_file in sorted(self.docs_dir.glob("*.pdf")):
            try:
                n = self.index_document(str(pdf_file))
                total_chunks += n
                indexed += 1
                log.info("Indexed %s → %d chunks", pdf_file.name, n)
            except Exception as e:
                log.error("Failed to index %s: %s", pdf_file.name, e)
                errors.append({"file": pdf_file.name, "error": str(e)})

        return {
            "indexed": indexed,
            "total_chunks": total_chunks,
            "errors": errors,
        }
