Spaces:
Paused
Paused
import gradio as gr | |
from pydantic import BaseModel, Field | |
from story_beam_search.stories_generator import StoryGenerationSystem | |
genre_choices = [ | |
"children", | |
"mystery", | |
"adventure", | |
"sci-fi", | |
"fantasy", | |
"romance", | |
"comedy", | |
"drama", | |
"horror", | |
] | |
class InputModel(BaseModel): | |
prompt: str | |
genre: str | |
num_stories: int = Field(3, ge=2, le=7) | |
temperature: float = Field(2.5, ge=0.7, le=3.5) | |
max_length: int = Field(60, ge=30, le=200) | |
def create_story_generation_interface() -> gr.Interface: | |
# Initialize the story generation system | |
system = StoryGenerationSystem() | |
system.initialize() | |
def generate_stories( | |
prompt: str, genre: str, num_stories: int, temperature: float, max_length: int | |
) -> tuple[str, list[str]]: | |
""" | |
Generate and evaluate stories based on user input. | |
Returns a tuple of (detailed_scores, story_texts). | |
""" | |
# Validate inputs.Gradio seems to validate chioces but not the range of the values | |
input_values = InputModel( | |
prompt=prompt, | |
genre=genre, | |
num_stories=num_stories, | |
temperature=temperature, | |
max_length=max_length, | |
) | |
# Update beam search config with user parameters | |
system.beam_search.config.temperature = input_values.temperature | |
system.beam_search.config.max_length = input_values.max_length | |
# Generate and evaluate stories | |
ranked_stories = system.generate_and_evaluate( | |
input_values.prompt, | |
input_values.genre, | |
num_stories=input_values.num_stories, | |
) | |
# Format detailed scores | |
detailed_scores = "" | |
story_texts = [] | |
for i, (story, scores) in enumerate(ranked_stories, 1): | |
detailed_scores += f"Story {i}:\n" | |
detailed_scores += f"Total Score: {scores.total:.3f}\n" | |
detailed_scores += f"Coherence: {scores.coherence:.3f}\n" | |
detailed_scores += f"Fluency: {scores.fluency:.3f}\n" | |
detailed_scores += f"Genre Alignment: {scores.genre_alignment:.3f}\n" | |
detailed_scores += "-" * 50 + "\n" | |
story_texts.append(f"Story {i}:\n{story}\n") | |
return detailed_scores, "\n".join(story_texts) | |
# Define interface components | |
prompt_input = gr.Textbox( | |
label="Story Prompt", | |
placeholder="Enter the beginning of your story...", | |
lines=3, | |
) | |
genre_input = gr.Dropdown( | |
choices=genre_choices, | |
label="Genre", | |
value="fantasy", | |
) | |
num_stories_input = gr.Slider( | |
minimum=2, maximum=7, value=3, step=1, label="Number of Stories to Generate" | |
) | |
temperature_input = gr.Slider( | |
minimum=0.7, maximum=3.5, value=2.5, step=0.1, label="Temperature (Creativity)" | |
) | |
max_length_input = gr.Slider( | |
minimum=40, maximum=200, value=60, step=20, label="Maximum Length" | |
) | |
# Output components | |
scores_output = gr.Textbox(label="Detailed Scores", lines=10, interactive=False) | |
stories_output = gr.Textbox(label="Generated Stories", lines=15, interactive=False) | |
# Create the interface | |
interface = gr.Interface( | |
fn=generate_stories, | |
inputs=[ | |
prompt_input, | |
genre_input, | |
num_stories_input, | |
temperature_input, | |
max_length_input, | |
], | |
outputs=[scores_output, stories_output], | |
title="AI Story Generator", | |
description=""" | |
Generate creative stories using AI! Enter a prompt and choose your preferences. | |
The system will generate multiple stories and evaluate them based on coherence, | |
fluency, and genre alignment. | |
""", | |
examples=[ | |
[ | |
"Once upon a time in a magical forest, the trees whispered secrets, and moonlight revealed hidden paths to a realm where time stood still.", | |
"fantasy", | |
3, | |
1.8, | |
150, | |
], | |
[ | |
"The detective knelt beside the bloodstained carpet, her gaze sharp as she traced the faint outline of a shoeprint.", | |
"mystery", | |
3, | |
2.7, | |
200, | |
], | |
], | |
theme=gr.themes.Soft(), | |
) | |
return interface | |
if __name__ == "__main__": | |
# Create and launch the interface | |
interface = create_story_generation_interface() | |
interface.launch(show_error=True) | |