CamiloVega commited on
Commit
30cf2e4
·
verified ·
1 Parent(s): a1d8b8e

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +541 -0
app.py ADDED
@@ -0,0 +1,541 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ import gradio as gr
3
+ import logging
4
+ import os
5
+ import tempfile
6
+ import pandas as pd
7
+ import requests
8
+ from bs4 import BeautifulSoup
9
+ from transformers import AutoModelForCausalLM, AutoTokenizer
10
+ import torch
11
+ import whisper
12
+ from moviepy.editor import VideoFileClip
13
+ from pydub import AudioSegment
14
+ import fitz
15
+ import docx
16
+ import yt_dlp
17
+ from functools import lru_cache
18
+ import gc
19
+
20
+ # Configure logging
21
+ logging.basicConfig(
22
+ level=logging.INFO,
23
+ format='%(asctime)s - %(levelname)s - %(message)s'
24
+ )
25
+ logger = logging.getLogger(__name__)
26
+
27
+ class ModelManager:
28
+ _instance = None
29
+
30
+ def __new__(cls):
31
+ if cls._instance is None:
32
+ cls._instance = super(ModelManager, cls).__new__(cls)
33
+ cls._instance._initialized = False
34
+ return cls._instance
35
+
36
+ def __init__(self):
37
+ if not self._initialized:
38
+ self.tokenizer = None
39
+ self.model = None
40
+ self.news_generator = None
41
+ self.whisper_model = None
42
+ self._initialized = True
43
+
44
+ @spaces.GPU(duration=120)
45
+ def initialize_models(self):
46
+ """Initialize models with ZeroGPU compatible settings"""
47
+ try:
48
+ import torch
49
+ from transformers import AutoModelForCausalLM, AutoTokenizer
50
+
51
+ HUGGINGFACE_TOKEN = os.environ.get('HUGGINGFACE_TOKEN')
52
+ if not HUGGINGFACE_TOKEN:
53
+ raise ValueError("HUGGINGFACE_TOKEN environment variable not set")
54
+
55
+ logger.info("Starting model initialization...")
56
+ model_name = "meta-llama/Llama-2-7b-chat-hf"
57
+
58
+ # Load tokenizer
59
+ logger.info("Loading tokenizer...")
60
+ self.tokenizer = AutoTokenizer.from_pretrained(
61
+ model_name,
62
+ token=HUGGINGFACE_TOKEN,
63
+ use_fast=True,
64
+ model_max_length=512
65
+ )
66
+ self.tokenizer.pad_token = self.tokenizer.eos_token
67
+
68
+ # Initialize model with ZeroGPU compatible settings
69
+ logger.info("Loading model...")
70
+ self.model = AutoModelForCausalLM.from_pretrained(
71
+ model_name,
72
+ token=HUGGINGFACE_TOKEN,
73
+ device_map="auto", # Automatically handle device placement
74
+ torch_dtype=torch.float16, # Use float16 to reduce memory usage
75
+ low_cpu_mem_usage=True, # Optimize CPU memory usage
76
+ use_safetensors=True, # Use safetensors for better memory management
77
+ max_memory={0: "6GB"}, # Limit GPU memory usage
78
+ offload_folder="offload", # Folder for offloading to CPU
79
+ offload_state_dict=True # Offload state dict to CPU
80
+ )
81
+
82
+ # Create pipeline with minimal settings
83
+ logger.info("Creating pipeline...")
84
+ from transformers import pipeline
85
+ self.news_generator = pipeline(
86
+ "text-generation",
87
+ model=self.model,
88
+ tokenizer=self.tokenizer,
89
+ device_map="auto", # Automatically handle device placement
90
+ torch_dtype=torch.float16, # Use float16 for memory efficiency
91
+ max_new_tokens=512,
92
+ do_sample=True,
93
+ temperature=0.7,
94
+ top_p=0.95,
95
+ repetition_penalty=1.2,
96
+ num_return_sequences=1,
97
+ early_stopping=True
98
+ )
99
+
100
+ # Load Whisper model with minimal settings
101
+ logger.info("Loading Whisper model...")
102
+ self.whisper_model = whisper.load_model(
103
+ "tiny",
104
+ device="cuda" if torch.cuda.is_available() else "cpu",
105
+ download_root="/tmp/whisper"
106
+ )
107
+
108
+ logger.info("All models initialized successfully")
109
+ return True
110
+
111
+ except Exception as e:
112
+ logger.error(f"Error during model initialization: {str(e)}")
113
+ self.reset_models()
114
+ raise
115
+
116
+ def reset_models(self):
117
+ """Reset all models and clear memory"""
118
+ try:
119
+ if hasattr(self, 'model') and self.model is not None:
120
+ self.model.cpu()
121
+ del self.model
122
+
123
+ if hasattr(self, 'tokenizer') and self.tokenizer is not None:
124
+ del self.tokenizer
125
+
126
+ if hasattr(self, 'news_generator') and self.news_generator is not None:
127
+ del self.news_generator
128
+
129
+ if hasattr(self, 'whisper_model') and self.whisper_model is not None:
130
+ if hasattr(self.whisper_model, 'cpu'):
131
+ self.whisper_model.cpu()
132
+ del self.whisper_model
133
+
134
+ self.tokenizer = None
135
+ self.model = None
136
+ self.news_generator = None
137
+ self.whisper_model = None
138
+
139
+ if torch.cuda.is_available():
140
+ torch.cuda.empty_cache()
141
+ torch.cuda.synchronize()
142
+
143
+ import gc
144
+ gc.collect()
145
+
146
+ except Exception as e:
147
+ logger.error(f"Error during model reset: {str(e)}")
148
+
149
+ def check_models_initialized(self):
150
+ """Check if all models are properly initialized"""
151
+ if None in (self.tokenizer, self.model, self.news_generator, self.whisper_model):
152
+ logger.warning("Models not initialized, attempting to initialize...")
153
+ self.initialize_models()
154
+
155
+ def get_models(self):
156
+ """Get initialized models, initializing if necessary"""
157
+ self.check_models_initialized()
158
+ return self.tokenizer, self.model, self.news_generator, self.whisper_model
159
+
160
+ # Create global model manager instance
161
+ model_manager = ModelManager()
162
+
163
+ @lru_cache(maxsize=32)
164
+ def download_social_media_video(url):
165
+ """Download a video from social media."""
166
+ ydl_opts = {
167
+ 'format': 'bestaudio/best',
168
+ 'postprocessors': [{
169
+ 'key': 'FFmpegExtractAudio',
170
+ 'preferredcodec': 'mp3',
171
+ 'preferredquality': '192',
172
+ }],
173
+ 'outtmpl': '%(id)s.%(ext)s',
174
+ }
175
+ try:
176
+ with yt_dlp.YoutubeDL(ydl_opts) as ydl:
177
+ info_dict = ydl.extract_info(url, download=True)
178
+ audio_file = f"{info_dict['id']}.mp3"
179
+ logger.info(f"Video downloaded successfully: {audio_file}")
180
+ return audio_file
181
+ except Exception as e:
182
+ logger.error(f"Error downloading video: {str(e)}")
183
+ raise
184
+
185
+ def convert_video_to_audio(video_file):
186
+ """Convert a video file to audio."""
187
+ try:
188
+ video = VideoFileClip(video_file)
189
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as temp_file:
190
+ video.audio.write_audiofile(temp_file.name)
191
+ logger.info(f"Video converted to audio: {temp_file.name}")
192
+ return temp_file.name
193
+ except Exception as e:
194
+ logger.error(f"Error converting video: {str(e)}")
195
+ raise
196
+
197
+ def preprocess_audio(audio_file):
198
+ """Preprocess the audio file to improve quality."""
199
+ try:
200
+ audio = AudioSegment.from_file(audio_file)
201
+ audio = audio.apply_gain(-audio.dBFS + (-20))
202
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as temp_file:
203
+ audio.export(temp_file.name, format="mp3")
204
+ logger.info(f"Audio preprocessed: {temp_file.name}")
205
+ return temp_file.name
206
+ except Exception as e:
207
+ logger.error(f"Error preprocessing audio: {str(e)}")
208
+ raise
209
+
210
+ @spaces.GPU(duration=120)
211
+ def transcribe_audio(file):
212
+ """Transcribe an audio or video file."""
213
+ try:
214
+ _, _, _, whisper_model = model_manager.get_models()
215
+
216
+ if isinstance(file, str) and file.startswith('http'):
217
+ file_path = download_social_media_video(file)
218
+ elif isinstance(file, str) and file.lower().endswith(('.mp4', '.avi', '.mov', '.mkv')):
219
+ file_path = convert_video_to_audio(file)
220
+ else:
221
+ file_path = preprocess_audio(file)
222
+
223
+ logger.info(f"Transcribing audio: {file_path}")
224
+ if not os.path.exists(file_path):
225
+ raise FileNotFoundError(f"Audio file not found: {file_path}")
226
+
227
+ with torch.inference_mode():
228
+ result = whisper_model.transcribe(file_path)
229
+ if not result:
230
+ raise RuntimeError("Transcription failed to produce results")
231
+
232
+ transcription = result.get("text", "Error in transcription")
233
+ logger.info(f"Transcription completed: {transcription[:50]}...")
234
+ return transcription
235
+ except Exception as e:
236
+ logger.error(f"Error transcribing: {str(e)}")
237
+ return f"Error processing the file: {str(e)}"
238
+
239
+ @lru_cache(maxsize=32)
240
+ def read_document(document_path):
241
+ """Read the content of a document."""
242
+ try:
243
+ if document_path.endswith(".pdf"):
244
+ doc = fitz.open(document_path)
245
+ return "\n".join([page.get_text() for page in doc])
246
+ elif document_path.endswith(".docx"):
247
+ doc = docx.Document(document_path)
248
+ return "\n".join([paragraph.text for paragraph in doc.paragraphs])
249
+ elif document_path.endswith(".xlsx"):
250
+ return pd.read_excel(document_path).to_string()
251
+ elif document_path.endswith(".csv"):
252
+ return pd.read_csv(document_path).to_string()
253
+ else:
254
+ return "Unsupported file type. Please upload a PDF, DOCX, XLSX or CSV document."
255
+ except Exception as e:
256
+ logger.error(f"Error reading document: {str(e)}")
257
+ return f"Error reading document: {str(e)}"
258
+
259
+ @lru_cache(maxsize=32)
260
+ def read_url(url):
261
+ """Read the content of a URL."""
262
+ try:
263
+ response = requests.get(url)
264
+ response.raise_for_status()
265
+ soup = BeautifulSoup(response.content, 'html.parser')
266
+ return soup.get_text()
267
+ except Exception as e:
268
+ logger.error(f"Error reading URL: {str(e)}")
269
+ return f"Error reading URL: {str(e)}"
270
+
271
+ def process_social_content(url):
272
+ """Process social media content."""
273
+ try:
274
+ text_content = read_url(url)
275
+ try:
276
+ video_content = transcribe_audio(url)
277
+ except Exception as e:
278
+ logger.error(f"Error processing video content: {str(e)}")
279
+ video_content = None
280
+
281
+ return {
282
+ "text": text_content,
283
+ "video": video_content
284
+ }
285
+ except Exception as e:
286
+ logger.error(f"Error processing social content: {str(e)}")
287
+ return None
288
+
289
+ @spaces.GPU(duration=120)
290
+ def generate_news(instructions, facts, size, tone, *args):
291
+ try:
292
+ tokenizer, _, news_generator, _ = model_manager.get_models()
293
+
294
+ knowledge_base = {
295
+ "instructions": instructions,
296
+ "facts": facts,
297
+ "document_content": [],
298
+ "audio_data": [],
299
+ "url_content": [],
300
+ "social_content": []
301
+ }
302
+
303
+ num_audios = 5 * 3
304
+ num_social_urls = 3 * 3
305
+ num_urls = 5
306
+
307
+ audios = args[:num_audios]
308
+ social_urls = args[num_audios:num_audios+num_social_urls]
309
+ urls = args[num_audios+num_social_urls:num_audios+num_social_urls+num_urls]
310
+ documents = args[num_audios+num_social_urls+num_urls:]
311
+
312
+ for url in urls:
313
+ if url:
314
+ content = read_url(url)
315
+ if content and not content.startswith("Error"):
316
+ knowledge_base["url_content"].append(content)
317
+
318
+ for document in documents:
319
+ if document is not None:
320
+ content = read_document(document.name)
321
+ if content and not content.startswith("Error"):
322
+ knowledge_base["document_content"].append(content)
323
+
324
+ for i in range(0, len(audios), 3):
325
+ audio_file, name, position = audios[i:i+3]
326
+ if audio_file is not None:
327
+ knowledge_base["audio_data"].append({
328
+ "audio": audio_file,
329
+ "name": name,
330
+ "position": position
331
+ })
332
+
333
+ for i in range(0, len(social_urls), 3):
334
+ social_url, social_name, social_context = social_urls[i:i+3]
335
+ if social_url:
336
+ social_content = process_social_content(social_url)
337
+ if social_content:
338
+ knowledge_base["social_content"].append({
339
+ "url": social_url,
340
+ "name": social_name,
341
+ "context": social_context,
342
+ "text": social_content["text"],
343
+ "video": social_content["video"]
344
+ })
345
+
346
+ transcriptions_text = ""
347
+ raw_transcriptions = ""
348
+
349
+ for idx, data in enumerate(knowledge_base["audio_data"]):
350
+ if data["audio"] is not None:
351
+ transcription = transcribe_audio(data["audio"])
352
+ if not transcription.startswith("Error"):
353
+ transcriptions_text += f'"{transcription}" - {data["name"]}, {data["position"]}\n'
354
+ raw_transcriptions += f'[Audio/Video {idx + 1}]: "{transcription}" - {data["name"]}, {data["position"]}\n\n'
355
+
356
+ for data in knowledge_base["social_content"]:
357
+ if data["text"] and not str(data["text"]).startswith("Error"):
358
+ transcriptions_text += f'[Social media text]: "{data["text"][:200]}..." - {data["name"]}, {data["context"]}\n'
359
+ raw_transcriptions += transcriptions_text + "\n\n"
360
+ if data["video"] and not str(data["video"]).startswith("Error"):
361
+ video_transcription = f'[Social media video]: "{data["video"]}" - {data["name"]}, {data["context"]}\n'
362
+ transcriptions_text += video_transcription
363
+ raw_transcriptions += video_transcription + "\n\n"
364
+
365
+ document_content = "\n\n".join(knowledge_base["document_content"])
366
+ url_content = "\n\n".join(knowledge_base["url_content"])
367
+
368
+
369
+ prompt = f"""[INST] You are a professional news writer. Write a news article based on the following information:
370
+
371
+ Instructions: {knowledge_base["instructions"]}
372
+ Facts: {knowledge_base["facts"]}
373
+ Additional content from documents: {document_content}
374
+ Additional content from URLs: {url_content}
375
+
376
+ Use these transcriptions as direct and indirect quotes:
377
+ {transcriptions_text}
378
+
379
+ Follow these requirements:
380
+ - Write a title
381
+ - Write a 15-word hook that complements the title
382
+ - Write the body with {size} words
383
+ - Use a {tone} tone
384
+ - Answer the 5 Ws (Who, What, When, Where, Why) in the first paragraph
385
+ - Use at least 80% direct quotes (in quotation marks)
386
+ - Use proper journalistic style
387
+ - Do not invent information
388
+ - Be rigorous with the provided facts [/INST]"""
389
+
390
+ # Optimize size and max tokens
391
+ max_tokens = min(int(size * 1.5), 512)
392
+
393
+ # Generate article with optimized settings
394
+ with torch.inference_mode():
395
+ try:
396
+ news_article = news_generator(
397
+ prompt,
398
+ max_new_tokens=max_tokens,
399
+ num_return_sequences=1,
400
+ do_sample=True,
401
+ temperature=0.7,
402
+ top_p=0.95,
403
+ repetition_penalty=1.2,
404
+ early_stopping=True
405
+ )
406
+
407
+ # Process the generated text
408
+ if isinstance(news_article, list):
409
+ news_article = news_article[0]['generated_text']
410
+ news_article = news_article.replace('[INST]', '').replace('[/INST]', '').strip()
411
+
412
+ except Exception as gen_error:
413
+ logger.error(f"Error in text generation: {str(gen_error)}")
414
+ raise
415
+
416
+ return news_article, raw_transcriptions
417
+
418
+ except Exception as e:
419
+ logger.error(f"Error generating news: {str(e)}")
420
+ try:
421
+ # Attempt to recover by resetting and reinitializing models
422
+ model_manager.reset_models()
423
+ model_manager.initialize_models()
424
+ logger.info("Models reinitialized successfully after error")
425
+ except Exception as reinit_error:
426
+ logger.error(f"Failed to reinitialize models: {str(reinit_error)}")
427
+ return f"Error generating the news article: {str(e)}", ""
428
+
429
+ def create_demo():
430
+ with gr.Blocks() as demo:
431
+ gr.Markdown("## Generador de noticias todo en uno")
432
+
433
+ with gr.Row():
434
+ with gr.Column(scale=2):
435
+ instrucciones = gr.Textbox(
436
+ label="Instrucciones para la noticia",
437
+ lines=2
438
+ )
439
+ hechos = gr.Textbox(
440
+ label="Describe los hechos de la noticia",
441
+ lines=4
442
+ )
443
+ tamaño = gr.Number(
444
+ label="Tamaño del cuerpo de la noticia (en palabras)",
445
+ value=100
446
+ )
447
+ tono = gr.Dropdown(
448
+ label="Tono de la noticia",
449
+ choices=["serio", "neutral", "divertido"],
450
+ value="neutral"
451
+ )
452
+
453
+ with gr.Column(scale=3):
454
+ inputs_list = [instrucciones, hechos, tamaño, tono]
455
+
456
+ with gr.Tabs():
457
+ for i in range(1, 6):
458
+ with gr.TabItem(f"Audio/Video {i}"):
459
+ file = gr.File(
460
+ label=f"Audio/Video {i}",
461
+ file_types=["audio", "video"]
462
+ )
463
+ nombre = gr.Textbox(
464
+ label="Nombre",
465
+ placeholder="Nombre del entrevistado"
466
+ )
467
+ cargo = gr.Textbox(
468
+ label="Cargo",
469
+ placeholder="Cargo o rol"
470
+ )
471
+ inputs_list.extend([file, nombre, cargo])
472
+
473
+ for i in range(1, 4):
474
+ with gr.TabItem(f"Red Social {i}"):
475
+ social_url = gr.Textbox(
476
+ label=f"URL de red social {i}",
477
+ placeholder="https://..."
478
+ )
479
+ social_nombre = gr.Textbox(
480
+ label=f"Nombre de persona/cuenta {i}"
481
+ )
482
+ social_contexto = gr.Textbox(
483
+ label=f"Contexto del contenido {i}",
484
+ lines=2
485
+ )
486
+ inputs_list.extend([social_url, social_nombre, social_contexto])
487
+
488
+ for i in range(1, 6):
489
+ with gr.TabItem(f"URL {i}"):
490
+ url = gr.Textbox(
491
+ label=f"URL {i}",
492
+ placeholder="https://..."
493
+ )
494
+ inputs_list.append(url)
495
+
496
+ for i in range(1, 6):
497
+ with gr.TabItem(f"Documento {i}"):
498
+ documento = gr.File(
499
+ label=f"Documento {i}",
500
+ file_types=["pdf", "docx", "xlsx", "csv"],
501
+ file_count="single"
502
+ )
503
+ inputs_list.append(documento)
504
+
505
+ gr.Markdown("---")
506
+
507
+ with gr.Row():
508
+ transcripciones_output = gr.Textbox(
509
+ label="Transcripciones",
510
+ lines=10,
511
+ show_copy_button=True
512
+ )
513
+
514
+ gr.Markdown("---")
515
+
516
+ with gr.Row():
517
+ generar = gr.Button("Generar borrador")
518
+
519
+ with gr.Row():
520
+ noticia_output = gr.Textbox(
521
+ label="Borrador generado",
522
+ lines=20,
523
+ show_copy_button=True
524
+ )
525
+
526
+ generar.click(
527
+ fn=generate_news,
528
+ inputs=inputs_list,
529
+ outputs=[noticia_output, transcripciones_output]
530
+ )
531
+
532
+ return demo
533
+
534
+ if __name__ == "__main__":
535
+ demo = create_demo()
536
+ demo.queue()
537
+ demo.launch(
538
+ share=True,
539
+ server_name="0.0.0.0",
540
+ server_port=7860
541
+ )