shivamjadhav commited on
Commit
21dc881
·
1 Parent(s): 2cdfdec

started backend model

Browse files
Files changed (4) hide show
  1. DockerFile +0 -0
  2. api.py +19 -0
  3. classifier/Albert_latest.py +41 -0
  4. requirements.txt +33 -0
DockerFile ADDED
File without changes
api.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI
2
+ from classifier.Albert_latest import get_model
3
+
4
+ app = FastAPI()
5
+ model = get_model()
6
+
7
+ @app.post("/predict")
8
+ async def predict(issue: str):
9
+ predictions = model.predict(issue)
10
+ print(f"Predictions: {predictions}")
11
+ id = predictions
12
+ print(f"Predictions: {predictions}")
13
+ return {
14
+ "priority 1": str(predictions[0]),
15
+ "priority 2": str(predictions[1]),
16
+ "priority 3": str(predictions[2]),
17
+ "priority 4": str(predictions[3]),
18
+ "priority 5": str(predictions[4])
19
+ }
classifier/Albert_latest.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AlbertTokenizer, AlbertForSequenceClassification
2
+ import torch
3
+
4
+ class Model:
5
+ def __init__(self, model_weights):
6
+ self.tokenizer = AlbertTokenizer.from_pretrained('albert-base-v2')
7
+ self.model = AlbertForSequenceClassification.from_pretrained('albert-base-v2', num_labels=2)
8
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
9
+
10
+ # Load the checkpoint
11
+ checkpoint = torch.load(model_weights, map_location=self.device)
12
+
13
+ # Load the model's state dictionary
14
+ self.model.load_state_dict(checkpoint['model_state_dict'])
15
+ self.currepoch = checkpoint['epoch']
16
+ self.loss = checkpoint['loss']
17
+ print(f"Loaded model state: Current epoch {self.currepoch}, current loss {self.loss}")
18
+
19
+ self.model.to(self.device)
20
+ self.model.eval()
21
+
22
+ def predict(self, text):
23
+ inputs = self.tokenizer(text, return_tensors='pt', padding=True, truncation=True, max_length=512)
24
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
25
+
26
+ with torch.no_grad():
27
+ outputs = self.model(**inputs)
28
+
29
+ logits = outputs.logits
30
+ print(f"logits: {logits}")
31
+ predictions = torch.nn.functional.softmax(logits, dim=-1)
32
+ return predictions[0].tolist()
33
+
34
+ model_instance = None
35
+ model_weights = "../assets/albert_sentiment_checkpoint_58.pt"
36
+
37
+ def get_model():
38
+ global model_instance
39
+ if model_instance is None:
40
+ model_instance = Model(model_weights)
41
+ return model_instance
requirements.txt ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ anyio==3.5.0
2
+ asgiref==3.5.0
3
+ certifi==2021.10.8
4
+ charset-normalizer==2.0.12
5
+ click==8.0.4
6
+ colorama==0.4.4
7
+ fastapi==0.75.0
8
+ filelock==3.6.0
9
+ gunicorn==20.1.0
10
+ h11==0.13.0
11
+ huggingface-hub==0.4.0
12
+ idna==3.3
13
+ joblib==1.1.0
14
+ numpy==1.22.3
15
+ packaging==21.3
16
+ pydantic==1.9.0
17
+ pyparsing==3.0.7
18
+ PyYAML==6.0
19
+ regex==2022.3.15
20
+ requests==2.27.1
21
+ sacremoses==0.0.49
22
+ sentencepiece==0.1.96
23
+ six==1.16.0
24
+ sniffio==1.2.0
25
+ starlette==0.17.1
26
+ tokenizers==0.11.6
27
+ --find-links https://download.pytorch.org/whl/torch_stable.html
28
+ torch==1.11.0+cpu
29
+ tqdm==4.63.0
30
+ transformers==4.17.0
31
+ typing_extensions==4.1.1
32
+ urllib3==1.26.8
33
+ uvicorn==0.17.6