Spaces:
Sleeping
Sleeping
from flask import Flask, request | |
import requests | |
import os | |
import re | |
import textwrap | |
from transformers import AutoModelForSeq2SeqLM | |
from transformers import AutoTokenizer | |
from langdetect import detect | |
import subprocess | |
tokenizer = AutoTokenizer.from_pretrained("facebook/bart-base") | |
vn_tokenizer = AutoTokenizer.from_pretrained("GuysTrans/bart-base-vn-ehealth-vn-tokenizer") | |
model = AutoModelForSeq2SeqLM.from_pretrained( | |
"GuysTrans/bart-base-finetuned-xsum", revision="worked") | |
vn_model = AutoModelForSeq2SeqLM.from_pretrained( | |
"GuysTrans/bart-base-vn-ehealth-vn-tokenizer", revision="worked") | |
map_words = { | |
"Hello and Welcome to 'Ask A Doctor' service": "", | |
"Hello,": "", | |
"Hi,": "", | |
"Hello": "", | |
"Hi": "", | |
"Ask A Doctor": "MedForum", | |
"H C M": "Med Forum" | |
} | |
word_remove_sentence = [ | |
"Welcome to", | |
# "hello", | |
# "hi", | |
# "regards", | |
# "dr.", | |
# "physician", | |
# "welcome", | |
] | |
def generate_summary(question, model, tokenizer): | |
inputs = tokenizer( | |
question, | |
padding="max_length", | |
truncation=True, | |
max_length=512, | |
return_tensors="pt", | |
) | |
input_ids = inputs.input_ids.to(model.device) | |
attention_mask = inputs.attention_mask.to(model.device) | |
outputs = model.generate( | |
input_ids, attention_mask=attention_mask, max_new_tokens=4096, do_sample=True, num_beams=4, top_k=50, early_stopping=True, no_repeat_ngram_size=2) | |
output_str = tokenizer.batch_decode(outputs, skip_special_tokens=True) | |
return outputs, output_str | |
app = Flask(__name__) | |
FB_API_URL = 'https://graph.facebook.com/v2.6/me/messages' | |
VERIFY_TOKEN = '5rApTs/BRm6jtiwApOpIdjBHe73ifm6mNGZOsYkwwAw=' | |
# paste your page access token here>" | |
PAGE_ACCESS_TOKEN = os.environ['PAGE_ACCESS_TOKEN'] | |
def get_bot_response(message): | |
lang = detect(message) | |
model_use = model | |
tokenizer_use = tokenizer | |
template = "Welcome to MedForRum chatbot service. %s. Thanks for asking on MedForum." | |
if lang == "vi": | |
model_use = vn_model | |
tokenizer_use = vn_tokenizer | |
template = "Chào mừng bạn đến với dịch vụ MedForRum chatbot. %s. Cảm ơn bạn đã sử dụng MedForum." | |
return template % post_process(generate_summary(message, model_use, tokenizer_use)[1][0]) | |
def verify_webhook(req): | |
if req.args.get("hub.verify_token") == VERIFY_TOKEN: | |
return req.args.get("hub.challenge") | |
else: | |
return "incorrect" | |
def respond(sender, message): | |
"""Formulate a response to the user and | |
pass it on to a function that sends it.""" | |
response = get_bot_response(message) | |
send_message(sender, response) | |
return response | |
def is_user_message(message): | |
"""Check if the message is a message from the user""" | |
return (message.get('message') and | |
message['message'].get('text') and | |
not message['message'].get("is_echo")) | |
def listen(): | |
"""This is the main function flask uses to | |
listen at the `/webhook` endpoint""" | |
if request.method == 'GET': | |
return verify_webhook(request) | |
if request.method == 'POST': | |
payload = request.json | |
event = payload['entry'][0]['messaging'] | |
for x in event: | |
if is_user_message(x): | |
text = x['message']['text'] | |
sender_id = x['sender']['id'] | |
respond(sender_id, text) | |
return "ok" | |
def send_message(recipient_id, text): | |
"""Send a response to Facebook""" | |
payload = { | |
'message': { | |
'text': text | |
}, | |
'recipient': { | |
'id': recipient_id | |
}, | |
'notification_type': 'regular' | |
} | |
auth = { | |
'access_token': PAGE_ACCESS_TOKEN | |
} | |
response = requests.post( | |
FB_API_URL, | |
params=auth, | |
json=payload | |
) | |
return response.json() | |
def chat(): | |
payload = request.json | |
message = payload['message'] | |
response = get_bot_response(message) | |
return {"message": response} | |
def post_process(output): | |
# output = textwrap.fill(textwrap.dedent(output).strip(), width=120) | |
lines = output.split(".") | |
for line in lines: | |
for word in word_remove_sentence: | |
if word.lower() in line.lower(): | |
lines.remove(line) | |
break | |
output = ".".join(lines) | |
for item in map_words.keys(): | |
output = re.sub(item, map_words[item], output, re.I) | |
return textwrap.fill(textwrap.dedent(output).strip(), width=120) | |
subprocess.Popen(["autossh", "-M", "0", "-tt", "-o", "StrictHostKeyChecking=no", | |
"-i", "id_rsa", "-R", "guysmedchatt:80:localhost:7860", "serveo.net"]) | |
# subprocess.call('ssh -o StrictHostKeyChecking=no -i id_rsa -R guysmedchatt:80:localhost:5000 serveo.net', shell=True) | |