|
import torch |
|
from PIL import Image |
|
from classes.genomic_plip_model import GenomicPLIPModel |
|
from classes.binary_neural_classifier import SimpleNN |
|
from transformers import CLIPImageProcessor |
|
|
|
def load_and_preprocess_image(image_path, clip_processor_path): |
|
clip_processor = CLIPImageProcessor.from_pretrained(clip_processor_path) |
|
image = Image.open(image_path).convert("RGB") |
|
inputs = clip_processor(images=[image], return_tensors="pt") |
|
image_tensor = inputs['pixel_values'] |
|
return image_tensor |
|
|
|
def genomic_plip_predictions(image_tensor, model_path): |
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
gmodel = GenomicPLIPModel.from_pretrained(model_path).to(device) |
|
gmodel.eval() |
|
with torch.no_grad(): |
|
pred_data = gmodel(image_tensor.to(device)) |
|
return pred_data |
|
|
|
def classify_tiles(pred_data, model_path): |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
model = SimpleNN().to(device) |
|
model.load_state_dict(torch.load(model_path, map_location=device)) |
|
model.eval() |
|
with torch.no_grad(): |
|
output = model(pred_data).mean() |
|
return output.item() |
|
|