Awaz-e-Sehat commited on
Commit
3a2b92f
1 Parent(s): 0ce23cc

Upload handler.py

Browse files
Files changed (1) hide show
  1. handler.py +58 -0
handler.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # from typing import Dict, List, Any
2
+ from transformers import pipeline
3
+ from transformers import (
4
+ AutomaticSpeechRecognitionPipeline,
5
+ WhisperForConditionalGeneration,
6
+ WhisperTokenizer,
7
+ WhisperProcessor,
8
+ )
9
+ # import faster_whisper
10
+ # import json
11
+ import logging
12
+ from peft import PeftModel, PeftConfig
13
+ import torch
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+ class EndpointHandler():
18
+ def __init__(self, path=""):
19
+ peft_model_id = "Awaz-e-Sehat/whisper-fine-tune-new-LoRA" # Use the same model ID as before.
20
+ peft_config = PeftConfig.from_pretrained(peft_model_id)
21
+ model = WhisperForConditionalGeneration.from_pretrained(peft_config.base_model_name_or_path,load_in_8bit=True,device_map="auto")
22
+ # self.model = faster_whisper.WhisperModel(path, device = "cuda")
23
+ language = "Urdu"
24
+ task = "transcribe"
25
+ model = PeftModel.from_pretrained(model, peft_model_id)
26
+ tokenizer = WhisperTokenizer.from_pretrained(peft_config.base_model_name_or_path, language=language, task=task)
27
+ processor = WhisperProcessor.from_pretrained(peft_config.base_model_name_or_path, language=language, task=task)
28
+ feature_extractor = processor.feature_extractor
29
+ self.forced_decoder_ids = processor.get_decoder_prompt_ids(language=language, task=task)
30
+ self.pipe = AutomaticSpeechRecognitionPipeline(model=model, tokenizer=tokenizer, feature_extractor=feature_extractor, chunk_length_s = 30, stride_length_s = 5)
31
+ logger.info("Model Initialized")
32
+
33
+ def __call__(self, data: Any) -> str:
34
+ """
35
+ data args:
36
+ inputs (:obj: `str`)
37
+ date (:obj: `str`)
38
+ Return:
39
+ A :obj:`list` | `dict`: will be serialized and returned
40
+ """
41
+ # get inputs
42
+ logger.info("In inference")
43
+ logger.info(data)
44
+ inputs = data.pop("inputs",data)
45
+ logger.info("Data pop")
46
+ logger.info(inputs)
47
+ # segments, _ = self.model.transcribe(inputs, language = "ur", task = "transcribe")
48
+ # logger.info("model transcribe")
49
+ # segments = list(segments)
50
+ # logger.info("Actual transcribed")
51
+ # prediction = ''
52
+ # for i in segments:
53
+ # prediction += i[4]
54
+ # return prediction
55
+
56
+ with torch.cuda.amp.autocast():
57
+ text = self.pipe(inputs, generate_kwargs={"forced_decoder_ids": self.forced_decoder_ids}, max_new_tokens=255)["text"]
58
+ return text