Erfan11 commited on
Commit
b58127a
1 Parent(s): c40189d

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +14 -19
main.py CHANGED
@@ -1,38 +1,33 @@
1
  import os
 
2
  from dotenv import load_dotenv
3
- import torch
4
- from transformers import AutoModelForSequenceClassification, AutoTokenizer
5
 
6
  # Load environment variables
7
  load_dotenv()
8
 
9
  def load_model(model_path):
10
- model = AutoModelForSequenceClassification.from_pretrained(model_path)
11
- tokenizer = AutoTokenizer.from_pretrained(model_path)
12
- return model, tokenizer
 
 
 
 
13
 
14
  def predict(text, model, tokenizer):
15
- inputs = tokenizer(text, return_tensors="pt")
16
- outputs = model(**inputs)
17
  return outputs
18
 
19
  def main():
20
  model_path = os.getenv('MODEL_PATH')
21
- model, tokenizer = load_model(model_path)
 
22
  # Example usage
23
  text = "Sample input text"
24
  result = predict(text, model, tokenizer)
25
  print(result)
26
 
27
  if __name__ == "__main__":
28
- main()
29
- from transformers import BertForSequenceClassification
30
-
31
- # Load the TensorFlow model using from_tf=True
32
- model = BertForSequenceClassification.from_pretrained(
33
- "Erfan11/Neuracraft",
34
- from_tf=True,
35
- use_auth_token="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"
36
- )
37
-
38
- # Additional code to run your app can go here (for example, Streamlit or Gradio interface)
 
1
  import os
2
+ import tensorflow as tf
3
  from dotenv import load_dotenv
4
+ from transformers import BertTokenizerFast
 
5
 
6
  # Load environment variables
7
  load_dotenv()
8
 
9
  def load_model(model_path):
10
+ # Load the TensorFlow model using from_tf=True
11
+ model = tf.keras.models.load_model(model_path)
12
+ return model
13
+
14
+ def load_tokenizer(model_path):
15
+ tokenizer = BertTokenizerFast.from_pretrained(model_path)
16
+ return tokenizer
17
 
18
  def predict(text, model, tokenizer):
19
+ inputs = tokenizer(text, return_tensors="tf")
20
+ outputs = model(inputs)
21
  return outputs
22
 
23
  def main():
24
  model_path = os.getenv('MODEL_PATH')
25
+ model = load_model(model_path)
26
+ tokenizer = load_tokenizer(model_path)
27
  # Example usage
28
  text = "Sample input text"
29
  result = predict(text, model, tokenizer)
30
  print(result)
31
 
32
  if __name__ == "__main__":
33
+ main()