sagar007 commited on
Commit
aaa17ba
·
verified ·
1 Parent(s): 25893d0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +53 -70
app.py CHANGED
@@ -3,31 +3,9 @@ import torch.nn as nn
3
  from torch.nn import functional as F
4
  import tiktoken
5
  import gradio as gr
6
- import torch
7
- import torch.nn as nn
8
- from torch.nn import functional as F
9
- import tiktoken
10
- import gradio as gr
11
- import asyncio
12
- import gradio as gr
13
  import asyncio
14
 
15
- # Add the post-processing function here
16
- def post_process_text(text):
17
- # Ensure the text starts with a capital letter
18
- text = text.capitalize()
19
-
20
- # Remove any incomplete sentences at the end
21
- sentences = text.split('.')
22
- complete_sentences = sentences[:-1] if len(sentences) > 1 else sentences
23
-
24
- # Rejoin sentences and add a period if missing
25
- processed_text = '. '.join(complete_sentences)
26
- if not processed_text.endswith('.'):
27
- processed_text += '.'
28
-
29
- return processed_text
30
- # Define the model architecture
31
  class GPTConfig:
32
  def __init__(self):
33
  self.block_size = 1024
@@ -36,6 +14,7 @@ class GPTConfig:
36
  self.n_head = 12
37
  self.n_embd = 768
38
 
 
39
  class CausalSelfAttention(nn.Module):
40
  def __init__(self, config):
41
  super().__init__()
@@ -43,7 +22,6 @@ class CausalSelfAttention(nn.Module):
43
  self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)
44
  self.c_proj = nn.Linear(config.n_embd, config.n_embd)
45
  self.n_head = config.n_head
46
- self.n_embd = config.n_embd
47
  self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size)).view(1, 1, config.block_size, config.block_size))
48
 
49
  def forward(self, x):
@@ -56,6 +34,7 @@ class CausalSelfAttention(nn.Module):
56
  y = y.transpose(1, 2).contiguous().view(B, T, C)
57
  return self.c_proj(y)
58
 
 
59
  class MLP(nn.Module):
60
  def __init__(self, config):
61
  super().__init__()
@@ -66,6 +45,7 @@ class MLP(nn.Module):
66
  def forward(self, x):
67
  return self.c_proj(self.gelu(self.c_fc(x)))
68
 
 
69
  class Block(nn.Module):
70
  def __init__(self, config):
71
  super().__init__()
@@ -79,6 +59,7 @@ class Block(nn.Module):
79
  x = x + self.mlp(self.ln_2(x))
80
  return x
81
 
 
82
  class GPT(nn.Module):
83
  def __init__(self, config):
84
  super().__init__()
@@ -121,15 +102,17 @@ class GPT(nn.Module):
121
 
122
  return logits, loss
123
 
124
- # Load the model
125
  def load_model(model_path):
126
  config = GPTConfig()
127
  model = GPT(config)
128
-
129
- checkpoint = torch.load(model_path, map_location=torch.device('cpu'))
130
-
131
- print("Checkpoint keys:", checkpoint.keys()) # Debug print
132
-
 
 
133
  if 'model_state_dict' in checkpoint:
134
  model.load_state_dict(checkpoint['model_state_dict'])
135
  else:
@@ -138,59 +121,59 @@ def load_model(model_path):
138
  model.eval()
139
  return model
140
 
141
- # Load the model
142
- model = load_model('gpt_model.pth') # Replace with the actual path to your .pt file
143
- enc = tiktoken.get_encoding('gpt2')
144
 
145
- # Improved text generation function
146
- import torch
147
- import torch.nn as nn
148
- from torch.nn import functional as F
149
- import tiktoken
150
- import gradio as gr
151
-
152
- # [Your existing model code remains unchanged]
 
153
 
154
- # Modify the generate_text function to be asynchronous
155
  async def generate_text(prompt, max_length=432, temperature=0.8, top_k=40):
156
- input_ids = torch.tensor(enc.encode(prompt)).unsqueeze(0)
 
157
  generated = []
158
-
159
  with torch.no_grad():
160
  for _ in range(max_length):
161
- outputs, _ = model(input_ids)
162
- next_token_logits = outputs[:, -1, :]
163
- next_token_logits = next_token_logits / temperature
164
- top_k_logits, top_k_indices = torch.topk(next_token_logits, top_k, dim=-1)
165
- next_token_probs = F.softmax(top_k_logits, dim=-1)
166
- next_token_index = torch.multinomial(next_token_probs, num_samples=1)
167
- next_token = top_k_indices.gather(-1, next_token_index)
168
-
169
- input_ids = torch.cat([input_ids, next_token], dim=-1)
170
- generated.append(next_token.item())
171
-
172
- next_token_str = enc.decode([next_token.item()])
173
- yield next_token_str
174
-
175
- if next_token.item() == enc.encode('\n')[0] and len(generated) > 100:
176
- break
177
-
178
- await asyncio.sleep(0.02) # Slightly faster typing effect
179
 
180
- if len(generated) == max_length:
181
- yield "... (output truncated due to length)"
182
- # Modify the gradio_generate function to be asynchronous
 
 
 
 
183
  async def gradio_generate(prompt, max_length, temperature, top_k):
184
  output = ""
185
  async for token in generate_text(prompt, max_length, temperature, top_k):
186
  output += token
187
  yield output
 
 
188
 
189
- # Custom CSS for the animation effect
190
- import gradio as gr
191
- import asyncio
192
-
193
- # Your existing imports and model code here...
194
 
195
  css = """
196
  <style>
 
3
  from torch.nn import functional as F
4
  import tiktoken
5
  import gradio as gr
 
 
 
 
 
 
 
6
  import asyncio
7
 
8
+ # Model Configuration
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  class GPTConfig:
10
  def __init__(self):
11
  self.block_size = 1024
 
14
  self.n_head = 12
15
  self.n_embd = 768
16
 
17
+ # Causal Self-Attention
18
  class CausalSelfAttention(nn.Module):
19
  def __init__(self, config):
20
  super().__init__()
 
22
  self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)
23
  self.c_proj = nn.Linear(config.n_embd, config.n_embd)
24
  self.n_head = config.n_head
 
25
  self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size)).view(1, 1, config.block_size, config.block_size))
26
 
27
  def forward(self, x):
 
34
  y = y.transpose(1, 2).contiguous().view(B, T, C)
35
  return self.c_proj(y)
36
 
37
+ # Multi-Layer Perceptron
38
  class MLP(nn.Module):
39
  def __init__(self, config):
40
  super().__init__()
 
45
  def forward(self, x):
46
  return self.c_proj(self.gelu(self.c_fc(x)))
47
 
48
+ # Transformer Block
49
  class Block(nn.Module):
50
  def __init__(self, config):
51
  super().__init__()
 
59
  x = x + self.mlp(self.ln_2(x))
60
  return x
61
 
62
+ # GPT Model
63
  class GPT(nn.Module):
64
  def __init__(self, config):
65
  super().__init__()
 
102
 
103
  return logits, loss
104
 
105
+ # Load Model
106
  def load_model(model_path):
107
  config = GPTConfig()
108
  model = GPT(config)
109
+ try:
110
+ checkpoint = torch.load(model_path, map_location=torch.device('cpu')) # Load on CPU first
111
+ except FileNotFoundError:
112
+ raise FileNotFoundError(f"Model file not found at: {model_path}")
113
+ except Exception as e:
114
+ raise Exception(f"Error loading model: {e}")
115
+
116
  if 'model_state_dict' in checkpoint:
117
  model.load_state_dict(checkpoint['model_state_dict'])
118
  else:
 
121
  model.eval()
122
  return model
123
 
 
 
 
124
 
125
+ # Text Post-processing
126
+ def post_process_text(text):
127
+ text = text.capitalize()
128
+ sentences = text.split('.')
129
+ complete_sentences = sentences[:-1] if len(sentences) > 1 else sentences
130
+ processed_text = '. '.join(complete_sentences)
131
+ if not processed_text.endswith('.'):
132
+ processed_text += '.'
133
+ return processed_text
134
 
135
+ # Text Generation Function (Asynchronous)
136
  async def generate_text(prompt, max_length=432, temperature=0.8, top_k=40):
137
+ enc = tiktoken.get_encoding('gpt2')
138
+ input_ids = torch.tensor(enc.encode(prompt)).unsqueeze(0).to(device)
139
  generated = []
140
+
141
  with torch.no_grad():
142
  for _ in range(max_length):
143
+ try:
144
+ outputs, _ = model(input_ids)
145
+ next_token_logits = outputs[:, -1, :]
146
+ next_token_logits = next_token_logits / temperature
147
+ top_k_logits, top_k_indices = torch.topk(next_token_logits, top_k, dim=-1)
148
+ next_token_probs = F.softmax(top_k_logits, dim=-1)
149
+ next_token_index = torch.multinomial(next_token_probs, num_samples=1)
150
+ next_token = top_k_indices.gather(-1, next_token_index)
151
+
152
+ input_ids = torch.cat([input_ids, next_token], dim=-1)
153
+ generated.append(next_token.item())
154
+
155
+ next_token_str = enc.decode([next_token.item()])
156
+ yield next_token_str
157
+
158
+ if next_token.item() == enc.encode('\n')[0] and len(generated) > 100:
159
+ break
 
160
 
161
+ await asyncio.sleep(0.02) # For typing effect
162
+
163
+ except Exception as e:
164
+ yield f"Error during generation: {e}"
165
+ return
166
+
167
+ # Gradio Generate Function
168
  async def gradio_generate(prompt, max_length, temperature, top_k):
169
  output = ""
170
  async for token in generate_text(prompt, max_length, temperature, top_k):
171
  output += token
172
  yield output
173
+ output = post_process_text(output)
174
+ yield output
175
 
176
+ # Load the model (replace with your model path
 
 
 
 
177
 
178
  css = """
179
  <style>