import json
import uuid
import urllib.parse
import subprocess
import os
import threading
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor

TRIMMED_DIR = "/root/trimmed_audio"
MAX_WORKERS = 16  # Adjust this based on your CPU cores

# Shared progress tracking
progress_lock = threading.Lock()
completed_count = 0
skipped_count = 0

def trim_audio_worker(task_info):
    """Worker function to run a single ffmpeg command."""
    global completed_count, skipped_count
    source_audio, trim_start, trim_end, trimmed_path, total_jobs = task_info
    
    if os.path.exists(trimmed_path):
        with progress_lock:
            skipped_count += 1
            completed_count += 1
            if completed_count % 100 == 0:
                print(f"Progress: {completed_count}/{total_jobs} (Skipped: {skipped_count})")
        return True

    command = [
        'ffmpeg', '-y',
        '-i', source_audio,
        '-ss', str(trim_start),
        '-to', str(trim_end),
        trimmed_path
    ]
    subprocess.run(command, stderr=subprocess.DEVNULL, stdout=subprocess.DEVNULL)
    
    with progress_lock:
        completed_count += 1
        if completed_count % 100 == 0:
            print(f"Progress: {completed_count}/{total_jobs} (Processed: {completed_count - skipped_count})")
    return True

def process_manifest(manifest_path):
    # Group by source audio
    audio_data = defaultdict(lambda: {"sentences": [], "source_path": ""})
    
    with open(manifest_path, "r", encoding="utf-8") as f:
        for line in f:
            item = json.loads(line.strip())
            source_audio = item["source_audio_filepath"]
            audio_data[source_audio]["source_path"] = source_audio
            
            sentence_info = {
                "text": item["text"],
                "start_time": item["offset"],
                "end_time": item["offset"] + item.get("duration", 10.0)
            }
            audio_data[source_audio]["sentences"].append(sentence_info)
            
    all_tasks = []
    temp_jobs = []
    
    print("Preparing task metadata...")
    for source_audio, data in audio_data.items():
        sentences = sorted(data["sentences"], key=lambda x: x["start_time"])
        basename = os.path.splitext(os.path.basename(source_audio))[0]
        out_dir = os.path.join(TRIMMED_DIR, basename)
        os.makedirs(out_dir, exist_ok=True)
        
        for i, sentence in enumerate(sentences):
            prev_chunk = sentences[i-1]["text"] if i > 0 else ""
            curr_chunk = sentence["text"]
            next_chunk = sentences[i+1]["text"] if i < len(sentences) - 1 else ""
            
            parts = [p for p in [prev_chunk, curr_chunk, next_chunk] if p]
            full_text = " ".join(parts)
            
            start_offset = len(prev_chunk) + 1 if prev_chunk else 0
            end_offset = start_offset + len(curr_chunk)
            
            trim_start = sentence["start_time"]
            trim_end = sentence["end_time"]
            if i > 0:
                trim_start = sentences[i-1]["start_time"]
            if i < len(sentences) - 1:
                trim_end = sentences[i+1]["end_time"]
                
            trim_start = max(0.0, trim_start - 5.0)
            trim_end = trim_end + 5.0
            
            trimmed_filename = f"chunk_{i:05d}.wav"
            trimmed_path = os.path.join(out_dir, trimmed_filename)
            
            # Temporary store jobs to count them first
            temp_jobs.append((source_audio, trim_start, trim_end, trimmed_path))
            
            # Build Label Studio Task Data
            relative_trimmed = trimmed_path[len("/root/"):] if trimmed_path.startswith("/root/") else trimmed_path
            audio_url = f"https://tts-labelstudio-audios.ai4bharat.org/{urllib.parse.quote(relative_trimmed)}"
            
            task = {
                "data": {
                    "audio": audio_url,
                    "full_text": full_text,
                    "original_trim_start": trim_start,
                    "original_trim_end": trim_end
                },
                "predictions": [{
                    "model_version": "auto_aligner_v1",
                    "result": [
                        {
                            "id": f"audio_{uuid.uuid4().hex[:8]}",
                            "from_name": "labels", "to_name": "audio", "type": "labels",
                            "value": {
                                "start": sentence["start_time"] - trim_start,
                                "end": sentence["end_time"] - trim_start,
                                "labels": ["Sentence"]
                            }
                        },
                        {
                            "id": f"text_{uuid.uuid4().hex[:8]}",
                            "from_name": "text_labels", "to_name": "full_text", "type": "labels",
                            "value": {
                                "start": start_offset, "end": end_offset,
                                "text": curr_chunk, "labels": ["Sentence"]
                            }
                        }
                    ]
                }]
            }
            all_tasks.append(task)
            
    total_jobs = len(temp_jobs)
    # Add total_jobs to each task info for the worker to print
    trim_jobs = [(*job, total_jobs) for job in temp_jobs]
            
    # Run ffmpeg jobs in parallel
    print(f"Starting parallel trimming with {MAX_WORKERS} workers ({total_jobs} total tasks)...")
    with ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor:
        list(executor.map(trim_audio_worker, trim_jobs))
            
    return all_tasks

if __name__ == "__main__":
    manifest_file = "/root/aivanta_chunks/manifest_final.jsonl"
    ls_tasks = process_manifest(manifest_file)
    
    output_file = "pre_annotations_sentences.json"
    with open(output_file, "w") as f:
        json.dump(ls_tasks, f, indent=2)
    print(f"\nSuccessfully generated {output_file} with {len(ls_tasks)} tasks")