File size: 2,353 Bytes
f4abbca |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 |
from typing import List
from resources import set_start, audit_elapsedtime, entities_list_to_dict
from transformers import BertTokenizer, BertForTokenClassification
import torch
#Named-Entity Recognition model
def init_model_ner():
print("Initiating NER model...")
start = set_start()
# Load pre-trained tokenizer and model
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
model = BertForTokenClassification.from_pretrained("dbmdz/bert-large-cased-finetuned-conll03-english")
audit_elapsedtime(function="Initiating NER model", start=start)
return tokenizer, model
def get_entity_results(tokenizer, model, text: str, entities_list: List[str]): #-> Lead_labels:
print("Initiating entity recognition...")
start = set_start()
tokens = tokenizer.tokenize(tokenizer.decode(tokenizer.encode(text)))
labels = entities_list
# Convert tokens to IDs
input_ids = tokenizer.encode(text, return_tensors="pt")
# Perform NER prediction
with torch.no_grad():
outputs = model(input_ids)
# Get the predicted labels
predicted_labels = torch.argmax(outputs.logits, dim=2)[0]
# Map predicted labels to actual entities
entities = []
current_entity = ""
for i, label_id in enumerate(predicted_labels):
label = model.config.id2label[label_id.item()]
token = tokens[i]
if label.startswith('B-'): # Beginning of a new entity
if current_entity:
entities.append(current_entity.strip())
current_entity = token
elif label.startswith('I-'): # Inside of an entity
current_entity += " " + token
else: # Outside of any entity
if current_entity:
entities.append(current_entity.strip())
current_entity = ""
# Filter out only the entities you are interested in
filtered_entities = [entity for entity in entities if entity in labels]
# entities_result = model.predict_entities(text, labels)
# entities_dict = entities_list_to_dict(entities_list)
# for entity in entities_result:
# print(entity["text"], "=>", entity["label"])
# entities_dict[entity["label"]] = entity["text"]
audit_elapsedtime(function="Retreiving entity labels from text", start=start)
return filtered_entities |