Erfan11 commited on
Commit
8bad412
1 Parent(s): e1fff91

Update load_model.py

Browse files
Files changed (1) hide show
  1. load_model.py +4 -34
load_model.py CHANGED
@@ -1,39 +1,9 @@
1
  import os
 
2
  from dotenv import load_dotenv
3
- from transformers import TFBertForSequenceClassification, BertTokenizerFast
4
 
5
- # Load environment variables from .env file
6
  load_dotenv()
 
7
 
8
- def load_model(model_name):
9
- try:
10
- # Load TensorFlow model from Hugging Face
11
- model = TFBertForSequenceClassification.from_pretrained(model_name, use_auth_token=os.getenv('API_KEY'), from_tf=True)
12
- except OSError:
13
- raise ValueError("Model loading failed.")
14
- return model
15
-
16
- def load_tokenizer(model_name):
17
- tokenizer = BertTokenizerFast.from_pretrained(model_name, use_auth_token=os.getenv('API_KEY'))
18
- return tokenizer
19
-
20
- def predict(text, model, tokenizer):
21
- inputs = tokenizer(text, return_tensors="tf")
22
- outputs = model(**inputs)
23
- return outputs
24
-
25
- def main():
26
- model_name = os.getenv('MODEL_PATH')
27
- if model_name is None:
28
- raise ValueError("MODEL_PATH environment variable not set or is None")
29
-
30
- model = load_model(model_name)
31
- tokenizer = load_tokenizer(model_name)
32
-
33
- # Example prediction
34
- text = "Sample input text"
35
- result = predict(text, model, tokenizer)
36
- print(result)
37
-
38
- if __name__ == "__main__":
39
- main()
 
1
  import os
2
+ import tensorflow as tf
3
  from dotenv import load_dotenv
 
4
 
 
5
  load_dotenv()
6
+ model_path = os.getenv('MODEL_PATH')
7
 
8
+ def load_model():
9
+ return tf.keras.models.load_model(model_path)