Spaces:
Running
Running
import streamlit as st | |
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline | |
from gtts import gTTS | |
import io | |
from PIL import Image | |
# Install PyTorch | |
try: | |
import torch | |
except ImportError: | |
st.warning("PyTorch is not installed. Installing PyTorch...") | |
import subprocess | |
subprocess.run(["pip", "install", "torch"]) | |
st.success("PyTorch has been successfully installed!") | |
import torch | |
# Load the image captioning model | |
caption_model = pipeline("image-to-text", model="unography/blip-large-long-cap") | |
story_generator = pipeline("text-generation", model="distilbert/distilgpt2") | |
#story_generator = pipeline("text-generation", model="isarth/distill_gpt2_story_generator") | |
def generate_caption(image): | |
# Generate the caption for the uploaded image | |
caption = caption_model(image)[0]["generated_text"] | |
return caption | |
def generate_story(caption): | |
# Generate the story based on the caption using the GPT-2 model | |
prompt = f"Once upon a time, in a world inspired by the image of {caption}, a delightful children's story took place. The story, suitable for ages 3-10, goes like this:\n\nIntroduction (1-2 sentences): Introduce the main character(s) and the setting.\n\nBeginning (2-3 sentences): Describe the character's normal life or routine.\n\nMiddle (3-4 sentences): Present a problem or challenge the character faces.\n\nEnd (2-3 sentences): Show how the character solves the problem or learns a lesson.\n\nThe story should be simple, engaging, and convey a positive message. Let's begin the tale:\n\n" | |
story = story_generator(prompt, max_length=500, num_return_sequences=1)[0]["generated_text"] | |
# Extract the story text from the generated output | |
story_parts = story.split("\n\n") | |
if len(story_parts) > 7: | |
story = "\n\n".join(story_parts[7:]).strip() | |
else: | |
story = "\n\n".join(story_parts).strip() | |
# Post-process the story (example: remove inappropriate words) | |
inappropriate_words = ["violence", "horror", "scary", "adult", "death", "gun", "shoot"] | |
for word in inappropriate_words: | |
story = story.replace(word, "") | |
# Limit the story to approximately 100 words | |
words = story.split() | |
if len(words) > 100: | |
story = " ".join(words[:100]) + "..." | |
return story | |
def convert_to_audio(story): | |
# Convert the story to audio using gTTS | |
tts = gTTS(text=story, lang="en") | |
audio_bytes = io.BytesIO() | |
tts.write_to_fp(audio_bytes) | |
audio_bytes.seek(0) | |
return audio_bytes | |
def main(): | |
st.title("Storytelling Application") | |
# File uploader for the image (restricted to JPG) | |
uploaded_image = st.file_uploader("Upload an image", type=["jpg"]) | |
if uploaded_image is not None: | |
# Convert the uploaded image to PIL image | |
image = Image.open(uploaded_image) | |
# Display the uploaded image | |
st.image(image, caption="Uploaded Image", use_container_width=True) | |
# Generate the caption for the image | |
caption = generate_caption(image) | |
st.subheader("Generated Caption:") | |
st.write(caption) | |
# Generate the story based on the caption using the GPT-2 model | |
story = generate_story(caption) | |
st.subheader("Generated Story:") | |
st.write(story) | |
# Convert the story to audio | |
audio_bytes = convert_to_audio(story) | |
# Display the audio player | |
st.audio(audio_bytes, format="audio/mp3") | |
if __name__ == "__main__": | |
main() |