Spaces:
Sleeping
Sleeping
File size: 3,732 Bytes
bbae066 70135dd bbae066 70135dd bbae066 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 |
import argparse
from pathlib import Path
from typing import Sequence, Union
from PIL import Image
import torch
import torchvision
import numpy as np
import re
import chess
from fentoimage.board import BoardImage
import image_to_fen.util as util
STAGED_MODEL_DIRNAME = Path(__file__).resolve().parent / "artifacts" / "image-to-fen"
MODEL_FILE = "model.pt"
class ImageToFen:
"""Takes image of chess board and returns FEN string."""
def __init__(self, model_path=None):
if model_path is None:
model_path = STAGED_MODEL_DIRNAME / MODEL_FILE
self.model = torch.jit.load(model_path)
@torch.no_grad()
def predict(self, image: Union[str, Path, Image.Image]) -> str:
"""Predict FEN string for image of chess board."""
image = image
if not isinstance(image, Image.Image):
image = util.read_image_pil(image, grayscale=True)
image = image.resize((200, 200))
image = torchvision.transforms.PILToTensor()(image)/255
pred = self.model([image])[1][0]
nms_pred = apply_nms(pred, iou_thresh=0.2)
pred_str = boxes_labels_to_fen(nms_pred['boxes'], nms_pred['labels'])
return pred_str
def apply_nms(orig_prediction, iou_thresh=0.3):
# torchvision returns the indices of the bboxes to keep
keep = torchvision.ops.nms(orig_prediction['boxes'], orig_prediction['scores'], iou_thresh)
final_prediction = orig_prediction
final_prediction['boxes'] = final_prediction['boxes'][keep]
final_prediction['scores'] = final_prediction['scores'][keep]
final_prediction['labels'] = final_prediction['labels'][keep]
return final_prediction
def boxes_labels_to_fen(boxes, labels, square_size=25):
boxes = torch.round(boxes / 25) * 25
eye = np.eye(13)
one_hot = onehot_from_fen("8-8-8-8-8-8-8-8")
for i, box in enumerate(boxes):
x = box[0]
y = box[1]
ind = int((x / square_size) + (y / square_size) * 8)
if (ind >= 64):
continue
one_hot[ind] = eye[12 - labels[i]].reshape((1, 13)).astype(int)
return fen_from_onehot(one_hot)
def onehot_from_fen(fen):
piece_symbols = 'prbnkqPRBNKQ'
eye = np.eye(13)
output = np.empty((0, 13))
fen = re.sub('[-]', '', fen)
for char in fen:
if(char in '12345678'):
output = np.append(
output, np.tile(eye[12], (int(char), 1)), axis=0)
else:
idx = piece_symbols.index(char)
output = np.append(output, eye[idx].reshape((1, 13)), axis=0)
return output
def fen_from_onehot(one_hot):
piece_symbols = 'prbnkqPRBNKQ'
output = ''
for j in range(8):
for i in range(8):
idx = np.where(one_hot[j*8 + i]==1)[0][0]
if(idx == 12):
output += ' '
else:
output += piece_symbols[idx]
if(j != 7):
output += '-'
for i in range(8, 0, -1):
output = output.replace(' ' * i, str(i))
return output
def fen_and_image(input):
itf = ImageToFen()
output = itf.predict(input)
fen = output.replace('-', '/')
renderer = BoardImage(fen)
image = renderer.render()
return fen, image
def main():
"""Run prediction on image."""
parser = argparse.ArgumentParser(description="Predict FEN string for image of chess board.")
parser.add_argument("image", type=Path, help="Path to image file.")
parser.add_argument("--model-path", type=Path, help="Path to model file.")
args = parser.parse_args()
image_to_fen = ImageToFen(args.model_path)
pred = image_to_fen.predict(args.image)
print(f"Prediction: {pred}")
# image_to_fen/tests/support/boards/phpSrRLQ1.png
if __name__ == "__main__":
main() |