yeshpanovrustem commited on
Commit
8d8fdc8
1 Parent(s): e363ad3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +88 -4
app.py CHANGED
@@ -1,15 +1,99 @@
1
  import streamlit as st
2
- import transformers
3
 
4
  # use @st.cache decorator to cache model — because it is too large, we do not want to reload it every time
5
  # use allow_output_mutation = True to tell streamlit that model should be treated as immutable object — singleton
6
- # @st.cache(allow_output_mutation = True)
7
 
8
  # load model and tokenizer
9
- tokenizer = transformers.AutoTokenizer.from_pretrained("yeshpanovrustem/xlm-roberta-large-ner-kazakh")
10
- model = transformers.AutoModelForTokenClassification.from_pretrained("yeshpanovrustem/xlm-roberta-large-ner-kazakh")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
  # # define function for ner
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
 
15
  # st.markdown("# Hello")
 
1
  import streamlit as st
2
+ from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline
3
 
4
  # use @st.cache decorator to cache model — because it is too large, we do not want to reload it every time
5
  # use allow_output_mutation = True to tell streamlit that model should be treated as immutable object — singleton
6
+ @st.cache(allow_output_mutation = True)
7
 
8
  # load model and tokenizer
9
+ tokenizer = AutoTokenizer.from_pretrained("yeshpanovrustem/xlm-roberta-large-ner-kazakh")
10
+ model = AutoModelForTokenClassification.from_pretrained("yeshpanovrustem/xlm-roberta-large-ner-kazakh")
11
+
12
+ labels_dict = {0: 'O',
13
+ 1: 'B-ADAGE',
14
+ 2: 'I-ADAGE',
15
+ 3: 'B-ART',
16
+ 4: 'I-ART',
17
+ 5: 'B-CARDINAL',
18
+ 6: 'I-CARDINAL',
19
+ 7: 'B-CONTACT',
20
+ 8: 'I-CONTACT',
21
+ 9: 'B-DATE',
22
+ 10: 'I-DATE',
23
+ 11: 'B-DISEASE',
24
+ 12: 'I-DISEASE',
25
+ 13: 'B-EVENT',
26
+ 14: 'I-EVENT',
27
+ 15: 'B-FACILITY',
28
+ 16: 'I-FACILITY',
29
+ 17: 'B-GPE',
30
+ 18: 'I-GPE',
31
+ 19: 'B-LANGUAGE',
32
+ 20: 'I-LANGUAGE',
33
+ 21: 'B-LAW',
34
+ 22: 'I-LAW',
35
+ 23: 'B-LOCATION',
36
+ 24: 'I-LOCATION',
37
+ 25: 'B-MISCELLANEOUS',
38
+ 26: 'I-MISCELLANEOUS',
39
+ 27: 'B-MONEY',
40
+ 28: 'I-MONEY',
41
+ 29: 'B-NON_HUMAN',
42
+ 30: 'I-NON_HUMAN',
43
+ 31: 'B-NORP',
44
+ 32: 'I-NORP',
45
+ 33: 'B-ORDINAL',
46
+ 34: 'I-ORDINAL',
47
+ 35: 'B-ORGANISATION',
48
+ 36: 'I-ORGANISATION',
49
+ 37: 'B-PERSON',
50
+ 38: 'I-PERSON',
51
+ 39: 'B-PERCENTAGE',
52
+ 40: 'I-PERCENTAGE',
53
+ 41: 'B-POSITION',
54
+ 42: 'I-POSITION',
55
+ 43: 'B-PRODUCT',
56
+ 44: 'I-PRODUCT',
57
+ 45: 'B-PROJECT',
58
+ 46: 'I-PROJECT',
59
+ 47: 'B-QUANTITY',
60
+ 48: 'I-QUANTITY',
61
+ 49: 'B-TIME',
62
+ 50: 'I-TIME'}
63
 
64
  # # define function for ner
65
+ # def label_sentence(text):
66
+ # load pipeline
67
+ nlp = pipeline("ner", model = model, tokenizer = tokenizer)
68
+ example = "Қазақстан Республикасы — Шығыс Еуропа мен Орталық Азияда орналасқан мемлекет."
69
+
70
+ single_sentence_tokens = word_tokenize(example)
71
+ tokenized_input = tokenizer(single_sentence_tokens, is_split_into_words = True, return_tensors = "pt")
72
+ tokens = tokenized_input.tokens()
73
+ output = model(**tokenized_input).logits
74
+ predictions = torch.argmax(output, dim = 2)
75
+
76
+ # convert label IDs to label names
77
+ word_ids = tokenized_input.word_ids(batch_index = 0)
78
+ # print(count, word_ids)
79
+ previous_word_id = None
80
+ labels = []
81
+ for token, word_id, prediction in zip(tokens, word_ids, predictions[0].numpy()):
82
+ # # Special tokens have a word id that is None. We set the label to -100 so they are
83
+ # # automatically ignored in the loss function.
84
+ # print(token, word_id, prediction)
85
+ if word_id is None or word_id == previous_word_id:
86
+ continue
87
+ elif word_id != previous_word_id:
88
+ labels.append(labels_dict[prediction])
89
+ previous_word_id = word_id
90
+ # print(len(sentence_tokens), sentence_tokens)
91
+ # print(len(labels), labels)
92
+ assert len(single_sentence_tokens) == len(labels), "Mismatch between input token and label sizes!"
93
+
94
+ for token, label in zip(single_sentence_tokens, labels):
95
+ print(token, label)
96
+
97
 
98
 
99
  # st.markdown("# Hello")