|
|
|
from fire import Fire |
|
import string |
|
import tensorflow as tf |
|
from transformers import AutoTokenizer |
|
from hazm import * |
|
from transformers import pipeline |
|
from transformers import TextClassificationPipeline |
|
original_model = "HooshvareLab/bert-fa-base-uncased" |
|
model_path = 'models' |
|
def remove_punctuation(input_string): |
|
translator = str.maketrans("", "", string.punctuation) |
|
|
|
result = input_string.translate(translator) |
|
return result |
|
def predict(file_path): |
|
normalizer = Normalizer() |
|
tokenizer = AutoTokenizer.from_pretrained(original_model) |
|
|
|
|
|
with open(file_path, 'r') as file: |
|
text = file.read() |
|
|
|
text = remove_punctuation(text) |
|
text = normalizer.normalize(text) |
|
|
|
input_tokens = tokenizer.batch_encode_plus( |
|
[text], |
|
padding=True, |
|
truncation=True, |
|
return_tensors="tf", |
|
max_length=128 |
|
) |
|
input_ids = input_tokens["input_ids"] |
|
attention_mask = input_tokens["attention_mask"] |
|
new_model = tf.keras.models.load_model(model_path) |
|
|
|
|
|
print({"input_ids": input_ids, "attention_mask": attention_mask}) |
|
predictions = new_model.predict([{"input_ids": input_ids, "attention_mask": attention_mask}]) |
|
print(predictions[0]) |
|
|
|
if __name__ == '__main__': |
|
Fire(predict) |