#!/usr/bin/env python3
import re
import json
import requests

# Read SRT file
srt_file = "/home/laihenyi/clawd/downloads/Siguiriya y Martinete【Let's dance to flamenco compas】 [-hhSuIMTKOc].srt"
output_file = srt_file.replace(".srt", "_zh-TW.srt")

with open(srt_file, 'r', encoding='utf-8') as f:
    content = f.read()

# Parse SRT blocks
pattern = r'(\d+)\n(\d{2}:\d{2}:\d{2},\d{3} --> \d{2}:\d{2}:\d{2},\d{3})\n(.*?)(?=\n\n|\Z)'
blocks = re.findall(pattern, content, re.DOTALL)

print(f"Found {len(blocks)} subtitle blocks")

# Translate function
def translate_batch(texts):
    prompt = f"""將以下日語翻譯成繁體中文，保持專業舞蹈術語的準確性（弗拉門戈 Flamenco）：

{chr(10).join([f'{i+1}. {t}' for i, t in enumerate(texts)])}

請只返回翻譯結果，每行對應原文的順序，不要添加編號。"""

    response = requests.post(
        "https://api.z.ai/api/coding/paas/v4/chat/completions",
        headers={
            "Content-Type": "application/json",
            "Authorization": "Bearer ee68d409114b45de80f35b8083ddae6e.yrNtMNnbD9LxuxEX"
        },
        json={
            "model": "glm-4.7",
            "messages": [
                {"role": "user", "content": prompt}
            ],
            "max_tokens": 2000
        }
    )

    result = response.json()
    translated = result['choices'][0]['message']['content'].strip()
    translations = translated.split('\n')
    # Clean up lines
    return [t.strip() for t in translations if t.strip()]

# Process all at once (try smaller batch if needed)
translated_blocks = []

# Try smaller batches
batch_size = 15

for i in range(0, len(blocks), batch_size):
    batch = blocks[i:i+batch_size]
    print(f"Translating batch {i//batch_size + 1}/{(len(blocks)-1)//batch_size + 1}...")

    texts = [block[2] for block in batch]

    try:
        translations = translate_batch(texts)
        print(f"  Got {len(translations)} translations for {len(texts)} texts")

        # Make sure we have the right number
        while len(translations) < len(texts):
            translations.append(translations[-1] if translations else texts[-1])
        while len(translations) > len(texts):
            translations.pop()

        for j, block in enumerate(batch):
            translated_blocks.append((block[0], block[1], translations[j]))
    except Exception as e:
        print(f"  Error: {e}, using original text")
        for block in batch:
            translated_blocks.append((block[0], block[1], block[2]))

# Write translated SRT
with open(output_file, 'w', encoding='utf-8') as f:
    for num, timing, text in translated_blocks:
        f.write(f"{num}\n{timing}\n{text}\n\n")

print(f"Saved to: {output_file}")
