|
from smolagents.tools import Tool |
|
from typing import Optional |
|
import os |
|
from transformers import pipeline |
|
import requests |
|
import io |
|
from PIL import Image |
|
|
|
|
|
|
|
|
|
class TranscriptSummarizer(Tool): |
|
description = "Summarizes a transcript and generates blog content using the transformers library and Hugging Face API for image generation." |
|
name = "transcript_summarizer" |
|
inputs = {'transcript': {'type': 'string', 'description': 'The transcript to summarize.'}} |
|
output_type = "string" |
|
|
|
def __init__(self, *args, hf_api_key: str = None, **kwargs): |
|
super().__init__(*args, **kwargs) |
|
self.summarizer = pipeline("summarization", model="google/pegasus-xsum") |
|
self.api_url = "https://api-inference.huggingface.co/models/ZB-Tech/Text-to-Image" |
|
self.hf_api_key = hf_api_key |
|
self.headers = {"Authorization": f"Bearer {self.hf_api_key}"} |
|
|
|
def query(self, payload): |
|
response = requests.post(self.api_url, headers=self.headers, json=payload) |
|
return response.content |
|
|
|
def forward(self, transcript: str) -> str: |
|
try: |
|
if not self.hf_api_key: |
|
return "Hugging Face API key is required. Please provide it in the input field." |
|
|
|
transcript_length = len(transcript) |
|
|
|
def get_summary_lengths(length): |
|
|
|
max_length = int(length * 0.8) |
|
min_length = int(length * 0.2) |
|
return max_length, min_length |
|
|
|
|
|
if transcript_length < 500: |
|
return "Transcript is too short to summarize." |
|
chunk_size = 500 |
|
transcript_chunks = [transcript[i:i+chunk_size] for i in range(0, len(transcript), chunk_size)] |
|
|
|
|
|
summaries = [] |
|
for chunk in transcript_chunks: |
|
max_length, min_length = get_summary_lengths(len(chunk)) |
|
summary = self.summarizer(chunk, max_length=max_length, min_length=min_length, do_sample=False)[0]['summary_text'] |
|
summaries.append(summary) |
|
|
|
|
|
full_summary = "\n".join(summaries) |
|
|
|
key_entities = full_summary.split()[:15] |
|
image_prompt = f"Generate an image related to: {' '.join(key_entities)}, cartoon style" |
|
image_bytes = self.query({"inputs": image_prompt}) |
|
image = Image.open(io.BytesIO(image_bytes)) |
|
image_folder = "Image" |
|
if not os.path.exists(image_folder): |
|
os.makedirs(image_folder) |
|
image_url = os.path.join(image_folder, "image.jpg") |
|
image.save(image_url) |
|
return f"{full_summary}\n\nImage URL: {image_url}" |
|
except Exception as e: |
|
return f"An unexpected error occurred: {str(e)}" |
|
|
|
class YouTubeTranscriptExtractor(Tool): |
|
description = "Extracts the transcript from a YouTube video." |
|
name = "youtube_transcript_extractor" |
|
inputs = {'video_url': {'type': 'string', 'description': 'The URL of the YouTube video.'}} |
|
output_type = "string" |
|
|
|
def forward(self, video_url: str) -> str: |
|
try: |
|
from pytubefix import YouTube |
|
|
|
yt = YouTube(video_url) |
|
lang='en' |
|
|
|
if lang in yt.captions: |
|
transcript = yt.captions['en'].generate_srt_captions() |
|
else: |
|
transcript = yt.captions.all()[0].generate_srt_captions() |
|
lang = yt.captions.all()[0].code |
|
|
|
|
|
cleaned_transcript = "" |
|
for line in transcript.splitlines(): |
|
if not line.strip().isdigit() and "-->" not in line: |
|
cleaned_transcript += line + "\n" |
|
|
|
print("transcript : ", cleaned_transcript) |
|
return cleaned_transcript |
|
except Exception as e: |
|
return f"An unexpected error occurred: {str(e)}" |
|
|
|
def __init__(self, *args, **kwargs): |
|
self.is_initialized = False |
|
|