import sys
import mwxml
import re
import uuid

def chunk_text(text, chunk_size=1000, overlap=200):
    chunks = []
    start = 0
    while start < len(text):
        end = start + chunk_size
        chunk = text[start:end]
        chunks.append(chunk)
        start += chunk_size - overlap
    return chunks

def chunk_by_sections(text):
    pattern = r"(==+ .*? ==+)"
    parts = re.split(pattern, text)
    
    chunks = []
    for i in range(1, len(parts), 2):
        heading = parts[i].strip()
        content = parts[i+1].strip() if i+1 < len(parts) else ""
        if content:
            chunks.append(f"{heading}\n{content}")
    return chunks

def combined_chunking(text, chunk_size=1000, overlap=200):
    section_chunks = chunk_by_sections(text)
    final_chunks = []
    for section in section_chunks:
        final_chunks.extend(chunk_text(section, chunk_size, overlap))
    return final_chunks

def parse_wiki_and_chunk(file_path, chunk_size=1000, overlap=200):
    results = []
    dump = mwxml.Dump.from_file(open(file_path, 'rb'))

    for page in dump.pages:
        for revision in page:
            text = revision.text or ""
            title = page.title

            # Skip redirect or empty pages
            if "#REDIRECT" in text.upper() or len(text.strip()) < 100:
                continue

            chunks = combined_chunking(text, chunk_size, overlap)
            for idx, chunk in enumerate(chunks):
                results.append({
                    "id": f"{title.replace(' ', '_')}_{idx}_{uuid.uuid4().hex[:6]}",
                    "text": chunk,
                    "metadata": {
                        "title": title,
                        "chunk_index": idx
                    }
                })

    return results

def main():
    if len(sys.argv) != 2:
        print("Usage: python wiki2chunks.py <wiki_dump.xml>")
        sys.exit(1)

    file_path = sys.argv[1]
    data = parse_wiki_and_chunk(file_path)

    # Output JSONL (each chunk on a line)
    import json
    with open("wiki_chunks.jsonl", "w", encoding="utf-8") as f:
        for entry in data:
            json.dump(entry, f, ensure_ascii=False)
            f.write("\n")

    print(f"✅ Parsed and chunked {len(data)} text blocks into wiki_chunks.jsonl")

if __name__ == "__main__":
    main()
