|
import os |
|
import io |
|
import torch |
|
import gradio as gr |
|
import wikipediaapi |
|
import re |
|
import inflect |
|
import soundfile as sf |
|
import unicodedata |
|
import num2words |
|
from PIL import Image |
|
from datasets import load_dataset |
|
from scipy.io.wavfile import write |
|
|
|
from transformers import VitsModel, AutoTokenizer |
|
from transformers import pipeline |
|
from transformers import T5ForConditionalGeneration, T5Tokenizer |
|
from transformers import SpeechT5Processor, SpeechT5ForTextToSpeech, SpeechT5HifiGan |
|
|
|
from google.cloud import vision |
|
|
|
from transformers import CLIPProcessor, CLIPModel |
|
|
|
|
|
|
|
|
|
|
|
|
|
def clean_text(text): |
|
|
|
text = re.sub(r'МФА:?\s?\[.*?\]', '', text) |
|
text = re.sub(r'\[.*?\]', '', text) |
|
|
|
def remove_diacritics(char): |
|
if unicodedata.category(char) == 'Mn': |
|
return '' |
|
return char |
|
|
|
text = unicodedata.normalize('NFD', text) |
|
text = ''.join(remove_diacritics(char) for char in text) |
|
text = unicodedata.normalize('NFC', text) |
|
|
|
text = re.sub(r'\s+', ' ', text) |
|
text = re.sub(r'[^\w\s.,!?-]', '', text) |
|
|
|
return text.strip() |
|
|
|
from num2words import num2words |
|
|
|
def number_to_russian_text(number): |
|
try: |
|
return num2words(number, lang='ru') |
|
except NotImplementedError: |
|
return "Ошибка: Не поддерживается преобразование для этого числа." |
|
|
|
summarization_model = pipeline("summarization", model="facebook/bart-large-cnn") |
|
|
|
wiki = wikipediaapi.Wikipedia("Nikita", "en") |
|
|
|
embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation") |
|
|
|
t2s_pipe = pipeline("text-to-speech", model="facebook/mms-tts-rus") |
|
|
|
translator = pipeline("translation_en_to_ru", model="Helsinki-NLP/opus-mt-en-ru") |
|
|
|
def text_to_speech(text, output_path="speech.wav"): |
|
text = number_to_russian_text(text) |
|
model = VitsModel.from_pretrained("facebook/mms-tts-rus") |
|
tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-rus") |
|
|
|
inputs = tokenizer(text, return_tensors="pt") |
|
|
|
with torch.no_grad(): |
|
output = model(**inputs).waveform.squeeze().numpy() |
|
|
|
sf.write(output_path, output, samplerate=model.config.sampling_rate) |
|
|
|
return output_path |
|
|
|
def fetch_wikipedia_summary(landmark): |
|
page = wiki.page(landmark) |
|
if page.exists(): |
|
return clean_text(page.summary) |
|
else: |
|
return "Found error!" |
|
|
|
def recognize_landmark_google_cloud(image): |
|
client = vision.ImageAnnotatorClient() |
|
|
|
if not isinstance(image, Image.Image): |
|
image = Image.fromarray(image) |
|
|
|
img_bytes = io.BytesIO() |
|
image.save(img_bytes, format='PNG') |
|
content = img_bytes.getvalue() |
|
vision_image = vision.Image(content=content) |
|
|
|
response = client.landmark_detection(image=vision_image) |
|
landmarks = response.landmark_annotations |
|
if landmarks: |
|
return landmarks[0].description |
|
else: |
|
return "Unknown" |
|
|
|
def tourist_helper_english(landmark): |
|
wiki_text = fetch_wikipedia_summary(landmark) |
|
if wiki_text == "Found error!": |
|
return None |
|
|
|
summarized_text = summarization_model(wiki_text, min_length=20, max_length=210)[0]["summary_text"] |
|
audio_path = text_to_speech(summarized_text) |
|
return audio_path |
|
|
|
def process_image_google_cloud(image): |
|
recognized = recognize_landmark_google_cloud(image) |
|
print(f"[GoogleVision] Распознано: {recognized}") |
|
audio_path = tourist_helper_english(recognized) |
|
return audio_path |
|
|
|
def process_text_google_cloud(landmark): |
|
return tourist_helper_english(landmark) |
|
|
|
|
|
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") |
|
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") |
|
|
|
text_inputs = clip_processor( |
|
text=landmark_titles, |
|
images=None, |
|
return_tensors="pt", |
|
padding=True |
|
) |
|
with torch.no_grad(): |
|
text_embeds = clip_model.get_text_features(**text_inputs) |
|
text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True) |
|
|
|
def recognize_landmark_clip(image): |
|
if not isinstance(image, Image.Image): |
|
image = Image.fromarray(image) |
|
|
|
image_inputs = clip_processor(images=image, return_tensors="pt") |
|
with torch.no_grad(): |
|
image_embed = clip_model.get_image_features(**image_inputs) |
|
image_embed = image_embed / image_embed.norm(p=2, dim=-1, keepdim=True) |
|
|
|
similarity = (image_embed @ text_embeds.T).squeeze(0) |
|
best_idx = similarity.argmax().item() |
|
best_score = similarity[best_idx].item() |
|
recognized_landmark = landmark_titles[best_idx] |
|
return recognized_landmark, best_score |
|
|
|
def tourist_helper_with_russian(landmark): |
|
wiki_text = fetch_wikipedia_summary(landmark) |
|
if wiki_text == "Found error!": |
|
return None |
|
|
|
print(wiki_text) |
|
summarized_text = summarization_model(wiki_text, min_length=20, max_length=210)[0]["summary_text"] |
|
print(summarized_text) |
|
|
|
translated = translator(summarized_text, max_length=1000)[0]["translation_text"] |
|
print(translated) |
|
|
|
audio_path = text_to_speech(translated) |
|
return audio_path |
|
|
|
def process_image_clip(image): |
|
recognized, score = recognize_landmark_clip(image) |
|
print(f"[CLIP] Распознано: {recognized}, score={score:.2f}") |
|
audio_path = tourist_helper_with_russian(recognized) |
|
return audio_path |
|
|
|
def process_text_clip(landmark): |
|
return tourist_helper_with_russian(landmark) |
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("## Две демки: Google Cloud Vision и CLIP (с переводом на русский)") |
|
|
|
with gr.Tabs(): |
|
with gr.Tab("CLIP + Sum + Translate + T2S"): |
|
gr.Markdown("### Распознавание (CLIP) и перевод на русский") |
|
|
|
with gr.Row(): |
|
image_input_c = gr.Image(label="Загрузите фото", type="pil") |
|
text_input_c = gr.Textbox(label="Или введите название") |
|
|
|
audio_output_c = gr.Audio(label="Результатт") |
|
|
|
with gr.Row(): |
|
btn_recognize_c = gr.Button("Распознать и перевести на русский") |
|
btn_text_c = gr.Button("Поиск по тексту") |
|
|
|
btn_recognize_c.click( |
|
fn=process_image_clip, |
|
inputs=image_input_c, |
|
outputs=audio_output_c |
|
) |
|
btn_text_c.click( |
|
fn=process_text_clip, |
|
inputs=text_input_c, |
|
outputs=audio_output_c |
|
) |
|
|
|
with gr.Tab("Google + Sum + T2S"): |
|
gr.Markdown("### Распознавание достопримечательности (Google)") |
|
|
|
with gr.Row(): |
|
image_input_g = gr.Image(label="Загрузите фото", type="pil") |
|
text_input_g = gr.Textbox(label="Или введите название вручную") |
|
|
|
audio_output_g = gr.Audio(label="Результат") |
|
|
|
with gr.Row(): |
|
btn_recognize_g = gr.Button("Распознать и озвучить") |
|
btn_text_g = gr.Button("Распознать по тексту и озвучить") |
|
|
|
btn_recognize_g.click( |
|
fn=process_image_google_cloud, |
|
inputs=image_input_g, |
|
outputs=audio_output_g |
|
) |
|
btn_text_g.click( |
|
fn=process_text_google_cloud, |
|
inputs=text_input_g, |
|
outputs=audio_output_g |
|
) |
|
|
|
demo.launch(debug=True) |
|
|