Spaces:
Runtime error
Runtime error
File size: 4,075 Bytes
4ecf8d9 530d98b fdb5dd9 530d98b a5b6307 530d98b fdb5dd9 530d98b fdb5dd9 530d98b fdb5dd9 530d98b fdb5dd9 9045e66 fdb5dd9 530d98b fdb5dd9 530d98b 849c4b7 530d98b fdb5dd9 530d98b fdb5dd9 530d98b fdb5dd9 530d98b a5b6307 fdb5dd9 530d98b fdb5dd9 19d1bcf 530d98b fdb5dd9 530d98b fdb5dd9 19d1bcf 77f4643 19d1bcf 1912310 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 |
import gradio as gr
from typing import Dict, List
import torch
torch.backends.cudnn.enabled = False
import json
import pickle
from pathlib import Path
from utils import Vocab
from model import SeqClassifier
import re
# Set model parameters
max_len = 128
hidden_size = 256
num_layers = 2
dropout = 0.1
bidirectional = True
device = "cpu"
ckpt_dir = Path("./ckpt/intent/")
cache_dir = Path("./cache/intent/")
# Load vocabulary and intent index mapping
with open(cache_dir / "vocab.pkl", "rb") as f:
vocab: Vocab = pickle.load(f)
intent_idx_path = cache_dir / "intent2idx.json"
intent2idx: Dict[str, int] = json.loads(intent_idx_path.read_text())
__idx2label = {idx: intent for intent, idx in intent2idx.items()}
def idx2label(idx: int):
return __idx2label[idx]
# Set embedding layer size
embeddings_size = (5621, 300)
embeddings = torch.empty(embeddings_size)
embeddings.to(device)
# Load the best model
best_model = SeqClassifier(
embeddings=embeddings,
hidden_size=hidden_size,
num_layers=num_layers,
dropout=dropout,
bidirectional=bidirectional,
num_class=len(intent2idx)
).to(device)
# Define the path to the checkpoint file
ckpt_path = ckpt_dir / "intent_checkpoint.pth"
# Load the model's weights
checkpoint = torch.load(ckpt_path, map_location=torch.device('cpu'))
best_model.load_state_dict(checkpoint['model_state_dict'])
# Set the model to evaluation mode
best_model.eval()
# Processing function to convert text to embedding indices
def collate_fn(texts: str) -> torch.tensor:
texts = re.findall(r"\w+|[^\w\s]", texts)
encoded_texts = vocab.encode_batch([[text for text in texts]], to_len=max_len)[0]
encoded_text = torch.tensor(encoded_texts)
return encoded_text
# Classification function
def classifier(text):
encoded_text = collate_fn(text).to(device)
output = best_model(encoded_text)
Predicted_class = torch.argmax(output).item()
prediction = idx2label(Predicted_class)
return prediction
import gradio as gr
from gradio.components import Textbox
import random
def random_sample():
random_number = random.randint(0, len(examples) - 1)
return examples[random_number]
examples=[
"what are some fun things i can partake in in atlanta",
"how do i make pumpkin pie",
"what's the currency conversion between rubles and pounds",
"please set an alarm for mid day",
"how many hours will it take to get to my destination",
"so i made a fraudulent transaction",
"tell lydia and laura where i am located",
"i want you to talk more quickly",
"what's the deal with my health care",
"What's the exchange rate for rubles to pounds",
"How long will it take to reach my destination",
"I suspect a fraudulent transaction on my account",
"Inform Lydia and Laura of my current location",
"I'd like you to speak faster",
"Can you provide information about my health care",
"Give me the details on my health insurance",
"What's the local time now",
"Find a recipe for chocolate chip cookies",
"Check my credit card balance",
"Translate 'Hello' to French",
"Recommend a good restaurant nearby",
]
title="Text Intent Classification"
description="""
# Text Intent Classification
This demo uses a model to classify text into different intents or categories. Enter a text and see the classification result.
"""
with gr.Blocks(theme=gr.themes.Soft(), title="Question Answering") as demo:
gr.Markdown(description)
with gr.Row():
C_input = gr.Textbox(lines=3, label="Context paragraph", placeholder="Please enter text")
A_output = Textbox(lines=3, label="Category")
with gr.Row():
random_button = gr.Button("Random")
classifier_button = gr.Button("classifier")
random_button.click(random_sample, inputs=None, outputs=C_input)
classifier_button.click(classifier, inputs=C_input, outputs=A_output)
demo.launch(share=True) |