SnakeCLEF2024 / script.py
Someshfengde's picture
Upload folder using huggingface_hub
6b340f0
raw
history blame
2.58 kB
import pandas as pd
import numpy as np
import os
from tqdm import tqdm
import timm
import torchvision.transforms as T
from PIL import Image
import torch
from create_model import HieraForImageClassification
from transformers import AutoImageProcessor
def is_gpu_available():
"""Check if the python package `onnxruntime-gpu` is installed."""
return torch.cuda.is_available()
class PytorchWorker:
def __init__(self):
def _load_model():
print("Setting up Pytorch Model")
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using devide: {self.device}")
image_processor = AutoImageProcessor.from_pretrained("./hiera_model")
model = HieraForImageClassification.from_pretrained("./hiera_model", num_labels =1784 ).to(self.device).eval()
return model, image_processor
self.model, self.image_processor = _load_model()
def predict_image(self, image: np.ndarray) -> list():
"""Run inference using ONNX runtime.
:param image: Input image as numpy array.
:return: A list with logits and confidences.
"""
inputs = self.image_processor(images=image, return_tensors="pt")
outputs = self.model(**inputs)
logits = outputs.logits
return logits.tolist()
def make_submission(test_metadata, model_path, model_name, output_csv_path="./submission.csv", images_root_path="/tmp/data/private_testset"):
"""Make submission with given """
model = PytorchWorker()
predictions = []
for _, row in tqdm(test_metadata.iterrows(), total=len(test_metadata)):
image_path = os.path.join(images_root_path, row.filename)
test_image = Image.open(image_path).convert("RGB")
logits = model.predict_image(test_image)
predictions.append(np.argmax(logits))
test_metadata["class_id"] = predictions
user_pred_df = test_metadata.drop_duplicates("observation_id", keep="first")
user_pred_df[["observation_id", "class_id"]].to_csv(output_csv_path, index=None)
if __name__ == "__main__":
import zipfile
with zipfile.ZipFile("/tmp/data/private_testset.zip", 'r') as zip_ref:
zip_ref.extractall("/tmp/data")
MODEL_PATH = "pytorch_model.bin"
MODEL_NAME = "swinv2_tiny_window16_256.ms_in1k"
metadata_file_path = "./SnakeCLEF2024_TestMetadata.csv"
test_metadata = pd.read_csv(metadata_file_path)
make_submission(
test_metadata=test_metadata,
model_path=MODEL_PATH,
model_name=MODEL_NAME
)