koey811 commited on
Commit
b92cea6
·
verified ·
1 Parent(s): 4dfca06

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -7
app.py CHANGED
@@ -17,10 +17,12 @@ except ImportError:
17
  # Load the image captioning model
18
  caption_model = pipeline("image-to-text", model="unography/blip-large-long-cap")
19
 
20
- story_generator = pipeline("text-generation", model="distilbert/distilgpt2")
21
 
22
  #story_generator = pipeline("text-generation", model="isarth/distill_gpt2_story_generator")
23
 
 
 
24
  def generate_caption(image):
25
  # Generate the caption for the uploaded image
26
  caption = caption_model(image)[0]["generated_text"]
@@ -28,15 +30,11 @@ def generate_caption(image):
28
 
29
  def generate_story(caption):
30
  # Generate the story based on the caption using the GPT-2 model
31
- prompt = f"Once upon a time, in a world inspired by the image of {caption}, a delightful children's story took place. The story, suitable for ages 3-10, goes like this:\n\nIntroduction (1-2 sentences): Introduce the main character(s) and the setting.\n\nBeginning (2-3 sentences): Describe the character's normal life or routine.\n\nMiddle (3-4 sentences): Present a problem or challenge the character faces.\n\nEnd (2-3 sentences): Show how the character solves the problem or learns a lesson.\n\nThe story should be simple, engaging, and convey a positive message. Let's begin the tale:\n\n"
32
  story = story_generator(prompt, max_length=500, num_return_sequences=1)[0]["generated_text"]
33
 
34
  # Extract the story text from the generated output
35
- story_parts = story.split("\n\n")
36
- if len(story_parts) > 7:
37
- story = "\n\n".join(story_parts[7:]).strip()
38
- else:
39
- story = "\n\n".join(story_parts).strip()
40
 
41
  # Post-process the story (example: remove inappropriate words)
42
  inappropriate_words = ["violence", "horror", "scary", "adult", "death", "gun", "shoot"]
 
17
  # Load the image captioning model
18
  caption_model = pipeline("image-to-text", model="unography/blip-large-long-cap")
19
 
20
+ #story_generator = pipeline("text-generation", model="distilbert/distilgpt2")
21
 
22
  #story_generator = pipeline("text-generation", model="isarth/distill_gpt2_story_generator")
23
 
24
+ story_generator = pipeline("text-generation", model="TheBloke/storytime-13B-GPTQ")
25
+
26
  def generate_caption(image):
27
  # Generate the caption for the uploaded image
28
  caption = caption_model(image)[0]["generated_text"]
 
30
 
31
  def generate_story(caption):
32
  # Generate the story based on the caption using the GPT-2 model
33
+ prompt = f"Write a short, simple children's story inspired by the image of {caption}. Here's the story:\n\n"
34
  story = story_generator(prompt, max_length=500, num_return_sequences=1)[0]["generated_text"]
35
 
36
  # Extract the story text from the generated output
37
+ story = story.strip()
 
 
 
 
38
 
39
  # Post-process the story (example: remove inappropriate words)
40
  inappropriate_words = ["violence", "horror", "scary", "adult", "death", "gun", "shoot"]