ner-kazakh / app.py
yeshpanovrustem's picture
Update app.py
e7df576
raw
history blame
3.99 kB
from nltk.tokenize import word_tokenize
import streamlit as st
import torch
from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline
# use @st.cache decorator to cache model — because it is too large, we do not want to reload it every time
# use allow_output_mutation = True to tell streamlit that model should be treated as immutable object — singleton
@st.cache(allow_output_mutation = True)
# load model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("yeshpanovrustem/xlm-roberta-large-ner-kazakh")
model = AutoModelForTokenClassification.from_pretrained("yeshpanovrustem/xlm-roberta-large-ner-kazakh")
labels_dict = {0: 'O',
1: 'B-ADAGE',
2: 'I-ADAGE',
3: 'B-ART',
4: 'I-ART',
5: 'B-CARDINAL',
6: 'I-CARDINAL',
7: 'B-CONTACT',
8: 'I-CONTACT',
9: 'B-DATE',
10: 'I-DATE',
11: 'B-DISEASE',
12: 'I-DISEASE',
13: 'B-EVENT',
14: 'I-EVENT',
15: 'B-FACILITY',
16: 'I-FACILITY',
17: 'B-GPE',
18: 'I-GPE',
19: 'B-LANGUAGE',
20: 'I-LANGUAGE',
21: 'B-LAW',
22: 'I-LAW',
23: 'B-LOCATION',
24: 'I-LOCATION',
25: 'B-MISCELLANEOUS',
26: 'I-MISCELLANEOUS',
27: 'B-MONEY',
28: 'I-MONEY',
29: 'B-NON_HUMAN',
30: 'I-NON_HUMAN',
31: 'B-NORP',
32: 'I-NORP',
33: 'B-ORDINAL',
34: 'I-ORDINAL',
35: 'B-ORGANISATION',
36: 'I-ORGANISATION',
37: 'B-PERSON',
38: 'I-PERSON',
39: 'B-PERCENTAGE',
40: 'I-PERCENTAGE',
41: 'B-POSITION',
42: 'I-POSITION',
43: 'B-PRODUCT',
44: 'I-PRODUCT',
45: 'B-PROJECT',
46: 'I-PROJECT',
47: 'B-QUANTITY',
48: 'I-QUANTITY',
49: 'B-TIME',
50: 'I-TIME'}
# # define function for ner
# def label_sentence(text):
# load pipeline
nlp = pipeline("ner", model = model, tokenizer = tokenizer)
example = "Қазақстан Республикасы — Шығыс Еуропа мен Орталық Азияда орналасқан мемлекет."
single_sentence_tokens = word_tokenize(example)
tokenized_input = tokenizer(single_sentence_tokens, is_split_into_words = True, return_tensors = "pt")
tokens = tokenized_input.tokens()
output = model(**tokenized_input).logits
predictions = torch.argmax(output, dim = 2)
# convert label IDs to label names
word_ids = tokenized_input.word_ids(batch_index = 0)
# print(count, word_ids)
previous_word_id = None
labels = []
for token, word_id, prediction in zip(tokens, word_ids, predictions[0].numpy()):
# # Special tokens have a word id that is None. We set the label to -100 so they are
# # automatically ignored in the loss function.
# print(token, word_id, prediction)
if word_id is None or word_id == previous_word_id:
continue
elif word_id != previous_word_id:
labels.append(labels_dict[prediction])
previous_word_id = word_id
# print(len(sentence_tokens), sentence_tokens)
# print(len(labels), labels)
assert len(single_sentence_tokens) == len(labels), "Mismatch between input token and label sizes!"
for token, label in zip(single_sentence_tokens, labels):
print(token, label)
# st.markdown("# Hello")
# # st.set_page_config(page_title = "Kazakh Named Entity Recognition", page_icon = "🔍")
# # st.title("🔍 Kazakh Named Entity Recognition")
# x = st.slider('Select a value')
# st.write(x, 'squared is', x * x)