inference code

import torch
from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq, AutoTokenizer, AutoModelForCausalLM
import soundfile as sf
from model import create_asr_model, modify_llama_blocks, decode_asr_output
import gc
import librosa
import numpy as np
import os
from datasets import load_dataset


def load_trained_model(model_path):
    gc.collect()
    torch.cuda.empty_cache()
    
    try:
        if torch.cuda.is_available():
            torch.cuda.set_per_process_memory_fraction(0.5)
        
        print("Loading Whisper encoder...")
        whisper = AutoModelForSpeechSeq2Seq.from_pretrained(
            "openai/whisper-large-v2",
            torch_dtype=torch.float16,
            low_cpu_mem_usage=True,
            device_map="auto"  # ์ž๋™์œผ๋กœ ๋ฉ”๋ชจ๋ฆฌ ๊ด€๋ฆฌ
        )
        whisper_encoder = whisper.get_encoder()
        
        print("Loading Llama...")
        tokenizer = AutoTokenizer.from_pretrained(
            "meta-llama/Llama-3.2-1B",
            use_fast=True
        )
        
        # ํ† ํฌ๋‚˜์ด์ € ์„ค์ •
        tokenizer.pad_token = tokenizer.eos_token
        tokenizer.padding_side = "left"
        
        # Llama ๋ชจ๋ธ ์„ค์ •
        llama = AutoModelForCausalLM.from_pretrained(
            "meta-llama/Llama-3.2-1B",
            torch_dtype=torch.float16,
            low_cpu_mem_usage=True,
            device_map="auto"  # ์ž๋™์œผ๋กœ ๋ฉ”๋ชจ๋ฆฌ ๊ด€๋ฆฌ
        )
        llama.config.pad_token_id = tokenizer.pad_token_id
        llama.resize_token_embeddings(len(tokenizer))
        
        modified_llama = modify_llama_blocks(llama, num_blocks_to_keep=2)
        del llama
        gc.collect()
        
        print("Creating model...")
        model = create_asr_model(whisper_encoder, modified_llama)
        model = model.half()
        
        print("Loading weights...")
        state_dict = torch.load(model_path, map_location='cpu')
        
        # ๋””๋ฒ„๊น… ์ •๋ณด ์ถœ๋ ฅ
        print(f"\nModel vocab size: {model.decoder.model.embed_tokens.weight.shape[0]}")
        print(f"Tokenizer vocab size: {len(tokenizer)}")
        print(f"BOS token id: {tokenizer.bos_token_id}")
        print(f"EOS token id: {tokenizer.eos_token_id}")
        print(f"PAD token id: {tokenizer.pad_token_id}")
        
        missing, unexpected = model.load_state_dict(
            {k: v.half() for k, v in state_dict.items()}, 
            strict=False
        )
        
        print(f"\nMissing keys: {len(missing)}")
        print(f"Unexpected keys: {len(unexpected)}")
        processor = AutoProcessor.from_pretrained("openai/whisper-large-v2")

        model.eval()
        
        return model, processor, tokenizer
        
    except Exception as e:
        print(f"Error during model loading: {e}")
        gc.collect()
        torch.cuda.empty_cache()
        raise

def process_audio(audio_path, processor):
    try:
        print(f"Loading audio from {audio_path}...")
        # librosa๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ์ž๋™ ๋ฆฌ์ƒ˜ํ”Œ๋ง
        audio, orig_sr = librosa.load(audio_path)
        
        # 16kHz๋กœ ๋ฆฌ์ƒ˜ํ”Œ๋ง
        if orig_sr != 16000:
            print(f"Resampling from {orig_sr}Hz to 16000Hz")
            audio = librosa.resample(audio, orig_sr=orig_sr, target_sr=16000)
        
        # ์˜ค๋””์˜ค ์ •๊ทœํ™”
        audio = audio / np.abs(audio).max()
        
        input_features = processor(
            audio,
            sampling_rate=16000,
            return_tensors="pt"
        ).input_features.half()
        
        return input_features
        
    except Exception as e:
        print(f"Error processing audio: {e}")
        raise

def run_inference(model, input_features, tokenizer, max_length=200):
    try:
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        print(f"\nUsing device: {device}")
        
        with torch.cuda.amp.autocast():
            with torch.no_grad():
                if torch.cuda.is_available():
                    model = model.to(device)
                    input_features = input_features.to(device)
                
                print("\nInput features shape:", input_features.shape)
                
                # ์‹œ์ž‘ ํ† ํฐ ์„ค์ •
                start_token = tokenizer.bos_token_id
                print(f"Using start token: {start_token} ({tokenizer.decode([start_token])})")
                
                decoder_input_ids = torch.tensor([[start_token]], 
                                               device=device,
                                               dtype=torch.long)
                
                # Greedy decoding
                max_length = 100
                generated_ids = []
                
                for _ in range(max_length):
                    outputs = model(
                        input_features=input_features,
                        decoder_input_ids=decoder_input_ids
                    )
                    
                    next_token_logits = outputs.logits[:, -1, :]
                    next_token_id = torch.argmax(next_token_logits, dim=-1).item()
                    
                    # Top 5 ํ† ํฐ ์ถœ๋ ฅ
                    top_tokens = torch.topk(next_token_logits[0], 5)
                    print("\nTop 5 tokens for position", len(generated_ids))
                    for token_id, logit in zip(top_tokens.indices, top_tokens.values):
                        token = tokenizer.decode([token_id])
                        prob = torch.softmax(top_tokens.values, dim=-1)[0].item()
                        print(f"Token: {token}, Probability: {prob:.4f}")
                    
                    generated_ids.append(next_token_id)
                    
                    if next_token_id == tokenizer.eos_token_id:
                        break
                        
                    decoder_input_ids = torch.cat([
                        decoder_input_ids, 
                        torch.tensor([[next_token_id]], device=device)
                    ], dim=-1)
                
                # ์ „์ฒด ์‹œํ€€์Šค ๋””์ฝ”๋”ฉ
                text = tokenizer.decode(generated_ids, skip_special_tokens=True)
                
                if torch.cuda.is_available():
                    model = model.cpu()
                    torch.cuda.empty_cache()
                
                return text
                
    except Exception as e:
        print(f"Error during inference: {e}")
        torch.cuda.empty_cache()
        raise
    finally:
        gc.collect()
        torch.cuda.empty_cache()


def main():
    try:
        model_path = "/home/elicer/.cache/huggingface/hub/models--Kyudan--whisperllama/snapshots/3269c93814c84e38f2d1a46935851f4923d73659/best_model_epoch_0.pt"
        
        # LibriSpeech ํ…Œ์ŠคํŠธ ์…‹ ๋กœ๋“œ (10๊ฐœ ์ƒ˜ํ”Œ)
        print("Loading LibriSpeech test samples...")
        dataset = load_dataset("librispeech_asr", "clean", split="test", streaming=True)
        samples = list(dataset.take(10))  # 10๊ฐœ ์ƒ˜ํ”Œ๋งŒ ๊ฐ€์ ธ์˜ค๊ธฐ
        
        print("Loading model...")
        model, processor, tokenizer = load_trained_model(model_path)
        
        # ๊ฐ ์ƒ˜ํ”Œ์— ๋Œ€ํ•ด ์ถ”๋ก  ์‹คํ–‰
        for idx, sample in enumerate(samples, 1):
            print(f"\nProcessing sample {idx}/10...")
            print(f"Speaker ID: {sample['speaker_id']}")
            print(f"Chapter ID: {sample['chapter_id']}")
            print(f"Reference text: {sample['text']}")
            
            # ์˜ค๋””์˜ค ์ฒ˜๋ฆฌ
            input_features = processor(
                sample['audio']['array'],
                sampling_rate=16000,
                return_tensors="pt"
            ).input_features.half()
            
            # ์ถ”๋ก  ์‹คํ–‰
            print("Running inference...")
            transcribed_text = run_inference(model, input_features, tokenizer)
            
            print("\nTranscription Results:")
            print("-" * 50)
            print(f"Model output: {transcribed_text}")
            print(f"Reference  : {sample['text']}")
            print("-" * 50)
            
    except Exception as e:
        print(f"Error in main: {e}")
    finally:
        gc.collect()
        torch.cuda.empty_cache()

if __name__ == "__main__":
    main()
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model is not currently available via any of the supported Inference Providers.
The model cannot be deployed to the HF Inference API: The model has no library tag.

Model tree for Kyudan/whisperllama

Finetuned
(249)
this model

Dataset used to train Kyudan/whisperllama