Sijuade commited on
Commit
87d9314
1 Parent(s): f59792e

Upload 3 files

Browse files
Files changed (3) hide show
  1. app.py +45 -0
  2. requirements.txt +6 -0
  3. utils.py +181 -0
app.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import random
3
+ import torch
4
+ from utils import *
5
+
6
+
7
+ with gr.Blocks() as demo:
8
+
9
+ with gr.Row():
10
+ show_label = True
11
+ gr.HTML(value=generate_html, show_label=show_label)
12
+
13
+ with gr.Row():
14
+ temp = gr.Slider(0, 1, value=0.2, label="Temperature", info="Choose between 0 and 1")
15
+ seed = gr.Slider(0, 1000, value=42, label="Seed", info="Select Random Seed")
16
+ max_tokens = gr.Slider(100, 1000, value=200, label="Max Tokens", info="Choose Max Tokens")
17
+
18
+ with gr.Row():
19
+ with gr.Column():
20
+ chatbot = gr.Chatbot()
21
+ msg = gr.Textbox(label='Message AI Assistant')
22
+ clear = gr.ClearButton([msg, chatbot])
23
+
24
+ def respond(message, chat_history, temp, seed, max_tokens):
25
+
26
+ torch.manual_seed(seed)
27
+ model_inputs = tokenizer(
28
+ [f"[INST] {message} [/INST]"],
29
+ return_tensors="pt", padding=True)
30
+ generated_ids = model.generate(**model_inputs, max_new_tokens=max_tokens)
31
+ result = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
32
+ bot_message = extract_responses(result[0])
33
+ chat_history.append((message, bot_message))
34
+
35
+ return "", chat_history
36
+
37
+ msg.submit(respond, [msg, chatbot, temp, seed, max_tokens], [msg, chatbot])
38
+
39
+ with gr.Row():
40
+ show_label = True
41
+ gr.HTML(value=generate_footer, show_label=show_label)
42
+
43
+ if __name__ == "__main__":
44
+ demo.launch()
45
+
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ torch
2
+ transformers
3
+ einops
4
+ git+https://github.com/huggingface/peft.git
5
+ accelerate
6
+ gradio
utils.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
+
4
+ def extract_responses(text):
5
+ """
6
+ Extracts and returns the responses from the text, excluding the parts
7
+ between and including the [INST] tags.
8
+
9
+ Args:
10
+ text (str): The input text containing responses and [INST] tags.
11
+
12
+ Returns:
13
+ str: The extracted responses.
14
+ """
15
+ import re
16
+
17
+ # Split the text by [INST] tags and accumulate non-tag parts
18
+ parts = re.split(r'\[INST\].*?\[/INST\]', text, flags=re.DOTALL)
19
+ cleaned_text = "".join(parts)
20
+
21
+ # Return the cleaned and trimmed text
22
+ return cleaned_text.strip()
23
+
24
+
25
+ def generate_html():
26
+
27
+ return(
28
+ '''
29
+ <!DOCTYPE html>
30
+ <html lang="en">
31
+ <head>
32
+ <meta charset="UTF-8">
33
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
34
+ <title>Your Gradio App</title>
35
+ <style>
36
+ @import url('https://fonts.googleapis.com/css2?family=Montserrat:wght@300;400&display=swap');
37
+
38
+ body, html {
39
+ margin: 0;
40
+ padding: 0;
41
+ font-family: 'Montserrat', sans-serif;
42
+ background: #f9f9f9;
43
+ }
44
+
45
+ header {
46
+ background-color: #e8f0fe;
47
+ color: #333;
48
+ text-align: center;
49
+ padding: 40px 20px;
50
+ border-radius: 0 0 25px 25px;
51
+ background-image: linear-gradient(to right, #a7c7e7, #c0d8f0);
52
+ box-shadow: 0 8px 16px 0 rgba(0,0,0,0.2);
53
+ position: relative;
54
+ overflow: hidden;
55
+ }
56
+
57
+ .background-shapes {
58
+ position: absolute;
59
+ top: 0;
60
+ left: 0;
61
+ right: 0;
62
+ bottom: 0;
63
+ background-image: linear-gradient(120deg, #a7c7e7 0%, #c0d8f0 100%);
64
+ opacity: 0.6;
65
+ animation: pulse 5s ease-in-out infinite alternate;
66
+ }
67
+
68
+ .header-content h1 {
69
+ font-size: 2.8em;
70
+ margin: 0;
71
+ }
72
+
73
+ .header-content p {
74
+ font-size: 1.3em;
75
+ margin-top: 20px;
76
+ }
77
+
78
+ @keyframes pulse {
79
+ from { background-size: 100% 100%; }
80
+ to { background-size: 110% 110%; }
81
+ }
82
+ </style>
83
+ </head>
84
+ <body>
85
+ <header>
86
+ <div class="background-shapes"></div>
87
+ <div class="header-content">
88
+ <h1>AI Assistant</h1>
89
+ <p>This interactive app leverages the power of a fine-tuned Phi 2 AI model to provide insightful responses. Type your query below and witness AI in action.</p>
90
+ </div>
91
+ </header>
92
+ <!-- Rest of your Gradio app goes here -->
93
+ </body>
94
+ </html>
95
+
96
+ ''')
97
+
98
+ def generate_footer():
99
+
100
+ return(
101
+ '''
102
+ <!DOCTYPE html>
103
+ <html lang="en">
104
+ <head>
105
+ <meta charset="UTF-8">
106
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
107
+ <title>Your Gradio App</title>
108
+ <style>
109
+ @import url('https://fonts.googleapis.com/css2?family=Roboto+Slab:wght@400;700&display=swap');
110
+
111
+ body, html {
112
+ margin: 0;
113
+ padding: 0;
114
+ font-family: 'Roboto Slab', serif;
115
+ background: #f9f9f9;
116
+ }
117
+
118
+ header, footer {
119
+ color: #333;
120
+ text-align: center;
121
+ padding: 40px 20px;
122
+ border-radius: 25px;
123
+ background: linear-gradient(120deg, #a7c7e7 0%, #c0d8f0 100%);
124
+ background-size: 200% 200%;
125
+ animation: gradientShift 8s ease-in-out infinite;
126
+ position: relative;
127
+ overflow: hidden;
128
+ }
129
+
130
+ .header-content, .footer-content {
131
+ position: relative;
132
+ z-index: 1;
133
+ }
134
+
135
+ .header-content h1, .footer-content p {
136
+ font-size: 2.8em;
137
+ margin: 0;
138
+ }
139
+
140
+ .header-content p, .footer-content p {
141
+ font-size: 1.3em;
142
+ margin-top: 20px;
143
+ }
144
+
145
+ @keyframes gradientShift {
146
+ 0% { background-position: 0% 50%; }
147
+ 50% { background-position: 100% 50%; }
148
+ 100% { background-position: 0% 50%; }
149
+ }
150
+
151
+ footer {
152
+ margin-top: 40px;
153
+ border-radius: 25px 25px 0 0;
154
+ }
155
+ </style>
156
+ </head>
157
+ <body>
158
+
159
+ <footer>
160
+ <div class="footer-content">
161
+ <p>This model was fine-tuned on a subset of the OpenAssistant dataset.</p>
162
+ </div>
163
+ </footer>
164
+ </body>
165
+ </html>
166
+
167
+ ''')
168
+
169
+
170
+
171
+ model = AutoModelForCausalLM.from_pretrained(
172
+ "microsoft/phi-2",
173
+ torch_dtype=torch.float32,
174
+ device_map="cpu",
175
+ trust_remote_code=True
176
+ )
177
+ model.load_adapter('checkpoint-780')
178
+
179
+
180
+ tokenizer = AutoTokenizer.from_pretrained('checkpoint-780', trust_remote_code=True)
181
+ tokenizer.pad_token = tokenizer.eos_token