#!/usr/bin/env python3
"""
Whisper forced alignment: 用 faster-whisper 获取词级时间戳，
通过 difflib 序列匹配对齐到已有转录文本。
"""
import json, sys, os, re, time
import numpy as np
from difflib import SequenceMatcher

os.environ.setdefault("HF_HUB_DISABLE_PROGRESS_BARS", "1")

# ── 加载已有转录 ──────────────────────────────────
with open("deuda_segments.json") as f:
    data = json.load(f)
    old_segments = data["segments"]

# 构建纯净文本（去掉 speaker label 前缀）
def clean_text(text):
    return re.sub(r'^\[.*?\]\s*', '', text).strip()

old_texts = [clean_text(s["text"]) for s in old_segments]
old_flat = " ".join(old_texts)

# ── 运行 Whisper 获取时间戳 ──────────────────────
print("Loading faster-whisper model (small)...")
from faster_whisper import WhisperModel

model = WhisperModel("small", device="cpu", compute_type="int8")

print("Transcribing audio (this takes ~5-10 min for 51 min audio)...")
segments, info = model.transcribe(
    "deuda.mp3",
    language="es",
    beam_size=5,
    word_timestamps=True,
    vad_filter=True,         # 自动跳过静音/无语音段
    vad_parameters=dict(
        min_silence_duration_ms=500,
    ),
)

# 收集所有词及时间戳
whisper_words = []  # [(start, end, word), ...]
whisper_texts = []  # 仅文本
for seg in segments:
    if seg.words:
        for w in seg.words:
            whisper_words.append((w.start, w.end, w.word.strip()))
            whisper_texts.append(w.word.strip())

whisper_flat = " ".join(whisper_texts)
print(f"Whisper: {len(whisper_words)} words, {len(whisper_texts)} tokens")

# ── 序列对齐 ──────────────────────────────────────
print("Aligning transcriptions...")
matcher = SequenceMatcher(None, whisper_flat.lower(), old_flat.lower())
matching_blocks = matcher.get_matching_blocks()

# 构建 whisper 字符位置 → 词索引 的映射
whisper_char_to_word = []
for i, w in enumerate(whisper_texts):
    for _ in range(len(w)):
        whisper_char_to_word.append(i)
    whisper_char_to_word.append(i)  # for the space

# 对每个 matching block，映射时间戳到旧文本
old_char_pos = 0
old_seg_idx = 0
new_segments = []

for seg_idx, seg in enumerate(old_segments):
    seg_text = clean_text(seg["text"])
    seg_len = len(seg_text)
    seg_start_char = old_char_pos
    seg_end_char = old_char_pos + seg_len
    
    # 收集这个段落对应的时间戳
    seg_timestamps = []
    for block in matching_blocks:
        w_start = block.a           # whisper char pos
        w_end = block.a + block.size
        o_start = block.b           # old text char pos
        o_end = block.b + block.size
        
        # 这个 block 是否与当前段落重叠？
        if o_end <= seg_start_char or o_start >= seg_end_char:
            continue
        
        # 计算重叠区域
        overlap_start = max(o_start, seg_start_char)
        overlap_end = min(o_end, seg_end_char)
        
        # 映射到 whisper 字符位置
        offset = overlap_start - o_start
        w_overlap_start = w_start + offset
        w_overlap_end = w_start + offset + (overlap_end - overlap_start)
        
        # 映射到词索引
        if w_overlap_start < len(whisper_char_to_word) and w_overlap_end < len(whisper_char_to_word):
            word_start_idx = whisper_char_to_word[max(0, w_overlap_start)]
            word_end_idx = whisper_char_to_word[min(len(whisper_char_to_word)-1, w_overlap_end - 1)]
            
            if word_start_idx < len(whisper_words) and word_end_idx < len(whisper_words):
                seg_timestamps.append(whisper_words[word_start_idx][0])
                seg_timestamps.append(whisper_words[word_end_idx][1])
    
    # 计算最佳 start/end
    if seg_timestamps:
        seg_start_time = min(seg_timestamps)
        seg_end_time = max(seg_timestamps)
    else:
        # fallback: 使用旧估算
        seg_start_time = seg["start"]
        seg_end_time = seg["end"]
        print(f"  ⚠️  Segment {seg_idx} no match, using old estimate")
    
    new_segments.append({
        "start": round(seg_start_time, 2),
        "end": round(seg_end_time, 2),
        "speaker": seg["speaker"],
        "text": seg["text"],
    })
    
    old_char_pos += seg_len + 1  # +1 for space

# ── 验证 ──────────────────────────────────────────
print(f"\nAlignment complete: {len(new_segments)} segments")
errors = 0
for i in range(1, len(new_segments)):
    if new_segments[i]["start"] < new_segments[i-1]["end"] - 0.5:
        # 轻微重叠是正常的（词级映射可能产生），严重重叠才报警
        overlap = new_segments[i-1]["end"] - new_segments[i]["start"]
        if overlap > 2.0:
            print(f"  ⚠️  Overlap at seg {i}: prev_end={new_segments[i-1]['end']:.1f}s, curr_start={new_segments[i]['start']:.1f}s, overlap={overlap:.1f}s")
            errors += 1

if errors == 0:
    print("✅ All timestamps monotonically increasing")

# ── 输出 ──────────────────────────────────────────
with open("deuda_segments_aligned.json", "w") as f:
    json.dump(new_segments, f, ensure_ascii=False, indent=2)

print(f"\nWrote deuda_segments_aligned.json")
print(f"First 5 segments:")
for s in new_segments[:5]:
    print(f"  [{s['start']:7.1f}s - {s['end']:7.1f}s] {s['speaker']}: {s['text'][:60]}...")
