#!/usr/bin/env python3
"""
Whisper forced alignment v2: 用 Whisper 自己的分段获取精确时间戳，
通过 rapidfuzz 把 Radio Ambulante 的 speaker 标签映射上去。

策略：信任 Whisper 的音频分段（时间戳准确），用模糊匹配从旧转录里找对应段落的 speaker。
"""
import json, sys, os, re
import numpy as np

os.environ.setdefault("HF_HUB_DISABLE_PROGRESS_BARS", "1")

# ── 加载已有转录 ──────────────────────────────────
with open("deuda_segments.json") as f:
    data = json.load(f)
    old_segments = data["segments"]

# 构建 speaker 映射表：{规范化文本 → speaker}
def normalize(s):
    """小写，去标点，去多余空格"""
    s = s.lower().strip()
    s = re.sub(r'[^\w\sáéíóúñü]', '', s)
    s = re.sub(r'\s+', ' ', s)
    return s

old_map = [(normalize(s["text"]), s["speaker"]) for s in old_segments]

# ── 运行 Whisper ──────────────────────────────────
print("Loading faster-whisper model (small)...")
from faster_whisper import WhisperModel

model = WhisperModel("small", device="cpu", compute_type="int8")

print("Transcribing (51 min audio, ~5-8 min)...")
segments, info = model.transcribe(
    "deuda.mp3",
    language="es",
    beam_size=5,
    word_timestamps=True,
    vad_filter=True,
    vad_parameters=dict(min_silence_duration_ms=500),
)

# 收集 Whisper 分段
whisper_segments = []
for seg in segments:
    text = seg.text.strip()
    if not text:
        continue
    whisper_segments.append({
        "start": round(seg.start, 2),
        "end": round(seg.end, 2),
        "text": text,
    })

print(f"Whisper produced {len(whisper_segments)} segments")

# ── 模糊匹配映射 speaker ─────────────────────────
print("Loading rapidfuzz for speaker matching...")
from rapidfuzz import fuzz, process

# 预处理旧转录文本
old_texts = [t for t, _ in old_map]

new_segments = []
unmatched = 0

for i, ws in enumerate(whisper_segments):
    ws_norm = normalize(ws["text"])
    
    # 尝试精确子串匹配（归一化后）
    speaker = None
    match_score = 0
    
    # 策略1: 子串包含（同一句话的不同切分）
    for old_text, old_speaker in old_map:
        old_norm = normalize(old_text)
        if ws_norm in old_norm or old_norm in ws_norm:
            speaker = old_speaker
            match_score = 100
            break
    
    # 策略2: rapidfuzz 相似度
    if not speaker:
        best = process.extractOne(ws_norm, old_texts, scorer=fuzz.token_sort_ratio)
        if best and best[1] >= 75:
            idx = best[2]
            speaker = old_map[idx][1]
            match_score = best[1]
    
    # 策略3: 继承前一 speaker（连续对话）
    if not speaker and new_segments:
        # 如果前后 whisper 段落相邻（< 2s 间隔），继承 speaker
        gap = ws["start"] - new_segments[-1]["end"]
        if gap < 2.0:
            speaker = new_segments[-1]["speaker"]
            match_score = 0  # inherited
    
    if speaker:
        new_segments.append({
            "start": ws["start"],
            "end": ws["end"],
            "speaker": speaker,
            "text": ws["text"],
        })
    else:
        unmatched += 1
        # 即使 unmatched 也保留，标为 Unknown
        new_segments.append({
            "start": ws["start"],
            "end": ws["end"],
            "speaker": "Unknown",
            "text": ws["text"],
        })

print(f"Mapped {len(new_segments) - unmatched} segments, {unmatched} unmatched")

# ── 合并相邻同 speaker 段落 ──────────────────────
print("Merging adjacent same-speaker segments...")
merged = []
for seg in new_segments:
    if merged and merged[-1]["speaker"] == seg["speaker"]:
        # 合并
        merged[-1]["end"] = seg["end"]
        merged[-1]["text"] += " " + seg["text"]
    else:
        merged.append(seg.copy())

print(f"Merged: {len(new_segments)} → {len(merged)} segments")

# ── 输出 ──────────────────────────────────────────
output = {
    "title": data["title"],
    "audio_url": data["audio_url"],
    "segments": merged,
}

with open("deuda_segments.json", "w") as f:
    json.dump(output, f, ensure_ascii=False, indent=2)

print(f"\n✅ Updated deuda_segments.json with {len(merged)} segments")
print(f"Duration: {merged[-1]['end']:.0f}s ({merged[-1]['end']/60:.1f} min)")
print(f"\nFirst 8 segments:")
for s in merged[:8]:
    print(f"  [{s['start']:7.1f}s - {s['end']:7.1f}s] {s['speaker']:15s}: {s['text'][:80]}...")

# 统计 speaker
from collections import Counter
speakers = Counter(s["speaker"] for s in merged)
print(f"\nSpeaker distribution:")
for sp, count in speakers.most_common():
    print(f"  {sp:20s}: {count} segments")
