Brian Watson commited on
Commit
df19b76
β€’
1 Parent(s): b459cca

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -19
app.py CHANGED
@@ -1,35 +1,26 @@
1
  import torch
2
  import gradio as gr
3
- import json
4
  from transformers import GPT2Tokenizer, GPT2LMHeadModel, pipeline
5
 
6
  tokenizer = GPT2Tokenizer.from_pretrained('distilgpt2')
7
  tokenizer.add_special_tokens({'pad_token': '[PAD]'})
8
  model = GPT2LMHeadModel.from_pretrained('FredZhang7/anime-anything-promptgen-v2')
9
 
10
- # prompt = r'1girl, genshin'
11
-
12
- # generate text using fine-tuned model
13
  nlp = pipeline('text-generation', model=model, tokenizer=tokenizer)
14
 
15
  def generate(prompt):
16
- # generate 10 samples using contrastive search
17
  outs = nlp(prompt, max_length=76, num_return_sequences=3, do_sample=True, repetition_penalty=1.2, temperature=0.7, top_k=3, early_stopping=True)
18
- jsonStr = json.dumps(outs)
19
- print(prompt)
20
- print(jsonStr)
21
- return jsonStr
22
-
 
23
 
24
- # for i in range(len(outs)):
25
- # remove trailing commas and double spaces
26
- # outs[i] = str(outs[i]['generated_text']).replace(' ', '').rstrip(',')
27
- # print('\033[92m' + '\n\n'.join(outs) + '\033[0m\n')
28
- # print(str(outs[i]['generated_text']))
29
-
30
- input_component = gr.Textbox(label = "Input a prompt", value = "1girl, genshin")
31
- output_component = gr.Textbox(label = "detail Prompt")
32
  examples = []
33
  description = ""
34
- gr.Interface(generate, inputs = input_component, outputs=output_component, examples=examples, title = "anything prompt", description=description).launch()
35
 
 
 
1
  import torch
2
  import gradio as gr
 
3
  from transformers import GPT2Tokenizer, GPT2LMHeadModel, pipeline
4
 
5
  tokenizer = GPT2Tokenizer.from_pretrained('distilgpt2')
6
  tokenizer.add_special_tokens({'pad_token': '[PAD]'})
7
  model = GPT2LMHeadModel.from_pretrained('FredZhang7/anime-anything-promptgen-v2')
8
 
 
 
 
9
  nlp = pipeline('text-generation', model=model, tokenizer=tokenizer)
10
 
11
  def generate(prompt):
12
+ output = ''
13
  outs = nlp(prompt, max_length=76, num_return_sequences=3, do_sample=True, repetition_penalty=1.2, temperature=0.7, top_k=3, early_stopping=True)
14
+ for i in range(len(outs)):
15
+ generated_text = str(outs[i]['generated_text']).replace(' ', '').rstrip(',')
16
+ output += generated_text
17
+ if i < len(outs)-1:
18
+ output += '\n\n'
19
+ return output
20
 
21
+ input_component = gr.Textbox(label="Prompt Idea", value="")
22
+ output_component = gr.Textbox(label="Extended Prompts")
 
 
 
 
 
 
23
  examples = []
24
  description = ""
 
25
 
26
+ gr.Interface(generate, inputs=input_component, outputs=output_component, examples=examples, title="Anime Prompt Gen", description=description).launch()