File size: 3,337 Bytes
7a2c0ff
 
 
 
 
 
 
 
 
9ac5adf
7a2c0ff
 
 
 
 
 
9ac5adf
7a2c0ff
 
 
79d4bfd
7a2c0ff
 
 
9ac5adf
 
 
 
 
 
 
7a2c0ff
 
 
 
6829958
7a2c0ff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6829958
 
7a2c0ff
 
 
 
 
 
 
885d7a0
 
 
7a2c0ff
 
 
9ac5adf
7a2c0ff
9ac5adf
 
e61925e
9ac5adf
 
 
 
 
7a2c0ff
 
9ac5adf
 
 
 
7a2c0ff
 
9ac5adf
7a2c0ff
 
6829958
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import openai
import gradio as gr


instructions = {
    "completion": "Please help me complete the text",
    "correction": "Please help me correct mistakes in the text",
    "polishing": "Please help me polish the language and improve my writing",
    "paraphrase": "Please help me paraphrase the text",
    "translation": "Please help me translate the text",
    "freestyle": "",
}

template = "{instruction}:\n\nText: {text}"


def chat(task_type: str, text: str, api_key: str, tgt_lang: str = "") -> str:
    openai.api_key = api_key

    prompt = ""
    task_type = task_type[1:].strip().lower()
    if task_type == "freestyle":
        prompt = text
    else:
        instruction = instructions[task_type]
        if task_type == "translation":
            if tgt_lang:
                instruction += f" into {tgt_lang.strip()}"
            else:
                raise ValueError("Target language cannot be empty when translating")
        prompt = template.format(instruction=instruction, text=text)

    messages = [
        {
            "role": "system",
            "content": f"You are a helpful writing assistant who can do {task_type}.",
        },
        {"role": "user", "content": prompt},
    ]
    finish_reason = None
    while finish_reason != "stop":
        if len(messages) > 2 and messages[-1]["role"] == "assistant":
            messages.append({"role": "user", "content": "please continue"})
        res = openai.ChatCompletion.create(
            model="gpt-3.5-turbo-0301",
            messages=messages,
        )
        messages.append(res["choices"][0]["message"])
        finish_reason = res["choices"][0]["finish_reason"]
        if len(messages) >= 5:
            break
    response_text = " ".join(
        [msg["content"] for msg in messages if msg["role"] == "assistant"]
    ).strip()

    return response_text


with gr.Blocks(css="") as demo:
    gr.Markdown("# ✒️ Writing Comrade")
    gr.Markdown("Comrade, I'm your faithful writing fellow powered by ChatGPT. Destination, commander?")
    gr.Markdown(
        "🎮 This demo is hosted on: [Huggingface Spaces](https://huggingface.co./spaces/Spico/writing-comrade) <br />"
        "⭐ Star me on GitHub: [Spico197/writing-comrade](https://github.com/Spico197/writing-comrade) <br />"
        "You may want to follow [this instruction](https://help.openai.com/en/articles/4936850-where-do-i-find-my-secret-api-key) to get an API key."
    )

    with gr.Row():
        api_key = gr.Textbox(label='OpenAI API Key', type="password")

    with gr.Row().style(equal_height=True):
        with gr.Column(scale=3):
            emojis = "📝🥊💎🍦🚌🎤"
            task_type = gr.Radio([f"{emojis[i]}{k.title()}" for i, k in enumerate(instructions.keys())], label="Task")
        with gr.Column(min_width=100):
            tgt_lang = gr.Textbox(label="Target language in translation")
        with gr.Column():
            text_button = gr.Button("Can~ do!", variant="primary")

    with gr.Row():
        with gr.Column():
            text_input = gr.TextArea(lines=15, label="Input")
        with gr.Column():
            text_output = gr.TextArea(lines=15, label="Output")

        text_button.click(
            chat, inputs=[task_type, text_input, api_key, tgt_lang], outputs=text_output
        )

demo.launch(show_error=True)