File size: 3,732 Bytes
bbae066
 
 
 
 
 
 
 
 
70135dd
 
bbae066
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70135dd
 
 
 
 
 
 
 
bbae066
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import argparse
from pathlib import Path
from typing import Sequence, Union

from PIL import Image
import torch
import torchvision
import numpy as np
import re
import chess
from fentoimage.board import BoardImage

import image_to_fen.util as util


STAGED_MODEL_DIRNAME = Path(__file__).resolve().parent / "artifacts" / "image-to-fen"
MODEL_FILE = "model.pt"

class ImageToFen:
    """Takes image of chess board and returns FEN string."""

    def __init__(self, model_path=None):
        if model_path is None:
            model_path = STAGED_MODEL_DIRNAME / MODEL_FILE
        self.model = torch.jit.load(model_path)

    @torch.no_grad()
    def predict(self, image: Union[str, Path, Image.Image]) -> str:
        """Predict FEN string for image of chess board."""
        image = image
        if not isinstance(image, Image.Image):
            image = util.read_image_pil(image, grayscale=True)
        image = image.resize((200, 200))
        image = torchvision.transforms.PILToTensor()(image)/255
        pred = self.model([image])[1][0]
        nms_pred = apply_nms(pred, iou_thresh=0.2)
        pred_str = boxes_labels_to_fen(nms_pred['boxes'], nms_pred['labels'])
        return pred_str
    
def apply_nms(orig_prediction, iou_thresh=0.3):

    # torchvision returns the indices of the bboxes to keep
    keep = torchvision.ops.nms(orig_prediction['boxes'], orig_prediction['scores'], iou_thresh)

    final_prediction = orig_prediction
    final_prediction['boxes'] = final_prediction['boxes'][keep]
    final_prediction['scores'] = final_prediction['scores'][keep]
    final_prediction['labels'] = final_prediction['labels'][keep]

    return final_prediction

def boxes_labels_to_fen(boxes, labels, square_size=25):
  boxes = torch.round(boxes / 25) * 25
  eye = np.eye(13)
  one_hot = onehot_from_fen("8-8-8-8-8-8-8-8")
  for i, box in enumerate(boxes):
    x = box[0]
    y = box[1]
    ind = int((x / square_size) + (y / square_size) * 8)
    if (ind >= 64):
      continue
    one_hot[ind] = eye[12 - labels[i]].reshape((1, 13)).astype(int)
  return fen_from_onehot(one_hot)

def onehot_from_fen(fen):
    piece_symbols = 'prbnkqPRBNKQ'
    eye = np.eye(13)
    output = np.empty((0, 13))
    fen = re.sub('[-]', '', fen)

    for char in fen:
        if(char in '12345678'):
            output = np.append(
              output, np.tile(eye[12], (int(char), 1)), axis=0)
        else:
            idx = piece_symbols.index(char)
            output = np.append(output, eye[idx].reshape((1, 13)), axis=0)

    return output

def fen_from_onehot(one_hot):
    piece_symbols = 'prbnkqPRBNKQ'
    output = ''
    for j in range(8):
        for i in range(8):
            idx = np.where(one_hot[j*8 + i]==1)[0][0]
            if(idx == 12):
                output += ' '
            else:
                output += piece_symbols[idx]
        if(j != 7):
            output += '-'

    for i in range(8, 0, -1):
        output = output.replace(' ' * i, str(i))

    return output

def fen_and_image(input):
  itf = ImageToFen()
  output = itf.predict(input)
  fen = output.replace('-', '/')
  renderer = BoardImage(fen)
  image = renderer.render()
  return fen, image
    
def main():
    """Run prediction on image."""
    parser = argparse.ArgumentParser(description="Predict FEN string for image of chess board.")
    parser.add_argument("image", type=Path, help="Path to image file.")
    parser.add_argument("--model-path", type=Path, help="Path to model file.")
    args = parser.parse_args()
    image_to_fen = ImageToFen(args.model_path)
    pred = image_to_fen.predict(args.image)
    print(f"Prediction: {pred}")
    
    # image_to_fen/tests/support/boards/phpSrRLQ1.png
if __name__ == "__main__":
    main()