File size: 5,834 Bytes
0db3431
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
import random
import re

import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import AutoModelForSeq2SeqLM
from transformers import AutoProcessor
from transformers import pipeline, set_seed

device = "cuda" if torch.cuda.is_available() else "cpu"
big_processor = AutoProcessor.from_pretrained("microsoft/git-base-coco")
big_model = AutoModelForCausalLM.from_pretrained("microsoft/git-base-coco")

text_pipe = pipeline('text-generation', model='succinctly/text2image-prompt-generator')

zh2en_model = AutoModelForSeq2SeqLM.from_pretrained('Helsinki-NLP/opus-mt-zh-en').eval()
zh2en_tokenizer = AutoTokenizer.from_pretrained('Helsinki-NLP/opus-mt-zh-en')
en2zh_model = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-en-zh").eval()
en2zh_tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-zh")


def load_prompter():
    prompter_model = AutoModelForCausalLM.from_pretrained("microsoft/Promptist")
    tokenizer = AutoTokenizer.from_pretrained("gpt2")
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = "left"
    return prompter_model, tokenizer


prompter_model, prompter_tokenizer = load_prompter()


def generate_prompter(plain_text, max_new_tokens=75, num_beams=8, num_return_sequences=8, length_penalty=-1.0):
    input_ids = prompter_tokenizer(plain_text.strip() + " Rephrase:", return_tensors="pt").input_ids
    eos_id = prompter_tokenizer.eos_token_id
    outputs = prompter_model.generate(
        input_ids,
        do_sample=False,
        max_new_tokens=max_new_tokens,
        num_beams=num_beams,
        num_return_sequences=num_return_sequences,
        eos_token_id=eos_id,
        pad_token_id=eos_id,
        length_penalty=length_penalty
    )
    output_texts = prompter_tokenizer.batch_decode(outputs, skip_special_tokens=True)
    result = []
    for output_text in output_texts:
        result.append(output_text.replace(plain_text + " Rephrase:", "").strip())

    return "\n".join(result)


def translate_zh2en(text):
    with torch.no_grad():
        encoded = zh2en_tokenizer([text], return_tensors='pt')
        sequences = zh2en_model.generate(**encoded)
        return zh2en_tokenizer.batch_decode(sequences, skip_special_tokens=True)[0]


def translate_en2zh(text):
    with torch.no_grad():
        encoded = en2zh_tokenizer([text], return_tensors="pt")
        sequences = en2zh_model.generate(**encoded)
        return en2zh_tokenizer.batch_decode(sequences, skip_special_tokens=True)[0]


def text_generate(text_in_english):
    seed = random.randint(100, 1000000)
    set_seed(seed)

    result = ""
    for _ in range(6):
        sequences = text_pipe(text_in_english, max_length=random.randint(60, 90), num_return_sequences=8)
        list = []
        for sequence in sequences:
            line = sequence['generated_text'].strip()
            if line != text_in_english and len(line) > (len(text_in_english) + 4) and line.endswith(
                    (':', '-', '—')) is False:
                list.append(line)

        result = "\n".join(list)
        result = re.sub('[^ ]+\.[^ ]+', '', result)
        result = result.replace('<', '').replace('>', '')
        if result != '':
            break
    return result, "\n".join(translate_en2zh(line) for line in result.split("\n") if len(line) > 0)


def get_prompt_from_image(input_image):
    image = input_image.convert('RGB')
    pixel_values = big_processor(images=image, return_tensors="pt").to(device).pixel_values

    generated_ids = big_model.to(device).generate(pixel_values=pixel_values, max_length=50)
    generated_caption = big_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
    print(generated_caption)
    return generated_caption


with gr.Blocks() as block:
    with gr.Column():
        with gr.Tab('文本生成'):
            with gr.Row():
                input_text = gr.Textbox(lines=6, label='你的想法', placeholder='在此输入内容...')
                translate_output = gr.Textbox(lines=6, label='翻译结果(Prompt输入)')

            with gr.Accordion('SD优化参数设置', open=False):
                max_new_tokens = gr.Slider(1, 255, 75, label='max_new_tokens', step=1)
                nub_beams = gr.Slider(1, 30, 8, label='num_beams', step=1)
                num_return_sequences = gr.Slider(1, 30, 8, label='num_return_sequences', step=1)
                length_penalty = gr.Slider(-1.0, 1.0, -1.0, label='length_penalty')

            generate_prompter_output = gr.Textbox(lines=6, label='SD优化的 Prompt')

            output = gr.Textbox(lines=6, label='瞎编的 Prompt')
            output_zh = gr.Textbox(lines=6, label='瞎编的 Prompt(zh)')
            with gr.Row():
                translate_btn = gr.Button('翻译')
                generate_prompter_btn = gr.Button('SD优化')
                gpt_btn = gr.Button('瞎编')

        with gr.Tab('从图片中生成'):
            with gr.Row():
                input_image = gr.Image(type='pil')
            img_btn = gr.Button('提交')
            output_image = gr.Textbox(lines=6, label='生成的 Prompt')
    translate_btn.click(
        fn=translate_zh2en,
        inputs=input_text,
        outputs=translate_output
    )
    generate_prompter_btn.click(
        fn=generate_prompter,
        inputs=[translate_output, max_new_tokens, nub_beams, num_return_sequences, length_penalty],
        outputs=generate_prompter_output
    )
    gpt_btn.click(
        fn=text_generate,
        inputs=translate_output,
        outputs=[output, output_zh]
    )
    img_btn.click(
        fn=get_prompt_from_image,
        inputs=input_image,
        outputs=output_image
    )

block.queue(max_size=64).launch(show_api=False, enable_queue=True, debug=True, share=False, server_name='0.0.0.0')