Gbssreejith commited on
Commit
7dfb322
1 Parent(s): 1de94f0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +451 -0
app.py CHANGED
@@ -0,0 +1,451 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import torch
3
+ import math
4
+ import torch.nn as nn
5
+ from torch.nn.parameter import Parameter
6
+ import random
7
+ import numpy as np
8
+ from load_weights import load_weight
9
+ from sklearn.model_selection import train_test_split
10
+ from transformers import GPT2TokenizerFast
11
+ import pandas as pd
12
+ from torch.utils.data import Dataset, DataLoader
13
+ from transformers import AdamW, get_linear_schedule_with_warmup
14
+ torch.manual_seed(42)
15
+ import nltk
16
+ # nltk.download('punkt')
17
+
18
+ from transformers import GPT2Tokenizer
19
+ from torch.utils.data import Dataset, DataLoader, random_split, RandomSampler, SequentialSampler
20
+ import datetime
21
+ import time
22
+ import os
23
+ os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
24
+ from tqdm import trange
25
+ import gradio as gr
26
+
27
+
28
+
29
+ def gelu(x):
30
+ return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
31
+
32
+ class Conv1D(nn.Module):
33
+ def __init__(self, nf, nx):
34
+ super(Conv1D, self).__init__()
35
+ self.nf = nf
36
+ w = torch.empty(nx, nf)
37
+ nn.init.normal_(w, std=0.02)
38
+ self.weight = Parameter(w)
39
+ self.bias = Parameter(torch.zeros(nf))
40
+
41
+ def forward(self, x):
42
+ size_out = x.size()[:-1] + (self.nf,)
43
+ x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight)
44
+ x = x.view(*size_out)
45
+ return x
46
+
47
+ class LayerNorm(nn.Module):
48
+ def __init__(self, hidden_size, eps=1e-12):
49
+ """Construct a layernorm module in the TF style (epsilon inside the square root).
50
+ """
51
+ super(LayerNorm, self).__init__()
52
+ self.weight = nn.Parameter(torch.ones(hidden_size))
53
+ self.bias = nn.Parameter(torch.zeros(hidden_size))
54
+ self.variance_epsilon = eps
55
+
56
+ def forward(self, x):
57
+ u = x.mean(-1, keepdim=True)
58
+ s = (x - u).pow(2).mean(-1, keepdim=True)
59
+ x = (x - u) / torch.sqrt(s + self.variance_epsilon)
60
+ return self.weight * x + self.bias
61
+
62
+
63
+
64
+ class Attention(nn.Module):
65
+ def __init__(self, nx, n_ctx, config, scale=False):
66
+ super(Attention, self).__init__()
67
+ n_state = nx # in Attention: n_state=768 (nx=n_embd)
68
+ # [switch nx => n_state from Block to Attention to keep identical to TF implem]
69
+ assert n_state % config.n_head == 0
70
+ self.register_buffer("bias", torch.tril(torch.ones(n_ctx, n_ctx)).view(1, 1, n_ctx, n_ctx))
71
+ self.n_head = config.n_head
72
+ self.split_size = n_state
73
+ self.scale = scale
74
+ self.c_attn = Conv1D(n_state * 3, nx)
75
+ self.c_proj = Conv1D(n_state, nx)
76
+
77
+ def _attn(self, q, k, v):
78
+ w = torch.matmul(q, k)
79
+ if self.scale:
80
+ w = w / math.sqrt(v.size(-1))
81
+ nd, ns = w.size(-2), w.size(-1)
82
+ b = self.bias[:, :, ns-nd:ns, :ns]
83
+ w = w * b - 1e10 * (1 - b)
84
+ w = nn.Softmax(dim=-1)(w)
85
+ return torch.matmul(w, v)
86
+
87
+ def merge_heads(self, x):
88
+ x = x.permute(0, 2, 1, 3).contiguous()
89
+ new_x_shape = x.size()[:-2] + (x.size(-2) * x.size(-1),)
90
+ return x.view(*new_x_shape) # in Tensorflow implem: fct merge_states
91
+
92
+ def split_heads(self, x, k=False):
93
+ new_x_shape = x.size()[:-1] + (self.n_head, x.size(-1) // self.n_head)
94
+ x = x.view(*new_x_shape) # in Tensorflow implem: fct split_states
95
+ if k:
96
+ return x.permute(0, 2, 3, 1) # (batch, head, head_features, seq_length)
97
+ else:
98
+ return x.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features)
99
+
100
+ def forward(self, x, layer_past=None):
101
+ x = self.c_attn(x)
102
+ query, key, value = x.split(self.split_size, dim=2)
103
+ query = self.split_heads(query)
104
+ key = self.split_heads(key, k=True)
105
+ value = self.split_heads(value)
106
+ if layer_past is not None:
107
+ past_key, past_value = layer_past[0].transpose(-2, -1), layer_past[1] # transpose back cf below
108
+ key = torch.cat((past_key, key), dim=-1)
109
+ value = torch.cat((past_value, value), dim=-2)
110
+ present = torch.stack((key.transpose(-2, -1), value)) # transpose to have same shapes for stacking
111
+ a = self._attn(query, key, value)
112
+ a = self.merge_heads(a)
113
+ a = self.c_proj(a)
114
+ return a, present
115
+
116
+
117
+ class MLP(nn.Module):
118
+ def __init__(self, n_state, config): # in MLP: n_state=3072 (4 * n_embd)
119
+ super(MLP, self).__init__()
120
+ nx = config.n_embd
121
+ self.c_fc = Conv1D(n_state, nx)
122
+ self.c_proj = Conv1D(nx, n_state)
123
+ self.act = gelu
124
+
125
+ def forward(self, x):
126
+ h = self.act(self.c_fc(x))
127
+ h2 = self.c_proj(h)
128
+ return h2
129
+
130
+
131
+ class Block(nn.Module):
132
+ def __init__(self, n_ctx, config, scale=False):
133
+ super(Block, self).__init__()
134
+ nx = config.n_embd
135
+ self.ln_1 = LayerNorm(nx, eps=config.layer_norm_epsilon)
136
+ self.attn = Attention(nx, n_ctx, config, scale)
137
+ self.ln_2 = LayerNorm(nx, eps=config.layer_norm_epsilon)
138
+ self.mlp = MLP(4 * nx, config)
139
+
140
+ def forward(self, x, layer_past=None):
141
+ a, present = self.attn(self.ln_1(x), layer_past=layer_past)
142
+ x = x + a
143
+ m = self.mlp(self.ln_2(x))
144
+ x = x + m
145
+ return x, present
146
+
147
+
148
+
149
+ class GPT2Model(nn.Module):
150
+ def __init__(self, config):
151
+ super(GPT2Model, self).__init__()
152
+ self.n_layer = config.n_layer
153
+ self.n_embd = config.n_embd
154
+ self.n_vocab = config.vocab_size
155
+
156
+ self.wte = nn.Embedding(config.vocab_size, config.n_embd)
157
+ self.wpe = nn.Embedding(config.n_positions, config.n_embd)
158
+ block = Block(config.n_ctx, config, scale=True)
159
+ self.h = nn.ModuleList([copy.deepcopy(block) for _ in range(config.n_layer)])
160
+ self.ln_f = LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
161
+
162
+ def set_embeddings_weights(self, model_embeddings_weights):
163
+ embed_shape = model_embeddings_weights.shape
164
+ self.decoder = nn.Linear(embed_shape[1], embed_shape[0], bias=False)
165
+ self.decoder.weight = model_embeddings_weights # Tied weights
166
+
167
+
168
+
169
+ def forward(self, input_ids, position_ids=None, token_type_ids=None, past=None):
170
+
171
+ if (input_ids >= self.n_vocab).any():
172
+ raise ValueError(f"Invalid token ID found in input_ids: {input_ids}")
173
+
174
+ # print(f"input_ids: {input_ids}") # Debugging statement
175
+ # print(f"Max input_id: {input_ids.max().item()}") # Debugging statement
176
+ # print(f"Min input_id: {input_ids.min().item()}") # Debugging statement
177
+
178
+ if past is None:
179
+ past_length = 0
180
+ past = [None] * len(self.h)
181
+ else:
182
+ past_length = past[0][0].size(-2)
183
+ if position_ids is None:
184
+ position_ids = torch.arange(past_length, input_ids.size(-1) + past_length, dtype=torch.long,
185
+ device=input_ids.device)
186
+ position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
187
+
188
+ input_shape = input_ids.size()
189
+ input_ids = input_ids.view(-1, input_ids.size(-1))
190
+ position_ids = position_ids.view(-1, position_ids.size(-1))
191
+
192
+ inputs_embeds = self.wte(input_ids)
193
+ position_embeds = self.wpe(position_ids)
194
+
195
+ # print(f"inputs_embeds shape: {inputs_embeds.shape}")
196
+ # print(f"position_embeds shape: {position_embeds.shape}")
197
+
198
+
199
+ if token_type_ids is not None:
200
+ token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1))
201
+ token_type_embeds = self.wte(token_type_ids)
202
+ else:
203
+ token_type_embeds = 0
204
+ hidden_states = inputs_embeds + position_embeds + token_type_embeds
205
+ presents = []
206
+ for block, layer_past in zip(self.h, past):
207
+ hidden_states, present = block(hidden_states, layer_past)
208
+ presents.append(present)
209
+ hidden_states = self.ln_f(hidden_states)
210
+ output_shape = input_shape + (hidden_states.size(-1),)
211
+ return hidden_states.view(*output_shape), presents
212
+
213
+ class GPT2LMHead(nn.Module):
214
+ def __init__(self, model_embeddings_weights, config):
215
+ super(GPT2LMHead, self).__init__()
216
+ self.n_embd = config.n_embd
217
+ self.set_embeddings_weights(model_embeddings_weights)
218
+
219
+ def set_embeddings_weights(self, model_embeddings_weights):
220
+ embed_shape = model_embeddings_weights.shape
221
+ self.decoder = nn.Linear(embed_shape[1], embed_shape[0], bias=False)
222
+ self.decoder.weight = model_embeddings_weights # Tied weights
223
+
224
+ def forward(self, hidden_state):
225
+ # Truncated Language modeling logits (we remove the last token)
226
+ # h_trunc = h[:, :-1].contiguous().view(-1, self.n_embd)
227
+ lm_logits = self.decoder(hidden_state)
228
+ return lm_logits
229
+
230
+ import torch.nn.functional as F
231
+
232
+ def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
233
+ """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
234
+ Args:
235
+ logits: logits distribution shape (batch size, vocabulary size)
236
+ top_k > 0: keep only top k tokens with highest probability (top-k filtering).
237
+ top_p > 0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
238
+ filter_value: value to replace filtered logits.
239
+ """
240
+ assert logits.dim() == 2 # batch size x vocabulary size
241
+ top_k = min(top_k, logits.size(-1)) # Safety check
242
+ if top_k > 0:
243
+ # Remove all tokens with a probability less than the last token of the top-k
244
+ indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
245
+ logits[indices_to_remove] = filter_value
246
+
247
+ if top_p > 0.0:
248
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
249
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
250
+
251
+ # Remove tokens with cumulative probability above the threshold
252
+ sorted_indices_to_remove = cumulative_probs > top_p
253
+ # Shift the indices to the right to keep also the first token above the threshold
254
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
255
+ sorted_indices_to_remove[..., 0] = 0
256
+
257
+ indices_to_remove = sorted_indices[sorted_indices_to_remove]
258
+ logits[indices_to_remove] = filter_value
259
+ return logits
260
+
261
+
262
+ class GPT2LMHeadModel(nn.Module):
263
+ def __init__(self, config):
264
+ super(GPT2LMHeadModel, self).__init__()
265
+ self.transformer = GPT2Model(config)
266
+ self.lm_head = GPT2LMHead(self.transformer.wte.weight, config)
267
+
268
+ def set_tied(self):
269
+ """ Make sure we are sharing the embeddings
270
+ """
271
+ self.lm_head.set_embeddings_weights(self.transformer.wte.weight)
272
+
273
+ def forward(self, input_ids, position_ids=None, token_type_ids=None, lm_labels=None, past=None):
274
+ hidden_states, presents = self.transformer(input_ids, position_ids, token_type_ids, past)
275
+ lm_logits = self.lm_head(hidden_states)
276
+
277
+ outputs = (lm_logits,presents)
278
+
279
+ if lm_labels is not None:
280
+ shift_logits = lm_logits[..., :-1, :].contiguous()
281
+ shift_labels = lm_labels[..., 1:].contiguous()
282
+ loss_fct = nn.CrossEntropyLoss()
283
+ loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
284
+ outputs = (loss,) + outputs
285
+ return outputs
286
+
287
+ import torch.nn.functional as F
288
+
289
+
290
+
291
+ def generate(
292
+ self, input_ids, max_length, temperature=1.0, top_k=0, top_p=0.9, repetition_penalty=1.0, device='cuda'
293
+ ):
294
+ self.eval()
295
+ input_ids = input_ids.to(device)
296
+ batch_size = input_ids.shape[0]
297
+ past = None
298
+
299
+ generated = input_ids
300
+ with torch.no_grad():
301
+ for _ in range(max_length):
302
+ outputs = self(input_ids, past=past)
303
+ next_token_logits = outputs[0][:, -1, :]
304
+ past = outputs[1]
305
+
306
+ for i in range(batch_size):
307
+ for token_id in set(generated[i].tolist()):
308
+ next_token_logits[i, token_id] /= repetition_penalty
309
+
310
+ next_token_logits = next_token_logits / temperature
311
+ filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
312
+ next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1)
313
+ generated = torch.cat((generated, next_token), dim=1)
314
+
315
+ if (next_token == self.config.eos_token_id).all():
316
+ break
317
+
318
+ input_ids = next_token
319
+
320
+ return generated
321
+
322
+
323
+ class GPT2Config(object):
324
+ def __init__(
325
+ self,
326
+ vocab_size_or_config_json_file=50257,
327
+ n_positions=1024,
328
+ n_ctx=1024,
329
+ n_embd=768,
330
+ n_layer=12,
331
+ n_head=12,
332
+ layer_norm_epsilon=1e-5,
333
+ initializer_range=0.02,
334
+ ):
335
+ self.vocab_size = vocab_size_or_config_json_file
336
+ self.n_ctx = n_ctx
337
+ self.n_positions = n_positions
338
+ self.n_embd = n_embd
339
+ self.n_layer = n_layer
340
+ self.n_head = n_head
341
+ self.layer_norm_epsilon = layer_norm_epsilon
342
+ self.initializer_range = initializer_range
343
+
344
+
345
+
346
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
347
+ config = GPT2Config()
348
+ model = GPT2LMHeadModel(config)
349
+ state_dict = torch.load(r'weights/epoch_4.pth', map_location='cpu' if not torch.cuda.is_available() else None)
350
+ model = load_weight(model, state_dict)
351
+ model.to(device)
352
+ print(model)
353
+ model.eval()
354
+
355
+ tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
356
+ tokenizer.pad_token = tokenizer.eos_token
357
+
358
+
359
+
360
+ def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
361
+ """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
362
+ Args:
363
+ logits: logits distribution shape (batch size x vocabulary size)
364
+ top_k > 0: keep only top k tokens with highest probability (top-k filtering).
365
+ top_p > 0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
366
+ Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
367
+ """
368
+ assert logits.dim() == 2, "Expected logits dimension to be 2 (batch size x vocabulary size)"
369
+ top_k = min(top_k, logits.size(-1)) # Safety check
370
+ if top_k > 0:
371
+ # Remove all tokens with a probability less than the last token of the top-k
372
+ indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
373
+ logits[indices_to_remove] = filter_value
374
+
375
+ if top_p > 0.0:
376
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
377
+ cumulative_probs = torch.cumsum(nn.Softmax(dim=-1)(sorted_logits), dim=-1)
378
+
379
+ # Remove tokens with cumulative probability above the threshold
380
+ sorted_indices_to_remove = cumulative_probs > top_p
381
+ # Shift the indices to the right to keep also the first token above the threshold
382
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
383
+ sorted_indices_to_remove[..., 0] = 0
384
+
385
+ # Ensure that the dimensions match
386
+ if sorted_indices_to_remove.size() != sorted_indices.size():
387
+ raise ValueError(f"Size mismatch: {sorted_indices_to_remove.size()} vs {sorted_indices.size()}")
388
+
389
+ indices_to_remove = sorted_indices[sorted_indices_to_remove]
390
+
391
+ # Expand dimensions to match logits tensor and use scatter_
392
+ for batch_idx in range(logits.size(0)):
393
+ logits[batch_idx, indices_to_remove[batch_idx]] = filter_value
394
+
395
+ return logits
396
+
397
+ # prompt_text = "What is a nucleophile in organic chemistry?"
398
+ # prompt = f"\n<|startoftext|>[WP] {prompt_text} \n[RESPONSE]"
399
+ # input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device)
400
+
401
+
402
+ max_length = 100
403
+ temperature = 0.7
404
+ top_k = 1
405
+ top_p = 0.95
406
+ repetition_penalty = 1.0
407
+
408
+ with torch.no_grad():
409
+ for _ in range(max_length):
410
+ outputs = model(input_ids)
411
+ logits = outputs[0]
412
+ next_token_logits = logits[:, -1, :] / temperature
413
+
414
+ # Apply repetition penalty
415
+ for i in range(input_ids.size(0)):
416
+ for token_id in set(input_ids[i].tolist()):
417
+ next_token_logits[0, token_id] /= repetition_penalty
418
+
419
+ # Filter logits using top-k and/or top-p filtering
420
+ filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
421
+ next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1)
422
+ input_ids = torch.cat([input_ids, next_token], dim=-1).to(device)
423
+
424
+
425
+ # import re
426
+ # # generated_text = tokenizer.decode(input_ids[0], skip_special_tokens=True)
427
+ # # wp_responses = re.split(r"\[WP\].*?\n|\[RESPONSE\]", generated_text)[1:]
428
+ # print(input_ids[0])
429
+
430
+ # generated_text = tokenizer.decode(input_ids[0], skip_special_tokens=True)
431
+ # wp_responses = re.split(r"\[WP\].*?\n|\[RESPONSE\]", generated_text)[1:]
432
+ # print(wp_responses)
433
+
434
+ # Create a Gradio interface
435
+ iface = gr.Interface(
436
+ fn=generate_response,
437
+ inputs="text",
438
+ outputs="text",
439
+ title="Custom GPT-2 Model",
440
+ description="Enter a prompt to get a generated response from the custom-trained GPT-2 model.",
441
+ examples=[
442
+ ["What is a nucleophile in organic chemistry?"],
443
+ ["Explain the concept of quantum entanglement."],
444
+ ["How does photosynthesis work?"]
445
+ ]
446
+ )
447
+
448
+ # Launch the Gradio interface
449
+ iface.launch(share=True, debug=True)
450
+
451
+