Erfan11 commited on
Commit
634cec7
1 Parent(s): 7339af2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -25
app.py CHANGED
@@ -1,31 +1,36 @@
1
- import tensorflow as tf
2
- from transformers import TFBertForSequenceClassification
3
- from flask import Flask, request, jsonify
4
 
5
- app = Flask(__name__)
 
 
 
 
 
 
 
6
 
7
- # Load the model
8
- model_name = "Erfan11/Neuracraft"
9
- model = TFBertForSequenceClassification.from_pretrained(model_name, use_auth_token="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")
10
 
11
- @app.route('/predict', methods=['POST'])
12
- def predict():
13
- data = request.get_json()
14
- # Preprocess input data
15
- inputs = preprocess_data(data)
16
- # Make prediction
17
- predictions = model.predict(inputs)
18
- # Postprocess and return results
19
- results = postprocess_predictions(predictions)
20
- return jsonify(results)
21
 
22
- def preprocess_data(data):
23
- # Implement your data preprocessing here
24
- pass
 
25
 
26
- def postprocess_predictions(predictions):
27
- # Implement your result postprocessing here
28
- pass
29
 
30
- if __name__ == '__main__':
31
- app.run(debug=True)
 
 
 
 
 
 
1
+ import os
2
+ from transformers import TFBertForSequenceClassification, BertTokenizerFast
 
3
 
4
+ def load_model(model_name):
5
+ try:
6
+ # Load TensorFlow model from Hugging Face
7
+ model = TFBertForSequenceClassification.from_pretrained(model_name, use_auth_token=os.getenv('API_KEY'))
8
+ except OSError:
9
+ # Fallback to PyTorch model if TensorFlow fails
10
+ model = TFBertForSequenceClassification.from_pretrained(model_name, use_auth_token=os.getenv('API_KEY'), from_pt=True)
11
+ return model
12
 
13
+ def load_tokenizer(model_name):
14
+ tokenizer = BertTokenizerFast.from_pretrained(model_name, use_auth_token=os.getenv('API_KEY'))
15
+ return tokenizer
16
 
17
+ def predict(text, model, tokenizer):
18
+ inputs = tokenizer(text, return_tensors="tf")
19
+ outputs = model(**inputs)
20
+ return outputs
 
 
 
 
 
 
21
 
22
+ def main():
23
+ model_name = os.getenv('MODEL_PATH')
24
+ if model_name is None:
25
+ raise ValueError("MODEL_PATH environment variable not set or is None")
26
 
27
+ model = load_model(model_name)
28
+ tokenizer = load_tokenizer(model_name)
 
29
 
30
+ # Example prediction
31
+ text = "Sample input text"
32
+ result = predict(text, model, tokenizer)
33
+ print(result)
34
+
35
+ if __name__ == "__main__":
36
+ main()