File size: 587 Bytes
63d0aa5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
from fastai.collab import load_learner
from fastai.tabular.all import *

def custom_accuracy(prediction, target):    
    # set all predictions above 0.95 as true positive (correct prediction)
    prediction = torch.where(prediction > 0.95, torch.tensor(1.0), prediction)
    # shape [64, 1] to [64]
    target = target.squeeze(1)
    correct = (prediction == target).float()
    accuracy = correct.sum() / len(target)
    return accuracy

async def setup_learner(model_filename: str):
    learn = load_learner(model_filename)
    learn.dls.device = 'cpu'
    return learn