|
import os |
|
from transformers import TFBertForSequenceClassification, BertTokenizerFast |
|
|
|
def load_model(model_name): |
|
try: |
|
|
|
model = TFBertForSequenceClassification.from_pretrained(model_name, use_auth_token=os.getenv('API_KEY')) |
|
except OSError: |
|
|
|
model = TFBertForSequenceClassification.from_pretrained(model_name, use_auth_token=os.getenv('API_KEY'), from_pt=True) |
|
return model |
|
|
|
def load_tokenizer(model_name): |
|
tokenizer = BertTokenizerFast.from_pretrained(model_name, use_auth_token=os.getenv('API_KEY')) |
|
return tokenizer |
|
|
|
def predict(text, model, tokenizer): |
|
inputs = tokenizer(text, return_tensors="tf") |
|
outputs = model(**inputs) |
|
return outputs |
|
|
|
def main(): |
|
model_name = os.getenv('MODEL_PATH') |
|
if model_name is None: |
|
raise ValueError("MODEL_PATH environment variable not set or is None") |
|
|
|
model = load_model(model_name) |
|
tokenizer = load_tokenizer(model_name) |
|
|
|
|
|
text = "Sample input text" |
|
result = predict(text, model, tokenizer) |
|
print(result) |
|
|
|
if __name__ == "__main__": |
|
main() |