#!/usr/bin/env python3
"""
Word-level alignment: 用 faster-whisper 词级时间戳 + speaker 标签映射。
输出每个词的 start/end/speaker，播放器逐词高亮。
"""
import json, os, re
os.environ.setdefault("HF_HUB_DISABLE_PROGRESS_BARS", "1")

# ── 加载已有分段（用于 speaker 映射）────────────
with open("deuda_segments.json") as f:
    data = json.load(f)
    old_segments = data["segments"]

def normalize(s):
    s = s.lower().strip()
    s = re.sub(r'[^\w\sáéíóúñü]', '', s)
    s = re.sub(r'\s+', ' ', s)
    return s

old_map = [(normalize(s["text"]), s["speaker"]) for s in old_segments]
old_texts = [t for t, _ in old_map]

# ── 运行 Whisper（词级）──────────────────────────
print("Loading faster-whisper...")
from faster_whisper import WhisperModel

model = WhisperModel("small", device="cpu", compute_type="int8")

print("Transcribing with word timestamps...")
segments, info = model.transcribe(
    "deuda.mp3",
    language="es",
    beam_size=5,
    word_timestamps=True,
    vad_filter=True,
    vad_parameters=dict(min_silence_duration_ms=500),
)

# ── 收集所有词 ────────────────────────────────────
from rapidfuzz import fuzz, process

words = []
segment_text_buffer = ""
segment_speaker = None
segment_start = None

for seg in segments:
    if not seg.words:
        continue
    
    seg_text = seg.text.strip()
    seg_norm = normalize(seg_text)
    
    # 匹配 speaker
    speaker = None
    for old_norm, old_speaker in old_map:
        if seg_norm in old_norm or old_norm in seg_norm:
            speaker = old_speaker
            break
    
    if not speaker:
        best = process.extractOne(seg_norm, old_texts, scorer=fuzz.token_sort_ratio)
        if best and best[1] >= 70:
            speaker = old_map[best[2]][1]
    
    if not speaker and words:
        # 如果词间隔 < 2s，继承前一 speaker
        gap = seg.start - words[-1]["end"]
        if gap < 2.0:
            speaker = words[-1]["speaker"]
    
    if not speaker:
        speaker = "Unknown"
    
    # 输出词级数据
    for w in seg.words:
        word_text = w.word.strip()
        if not word_text:
            continue
        words.append({
            "start": round(w.start, 3),
            "end": round(w.end, 3),
            "word": word_text,
            "speaker": speaker,
        })

print(f"Collected {len(words)} words")

# ── 后处理：标点符号归属 ──────────────────────────
# Whisper 常把标点附着在词上 ("Hola." → word="Hola.")
# 分离标点，让标点保持在前一词上
processed_words = []
for w in words:
    text = w["word"]
    # 如果词以标点结尾，分离
    m = re.match(r'^(.+?)([.,!?;:¿¡—]+)$', text)
    if m and len(m.group(1)) > 0:
        # 词本身
        processed_words.append({
            "start": w["start"],
            "end": w["end"] - 0.05,  # 稍微缩短给标点留时间
            "word": m.group(1),
            "speaker": w["speaker"],
        })
        # 标点作为独立词（不发音，end=start）
        processed_words.append({
            "start": w["end"] - 0.05,
            "end": w["end"],
            "word": m.group(2),
            "speaker": w["speaker"],
        })
    else:
        processed_words.append(w)

words = processed_words

# ── Speaker 变化检测 ──────────────────────────────
# 在 speaker 切换时插入标识
word_sequence = []
prev_speaker = None
for w in words:
    speaker_changed = (prev_speaker is not None and w["speaker"] != prev_speaker)
    word_sequence.append({
        **w,
        "speaker_change": speaker_changed,
    })
    prev_speaker = w["speaker"]

# ── 输出 ──────────────────────────────────────────
output = {
    "title": data["title"],
    "audio_url": data["audio_url"],
    "words": word_sequence,
    "total_words": len(word_sequence),
    "duration": round(words[-1]["end"], 1) if words else 0,
}

with open("deuda_words.json", "w") as f:
    json.dump(output, f, ensure_ascii=False, indent=2)

print(f"\n✅ Wrote deuda_words.json ({len(word_sequence)} words, {output['duration']:.1f}s)")

# 统计
from collections import Counter
speakers = Counter(w["speaker"] for w in word_sequence)
print(f"\nSpeaker distribution:")
for sp, count in speakers.most_common():
    print(f"  {sp:20s}: {count} words ({count/len(word_sequence)*100:.0f}%)")

# 显示前 20 个词
print(f"\nFirst 20 words:")
for w in word_sequence[:20]:
    change = " 🔄" if w["speaker_change"] else ""
    print(f"  [{w['start']:7.2f}s] {w['speaker']:10s}: {w['word']}{change}")
