DerekLiu35's picture
Update image_to_fen/fen.py
70135dd
raw
history blame
No virus
3.73 kB
import argparse
from pathlib import Path
from typing import Sequence, Union
from PIL import Image
import torch
import torchvision
import numpy as np
import re
import chess
from fentoimage.board import BoardImage
import image_to_fen.util as util
STAGED_MODEL_DIRNAME = Path(__file__).resolve().parent / "artifacts" / "image-to-fen"
MODEL_FILE = "model.pt"
class ImageToFen:
"""Takes image of chess board and returns FEN string."""
def __init__(self, model_path=None):
if model_path is None:
model_path = STAGED_MODEL_DIRNAME / MODEL_FILE
self.model = torch.jit.load(model_path)
@torch.no_grad()
def predict(self, image: Union[str, Path, Image.Image]) -> str:
"""Predict FEN string for image of chess board."""
image = image
if not isinstance(image, Image.Image):
image = util.read_image_pil(image, grayscale=True)
image = image.resize((200, 200))
image = torchvision.transforms.PILToTensor()(image)/255
pred = self.model([image])[1][0]
nms_pred = apply_nms(pred, iou_thresh=0.2)
pred_str = boxes_labels_to_fen(nms_pred['boxes'], nms_pred['labels'])
return pred_str
def apply_nms(orig_prediction, iou_thresh=0.3):
# torchvision returns the indices of the bboxes to keep
keep = torchvision.ops.nms(orig_prediction['boxes'], orig_prediction['scores'], iou_thresh)
final_prediction = orig_prediction
final_prediction['boxes'] = final_prediction['boxes'][keep]
final_prediction['scores'] = final_prediction['scores'][keep]
final_prediction['labels'] = final_prediction['labels'][keep]
return final_prediction
def boxes_labels_to_fen(boxes, labels, square_size=25):
boxes = torch.round(boxes / 25) * 25
eye = np.eye(13)
one_hot = onehot_from_fen("8-8-8-8-8-8-8-8")
for i, box in enumerate(boxes):
x = box[0]
y = box[1]
ind = int((x / square_size) + (y / square_size) * 8)
if (ind >= 64):
continue
one_hot[ind] = eye[12 - labels[i]].reshape((1, 13)).astype(int)
return fen_from_onehot(one_hot)
def onehot_from_fen(fen):
piece_symbols = 'prbnkqPRBNKQ'
eye = np.eye(13)
output = np.empty((0, 13))
fen = re.sub('[-]', '', fen)
for char in fen:
if(char in '12345678'):
output = np.append(
output, np.tile(eye[12], (int(char), 1)), axis=0)
else:
idx = piece_symbols.index(char)
output = np.append(output, eye[idx].reshape((1, 13)), axis=0)
return output
def fen_from_onehot(one_hot):
piece_symbols = 'prbnkqPRBNKQ'
output = ''
for j in range(8):
for i in range(8):
idx = np.where(one_hot[j*8 + i]==1)[0][0]
if(idx == 12):
output += ' '
else:
output += piece_symbols[idx]
if(j != 7):
output += '-'
for i in range(8, 0, -1):
output = output.replace(' ' * i, str(i))
return output
def fen_and_image(input):
itf = ImageToFen()
output = itf.predict(input)
fen = output.replace('-', '/')
renderer = BoardImage(fen)
image = renderer.render()
return fen, image
def main():
"""Run prediction on image."""
parser = argparse.ArgumentParser(description="Predict FEN string for image of chess board.")
parser.add_argument("image", type=Path, help="Path to image file.")
parser.add_argument("--model-path", type=Path, help="Path to model file.")
args = parser.parse_args()
image_to_fen = ImageToFen(args.model_path)
pred = image_to_fen.predict(args.image)
print(f"Prediction: {pred}")
# image_to_fen/tests/support/boards/phpSrRLQ1.png
if __name__ == "__main__":
main()