File size: 655 Bytes
654e1b3
fab504a
 
 
4414af8
654e1b3
fab504a
 
 
 
4414af8
 
fab504a
 
 
 
4414af8
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
model_name="alfaxadeyembe/gemma2-27b-swahili-it"
from transformers import AutoModelForSequenceClassification, AutoTokenizer
import torch

class EndpointHandler:
    def __init__(self, model_name):
        self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
        self.model = AutoModelForSequenceClassification.from_pretrained(model_name_or_path)
        self.model.eval()  # Set the model to evaluation mode

    def __call__(self, data):
        inputs = data.get("inputs", "")
        tokens = self.tokenizer(inputs, return_tensors='pt')
        with torch.no_grad():
            outputs = self.model(**tokens)
        return outputs