alexkueck commited on
Commit
72b1673
·
1 Parent(s): 4dc9c10

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +74 -195
app.py CHANGED
@@ -6,6 +6,7 @@ import gradio as gr
6
  import torch
7
  from utils import *
8
  from presets import *
 
9
 
10
 
11
  ######################################################################
@@ -19,12 +20,9 @@ base_model = "project-baize/baize-v2-7b" #load_8bit = False (in load_tokenizer_
19
  tokenizer,model,device = load_tokenizer_and_model(base_model,False)
20
  dataset_neu = daten_laden("alexkueck/tis")
21
 
22
- ###################################
23
- #Vorbereiten für das training der neuen Daten
24
- #Datensets in den Tokenizer schieben...
25
- def tokenize_function(examples):
26
- return tokenizer(examples["text"])
27
-
28
 
29
  #alles zusammen auf das neue datenset anwenden - batched = True und 4 Prozesse, um die Berechnung zu beschleunigen. Die "text" - Spalte braucht man anschließend nicht mehr, daher weglassen.
30
  tokenized_datasets = dataset_neu.map(tokenize_function, batched=True, num_proc=4, remove_columns=["text"])
@@ -36,203 +34,84 @@ tokenized_datasets = dataset_neu.map(tokenize_function, batched=True, num_proc=4
36
  # block_size = tokenizer.model_max_length
37
  block_size = 128
38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
 
 
40
 
 
 
41
 
42
- ########################################################################
43
- #Chat KI nutzen, um Text zu generieren...
44
- def predict(text,
45
- chatbotGr,
46
- history,
47
- top_p,
48
- temperature,
49
- max_length_tokens,
50
- max_context_length_tokens,):
51
- if text=="":
52
- yield chatbotGr,history,"Empty context."
53
- return
54
- try:
55
- model
56
- except:
57
- yield [[text,"No Model Found"]],[],"No Model Found"
58
- return
59
-
60
- inputs = generate_prompt_with_history(text,history,tokenizer,max_length=max_context_length_tokens)
61
- if inputs is None:
62
- yield chatbotGr,history,"Input too long."
63
- return
64
- else:
65
- prompt,inputs=inputs
66
- begin_length = len(prompt)
67
-
68
- input_ids = inputs["input_ids"][:,-max_context_length_tokens:].to(device)
69
- torch.cuda.empty_cache()
70
-
71
- #torch.no_grad() bedeutet, dass für die betreffenden tensoren keine Ableitungen berechnet werden bei der backpropagation
72
- #hier soll das NN ja auch nicht geändert werden 8backprop ist nicht nötig), da es um interference-prompts geht!
73
- with torch.no_grad():
74
- #die vergangenen prompts werden alle als Tupel in history abgelegt sortiert nach 'Human' und 'AI'- dass sind daher auch die stop-words, die den jeweils nächsten Eintrag kennzeichnen
75
- for x in greedy_search(input_ids,model,tokenizer,stop_words=["[|Human|]", "[|AI|]"],max_length=max_length_tokens,temperature=temperature,top_p=top_p):
76
- if is_stop_word_or_prefix(x,["[|Human|]", "[|AI|]"]) is False:
77
- if "[|Human|]" in x:
78
- x = x[:x.index("[|Human|]")].strip()
79
- if "[|AI|]" in x:
80
- x = x[:x.index("[|AI|]")].strip()
81
- x = x.strip()
82
- a, b= [[y[0],convert_to_markdown(y[1])] for y in history]+[[text, convert_to_markdown(x)]],history + [[text,x]]
83
- yield a, b, "Generating..."
84
- if shared_state.interrupted:
85
- shared_state.recover()
86
- try:
87
- yield a, b, "Stop: Success"
88
- return
89
- except:
90
- pass
91
- del input_ids
92
- gc.collect()
93
- torch.cuda.empty_cache()
94
-
95
- try:
96
- yield a,b,"Generate: Success"
97
- except:
98
- pass
99
 
100
 
101
- def reset_chat():
102
- #id_new = chatbot.new_conversation()
103
- #chatbot.change_conversation(id_new)
104
- reset_textbox()
105
-
106
-
107
- ##########################################################
108
- #Übersetzungs Ki nutzen
109
- def translate():
110
- return "Kommt noch!"
111
 
112
- #Programmcode KI
113
- def coding():
114
- return "Kommt noch!"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
 
116
  #######################################################################
117
  #Darstellung mit Gradio
118
 
119
- with open("custom.css", "r", encoding="utf-8") as f:
120
- customCSS = f.read()
121
-
122
- with gr.Blocks(theme=small_and_beautiful_theme) as demo:
123
- history = gr.State([])
124
- user_question = gr.State("")
125
- gr.Markdown("KIs am LI - wähle aus, was du bzgl. KI-Bots ausprobieren möchtest!")
126
- with gr.Tabs():
127
- with gr.TabItem("LI-Chat"):
128
- with gr.Row():
129
- gr.HTML(title)
130
- status_display = gr.Markdown("Erfolg", elem_id="status_display")
131
- gr.Markdown(description_top)
132
- with gr.Row(scale=1).style(equal_height=True):
133
- with gr.Column(scale=5):
134
- with gr.Row(scale=1):
135
- chatbotGr = gr.Chatbot(elem_id="LI_chatbot").style(height="100%")
136
- with gr.Row(scale=1):
137
- with gr.Column(scale=12):
138
- user_input = gr.Textbox(
139
- show_label=False, placeholder="Gib deinen Text / Frage ein."
140
- ).style(container=False)
141
- with gr.Column(min_width=100, scale=1):
142
- submitBtn = gr.Button("Absenden")
143
- with gr.Column(min_width=100, scale=1):
144
- cancelBtn = gr.Button("Stoppen")
145
- with gr.Row(scale=1):
146
- emptyBtn = gr.Button(
147
- "🧹 Neuer Chat",
148
- )
149
- with gr.Column():
150
- with gr.Column(min_width=50, scale=1):
151
- with gr.Tab(label="Parameter zum Model"):
152
- gr.Markdown("# Parameters")
153
- top_p = gr.Slider(
154
- minimum=-0,
155
- maximum=1.0,
156
- value=0.95,
157
- step=0.05,
158
- interactive=True,
159
- label="Top-p",
160
- )
161
- temperature = gr.Slider(
162
- minimum=0.1,
163
- maximum=2.0,
164
- value=1,
165
- step=0.1,
166
- interactive=True,
167
- label="Temperature",
168
- )
169
- max_length_tokens = gr.Slider(
170
- minimum=0,
171
- maximum=512,
172
- value=512,
173
- step=8,
174
- interactive=True,
175
- label="Max Generation Tokens",
176
- )
177
- max_context_length_tokens = gr.Slider(
178
- minimum=0,
179
- maximum=4096,
180
- value=2048,
181
- step=128,
182
- interactive=True,
183
- label="Max History Tokens",
184
- )
185
- gr.Markdown(description)
186
-
187
- with gr.TabItem("Übersetzungen"):
188
- with gr.Row():
189
- gr.Textbox(
190
- show_label=False, placeholder="Ist noch in Arbeit..."
191
- ).style(container=False)
192
- with gr.TabItem("Code-Generierungen"):
193
- with gr.Row():
194
- gr.Textbox(
195
- show_label=False, placeholder="Ist noch in Arbeit..."
196
- ).style(container=False)
197
-
198
- predict_args = dict(
199
- fn=predict,
200
- inputs=[
201
- user_question,
202
- chatbotGr,
203
- history,
204
- top_p,
205
- temperature,
206
- max_length_tokens,
207
- max_context_length_tokens,
208
- ],
209
- outputs=[chatbotGr, history, status_display],
210
- show_progress=True,
211
- )
212
-
213
- #neuer Chat
214
- reset_args = dict(
215
- #fn=reset_chat, inputs=[], outputs=[user_input, status_display]
216
- fn=reset_textbox, inputs=[], outputs=[user_input, status_display]
217
- )
218
-
219
- # Chatbot
220
- transfer_input_args = dict(
221
- fn=transfer_input, inputs=[user_input], outputs=[user_question, user_input, submitBtn], show_progress=True
222
- )
223
-
224
- #Listener auf Start-Click auf Button oder Return
225
- predict_event1 = user_input.submit(**transfer_input_args).then(**predict_args)
226
- predict_event2 = submitBtn.click(**transfer_input_args).then(**predict_args)
227
-
228
- #Listener, Wenn reset...
229
- emptyBtn.click(
230
- reset_state,
231
- outputs=[chatbotGr, history, status_display],
232
- show_progress=True,
233
- )
234
- emptyBtn.click(**reset_args)
235
 
236
- demo.title = "LI Chat"
237
- #demo.queue(concurrency_count=1).launch(share=True)
238
- demo.queue(concurrency_count=1).launch(debug=True)
 
6
  import torch
7
  from utils import *
8
  from presets import *
9
+ from transformers import Trainer, TrainingArguments
10
 
11
 
12
  ######################################################################
 
20
  tokenizer,model,device = load_tokenizer_and_model(base_model,False)
21
  dataset_neu = daten_laden("alexkueck/tis")
22
 
23
+ #############################################
24
+ #Vorbereiten für das Training der neuen Daten
25
+ #############################################
 
 
 
26
 
27
  #alles zusammen auf das neue datenset anwenden - batched = True und 4 Prozesse, um die Berechnung zu beschleunigen. Die "text" - Spalte braucht man anschließend nicht mehr, daher weglassen.
28
  tokenized_datasets = dataset_neu.map(tokenize_function, batched=True, num_proc=4, remove_columns=["text"])
 
34
  # block_size = tokenizer.model_max_length
35
  block_size = 128
36
 
37
+ #nochmal die map-Funktion auf das bereits tokenisierte Datenset anwenden
38
+ #die bereits tokenisierten Datensatze ändern sich dadurch: die samples enthalten nun Mengen aus block_size Tokens
39
+ lm_datasets = tokenized_datasets.map(
40
+ group_texts,
41
+ batched=True,
42
+ batch_size=1000,
43
+ num_proc=4,
44
+ )
45
+
46
+ #die Daten wurden nun "gereinigt" und für das Model vorbereitet.
47
+ #z.B. anschauen mit: tokenizer.decode(lm_datasets["train"][1]["input_ids"])
48
+
49
+ ####################################################
50
+ #Training
51
+ ####################################################
52
+
53
+ #Training Args
54
+ model_name = base_model.split("/")[-1]
55
+ training_args = TrainingArguments(
56
+ f"{model_name}-finetuned-tis",
57
+ evaluation_strategy = "epoch",
58
+ learning_rate=2e-5,
59
+ weight_decay=0.01,
60
+ push_to_hub=True,
61
+ )
62
+
63
+ ############################################
64
+ def trainieren_neu():
65
+ #Trainer zusammenstellen
66
+ trainer = Trainer(
67
+ model=model,
68
+ args=training_args,
69
+ train_dataset=lm_datasets["train"],
70
+ eval_dataset=lm_datasets["validation"],
71
+ )
72
 
73
+ #trainer ausführen
74
+ trainer.train()
75
 
76
+ #in den Hub laden
77
+ trainer.push_to_hub()
78
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
 
80
 
81
+ #####################################################
82
+ #Hilfsfunktionen für das training
83
+ #####################################################
84
+ #Datensets in den Tokenizer schieben...
85
+ def tokenize_function(examples):
86
+ return tokenizer(examples["text"])
87
+
 
 
 
88
 
89
+ #Funktion, die den gegebenen Text aus dem Datenset gruppiert
90
+ def group_texts(examples):
91
+ # Concatenate all texts.
92
+ concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
93
+ total_length = len(concatenated_examples[list(examples.keys())[0]])
94
+ # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
95
+ # customize this part to your needs.
96
+ total_length = (total_length // block_size) * block_size
97
+ # Split by chunks of max_len.
98
+ result = {
99
+ k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
100
+ for k, t in concatenated_examples.items()
101
+ }
102
+ result["labels"] = result["input_ids"].copy()
103
+ return result
104
+
105
+
106
+
107
 
108
  #######################################################################
109
  #Darstellung mit Gradio
110
 
111
+ with gr.Blocks() as demo:
112
+ output = gr.Textbox(label="Output Box")
113
+ start_btn = gr.Button("Start")
114
+ start_btn.click(fn=greet, inputs, outputs=output, api_name="trainieren_neu")
115
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
 
117
+ demo.launch()