negarb's picture
Upload 2 files
488f448
from fire import Fire
import string
import tensorflow as tf
from transformers import AutoTokenizer
from hazm import *
from transformers import pipeline
from transformers import TextClassificationPipeline
original_model = "HooshvareLab/bert-fa-base-uncased"
model_path = 'models'
def remove_punctuation(input_string):
translator = str.maketrans("", "", string.punctuation)
result = input_string.translate(translator)
return result
def predict(file_path):
normalizer = Normalizer()
tokenizer = AutoTokenizer.from_pretrained(original_model)
# classifier = pipeline("text-classification", model="stevhliu/my_awesome_model")
with open(file_path, 'r') as file:
text = file.read()
text = remove_punctuation(text)
text = normalizer.normalize(text)
input_tokens = tokenizer.batch_encode_plus(
[text],
padding=True,
truncation=True,
return_tensors="tf",
max_length=128
)
input_ids = input_tokens["input_ids"]
attention_mask = input_tokens["attention_mask"]
new_model = tf.keras.models.load_model(model_path)
# pipe = TextClassificationPipeline(model=new_model, tokenizer=tokenizer, return_all_scores=True)
print({"input_ids": input_ids, "attention_mask": attention_mask})
predictions = new_model.predict([{"input_ids": input_ids, "attention_mask": attention_mask}])
print(predictions[0])
# print(pipe([text]))
if __name__ == '__main__':
Fire(predict)