sagar007 commited on
Commit
d98db84
·
verified ·
1 Parent(s): 532845f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -11
app.py CHANGED
@@ -129,13 +129,15 @@ def load_model(model_path):
129
  model.to(device)
130
  return model
131
 
132
- # Load the model
133
- model = load_model('gpt_model.pth') # Replace with the actual path to your .pt file
134
  enc = tiktoken.get_encoding('gpt2')
135
 
136
  # Update the generate_text function
137
- @spaces.GPU(duration=60) # Adjust duration as needed
138
  async def generate_text(prompt, max_length=432, temperature=0.8, top_k=40):
 
 
139
  device = next(model.parameters()).device
140
  input_ids = torch.tensor(enc.encode(prompt)).unsqueeze(0).to(device)
141
  generated = []
@@ -159,18 +161,11 @@ async def generate_text(prompt, max_length=432, temperature=0.8, top_k=40):
159
  if next_token.item() == enc.encode('\n')[0] and len(generated) > 100:
160
  break
161
 
162
- await asyncio.sleep(0.02) # Slightly faster typing effect
163
 
164
  if len(generated) == max_length:
165
  yield "... (output truncated due to length)"
166
 
167
- # Update the gradio_generate function
168
- @spaces.GPU(duration=60) # Adjust duration as needed
169
- async def gradio_generate(prompt, max_length, temperature, top_k):
170
- output = ""
171
- async for token in generate_text(prompt, max_length, temperature, top_k):
172
- output += token
173
- yield output
174
 
175
  # # Your existing imports and model code here...
176
 
 
129
  model.to(device)
130
  return model
131
 
132
+ # Don't load the model here
133
+ # model = load_model('gpt_model.pth')
134
  enc = tiktoken.get_encoding('gpt2')
135
 
136
  # Update the generate_text function
137
+ @spaces.GPU(duration=60)
138
  async def generate_text(prompt, max_length=432, temperature=0.8, top_k=40):
139
+ # Load the model inside the GPU-decorated function
140
+ model = load_model('gpt_model.pth')
141
  device = next(model.parameters()).device
142
  input_ids = torch.tensor(enc.encode(prompt)).unsqueeze(0).to(device)
143
  generated = []
 
161
  if next_token.item() == enc.encode('\n')[0] and len(generated) > 100:
162
  break
163
 
164
+ await asyncio.sleep(0.02)
165
 
166
  if len(generated) == max_length:
167
  yield "... (output truncated due to length)"
168
 
 
 
 
 
 
 
 
169
 
170
  # # Your existing imports and model code here...
171