import os
import json
import soundfile as sf
from datasets import load_dataset, get_dataset_config_names, Audio
from tqdm import tqdm

def download_rasa_with_test_metadata():
    # Configuration
    dataset_name = "ai4bharat/Rasa"
    output_base_dir = "rasa_dataset_subset"
    train_metadata_file = os.path.join(output_base_dir, "train_metadata.jsonl")
    test_metadata_file = os.path.join(output_base_dir, "test_metadata.jsonl")
    
    target_style = "CONV"
    target_train_count = 10  # per gender per language for training
    
    # Create base directory
    os.makedirs(output_base_dir, exist_ok=True)
    
    # Get all available language configs
    print(f"Fetching configuration names for {dataset_name}...")
    try:
        all_configs = get_dataset_config_names(dataset_name)
    except Exception as e:
        print(f"Error fetching configs. Ensure you have accepted the dataset terms on HF. Error: {e}")
        return

    # Filter out 'default'
    lang_configs = [c for c in all_configs if c != "default"]
    print(f"Found {len(lang_configs)} languages.")

    # Open both metadata files
    with open(train_metadata_file, "w", encoding="utf-8") as train_f, \
         open(test_metadata_file, "w", encoding="utf-8") as test_f:
        
        for lang in tqdm(lang_configs, desc="Processing Languages"):
            
            # --- PART 1: TRAIN SET (Download Audio + Metadata) ---
            lang_dir = os.path.join(output_base_dir, lang)
            os.makedirs(lang_dir, exist_ok=True)
            
            collected = {"Male": 0, "Female": 0}
            
            try:
                ds_train = load_dataset(
                    dataset_name, lang, split="train", streaming=True, trust_remote_code=True
                )
                
                # Iterate Train Set
                for sample in ds_train:
                    if collected["Male"] >= target_train_count and collected["Female"] >= target_train_count:
                        break

                    style = sample.get("style", "").upper()
                    gender = sample.get("gender", "").capitalize()
                    
                    if style != target_style: continue
                    if gender not in ["Male", "Female"]: continue
                    if collected[gender] >= target_train_count: continue

                    # Process Audio (Download & Save)
                    audio_data = sample["audio"]["array"]
                    sr = sample["audio"]["sampling_rate"]
                    
                    filename = f"{lang}_{gender}_{style}_{collected[gender] + 1}.wav"
                    filepath = os.path.join(lang_dir, filename)
                    
                    sf.write(filepath, audio_data, sr)
                    
                    # Save Train Metadata
                    meta_entry = {
                        "split": "train",
                        "language": lang,
                        "gender": gender,
                        "style": style,
                        "text": sample.get("text", ""),
                        "speaker_id": sample.get("speaker_id", "unknown"),
                        "audio_filepath": filepath, # Local path
                        "sampling_rate": sr
                    }
                    train_f.write(json.dumps(meta_entry, ensure_ascii=False) + "\n")
                    collected[gender] += 1
            
            except Exception as e:
                print(f"Error processing TRAIN for {lang}: {e}")

            # --- PART 2: TEST SET (Metadata ONLY) ---
            try:
                # IMPORTANT: cast_column("audio", Audio(decode=False)) prevents downloading the actual wav files
                ds_test = load_dataset(
                    dataset_name, lang, split="test", streaming=True, trust_remote_code=True
                ).cast_column("audio", Audio(decode=False))

                for sample in ds_test:
                    # For test set, we likely want ALL styles for evaluation, not just CONV
                    # If you only want CONV, uncomment the if statement below:
                    # if sample.get("style", "").upper() != "CONV": continue

                    meta_entry = {
                        "split": "test",
                        "language": lang,
                        "gender": sample.get("gender", "").capitalize(),
                        "style": sample.get("style", ""),
                        "text": sample.get("text", ""),
                        "speaker_id": sample.get("speaker_id", "unknown"),
                        # Since decode=False, 'audio' is a dict containing 'path' (remote URL/path)
                        "remote_audio_path": sample["audio"].get("path", ""), 
                        "sampling_rate": sample["audio"].get("sampling_rate", 16000)
                    }
                    test_f.write(json.dumps(meta_entry, ensure_ascii=False) + "\n")

            except Exception as e:
                # Some languages might not have a test split
                print(f"Note: Could not process TEST for {lang} (or it doesn't exist): {e}")

    print(f"\nProcessing complete.")
    print(f"Train samples saved to: {output_base_dir}/")
    print(f"Train metadata: {train_metadata_file}")
    print(f"Test metadata:  {test_metadata_file}")

if __name__ == "__main__":
    download_rasa_with_test_metadata()