voicelead / nameder.py
zeimoto's picture
first commit
f4abbca
raw
history blame
2.35 kB
from typing import List
from resources import set_start, audit_elapsedtime, entities_list_to_dict
from transformers import BertTokenizer, BertForTokenClassification
import torch
#Named-Entity Recognition model
def init_model_ner():
print("Initiating NER model...")
start = set_start()
# Load pre-trained tokenizer and model
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
model = BertForTokenClassification.from_pretrained("dbmdz/bert-large-cased-finetuned-conll03-english")
audit_elapsedtime(function="Initiating NER model", start=start)
return tokenizer, model
def get_entity_results(tokenizer, model, text: str, entities_list: List[str]): #-> Lead_labels:
print("Initiating entity recognition...")
start = set_start()
tokens = tokenizer.tokenize(tokenizer.decode(tokenizer.encode(text)))
labels = entities_list
# Convert tokens to IDs
input_ids = tokenizer.encode(text, return_tensors="pt")
# Perform NER prediction
with torch.no_grad():
outputs = model(input_ids)
# Get the predicted labels
predicted_labels = torch.argmax(outputs.logits, dim=2)[0]
# Map predicted labels to actual entities
entities = []
current_entity = ""
for i, label_id in enumerate(predicted_labels):
label = model.config.id2label[label_id.item()]
token = tokens[i]
if label.startswith('B-'): # Beginning of a new entity
if current_entity:
entities.append(current_entity.strip())
current_entity = token
elif label.startswith('I-'): # Inside of an entity
current_entity += " " + token
else: # Outside of any entity
if current_entity:
entities.append(current_entity.strip())
current_entity = ""
# Filter out only the entities you are interested in
filtered_entities = [entity for entity in entities if entity in labels]
# entities_result = model.predict_entities(text, labels)
# entities_dict = entities_list_to_dict(entities_list)
# for entity in entities_result:
# print(entity["text"], "=>", entity["label"])
# entities_dict[entity["label"]] = entity["text"]
audit_elapsedtime(function="Retreiving entity labels from text", start=start)
return filtered_entities