test / pipeline.py
davanstrien's picture
davanstrien HF staff
Update pipeline.py
17febd0
raw
history blame
1.02 kB
from typing import Dict, List, Any
import os
import json
import numpy as np
from fastai.learner import load_learner
class PreTrainedPipeline():
def __init__(self, path=""):
# IMPLEMENT_THIS
# Preload all the elements you are going to need at inference.
# For instance your model, processors, tokenizer that might be needed.
# This function is only called once, so do all the heavy processing I/O here"""
self.model = load_learner(os.path.join(path, "20211115-model.pkl"))
with open(os.path.join(path, "config.json")) as config:
config = json.load(config)
self.id2label = config["id2label"]
def __call__(self, inputs: str) -> List[Dict[str, Any]]:
_, _, preds = self.model.predict(inputs)
preds = preds.tolist()
labels = [
{"label": str(self.id2label["0"]), "score": preds[0]},
{"label": str(self.id2label["1"]), "score": preds[1]},
]
return labels