import os
import json
import soundfile as sf
from datasets import load_dataset, get_dataset_config_names, Audio
from tqdm import tqdm

def process_rasa_per_language():
    # Configuration
    dataset_name = "ai4bharat/Rasa"
    output_base_dir = "rasa_subset_processed"
    target_style = "CONV"
    target_count = 10  # per gender for train
    
    # Create base directory
    os.makedirs(output_base_dir, exist_ok=True)
    
    # Get all available language configs
    print(f"Fetching configuration names for {dataset_name}...")
    try:
        # Note: If trust_remote_code=False is required for get_dataset_config_names, 
        # it relies on HF's default parsing. 
        all_configs = get_dataset_config_names(dataset_name)
    except Exception as e:
        print(f"Error fetching configs. Ensure you have access rights. Error: {e}")
        return

    lang_configs = [c for c in all_configs if c != "default"]
    print(f"Found {len(lang_configs)} languages.")

    for lang in tqdm(lang_configs, desc="Processing Languages"):
        
        # Define language-specific paths
        lang_dir = os.path.join(output_base_dir, lang)
        audio_dir = os.path.join(lang_dir, "audio")
        os.makedirs(audio_dir, exist_ok=True)
        
        train_jsonl_path = os.path.join(lang_dir, "train.jsonl")
        test_jsonl_path = os.path.join(lang_dir, "test.jsonl")

        # ---------------------------------------------------------
        # 1. PROCESS TRAIN SPLIT
        # ---------------------------------------------------------
        try:
            # Load full train split (Non-streaming, Multi-proc)
            ds_train = load_dataset(
                dataset_name, 
                lang, 
                split="train", 
                streaming=False, 
                num_proc=16, 
            )
            
            # Filter for CONV style first (using num_proc for speed)
            # We filter loosely first to reduce size
            ds_conv = ds_train.filter(
                lambda x: x.get("style", "") == target_style, 
                num_proc=16
            )

            # Select 10 Male and 10 Female
            selected_indices = []
            counts = {"Male": 0, "Female": 0}
            
            # Iterate to find indices (fast since data is local)
            for idx, item in enumerate(ds_conv):
                if counts["Male"] >= target_count and counts["Female"] >= target_count:
                    break
                
                gender = item.get("gender", "").capitalize()
                if gender in ["Male", "Female"] and counts[gender] < target_count:
                    selected_indices.append(idx)
                    counts[gender] += 1
            
            ds_subset = ds_conv.select(selected_indices)

            # Write Train JSONL and Save Audio
            with open(train_jsonl_path, "w", encoding="utf-8") as f_train:
                for i, item in enumerate(ds_subset):
                    gender = item.get("gender", "").capitalize()
                    # Save Audio File
                    audio_data = item["audio"]["array"]
                    sr = item["audio"]["sampling_rate"]
                    
                    filename = f"{lang}_{gender}_{target_style}_{i}.wav"
                    file_path = os.path.join(audio_dir, filename)
                    
                    sf.write(file_path, audio_data, sr)
                    
                    # Write Metadata
                    meta = {
                        "language": lang,
                        "gender": gender,
                        "style": target_style,
                        "text": item.get("text", ""),
                        "audio_filepath": file_path,
                        "sampling_rate": sr
                    }
                    f_train.write(json.dumps(meta, ensure_ascii=False) + "\n")

        except Exception as e:
            print(f"[{lang}] Train split error: {e}")

        # ---------------------------------------------------------
        # 2. PROCESS TEST SPLIT
        # ---------------------------------------------------------
        try:
            # Load test split 
            # Crucial: cast_column Audio(decode=False) to avoid decoding audio
            ds_test = load_dataset(
                dataset_name, 
                lang, 
                split="test", 
                streaming=False, 
                num_proc=16, 
            ).cast_column("audio", Audio(decode=False))
            
            with open(test_jsonl_path, "w", encoding="utf-8") as f_test:
                # Using map with num_proc to format is faster for large test sets
                def format_test_entry(batch):
                    entries = []
                    for i in range(len(batch["text"])):
                        audio_info = batch["audio"][i]
                        entries.append(json.dumps({
                            "language": lang,
                            "gender": batch["gender"][i],
                            "style": batch["style"][i],
                            "text": batch["text"][i],
                            "remote_path": audio_info.get("path", ""),
                            "sampling_rate": audio_info.get("sampling_rate", 16000)
                        }, ensure_ascii=False))
                    return {"json_str": entries}

                # Process in batches
                ds_test_mapped = ds_test.map(
                    format_test_entry, 
                    batched=True, 
                    num_proc=16, 
                    remove_columns=ds_test.column_names
                )
                
                # Write to file
                for row in ds_test_mapped:
                    f_test.write(row["json_str"] + "\n")

        except Exception as e:
            print(f"[{lang}] Test split error (or split missing): {e}")

    print(f"\nProcessing complete. Output saved to {output_base_dir}")

if __name__ == "__main__":
    process_rasa_per_language()