Update app.py
Browse files
app.py
CHANGED
@@ -3,31 +3,9 @@ import torch.nn as nn
|
|
3 |
from torch.nn import functional as F
|
4 |
import tiktoken
|
5 |
import gradio as gr
|
6 |
-
import torch
|
7 |
-
import torch.nn as nn
|
8 |
-
from torch.nn import functional as F
|
9 |
-
import tiktoken
|
10 |
-
import gradio as gr
|
11 |
-
import asyncio
|
12 |
-
import gradio as gr
|
13 |
import asyncio
|
14 |
|
15 |
-
#
|
16 |
-
def post_process_text(text):
|
17 |
-
# Ensure the text starts with a capital letter
|
18 |
-
text = text.capitalize()
|
19 |
-
|
20 |
-
# Remove any incomplete sentences at the end
|
21 |
-
sentences = text.split('.')
|
22 |
-
complete_sentences = sentences[:-1] if len(sentences) > 1 else sentences
|
23 |
-
|
24 |
-
# Rejoin sentences and add a period if missing
|
25 |
-
processed_text = '. '.join(complete_sentences)
|
26 |
-
if not processed_text.endswith('.'):
|
27 |
-
processed_text += '.'
|
28 |
-
|
29 |
-
return processed_text
|
30 |
-
# Define the model architecture
|
31 |
class GPTConfig:
|
32 |
def __init__(self):
|
33 |
self.block_size = 1024
|
@@ -36,6 +14,7 @@ class GPTConfig:
|
|
36 |
self.n_head = 12
|
37 |
self.n_embd = 768
|
38 |
|
|
|
39 |
class CausalSelfAttention(nn.Module):
|
40 |
def __init__(self, config):
|
41 |
super().__init__()
|
@@ -43,7 +22,6 @@ class CausalSelfAttention(nn.Module):
|
|
43 |
self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)
|
44 |
self.c_proj = nn.Linear(config.n_embd, config.n_embd)
|
45 |
self.n_head = config.n_head
|
46 |
-
self.n_embd = config.n_embd
|
47 |
self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size)).view(1, 1, config.block_size, config.block_size))
|
48 |
|
49 |
def forward(self, x):
|
@@ -56,6 +34,7 @@ class CausalSelfAttention(nn.Module):
|
|
56 |
y = y.transpose(1, 2).contiguous().view(B, T, C)
|
57 |
return self.c_proj(y)
|
58 |
|
|
|
59 |
class MLP(nn.Module):
|
60 |
def __init__(self, config):
|
61 |
super().__init__()
|
@@ -66,6 +45,7 @@ class MLP(nn.Module):
|
|
66 |
def forward(self, x):
|
67 |
return self.c_proj(self.gelu(self.c_fc(x)))
|
68 |
|
|
|
69 |
class Block(nn.Module):
|
70 |
def __init__(self, config):
|
71 |
super().__init__()
|
@@ -79,6 +59,7 @@ class Block(nn.Module):
|
|
79 |
x = x + self.mlp(self.ln_2(x))
|
80 |
return x
|
81 |
|
|
|
82 |
class GPT(nn.Module):
|
83 |
def __init__(self, config):
|
84 |
super().__init__()
|
@@ -121,15 +102,17 @@ class GPT(nn.Module):
|
|
121 |
|
122 |
return logits, loss
|
123 |
|
124 |
-
# Load
|
125 |
def load_model(model_path):
|
126 |
config = GPTConfig()
|
127 |
model = GPT(config)
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
|
|
|
|
133 |
if 'model_state_dict' in checkpoint:
|
134 |
model.load_state_dict(checkpoint['model_state_dict'])
|
135 |
else:
|
@@ -138,59 +121,59 @@ def load_model(model_path):
|
|
138 |
model.eval()
|
139 |
return model
|
140 |
|
141 |
-
# Load the model
|
142 |
-
model = load_model('gpt_model.pth') # Replace with the actual path to your .pt file
|
143 |
-
enc = tiktoken.get_encoding('gpt2')
|
144 |
|
145 |
-
#
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
|
|
153 |
|
154 |
-
#
|
155 |
async def generate_text(prompt, max_length=432, temperature=0.8, top_k=40):
|
156 |
-
|
|
|
157 |
generated = []
|
158 |
-
|
159 |
with torch.no_grad():
|
160 |
for _ in range(max_length):
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
await asyncio.sleep(0.02) # Slightly faster typing effect
|
179 |
|
180 |
-
|
181 |
-
|
182 |
-
|
|
|
|
|
|
|
|
|
183 |
async def gradio_generate(prompt, max_length, temperature, top_k):
|
184 |
output = ""
|
185 |
async for token in generate_text(prompt, max_length, temperature, top_k):
|
186 |
output += token
|
187 |
yield output
|
|
|
|
|
188 |
|
189 |
-
#
|
190 |
-
import gradio as gr
|
191 |
-
import asyncio
|
192 |
-
|
193 |
-
# Your existing imports and model code here...
|
194 |
|
195 |
css = """
|
196 |
<style>
|
|
|
3 |
from torch.nn import functional as F
|
4 |
import tiktoken
|
5 |
import gradio as gr
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
import asyncio
|
7 |
|
8 |
+
# Model Configuration
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
class GPTConfig:
|
10 |
def __init__(self):
|
11 |
self.block_size = 1024
|
|
|
14 |
self.n_head = 12
|
15 |
self.n_embd = 768
|
16 |
|
17 |
+
# Causal Self-Attention
|
18 |
class CausalSelfAttention(nn.Module):
|
19 |
def __init__(self, config):
|
20 |
super().__init__()
|
|
|
22 |
self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)
|
23 |
self.c_proj = nn.Linear(config.n_embd, config.n_embd)
|
24 |
self.n_head = config.n_head
|
|
|
25 |
self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size)).view(1, 1, config.block_size, config.block_size))
|
26 |
|
27 |
def forward(self, x):
|
|
|
34 |
y = y.transpose(1, 2).contiguous().view(B, T, C)
|
35 |
return self.c_proj(y)
|
36 |
|
37 |
+
# Multi-Layer Perceptron
|
38 |
class MLP(nn.Module):
|
39 |
def __init__(self, config):
|
40 |
super().__init__()
|
|
|
45 |
def forward(self, x):
|
46 |
return self.c_proj(self.gelu(self.c_fc(x)))
|
47 |
|
48 |
+
# Transformer Block
|
49 |
class Block(nn.Module):
|
50 |
def __init__(self, config):
|
51 |
super().__init__()
|
|
|
59 |
x = x + self.mlp(self.ln_2(x))
|
60 |
return x
|
61 |
|
62 |
+
# GPT Model
|
63 |
class GPT(nn.Module):
|
64 |
def __init__(self, config):
|
65 |
super().__init__()
|
|
|
102 |
|
103 |
return logits, loss
|
104 |
|
105 |
+
# Load Model
|
106 |
def load_model(model_path):
|
107 |
config = GPTConfig()
|
108 |
model = GPT(config)
|
109 |
+
try:
|
110 |
+
checkpoint = torch.load(model_path, map_location=torch.device('cpu')) # Load on CPU first
|
111 |
+
except FileNotFoundError:
|
112 |
+
raise FileNotFoundError(f"Model file not found at: {model_path}")
|
113 |
+
except Exception as e:
|
114 |
+
raise Exception(f"Error loading model: {e}")
|
115 |
+
|
116 |
if 'model_state_dict' in checkpoint:
|
117 |
model.load_state_dict(checkpoint['model_state_dict'])
|
118 |
else:
|
|
|
121 |
model.eval()
|
122 |
return model
|
123 |
|
|
|
|
|
|
|
124 |
|
125 |
+
# Text Post-processing
|
126 |
+
def post_process_text(text):
|
127 |
+
text = text.capitalize()
|
128 |
+
sentences = text.split('.')
|
129 |
+
complete_sentences = sentences[:-1] if len(sentences) > 1 else sentences
|
130 |
+
processed_text = '. '.join(complete_sentences)
|
131 |
+
if not processed_text.endswith('.'):
|
132 |
+
processed_text += '.'
|
133 |
+
return processed_text
|
134 |
|
135 |
+
# Text Generation Function (Asynchronous)
|
136 |
async def generate_text(prompt, max_length=432, temperature=0.8, top_k=40):
|
137 |
+
enc = tiktoken.get_encoding('gpt2')
|
138 |
+
input_ids = torch.tensor(enc.encode(prompt)).unsqueeze(0).to(device)
|
139 |
generated = []
|
140 |
+
|
141 |
with torch.no_grad():
|
142 |
for _ in range(max_length):
|
143 |
+
try:
|
144 |
+
outputs, _ = model(input_ids)
|
145 |
+
next_token_logits = outputs[:, -1, :]
|
146 |
+
next_token_logits = next_token_logits / temperature
|
147 |
+
top_k_logits, top_k_indices = torch.topk(next_token_logits, top_k, dim=-1)
|
148 |
+
next_token_probs = F.softmax(top_k_logits, dim=-1)
|
149 |
+
next_token_index = torch.multinomial(next_token_probs, num_samples=1)
|
150 |
+
next_token = top_k_indices.gather(-1, next_token_index)
|
151 |
+
|
152 |
+
input_ids = torch.cat([input_ids, next_token], dim=-1)
|
153 |
+
generated.append(next_token.item())
|
154 |
+
|
155 |
+
next_token_str = enc.decode([next_token.item()])
|
156 |
+
yield next_token_str
|
157 |
+
|
158 |
+
if next_token.item() == enc.encode('\n')[0] and len(generated) > 100:
|
159 |
+
break
|
|
|
160 |
|
161 |
+
await asyncio.sleep(0.02) # For typing effect
|
162 |
+
|
163 |
+
except Exception as e:
|
164 |
+
yield f"Error during generation: {e}"
|
165 |
+
return
|
166 |
+
|
167 |
+
# Gradio Generate Function
|
168 |
async def gradio_generate(prompt, max_length, temperature, top_k):
|
169 |
output = ""
|
170 |
async for token in generate_text(prompt, max_length, temperature, top_k):
|
171 |
output += token
|
172 |
yield output
|
173 |
+
output = post_process_text(output)
|
174 |
+
yield output
|
175 |
|
176 |
+
# Load the model (replace with your model path
|
|
|
|
|
|
|
|
|
177 |
|
178 |
css = """
|
179 |
<style>
|