rayl-aoit commited on
Commit
23d42a9
1 Parent(s): de3b2d1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -5
app.py CHANGED
@@ -1,8 +1,24 @@
1
  import gradio as gr
2
- from transformers import pipeline
 
 
 
3
 
4
- get_completion = pipeline("summarization", model="sshleifer/distilbart-cnn-12-6")
5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  def translate(input_text, source, target):
7
  # source_readable = source
8
  # if source == "Auto Detect" or source.startswith("Detected"):
@@ -20,14 +36,19 @@ def translate(input_text, source, target):
20
  return "", f"Error: Translation direction {source_readable} to {target} is not supported by Helsinki Translation Models"
21
 
22
  def summarize(input):
23
- output = get_completion(input)
24
  summary_origin = output[0]['summary_text']
25
  summary_translated = translate(summary_origin,'en','fr')
26
- return summary_origin, summary_translated[0]
 
27
 
28
  demo = gr.Interface(fn=summarize,
29
  inputs=[gr.Textbox(label="Text to summarize", lines=6)],
30
- outputs=[gr.Textbox(label="Result", lines=3),gr.Textbox(label="Translate Result", lines=3)],
 
 
 
 
31
  title="Text summarization with distilbart-cnn",
32
  description="Summarize any text using the `sshleifer/distilbart-cnn-12-6` model under the hood!",
33
  examples=[
 
1
  import gradio as gr
2
+ from transformers import pipeline, VitsModel, AutoTokenizer, set_seed
3
+ import torch
4
+ import uuid
5
+ import scipy.io.wavfile as wav
6
 
 
7
 
8
+ generate_summary = pipeline("summarization", model="sshleifer/distilbart-cnn-12-6")
9
+ tts_model = VitsModel.from_pretrained("facebook/mms-tts-eng")
10
+ tts_tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-eng")
11
+
12
+ def gen_speech(text):
13
+ set_seed(555) # Make it deterministic
14
+ input_text = tts_tokenizer(text, return_tensors="pt")
15
+ with torch.no_grad():
16
+ outputs = tts_model(**input_text)
17
+ waveform_np = outputs.waveform[0].cpu().numpy()
18
+ output_file = f"{str(uuid.uuid4())}.wav"
19
+ wav.write(output_file, rate=tts_model.config.sampling_rate, data=waveform_np)
20
+ return output_file
21
+
22
  def translate(input_text, source, target):
23
  # source_readable = source
24
  # if source == "Auto Detect" or source.startswith("Detected"):
 
36
  return "", f"Error: Translation direction {source_readable} to {target} is not supported by Helsinki Translation Models"
37
 
38
  def summarize(input):
39
+ output = generate_summary(input)
40
  summary_origin = output[0]['summary_text']
41
  summary_translated = translate(summary_origin,'en','fr')
42
+ audio_output_filepath = gen_speech(summary_origin)
43
+ return summary_origin, summary_translated[0], audio_output_filepath
44
 
45
  demo = gr.Interface(fn=summarize,
46
  inputs=[gr.Textbox(label="Text to summarize", lines=6)],
47
+ outputs=[
48
+ gr.Textbox(label="Result", lines=3),
49
+ gr.Textbox(label="Translate Result", lines=3),
50
+ gr.Audio(type="filepath", label="Generated Speech")
51
+ ],
52
  title="Text summarization with distilbart-cnn",
53
  description="Summarize any text using the `sshleifer/distilbart-cnn-12-6` model under the hood!",
54
  examples=[