Spaces:
Runtime error
Runtime error
File size: 2,589 Bytes
4ecf8d9 530d98b fdb5dd9 530d98b fdb5dd9 530d98b fdb5dd9 530d98b fdb5dd9 530d98b fdb5dd9 530d98b fdb5dd9 530d98b 849c4b7 530d98b fdb5dd9 530d98b fdb5dd9 530d98b fdb5dd9 530d98b fdb5dd9 530d98b fdb5dd9 530d98b fdb5dd9 530d98b fdb5dd9 530d98b fdb5dd9 9106e6c fdb5dd9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 |
import gradio as gr
from typing import Dict, List
import torch
torch.backends.cudnn.enabled = False
import json
import pickle
from pathlib import Path
from utils import Vocab
from model import SeqClassifier
# Set model parameters
max_len = 128
hidden_size = 256
num_layers = 2
dropout = 0.1
bidirectional = True
device = "cpu"
ckpt_dir = Path("./ckpt/intent/")
cache_dir = Path("./cache/intent/")
# Load vocabulary and intent index mapping
with open(cache_dir / "vocab.pkl", "rb") as f:
vocab: Vocab = pickle.load(f)
intent_idx_path = cache_dir / "intent2idx.json"
intent2idx: Dict[str, int] = json.loads(intent_idx_path.read_text())
__idx2label = {idx: intent for intent, idx in intent2idx.items()}
def idx2label(idx: int):
return __idx2label[idx]
# Set embedding layer size
embeddings_size = (6491, 300)
embeddings = torch.empty(embeddings_size)
embeddings.to(device)
# Load the best model
best_model = SeqClassifier(
embeddings=embeddings,
hidden_size=hidden_size,
num_layers=num_layers,
dropout=dropout,
bidirectional=bidirectional,
num_class=len(intent2idx)
).to(device)
# Define the path to the checkpoint file
ckpt_path = ckpt_dir / "intent_checkpoint.pth"
# Load the model's weights
checkpoint = torch.load(ckpt_path, map_location=torch.device('cpu'))
best_model.load_state_dict(checkpoint['model_state_dict'])
# Set the model to evaluation mode
best_model.eval()
# Processing function to convert text to embedding indices
def collate_fn(texts: str) -> torch.tensor:
texts = texts.split()
encoded_texts = vocab.encode_batch([[text for text in texts]], to_len=max_len)[0]
encoded_text = torch.tensor(encoded_texts)
return encoded_text
# Classification function
def classify(text):
encoded_text = collate_fn(text).to(device)
output = best_model(encoded_text)
Predicted_class = torch.argmax(output).item()
prediction = idx2label(Predicted_class)
return "Category:" + prediction
# Create a Gradio interface
demo = gr.Interface(
fn=classify,
inputs=gr.Textbox(placeholder="Please enter a text..."),
outputs="label",
interpretation="none",
live=False,
enable_queue=True,
examples=[
["please set an alarm for mid day"],
["tell lydia and laura where i am located"],
["what's the deal with my health care"]
],
title="Text Intent Classification",
description="This demo uses a model to classify text into different intents or categories. Enter a text and see the classification result."
)
# Launch the Gradio interface
demo.launch(share=True)
|