RohanHBTU commited on
Commit
0adb0e5
1 Parent(s): b7d50cb

Upload 3 files

Browse files
Files changed (3) hide show
  1. app.py +132 -0
  2. best-model-version.ckpt +3 -0
  3. requirements.txt +5 -0
app.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flask import Flask, render_template, request, jsonify
2
+ #from transformers import AutoModelForCausalLM, AutoTokenizer
3
+ from transformers import T5Tokenizer, T5ForConditionalGeneration
4
+
5
+ import numpy as np
6
+ from transformers import AdamW
7
+ import pandas as pd
8
+ import torch
9
+ import pytorch_lightning as pl
10
+ from pytorch_lightning.callbacks import ModelCheckpoint
11
+ from torch.nn.utils.rnn import pad_sequence
12
+
13
+ MODEL_NAME='t5-base'
14
+ DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
15
+ INPUT_MAX_LEN = 512
16
+ OUTPUT_MAX_LEN = 512
17
+
18
+ #tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium")
19
+ #model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-medium")
20
+
21
+ tokenizer = T5Tokenizer.from_pretrained(MODEL_NAME, model_max_length=512)
22
+
23
+ app = Flask(__name__)
24
+ app.jinja_env.auto_reload = True
25
+ app.config['TEMPLATES_AUTO_RELOAD'] = True
26
+
27
+
28
+ @app.route("/")
29
+ def index():
30
+ return render_template('chat.html')
31
+
32
+
33
+ @app.route("/get", methods=["GET", "POST"])
34
+ def chat():
35
+ msg = request.form["msg"]
36
+ input = msg
37
+ return get_Chat_response(input)
38
+
39
+ class T5Model(pl.LightningModule):
40
+
41
+ def __init__(self):
42
+ super().__init__()
43
+ self.model = T5ForConditionalGeneration.from_pretrained(MODEL_NAME, return_dict = True)
44
+
45
+
46
+ def forward(self, input_ids, attention_mask, labels=None):
47
+
48
+ output = self.model(
49
+ input_ids=input_ids,
50
+ attention_mask=attention_mask,
51
+ labels=labels
52
+ )
53
+ return output.loss, output.logits
54
+
55
+ def training_step(self, batch, batch_idx):
56
+
57
+ input_ids = batch["input_ids"]
58
+ attention_mask = batch["attention_mask"]
59
+ labels= batch["target"]
60
+ loss, logits = self(input_ids , attention_mask, labels)
61
+
62
+
63
+ self.log("train_loss", loss, prog_bar=True, logger=True)
64
+
65
+ return {'loss': loss}
66
+
67
+ def validation_step(self, batch, batch_idx):
68
+ input_ids = batch["input_ids"]
69
+ attention_mask = batch["attention_mask"]
70
+ labels= batch["target"]
71
+ loss, logits = self(input_ids, attention_mask, labels)
72
+
73
+ self.log("val_loss", loss, prog_bar=True, logger=True)
74
+
75
+ return {'val_loss': loss}
76
+
77
+ def configure_optimizers(self):
78
+ return AdamW(self.parameters(), lr=0.0001)
79
+
80
+ train_model = T5Model.load_from_checkpoint('best-model-version.ckpt',map_location=DEVICE)
81
+ train_model.freeze()
82
+
83
+ def get_Chat_response(question):
84
+
85
+ inputs_encoding = tokenizer(
86
+ question,
87
+ add_special_tokens=True,
88
+ max_length= INPUT_MAX_LEN,
89
+ padding = 'max_length',
90
+ truncation='only_first',
91
+ return_attention_mask=True,
92
+ return_tensors="pt"
93
+ )
94
+
95
+
96
+ generate_ids = train_model.model.generate(
97
+ input_ids = inputs_encoding["input_ids"],
98
+ attention_mask = inputs_encoding["attention_mask"],
99
+ max_length = INPUT_MAX_LEN,
100
+ num_beams = 4,
101
+ num_return_sequences = 1,
102
+ no_repeat_ngram_size=2,
103
+ early_stopping=True,
104
+ )
105
+
106
+ preds = [
107
+ tokenizer.decode(gen_id,
108
+ skip_special_tokens=True,
109
+ clean_up_tokenization_spaces=True)
110
+ for gen_id in generate_ids
111
+ ]
112
+
113
+ return "".join(preds)
114
+
115
+ #def get_Chat_response(text):
116
+ #
117
+ # # Let's chat for 5 lines
118
+ # for step in range(5):
119
+ # # encode the new user input, add the eos_token and return a tensor in Pytorch
120
+ # new_user_input_ids = tokenizer.encode(str(text) + tokenizer.eos_token, return_tensors='pt')
121
+ #
122
+ # # append the new user input tokens to the chat history
123
+ # bot_input_ids = torch.cat([chat_history_ids, new_user_input_ids], dim=-1) if step > 0 else new_user_input_ids
124
+ #
125
+ # # generated a response while limiting the total chat history to 1000 tokens,
126
+ # chat_history_ids = model.generate(bot_input_ids, max_length=1000, pad_token_id=tokenizer.eos_token_id)
127
+ #
128
+ # # pretty print last ouput tokens from bot
129
+ # return tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)
130
+
131
+ if __name__ == '__main__':
132
+ app.run(debug=True)
best-model-version.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c2a974dd588bf796be688aeae2b0d764e0fe0f48e3c689919cfea2c23bf98317
3
+ size 2675123319
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ transformers==4.27.4
2
+ pandas==1.5.3
3
+ torch==2.0.0
4
+ pytorch-lightning==2.0.2
5
+ flask==3.0.0