Awaz-e-Sehat's picture
Update handler.py
69a9697 verified
from typing import Dict, List, Any
from transformers import pipeline
from transformers import (
AutomaticSpeechRecognitionPipeline,
WhisperForConditionalGeneration,
WhisperTokenizer,
WhisperProcessor,
)
# import faster_whisper
# import json
import logging
from peft import PeftModel, PeftConfig
import torch
logger = logging.getLogger(__name__)
class EndpointHandler():
def __init__(self, path=""):
peft_model_id = "Awaz-e-Sehat/whisper-fine-tune-new-LoRA" # Use the same model ID as before.
peft_config = PeftConfig.from_pretrained(peft_model_id)
model = WhisperForConditionalGeneration.from_pretrained(peft_config.base_model_name_or_path,load_in_8bit=True,device_map="auto")
# self.model = faster_whisper.WhisperModel(path, device = "cuda")
language = "Urdu"
task = "transcribe"
model = PeftModel.from_pretrained(model, peft_model_id)
tokenizer = WhisperTokenizer.from_pretrained(peft_config.base_model_name_or_path, language=language, task=task)
processor = WhisperProcessor.from_pretrained(peft_config.base_model_name_or_path, language=language, task=task)
feature_extractor = processor.feature_extractor
self.forced_decoder_ids = processor.get_decoder_prompt_ids(language=language, task=task)
self.pipe = AutomaticSpeechRecognitionPipeline(model=model, tokenizer=tokenizer, feature_extractor=feature_extractor, chunk_length_s = 30, stride_length_s = 5)
logger.info("Model Initialized")
def __call__(self, data: Any) -> str:
"""
data args:
inputs (:obj: `str`)
date (:obj: `str`)
Return:
A :obj:`list` | `dict`: will be serialized and returned
"""
# get inputs
logger.info("In inference")
logger.info(data)
inputs = data.pop("inputs",data)
logger.info("Data pop")
logger.info(inputs)
# segments, _ = self.model.transcribe(inputs, language = "ur", task = "transcribe")
# logger.info("model transcribe")
# segments = list(segments)
# logger.info("Actual transcribed")
# prediction = ''
# for i in segments:
# prediction += i[4]
# return prediction
with torch.cuda.amp.autocast():
text = self.pipe(inputs, generate_kwargs={"forced_decoder_ids": self.forced_decoder_ids}, max_new_tokens=255)["text"]
return text