Spaces:
Running
Running
import torch | |
import sentencepiece | |
from transformers import AutoModelForTokenClassification, AutoTokenizer, pipeline | |
import os | |
import spacy | |
import spacy_transformers | |
import zipfile | |
from collections import defaultdict | |
class Models(): | |
def __init__(self) -> None: | |
self.load_trained_models() | |
def load_trained_models(self): | |
tokenizer = AutoTokenizer.from_pretrained("Jean-Baptiste/camembert-ner-with-dates",use_fast=False) | |
model = AutoModelForTokenClassification.from_pretrained("Jean-Baptiste/camembert-ner-with-dates") | |
self.ner = pipeline('ner', model=model, tokenizer=tokenizer, aggregation_strategy="simple") | |
current_directory = os.path.dirname(os.path.realpath(__file__)) | |
custom_ner_path = os.path.join(current_directory, 'spacy_model_v2/output/model-best') | |
destination_folder = "/spacy_model_v2" | |
if not os.path.exists(custom_ner_path): | |
with zipfile.ZipFile(r"./spacy_model_v2.zip", 'r') as zip_ref: | |
# Extract all contents in the current working directory | |
zip_ref.extractall(current_directory+destination_folder) | |
self.custom_ner = spacy.load(custom_ner_path) | |
def extract_ner(self, text): | |
entities = self.ner(text) | |
keys = ['DATE', 'ORG', 'LOC'] | |
sort_dict = defaultdict(list) | |
for entity in entities: | |
if entity['score'] > 0.75: | |
sort_dict[entity['entity_group']].append(entity['word']) | |
filtered_dict = {key: value for key, value in sort_dict.items() if key in keys} | |
filtered_dict = defaultdict(list, filtered_dict) | |
return filtered_dict['DATE'], filtered_dict['ORG'], filtered_dict['LOC'] | |
def get_ner(self, text, recover_text): | |
dates, companies, locations = self.extract_ner(text) | |
alternative_dates, alternative_companies, alternative_locations = self.extract_ner(recover_text) | |
if dates == [] : | |
dates = alternative_dates | |
if companies == []: | |
companies = alternative_companies | |
if locations == []: | |
locations = alternative_locations | |
return dates, companies, locations | |
def get_custom_ner(self, text): | |
doc = self.custom_ner(text) | |
entities = list(doc.ents) | |
sort_dict = defaultdict(list) | |
for entity in entities: | |
sort_dict[entity.label_].append(entity.text) | |
return sort_dict | |