from flask import Flask, render_template, request, jsonify import os os.environ['TRANSFORMERS_CACHE'] = '/code/cache/' #os.environ['SENTENCE_TRANSFORMERS_HOME'] = './.cache' #from transformers import AutoModelForCausalLM, AutoTokenizer from transformers import T5Tokenizer, T5ForConditionalGeneration #import numpy as np from transformers import AdamW #import pandas as pd import torch import pytorch_lightning as pl from pytorch_lightning.callbacks import ModelCheckpoint from torch.nn.utils.rnn import pad_sequence MODEL_NAME='t5-base' DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') INPUT_MAX_LEN = 512 OUTPUT_MAX_LEN = 512 #tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium") #model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-medium") tokenizer = T5Tokenizer.from_pretrained(MODEL_NAME, model_max_length=512) app = Flask(__name__) app.jinja_env.auto_reload = True app.config['TEMPLATES_AUTO_RELOAD'] = True @app.route("/") def index(): return render_template('chat.html') @app.route("/get", methods=["GET", "POST"]) def chat(): msg = request.form["msg"] input = msg return get_Chat_response(input) class T5Model(pl.LightningModule): def __init__(self): super().__init__() self.model = T5ForConditionalGeneration.from_pretrained(MODEL_NAME, return_dict = True) def forward(self, input_ids, attention_mask, labels=None): output = self.model( input_ids=input_ids, attention_mask=attention_mask, labels=labels ) return output.loss, output.logits def training_step(self, batch, batch_idx): input_ids = batch["input_ids"] attention_mask = batch["attention_mask"] labels= batch["target"] loss, logits = self(input_ids , attention_mask, labels) self.log("train_loss", loss, prog_bar=True, logger=True) return {'loss': loss} def validation_step(self, batch, batch_idx): input_ids = batch["input_ids"] attention_mask = batch["attention_mask"] labels= batch["target"] loss, logits = self(input_ids, attention_mask, labels) self.log("val_loss", loss, prog_bar=True, logger=True) return {'val_loss': loss} def configure_optimizers(self): return AdamW(self.parameters(), lr=0.0001) train_model = T5Model.load_from_checkpoint('best-model-version.ckpt',map_location=DEVICE) train_model.freeze() def get_Chat_response(question): inputs_encoding = tokenizer( question, add_special_tokens=True, max_length= INPUT_MAX_LEN, padding = 'max_length', truncation='only_first', return_attention_mask=True, return_tensors="pt" ) generate_ids = train_model.model.generate( input_ids = inputs_encoding["input_ids"], attention_mask = inputs_encoding["attention_mask"], max_length = INPUT_MAX_LEN, num_beams = 4, num_return_sequences = 1, no_repeat_ngram_size=2, early_stopping=True, ) preds = [ tokenizer.decode(gen_id, skip_special_tokens=True, clean_up_tokenization_spaces=True) for gen_id in generate_ids ] return "".join(preds) #def get_Chat_response(text): # # # Let's chat for 5 lines # for step in range(5): # # encode the new user input, add the eos_token and return a tensor in Pytorch # new_user_input_ids = tokenizer.encode(str(text) + tokenizer.eos_token, return_tensors='pt') # # # append the new user input tokens to the chat history # bot_input_ids = torch.cat([chat_history_ids, new_user_input_ids], dim=-1) if step > 0 else new_user_input_ids # # # generated a response while limiting the total chat history to 1000 tokens, # chat_history_ids = model.generate(bot_input_ids, max_length=1000, pad_token_id=tokenizer.eos_token_id) # # # pretty print last ouput tokens from bot # return tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True) if __name__ == '__main__': app.run(debug=True)