|
model_name="alfaxadeyembe/gemma2-27b-swahili-it" |
|
from transformers import AutoModelForSequenceClassification, AutoTokenizer |
|
import torch |
|
|
|
class EndpointHandler: |
|
def __init__(self, model_name): |
|
self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) |
|
self.model = AutoModelForSequenceClassification.from_pretrained(model_name_or_path) |
|
self.model.eval() |
|
|
|
def __call__(self, data): |
|
inputs = data.get("inputs", "") |
|
tokens = self.tokenizer(inputs, return_tensors='pt') |
|
with torch.no_grad(): |
|
outputs = self.model(**tokens) |
|
return outputs |
|
|
|
|