Rinkalkumar M. Shah
commited on
Commit
•
cac5aaf
1
Parent(s):
624c056
update in App.py
Browse files
app.py
CHANGED
@@ -11,7 +11,7 @@ elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
|
11 |
device = "mps"
|
12 |
|
13 |
model = GPT(GPTConfig())
|
14 |
-
ckpt = torch.load("
|
15 |
unwanted_prefix = '_orig_mod.'
|
16 |
for k,v in list(ckpt.items()):
|
17 |
if k.startswith(unwanted_prefix):
|
@@ -28,22 +28,13 @@ def inference(input_text, num_return_sequences, max_length):
|
|
28 |
x = input_tokens.to(device)
|
29 |
|
30 |
while x.size(1) < max_length:
|
31 |
-
# forward the model to get the logits
|
32 |
with torch.no_grad():
|
33 |
-
logits = model(x)[0]
|
34 |
-
|
35 |
-
logits = logits[:, -1, :] # (B, vocab_size)
|
36 |
-
# get the probabilities
|
37 |
probs = F.softmax(logits, dim=-1)
|
38 |
-
# do top-k sampling of 50 (huggingface pipeline default)
|
39 |
-
# topk_probs here becomes (5, 50), topk_indices is (5, 50)
|
40 |
topk_probs, topk_indices = torch.topk(probs, 50, dim=-1)
|
41 |
-
# select a token from the top-k probabilities
|
42 |
-
# note: multinomial does not demand the input to sum to 1
|
43 |
ix = torch.multinomial(topk_probs, 1) # (B, 1)
|
44 |
-
# gather the corresponding indices
|
45 |
xcol = torch.gather(topk_indices, -1, ix) # (B, 1)
|
46 |
-
# append to the sequence
|
47 |
x = torch.cat((x, xcol), dim=1)
|
48 |
|
49 |
decode_list = []
|
@@ -56,25 +47,25 @@ def inference(input_text, num_return_sequences, max_length):
|
|
56 |
output = "\n======\n".join(decode_list)
|
57 |
return output
|
58 |
|
59 |
-
title = "GPT-2
|
60 |
-
description = "A simple Gradio interface to generate text from GPT-2 model trained on Shakespeare Plays"
|
61 |
-
examples = [["
|
62 |
-
["
|
63 |
-
["
|
64 |
-
["
|
65 |
-
["
|
66 |
-
["
|
67 |
-
["
|
68 |
-
["
|
69 |
-
["
|
70 |
-
["
|
71 |
]
|
72 |
demo = gr.Interface(
|
73 |
inference,
|
74 |
inputs = [
|
75 |
gr.Textbox(label="Enter some text", type="text"),
|
76 |
-
gr.Slider(minimum=1, maximum=
|
77 |
-
gr.Slider(minimum=10, maximum=
|
78 |
],
|
79 |
outputs = [
|
80 |
gr.Textbox(label="Output", type="text")
|
|
|
11 |
device = "mps"
|
12 |
|
13 |
model = GPT(GPTConfig())
|
14 |
+
ckpt = torch.load("model.pt", map_location=torch.device(device))
|
15 |
unwanted_prefix = '_orig_mod.'
|
16 |
for k,v in list(ckpt.items()):
|
17 |
if k.startswith(unwanted_prefix):
|
|
|
28 |
x = input_tokens.to(device)
|
29 |
|
30 |
while x.size(1) < max_length:
|
|
|
31 |
with torch.no_grad():
|
32 |
+
logits = model(x)[0]
|
33 |
+
logits = logits[:, -1, :]
|
|
|
|
|
34 |
probs = F.softmax(logits, dim=-1)
|
|
|
|
|
35 |
topk_probs, topk_indices = torch.topk(probs, 50, dim=-1)
|
|
|
|
|
36 |
ix = torch.multinomial(topk_probs, 1) # (B, 1)
|
|
|
37 |
xcol = torch.gather(topk_indices, -1, ix) # (B, 1)
|
|
|
38 |
x = torch.cat((x, xcol), dim=1)
|
39 |
|
40 |
decode_list = []
|
|
|
47 |
output = "\n======\n".join(decode_list)
|
48 |
return output
|
49 |
|
50 |
+
title = "GPT-2"
|
51 |
+
description = "A simple Gradio interface to generate text from GPT-2 model trained on Shakespeare Plays dataset"
|
52 |
+
examples = [["This above all: to thine own self be true", 5, 50],
|
53 |
+
["Our doubts are traitors, And make us lose the good we oft might win, By fearing to attempt.", 5, 50],
|
54 |
+
["There is nothing either good or bad, but thinking makes it so.", 5, 50],
|
55 |
+
["I do love nothing in the world so well as you: is not that strange?", 5, 50],
|
56 |
+
["The course of true love never did run smooth.", 5, 50],
|
57 |
+
["Love looks not with the eyes, but with the mind; and therefore is winged Cupid painted blind.", 5, 50],
|
58 |
+
["We know what we are, but know not what we may be.", 5, 50],
|
59 |
+
["As, I confess, it is my nature's plague To spy into abuses, and oft my jealousy Shapes faults that are not.", 5, 50],
|
60 |
+
["Good company, good wine, good welcome can make good people", 5, 50],
|
61 |
+
["Better three hours too soon than a minute late", 5, 50],
|
62 |
]
|
63 |
demo = gr.Interface(
|
64 |
inference,
|
65 |
inputs = [
|
66 |
gr.Textbox(label="Enter some text", type="text"),
|
67 |
+
gr.Slider(minimum=1, maximum=10, step=1, value=5, label="Number of outputs"),
|
68 |
+
gr.Slider(minimum=10, maximum=500, step=1, value=50, label="Maximum lenght of a sequence")
|
69 |
],
|
70 |
outputs = [
|
71 |
gr.Textbox(label="Output", type="text")
|