Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -129,13 +129,15 @@ def load_model(model_path):
|
|
129 |
model.to(device)
|
130 |
return model
|
131 |
|
132 |
-
#
|
133 |
-
model = load_model('gpt_model.pth')
|
134 |
enc = tiktoken.get_encoding('gpt2')
|
135 |
|
136 |
# Update the generate_text function
|
137 |
-
@spaces.GPU(duration=60)
|
138 |
async def generate_text(prompt, max_length=432, temperature=0.8, top_k=40):
|
|
|
|
|
139 |
device = next(model.parameters()).device
|
140 |
input_ids = torch.tensor(enc.encode(prompt)).unsqueeze(0).to(device)
|
141 |
generated = []
|
@@ -159,18 +161,11 @@ async def generate_text(prompt, max_length=432, temperature=0.8, top_k=40):
|
|
159 |
if next_token.item() == enc.encode('\n')[0] and len(generated) > 100:
|
160 |
break
|
161 |
|
162 |
-
await asyncio.sleep(0.02)
|
163 |
|
164 |
if len(generated) == max_length:
|
165 |
yield "... (output truncated due to length)"
|
166 |
|
167 |
-
# Update the gradio_generate function
|
168 |
-
@spaces.GPU(duration=60) # Adjust duration as needed
|
169 |
-
async def gradio_generate(prompt, max_length, temperature, top_k):
|
170 |
-
output = ""
|
171 |
-
async for token in generate_text(prompt, max_length, temperature, top_k):
|
172 |
-
output += token
|
173 |
-
yield output
|
174 |
|
175 |
# # Your existing imports and model code here...
|
176 |
|
|
|
129 |
model.to(device)
|
130 |
return model
|
131 |
|
132 |
+
# Don't load the model here
|
133 |
+
# model = load_model('gpt_model.pth')
|
134 |
enc = tiktoken.get_encoding('gpt2')
|
135 |
|
136 |
# Update the generate_text function
|
137 |
+
@spaces.GPU(duration=60)
|
138 |
async def generate_text(prompt, max_length=432, temperature=0.8, top_k=40):
|
139 |
+
# Load the model inside the GPU-decorated function
|
140 |
+
model = load_model('gpt_model.pth')
|
141 |
device = next(model.parameters()).device
|
142 |
input_ids = torch.tensor(enc.encode(prompt)).unsqueeze(0).to(device)
|
143 |
generated = []
|
|
|
161 |
if next_token.item() == enc.encode('\n')[0] and len(generated) > 100:
|
162 |
break
|
163 |
|
164 |
+
await asyncio.sleep(0.02)
|
165 |
|
166 |
if len(generated) == max_length:
|
167 |
yield "... (output truncated due to length)"
|
168 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
169 |
|
170 |
# # Your existing imports and model code here...
|
171 |
|