Rinkalkumar M. Shah commited on
Commit
cac5aaf
1 Parent(s): 624c056

update in App.py

Browse files
Files changed (1) hide show
  1. app.py +17 -26
app.py CHANGED
@@ -11,7 +11,7 @@ elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
11
  device = "mps"
12
 
13
  model = GPT(GPTConfig())
14
- ckpt = torch.load("gpt2.pt", map_location=torch.device(device))
15
  unwanted_prefix = '_orig_mod.'
16
  for k,v in list(ckpt.items()):
17
  if k.startswith(unwanted_prefix):
@@ -28,22 +28,13 @@ def inference(input_text, num_return_sequences, max_length):
28
  x = input_tokens.to(device)
29
 
30
  while x.size(1) < max_length:
31
- # forward the model to get the logits
32
  with torch.no_grad():
33
- logits = model(x)[0] # (B, T, vocab_size)
34
- # take the logits at the last position
35
- logits = logits[:, -1, :] # (B, vocab_size)
36
- # get the probabilities
37
  probs = F.softmax(logits, dim=-1)
38
- # do top-k sampling of 50 (huggingface pipeline default)
39
- # topk_probs here becomes (5, 50), topk_indices is (5, 50)
40
  topk_probs, topk_indices = torch.topk(probs, 50, dim=-1)
41
- # select a token from the top-k probabilities
42
- # note: multinomial does not demand the input to sum to 1
43
  ix = torch.multinomial(topk_probs, 1) # (B, 1)
44
- # gather the corresponding indices
45
  xcol = torch.gather(topk_indices, -1, ix) # (B, 1)
46
- # append to the sequence
47
  x = torch.cat((x, xcol), dim=1)
48
 
49
  decode_list = []
@@ -56,25 +47,25 @@ def inference(input_text, num_return_sequences, max_length):
56
  output = "\n======\n".join(decode_list)
57
  return output
58
 
59
- title = "GPT-2 trained on Shakespeare Plays dataset"
60
- description = "A simple Gradio interface to generate text from GPT-2 model trained on Shakespeare Plays"
61
- examples = [["Please put on these earmuffs because I can't hear you.", 5, 50],
62
- ["Twin 4-month-olds slept in the shade of the palm tree while the mother tanned in the sun.", 5, 50],
63
- ["Happiness can be found in the depths of chocolate pudding.", 5, 50],
64
- ["Seek success, but always be prepared for random cats.", 5, 50],
65
- ["This made him feel like an old-style rootbeer float smells.", 5, 50],
66
- ["The view from the lighthouse excited even the most seasoned traveler.", 5, 50],
67
- ["I've always wanted to go to Tajikistan, but my cat would miss me.", 5, 50],
68
- ["He found rain fascinating yet unpleasant.", 5, 50],
69
- ["Plans for this weekend include turning wine into water.", 5, 50],
70
- ["Iron pyrite is the most foolish of all minerals.", 5, 50],
71
  ]
72
  demo = gr.Interface(
73
  inference,
74
  inputs = [
75
  gr.Textbox(label="Enter some text", type="text"),
76
- gr.Slider(minimum=1, maximum=5, step=1, value=5, label="Number of outputs"),
77
- gr.Slider(minimum=10, maximum=100, step=1, value=50, label="Maximum lenght of a sequence")
78
  ],
79
  outputs = [
80
  gr.Textbox(label="Output", type="text")
 
11
  device = "mps"
12
 
13
  model = GPT(GPTConfig())
14
+ ckpt = torch.load("model.pt", map_location=torch.device(device))
15
  unwanted_prefix = '_orig_mod.'
16
  for k,v in list(ckpt.items()):
17
  if k.startswith(unwanted_prefix):
 
28
  x = input_tokens.to(device)
29
 
30
  while x.size(1) < max_length:
 
31
  with torch.no_grad():
32
+ logits = model(x)[0]
33
+ logits = logits[:, -1, :]
 
 
34
  probs = F.softmax(logits, dim=-1)
 
 
35
  topk_probs, topk_indices = torch.topk(probs, 50, dim=-1)
 
 
36
  ix = torch.multinomial(topk_probs, 1) # (B, 1)
 
37
  xcol = torch.gather(topk_indices, -1, ix) # (B, 1)
 
38
  x = torch.cat((x, xcol), dim=1)
39
 
40
  decode_list = []
 
47
  output = "\n======\n".join(decode_list)
48
  return output
49
 
50
+ title = "GPT-2"
51
+ description = "A simple Gradio interface to generate text from GPT-2 model trained on Shakespeare Plays dataset"
52
+ examples = [["This above all: to thine own self be true", 5, 50],
53
+ ["Our doubts are traitors, And make us lose the good we oft might win, By fearing to attempt.", 5, 50],
54
+ ["There is nothing either good or bad, but thinking makes it so.", 5, 50],
55
+ ["I do love nothing in the world so well as you: is not that strange?", 5, 50],
56
+ ["The course of true love never did run smooth.", 5, 50],
57
+ ["Love looks not with the eyes, but with the mind; and therefore is winged Cupid painted blind.", 5, 50],
58
+ ["We know what we are, but know not what we may be.", 5, 50],
59
+ ["As, I confess, it is my nature's plague To spy into abuses, and oft my jealousy Shapes faults that are not.", 5, 50],
60
+ ["Good company, good wine, good welcome can make good people", 5, 50],
61
+ ["Better three hours too soon than a minute late", 5, 50],
62
  ]
63
  demo = gr.Interface(
64
  inference,
65
  inputs = [
66
  gr.Textbox(label="Enter some text", type="text"),
67
+ gr.Slider(minimum=1, maximum=10, step=1, value=5, label="Number of outputs"),
68
+ gr.Slider(minimum=10, maximum=500, step=1, value=50, label="Maximum lenght of a sequence")
69
  ],
70
  outputs = [
71
  gr.Textbox(label="Output", type="text")