diff --git a/.gitattributes b/.gitattributes
index a6344aac8c09253b3b630fb776ae94478aa0275b..19f846cfea6e8d0869d82423a206db6e52aaa6ff 100644
--- a/.gitattributes
+++ b/.gitattributes
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
*.zip filter=lfs diff=lfs merge=lfs -text
*.zst filter=lfs diff=lfs merge=lfs -text
*tfevents* filter=lfs diff=lfs merge=lfs -text
+*.jpg filter=lfs diff=lfs merge=lfs -text
+*.png filter=lfs diff=lfs merge=lfs -text
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..0ca0f2b622a97f2263fae8c2d423169db5ae3681
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,177 @@
+# Output directories
+outputs/
+multirun/
+ray_results/
+
+# Byte-compiled / optimized / DLL files
+__pycache__/
+*.py[cod]
+*$py.class
+
+# C extensions
+*.so
+
+# Distribution / packaging
+# requirements/core.*.txt
+.Python
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+share/python-wheels/
+*.egg-info/
+.installed.cfg
+*.egg
+MANIFEST
+
+# PyInstaller
+# Usually these files are written by a python script from a template
+# before PyInstaller builds the exe, so as to inject date/other infos into it.
+*.manifest
+*.spec
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+htmlcov/
+.tox/
+.nox/
+.coverage
+.coverage.*
+.cache
+nosetests.xml
+coverage.xml
+*.cover
+*.py,cover
+.hypothesis/
+.pytest_cache/
+cover/
+
+# Translations
+*.mo
+*.pot
+
+# Django stuff:
+*.log
+local_settings.py
+db.sqlite3
+db.sqlite3-journal
+
+# Flask stuff:
+instance/
+.webassets-cache
+
+# Scrapy stuff:
+.scrapy
+
+# Sphinx documentation
+docs/_build/
+
+# PyBuilder
+.pybuilder/
+target/
+
+# Jupyter Notebook
+.ipynb_checkpoints
+
+# IPython
+profile_default/
+ipython_config.py
+
+# pyenv
+# For a library or package, you might want to ignore these files since the code is
+# intended to run in multiple environments; otherwise, check them in:
+# .python-version
+
+# pipenv
+# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
+# However, in case of collaboration, if having platform-specific dependencies or dependencies
+# having no cross-platform support, pipenv may install dependencies that don't work, or not
+# install all needed dependencies.
+#Pipfile.lock
+
+# PEP 582; used by e.g. github.com/David-OConnor/pyflow
+__pypackages__/
+
+# Celery stuff
+celerybeat-schedule
+celerybeat.pid
+
+# SageMath parsed files
+*.sage.py
+
+# Environments
+.env
+.venv
+env/
+venv/
+ENV/
+env.bak/
+venv.bak/
+.python-version
+
+# Spyder project settings
+.spyderproject
+.spyproject
+
+# Rope project settings
+.ropeproject
+
+# mkdocs documentation
+/site
+
+# mypy
+.mypy_cache/
+.dmypy.json
+dmypy.json
+
+# Pyre type checker
+.pyre/
+
+# pytype static type analyzer
+.pytype/
+
+# Cython debug symbols
+cython_debug/
+
+# IDE
+.idea/
+
+########## CUSTOM FOLDER ##############
+README_original.md
+
+results/
+images
+bharatSTR/East/tmp
+bharatSTR/models
+bharatSTR/images
+__pycache__/
+bharatSTR/
+
+IndicPhotoOCR/detection/East
+IndicPhotoOCR/recognition/models
+
+IndicPhotoOCR/script_identification/images
+IndicPhotoOCR/script_identification/models
+IndicPhotoOCR/script_identification/vit/models
+
+
+build/
+dist/
+test.png
+static/pics/IndicPhotoOCR.gif
+input_image.jpg
+output_image.png
+flagged
+
+
diff --git a/CHANGELOG.md b/CHANGELOG.md
new file mode 100644
index 0000000000000000000000000000000000000000..52fca244b7f25dededd8bb0497f103a89609485e
--- /dev/null
+++ b/CHANGELOG.md
@@ -0,0 +1,23 @@
+# Changelog
+
+## [1.2.0] - 2024-11-24
+### Added
+- Textbpn++ detection model added
+
+## [1.1.0] - 2024-11-24
+### Added
+- Updated package naming convention
+
+## [1.0.3] - 2024-11-06
+### Added
+- Python package requirements sorted with setup.py
+
+## [1.0.2] - 2024-11-01
+### Added
+- Added language support for 10 additional models in the recognition module.
+
+## [1.0.1] - 2024-10-28
+### Added
+- Added support for detecting polygonal bounding boxes in `visualize_detection`.
+- Introduced the `show` argument in `visualize_detection` to control image display.
+
diff --git a/Dockerfile b/Dockerfile
new file mode 100644
index 0000000000000000000000000000000000000000..aaa81c57c8c05ec7968a7d6871f4cdeb75401105
--- /dev/null
+++ b/Dockerfile
@@ -0,0 +1,45 @@
+# Use NVIDIA PyTorch as the base image
+FROM nvcr.io/nvidia/pytorch:23.12-py3
+
+# Install additional dependencies
+RUN apt-get update && apt-get install -y ffmpeg libsm6 libxext6
+
+# Set environment variables for Miniconda and Conda environment
+ENV CONDA_DIR /opt/conda
+ENV PATH $CONDA_DIR/bin:$PATH
+
+# Install Miniconda
+RUN apt-get update && apt-get install -y wget && \
+ wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh && \
+ bash Miniconda3-latest-Linux-x86_64.sh -b -p $CONDA_DIR && \
+ rm Miniconda3-latest-Linux-x86_64.sh
+
+# Create a new Conda environment named "bocr" with Python 3.9
+RUN conda create -n bocr python=3.9 -y
+
+# Initialize conda
+RUN conda init
+
+# Reload the env configs
+RUN source ~/.bashrc
+
+# Make RUN commands use the bocr environment
+SHELL ["conda", "run", "-n", "bocr", "/bin/bash", "-c"]
+
+# # Set default shell to bash
+# SHELL ["/bin/bash", "-c"]
+
+# # Clone BharatOCR repository
+# RUN git clone https://github.com/Bhashini-IITJ/BharatOCR.git && \
+# git switch photoOCR && \
+# cd IndicPhotoOCR && \
+# python setup.py sdist bdist_wheel && \
+# pip install ./dist/IndicPhotoOCR-1.1.0-py3-none-any.whl[cu118] --extra-index-url https://download.pytorch.org/whl/cu118
+
+# # # Set default command to run BharatOCR
+# CMD ["conda", "run", "-n", "bocr", "python", "-m", "IndicPhotoOCR.ocr"]
+
+
+# cd IndicPhotoOCR
+# sudo docker build -t indicphotoocr:latest .
+# sudo docker run --gpus all --rm -it indicphotoocr:latest
\ No newline at end of file
diff --git a/IndicPhotoOCR/__init__.py b/IndicPhotoOCR/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/IndicPhotoOCR/detection/__init__.py b/IndicPhotoOCR/detection/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/IndicPhotoOCR/detection/east_config.py b/IndicPhotoOCR/detection/east_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..c6c157236b45de4c590d5c6e7cc147a2776f7af6
--- /dev/null
+++ b/IndicPhotoOCR/detection/east_config.py
@@ -0,0 +1,39 @@
+
+# data-config
+import numpy as np
+
+train_data_path = './dataset/train/'
+train_batch_size_per_gpu = 14 # 14
+num_workers = 24 # 24
+gpu_ids = [0] # [0,1,2,3]
+gpu = 1 # 4
+input_size = 512 # 预处理后归一化后图像尺寸
+background_ratio = 3. / 8 # 纯背景样本比例
+random_scale = np.array([0.5, 1, 2.0, 3.0]) # 提取多尺度图片信息
+geometry = 'RBOX' # 选择使用几何特征图类型
+max_image_large_side = 1280
+max_text_size = 800
+min_text_size = 10
+min_crop_side_ratio = 0.1
+means=[100, 100, 100]
+pretrained = True # 是否加载基础网络的预训练模型
+pretrained_basemodel_path = 'IndicPhotoOCR/detection/East/tmp/backbone_net/mobilenet_v2.pth.tar'
+pre_lr = 1e-4 # 基础网络的初始学习率
+lr = 1e-3 # 后面网络的初始学习率
+decay_steps = 50 # decayed_learning_rate = learning_rate * decay_rate ^ (global_epoch / decay_steps)
+decay_rate = 0.97
+init_type = 'xavier' # 网络参数初始化方式
+resume = True # 整体网络是否恢复原来保存的模型
+checkpoint = 'IndicPhotoOCR/detection/East/tmp/epoch_990_checkpoint.pth.tar' # 指定具体路径及文件名
+max_epochs = 1000 # 最大迭代epochs数
+l2_weight_decay = 1e-6 # l2正则化惩罚项权重
+print_freq = 10 # 每10个batch输出损失结果
+save_eval_iteration = 50 # 每10个epoch保存一次模型,并做一次评价
+save_model_path = './tmp/' # 模型保存路径
+test_img_path = './dataset/full_set' # demo测试样本路径'./demo/test_img/',数据集测试为'./dataset/test/'
+res_img_path = 'results' # demo结果存放路径'./demo/result_img/',数据集测试为 './dataset/test_result/'
+write_images = True # 是否输出图像结果
+score_map_thresh = 0.8 # 置信度阈值
+box_thresh = 0.1 # 文本框中置信度平均值的阈值
+nms_thres = 0.2 # 局部非极大抑制IOU阈值
+compute_hmean_path = './dataset/test_compute_hmean/'
diff --git a/IndicPhotoOCR/detection/east_detector.py b/IndicPhotoOCR/detection/east_detector.py
new file mode 100644
index 0000000000000000000000000000000000000000..3d7dcbfe16151121f758f6ad7cf1b07698abf8a4
--- /dev/null
+++ b/IndicPhotoOCR/detection/east_detector.py
@@ -0,0 +1,88 @@
+import os
+import torch
+import cv2
+import numpy as np
+import time
+import warnings
+
+
+import IndicPhotoOCR.detection.east_config as cfg
+from IndicPhotoOCR.detection.east_utils import ModelManager
+from IndicPhotoOCR.detection.east_model import East
+import IndicPhotoOCR.detection.east_utils as utils
+
+# Suppress warnings
+warnings.filterwarnings("ignore")
+
+class EASTdetector:
+ def __init__(self, model_name= "east", model_path=None):
+ self.model_path = model_path
+ # self.model_manager = ModelManager()
+ # self.model_manager.ensure_model(model_name)
+ # self.ensure_model(self.model_name)
+ # self.root_model_dir = "BharatSTR/bharatOCR/detection/East/tmp"
+
+ def detect(self, image_path, model_checkpoint, device):
+ # Load image
+ im = cv2.imread(image_path)
+ # im = cv2.imread(image_path)[:, :, ::-1]
+
+ # Initialize the EAST model and load checkpoint
+ model = East()
+ model = torch.nn.DataParallel(model, device_ids=cfg.gpu_ids)
+
+ # Load the model checkpoint with weights_only=True
+ checkpoint = torch.load(model_checkpoint, map_location=torch.device(device), weights_only=True)
+ model.load_state_dict(checkpoint['state_dict'])
+ model.eval()
+
+ # Resize image and convert to tensor format
+ im_resized, (ratio_h, ratio_w) = utils.resize_image(im)
+ im_resized = im_resized.astype(np.float32).transpose(2, 0, 1)
+ im_tensor = torch.from_numpy(im_resized).unsqueeze(0).cpu()
+
+ # Inference
+ timer = {'net': 0, 'restore': 0, 'nms': 0}
+ start = time.time()
+ score, geometry = model(im_tensor)
+ timer['net'] = time.time() - start
+
+ # Process output
+ score = score.permute(0, 2, 3, 1).data.cpu().numpy()
+ geometry = geometry.permute(0, 2, 3, 1).data.cpu().numpy()
+
+ # Detect boxes
+ boxes, timer = utils.detect(
+ score_map=score, geo_map=geometry, timer=timer,
+ score_map_thresh=cfg.score_map_thresh, box_thresh=cfg.box_thresh,
+ nms_thres=cfg.box_thresh
+ )
+ bbox_result_dict = {'detections': []}
+
+ # Parse detected boxes and adjust coordinates
+ if boxes is not None:
+ boxes = boxes[:, :8].reshape((-1, 4, 2))
+ boxes[:, :, 0] /= ratio_w
+ boxes[:, :, 1] /= ratio_h
+ for box in boxes:
+ box = utils.sort_poly(box.astype(np.int32))
+ if np.linalg.norm(box[0] - box[1]) < 5 or np.linalg.norm(box[3] - box[0]) < 5:
+ continue
+ bbox_result_dict['detections'].append([
+ [int(coord[0]), int(coord[1])] for coord in box
+ ])
+
+ return bbox_result_dict
+
+if __name__ == "__main__":
+ import argparse
+ parser = argparse.ArgumentParser(description='Text detection using EAST model')
+ parser.add_argument('--image_path', type=str, required=True, help='Path to the input image')
+ parser.add_argument('--device', type=str, default='cpu', help='Device to run the model on, e.g., "cpu" or "cuda"')
+ parser.add_argument('--model_checkpoint', type=str, required=True, help='Path to the model checkpoint file')
+ args = parser.parse_args()
+
+ # Run prediction and get results as dictionary
+ east = EASTdetector(model_path = args.model_checkpoint)
+ detection_result = east.detect(args.image_path, args.model_checkpoint, args.device)
+ # print(detection_result)
diff --git a/IndicPhotoOCR/detection/east_locality_aware_nms.py b/IndicPhotoOCR/detection/east_locality_aware_nms.py
new file mode 100644
index 0000000000000000000000000000000000000000..63eea757d4dde10d6dc2ab216c150fd5f7cdaec2
--- /dev/null
+++ b/IndicPhotoOCR/detection/east_locality_aware_nms.py
@@ -0,0 +1,75 @@
+
+import numpy as np
+from shapely.geometry import Polygon
+
+
+def intersection(g, p):
+ g = Polygon(g[:8].reshape((4, 2)))
+ p = Polygon(p[:8].reshape((4, 2)))
+ if not g.is_valid or not p.is_valid:
+ return 0
+ inter = Polygon(g).intersection(Polygon(p)).area
+ union = g.area + p.area - inter
+ if union == 0:
+ return 0
+ else:
+ return inter/union
+
+
+def weighted_merge(g, p):
+ # g[0]=min(g[0],p[0])
+ # g[1] = min(g[1], p[1])
+ # g[4] = max(g[4], p[4])
+ # g[5]= max(g[5],p[5])
+ #
+ # g[2] = max(g[2], p[2])
+ # g[3] = min(g[3], p[3])
+ # g[6] = min(g[6], p[6])
+ # g[7] = max(g[7], p[7])
+
+ g[:8] = (g[8] * g[:8] + p[8] * p[:8])/(g[8] + p[8])
+ g[8] = (g[8] + p[8])
+ return g
+
+
+def standard_nms(S, thres):
+ order = np.argsort(S[:, 8])[::-1]
+ keep = []
+ while order.size > 0:
+ i = order[0]
+ keep.append(i)
+ ovr = np.array([intersection(S[i], S[t]) for t in order[1:]])
+
+ inds = np.where(ovr <= thres)[0]
+ order = order[inds+1]
+
+ return S[keep]
+
+
+def nms_locality(polys, thres=0.3):
+ '''
+ locality aware nms of EAST
+ :param polys: a N*9 numpy array. first 8 coordinates, then prob
+ :return: boxes after nms
+ '''
+ S = []
+ p = None
+ for g in polys:
+ if p is not None and intersection(g, p) > thres:
+ p = weighted_merge(g, p)
+ else:
+ if p is not None:
+ S.append(p)
+ p = g
+ if p is not None:
+ S.append(p)
+
+ if len(S) == 0:
+ return np.array([])
+ return standard_nms(np.array(S), thres)
+
+
+if __name__ == '__main__':
+ # 343,350,448,135,474,143,369,359
+ print(Polygon(np.array([[343, 350], [448, 135],
+ [474, 143], [369, 359]])).area)
diff --git a/IndicPhotoOCR/detection/east_model.py b/IndicPhotoOCR/detection/east_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..912bb2df9bcafc068811cdbf3b81864754f5a990
--- /dev/null
+++ b/IndicPhotoOCR/detection/east_model.py
@@ -0,0 +1,242 @@
+
+import torch.nn as nn
+import math
+import torch
+
+
+from IndicPhotoOCR.detection import east_config as cfg
+from IndicPhotoOCR.detection import east_utils as utils
+
+
+def conv_bn(inp, oup, stride):
+ return nn.Sequential(
+ nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
+ nn.BatchNorm2d(oup),
+ nn.ReLU6(inplace=True)
+ )
+
+
+class InvertedResidual(nn.Module):
+ def __init__(self, inp, oup, stride, expand_ratio):
+ super(InvertedResidual, self).__init__()
+ self.stride = stride
+ assert stride in [1, 2]
+
+ hidden_dim = round(inp * expand_ratio)
+ self.use_res_connect = self.stride == 1 and inp == oup
+
+ if expand_ratio == 1:
+ self.conv = nn.Sequential(
+ # dw
+ nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
+ nn.BatchNorm2d(hidden_dim),
+ nn.ReLU6(inplace=True),
+ # pw-linear
+ nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
+ nn.BatchNorm2d(oup),
+ )
+ else:
+ self.conv = nn.Sequential(
+ # pw
+ nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False),
+ nn.BatchNorm2d(hidden_dim),
+ nn.ReLU6(inplace=True),
+ # dw
+ nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
+ nn.BatchNorm2d(hidden_dim),
+ nn.ReLU6(inplace=True),
+ # pw-linear
+ nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
+ nn.BatchNorm2d(oup),
+ )
+
+ def forward(self, x):
+ if self.use_res_connect:
+ return x + self.conv(x)
+ else:
+ return self.conv(x)
+
+
+class MobileNetV2(nn.Module):
+ def __init__(self, width_mult=1.):
+ super(MobileNetV2, self).__init__()
+ block = InvertedResidual
+ input_channel = 32
+ last_channel = 1280
+ interverted_residual_setting = [
+ # t, c, n, s
+ [1, 16, 1, 1],
+ [6, 24, 2, 2],
+ [6, 32, 3, 2],
+ [6, 64, 4, 2],
+ [6, 96, 3, 1],
+ [6, 160, 3, 2],
+ # [6, 320, 1, 1],
+ ]
+
+ # building first layer
+ # assert input_size % 32 == 0
+ input_channel = int(input_channel * width_mult)
+ self.last_channel = int(last_channel * width_mult) if width_mult > 1.0 else last_channel
+ self.features = [conv_bn(3, input_channel, 2)]
+ # building inverted residual blocks
+ for t, c, n, s in interverted_residual_setting:
+ output_channel = int(c * width_mult)
+ for i in range(n):
+ if i == 0:
+ self.features.append(block(input_channel, output_channel, s, expand_ratio=t))
+ else:
+ self.features.append(block(input_channel, output_channel, 1, expand_ratio=t))
+ input_channel = output_channel
+
+ # make it nn.Sequential
+ self.features = nn.Sequential(*self.features)
+
+ self._initialize_weights()
+
+ def forward(self, x):
+ x = self.features(x)
+ # x = x.mean(3).mean(2)
+ # x = self.classifier(x)
+ return x
+
+ def _initialize_weights(self):
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
+ m.weight.data.normal_(0, math.sqrt(2. / n))
+ if m.bias is not None:
+ m.bias.data.zero_()
+ elif isinstance(m, nn.BatchNorm2d):
+ m.weight.data.fill_(1)
+ m.bias.data.zero_()
+ elif isinstance(m, nn.Linear):
+ n = m.weight.size(1)
+ m.weight.data.normal_(0, 0.01)
+ m.bias.data.zero_()
+
+
+def mobilenet(pretrained=True, **kwargs):
+ """
+ Constructs a ResNet-50 model.
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ """
+ model = MobileNetV2()
+ if pretrained:
+ model_dict = model.state_dict()
+ pretrained_dict = torch.load(cfg.pretrained_basemodel_path,map_location=torch.device('cpu'), weights_only=True)
+ pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
+ model_dict.update(pretrained_dict)
+ model.load_state_dict(model_dict)
+ # state_dict = torch.load(cfg.pretrained_basemodel_path) # add map_location='cpu' if no gpu
+ # model.load_state_dict(state_dict)
+
+ return model
+
+
+class East(nn.Module):
+ def __init__(self):
+ super(East, self).__init__()
+ self.mobilenet = mobilenet(True)
+ # self.si for stage i
+ self.s1 = nn.Sequential(*list(self.mobilenet.children())[0][0:4])
+ self.s2 = nn.Sequential(*list(self.mobilenet.children())[0][4:7])
+ self.s3 = nn.Sequential(*list(self.mobilenet.children())[0][7:14])
+ self.s4 = nn.Sequential(*list(self.mobilenet.children())[0][14:17])
+
+ self.conv1 = nn.Conv2d(160+96, 128, 1)
+ self.bn1 = nn.BatchNorm2d(128)
+ self.relu1 = nn.ReLU()
+
+ self.conv2 = nn.Conv2d(128, 128, 3, padding=1)
+ self.bn2 = nn.BatchNorm2d(128)
+ self.relu2 = nn.ReLU()
+
+ self.conv3 = nn.Conv2d(128+32, 64, 1)
+ self.bn3 = nn.BatchNorm2d(64)
+ self.relu3 = nn.ReLU()
+
+ self.conv4 = nn.Conv2d(64, 64, 3, padding=1)
+ self.bn4 = nn.BatchNorm2d(64)
+ self.relu4 = nn.ReLU()
+
+ self.conv5 = nn.Conv2d(64+24, 64, 1)
+ self.bn5 = nn.BatchNorm2d(64)
+ self.relu5 = nn.ReLU()
+
+ self.conv6 = nn.Conv2d(64, 32, 3, padding=1)
+ self.bn6 = nn.BatchNorm2d(32)
+ self.relu6 = nn.ReLU()
+
+ self.conv7 = nn.Conv2d(32, 32, 3, padding=1)
+ self.bn7 = nn.BatchNorm2d(32)
+ self.relu7 = nn.ReLU()
+
+ self.conv8 = nn.Conv2d(32, 1, 1)
+ self.sigmoid1 = nn.Sigmoid()
+ self.conv9 = nn.Conv2d(32, 4, 1)
+ self.sigmoid2 = nn.Sigmoid()
+ self.conv10 = nn.Conv2d(32, 1, 1)
+ self.sigmoid3 = nn.Sigmoid()
+ self.unpool1 = nn.Upsample(scale_factor=2, mode='bilinear')
+ self.unpool2 = nn.Upsample(scale_factor=2, mode='bilinear')
+ self.unpool3 = nn.Upsample(scale_factor=2, mode='bilinear')
+
+ # utils.init_weights([self.conv1,self.conv2,self.conv3,self.conv4,
+ # self.conv5,self.conv6,self.conv7,self.conv8,
+ # self.conv9,self.conv10,self.bn1,self.bn2,
+ # self.bn3,self.bn4,self.bn5,self.bn6,self.bn7])
+
+ def forward(self, images):
+ images = utils.mean_image_subtraction(images)
+
+ f0 = self.s1(images)
+ f1 = self.s2(f0)
+ f2 = self.s3(f1)
+ f3 = self.s4(f2)
+
+ # _, f = self.mobilenet(images)
+ h = f3 # bs 2048 w/32 h/32
+ g = (self.unpool1(h)) # bs 2048 w/16 h/16
+ c = self.conv1(torch.cat((g, f2), 1))
+ c = self.bn1(c)
+ c = self.relu1(c)
+
+ h = self.conv2(c) # bs 128 w/16 h/16
+ h = self.bn2(h)
+ h = self.relu2(h)
+ g = self.unpool2(h) # bs 128 w/8 h/8
+ c = self.conv3(torch.cat((g, f1), 1))
+ c = self.bn3(c)
+ c = self.relu3(c)
+
+ h = self.conv4(c) # bs 64 w/8 h/8
+ h = self.bn4(h)
+ h = self.relu4(h)
+ g = self.unpool3(h) # bs 64 w/4 h/4
+ c = self.conv5(torch.cat((g, f0), 1))
+ c = self.bn5(c)
+ c = self.relu5(c)
+
+ h = self.conv6(c) # bs 32 w/4 h/4
+ h = self.bn6(h)
+ h = self.relu6(h)
+ g = self.conv7(h) # bs 32 w/4 h/4
+ g = self.bn7(g)
+ g = self.relu7(g)
+
+ F_score = self.conv8(g) # bs 1 w/4 h/4
+ F_score = self.sigmoid1(F_score)
+ geo_map = self.conv9(g)
+ geo_map = self.sigmoid2(geo_map) * 512
+ angle_map = self.conv10(g)
+ angle_map = self.sigmoid3(angle_map)
+ angle_map = (angle_map - 0.5) * math.pi / 2
+
+ F_geometry = torch.cat((geo_map, angle_map), 1) # bs 5 w/4 h/4
+
+ return F_score, F_geometry
+
+
+model=East()
diff --git a/IndicPhotoOCR/detection/east_preprossing.py b/IndicPhotoOCR/detection/east_preprossing.py
new file mode 100644
index 0000000000000000000000000000000000000000..aa7eeffbfa163cf133736dae9b85967f2432f78e
--- /dev/null
+++ b/IndicPhotoOCR/detection/east_preprossing.py
@@ -0,0 +1,681 @@
+
+# coding:utf-8
+import glob
+import csv
+import cv2
+import os
+import numpy as np
+from shapely.geometry import Polygon
+
+
+from IndicPhotoOCR.detection import east_config as cfg
+from IndicPhotoOCR.detection import east_utils
+
+
+def get_images(img_root):
+ files = []
+ for ext in ['jpg']:
+ files.extend(glob.glob(
+ os.path.join(img_root, '*.{}'.format(ext))))
+ # print(glob.glob(
+ # os.path.join(FLAGS.training_data_path, '*.{}'.format(ext))))
+ return files
+
+
+def load_annoataion(p):
+ '''
+ load annotation from the text file
+ :param p:
+ :return:
+ '''
+ text_polys = []
+ text_tags = []
+ if not os.path.exists(p):
+ return np.array(text_polys, dtype=np.float32)
+ with open(p, 'r', encoding='UTF-8') as f:
+ reader = csv.reader(f)
+ for line in reader:
+ label = line[-1]
+ # strip BOM. \ufeff for python3, \xef\xbb\bf for python2
+ line = [i.strip('\ufeff').strip('\xef\xbb\xbf') for i in line]
+
+ x1, y1, x2, y2, x3, y3, x4, y4 = list(map(float, line[:8]))
+ text_polys.append([[x1, y1], [x2, y2], [x3, y3], [x4, y4]])
+ # print(text_polys)
+ if label == '*' or label == '###':
+ text_tags.append(True)
+ else:
+ text_tags.append(False)
+ return np.array(text_polys, dtype=np.float32), np.array(text_tags, dtype=np.bool)
+
+
+def polygon_area(poly):
+ '''
+ compute area of a polygon
+ :param poly:
+ :return:
+ '''
+ edge = [
+ (poly[1][0] - poly[0][0]) * (poly[1][1] + poly[0][1]),
+ (poly[2][0] - poly[1][0]) * (poly[2][1] + poly[1][1]),
+ (poly[3][0] - poly[2][0]) * (poly[3][1] + poly[2][1]),
+ (poly[0][0] - poly[3][0]) * (poly[0][1] + poly[3][1])
+ ]
+ return np.sum(edge) / 2.
+
+
+def check_and_validate_polys(polys, tags, xxx_todo_changeme):
+ '''
+ check so that the text poly is in the same direction,
+ and also filter some invalid polygons
+ :param polys:
+ :param tags:
+ :return:
+ '''
+ (h, w) = xxx_todo_changeme
+ if polys.shape[0] == 0:
+ return polys
+ polys[:, :, 0] = np.clip(polys[:, :, 0], 0, w - 1)
+ polys[:, :, 1] = np.clip(polys[:, :, 1], 0, h - 1)
+
+ validated_polys = []
+ validated_tags = []
+
+ # 判断四边形的点时针方向,以及是否是有效四边形
+ for poly, tag in zip(polys, tags):
+ p_area = polygon_area(poly)
+ if abs(p_area) < 1:
+ # print poly
+ print('invalid poly')
+ continue
+ if p_area > 0:
+ print('poly in wrong direction')
+ poly = poly[(0, 3, 2, 1), :]
+ validated_polys.append(poly)
+ validated_tags.append(tag)
+ return np.array(validated_polys), np.array(validated_tags)
+
+
+def crop_area(im, polys, tags, crop_background=False, max_tries=100):
+ '''
+ make random crop from the input image
+ :param im:
+ :param polys:
+ :param tags:
+ :param crop_background:
+ :param max_tries:
+ :return:
+ '''
+ h, w, _ = im.shape
+ pad_h = h // 10
+ pad_w = w // 10
+ h_array = np.zeros((h + pad_h * 2), dtype=np.int32)
+ w_array = np.zeros((w + pad_w * 2), dtype=np.int32)
+ for poly in polys:
+ poly = np.round(poly, decimals=0).astype(np.int32)
+ minx = np.min(poly[:, 0])
+ maxx = np.max(poly[:, 0])
+ w_array[minx + pad_w:maxx + pad_w] = 1
+ miny = np.min(poly[:, 1])
+ maxy = np.max(poly[:, 1])
+ h_array[miny + pad_h:maxy + pad_h] = 1
+ # ensure the cropped area not across a text,保证裁剪区域不能与文本交叉
+ h_axis = np.where(h_array == 0)[0]
+ w_axis = np.where(w_array == 0)[0]
+ if len(h_axis) == 0 or len(w_axis) == 0:
+ return im, polys, tags
+ for i in range(max_tries): # 试验50次
+ xx = np.random.choice(w_axis, size=2)
+ xmin = np.min(xx) - pad_w
+ xmax = np.max(xx) - pad_w
+ xmin = np.clip(xmin, 0, w - 1)
+ xmax = np.clip(xmax, 0, w - 1)
+ yy = np.random.choice(h_axis, size=2)
+ ymin = np.min(yy) - pad_h
+ ymax = np.max(yy) - pad_h
+ ymin = np.clip(ymin, 0, h - 1)
+ ymax = np.clip(ymax, 0, h - 1)
+ if xmax - xmin < cfg.min_crop_side_ratio * w or ymax - ymin < cfg.min_crop_side_ratio * h:
+ # area too small
+ continue
+ if polys.shape[0] != 0:
+ poly_axis_in_area = (polys[:, :, 0] >= xmin) & (polys[:, :, 0] <= xmax) \
+ & (polys[:, :, 1] >= ymin) & (polys[:, :, 1] <= ymax)
+ selected_polys = np.where(np.sum(poly_axis_in_area, axis=1) == 4)[0]
+ else:
+ selected_polys = []
+ if len(selected_polys) == 0:
+ # no text in this area
+ if crop_background:
+ return im[ymin:ymax + 1, xmin:xmax + 1, :], polys[selected_polys], tags[selected_polys]
+ else:
+ continue
+ im = im[ymin:ymax + 1, xmin:xmax + 1, :]
+ polys = polys[selected_polys]
+ tags = tags[selected_polys]
+ polys[:, :, 0] -= xmin
+ polys[:, :, 1] -= ymin
+ return im, polys, tags
+
+ return im, polys, tags
+
+
+def shrink_poly(poly, r):
+ '''
+ fit a poly inside the origin poly, maybe bugs here...
+ used for generate the score map
+ :param poly: the text poly
+ :param r: r in the paper
+ :return: the shrinked poly
+ '''
+ # shrink ratio
+ R = 0.3
+ # find the longer pair
+ if np.linalg.norm(poly[0] - poly[1]) + np.linalg.norm(poly[2] - poly[3]) > \
+ np.linalg.norm(poly[0] - poly[3]) + np.linalg.norm(poly[1] - poly[2]):
+ # first move (p0, p1), (p2, p3), then (p0, p3), (p1, p2)
+ ## p0, p1
+ theta = np.arctan2((poly[1][1] - poly[0][1]), (poly[1][0] - poly[0][0]))
+ poly[0][0] += R * r[0] * np.cos(theta)
+ poly[0][1] += R * r[0] * np.sin(theta)
+ poly[1][0] -= R * r[1] * np.cos(theta)
+ poly[1][1] -= R * r[1] * np.sin(theta)
+ ## p2, p3
+ theta = np.arctan2((poly[2][1] - poly[3][1]), (poly[2][0] - poly[3][0]))
+ poly[3][0] += R * r[3] * np.cos(theta)
+ poly[3][1] += R * r[3] * np.sin(theta)
+ poly[2][0] -= R * r[2] * np.cos(theta)
+ poly[2][1] -= R * r[2] * np.sin(theta)
+ ## p0, p3
+ theta = np.arctan2((poly[3][0] - poly[0][0]), (poly[3][1] - poly[0][1]))
+ poly[0][0] += R * r[0] * np.sin(theta)
+ poly[0][1] += R * r[0] * np.cos(theta)
+ poly[3][0] -= R * r[3] * np.sin(theta)
+ poly[3][1] -= R * r[3] * np.cos(theta)
+ ## p1, p2
+ theta = np.arctan2((poly[2][0] - poly[1][0]), (poly[2][1] - poly[1][1]))
+ poly[1][0] += R * r[1] * np.sin(theta)
+ poly[1][1] += R * r[1] * np.cos(theta)
+ poly[2][0] -= R * r[2] * np.sin(theta)
+ poly[2][1] -= R * r[2] * np.cos(theta)
+ else:
+ ## p0, p3
+ # print poly
+ theta = np.arctan2((poly[3][0] - poly[0][0]), (poly[3][1] - poly[0][1]))
+ poly[0][0] += R * r[0] * np.sin(theta)
+ poly[0][1] += R * r[0] * np.cos(theta)
+ poly[3][0] -= R * r[3] * np.sin(theta)
+ poly[3][1] -= R * r[3] * np.cos(theta)
+ ## p1, p2
+ theta = np.arctan2((poly[2][0] - poly[1][0]), (poly[2][1] - poly[1][1]))
+ poly[1][0] += R * r[1] * np.sin(theta)
+ poly[1][1] += R * r[1] * np.cos(theta)
+ poly[2][0] -= R * r[2] * np.sin(theta)
+ poly[2][1] -= R * r[2] * np.cos(theta)
+ ## p0, p1
+ theta = np.arctan2((poly[1][1] - poly[0][1]), (poly[1][0] - poly[0][0]))
+ poly[0][0] += R * r[0] * np.cos(theta)
+ poly[0][1] += R * r[0] * np.sin(theta)
+ poly[1][0] -= R * r[1] * np.cos(theta)
+ poly[1][1] -= R * r[1] * np.sin(theta)
+ ## p2, p3
+ theta = np.arctan2((poly[2][1] - poly[3][1]), (poly[2][0] - poly[3][0]))
+ poly[3][0] += R * r[3] * np.cos(theta)
+ poly[3][1] += R * r[3] * np.sin(theta)
+ poly[2][0] -= R * r[2] * np.cos(theta)
+ poly[2][1] -= R * r[2] * np.sin(theta)
+ return poly
+
+
+# def point_dist_to_line(p1, p2, p3):
+# # compute the distance from p3 to p1-p2
+# return np.linalg.norm(np.cross(p2 - p1, p1 - p3)) / np.linalg.norm(p2 - p1)
+
+
+# 点p3到直线p12的距离
+def point_dist_to_line(p1, p2, p3):
+ # compute the distance from p3 to p1-p2
+ # return np.linalg.norm(np.cross(p2 - p1, p1 - p3)) / np.linalg.norm(p2 - p1)
+ a = np.linalg.norm(p1 - p2)
+ b = np.linalg.norm(p2 - p3)
+ c = np.linalg.norm(p3 - p1)
+ s = (a + b + c) / 2.0
+ area = np.abs((s * (s - a) * (s - b) * (s - c))) ** 0.5
+ if a < 1.0:
+ return (b + c) / 2.0
+ return 2 * area / a
+
+
+def fit_line(p1, p2):
+ # fit a line ax+by+c = 0
+ if p1[0] == p1[1]:
+ return [1., 0., -p1[0]]
+ else:
+ [k, b] = np.polyfit(p1, p2, deg=1)
+ return [k, -1., b]
+
+
+def line_cross_point(line1, line2):
+ # line1 0= ax+by+c, compute the cross point of line1 and line2
+ if line1[0] != 0 and line1[0] == line2[0]:
+ print('Cross point does not exist')
+ return None
+ if line1[0] == 0 and line2[0] == 0:
+ print('Cross point does not exist')
+ return None
+ if line1[1] == 0:
+ x = -line1[2]
+ y = line2[0] * x + line2[2]
+ elif line2[1] == 0:
+ x = -line2[2]
+ y = line1[0] * x + line1[2]
+ else:
+ k1, _, b1 = line1
+ k2, _, b2 = line2
+ x = -(b1 - b2) / (k1 - k2)
+ y = k1 * x + b1
+ return np.array([x, y], dtype=np.float32)
+
+
+def line_verticle(line, point):
+ # get the verticle line from line across point
+ if line[1] == 0:
+ verticle = [0, -1, point[1]]
+ else:
+ if line[0] == 0:
+ verticle = [1, 0, -point[0]]
+ else:
+ verticle = [-1. / line[0], -1, point[1] - (-1 / line[0] * point[0])]
+ return verticle
+
+
+def rectangle_from_parallelogram(poly):
+ '''
+ fit a rectangle from a parallelogram
+ :param poly:
+ :return:
+ '''
+ p0, p1, p2, p3 = poly
+ angle_p0 = np.arccos(np.dot(p1 - p0, p3 - p0) / (np.linalg.norm(p0 - p1) * np.linalg.norm(p3 - p0)))
+ if angle_p0 < 0.5 * np.pi:
+ if np.linalg.norm(p0 - p1) > np.linalg.norm(p0 - p3):
+ # p0 and p2
+ ## p0
+ p2p3 = fit_line([p2[0], p3[0]], [p2[1], p3[1]])
+ p2p3_verticle = line_verticle(p2p3, p0)
+
+ new_p3 = line_cross_point(p2p3, p2p3_verticle)
+ ## p2
+ p0p1 = fit_line([p0[0], p1[0]], [p0[1], p1[1]])
+ p0p1_verticle = line_verticle(p0p1, p2)
+
+ new_p1 = line_cross_point(p0p1, p0p1_verticle)
+ return np.array([p0, new_p1, p2, new_p3], dtype=np.float32)
+ else:
+ p1p2 = fit_line([p1[0], p2[0]], [p1[1], p2[1]])
+ p1p2_verticle = line_verticle(p1p2, p0)
+
+ new_p1 = line_cross_point(p1p2, p1p2_verticle)
+ p0p3 = fit_line([p0[0], p3[0]], [p0[1], p3[1]])
+ p0p3_verticle = line_verticle(p0p3, p2)
+
+ new_p3 = line_cross_point(p0p3, p0p3_verticle)
+ return np.array([p0, new_p1, p2, new_p3], dtype=np.float32)
+ else:
+ if np.linalg.norm(p0 - p1) > np.linalg.norm(p0 - p3):
+ # p1 and p3
+ ## p1
+ p2p3 = fit_line([p2[0], p3[0]], [p2[1], p3[1]])
+ p2p3_verticle = line_verticle(p2p3, p1)
+
+ new_p2 = line_cross_point(p2p3, p2p3_verticle)
+ ## p3
+ p0p1 = fit_line([p0[0], p1[0]], [p0[1], p1[1]])
+ p0p1_verticle = line_verticle(p0p1, p3)
+
+ new_p0 = line_cross_point(p0p1, p0p1_verticle)
+ return np.array([new_p0, p1, new_p2, p3], dtype=np.float32)
+ else:
+ p0p3 = fit_line([p0[0], p3[0]], [p0[1], p3[1]])
+ p0p3_verticle = line_verticle(p0p3, p1)
+
+ new_p0 = line_cross_point(p0p3, p0p3_verticle)
+ p1p2 = fit_line([p1[0], p2[0]], [p1[1], p2[1]])
+ p1p2_verticle = line_verticle(p1p2, p3)
+
+ new_p2 = line_cross_point(p1p2, p1p2_verticle)
+ return np.array([new_p0, p1, new_p2, p3], dtype=np.float32)
+
+
+def sort_rectangle(poly):
+ # sort the four coordinates of the polygon, points in poly should be sorted clockwise
+ # First find the lowest point
+ p_lowest = np.argmax(poly[:, 1])
+ if np.count_nonzero(poly[:, 1] == poly[p_lowest, 1]) == 2:
+ # 底边平行于X轴, 那么p0为左上角 - if the bottom line is parallel to x-axis, then p0 must be the upper-left corner
+ p0_index = np.argmin(np.sum(poly, axis=1))
+ p1_index = (p0_index + 1) % 4
+ p2_index = (p0_index + 2) % 4
+ p3_index = (p0_index + 3) % 4
+ return poly[[p0_index, p1_index, p2_index, p3_index]], 0.
+ else:
+ # 找到最低点右边的点 - find the point that sits right to the lowest point
+ p_lowest_right = (p_lowest - 1) % 4
+ p_lowest_left = (p_lowest + 1) % 4
+ angle = np.arctan(
+ -(poly[p_lowest][1] - poly[p_lowest_right][1]) / (poly[p_lowest][0] - poly[p_lowest_right][0]))
+ # assert angle > 0
+ if angle <= 0:
+ print(angle, poly[p_lowest], poly[p_lowest_right])
+ if angle / np.pi * 180 > 45:
+ # 这个点为p2 - this point is p2
+ p2_index = p_lowest
+ p1_index = (p2_index - 1) % 4
+ p0_index = (p2_index - 2) % 4
+ p3_index = (p2_index + 1) % 4
+ return poly[[p0_index, p1_index, p2_index, p3_index]], -(np.pi / 2 - angle)
+ else:
+ # 这个点为p3 - this point is p3
+ p3_index = p_lowest
+ p0_index = (p3_index + 1) % 4
+ p1_index = (p3_index + 2) % 4
+ p2_index = (p3_index + 3) % 4
+ return poly[[p0_index, p1_index, p2_index, p3_index]], angle
+
+
+def restore_rectangle_rbox(origin, geometry):
+ d = geometry[:, :4]
+ angle = geometry[:, 4]
+ # for angle > 0
+ origin_0 = origin[angle >= 0]
+ d_0 = d[angle >= 0]
+ angle_0 = angle[angle >= 0]
+ if origin_0.shape[0] > 0:
+ p = np.array([np.zeros(d_0.shape[0]), -d_0[:, 0] - d_0[:, 2],
+ d_0[:, 1] + d_0[:, 3], -d_0[:, 0] - d_0[:, 2],
+ d_0[:, 1] + d_0[:, 3], np.zeros(d_0.shape[0]),
+ np.zeros(d_0.shape[0]), np.zeros(d_0.shape[0]),
+ d_0[:, 3], -d_0[:, 2]])
+ p = p.transpose((1, 0)).reshape((-1, 5, 2)) # N*5*2
+
+ rotate_matrix_x = np.array([np.cos(angle_0), np.sin(angle_0)]).transpose((1, 0))
+ rotate_matrix_x = np.repeat(rotate_matrix_x, 5, axis=1).reshape(-1, 2, 5).transpose((0, 2, 1)) # N*5*2
+
+ rotate_matrix_y = np.array([-np.sin(angle_0), np.cos(angle_0)]).transpose((1, 0))
+ rotate_matrix_y = np.repeat(rotate_matrix_y, 5, axis=1).reshape(-1, 2, 5).transpose((0, 2, 1))
+
+ p_rotate_x = np.sum(rotate_matrix_x * p, axis=2)[:, :, np.newaxis] # N*5*1
+ p_rotate_y = np.sum(rotate_matrix_y * p, axis=2)[:, :, np.newaxis] # N*5*1
+
+ p_rotate = np.concatenate([p_rotate_x, p_rotate_y], axis=2) # N*5*2
+
+ p3_in_origin = origin_0 - p_rotate[:, 4, :]
+ new_p0 = p_rotate[:, 0, :] + p3_in_origin # N*2
+ new_p1 = p_rotate[:, 1, :] + p3_in_origin
+ new_p2 = p_rotate[:, 2, :] + p3_in_origin
+ new_p3 = p_rotate[:, 3, :] + p3_in_origin
+
+ new_p_0 = np.concatenate([new_p0[:, np.newaxis, :], new_p1[:, np.newaxis, :],
+ new_p2[:, np.newaxis, :], new_p3[:, np.newaxis, :]], axis=1) # N*4*2
+ else:
+ new_p_0 = np.zeros((0, 4, 2))
+ # for angle < 0
+ origin_1 = origin[angle < 0]
+ d_1 = d[angle < 0]
+ angle_1 = angle[angle < 0]
+ if origin_1.shape[0] > 0:
+ p = np.array([-d_1[:, 1] - d_1[:, 3], -d_1[:, 0] - d_1[:, 2],
+ np.zeros(d_1.shape[0]), -d_1[:, 0] - d_1[:, 2],
+ np.zeros(d_1.shape[0]), np.zeros(d_1.shape[0]),
+ -d_1[:, 1] - d_1[:, 3], np.zeros(d_1.shape[0]),
+ -d_1[:, 1], -d_1[:, 2]])
+ p = p.transpose((1, 0)).reshape((-1, 5, 2)) # N*5*2
+
+ rotate_matrix_x = np.array([np.cos(-angle_1), -np.sin(-angle_1)]).transpose((1, 0))
+ rotate_matrix_x = np.repeat(rotate_matrix_x, 5, axis=1).reshape(-1, 2, 5).transpose((0, 2, 1)) # N*5*2
+
+ rotate_matrix_y = np.array([np.sin(-angle_1), np.cos(-angle_1)]).transpose((1, 0))
+ rotate_matrix_y = np.repeat(rotate_matrix_y, 5, axis=1).reshape(-1, 2, 5).transpose((0, 2, 1))
+
+ p_rotate_x = np.sum(rotate_matrix_x * p, axis=2)[:, :, np.newaxis] # N*5*1
+ p_rotate_y = np.sum(rotate_matrix_y * p, axis=2)[:, :, np.newaxis] # N*5*1
+
+ p_rotate = np.concatenate([p_rotate_x, p_rotate_y], axis=2) # N*5*2
+
+ p3_in_origin = origin_1 - p_rotate[:, 4, :]
+ new_p0 = p_rotate[:, 0, :] + p3_in_origin # N*2
+ new_p1 = p_rotate[:, 1, :] + p3_in_origin
+ new_p2 = p_rotate[:, 2, :] + p3_in_origin
+ new_p3 = p_rotate[:, 3, :] + p3_in_origin
+
+ new_p_1 = np.concatenate([new_p0[:, np.newaxis, :], new_p1[:, np.newaxis, :],
+ new_p2[:, np.newaxis, :], new_p3[:, np.newaxis, :]], axis=1) # N*4*2
+ else:
+ new_p_1 = np.zeros((0, 4, 2))
+ return np.concatenate([new_p_0, new_p_1])
+
+
+def restore_rectangle(origin, geometry):
+ return restore_rectangle_rbox(origin, geometry)
+
+
+def generate_rbox(im_size, polys, tags):
+ h, w = im_size
+ poly_mask = np.zeros((h, w), dtype=np.uint8)
+ score_map = np.zeros((h, w), dtype=np.uint8)
+ geo_map = np.zeros((h, w, 5), dtype=np.float32)
+ # mask used during traning, to ignore some hard areas,用于忽略那些过小的文本
+ training_mask = np.ones((h, w), dtype=np.uint8)
+ for poly_idx, poly_tag in enumerate(zip(polys, tags)):
+ poly = poly_tag[0]
+ tag = poly_tag[1]
+
+ # 对每个顶点,找到经过他的两条边中较短的那条
+ r = [None, None, None, None]
+ for i in range(4):
+ r[i] = min(np.linalg.norm(poly[i] - poly[(i + 1) % 4]),
+ np.linalg.norm(poly[i] - poly[(i - 1) % 4]))
+ # score map
+ # 放缩边框为之前的0.3倍,并对边框对应score图中的位置进行填充
+ shrinked_poly = shrink_poly(poly.copy(), r).astype(np.int32)[np.newaxis, :, :]
+ cv2.fillPoly(score_map, shrinked_poly, 1)
+ cv2.fillPoly(poly_mask, shrinked_poly, poly_idx + 1)
+ # if the poly is too small, then ignore it during training
+ # 如果文本框标签太小或者txt中没具体标记是什么内容,即*或者###,则加掩模,训练时忽略该部分
+ poly_h = min(np.linalg.norm(poly[0] - poly[3]), np.linalg.norm(poly[1] - poly[2]))
+ poly_w = min(np.linalg.norm(poly[0] - poly[1]), np.linalg.norm(poly[2] - poly[3]))
+ if min(poly_h, poly_w) < cfg.min_text_size:
+ cv2.fillPoly(training_mask, poly.astype(np.int32)[np.newaxis, :, :], 0)
+ if tag:
+ cv2.fillPoly(training_mask, poly.astype(np.int32)[np.newaxis, :, :], 0)
+
+ # 当前新加入的文本框区域像素点
+ xy_in_poly = np.argwhere(poly_mask == (poly_idx + 1))
+ # if geometry == 'RBOX':
+ # 对任意两个顶点的组合生成一个平行四边形 - generate a parallelogram for any combination of two vertices
+ fitted_parallelograms = []
+ for i in range(4):
+ # 选中p0和p1的连线边,生成两个平行四边形
+ p0 = poly[i]
+ p1 = poly[(i + 1) % 4]
+ p2 = poly[(i + 2) % 4]
+ p3 = poly[(i + 3) % 4]
+ # 拟合ax+by+c=0
+ edge = fit_line([p0[0], p1[0]], [p0[1], p1[1]])
+ backward_edge = fit_line([p0[0], p3[0]], [p0[1], p3[1]])
+ forward_edge = fit_line([p1[0], p2[0]], [p1[1], p2[1]])
+ # 通过另外两个点距离edge的距离,来决定edge对应的平行线应该过p2还是p3
+ if point_dist_to_line(p0, p1, p2) > point_dist_to_line(p0, p1, p3):
+ # 平行线经过p2 - parallel lines through p2
+ if edge[1] == 0:
+ edge_opposite = [1, 0, -p2[0]]
+ else:
+ edge_opposite = [edge[0], -1, p2[1] - edge[0] * p2[0]]
+ else:
+ # 经过p3 - after p3
+ if edge[1] == 0:
+ edge_opposite = [1, 0, -p3[0]]
+ else:
+ edge_opposite = [edge[0], -1, p3[1] - edge[0] * p3[0]]
+ # move forward edge
+ new_p0 = p0
+ new_p1 = p1
+ new_p2 = p2
+ new_p3 = p3
+ new_p2 = line_cross_point(forward_edge, edge_opposite)
+ if point_dist_to_line(p1, new_p2, p0) > point_dist_to_line(p1, new_p2, p3):
+ # across p0
+ if forward_edge[1] == 0:
+ forward_opposite = [1, 0, -p0[0]]
+ else:
+ forward_opposite = [forward_edge[0], -1, p0[1] - forward_edge[0] * p0[0]]
+ else:
+ # across p3
+ if forward_edge[1] == 0:
+ forward_opposite = [1, 0, -p3[0]]
+ else:
+ forward_opposite = [forward_edge[0], -1, p3[1] - forward_edge[0] * p3[0]]
+ new_p0 = line_cross_point(forward_opposite, edge)
+ new_p3 = line_cross_point(forward_opposite, edge_opposite)
+ fitted_parallelograms.append([new_p0, new_p1, new_p2, new_p3, new_p0])
+ # or move backward edge
+ new_p0 = p0
+ new_p1 = p1
+ new_p2 = p2
+ new_p3 = p3
+ new_p3 = line_cross_point(backward_edge, edge_opposite)
+ if point_dist_to_line(p0, p3, p1) > point_dist_to_line(p0, p3, p2):
+ # across p1
+ if backward_edge[1] == 0:
+ backward_opposite = [1, 0, -p1[0]]
+ else:
+ backward_opposite = [backward_edge[0], -1, p1[1] - backward_edge[0] * p1[0]]
+ else:
+ # across p2
+ if backward_edge[1] == 0:
+ backward_opposite = [1, 0, -p2[0]]
+ else:
+ backward_opposite = [backward_edge[0], -1, p2[1] - backward_edge[0] * p2[0]]
+ new_p1 = line_cross_point(backward_opposite, edge)
+ new_p2 = line_cross_point(backward_opposite, edge_opposite)
+ fitted_parallelograms.append([new_p0, new_p1, new_p2, new_p3, new_p0])
+
+ # 选定面积最小的平行四边形
+ areas = [Polygon(t).area for t in fitted_parallelograms]
+ parallelogram = np.array(fitted_parallelograms[np.argmin(areas)][:-1], dtype=np.float32)
+ # sort thie polygon
+ parallelogram_coord_sum = np.sum(parallelogram, axis=1)
+ min_coord_idx = np.argmin(parallelogram_coord_sum)
+ parallelogram = parallelogram[
+ [min_coord_idx, (min_coord_idx + 1) % 4, (min_coord_idx + 2) % 4, (min_coord_idx + 3) % 4]]
+
+ # 得到外包矩形即旋转角
+ rectange = rectangle_from_parallelogram(parallelogram)
+ rectange, rotate_angle = sort_rectangle(rectange)
+
+ p0_rect, p1_rect, p2_rect, p3_rect = rectange
+ # 对当前新加入的文本框区域像素点,根据其到矩形四边的距离修改geo_map
+ for y, x in xy_in_poly:
+ point = np.array([x, y], dtype=np.float32)
+ # top
+ geo_map[y, x, 0] = point_dist_to_line(p0_rect, p1_rect, point)
+ # right
+ geo_map[y, x, 1] = point_dist_to_line(p1_rect, p2_rect, point)
+ # down
+ geo_map[y, x, 2] = point_dist_to_line(p2_rect, p3_rect, point)
+ # left
+ geo_map[y, x, 3] = point_dist_to_line(p3_rect, p0_rect, point)
+ # angle
+ geo_map[y, x, 4] = rotate_angle
+ return score_map, geo_map, training_mask
+
+
+def generator(index,
+ input_size=512,
+ background_ratio=3. / 8, # 纯背景样本比例
+ random_scale=np.array([0.5, 1, 2.0, 3.0]), # 提取多尺度图片信息
+ image_list=None):
+ try:
+ im_fn = image_list[index]
+ im = cv2.imread(im_fn)
+ if im is None:
+ print("can't find image")
+ return None, None, None, None, None
+ h, w, _ = im.shape
+ # 所以要把gt去掉
+ txt_fn = im_fn.replace(os.path.basename(im_fn).split('.')[1], 'txt')
+ if not os.path.exists(txt_fn):
+ print('text file {} does not exists'.format(txt_fn))
+ return None, None, None, None, None
+ # 加载标注框信息
+ text_polys, text_tags = load_annoataion(txt_fn)
+
+ text_polys, text_tags = check_and_validate_polys(text_polys, text_tags, (h, w))
+
+ # random scale this image,随机选择一种尺度
+ rd_scale = np.random.choice(random_scale)
+ im = cv2.resize(im, dsize=None, fx=rd_scale, fy=rd_scale)
+ text_polys *= rd_scale
+
+ # random crop a area from image,3/8的选中的概率,裁剪纯背景的图片
+ if np.random.rand() < background_ratio:
+ # crop background
+ im, text_polys, text_tags = crop_area(im, text_polys, text_tags, crop_background=True)
+ if text_polys.shape[0] > 0:
+ # print("cannot find background")
+ return None, None, None, None, None
+ # pad and resize image
+ new_h, new_w, _ = im.shape
+ max_h_w_i = np.max([new_h, new_w, input_size])
+ im_padded = np.zeros((max_h_w_i, max_h_w_i, 3), dtype=np.uint8)
+ im_padded[:new_h, :new_w, :] = im.copy()
+ # 将裁剪后图片扩充成512*512的图片
+ im = cv2.resize(im_padded, dsize=(input_size, input_size))
+ score_map = np.zeros((input_size, input_size), dtype=np.uint8)
+ geo_map_channels = 5 if cfg.geometry == 'RBOX' else 8
+ geo_map = np.zeros((input_size, input_size, geo_map_channels), dtype=np.float32)
+ training_mask = np.ones((input_size, input_size), dtype=np.uint8)
+ else:
+ # 5 / 8的选中的概率,裁剪含文本信息的图片
+ im, text_polys, text_tags = crop_area(im, text_polys, text_tags, crop_background=False)
+ if text_polys.shape[0] == 0:
+ # print("cannot find txt ground")
+ return None, None, None, None, None
+ h, w, _ = im.shape
+ # pad the image to the training input size or the longer side of image
+ new_h, new_w, _ = im.shape
+ max_h_w_i = np.max([new_h, new_w, input_size])
+ im_padded = np.zeros((max_h_w_i, max_h_w_i, 3), dtype=np.uint8)
+ im_padded[:new_h, :new_w, :] = im.copy()
+ im = im_padded
+ # resize the image to input size
+ # 填充,resize图像至设定尺寸
+ new_h, new_w, _ = im.shape
+ resize_h = input_size
+ resize_w = input_size
+ im = cv2.resize(im, dsize=(resize_w, resize_h))
+ # 将文本框坐标标签等比例修改
+ resize_ratio_3_x = resize_w / float(new_w)
+ resize_ratio_3_y = resize_h / float(new_h)
+ text_polys[:, :, 0] *= resize_ratio_3_x
+ text_polys[:, :, 1] *= resize_ratio_3_y
+ new_h, new_w, _ = im.shape
+ score_map, geo_map, training_mask = generate_rbox((new_h, new_w), text_polys, text_tags)
+
+ # 将一个样本的样本内容和标签信息append
+ images = im[:,:,::-1].astype(np.float32)
+ # 文件名加入列表
+ image_fns = im_fn
+ # 512*512取提取四分之一行列
+ score_maps = score_map[::4, ::4, np.newaxis].astype(np.float32)
+ geo_maps = geo_map[::4, ::4, :].astype(np.float32)
+ training_masks = training_mask[::4, ::4, np.newaxis].astype(np.float32)
+ # 符合一个样本之后输出
+ return images, image_fns, score_maps, geo_maps, training_masks
+
+ except Exception as e:
+ import traceback
+ traceback.print_exc()
+
+ # print("Exception is exist!")
+ return None, None, None, None, None
diff --git a/IndicPhotoOCR/detection/east_utils.py b/IndicPhotoOCR/detection/east_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..03c2a3838052579aa766ec00fd6f94fafd5554e4
--- /dev/null
+++ b/IndicPhotoOCR/detection/east_utils.py
@@ -0,0 +1,283 @@
+
+import torch
+import os
+from torch.nn import init
+import cv2
+import numpy as np
+import time
+import requests
+
+from IndicPhotoOCR.detection import east_config as cfg
+from IndicPhotoOCR.detection import east_preprossing as preprossing
+from IndicPhotoOCR.detection import east_locality_aware_nms as locality_aware_nms
+
+
+
+# Example usage:
+model_info = {
+ "east": {
+ "paths": [ cfg.checkpoint, cfg.pretrained_basemodel_path],
+ "urls" : ["https://github.com/anikde/STocr/releases/download/e0.1.0/epoch_990_checkpoint.pth.tar", "https://github.com/anikde/STocr/releases/download/e0.1.0/mobilenet_v2.pth.tar"]
+ },
+}
+
+class ModelManager:
+ def __init__(self):
+ # self.root_model_dir = "bharatOCR/detection/"
+ pass
+
+ def download_model(self, url, path):
+ response = requests.get(url, stream=True)
+ if response.status_code == 200:
+ with open(path, 'wb') as f:
+ for chunk in response.iter_content(chunk_size=8192):
+ if chunk: # Filter out keep-alive chunks
+ f.write(chunk)
+ print(f"Downloaded: {path}")
+ else:
+ print(f"Failed to download from {url}")
+
+ def ensure_model(self, model_name):
+ model_paths = model_info[model_name]["paths"] # Changed to handle multiple paths
+ urls = model_info[model_name]["urls"] # Changed to handle multiple URLs
+
+
+ for model_path, url in zip(model_paths, urls):
+ # full_model_path = os.path.join(self.root_model_dir, model_path)
+
+ # Ensure the model path directory exists
+ os.makedirs(os.path.dirname(os.path.join(*cfg.pretrained_basemodel_path.split("/"))), exist_ok=True)
+
+ if not os.path.exists(model_path):
+ print(f"Model not found locally. Downloading {model_name} from {url}...")
+ self.download_model(url, model_path)
+ else:
+ print(f"Model already exists at {model_path}. No need to download.")
+
+
+
+# # Initialize ModelManager and ensure Hindi models are downloaded
+model_manager = ModelManager()
+model_manager.ensure_model("east")
+
+
+
+def init_weights(m_list, init_type=cfg.init_type, gain=0.02):
+ print("EAST <==> Prepare <==> Init Network'{}' <==> Begin".format(cfg.init_type))
+ # this will apply to each layer
+ for m in m_list:
+ classname = m.__class__.__name__
+ if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
+ if init_type == 'normal':
+ init.normal_(m.weight.data, 0.0, gain)
+ elif init_type == 'xavier':
+ init.xavier_normal_(m.weight.data, gain=gain)
+ elif init_type == 'kaiming':
+ init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') # good for relu
+ elif init_type == 'orthogonal':
+ init.orthogonal_(m.weight.data, gain=gain)
+ else:
+ raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
+
+ if hasattr(m, 'bias') and m.bias is not None:
+ init.constant_(m.bias.data, 0.0)
+ elif classname.find('BatchNorm2d') != -1:
+ init.normal_(m.weight.data, 1.0, gain)
+ init.constant_(m.bias.data, 0.0)
+
+ print("EAST <==> Prepare <==> Init Network'{}' <==> Done".format(cfg.init_type))
+
+
+def Loading_checkpoint(model, optimizer, scheduler, filename='checkpoint.pth.tar'):
+ """[summary]
+ [description]
+ Arguments:
+ state {[type]} -- [description] a dict describe some params
+ Keyword Arguments:
+ filename {str} -- [description] (default: {'checkpoint.pth.tar'})
+ """
+ weightpath = os.path.abspath(cfg.checkpoint)
+ print("EAST <==> Prepare <==> Loading checkpoint '{}' <==> Begin".format(weightpath))
+ checkpoint = torch.load(weightpath)
+ start_epoch = checkpoint['epoch'] + 1
+ model.load_state_dict(checkpoint['state_dict'])
+ optimizer.load_state_dict(checkpoint['optimizer'])
+ scheduler.load_state_dict(checkpoint['scheduler'])
+ print("EAST <==> Prepare <==> Loading checkpoint '{}' <==> Done".format(weightpath))
+
+ return start_epoch
+
+
+def save_checkpoint(epoch, model, optimizer, scheduler, filename='checkpoint.pth.tar'):
+ """[summary]
+ [description]
+ Arguments:
+ state {[type]} -- [description] a dict describe some params
+ Keyword Arguments:
+ filename {str} -- [description] (default: {'checkpoint.pth.tar'})
+ """
+ print('EAST <==> Save weight - epoch {} <==> Begin'.format(epoch))
+ state = {
+ 'epoch': epoch,
+ 'state_dict': model.state_dict(),
+ 'optimizer': optimizer.state_dict(),
+ 'scheduler': scheduler.state_dict()
+ }
+ weight_dir = cfg.save_model_path
+ if not os.path.exists(weight_dir):
+ os.mkdir(weight_dir)
+ filename = 'epoch_' + str(epoch) + '_checkpoint.pth.tar'
+ file_path = os.path.join(weight_dir, filename)
+ torch.save(state, file_path)
+ print('EAST <==> Save weight - epoch {} <==> Done'.format(epoch))
+
+
+class Regularization(torch.nn.Module):
+ def __init__(self, model, weight_decay, p=2):
+ super(Regularization, self).__init__()
+ if weight_decay < 0:
+ print("param weight_decay can not <0")
+ exit(0)
+ self.model = model
+ self.weight_decay = weight_decay
+ self.p = p
+ self.weight_list = self.get_weight(model)
+ # self.weight_info(self.weight_list)
+
+ def to(self, device):
+ self.device = device
+ super().to(device)
+ return self
+
+ def forward(self, model):
+ self.weight_list = self.get_weight(model) # 获得最新的权重
+ reg_loss = self.regularization_loss(self.weight_list, self.weight_decay, p=self.p)
+ return reg_loss
+
+ def get_weight(self, model):
+ weight_list = []
+ for name, param in model.named_parameters():
+ if 'weight' in name:
+ weight = (name, param)
+ weight_list.append(weight)
+ return weight_list
+
+ def regularization_loss(self, weight_list, weight_decay, p=2):
+ reg_loss = 0
+ for name, w in weight_list:
+ l2_reg = torch.norm(w, p=p)
+ reg_loss = reg_loss + l2_reg
+
+ reg_loss = weight_decay * reg_loss
+ return reg_loss
+
+ def weight_info(self, weight_list):
+ print("---------------regularization weight---------------")
+ for name, w in weight_list:
+ print(name)
+ print("---------------------------------------------------")
+
+
+def resize_image(im, max_side_len=2400):
+ '''
+ resize image to a size multiple of 32 which is required by the network
+ :param im: the resized image
+ :param max_side_len: limit of max image size to avoid out of memory in gpu
+ :return: the resized image and the resize ratio
+ '''
+ h, w, _ = im.shape
+
+ resize_w = w
+ resize_h = h
+
+ # limit the max side
+ """
+ if max(resize_h, resize_w) > max_side_len:
+ ratio = float(max_side_len) / resize_h if resize_h > resize_w else float(max_side_len) / resize_w
+ else:
+ ratio = 1.
+
+ resize_h = int(resize_h * ratio)
+ resize_w = int(resize_w * ratio)
+ """
+
+ resize_h = resize_h if resize_h % 32 == 0 else (resize_h // 32 - 1) * 32
+ resize_w = resize_w if resize_w % 32 == 0 else (resize_w // 32 - 1) * 32
+ #resize_h, resize_w = 512, 512
+ im = cv2.resize(im, (int(resize_w), int(resize_h)))
+
+ ratio_h = resize_h / float(h)
+ ratio_w = resize_w / float(w)
+
+ return im, (ratio_h, ratio_w)
+
+
+def detect(score_map, geo_map, timer, score_map_thresh=0.8, box_thresh=0.1, nms_thres=0.2):
+ '''
+ restore text boxes from score map and geo map
+ :param score_map:
+ :param geo_map:
+ :param timer:
+ :param score_map_thresh: threshhold for score map
+ :param box_thresh: threshhold for boxes
+ :param nms_thres: threshold for nms
+ :return:
+ '''
+
+ # score_map 和 geo_map 的维数进行调整
+ if len(score_map.shape) == 4:
+ score_map = score_map[0, :, :, 0]
+ geo_map = geo_map[0, :, :, :]
+ # filter the score map
+ xy_text = np.argwhere(score_map > score_map_thresh)
+ # sort the text boxes via the y axis
+ xy_text = xy_text[np.argsort(xy_text[:, 0])]
+ # restore
+ start = time.time()
+ text_box_restored = preprossing.restore_rectangle(xy_text[:, ::-1] * 4,
+ geo_map[xy_text[:, 0], xy_text[:, 1], :]) # N*4*2
+ # print('{} text boxes before nms'.format(text_box_restored.shape[0]))
+ boxes = np.zeros((text_box_restored.shape[0], 9), dtype=np.float32)
+ boxes[:, :8] = text_box_restored.reshape((-1, 8))
+ boxes[:, 8] = score_map[xy_text[:, 0], xy_text[:, 1]]
+ timer['restore'] = time.time() - start
+ # nms part
+ start = time.time()
+ boxes = locality_aware_nms.nms_locality(boxes.astype(np.float64), nms_thres)
+ timer['nms'] = time.time() - start
+ # print(timer['nms'])
+ if boxes.shape[0] == 0:
+ return None, timer
+
+ # here we filter some low score boxes by the average score map, this is different from the orginal paper
+ for i, box in enumerate(boxes):
+ mask = np.zeros_like(score_map, dtype=np.uint8)
+ cv2.fillPoly(mask, box[:8].reshape((-1, 4, 2)).astype(np.int32) // 4, 1)
+ boxes[i, 8] = cv2.mean(score_map, mask)[0]
+ boxes = boxes[boxes[:, 8] > box_thresh]
+ return boxes, timer
+
+
+def sort_poly(p):
+ min_axis = np.argmin(np.sum(p, axis=1))
+ p = p[[min_axis, (min_axis + 1) % 4, (min_axis + 2) % 4, (min_axis + 3) % 4]]
+ if abs(p[0, 0] - p[1, 0]) > abs(p[0, 1] - p[1, 1]):
+ return p
+ else:
+ return p[[0, 3, 2, 1]]
+
+
+def mean_image_subtraction(images, means=cfg.means):
+ '''
+ image normalization
+ :param images: bs * w * h * channel
+ :param means:
+ :return:
+ '''
+ num_channels = images.data.shape[1]
+ if len(means) != num_channels:
+ raise ValueError('len(means) must match the number of channels')
+ for i in range(num_channels):
+ images.data[:, i, :, :] -= means[i]
+
+ return images
diff --git a/IndicPhotoOCR/detection/textbpn/__init__.py b/IndicPhotoOCR/detection/textbpn/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/IndicPhotoOCR/detection/textbpn/cfglib/config.py b/IndicPhotoOCR/detection/textbpn/cfglib/config.py
new file mode 100644
index 0000000000000000000000000000000000000000..ae1ea831dd23c535b9d220932a26e5ba823f474b
--- /dev/null
+++ b/IndicPhotoOCR/detection/textbpn/cfglib/config.py
@@ -0,0 +1,90 @@
+from easydict import EasyDict
+import torch
+import os
+
+config = EasyDict()
+
+
+# Normalize image
+config.means = (0.485, 0.456, 0.406)
+config.stds = (0.229, 0.224, 0.225)
+
+config.gpu = "1"
+
+# Experiment name #
+config.exp_name = "Synthtext"
+
+# dataloader jobs number
+config.num_workers = 24
+
+# batch_size
+config.batch_size = 12
+
+# training epoch number
+config.max_epoch = 200
+
+config.start_epoch = 0
+
+# learning rate
+config.lr = 1e-4
+
+# using GPU
+config.cuda = False
+
+config.output_dir = 'output'
+
+config.input_size = 640
+
+# max polygon per image
+# synText, total-text:64; CTW1500: 64; icdar: 64; MLT: 32; TD500: 64.
+config.max_annotation = 64
+
+# adj num for graph
+config.adj_num = 4
+
+# control points number
+config.num_points = 20
+
+# use hard examples (annotated as '#')
+config.use_hard = True
+
+# Load data into memory at one time
+config.load_memory = False
+
+# prediction on 1/scale feature map
+config.scale = 1
+
+# # clip gradient of loss
+config.grad_clip = 25
+
+# demo tcl threshold
+config.dis_threshold = 0.4
+
+config.cls_threshold = 0.8
+
+# Contour approximation factor
+config.approx_factor = 0.004
+
+
+def update_config(config, extra_config):
+ for k, v in vars(extra_config).items():
+ config[k] = v
+ # print(config.gpu)
+ # config.device = torch.device('cuda') if config.cuda else torch.device('cpu')
+ config.device = torch.device('cpu')
+
+
+def print_config(config):
+ print('==========Options============')
+ for k, v in config.items():
+ print('{}: {}'.format(k, v))
+ print('=============End=============')
+
+
+
+################### MY Settings ##################
+config.resume=True
+
+config.device="cpu"
+
+# config.test_size = [224, 224]
\ No newline at end of file
diff --git a/IndicPhotoOCR/detection/textbpn/cfglib/option.py b/IndicPhotoOCR/detection/textbpn/cfglib/option.py
new file mode 100644
index 0000000000000000000000000000000000000000..6b487d0a37e6ea9eaa4ab9c6d901f2253f38855c
--- /dev/null
+++ b/IndicPhotoOCR/detection/textbpn/cfglib/option.py
@@ -0,0 +1,123 @@
+import argparse
+import torch
+import os
+import torch.backends.cudnn as cudnn
+
+from datetime import datetime
+
+
+def str2bool(v):
+ return v.lower() in ("yes", "true", "t", "1")
+
+
+def arg2str(args):
+ args_dict = vars(args)
+ option_str = datetime.now().strftime('%b%d_%H-%M-%S') + '\n'
+
+ for k, v in sorted(args_dict.items()):
+ option_str += ('{}: {}\n'.format(str(k), str(v)))
+
+ return option_str
+
+
+class BaseOptions(object):
+
+ def __init__(self):
+
+ self.parser = argparse.ArgumentParser()
+
+ # basic opts
+ self.parser.add_argument('--exp_name', default="TD500", type=str,
+ choices=['Synthtext', 'Totaltext', 'Ctw1500','Icdar2015',
+ "MLT2017", 'TD500', "MLT2019", "ArT", "ALL"], help='Experiment name')
+ self.parser.add_argument("--gpu", default="1", help="set gpu id", type=str)
+ self.parser.add_argument('--resume', default=None, type=str, help='Path to target resume checkpoint')
+ self.parser.add_argument('--num_workers', default=24, type=int, help='Number of workers used in dataloading')
+ self.parser.add_argument('--cuda', default=True, type=str2bool, help='Use cuda to train model')
+ self.parser.add_argument('--mgpu', action='store_true', help='Use multi-gpu to train model')
+ self.parser.add_argument('--save_dir', default='./model/', help='Path to save checkpoint models')
+ self.parser.add_argument('--vis_dir', default='./vis/', help='Path to save visualization images')
+ self.parser.add_argument('--log_dir', default='./logs/', help='Path to tensorboard log')
+ self.parser.add_argument('--loss', default='CrossEntropyLoss', type=str, help='Training Loss')
+ # self.parser.add_argument('--input_channel', default=1, type=int, help='number of input channels' )
+ self.parser.add_argument('--pretrain', default=False, type=str2bool, help='Pretrained AutoEncoder model')
+ self.parser.add_argument('--verbose', '-v', default=True, type=str2bool, help='Whether to output debug info')
+ self.parser.add_argument('--viz', action='store_true', help='Whether to output debug info')
+ # self.parser.add_argument('--viz', default=True, type=str2bool, help='Whether to output debug info')
+
+ # train opts
+ self.parser.add_argument('--max_epoch', default=250, type=int, help='Max epochs')
+ self.parser.add_argument('--lr', '--learning-rate', default=1e-3, type=float, help='initial learning rate')
+ self.parser.add_argument('--lr_adjust', default='fix',
+ choices=['fix', 'poly'], type=str, help='Learning Rate Adjust Strategy')
+ self.parser.add_argument('--stepvalues', default=[], nargs='+', type=int, help='# of iter to change lr')
+ self.parser.add_argument('--weight_decay', '--wd', default=0., type=float, help='Weight decay for SGD')
+ self.parser.add_argument('--gamma', default=0.1, type=float, help='Gamma update for SGD lr')
+ self.parser.add_argument('--momentum', default=0.9, type=float, help='momentum')
+ self.parser.add_argument('--batch_size', default=6, type=int, help='Batch size for training')
+ self.parser.add_argument('--optim', default='Adam', type=str, choices=['SGD', 'Adam'], help='Optimizer')
+ self.parser.add_argument('--save_freq', default=5, type=int, help='save weights every # epoch')
+ self.parser.add_argument('--display_freq', default=10, type=int, help='display training metrics every # iter')
+ self.parser.add_argument('--viz_freq', default=50, type=int, help='visualize training process every # iter')
+ self.parser.add_argument('--log_freq', default=10000, type=int, help='log to tensorboard every # iterations')
+ self.parser.add_argument('--val_freq', default=1000, type=int, help='do validation every # iterations')
+
+ # backbone
+ self.parser.add_argument('--scale', default=1, type=int, help='prediction on 1/scale feature map')
+ self.parser.add_argument('--net', default='resnet50', type=str,
+ choices=['vgg', 'resnet50', 'resnet18',
+ "deformable_resnet18", "deformable_resnet50"],
+ help='Network architecture')
+ # data args
+ self.parser.add_argument('--load_memory', default=False, type=str2bool, help='Load data into memory')
+ self.parser.add_argument('--rescale', type=float, default=255.0, help='rescale factor')
+ self.parser.add_argument('--input_size', default=640, type=int, help='model input size')
+ self.parser.add_argument('--test_size', default=[640, 960], type=int, nargs='+', help='test size')
+
+ # eval args00
+ self.parser.add_argument('--checkepoch', default=1070, type=int, help='Load checkpoint number')
+ self.parser.add_argument('--start_epoch', default=0, type=int, help='start epoch number')
+ self.parser.add_argument('--cls_threshold', default=0.875, type=float, help='threshold of pse')
+ self.parser.add_argument('--dis_threshold', default=0.35, type=float, help='filter the socre < score_i')
+
+ # demo args
+ self.parser.add_argument('--img_root', default=None, type=str, help='Path to deploy images')
+
+ def parse(self, fixed=None):
+
+ if fixed is not None:
+ args = self.parser.parse_args(fixed)
+ else:
+ args = self.parser.parse_args()
+
+ return args
+
+ def initialize(self, fixed=None):
+
+ # Parse options
+ self.args = self.parse(fixed)
+ os.environ['CUDA_VISIBLE_DEVICES'] = self.args.gpu
+
+ # Setting default torch Tensor type
+ if self.args.cuda and torch.cuda.is_available():
+ torch.set_default_tensor_type('torch.cuda.FloatTensor')
+ cudnn.benchmark = True
+ else:
+ torch.set_default_tensor_type('torch.FloatTensor')
+
+ # Create weights saving directory
+ if not os.path.exists(self.args.save_dir):
+ os.mkdir(self.args.save_dir)
+
+ # Create weights saving directory of target model
+ model_save_path = os.path.join(self.args.save_dir, self.args.exp_name)
+
+ if not os.path.exists(model_save_path):
+ os.mkdir(model_save_path)
+
+ return self.args
+
+ def update(self, args, extra_options):
+
+ for k, v in extra_options.items():
+ setattr(args, k, v)
diff --git a/IndicPhotoOCR/detection/textbpn/models/TextBPN_resnet50_300.pth b/IndicPhotoOCR/detection/textbpn/models/TextBPN_resnet50_300.pth
new file mode 100644
index 0000000000000000000000000000000000000000..8426c51dd291d2932737b611c63aa87d5f4098c9
--- /dev/null
+++ b/IndicPhotoOCR/detection/textbpn/models/TextBPN_resnet50_300.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b735b9c93c8758972d3b8cfd3ef8e1c09afa8cd9106f4cb11406300b141b1d78
+size 145703602
diff --git a/IndicPhotoOCR/detection/textbpn/network/Reg_loss.py b/IndicPhotoOCR/detection/textbpn/network/Reg_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..0ca0c0dcd96b44d2a1dead384eacd7c06f69d75b
--- /dev/null
+++ b/IndicPhotoOCR/detection/textbpn/network/Reg_loss.py
@@ -0,0 +1,196 @@
+# -*- coding: utf-8 -*-
+# @Time : 10/1/21
+# @Author : GXYM
+import torch
+from torch import nn
+import numpy as np
+import torch.nn.functional as F
+
+
+class PolyMatchingLoss(nn.Module):
+ def __init__(self, pnum, device, loss_type="L1"):
+ super(PolyMatchingLoss, self).__init__()
+
+ self.pnum = pnum
+ self.device = device
+ self.loss_type = loss_type
+ self.smooth_L1 = F.smooth_l1_loss
+ self.L2_loss = torch.nn.MSELoss(reduce=False, size_average=False)
+
+ batch_size = 1
+ pidxall = np.zeros(shape=(batch_size, pnum, pnum), dtype=np.int32)
+ for b in range(batch_size):
+ for i in range(pnum):
+ pidx = (np.arange(pnum) + i) % pnum
+ pidxall[b, i] = pidx
+
+ pidxall = torch.from_numpy(np.reshape(pidxall, newshape=(batch_size, -1))).to(device)
+ self.feature_id = pidxall.unsqueeze_(2).long().expand(pidxall.size(0), pidxall.size(1), 2).detach()
+ print(self.feature_id.shape)
+
+ def match_loss(self, pred, gt):
+ batch_size = pred.shape[0]
+ feature_id = self.feature_id.expand(batch_size, self.feature_id.size(1), 2)
+
+ gt_expand = torch.gather(gt, 1, feature_id).view(batch_size, self.pnum, self.pnum, 2)
+ pred_expand = pred.unsqueeze(1)
+
+ if self.loss_type == "L2":
+ dis = self.L2_loss(pred_expand, gt_expand)
+ dis = dis.sum(3).sqrt().mean(2)
+ elif self.loss_type == "L1":
+ dis = self.smooth_L1(pred_expand, gt_expand, reduction='none')
+ dis = dis.sum(3).mean(2)
+
+ min_dis, min_id = torch.min(dis, dim=1, keepdim=True)
+
+ return min_dis
+
+ def forward(self, pred_list, gt):
+ loss = torch.tensor(0.)
+ for pred in pred_list:
+ loss += torch.mean(self.match_loss(pred, gt))
+
+ return loss / torch.tensor(len(pred_list))
+
+ # los = []
+ # for pred in pred_list:
+ # los.append(self.match_loss(pred, gt))
+ #
+ # los_b = torch.tensor(0.)
+ # loss_c = torch.tensor(0.)
+ # for i, _ in enumerate(los):
+ # los_b += torch.mean(los[i])
+ # loss_c += (torch.mean(torch.clamp(los[i] - los[i - 1], min=0.0)) if i > 0 else torch.tensor(0.))
+ # loss = los_b / torch.tensor(len(los)) + 0.5*loss_c / torch.tensor(len(los)-1)
+ #
+ # return loss
+
+
+class AttentionLoss(nn.Module):
+ def __init__(self, beta=4, gamma=0.5):
+ super(AttentionLoss, self).__init__()
+
+ self.beta = beta
+ self.gamma = gamma
+
+ def forward(self, pred, gt):
+ num_pos = torch.sum(gt)
+ num_neg = torch.sum(1 - gt)
+ alpha = num_neg / (num_pos + num_neg)
+ edge_beta = torch.pow(self.beta, torch.pow(1 - pred, self.gamma))
+ bg_beta = torch.pow(self.beta, torch.pow(pred, self.gamma))
+
+ loss = 0
+ loss = loss - alpha * edge_beta * torch.log(pred) * gt
+ loss = loss - (1 - alpha) * bg_beta * torch.log(1 - pred) * (1 - gt)
+ return torch.mean(loss)
+
+
+class GeoCrossEntropyLoss(nn.Module):
+ def __init__(self):
+ super(GeoCrossEntropyLoss, self).__init__()
+
+ def forward(self, output, target, poly):
+ output = torch.nn.functional.softmax(output, dim=1)
+ output = torch.log(torch.clamp(output, min=1e-4))
+ poly = poly.view(poly.size(0), 4, poly.size(1) // 4, 2)
+ target = target[..., None, None].expand(poly.size(0), poly.size(1), 1, poly.size(3))
+ target_poly = torch.gather(poly, 2, target)
+ sigma = (poly[:, :, 0] - poly[:, :, 1]).pow(2).sum(2, keepdim=True)
+ kernel = torch.exp(-(poly - target_poly).pow(2).sum(3) / (sigma / 3))
+ loss = -(output * kernel.transpose(2, 1)).sum(1).mean()
+ return loss
+
+
+class AELoss(nn.Module):
+ def __init__(self):
+ super(AELoss, self).__init__()
+
+ def forward(self, ae, ind, ind_mask):
+ """
+ ae: [b, 1, h, w]
+ ind: [b, max_objs, max_parts]
+ ind_mask: [b, max_objs, max_parts]
+ obj_mask: [b, max_objs]
+ """
+ # first index
+ b, _, h, w = ae.shape
+ b, max_objs, max_parts = ind.shape
+ obj_mask = torch.sum(ind_mask, dim=2) != 0
+
+ ae = ae.view(b, h * w, 1)
+ seed_ind = ind.view(b, max_objs * max_parts, 1)
+ tag = ae.gather(1, seed_ind).view(b, max_objs, max_parts)
+
+ # compute the mean
+ tag_mean = tag * ind_mask
+ tag_mean = tag_mean.sum(2) / (ind_mask.sum(2) + 1e-4)
+
+ # pull ae of the same object to their mean
+ pull_dist = (tag - tag_mean.unsqueeze(2)).pow(2) * ind_mask
+ obj_num = obj_mask.sum(dim=1).float()
+ pull = (pull_dist.sum(dim=(1, 2)) / (obj_num + 1e-4)).sum()
+ pull /= b
+
+ # push away the mean of different objects
+ push_dist = torch.abs(tag_mean.unsqueeze(1) - tag_mean.unsqueeze(2))
+ push_dist = 1 - push_dist
+ push_dist = nn.functional.relu(push_dist, inplace=True)
+ obj_mask = (obj_mask.unsqueeze(1) + obj_mask.unsqueeze(2)) == 2
+ push_dist = push_dist * obj_mask.float()
+ push = ((push_dist.sum(dim=(1, 2)) - obj_num) / (obj_num * (obj_num - 1) + 1e-4)).sum()
+ push /= b
+ return pull, push
+
+
+def smooth_l1_loss(inputs, target, sigma=9.0):
+ try:
+ diff = torch.abs(inputs - target)
+ less_one = (diff < 1.0 / sigma).float()
+ loss = less_one * 0.5 * diff ** 2 * sigma \
+ + torch.abs(torch.tensor(1.0) - less_one) * (diff - 0.5 / sigma)
+ loss = torch.mean(loss) if loss.numel() > 0 else torch.tensor(0.0)
+ except Exception as e:
+ print('RPN_REGR_Loss Exception:', e)
+ loss = torch.tensor(0.0)
+
+ return loss
+
+
+def _neg_loss(pred, gt):
+ ''' Modified focal loss. Exactly the same as CornerNet.
+ Runs faster and costs a little bit more memory
+ Arguments:
+ pred (batch x c x h x w)
+ gt_regr (batch x c x h x w)
+ '''
+ pos_inds = gt.eq(1).float()
+ neg_inds = gt.lt(1).float()
+
+ neg_weights = torch.pow(1 - gt, 4)
+
+ loss = 0
+
+ pos_loss = torch.log(pred) * torch.pow(1 - pred, 2) * pos_inds
+ neg_loss = torch.log(1 - pred) * torch.pow(pred, 2) * neg_weights * neg_inds
+
+ num_pos = pos_inds.float().sum()
+ pos_loss = pos_loss.sum()
+ neg_loss = neg_loss.sum()
+
+ if num_pos == 0:
+ loss = loss - neg_loss
+ else:
+ loss = loss - (pos_loss + neg_loss) / num_pos
+ return loss
+
+
+class FocalLoss(nn.Module):
+ '''nn.Module warpper for focal loss'''
+ def __init__(self):
+ super(FocalLoss, self).__init__()
+ self.neg_loss = _neg_loss
+
+ def forward(self, out, target):
+ return self.neg_loss(out, target)
\ No newline at end of file
diff --git a/IndicPhotoOCR/detection/textbpn/network/Seg_loss.py b/IndicPhotoOCR/detection/textbpn/network/Seg_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..a03507f650f17d747c7f2eba2f1d31f57e411cb1
--- /dev/null
+++ b/IndicPhotoOCR/detection/textbpn/network/Seg_loss.py
@@ -0,0 +1,107 @@
+# -*- coding: utf-8 -*-
+# @Time : 10/1/21
+# @Author : GXYM
+import torch
+from torch import nn
+import numpy as np
+
+
+class SegmentLoss(nn.Module):
+ def __init__(self, Lambda, ratio=3, reduction='mean'):
+ """Implement PSE Loss.
+ """
+ super(SegmentLoss, self).__init__()
+ assert reduction in ['mean', 'sum'], " reduction must in ['mean','sum']"
+ self.Lambda = Lambda
+ self.ratio = ratio
+ self.reduction = reduction
+
+ def forward(self, outputs, labels, training_masks, th=0.5):
+ texts = outputs[:, -1, :, :]
+ kernels = outputs[:, :-1, :, :]
+ gt_texts = labels[:, -1, :, :]
+ gt_kernels = labels[:, :-1, :, :]
+
+ selected_masks = self.ohem_batch(texts, gt_texts, training_masks)
+ selected_masks = selected_masks.to(outputs.device)
+
+ loss_text = self.dice_loss(texts, gt_texts, selected_masks)
+
+ loss_kernels = []
+ # mask0 = torch.sigmoid(texts).data.cpu().numpy()
+ mask0 = texts.data.cpu().numpy()
+ mask1 = training_masks.data.cpu().numpy()
+ selected_masks = ((mask0 > th) & (mask1 > th)).astype('float32')
+ selected_masks = torch.from_numpy(selected_masks).float()
+ selected_masks = selected_masks.to(outputs.device)
+ kernels_num = gt_kernels.size()[1]
+ for i in range(kernels_num):
+ kernel_i = kernels[:, i, :, :]
+ gt_kernel_i = gt_kernels[:, i, :, :]
+ loss_kernel_i = self.dice_loss(kernel_i, gt_kernel_i, selected_masks)
+ loss_kernels.append(loss_kernel_i)
+ loss_kernels = torch.stack(loss_kernels).mean(0)
+ if self.reduction == 'mean':
+ loss_text = loss_text.mean()
+ loss_kernels = loss_kernels.mean()
+ elif self.reduction == 'sum':
+ loss_text = loss_text.sum()
+ loss_kernels = loss_kernels.sum()
+
+ loss = self.Lambda *loss_text + (1-self.Lambda)*loss_kernels
+ return loss_text, loss_kernels, loss
+
+ def dice_loss(self, input, target, mask):
+ # input = torch.sigmoid(input)
+
+ input = input.contiguous().view(input.size()[0], -1)
+ target = target.contiguous().view(target.size()[0], -1)
+ mask = mask.contiguous().view(mask.size()[0], -1)
+
+ input = input * mask
+ target = (target.float()) * mask
+
+ a = torch.sum(input * target, 1)
+ b = torch.sum(input * input, 1) + 0.001
+ c = torch.sum(target * target, 1) + 0.001
+ d = (2 * a) / (b + c)
+ return 1 - d
+
+ def ohem_single(self, score, gt_text, training_mask, th=0.5):
+ pos_num = (int)(np.sum(gt_text > th)) - (int)(np.sum((gt_text > th) & (training_mask <= th)))
+
+ if pos_num == 0:
+ # selected_mask = gt_text.copy() * 0 # may be not good
+ selected_mask = training_mask
+ selected_mask = selected_mask.reshape(1, selected_mask.shape[0], selected_mask.shape[1]).astype('float32')
+ return selected_mask
+
+ neg_num = (int)(np.sum(gt_text <= th))
+ neg_num = (int)(min(pos_num * 3, neg_num))
+
+ if neg_num == 0:
+ selected_mask = training_mask
+ selected_mask = selected_mask.reshape(1, selected_mask.shape[0], selected_mask.shape[1]).astype('float32')
+ return selected_mask
+
+ neg_score = score[gt_text <= th]
+ # 将负样本得分从高到低排序
+ neg_score_sorted = np.sort(-neg_score)
+ threshold = -neg_score_sorted[neg_num - 1]
+ # 选出 得分高的 负样本 和正样本 的 mask
+ selected_mask = ((score >= threshold) | (gt_text > th)) & (training_mask > th)
+ selected_mask = selected_mask.reshape(1, selected_mask.shape[0], selected_mask.shape[1]).astype('float32')
+ return selected_mask
+
+ def ohem_batch(self, scores, gt_texts, training_masks):
+ scores = scores.data.cpu().numpy()
+ gt_texts = gt_texts.data.cpu().numpy()
+ training_masks = training_masks.data.cpu().numpy()
+ selected_masks = []
+ for i in range(scores.shape[0]):
+ selected_masks.append(self.ohem_single(scores[i, :, :], gt_texts[i, :, :], training_masks[i, :, :]))
+
+ selected_masks = np.concatenate(selected_masks, 0)
+ selected_masks = torch.from_numpy(selected_masks).float()
+
+ return selected_masks
diff --git a/IndicPhotoOCR/detection/textbpn/network/__init__.py b/IndicPhotoOCR/detection/textbpn/network/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b9742821a6f164200bc145e7a847382f08778303
--- /dev/null
+++ b/IndicPhotoOCR/detection/textbpn/network/__init__.py
@@ -0,0 +1 @@
+from . import *
\ No newline at end of file
diff --git a/IndicPhotoOCR/detection/textbpn/network/backbone/__init__.py b/IndicPhotoOCR/detection/textbpn/network/backbone/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..d05af9e17f2bb084365379d39f38305dc23f339e
--- /dev/null
+++ b/IndicPhotoOCR/detection/textbpn/network/backbone/__init__.py
@@ -0,0 +1 @@
+from .resnet import resnet18, resnet34, resnet50, resnet101, deformable_resnet50, deformable_resnet18
\ No newline at end of file
diff --git a/IndicPhotoOCR/detection/textbpn/network/backbone/assets/dcn/Makefile b/IndicPhotoOCR/detection/textbpn/network/backbone/assets/dcn/Makefile
new file mode 100644
index 0000000000000000000000000000000000000000..c50242699a4ddb7d97650378ef2a199fde6b3d99
--- /dev/null
+++ b/IndicPhotoOCR/detection/textbpn/network/backbone/assets/dcn/Makefile
@@ -0,0 +1,6 @@
+#!/bin/bash
+rm *.so
+python setup.py build_ext --inplace
+rm -rf ./build
+
+
diff --git a/IndicPhotoOCR/detection/textbpn/network/backbone/assets/dcn/Makefile.sh b/IndicPhotoOCR/detection/textbpn/network/backbone/assets/dcn/Makefile.sh
new file mode 100644
index 0000000000000000000000000000000000000000..c50242699a4ddb7d97650378ef2a199fde6b3d99
--- /dev/null
+++ b/IndicPhotoOCR/detection/textbpn/network/backbone/assets/dcn/Makefile.sh
@@ -0,0 +1,6 @@
+#!/bin/bash
+rm *.so
+python setup.py build_ext --inplace
+rm -rf ./build
+
+
diff --git a/IndicPhotoOCR/detection/textbpn/network/backbone/assets/dcn/__init__.py b/IndicPhotoOCR/detection/textbpn/network/backbone/assets/dcn/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..165e63725354de429a448d866f665cccca991916
--- /dev/null
+++ b/IndicPhotoOCR/detection/textbpn/network/backbone/assets/dcn/__init__.py
@@ -0,0 +1,13 @@
+from .functions.deform_conv import deform_conv, modulated_deform_conv
+from .functions.deform_pool import deform_roi_pooling
+from .modules.deform_conv import (DeformConv, ModulatedDeformConv,
+ DeformConvPack, ModulatedDeformConvPack)
+from .modules.deform_pool import (DeformRoIPooling, DeformRoIPoolingPack,
+ ModulatedDeformRoIPoolingPack)
+
+__all__ = [
+ 'DeformConv', 'DeformConvPack', 'ModulatedDeformConv',
+ 'ModulatedDeformConvPack', 'DeformRoIPooling', 'DeformRoIPoolingPack',
+ 'ModulatedDeformRoIPoolingPack', 'deform_conv', 'modulated_deform_conv',
+ 'deform_roi_pooling'
+]
diff --git a/IndicPhotoOCR/detection/textbpn/network/backbone/assets/dcn/functions/__init__.py b/IndicPhotoOCR/detection/textbpn/network/backbone/assets/dcn/functions/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/IndicPhotoOCR/detection/textbpn/network/backbone/assets/dcn/functions/deform_conv.py b/IndicPhotoOCR/detection/textbpn/network/backbone/assets/dcn/functions/deform_conv.py
new file mode 100644
index 0000000000000000000000000000000000000000..6af75a758b8448ca1d981054525259f536d99d1e
--- /dev/null
+++ b/IndicPhotoOCR/detection/textbpn/network/backbone/assets/dcn/functions/deform_conv.py
@@ -0,0 +1,181 @@
+import torch
+from torch.autograd import Function
+from torch.nn.modules.utils import _pair
+
+from .. import deform_conv_cuda
+
+
+class DeformConvFunction(Function):
+
+ @staticmethod
+ def forward(ctx,
+ input,
+ offset,
+ weight,
+ stride=1,
+ padding=0,
+ dilation=1,
+ groups=1,
+ deformable_groups=1,
+ im2col_step=64):
+ if input is not None and input.dim() != 4:
+ raise ValueError(
+ "Expected 4D tensor as input, got {}D tensor instead.".format(
+ input.dim()))
+ ctx.stride = _pair(stride)
+ ctx.padding = _pair(padding)
+ ctx.dilation = _pair(dilation)
+ ctx.groups = groups
+ ctx.deformable_groups = deformable_groups
+ ctx.im2col_step = im2col_step
+
+ ctx.save_for_backward(input, offset, weight)
+
+ output = input.new_empty(
+ DeformConvFunction._output_size(input, weight, ctx.padding,
+ ctx.dilation, ctx.stride))
+
+ ctx.bufs_ = [input.new_empty(0), input.new_empty(0)] # columns, ones
+
+ if not input.is_cuda:
+ raise NotImplementedError
+ else:
+ cur_im2col_step = min(ctx.im2col_step, input.shape[0])
+ assert (input.shape[0] %
+ cur_im2col_step) == 0, 'im2col step must divide batchsize'
+ deform_conv_cuda.deform_conv_forward_cuda(
+ input, weight, offset, output, ctx.bufs_[0], ctx.bufs_[1],
+ weight.size(3), weight.size(2), ctx.stride[1], ctx.stride[0],
+ ctx.padding[1], ctx.padding[0], ctx.dilation[1],
+ ctx.dilation[0], ctx.groups, ctx.deformable_groups,
+ cur_im2col_step)
+ return output
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ input, offset, weight = ctx.saved_tensors
+
+ grad_input = grad_offset = grad_weight = None
+
+ if not grad_output.is_cuda:
+ raise NotImplementedError
+ else:
+ cur_im2col_step = min(ctx.im2col_step, input.shape[0])
+ assert (input.shape[0] %
+ cur_im2col_step) == 0, 'im2col step must divide batchsize'
+
+ if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]:
+ grad_input = torch.zeros_like(input)
+ grad_offset = torch.zeros_like(offset)
+ deform_conv_cuda.deform_conv_backward_input_cuda(
+ input, offset, grad_output, grad_input,
+ grad_offset, weight, ctx.bufs_[0], weight.size(3),
+ weight.size(2), ctx.stride[1], ctx.stride[0],
+ ctx.padding[1], ctx.padding[0], ctx.dilation[1],
+ ctx.dilation[0], ctx.groups, ctx.deformable_groups,
+ cur_im2col_step)
+
+ if ctx.needs_input_grad[2]:
+ grad_weight = torch.zeros_like(weight)
+ deform_conv_cuda.deform_conv_backward_parameters_cuda(
+ input, offset, grad_output,
+ grad_weight, ctx.bufs_[0], ctx.bufs_[1], weight.size(3),
+ weight.size(2), ctx.stride[1], ctx.stride[0],
+ ctx.padding[1], ctx.padding[0], ctx.dilation[1],
+ ctx.dilation[0], ctx.groups, ctx.deformable_groups, 1,
+ cur_im2col_step)
+
+ return (grad_input, grad_offset, grad_weight, None, None, None, None,
+ None)
+
+ @staticmethod
+ def _output_size(input, weight, padding, dilation, stride):
+ channels = weight.size(0)
+ output_size = (input.size(0), channels)
+ for d in range(input.dim() - 2):
+ in_size = input.size(d + 2)
+ pad = padding[d]
+ kernel = dilation[d] * (weight.size(d + 2) - 1) + 1
+ stride_ = stride[d]
+ output_size += ((in_size + (2 * pad) - kernel) // stride_ + 1, )
+ if not all(map(lambda s: s > 0, output_size)):
+ raise ValueError(
+ "convolution input is too small (output would be {})".format(
+ 'x'.join(map(str, output_size))))
+ return output_size
+
+
+class ModulatedDeformConvFunction(Function):
+
+ @staticmethod
+ def forward(ctx,
+ input,
+ offset,
+ mask,
+ weight,
+ bias=None,
+ stride=1,
+ padding=0,
+ dilation=1,
+ groups=1,
+ deformable_groups=1):
+ ctx.stride = stride
+ ctx.padding = padding
+ ctx.dilation = dilation
+ ctx.groups = groups
+ ctx.deformable_groups = deformable_groups
+ ctx.with_bias = bias is not None
+ if not ctx.with_bias:
+ bias = input.new_empty(1) # fake tensor
+ if not input.is_cuda:
+ raise NotImplementedError
+ if weight.requires_grad or mask.requires_grad or offset.requires_grad \
+ or input.requires_grad:
+ ctx.save_for_backward(input, offset, mask, weight, bias)
+ output = input.new_empty(
+ ModulatedDeformConvFunction._infer_shape(ctx, input, weight))
+ ctx._bufs = [input.new_empty(0), input.new_empty(0)]
+ deform_conv_cuda.modulated_deform_conv_cuda_forward(
+ input, weight, bias, ctx._bufs[0], offset, mask, output,
+ ctx._bufs[1], weight.shape[2], weight.shape[3], ctx.stride,
+ ctx.stride, ctx.padding, ctx.padding, ctx.dilation, ctx.dilation,
+ ctx.groups, ctx.deformable_groups, ctx.with_bias)
+ return output
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ if not grad_output.is_cuda:
+ raise NotImplementedError
+ input, offset, mask, weight, bias = ctx.saved_tensors
+ grad_input = torch.zeros_like(input)
+ grad_offset = torch.zeros_like(offset)
+ grad_mask = torch.zeros_like(mask)
+ grad_weight = torch.zeros_like(weight)
+ grad_bias = torch.zeros_like(bias)
+ deform_conv_cuda.modulated_deform_conv_cuda_backward(
+ input, weight, bias, ctx._bufs[0], offset, mask, ctx._bufs[1],
+ grad_input, grad_weight, grad_bias, grad_offset, grad_mask,
+ grad_output, weight.shape[2], weight.shape[3], ctx.stride,
+ ctx.stride, ctx.padding, ctx.padding, ctx.dilation, ctx.dilation,
+ ctx.groups, ctx.deformable_groups, ctx.with_bias)
+ if not ctx.with_bias:
+ grad_bias = None
+
+ return (grad_input, grad_offset, grad_mask, grad_weight, grad_bias,
+ None, None, None, None, None)
+
+ @staticmethod
+ def _infer_shape(ctx, input, weight):
+ n = input.size(0)
+ channels_out = weight.size(0)
+ height, width = input.shape[2:4]
+ kernel_h, kernel_w = weight.shape[2:4]
+ height_out = (height + 2 * ctx.padding -
+ (ctx.dilation * (kernel_h - 1) + 1)) // ctx.stride + 1
+ width_out = (width + 2 * ctx.padding -
+ (ctx.dilation * (kernel_w - 1) + 1)) // ctx.stride + 1
+ return n, channels_out, height_out, width_out
+
+
+deform_conv = DeformConvFunction.apply
+modulated_deform_conv = ModulatedDeformConvFunction.apply
diff --git a/IndicPhotoOCR/detection/textbpn/network/backbone/assets/dcn/functions/deform_pool.py b/IndicPhotoOCR/detection/textbpn/network/backbone/assets/dcn/functions/deform_pool.py
new file mode 100644
index 0000000000000000000000000000000000000000..65ff0efb5737e87ccf49387b2d24abcbeedd6497
--- /dev/null
+++ b/IndicPhotoOCR/detection/textbpn/network/backbone/assets/dcn/functions/deform_pool.py
@@ -0,0 +1,69 @@
+import torch
+from torch.autograd import Function
+
+from .. import deform_pool_cuda
+
+
+class DeformRoIPoolingFunction(Function):
+
+ @staticmethod
+ def forward(ctx,
+ data,
+ rois,
+ offset,
+ spatial_scale,
+ out_size,
+ out_channels,
+ no_trans,
+ group_size=1,
+ part_size=None,
+ sample_per_part=4,
+ trans_std=.0):
+ ctx.spatial_scale = spatial_scale
+ ctx.out_size = out_size
+ ctx.out_channels = out_channels
+ ctx.no_trans = no_trans
+ ctx.group_size = group_size
+ ctx.part_size = out_size if part_size is None else part_size
+ ctx.sample_per_part = sample_per_part
+ ctx.trans_std = trans_std
+
+ assert 0.0 <= ctx.trans_std <= 1.0
+ if not data.is_cuda:
+ raise NotImplementedError
+
+ n = rois.shape[0]
+ output = data.new_empty(n, out_channels, out_size, out_size)
+ output_count = data.new_empty(n, out_channels, out_size, out_size)
+ deform_pool_cuda.deform_psroi_pooling_cuda_forward(
+ data, rois, offset, output, output_count, ctx.no_trans,
+ ctx.spatial_scale, ctx.out_channels, ctx.group_size, ctx.out_size,
+ ctx.part_size, ctx.sample_per_part, ctx.trans_std)
+
+ if data.requires_grad or rois.requires_grad or offset.requires_grad:
+ ctx.save_for_backward(data, rois, offset)
+ ctx.output_count = output_count
+
+ return output
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ if not grad_output.is_cuda:
+ raise NotImplementedError
+
+ data, rois, offset = ctx.saved_tensors
+ output_count = ctx.output_count
+ grad_input = torch.zeros_like(data)
+ grad_rois = None
+ grad_offset = torch.zeros_like(offset)
+
+ deform_pool_cuda.deform_psroi_pooling_cuda_backward(
+ grad_output, data, rois, offset, output_count, grad_input,
+ grad_offset, ctx.no_trans, ctx.spatial_scale, ctx.out_channels,
+ ctx.group_size, ctx.out_size, ctx.part_size, ctx.sample_per_part,
+ ctx.trans_std)
+ return (grad_input, grad_rois, grad_offset, None, None, None, None,
+ None, None, None, None)
+
+
+deform_roi_pooling = DeformRoIPoolingFunction.apply
diff --git a/IndicPhotoOCR/detection/textbpn/network/backbone/assets/dcn/modules/__init__.py b/IndicPhotoOCR/detection/textbpn/network/backbone/assets/dcn/modules/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/IndicPhotoOCR/detection/textbpn/network/backbone/assets/dcn/modules/deform_conv.py b/IndicPhotoOCR/detection/textbpn/network/backbone/assets/dcn/modules/deform_conv.py
new file mode 100644
index 0000000000000000000000000000000000000000..50d15d1513f0ebc145982e04958f76a5f1ca1343
--- /dev/null
+++ b/IndicPhotoOCR/detection/textbpn/network/backbone/assets/dcn/modules/deform_conv.py
@@ -0,0 +1,157 @@
+import math
+
+import torch
+import torch.nn as nn
+from torch.nn.modules.utils import _pair
+
+from ..functions.deform_conv import deform_conv, modulated_deform_conv
+
+
+class DeformConv(nn.Module):
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=1,
+ padding=0,
+ dilation=1,
+ groups=1,
+ deformable_groups=1,
+ bias=False):
+ super(DeformConv, self).__init__()
+
+ assert not bias
+ assert in_channels % groups == 0, \
+ 'in_channels {} cannot be divisible by groups {}'.format(
+ in_channels, groups)
+ assert out_channels % groups == 0, \
+ 'out_channels {} cannot be divisible by groups {}'.format(
+ out_channels, groups)
+
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.kernel_size = _pair(kernel_size)
+ self.stride = _pair(stride)
+ self.padding = _pair(padding)
+ self.dilation = _pair(dilation)
+ self.groups = groups
+ self.deformable_groups = deformable_groups
+
+ self.weight = nn.Parameter(
+ torch.Tensor(out_channels, in_channels // self.groups,
+ *self.kernel_size))
+
+ self.reset_parameters()
+
+ def reset_parameters(self):
+ n = self.in_channels
+ for k in self.kernel_size:
+ n *= k
+ stdv = 1. / math.sqrt(n)
+ self.weight.data.uniform_(-stdv, stdv)
+
+ def forward(self, x, offset):
+ return deform_conv(x, offset, self.weight, self.stride, self.padding,
+ self.dilation, self.groups, self.deformable_groups)
+
+
+class DeformConvPack(DeformConv):
+
+ def __init__(self, *args, **kwargs):
+ super(DeformConvPack, self).__init__(*args, **kwargs)
+
+ self.conv_offset = nn.Conv2d(
+ self.in_channels,
+ self.deformable_groups * 2 * self.kernel_size[0] *
+ self.kernel_size[1],
+ kernel_size=self.kernel_size,
+ stride=_pair(self.stride),
+ padding=_pair(self.padding),
+ bias=True)
+ self.init_offset()
+
+ def init_offset(self):
+ self.conv_offset.weight.data.zero_()
+ self.conv_offset.bias.data.zero_()
+
+ def forward(self, x):
+ offset = self.conv_offset(x)
+ return deform_conv(x, offset, self.weight, self.stride, self.padding,
+ self.dilation, self.groups, self.deformable_groups)
+
+
+class ModulatedDeformConv(nn.Module):
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=1,
+ padding=0,
+ dilation=1,
+ groups=1,
+ deformable_groups=1,
+ bias=True):
+ super(ModulatedDeformConv, self).__init__()
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.kernel_size = _pair(kernel_size)
+ self.stride = stride
+ self.padding = padding
+ self.dilation = dilation
+ self.groups = groups
+ self.deformable_groups = deformable_groups
+ self.with_bias = bias
+
+ self.weight = nn.Parameter(
+ torch.Tensor(out_channels, in_channels // groups,
+ *self.kernel_size))
+ if bias:
+ self.bias = nn.Parameter(torch.Tensor(out_channels))
+ else:
+ self.register_parameter('bias', None)
+ self.reset_parameters()
+
+ def reset_parameters(self):
+ n = self.in_channels
+ for k in self.kernel_size:
+ n *= k
+ stdv = 1. / math.sqrt(n)
+ self.weight.data.uniform_(-stdv, stdv)
+ if self.bias is not None:
+ self.bias.data.zero_()
+
+ def forward(self, x, offset, mask):
+ return modulated_deform_conv(x, offset, mask, self.weight, self.bias,
+ self.stride, self.padding, self.dilation,
+ self.groups, self.deformable_groups)
+
+
+class ModulatedDeformConvPack(ModulatedDeformConv):
+
+ def __init__(self, *args, **kwargs):
+ super(ModulatedDeformConvPack, self).__init__(*args, **kwargs)
+
+ self.conv_offset_mask = nn.Conv2d(
+ self.in_channels,
+ self.deformable_groups * 3 * self.kernel_size[0] *
+ self.kernel_size[1],
+ kernel_size=self.kernel_size,
+ stride=_pair(self.stride),
+ padding=_pair(self.padding),
+ bias=True)
+ self.init_offset()
+
+ def init_offset(self):
+ self.conv_offset_mask.weight.data.zero_()
+ self.conv_offset_mask.bias.data.zero_()
+
+ def forward(self, x):
+ out = self.conv_offset_mask(x)
+ o1, o2, mask = torch.chunk(out, 3, dim=1)
+ offset = torch.cat((o1, o2), dim=1)
+ mask = torch.sigmoid(mask)
+ return modulated_deform_conv(x, offset, mask, self.weight, self.bias,
+ self.stride, self.padding, self.dilation,
+ self.groups, self.deformable_groups)
diff --git a/IndicPhotoOCR/detection/textbpn/network/backbone/assets/dcn/modules/deform_pool.py b/IndicPhotoOCR/detection/textbpn/network/backbone/assets/dcn/modules/deform_pool.py
new file mode 100644
index 0000000000000000000000000000000000000000..5e0196753ee1b427263bc397e0ae842af6a9938b
--- /dev/null
+++ b/IndicPhotoOCR/detection/textbpn/network/backbone/assets/dcn/modules/deform_pool.py
@@ -0,0 +1,172 @@
+from torch import nn
+
+from ..functions.deform_pool import deform_roi_pooling
+
+
+class DeformRoIPooling(nn.Module):
+
+ def __init__(self,
+ spatial_scale,
+ out_size,
+ out_channels,
+ no_trans,
+ group_size=1,
+ part_size=None,
+ sample_per_part=4,
+ trans_std=.0):
+ super(DeformRoIPooling, self).__init__()
+ self.spatial_scale = spatial_scale
+ self.out_size = out_size
+ self.out_channels = out_channels
+ self.no_trans = no_trans
+ self.group_size = group_size
+ self.part_size = out_size if part_size is None else part_size
+ self.sample_per_part = sample_per_part
+ self.trans_std = trans_std
+
+ def forward(self, data, rois, offset):
+ if self.no_trans:
+ offset = data.new_empty(0)
+ return deform_roi_pooling(
+ data, rois, offset, self.spatial_scale, self.out_size,
+ self.out_channels, self.no_trans, self.group_size, self.part_size,
+ self.sample_per_part, self.trans_std)
+
+
+class DeformRoIPoolingPack(DeformRoIPooling):
+
+ def __init__(self,
+ spatial_scale,
+ out_size,
+ out_channels,
+ no_trans,
+ group_size=1,
+ part_size=None,
+ sample_per_part=4,
+ trans_std=.0,
+ num_offset_fcs=3,
+ deform_fc_channels=1024):
+ super(DeformRoIPoolingPack,
+ self).__init__(spatial_scale, out_size, out_channels, no_trans,
+ group_size, part_size, sample_per_part, trans_std)
+
+ self.num_offset_fcs = num_offset_fcs
+ self.deform_fc_channels = deform_fc_channels
+
+ if not no_trans:
+ seq = []
+ ic = self.out_size * self.out_size * self.out_channels
+ for i in range(self.num_offset_fcs):
+ if i < self.num_offset_fcs - 1:
+ oc = self.deform_fc_channels
+ else:
+ oc = self.out_size * self.out_size * 2
+ seq.append(nn.Linear(ic, oc))
+ ic = oc
+ if i < self.num_offset_fcs - 1:
+ seq.append(nn.ReLU(inplace=True))
+ self.offset_fc = nn.Sequential(*seq)
+ self.offset_fc[-1].weight.data.zero_()
+ self.offset_fc[-1].bias.data.zero_()
+
+ def forward(self, data, rois):
+ assert data.size(1) == self.out_channels
+ if self.no_trans:
+ offset = data.new_empty(0)
+ return deform_roi_pooling(
+ data, rois, offset, self.spatial_scale, self.out_size,
+ self.out_channels, self.no_trans, self.group_size,
+ self.part_size, self.sample_per_part, self.trans_std)
+ else:
+ n = rois.shape[0]
+ offset = data.new_empty(0)
+ x = deform_roi_pooling(data, rois, offset, self.spatial_scale,
+ self.out_size, self.out_channels, True,
+ self.group_size, self.part_size,
+ self.sample_per_part, self.trans_std)
+ offset = self.offset_fc(x.view(n, -1))
+ offset = offset.view(n, 2, self.out_size, self.out_size)
+ return deform_roi_pooling(
+ data, rois, offset, self.spatial_scale, self.out_size,
+ self.out_channels, self.no_trans, self.group_size,
+ self.part_size, self.sample_per_part, self.trans_std)
+
+
+class ModulatedDeformRoIPoolingPack(DeformRoIPooling):
+
+ def __init__(self,
+ spatial_scale,
+ out_size,
+ out_channels,
+ no_trans,
+ group_size=1,
+ part_size=None,
+ sample_per_part=4,
+ trans_std=.0,
+ num_offset_fcs=3,
+ num_mask_fcs=2,
+ deform_fc_channels=1024):
+ super(ModulatedDeformRoIPoolingPack, self).__init__(
+ spatial_scale, out_size, out_channels, no_trans, group_size,
+ part_size, sample_per_part, trans_std)
+
+ self.num_offset_fcs = num_offset_fcs
+ self.num_mask_fcs = num_mask_fcs
+ self.deform_fc_channels = deform_fc_channels
+
+ if not no_trans:
+ offset_fc_seq = []
+ ic = self.out_size * self.out_size * self.out_channels
+ for i in range(self.num_offset_fcs):
+ if i < self.num_offset_fcs - 1:
+ oc = self.deform_fc_channels
+ else:
+ oc = self.out_size * self.out_size * 2
+ offset_fc_seq.append(nn.Linear(ic, oc))
+ ic = oc
+ if i < self.num_offset_fcs - 1:
+ offset_fc_seq.append(nn.ReLU(inplace=True))
+ self.offset_fc = nn.Sequential(*offset_fc_seq)
+ self.offset_fc[-1].weight.data.zero_()
+ self.offset_fc[-1].bias.data.zero_()
+
+ mask_fc_seq = []
+ ic = self.out_size * self.out_size * self.out_channels
+ for i in range(self.num_mask_fcs):
+ if i < self.num_mask_fcs - 1:
+ oc = self.deform_fc_channels
+ else:
+ oc = self.out_size * self.out_size
+ mask_fc_seq.append(nn.Linear(ic, oc))
+ ic = oc
+ if i < self.num_mask_fcs - 1:
+ mask_fc_seq.append(nn.ReLU(inplace=True))
+ else:
+ mask_fc_seq.append(nn.Sigmoid())
+ self.mask_fc = nn.Sequential(*mask_fc_seq)
+ self.mask_fc[-2].weight.data.zero_()
+ self.mask_fc[-2].bias.data.zero_()
+
+ def forward(self, data, rois):
+ assert data.size(1) == self.out_channels
+ if self.no_trans:
+ offset = data.new_empty(0)
+ return deform_roi_pooling(
+ data, rois, offset, self.spatial_scale, self.out_size,
+ self.out_channels, self.no_trans, self.group_size,
+ self.part_size, self.sample_per_part, self.trans_std)
+ else:
+ n = rois.shape[0]
+ offset = data.new_empty(0)
+ x = deform_roi_pooling(data, rois, offset, self.spatial_scale,
+ self.out_size, self.out_channels, True,
+ self.group_size, self.part_size,
+ self.sample_per_part, self.trans_std)
+ offset = self.offset_fc(x.view(n, -1))
+ offset = offset.view(n, 2, self.out_size, self.out_size)
+ mask = self.mask_fc(x.view(n, -1))
+ mask = mask.view(n, 1, self.out_size, self.out_size)
+ return deform_roi_pooling(
+ data, rois, offset, self.spatial_scale, self.out_size,
+ self.out_channels, self.no_trans, self.group_size,
+ self.part_size, self.sample_per_part, self.trans_std) * mask
diff --git a/IndicPhotoOCR/detection/textbpn/network/backbone/assets/dcn/setup.py b/IndicPhotoOCR/detection/textbpn/network/backbone/assets/dcn/setup.py
new file mode 100644
index 0000000000000000000000000000000000000000..5a9a0ecb742599cbeaa7ccc753418087704e1cfc
--- /dev/null
+++ b/IndicPhotoOCR/detection/textbpn/network/backbone/assets/dcn/setup.py
@@ -0,0 +1,19 @@
+import os
+PATH ="{}:{}".format(os.environ['PATH'], "/opt/cuda/bin")
+# os.environ['CUDA_VISIBLE_DEVICES'] = "1"
+os.environ['PATH'] = PATH
+from setuptools import setup
+from torch.utils.cpp_extension import BuildExtension, CUDAExtension
+
+setup(
+ name='deform_conv',
+ ext_modules=[
+ CUDAExtension('deform_conv_cuda', [
+ 'src/deform_conv_cuda.cpp',
+ 'src/deform_conv_cuda_kernel.cu',
+ ]),
+ CUDAExtension('deform_pool_cuda', [
+ 'src/deform_pool_cuda.cpp', 'src/deform_pool_cuda_kernel.cu'
+ ]),
+ ],
+ cmdclass={'build_ext': BuildExtension})
diff --git a/IndicPhotoOCR/detection/textbpn/network/backbone/assets/dcn/src/deform_conv_cuda.cpp b/IndicPhotoOCR/detection/textbpn/network/backbone/assets/dcn/src/deform_conv_cuda.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..e45155b94442f228760db21536f61948d7f1056e
--- /dev/null
+++ b/IndicPhotoOCR/detection/textbpn/network/backbone/assets/dcn/src/deform_conv_cuda.cpp
@@ -0,0 +1,695 @@
+// modify from
+// https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda.c
+
+#include
+
+#include
+#include
+
+void deformable_im2col(const at::Tensor data_im, const at::Tensor data_offset,
+ const int channels, const int height, const int width,
+ const int ksize_h, const int ksize_w, const int pad_h,
+ const int pad_w, const int stride_h, const int stride_w,
+ const int dilation_h, const int dilation_w,
+ const int parallel_imgs, const int deformable_group,
+ at::Tensor data_col);
+
+void deformable_col2im(const at::Tensor data_col, const at::Tensor data_offset,
+ const int channels, const int height, const int width,
+ const int ksize_h, const int ksize_w, const int pad_h,
+ const int pad_w, const int stride_h, const int stride_w,
+ const int dilation_h, const int dilation_w,
+ const int parallel_imgs, const int deformable_group,
+ at::Tensor grad_im);
+
+void deformable_col2im_coord(
+ const at::Tensor data_col, const at::Tensor data_im,
+ const at::Tensor data_offset, const int channels, const int height,
+ const int width, const int ksize_h, const int ksize_w, const int pad_h,
+ const int pad_w, const int stride_h, const int stride_w,
+ const int dilation_h, const int dilation_w, const int parallel_imgs,
+ const int deformable_group, at::Tensor grad_offset);
+
+void modulated_deformable_im2col_cuda(
+ const at::Tensor data_im, const at::Tensor data_offset,
+ const at::Tensor data_mask, const int batch_size, const int channels,
+ const int height_im, const int width_im, const int height_col,
+ const int width_col, const int kernel_h, const int kenerl_w,
+ const int pad_h, const int pad_w, const int stride_h, const int stride_w,
+ const int dilation_h, const int dilation_w, const int deformable_group,
+ at::Tensor data_col);
+
+void modulated_deformable_col2im_cuda(
+ const at::Tensor data_col, const at::Tensor data_offset,
+ const at::Tensor data_mask, const int batch_size, const int channels,
+ const int height_im, const int width_im, const int height_col,
+ const int width_col, const int kernel_h, const int kenerl_w,
+ const int pad_h, const int pad_w, const int stride_h, const int stride_w,
+ const int dilation_h, const int dilation_w, const int deformable_group,
+ at::Tensor grad_im);
+
+void modulated_deformable_col2im_coord_cuda(
+ const at::Tensor data_col, const at::Tensor data_im,
+ const at::Tensor data_offset, const at::Tensor data_mask,
+ const int batch_size, const int channels, const int height_im,
+ const int width_im, const int height_col, const int width_col,
+ const int kernel_h, const int kenerl_w, const int pad_h, const int pad_w,
+ const int stride_h, const int stride_w, const int dilation_h,
+ const int dilation_w, const int deformable_group, at::Tensor grad_offset,
+ at::Tensor grad_mask);
+
+void shape_check(at::Tensor input, at::Tensor offset, at::Tensor *gradOutput,
+ at::Tensor weight, int kH, int kW, int dH, int dW, int padH,
+ int padW, int dilationH, int dilationW, int group,
+ int deformable_group) {
+ TORCH_CHECK(weight.ndimension() == 4,
+ "4D weight tensor (nOutputPlane,nInputPlane,kH,kW) expected, "
+ "but got: %s",
+ weight.ndimension());
+
+ TORCH_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous");
+
+ TORCH_CHECK(kW > 0 && kH > 0,
+ "kernel size should be greater than zero, but got kH: %d kW: %d", kH,
+ kW);
+
+ TORCH_CHECK((weight.size(2) == kH && weight.size(3) == kW),
+ "kernel size should be consistent with weight, ",
+ "but got kH: %d kW: %d weight.size(2): %d, weight.size(3): %d", kH,
+ kW, weight.size(2), weight.size(3));
+
+ TORCH_CHECK(dW > 0 && dH > 0,
+ "stride should be greater than zero, but got dH: %d dW: %d", dH, dW);
+
+ TORCH_CHECK(
+ dilationW > 0 && dilationH > 0,
+ "dilation should be greater than 0, but got dilationH: %d dilationW: %d",
+ dilationH, dilationW);
+
+ int ndim = input.ndimension();
+ int dimf = 0;
+ int dimh = 1;
+ int dimw = 2;
+
+ if (ndim == 4) {
+ dimf++;
+ dimh++;
+ dimw++;
+ }
+
+ TORCH_CHECK(ndim == 3 || ndim == 4, "3D or 4D input tensor expected but got: %s",
+ ndim);
+
+ long nInputPlane = weight.size(1) * group;
+ long inputHeight = input.size(dimh);
+ long inputWidth = input.size(dimw);
+ long nOutputPlane = weight.size(0);
+ long outputHeight =
+ (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
+ long outputWidth =
+ (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
+
+ TORCH_CHECK(nInputPlane % deformable_group == 0,
+ "input channels must divide deformable group size");
+
+ if (outputWidth < 1 || outputHeight < 1)
+ AT_ERROR(
+ "Given input size: (%ld x %ld x %ld). "
+ "Calculated output size: (%ld x %ld x %ld). Output size is too small",
+ nInputPlane, inputHeight, inputWidth, nOutputPlane, outputHeight,
+ outputWidth);
+
+ TORCH_CHECK(input.size(1) == nInputPlane,
+ "invalid number of input planes, expected: %d, but got: %d",
+ nInputPlane, input.size(1));
+
+ TORCH_CHECK((inputHeight >= kH && inputWidth >= kW),
+ "input image is smaller than kernel");
+
+ TORCH_CHECK((offset.size(2) == outputHeight && offset.size(3) == outputWidth),
+ "invalid spatial size of offset, expected height: %d width: %d, but "
+ "got height: %d width: %d",
+ outputHeight, outputWidth, offset.size(2), offset.size(3));
+
+ TORCH_CHECK((offset.size(1) == deformable_group * 2 * kH * kW),
+ "invalid number of channels of offset");
+
+ if (gradOutput != NULL) {
+ TORCH_CHECK(gradOutput->size(dimf) == nOutputPlane,
+ "invalid number of gradOutput planes, expected: %d, but got: %d",
+ nOutputPlane, gradOutput->size(dimf));
+
+ TORCH_CHECK((gradOutput->size(dimh) == outputHeight &&
+ gradOutput->size(dimw) == outputWidth),
+ "invalid size of gradOutput, expected height: %d width: %d , but "
+ "got height: %d width: %d",
+ outputHeight, outputWidth, gradOutput->size(dimh),
+ gradOutput->size(dimw));
+ }
+}
+
+int deform_conv_forward_cuda(at::Tensor input, at::Tensor weight,
+ at::Tensor offset, at::Tensor output,
+ at::Tensor columns, at::Tensor ones, int kW,
+ int kH, int dW, int dH, int padW, int padH,
+ int dilationW, int dilationH, int group,
+ int deformable_group, int im2col_step) {
+ // todo: resize columns to include im2col: done
+ // todo: add im2col_step as input
+ // todo: add new output buffer and transpose it to output (or directly
+ // transpose output) todo: possibly change data indexing because of
+ // parallel_imgs
+
+ shape_check(input, offset, NULL, weight, kH, kW, dH, dW, padH, padW,
+ dilationH, dilationW, group, deformable_group);
+
+ input = input.contiguous();
+ offset = offset.contiguous();
+ weight = weight.contiguous();
+
+ int batch = 1;
+ if (input.ndimension() == 3) {
+ // Force batch
+ batch = 0;
+ input.unsqueeze_(0);
+ offset.unsqueeze_(0);
+ }
+
+ // todo: assert batchsize dividable by im2col_step
+
+ long batchSize = input.size(0);
+ long nInputPlane = input.size(1);
+ long inputHeight = input.size(2);
+ long inputWidth = input.size(3);
+
+ long nOutputPlane = weight.size(0);
+
+ long outputWidth =
+ (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
+ long outputHeight =
+ (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
+
+ TORCH_CHECK((offset.size(0) == batchSize), "invalid batch size of offset");
+
+ output = output.view({batchSize / im2col_step, im2col_step, nOutputPlane,
+ outputHeight, outputWidth});
+ columns = at::zeros(
+ {nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth},
+ input.options());
+
+ if (ones.ndimension() != 2 ||
+ ones.size(0) * ones.size(1) < outputHeight * outputWidth) {
+ ones = at::ones({outputHeight, outputWidth}, input.options());
+ }
+
+ input = input.view({batchSize / im2col_step, im2col_step, nInputPlane,
+ inputHeight, inputWidth});
+ offset =
+ offset.view({batchSize / im2col_step, im2col_step,
+ deformable_group * 2 * kH * kW, outputHeight, outputWidth});
+
+ at::Tensor output_buffer =
+ at::zeros({batchSize / im2col_step, nOutputPlane,
+ im2col_step * outputHeight, outputWidth},
+ output.options());
+
+ output_buffer = output_buffer.view(
+ {output_buffer.size(0), group, output_buffer.size(1) / group,
+ output_buffer.size(2), output_buffer.size(3)});
+
+ for (int elt = 0; elt < batchSize / im2col_step; elt++) {
+ deformable_im2col(input[elt], offset[elt], nInputPlane, inputHeight,
+ inputWidth, kH, kW, padH, padW, dH, dW, dilationH,
+ dilationW, im2col_step, deformable_group, columns);
+
+ columns = columns.view({group, columns.size(0) / group, columns.size(1)});
+ weight = weight.view({group, weight.size(0) / group, weight.size(1),
+ weight.size(2), weight.size(3)});
+
+ for (int g = 0; g < group; g++) {
+ output_buffer[elt][g] = output_buffer[elt][g]
+ .flatten(1)
+ .addmm_(weight[g].flatten(1), columns[g])
+ .view_as(output_buffer[elt][g]);
+ }
+ }
+
+ output_buffer = output_buffer.view(
+ {output_buffer.size(0), output_buffer.size(1) * output_buffer.size(2),
+ output_buffer.size(3), output_buffer.size(4)});
+
+ output_buffer = output_buffer.view({batchSize / im2col_step, nOutputPlane,
+ im2col_step, outputHeight, outputWidth});
+ output_buffer.transpose_(1, 2);
+ output.copy_(output_buffer);
+ output = output.view({batchSize, nOutputPlane, outputHeight, outputWidth});
+
+ input = input.view({batchSize, nInputPlane, inputHeight, inputWidth});
+ offset = offset.view(
+ {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});
+
+ if (batch == 0) {
+ output = output.view({nOutputPlane, outputHeight, outputWidth});
+ input = input.view({nInputPlane, inputHeight, inputWidth});
+ offset = offset.view({offset.size(1), offset.size(2), offset.size(3)});
+ }
+
+ return 1;
+}
+
+int deform_conv_backward_input_cuda(at::Tensor input, at::Tensor offset,
+ at::Tensor gradOutput, at::Tensor gradInput,
+ at::Tensor gradOffset, at::Tensor weight,
+ at::Tensor columns, int kW, int kH, int dW,
+ int dH, int padW, int padH, int dilationW,
+ int dilationH, int group,
+ int deformable_group, int im2col_step) {
+ shape_check(input, offset, &gradOutput, weight, kH, kW, dH, dW, padH, padW,
+ dilationH, dilationW, group, deformable_group);
+
+ input = input.contiguous();
+ offset = offset.contiguous();
+ gradOutput = gradOutput.contiguous();
+ weight = weight.contiguous();
+
+ int batch = 1;
+
+ if (input.ndimension() == 3) {
+ // Force batch
+ batch = 0;
+ input = input.view({1, input.size(0), input.size(1), input.size(2)});
+ offset = offset.view({1, offset.size(0), offset.size(1), offset.size(2)});
+ gradOutput = gradOutput.view(
+ {1, gradOutput.size(0), gradOutput.size(1), gradOutput.size(2)});
+ }
+
+ long batchSize = input.size(0);
+ long nInputPlane = input.size(1);
+ long inputHeight = input.size(2);
+ long inputWidth = input.size(3);
+
+ long nOutputPlane = weight.size(0);
+
+ long outputWidth =
+ (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
+ long outputHeight =
+ (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
+
+ TORCH_CHECK((offset.size(0) == batchSize), 3, "invalid batch size of offset");
+ gradInput = gradInput.view({batchSize, nInputPlane, inputHeight, inputWidth});
+ columns = at::zeros(
+ {nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth},
+ input.options());
+
+ // change order of grad output
+ gradOutput = gradOutput.view({batchSize / im2col_step, im2col_step,
+ nOutputPlane, outputHeight, outputWidth});
+ gradOutput.transpose_(1, 2);
+
+ gradInput = gradInput.view({batchSize / im2col_step, im2col_step, nInputPlane,
+ inputHeight, inputWidth});
+ input = input.view({batchSize / im2col_step, im2col_step, nInputPlane,
+ inputHeight, inputWidth});
+ gradOffset = gradOffset.view({batchSize / im2col_step, im2col_step,
+ deformable_group * 2 * kH * kW, outputHeight,
+ outputWidth});
+ offset =
+ offset.view({batchSize / im2col_step, im2col_step,
+ deformable_group * 2 * kH * kW, outputHeight, outputWidth});
+
+ for (int elt = 0; elt < batchSize / im2col_step; elt++) {
+ // divide into groups
+ columns = columns.view({group, columns.size(0) / group, columns.size(1)});
+ weight = weight.view({group, weight.size(0) / group, weight.size(1),
+ weight.size(2), weight.size(3)});
+ gradOutput = gradOutput.view(
+ {gradOutput.size(0), group, gradOutput.size(1) / group,
+ gradOutput.size(2), gradOutput.size(3), gradOutput.size(4)});
+
+ for (int g = 0; g < group; g++) {
+ columns[g] = columns[g].addmm_(weight[g].flatten(1).transpose(0, 1),
+ gradOutput[elt][g].flatten(1), 0.0f, 1.0f);
+ }
+
+ columns =
+ columns.view({columns.size(0) * columns.size(1), columns.size(2)});
+ gradOutput = gradOutput.view(
+ {gradOutput.size(0), gradOutput.size(1) * gradOutput.size(2),
+ gradOutput.size(3), gradOutput.size(4), gradOutput.size(5)});
+
+ deformable_col2im_coord(columns, input[elt], offset[elt], nInputPlane,
+ inputHeight, inputWidth, kH, kW, padH, padW, dH, dW,
+ dilationH, dilationW, im2col_step, deformable_group,
+ gradOffset[elt]);
+
+ deformable_col2im(columns, offset[elt], nInputPlane, inputHeight,
+ inputWidth, kH, kW, padH, padW, dH, dW, dilationH,
+ dilationW, im2col_step, deformable_group, gradInput[elt]);
+ }
+
+ gradOutput.transpose_(1, 2);
+ gradOutput =
+ gradOutput.view({batchSize, nOutputPlane, outputHeight, outputWidth});
+
+ gradInput = gradInput.view({batchSize, nInputPlane, inputHeight, inputWidth});
+ input = input.view({batchSize, nInputPlane, inputHeight, inputWidth});
+ gradOffset = gradOffset.view(
+ {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});
+ offset = offset.view(
+ {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});
+
+ if (batch == 0) {
+ gradOutput = gradOutput.view({nOutputPlane, outputHeight, outputWidth});
+ input = input.view({nInputPlane, inputHeight, inputWidth});
+ gradInput = gradInput.view({nInputPlane, inputHeight, inputWidth});
+ offset = offset.view({offset.size(1), offset.size(2), offset.size(3)});
+ gradOffset =
+ gradOffset.view({offset.size(1), offset.size(2), offset.size(3)});
+ }
+
+ return 1;
+}
+
+int deform_conv_backward_parameters_cuda(
+ at::Tensor input, at::Tensor offset, at::Tensor gradOutput,
+ at::Tensor gradWeight, // at::Tensor gradBias,
+ at::Tensor columns, at::Tensor ones, int kW, int kH, int dW, int dH,
+ int padW, int padH, int dilationW, int dilationH, int group,
+ int deformable_group, float scale, int im2col_step) {
+ // todo: transpose and reshape outGrad
+ // todo: reshape columns
+ // todo: add im2col_step as input
+
+ shape_check(input, offset, &gradOutput, gradWeight, kH, kW, dH, dW, padH,
+ padW, dilationH, dilationW, group, deformable_group);
+
+ input = input.contiguous();
+ offset = offset.contiguous();
+ gradOutput = gradOutput.contiguous();
+
+ int batch = 1;
+
+ if (input.ndimension() == 3) {
+ // Force batch
+ batch = 0;
+ input = input.view(
+ at::IntList({1, input.size(0), input.size(1), input.size(2)}));
+ gradOutput = gradOutput.view(
+ {1, gradOutput.size(0), gradOutput.size(1), gradOutput.size(2)});
+ }
+
+ long batchSize = input.size(0);
+ long nInputPlane = input.size(1);
+ long inputHeight = input.size(2);
+ long inputWidth = input.size(3);
+
+ long nOutputPlane = gradWeight.size(0);
+
+ long outputWidth =
+ (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
+ long outputHeight =
+ (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
+
+ TORCH_CHECK((offset.size(0) == batchSize), "invalid batch size of offset");
+
+ columns = at::zeros(
+ {nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth},
+ input.options());
+
+ gradOutput = gradOutput.view({batchSize / im2col_step, im2col_step,
+ nOutputPlane, outputHeight, outputWidth});
+ gradOutput.transpose_(1, 2);
+
+ at::Tensor gradOutputBuffer = at::zeros_like(gradOutput);
+ gradOutputBuffer =
+ gradOutputBuffer.view({batchSize / im2col_step, nOutputPlane, im2col_step,
+ outputHeight, outputWidth});
+ gradOutputBuffer.copy_(gradOutput);
+ gradOutputBuffer =
+ gradOutputBuffer.view({batchSize / im2col_step, nOutputPlane,
+ im2col_step * outputHeight, outputWidth});
+
+ gradOutput.transpose_(1, 2);
+ gradOutput =
+ gradOutput.view({batchSize, nOutputPlane, outputHeight, outputWidth});
+
+ input = input.view({batchSize / im2col_step, im2col_step, nInputPlane,
+ inputHeight, inputWidth});
+ offset =
+ offset.view({batchSize / im2col_step, im2col_step,
+ deformable_group * 2 * kH * kW, outputHeight, outputWidth});
+
+ for (int elt = 0; elt < batchSize / im2col_step; elt++) {
+ deformable_im2col(input[elt], offset[elt], nInputPlane, inputHeight,
+ inputWidth, kH, kW, padH, padW, dH, dW, dilationH,
+ dilationW, im2col_step, deformable_group, columns);
+
+ // divide into group
+ gradOutputBuffer = gradOutputBuffer.view(
+ {gradOutputBuffer.size(0), group, gradOutputBuffer.size(1) / group,
+ gradOutputBuffer.size(2), gradOutputBuffer.size(3)});
+ columns = columns.view({group, columns.size(0) / group, columns.size(1)});
+ gradWeight =
+ gradWeight.view({group, gradWeight.size(0) / group, gradWeight.size(1),
+ gradWeight.size(2), gradWeight.size(3)});
+
+ for (int g = 0; g < group; g++) {
+ gradWeight[g] = gradWeight[g]
+ .flatten(1)
+ .addmm_(gradOutputBuffer[elt][g].flatten(1),
+ columns[g].transpose(1, 0), 1.0, scale)
+ .view_as(gradWeight[g]);
+ }
+ gradOutputBuffer = gradOutputBuffer.view(
+ {gradOutputBuffer.size(0),
+ gradOutputBuffer.size(1) * gradOutputBuffer.size(2),
+ gradOutputBuffer.size(3), gradOutputBuffer.size(4)});
+ columns =
+ columns.view({columns.size(0) * columns.size(1), columns.size(2)});
+ gradWeight = gradWeight.view({gradWeight.size(0) * gradWeight.size(1),
+ gradWeight.size(2), gradWeight.size(3),
+ gradWeight.size(4)});
+ }
+
+ input = input.view({batchSize, nInputPlane, inputHeight, inputWidth});
+ offset = offset.view(
+ {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});
+
+ if (batch == 0) {
+ gradOutput = gradOutput.view({nOutputPlane, outputHeight, outputWidth});
+ input = input.view({nInputPlane, inputHeight, inputWidth});
+ }
+
+ return 1;
+}
+
+void modulated_deform_conv_cuda_forward(
+ at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones,
+ at::Tensor offset, at::Tensor mask, at::Tensor output, at::Tensor columns,
+ int kernel_h, int kernel_w, const int stride_h, const int stride_w,
+ const int pad_h, const int pad_w, const int dilation_h,
+ const int dilation_w, const int group, const int deformable_group,
+ const bool with_bias) {
+ TORCH_CHECK(input.is_contiguous(), "input tensor has to be contiguous");
+ TORCH_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous");
+
+ const int batch = input.size(0);
+ const int channels = input.size(1);
+ const int height = input.size(2);
+ const int width = input.size(3);
+
+ const int channels_out = weight.size(0);
+ const int channels_kernel = weight.size(1);
+ const int kernel_h_ = weight.size(2);
+ const int kernel_w_ = weight.size(3);
+
+ if (kernel_h_ != kernel_h || kernel_w_ != kernel_w)
+ AT_ERROR("Input shape and kernel shape wont match: (%d x %d vs %d x %d).",
+ kernel_h_, kernel_w, kernel_h_, kernel_w_);
+ if (channels != channels_kernel * group)
+ AT_ERROR("Input shape and kernel channels wont match: (%d vs %d).",
+ channels, channels_kernel * group);
+
+ const int height_out =
+ (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
+ const int width_out =
+ (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;
+
+ if (ones.ndimension() != 2 ||
+ ones.size(0) * ones.size(1) < height_out * width_out) {
+ // Resize plane and fill with ones...
+ ones = at::ones({height_out, width_out}, input.options());
+ }
+
+ // resize output
+ output = output.view({batch, channels_out, height_out, width_out}).zero_();
+ // resize temporary columns
+ columns =
+ at::zeros({channels * kernel_h * kernel_w, 1 * height_out * width_out},
+ input.options());
+
+ output = output.view({output.size(0), group, output.size(1) / group,
+ output.size(2), output.size(3)});
+
+ for (int b = 0; b < batch; b++) {
+ modulated_deformable_im2col_cuda(
+ input[b], offset[b], mask[b], 1, channels, height, width, height_out,
+ width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
+ dilation_h, dilation_w, deformable_group, columns);
+
+ // divide into group
+ weight = weight.view({group, weight.size(0) / group, weight.size(1),
+ weight.size(2), weight.size(3)});
+ columns = columns.view({group, columns.size(0) / group, columns.size(1)});
+
+ for (int g = 0; g < group; g++) {
+ output[b][g] = output[b][g]
+ .flatten(1)
+ .addmm_(weight[g].flatten(1), columns[g])
+ .view_as(output[b][g]);
+ }
+
+ weight = weight.view({weight.size(0) * weight.size(1), weight.size(2),
+ weight.size(3), weight.size(4)});
+ columns =
+ columns.view({columns.size(0) * columns.size(1), columns.size(2)});
+ }
+
+ output = output.view({output.size(0), output.size(1) * output.size(2),
+ output.size(3), output.size(4)});
+
+ if (with_bias) {
+ output += bias.view({1, bias.size(0), 1, 1});
+ }
+}
+
+void modulated_deform_conv_cuda_backward(
+ at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones,
+ at::Tensor offset, at::Tensor mask, at::Tensor columns,
+ at::Tensor grad_input, at::Tensor grad_weight, at::Tensor grad_bias,
+ at::Tensor grad_offset, at::Tensor grad_mask, at::Tensor grad_output,
+ int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h,
+ int pad_w, int dilation_h, int dilation_w, int group, int deformable_group,
+ const bool with_bias) {
+ TORCH_CHECK(input.is_contiguous(), "input tensor has to be contiguous");
+ TORCH_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous");
+
+ const int batch = input.size(0);
+ const int channels = input.size(1);
+ const int height = input.size(2);
+ const int width = input.size(3);
+
+ const int channels_kernel = weight.size(1);
+ const int kernel_h_ = weight.size(2);
+ const int kernel_w_ = weight.size(3);
+ if (kernel_h_ != kernel_h || kernel_w_ != kernel_w)
+ AT_ERROR("Input shape and kernel shape wont match: (%d x %d vs %d x %d).",
+ kernel_h_, kernel_w, kernel_h_, kernel_w_);
+ if (channels != channels_kernel * group)
+ AT_ERROR("Input shape and kernel channels wont match: (%d vs %d).",
+ channels, channels_kernel * group);
+
+ const int height_out =
+ (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
+ const int width_out =
+ (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;
+
+ if (ones.ndimension() != 2 ||
+ ones.size(0) * ones.size(1) < height_out * width_out) {
+ // Resize plane and fill with ones...
+ ones = at::ones({height_out, width_out}, input.options());
+ }
+
+ grad_input = grad_input.view({batch, channels, height, width});
+ columns = at::zeros({channels * kernel_h * kernel_w, height_out * width_out},
+ input.options());
+
+ grad_output =
+ grad_output.view({grad_output.size(0), group, grad_output.size(1) / group,
+ grad_output.size(2), grad_output.size(3)});
+
+ for (int b = 0; b < batch; b++) {
+ // divide int group
+ columns = columns.view({group, columns.size(0) / group, columns.size(1)});
+ weight = weight.view({group, weight.size(0) / group, weight.size(1),
+ weight.size(2), weight.size(3)});
+
+ for (int g = 0; g < group; g++) {
+ columns[g].addmm_(weight[g].flatten(1).transpose(0, 1),
+ grad_output[b][g].flatten(1), 0.0f, 1.0f);
+ }
+
+ columns =
+ columns.view({columns.size(0) * columns.size(1), columns.size(2)});
+ weight = weight.view({weight.size(0) * weight.size(1), weight.size(2),
+ weight.size(3), weight.size(4)});
+
+ // gradient w.r.t. input coordinate data
+ modulated_deformable_col2im_coord_cuda(
+ columns, input[b], offset[b], mask[b], 1, channels, height, width,
+ height_out, width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h,
+ stride_w, dilation_h, dilation_w, deformable_group, grad_offset[b],
+ grad_mask[b]);
+ // gradient w.r.t. input data
+ modulated_deformable_col2im_cuda(
+ columns, offset[b], mask[b], 1, channels, height, width, height_out,
+ width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
+ dilation_h, dilation_w, deformable_group, grad_input[b]);
+
+ // gradient w.r.t. weight, dWeight should accumulate across the batch and
+ // group
+ modulated_deformable_im2col_cuda(
+ input[b], offset[b], mask[b], 1, channels, height, width, height_out,
+ width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
+ dilation_h, dilation_w, deformable_group, columns);
+
+ columns = columns.view({group, columns.size(0) / group, columns.size(1)});
+ grad_weight = grad_weight.view({group, grad_weight.size(0) / group,
+ grad_weight.size(1), grad_weight.size(2),
+ grad_weight.size(3)});
+ if (with_bias)
+ grad_bias = grad_bias.view({group, grad_bias.size(0) / group});
+
+ for (int g = 0; g < group; g++) {
+ grad_weight[g] =
+ grad_weight[g]
+ .flatten(1)
+ .addmm_(grad_output[b][g].flatten(1), columns[g].transpose(0, 1))
+ .view_as(grad_weight[g]);
+ if (with_bias) {
+ grad_bias[g] =
+ grad_bias[g]
+ .view({-1, 1})
+ .addmm_(grad_output[b][g].flatten(1), ones.view({-1, 1}))
+ .view(-1);
+ }
+ }
+
+ columns =
+ columns.view({columns.size(0) * columns.size(1), columns.size(2)});
+ grad_weight = grad_weight.view({grad_weight.size(0) * grad_weight.size(1),
+ grad_weight.size(2), grad_weight.size(3),
+ grad_weight.size(4)});
+ if (with_bias)
+ grad_bias = grad_bias.view({grad_bias.size(0) * grad_bias.size(1)});
+ }
+ grad_output = grad_output.view({grad_output.size(0) * grad_output.size(1),
+ grad_output.size(2), grad_output.size(3),
+ grad_output.size(4)});
+}
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+ m.def("deform_conv_forward_cuda", &deform_conv_forward_cuda,
+ "deform forward (CUDA)");
+ m.def("deform_conv_backward_input_cuda", &deform_conv_backward_input_cuda,
+ "deform_conv_backward_input (CUDA)");
+ m.def("deform_conv_backward_parameters_cuda",
+ &deform_conv_backward_parameters_cuda,
+ "deform_conv_backward_parameters (CUDA)");
+ m.def("modulated_deform_conv_cuda_forward",
+ &modulated_deform_conv_cuda_forward,
+ "modulated deform conv forward (CUDA)");
+ m.def("modulated_deform_conv_cuda_backward",
+ &modulated_deform_conv_cuda_backward,
+ "modulated deform conv backward (CUDA)");
+}
diff --git a/IndicPhotoOCR/detection/textbpn/network/backbone/assets/dcn/src/deform_conv_cuda_kernel.cu b/IndicPhotoOCR/detection/textbpn/network/backbone/assets/dcn/src/deform_conv_cuda_kernel.cu
new file mode 100644
index 0000000000000000000000000000000000000000..48c6d8825387ce4b248f07f77f5eeb65ab9bcb49
--- /dev/null
+++ b/IndicPhotoOCR/detection/textbpn/network/backbone/assets/dcn/src/deform_conv_cuda_kernel.cu
@@ -0,0 +1,866 @@
+/*!
+ ******************* BEGIN Caffe Copyright Notice and Disclaimer ****************
+ *
+ * COPYRIGHT
+ *
+ * All contributions by the University of California:
+ * Copyright (c) 2014-2017 The Regents of the University of California (Regents)
+ * All rights reserved.
+ *
+ * All other contributions:
+ * Copyright (c) 2014-2017, the respective contributors
+ * All rights reserved.
+ *
+ * Caffe uses a shared copyright model: each contributor holds copyright over
+ * their contributions to Caffe. The project versioning records all such
+ * contribution and copyright details. If a contributor wants to further mark
+ * their specific copyright on a particular contribution, they should indicate
+ * their copyright solely in the commit message of the change when it is
+ * committed.
+ *
+ * LICENSE
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions are met:
+ *
+ * 1. Redistributions of source code must retain the above copyright notice, this
+ * list of conditions and the following disclaimer.
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
+ * this list of conditions and the following disclaimer in the documentation
+ * and/or other materials provided with the distribution.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
+ * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
+ * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
+ * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
+ * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+ * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
+ * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+ * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+ * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+ *
+ * CONTRIBUTION AGREEMENT
+ *
+ * By contributing to the BVLC/caffe repository through pull-request, comment,
+ * or otherwise, the contributor releases their content to the
+ * license and copyright terms herein.
+ *
+ ***************** END Caffe Copyright Notice and Disclaimer ********************
+ *
+ * Copyright (c) 2018 Microsoft
+ * Licensed under The MIT License [see LICENSE for details]
+ * \file modulated_deformable_im2col.cuh
+ * \brief Function definitions of converting an image to
+ * column matrix based on kernel, padding, dilation, and offset.
+ * These functions are mainly used in deformable convolution operators.
+ * \ref: https://arxiv.org/abs/1703.06211
+ * \author Yuwen Xiong, Haozhi Qi, Jifeng Dai, Xizhou Zhu, Han Hu, Dazhi Cheng
+ */
+
+// modify from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda_kernel.cu
+
+#include
+#include
+#include
+#include
+#include
+
+using namespace at;
+
+#define CUDA_KERNEL_LOOP(i, n) \
+ for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
+ i += blockDim.x * gridDim.x)
+
+const int CUDA_NUM_THREADS = 1024;
+const int kMaxGridNum = 65535;
+
+inline int GET_BLOCKS(const int N)
+{
+ return std::min(kMaxGridNum, (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS);
+}
+
+template
+__device__ scalar_t deformable_im2col_bilinear(const scalar_t *bottom_data, const int data_width,
+ const int height, const int width, scalar_t h, scalar_t w)
+{
+
+ int h_low = floor(h);
+ int w_low = floor(w);
+ int h_high = h_low + 1;
+ int w_high = w_low + 1;
+
+ scalar_t lh = h - h_low;
+ scalar_t lw = w - w_low;
+ scalar_t hh = 1 - lh, hw = 1 - lw;
+
+ scalar_t v1 = 0;
+ if (h_low >= 0 && w_low >= 0)
+ v1 = bottom_data[h_low * data_width + w_low];
+ scalar_t v2 = 0;
+ if (h_low >= 0 && w_high <= width - 1)
+ v2 = bottom_data[h_low * data_width + w_high];
+ scalar_t v3 = 0;
+ if (h_high <= height - 1 && w_low >= 0)
+ v3 = bottom_data[h_high * data_width + w_low];
+ scalar_t v4 = 0;
+ if (h_high <= height - 1 && w_high <= width - 1)
+ v4 = bottom_data[h_high * data_width + w_high];
+
+ scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
+
+ scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
+ return val;
+}
+
+template
+__device__ scalar_t get_gradient_weight(scalar_t argmax_h, scalar_t argmax_w,
+ const int h, const int w, const int height, const int width)
+{
+
+ if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width)
+ {
+ //empty
+ return 0;
+ }
+
+ int argmax_h_low = floor(argmax_h);
+ int argmax_w_low = floor(argmax_w);
+ int argmax_h_high = argmax_h_low + 1;
+ int argmax_w_high = argmax_w_low + 1;
+
+ scalar_t weight = 0;
+ if (h == argmax_h_low && w == argmax_w_low)
+ weight = (h + 1 - argmax_h) * (w + 1 - argmax_w);
+ if (h == argmax_h_low && w == argmax_w_high)
+ weight = (h + 1 - argmax_h) * (argmax_w + 1 - w);
+ if (h == argmax_h_high && w == argmax_w_low)
+ weight = (argmax_h + 1 - h) * (w + 1 - argmax_w);
+ if (h == argmax_h_high && w == argmax_w_high)
+ weight = (argmax_h + 1 - h) * (argmax_w + 1 - w);
+ return weight;
+}
+
+template
+__device__ scalar_t get_coordinate_weight(scalar_t argmax_h, scalar_t argmax_w,
+ const int height, const int width, const scalar_t *im_data,
+ const int data_width, const int bp_dir)
+{
+
+ if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width)
+ {
+ //empty
+ return 0;
+ }
+
+ int argmax_h_low = floor(argmax_h);
+ int argmax_w_low = floor(argmax_w);
+ int argmax_h_high = argmax_h_low + 1;
+ int argmax_w_high = argmax_w_low + 1;
+
+ scalar_t weight = 0;
+
+ if (bp_dir == 0)
+ {
+ if (argmax_h_low >= 0 && argmax_w_low >= 0)
+ weight += -1 * (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_low * data_width + argmax_w_low];
+ if (argmax_h_low >= 0 && argmax_w_high <= width - 1)
+ weight += -1 * (argmax_w - argmax_w_low) * im_data[argmax_h_low * data_width + argmax_w_high];
+ if (argmax_h_high <= height - 1 && argmax_w_low >= 0)
+ weight += (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_high * data_width + argmax_w_low];
+ if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)
+ weight += (argmax_w - argmax_w_low) * im_data[argmax_h_high * data_width + argmax_w_high];
+ }
+ else if (bp_dir == 1)
+ {
+ if (argmax_h_low >= 0 && argmax_w_low >= 0)
+ weight += -1 * (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_low];
+ if (argmax_h_low >= 0 && argmax_w_high <= width - 1)
+ weight += (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_high];
+ if (argmax_h_high <= height - 1 && argmax_w_low >= 0)
+ weight += -1 * (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_low];
+ if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)
+ weight += (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_high];
+ }
+
+ return weight;
+}
+
+template
+__global__ void deformable_im2col_gpu_kernel(const int n, const scalar_t *data_im, const scalar_t *data_offset,
+ const int height, const int width, const int kernel_h, const int kernel_w,
+ const int pad_h, const int pad_w, const int stride_h, const int stride_w,
+ const int dilation_h, const int dilation_w, const int channel_per_deformable_group,
+ const int batch_size, const int num_channels, const int deformable_group,
+ const int height_col, const int width_col,
+ scalar_t *data_col)
+{
+ CUDA_KERNEL_LOOP(index, n)
+ {
+ // index index of output matrix
+ const int w_col = index % width_col;
+ const int h_col = (index / width_col) % height_col;
+ const int b_col = (index / width_col / height_col) % batch_size;
+ const int c_im = (index / width_col / height_col) / batch_size;
+ const int c_col = c_im * kernel_h * kernel_w;
+
+ // compute deformable group index
+ const int deformable_group_index = c_im / channel_per_deformable_group;
+
+ const int h_in = h_col * stride_h - pad_h;
+ const int w_in = w_col * stride_w - pad_w;
+ scalar_t *data_col_ptr = data_col + ((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col;
+ //const scalar_t* data_im_ptr = data_im + ((b_col * num_channels + c_im) * height + h_in) * width + w_in;
+ const scalar_t *data_im_ptr = data_im + (b_col * num_channels + c_im) * height * width;
+ const scalar_t *data_offset_ptr = data_offset + (b_col * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col;
+
+ for (int i = 0; i < kernel_h; ++i)
+ {
+ for (int j = 0; j < kernel_w; ++j)
+ {
+ const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col;
+ const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col + w_col;
+ const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];
+ const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];
+ scalar_t val = static_cast(0);
+ const scalar_t h_im = h_in + i * dilation_h + offset_h;
+ const scalar_t w_im = w_in + j * dilation_w + offset_w;
+ if (h_im > -1 && w_im > -1 && h_im < height && w_im < width)
+ {
+ //const scalar_t map_h = i * dilation_h + offset_h;
+ //const scalar_t map_w = j * dilation_w + offset_w;
+ //const int cur_height = height - h_in;
+ //const int cur_width = width - w_in;
+ //val = deformable_im2col_bilinear(data_im_ptr, width, cur_height, cur_width, map_h, map_w);
+ val = deformable_im2col_bilinear(data_im_ptr, width, height, width, h_im, w_im);
+ }
+ *data_col_ptr = val;
+ data_col_ptr += batch_size * height_col * width_col;
+ }
+ }
+ }
+}
+
+void deformable_im2col(
+ const at::Tensor data_im, const at::Tensor data_offset, const int channels,
+ const int height, const int width, const int ksize_h, const int ksize_w,
+ const int pad_h, const int pad_w, const int stride_h, const int stride_w,
+ const int dilation_h, const int dilation_w, const int parallel_imgs,
+ const int deformable_group, at::Tensor data_col)
+{
+ // num_axes should be smaller than block size
+ // todo: check parallel_imgs is correctly passed in
+ int height_col = (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1;
+ int width_col = (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1;
+ int num_kernels = channels * height_col * width_col * parallel_imgs;
+ int channel_per_deformable_group = channels / deformable_group;
+
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+ data_im.type(), "deformable_im2col_gpu", ([&] {
+ const scalar_t *data_im_ = data_im.data();
+ const scalar_t *data_offset_ = data_offset.data();
+ scalar_t *data_col_ = data_col.data();
+
+ deformable_im2col_gpu_kernel<<>>(
+ num_kernels, data_im_, data_offset_, height, width, ksize_h, ksize_w,
+ pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
+ channel_per_deformable_group, parallel_imgs, channels, deformable_group,
+ height_col, width_col, data_col_);
+ }));
+
+ cudaError_t err = cudaGetLastError();
+ if (err != cudaSuccess)
+ {
+ printf("error in deformable_im2col: %s\n", cudaGetErrorString(err));
+ }
+}
+
+template
+__global__ void deformable_col2im_gpu_kernel(
+ const int n, const scalar_t *data_col, const scalar_t *data_offset,
+ const int channels, const int height, const int width,
+ const int kernel_h, const int kernel_w,
+ const int pad_h, const int pad_w,
+ const int stride_h, const int stride_w,
+ const int dilation_h, const int dilation_w,
+ const int channel_per_deformable_group,
+ const int batch_size, const int deformable_group,
+ const int height_col, const int width_col,
+ scalar_t *grad_im)
+{
+ CUDA_KERNEL_LOOP(index, n)
+ {
+ const int j = (index / width_col / height_col / batch_size) % kernel_w;
+ const int i = (index / width_col / height_col / batch_size / kernel_w) % kernel_h;
+ const int c = index / width_col / height_col / batch_size / kernel_w / kernel_h;
+ // compute the start and end of the output
+
+ const int deformable_group_index = c / channel_per_deformable_group;
+
+ int w_out = index % width_col;
+ int h_out = (index / width_col) % height_col;
+ int b = (index / width_col / height_col) % batch_size;
+ int w_in = w_out * stride_w - pad_w;
+ int h_in = h_out * stride_h - pad_h;
+
+ const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) *
+ 2 * kernel_h * kernel_w * height_col * width_col;
+ const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out;
+ const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out;
+ const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];
+ const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];
+ const scalar_t cur_inv_h_data = h_in + i * dilation_h + offset_h;
+ const scalar_t cur_inv_w_data = w_in + j * dilation_w + offset_w;
+
+ const scalar_t cur_top_grad = data_col[index];
+ const int cur_h = (int)cur_inv_h_data;
+ const int cur_w = (int)cur_inv_w_data;
+ for (int dy = -2; dy <= 2; dy++)
+ {
+ for (int dx = -2; dx <= 2; dx++)
+ {
+ if (cur_h + dy >= 0 && cur_h + dy < height &&
+ cur_w + dx >= 0 && cur_w + dx < width &&
+ abs(cur_inv_h_data - (cur_h + dy)) < 1 &&
+ abs(cur_inv_w_data - (cur_w + dx)) < 1)
+ {
+ int cur_bottom_grad_pos = ((b * channels + c) * height + cur_h + dy) * width + cur_w + dx;
+ scalar_t weight = get_gradient_weight(cur_inv_h_data, cur_inv_w_data, cur_h + dy, cur_w + dx, height, width);
+ atomicAdd(grad_im + cur_bottom_grad_pos, weight * cur_top_grad);
+ }
+ }
+ }
+ }
+}
+
+void deformable_col2im(
+ const at::Tensor data_col, const at::Tensor data_offset, const int channels,
+ const int height, const int width, const int ksize_h,
+ const int ksize_w, const int pad_h, const int pad_w,
+ const int stride_h, const int stride_w,
+ const int dilation_h, const int dilation_w,
+ const int parallel_imgs, const int deformable_group,
+ at::Tensor grad_im)
+{
+
+ // todo: make sure parallel_imgs is passed in correctly
+ int height_col = (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1;
+ int width_col = (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1;
+ int num_kernels = channels * ksize_h * ksize_w * height_col * width_col * parallel_imgs;
+ int channel_per_deformable_group = channels / deformable_group;
+
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+ data_col.type(), "deformable_col2im_gpu", ([&] {
+ const scalar_t *data_col_ = data_col.data();
+ const scalar_t *data_offset_ = data_offset.data();
+ scalar_t *grad_im_ = grad_im.data();
+
+ deformable_col2im_gpu_kernel<<>>(
+ num_kernels, data_col_, data_offset_, channels, height, width, ksize_h,
+ ksize_w, pad_h, pad_w, stride_h, stride_w,
+ dilation_h, dilation_w, channel_per_deformable_group,
+ parallel_imgs, deformable_group, height_col, width_col, grad_im_);
+ }));
+
+ cudaError_t err = cudaGetLastError();
+ if (err != cudaSuccess)
+ {
+ printf("error in deformable_col2im: %s\n", cudaGetErrorString(err));
+ }
+}
+
+template
+__global__ void deformable_col2im_coord_gpu_kernel(const int n, const scalar_t *data_col,
+ const scalar_t *data_im, const scalar_t *data_offset,
+ const int channels, const int height, const int width,
+ const int kernel_h, const int kernel_w,
+ const int pad_h, const int pad_w,
+ const int stride_h, const int stride_w,
+ const int dilation_h, const int dilation_w,
+ const int channel_per_deformable_group,
+ const int batch_size, const int offset_channels, const int deformable_group,
+ const int height_col, const int width_col, scalar_t *grad_offset)
+{
+ CUDA_KERNEL_LOOP(index, n)
+ {
+ scalar_t val = 0;
+ int w = index % width_col;
+ int h = (index / width_col) % height_col;
+ int c = (index / width_col / height_col) % offset_channels;
+ int b = (index / width_col / height_col) / offset_channels;
+ // compute the start and end of the output
+
+ const int deformable_group_index = c / (2 * kernel_h * kernel_w);
+ const int col_step = kernel_h * kernel_w;
+ int cnt = 0;
+ const scalar_t *data_col_ptr = data_col + deformable_group_index * channel_per_deformable_group *
+ batch_size * width_col * height_col;
+ const scalar_t *data_im_ptr = data_im + (b * deformable_group + deformable_group_index) *
+ channel_per_deformable_group / kernel_h / kernel_w * height * width;
+ const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 *
+ kernel_h * kernel_w * height_col * width_col;
+
+ const int offset_c = c - deformable_group_index * 2 * kernel_h * kernel_w;
+
+ for (int col_c = (offset_c / 2); col_c < channel_per_deformable_group; col_c += col_step)
+ {
+ const int col_pos = (((col_c * batch_size + b) * height_col) + h) * width_col + w;
+ const int bp_dir = offset_c % 2;
+
+ int j = (col_pos / width_col / height_col / batch_size) % kernel_w;
+ int i = (col_pos / width_col / height_col / batch_size / kernel_w) % kernel_h;
+ int w_out = col_pos % width_col;
+ int h_out = (col_pos / width_col) % height_col;
+ int w_in = w_out * stride_w - pad_w;
+ int h_in = h_out * stride_h - pad_h;
+ const int data_offset_h_ptr = (((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out);
+ const int data_offset_w_ptr = (((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out);
+ const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];
+ const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];
+ scalar_t inv_h = h_in + i * dilation_h + offset_h;
+ scalar_t inv_w = w_in + j * dilation_w + offset_w;
+ if (inv_h <= -1 || inv_w <= -1 || inv_h >= height || inv_w >= width)
+ {
+ inv_h = inv_w = -2;
+ }
+ const scalar_t weight = get_coordinate_weight(
+ inv_h, inv_w,
+ height, width, data_im_ptr + cnt * height * width, width, bp_dir);
+ val += weight * data_col_ptr[col_pos];
+ cnt += 1;
+ }
+
+ grad_offset[index] = val;
+ }
+}
+
+void deformable_col2im_coord(
+ const at::Tensor data_col, const at::Tensor data_im, const at::Tensor data_offset,
+ const int channels, const int height, const int width, const int ksize_h,
+ const int ksize_w, const int pad_h, const int pad_w, const int stride_h,
+ const int stride_w, const int dilation_h, const int dilation_w,
+ const int parallel_imgs, const int deformable_group, at::Tensor grad_offset)
+{
+
+ int height_col = (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1;
+ int width_col = (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1;
+ int num_kernels = height_col * width_col * 2 * ksize_h * ksize_w * deformable_group * parallel_imgs;
+ int channel_per_deformable_group = channels * ksize_h * ksize_w / deformable_group;
+
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+ data_col.type(), "deformable_col2im_coord_gpu", ([&] {
+ const scalar_t *data_col_ = data_col.data();
+ const scalar_t *data_im_ = data_im.data();
+ const scalar_t *data_offset_ = data_offset.data();
+ scalar_t *grad_offset_ = grad_offset.data();
+
+ deformable_col2im_coord_gpu_kernel<<>>(
+ num_kernels, data_col_, data_im_, data_offset_, channels, height, width,
+ ksize_h, ksize_w, pad_h, pad_w, stride_h, stride_w,
+ dilation_h, dilation_w, channel_per_deformable_group,
+ parallel_imgs, 2 * ksize_h * ksize_w * deformable_group, deformable_group,
+ height_col, width_col, grad_offset_);
+ }));
+}
+
+template
+__device__ scalar_t dmcn_im2col_bilinear(const scalar_t *bottom_data, const int data_width,
+ const int height, const int width, scalar_t h, scalar_t w)
+{
+ int h_low = floor(h);
+ int w_low = floor(w);
+ int h_high = h_low + 1;
+ int w_high = w_low + 1;
+
+ scalar_t lh = h - h_low;
+ scalar_t lw = w - w_low;
+ scalar_t hh = 1 - lh, hw = 1 - lw;
+
+ scalar_t v1 = 0;
+ if (h_low >= 0 && w_low >= 0)
+ v1 = bottom_data[h_low * data_width + w_low];
+ scalar_t v2 = 0;
+ if (h_low >= 0 && w_high <= width - 1)
+ v2 = bottom_data[h_low * data_width + w_high];
+ scalar_t v3 = 0;
+ if (h_high <= height - 1 && w_low >= 0)
+ v3 = bottom_data[h_high * data_width + w_low];
+ scalar_t v4 = 0;
+ if (h_high <= height - 1 && w_high <= width - 1)
+ v4 = bottom_data[h_high * data_width + w_high];
+
+ scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
+
+ scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
+ return val;
+}
+
+template
+__device__ scalar_t dmcn_get_gradient_weight(scalar_t argmax_h, scalar_t argmax_w,
+ const int h, const int w, const int height, const int width)
+{
+ if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width)
+ {
+ //empty
+ return 0;
+ }
+
+ int argmax_h_low = floor(argmax_h);
+ int argmax_w_low = floor(argmax_w);
+ int argmax_h_high = argmax_h_low + 1;
+ int argmax_w_high = argmax_w_low + 1;
+
+ scalar_t weight = 0;
+ if (h == argmax_h_low && w == argmax_w_low)
+ weight = (h + 1 - argmax_h) * (w + 1 - argmax_w);
+ if (h == argmax_h_low && w == argmax_w_high)
+ weight = (h + 1 - argmax_h) * (argmax_w + 1 - w);
+ if (h == argmax_h_high && w == argmax_w_low)
+ weight = (argmax_h + 1 - h) * (w + 1 - argmax_w);
+ if (h == argmax_h_high && w == argmax_w_high)
+ weight = (argmax_h + 1 - h) * (argmax_w + 1 - w);
+ return weight;
+}
+
+template
+__device__ scalar_t dmcn_get_coordinate_weight(scalar_t argmax_h, scalar_t argmax_w,
+ const int height, const int width, const scalar_t *im_data,
+ const int data_width, const int bp_dir)
+{
+ if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width)
+ {
+ //empty
+ return 0;
+ }
+
+ int argmax_h_low = floor(argmax_h);
+ int argmax_w_low = floor(argmax_w);
+ int argmax_h_high = argmax_h_low + 1;
+ int argmax_w_high = argmax_w_low + 1;
+
+ scalar_t weight = 0;
+
+ if (bp_dir == 0)
+ {
+ if (argmax_h_low >= 0 && argmax_w_low >= 0)
+ weight += -1 * (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_low * data_width + argmax_w_low];
+ if (argmax_h_low >= 0 && argmax_w_high <= width - 1)
+ weight += -1 * (argmax_w - argmax_w_low) * im_data[argmax_h_low * data_width + argmax_w_high];
+ if (argmax_h_high <= height - 1 && argmax_w_low >= 0)
+ weight += (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_high * data_width + argmax_w_low];
+ if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)
+ weight += (argmax_w - argmax_w_low) * im_data[argmax_h_high * data_width + argmax_w_high];
+ }
+ else if (bp_dir == 1)
+ {
+ if (argmax_h_low >= 0 && argmax_w_low >= 0)
+ weight += -1 * (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_low];
+ if (argmax_h_low >= 0 && argmax_w_high <= width - 1)
+ weight += (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_high];
+ if (argmax_h_high <= height - 1 && argmax_w_low >= 0)
+ weight += -1 * (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_low];
+ if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)
+ weight += (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_high];
+ }
+
+ return weight;
+}
+
+template
+__global__ void modulated_deformable_im2col_gpu_kernel(const int n,
+ const scalar_t *data_im, const scalar_t *data_offset, const scalar_t *data_mask,
+ const int height, const int width, const int kernel_h, const int kernel_w,
+ const int pad_h, const int pad_w,
+ const int stride_h, const int stride_w,
+ const int dilation_h, const int dilation_w,
+ const int channel_per_deformable_group,
+ const int batch_size, const int num_channels, const int deformable_group,
+ const int height_col, const int width_col,
+ scalar_t *data_col)
+{
+ CUDA_KERNEL_LOOP(index, n)
+ {
+ // index index of output matrix
+ const int w_col = index % width_col;
+ const int h_col = (index / width_col) % height_col;
+ const int b_col = (index / width_col / height_col) % batch_size;
+ const int c_im = (index / width_col / height_col) / batch_size;
+ const int c_col = c_im * kernel_h * kernel_w;
+
+ // compute deformable group index
+ const int deformable_group_index = c_im / channel_per_deformable_group;
+
+ const int h_in = h_col * stride_h - pad_h;
+ const int w_in = w_col * stride_w - pad_w;
+
+ scalar_t *data_col_ptr = data_col + ((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col;
+ //const float* data_im_ptr = data_im + ((b_col * num_channels + c_im) * height + h_in) * width + w_in;
+ const scalar_t *data_im_ptr = data_im + (b_col * num_channels + c_im) * height * width;
+ const scalar_t *data_offset_ptr = data_offset + (b_col * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col;
+
+ const scalar_t *data_mask_ptr = data_mask + (b_col * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col;
+
+ for (int i = 0; i < kernel_h; ++i)
+ {
+ for (int j = 0; j < kernel_w; ++j)
+ {
+ const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col;
+ const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col + w_col;
+ const int data_mask_hw_ptr = ((i * kernel_w + j) * height_col + h_col) * width_col + w_col;
+ const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];
+ const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];
+ const scalar_t mask = data_mask_ptr[data_mask_hw_ptr];
+ scalar_t val = static_cast(0);
+ const scalar_t h_im = h_in + i * dilation_h + offset_h;
+ const scalar_t w_im = w_in + j * dilation_w + offset_w;
+ //if (h_im >= 0 && w_im >= 0 && h_im < height && w_im < width) {
+ if (h_im > -1 && w_im > -1 && h_im < height && w_im < width)
+ {
+ //const float map_h = i * dilation_h + offset_h;
+ //const float map_w = j * dilation_w + offset_w;
+ //const int cur_height = height - h_in;
+ //const int cur_width = width - w_in;
+ //val = dmcn_im2col_bilinear(data_im_ptr, width, cur_height, cur_width, map_h, map_w);
+ val = dmcn_im2col_bilinear(data_im_ptr, width, height, width, h_im, w_im);
+ }
+ *data_col_ptr = val * mask;
+ data_col_ptr += batch_size * height_col * width_col;
+ //data_col_ptr += height_col * width_col;
+ }
+ }
+ }
+}
+
+template
+__global__ void modulated_deformable_col2im_gpu_kernel(const int n,
+ const scalar_t *data_col, const scalar_t *data_offset, const scalar_t *data_mask,
+ const int channels, const int height, const int width,
+ const int kernel_h, const int kernel_w,
+ const int pad_h, const int pad_w,
+ const int stride_h, const int stride_w,
+ const int dilation_h, const int dilation_w,
+ const int channel_per_deformable_group,
+ const int batch_size, const int deformable_group,
+ const int height_col, const int width_col,
+ scalar_t *grad_im)
+{
+ CUDA_KERNEL_LOOP(index, n)
+ {
+ const int j = (index / width_col / height_col / batch_size) % kernel_w;
+ const int i = (index / width_col / height_col / batch_size / kernel_w) % kernel_h;
+ const int c = index / width_col / height_col / batch_size / kernel_w / kernel_h;
+ // compute the start and end of the output
+
+ const int deformable_group_index = c / channel_per_deformable_group;
+
+ int w_out = index % width_col;
+ int h_out = (index / width_col) % height_col;
+ int b = (index / width_col / height_col) % batch_size;
+ int w_in = w_out * stride_w - pad_w;
+ int h_in = h_out * stride_h - pad_h;
+
+ const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col;
+ const scalar_t *data_mask_ptr = data_mask + (b * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col;
+ const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out;
+ const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out;
+ const int data_mask_hw_ptr = ((i * kernel_w + j) * height_col + h_out) * width_col + w_out;
+ const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];
+ const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];
+ const scalar_t mask = data_mask_ptr[data_mask_hw_ptr];
+ const scalar_t cur_inv_h_data = h_in + i * dilation_h + offset_h;
+ const scalar_t cur_inv_w_data = w_in + j * dilation_w + offset_w;
+
+ const scalar_t cur_top_grad = data_col[index] * mask;
+ const int cur_h = (int)cur_inv_h_data;
+ const int cur_w = (int)cur_inv_w_data;
+ for (int dy = -2; dy <= 2; dy++)
+ {
+ for (int dx = -2; dx <= 2; dx++)
+ {
+ if (cur_h + dy >= 0 && cur_h + dy < height &&
+ cur_w + dx >= 0 && cur_w + dx < width &&
+ abs(cur_inv_h_data - (cur_h + dy)) < 1 &&
+ abs(cur_inv_w_data - (cur_w + dx)) < 1)
+ {
+ int cur_bottom_grad_pos = ((b * channels + c) * height + cur_h + dy) * width + cur_w + dx;
+ scalar_t weight = dmcn_get_gradient_weight(cur_inv_h_data, cur_inv_w_data, cur_h + dy, cur_w + dx, height, width);
+ atomicAdd(grad_im + cur_bottom_grad_pos, weight * cur_top_grad);
+ }
+ }
+ }
+ }
+}
+
+template
+__global__ void modulated_deformable_col2im_coord_gpu_kernel(const int n,
+ const scalar_t *data_col, const scalar_t *data_im,
+ const scalar_t *data_offset, const scalar_t *data_mask,
+ const int channels, const int height, const int width,
+ const int kernel_h, const int kernel_w,
+ const int pad_h, const int pad_w,
+ const int stride_h, const int stride_w,
+ const int dilation_h, const int dilation_w,
+ const int channel_per_deformable_group,
+ const int batch_size, const int offset_channels, const int deformable_group,
+ const int height_col, const int width_col,
+ scalar_t *grad_offset, scalar_t *grad_mask)
+{
+ CUDA_KERNEL_LOOP(index, n)
+ {
+ scalar_t val = 0, mval = 0;
+ int w = index % width_col;
+ int h = (index / width_col) % height_col;
+ int c = (index / width_col / height_col) % offset_channels;
+ int b = (index / width_col / height_col) / offset_channels;
+ // compute the start and end of the output
+
+ const int deformable_group_index = c / (2 * kernel_h * kernel_w);
+ const int col_step = kernel_h * kernel_w;
+ int cnt = 0;
+ const scalar_t *data_col_ptr = data_col + deformable_group_index * channel_per_deformable_group * batch_size * width_col * height_col;
+ const scalar_t *data_im_ptr = data_im + (b * deformable_group + deformable_group_index) * channel_per_deformable_group / kernel_h / kernel_w * height * width;
+ const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col;
+ const scalar_t *data_mask_ptr = data_mask + (b * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col;
+
+ const int offset_c = c - deformable_group_index * 2 * kernel_h * kernel_w;
+
+ for (int col_c = (offset_c / 2); col_c < channel_per_deformable_group; col_c += col_step)
+ {
+ const int col_pos = (((col_c * batch_size + b) * height_col) + h) * width_col + w;
+ const int bp_dir = offset_c % 2;
+
+ int j = (col_pos / width_col / height_col / batch_size) % kernel_w;
+ int i = (col_pos / width_col / height_col / batch_size / kernel_w) % kernel_h;
+ int w_out = col_pos % width_col;
+ int h_out = (col_pos / width_col) % height_col;
+ int w_in = w_out * stride_w - pad_w;
+ int h_in = h_out * stride_h - pad_h;
+ const int data_offset_h_ptr = (((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out);
+ const int data_offset_w_ptr = (((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out);
+ const int data_mask_hw_ptr = (((i * kernel_w + j) * height_col + h_out) * width_col + w_out);
+ const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];
+ const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];
+ const scalar_t mask = data_mask_ptr[data_mask_hw_ptr];
+ scalar_t inv_h = h_in + i * dilation_h + offset_h;
+ scalar_t inv_w = w_in + j * dilation_w + offset_w;
+ if (inv_h <= -1 || inv_w <= -1 || inv_h >= height || inv_w >= width)
+ {
+ inv_h = inv_w = -2;
+ }
+ else
+ {
+ mval += data_col_ptr[col_pos] * dmcn_im2col_bilinear(data_im_ptr + cnt * height * width, width, height, width, inv_h, inv_w);
+ }
+ const scalar_t weight = dmcn_get_coordinate_weight(
+ inv_h, inv_w,
+ height, width, data_im_ptr + cnt * height * width, width, bp_dir);
+ val += weight * data_col_ptr[col_pos] * mask;
+ cnt += 1;
+ }
+ // KERNEL_ASSIGN(grad_offset[index], offset_req, val);
+ grad_offset[index] = val;
+ if (offset_c % 2 == 0)
+ // KERNEL_ASSIGN(grad_mask[(((b * deformable_group + deformable_group_index) * kernel_h * kernel_w + offset_c / 2) * height_col + h) * width_col + w], mask_req, mval);
+ grad_mask[(((b * deformable_group + deformable_group_index) * kernel_h * kernel_w + offset_c / 2) * height_col + h) * width_col + w] = mval;
+ }
+}
+
+void modulated_deformable_im2col_cuda(
+ const at::Tensor data_im, const at::Tensor data_offset, const at::Tensor data_mask,
+ const int batch_size, const int channels, const int height_im, const int width_im,
+ const int height_col, const int width_col, const int kernel_h, const int kenerl_w,
+ const int pad_h, const int pad_w, const int stride_h, const int stride_w,
+ const int dilation_h, const int dilation_w,
+ const int deformable_group, at::Tensor data_col)
+{
+ // num_axes should be smaller than block size
+ const int channel_per_deformable_group = channels / deformable_group;
+ const int num_kernels = channels * batch_size * height_col * width_col;
+
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+ data_im.type(), "modulated_deformable_im2col_gpu", ([&] {
+ const scalar_t *data_im_ = data_im.data();
+ const scalar_t *data_offset_ = data_offset.data();
+ const scalar_t *data_mask_ = data_mask.data();
+ scalar_t *data_col_ = data_col.data();
+
+ modulated_deformable_im2col_gpu_kernel<<>>(
+ num_kernels, data_im_, data_offset_, data_mask_, height_im, width_im, kernel_h, kenerl_w,
+ pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, channel_per_deformable_group,
+ batch_size, channels, deformable_group, height_col, width_col, data_col_);
+ }));
+
+ cudaError_t err = cudaGetLastError();
+ if (err != cudaSuccess)
+ {
+ // printf("error in modulated_deformable_im2col_cuda: %s\n", cudaGetErrorString(err));
+ }
+}
+
+void modulated_deformable_col2im_cuda(
+ const at::Tensor data_col, const at::Tensor data_offset, const at::Tensor data_mask,
+ const int batch_size, const int channels, const int height_im, const int width_im,
+ const int height_col, const int width_col, const int kernel_h, const int kernel_w,
+ const int pad_h, const int pad_w, const int stride_h, const int stride_w,
+ const int dilation_h, const int dilation_w,
+ const int deformable_group, at::Tensor grad_im)
+{
+
+ const int channel_per_deformable_group = channels / deformable_group;
+ const int num_kernels = channels * kernel_h * kernel_w * batch_size * height_col * width_col;
+
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+ data_col.type(), "modulated_deformable_col2im_gpu", ([&] {
+ const scalar_t *data_col_ = data_col.data();
+ const scalar_t *data_offset_ = data_offset.data();
+ const scalar_t *data_mask_ = data_mask.data();
+ scalar_t *grad_im_ = grad_im.data();
+
+ modulated_deformable_col2im_gpu_kernel<<>>(
+ num_kernels, data_col_, data_offset_, data_mask_, channels, height_im, width_im,
+ kernel_h, kernel_w, pad_h, pad_h, stride_h, stride_w,
+ dilation_h, dilation_w, channel_per_deformable_group,
+ batch_size, deformable_group, height_col, width_col, grad_im_);
+ }));
+
+ cudaError_t err = cudaGetLastError();
+ if (err != cudaSuccess)
+ {
+ printf("error in modulated_deformable_col2im_cuda: %s\n", cudaGetErrorString(err));
+ }
+}
+
+void modulated_deformable_col2im_coord_cuda(
+ const at::Tensor data_col, const at::Tensor data_im, const at::Tensor data_offset, const at::Tensor data_mask,
+ const int batch_size, const int channels, const int height_im, const int width_im,
+ const int height_col, const int width_col, const int kernel_h, const int kernel_w,
+ const int pad_h, const int pad_w, const int stride_h, const int stride_w,
+ const int dilation_h, const int dilation_w,
+ const int deformable_group,
+ at::Tensor grad_offset, at::Tensor grad_mask)
+{
+ const int num_kernels = batch_size * height_col * width_col * 2 * kernel_h * kernel_w * deformable_group;
+ const int channel_per_deformable_group = channels * kernel_h * kernel_w / deformable_group;
+
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+ data_col.type(), "modulated_deformable_col2im_coord_gpu", ([&] {
+ const scalar_t *data_col_ = data_col.data();
+ const scalar_t *data_im_ = data_im.data();
+ const scalar_t *data_offset_ = data_offset.data();
+ const scalar_t *data_mask_ = data_mask.data();
+ scalar_t *grad_offset_ = grad_offset.data();
+ scalar_t *grad_mask_ = grad_mask.data();
+
+ modulated_deformable_col2im_coord_gpu_kernel<<>>(
+ num_kernels, data_col_, data_im_, data_offset_, data_mask_, channels, height_im, width_im,
+ kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
+ dilation_h, dilation_w, channel_per_deformable_group,
+ batch_size, 2 * kernel_h * kernel_w * deformable_group, deformable_group, height_col, width_col,
+ grad_offset_, grad_mask_);
+ }));
+ cudaError_t err = cudaGetLastError();
+ if (err != cudaSuccess)
+ {
+ printf("error in modulated_deformable_col2im_coord_cuda: %s\n", cudaGetErrorString(err));
+ }
+}
diff --git a/IndicPhotoOCR/detection/textbpn/network/backbone/assets/dcn/src/deform_pool_cuda.cpp b/IndicPhotoOCR/detection/textbpn/network/backbone/assets/dcn/src/deform_pool_cuda.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..e19cf42aee6149a52d45c54f09dcb9afdc9dbe92
--- /dev/null
+++ b/IndicPhotoOCR/detection/textbpn/network/backbone/assets/dcn/src/deform_pool_cuda.cpp
@@ -0,0 +1,87 @@
+// modify from
+// https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/modulated_dcn_cuda.c
+
+// based on
+// author: Charles Shang
+// https://github.com/torch/cunn/blob/master/lib/THCUNN/generic/SpatialConvolutionMM.cu
+
+#include
+
+#include
+#include
+
+void DeformablePSROIPoolForward(
+ const at::Tensor data, const at::Tensor bbox, const at::Tensor trans,
+ at::Tensor out, at::Tensor top_count, const int batch, const int channels,
+ const int height, const int width, const int num_bbox,
+ const int channels_trans, const int no_trans, const float spatial_scale,
+ const int output_dim, const int group_size, const int pooled_size,
+ const int part_size, const int sample_per_part, const float trans_std);
+
+void DeformablePSROIPoolBackwardAcc(
+ const at::Tensor out_grad, const at::Tensor data, const at::Tensor bbox,
+ const at::Tensor trans, const at::Tensor top_count, at::Tensor in_grad,
+ at::Tensor trans_grad, const int batch, const int channels,
+ const int height, const int width, const int num_bbox,
+ const int channels_trans, const int no_trans, const float spatial_scale,
+ const int output_dim, const int group_size, const int pooled_size,
+ const int part_size, const int sample_per_part, const float trans_std);
+
+void deform_psroi_pooling_cuda_forward(
+ at::Tensor input, at::Tensor bbox, at::Tensor trans, at::Tensor out,
+ at::Tensor top_count, const int no_trans, const float spatial_scale,
+ const int output_dim, const int group_size, const int pooled_size,
+ const int part_size, const int sample_per_part, const float trans_std) {
+ TORCH_CHECK(input.is_contiguous(), "input tensor has to be contiguous");
+
+ const int batch = input.size(0);
+ const int channels = input.size(1);
+ const int height = input.size(2);
+ const int width = input.size(3);
+ const int channels_trans = no_trans ? 2 : trans.size(1);
+
+ const int num_bbox = bbox.size(0);
+ if (num_bbox != out.size(0))
+ AT_ERROR("Output shape and bbox number wont match: (%d vs %d).",
+ out.size(0), num_bbox);
+
+ DeformablePSROIPoolForward(
+ input, bbox, trans, out, top_count, batch, channels, height, width,
+ num_bbox, channels_trans, no_trans, spatial_scale, output_dim, group_size,
+ pooled_size, part_size, sample_per_part, trans_std);
+}
+
+void deform_psroi_pooling_cuda_backward(
+ at::Tensor out_grad, at::Tensor input, at::Tensor bbox, at::Tensor trans,
+ at::Tensor top_count, at::Tensor input_grad, at::Tensor trans_grad,
+ const int no_trans, const float spatial_scale, const int output_dim,
+ const int group_size, const int pooled_size, const int part_size,
+ const int sample_per_part, const float trans_std) {
+ TORCH_CHECK(out_grad.is_contiguous(), "out_grad tensor has to be contiguous");
+ TORCH_CHECK(input.is_contiguous(), "input tensor has to be contiguous");
+
+ const int batch = input.size(0);
+ const int channels = input.size(1);
+ const int height = input.size(2);
+ const int width = input.size(3);
+ const int channels_trans = no_trans ? 2 : trans.size(1);
+
+ const int num_bbox = bbox.size(0);
+ if (num_bbox != out_grad.size(0))
+ AT_ERROR("Output shape and bbox number wont match: (%d vs %d).",
+ out_grad.size(0), num_bbox);
+
+ DeformablePSROIPoolBackwardAcc(
+ out_grad, input, bbox, trans, top_count, input_grad, trans_grad, batch,
+ channels, height, width, num_bbox, channels_trans, no_trans,
+ spatial_scale, output_dim, group_size, pooled_size, part_size,
+ sample_per_part, trans_std);
+}
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+ m.def("deform_psroi_pooling_cuda_forward", &deform_psroi_pooling_cuda_forward,
+ "deform psroi pooling forward(CUDA)");
+ m.def("deform_psroi_pooling_cuda_backward",
+ &deform_psroi_pooling_cuda_backward,
+ "deform psroi pooling backward(CUDA)");
+}
diff --git a/IndicPhotoOCR/detection/textbpn/network/backbone/assets/dcn/src/deform_pool_cuda_kernel.cu b/IndicPhotoOCR/detection/textbpn/network/backbone/assets/dcn/src/deform_pool_cuda_kernel.cu
new file mode 100644
index 0000000000000000000000000000000000000000..e49446005679c0d8d7b7bd6fb84250325c37828f
--- /dev/null
+++ b/IndicPhotoOCR/detection/textbpn/network/backbone/assets/dcn/src/deform_pool_cuda_kernel.cu
@@ -0,0 +1,364 @@
+/*!
+ * Copyright (c) 2017 Microsoft
+ * Licensed under The MIT License [see LICENSE for details]
+ * \file deformable_psroi_pooling.cu
+ * \brief
+ * \author Yi Li, Guodong Zhang, Jifeng Dai
+*/
+/***************** Adapted by Charles Shang *********************/
+// modify from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/cuda/deform_psroi_pooling_cuda.cu
+
+#include
+#include
+#include
+#include
+#include
+
+using namespace at;
+
+#define CUDA_KERNEL_LOOP(i, n) \
+ for (int i = blockIdx.x * blockDim.x + threadIdx.x; \
+ i < (n); \
+ i += blockDim.x * gridDim.x)
+
+const int CUDA_NUM_THREADS = 1024;
+inline int GET_BLOCKS(const int N)
+{
+ return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS;
+}
+
+template
+__device__ scalar_t bilinear_interp(
+ const scalar_t *data,
+ const scalar_t x,
+ const scalar_t y,
+ const int width,
+ const int height)
+{
+ int x1 = floor(x);
+ int x2 = ceil(x);
+ int y1 = floor(y);
+ int y2 = ceil(y);
+ scalar_t dist_x = (scalar_t)(x - x1);
+ scalar_t dist_y = (scalar_t)(y - y1);
+ scalar_t value11 = data[y1 * width + x1];
+ scalar_t value12 = data[y2 * width + x1];
+ scalar_t value21 = data[y1 * width + x2];
+ scalar_t value22 = data[y2 * width + x2];
+ scalar_t value = (1 - dist_x) * (1 - dist_y) * value11 + (1 - dist_x) * dist_y * value12 + dist_x * (1 - dist_y) * value21 + dist_x * dist_y * value22;
+ return value;
+}
+
+template
+__global__ void DeformablePSROIPoolForwardKernel(
+ const int count,
+ const scalar_t *bottom_data,
+ const scalar_t spatial_scale,
+ const int channels,
+ const int height, const int width,
+ const int pooled_height, const int pooled_width,
+ const scalar_t *bottom_rois, const scalar_t *bottom_trans,
+ const int no_trans,
+ const scalar_t trans_std,
+ const int sample_per_part,
+ const int output_dim,
+ const int group_size,
+ const int part_size,
+ const int num_classes,
+ const int channels_each_class,
+ scalar_t *top_data,
+ scalar_t *top_count)
+{
+ CUDA_KERNEL_LOOP(index, count)
+ {
+ // The output is in order (n, ctop, ph, pw)
+ int pw = index % pooled_width;
+ int ph = (index / pooled_width) % pooled_height;
+ int ctop = (index / pooled_width / pooled_height) % output_dim;
+ int n = index / pooled_width / pooled_height / output_dim;
+
+ // [start, end) interval for spatial sampling
+ const scalar_t *offset_bottom_rois = bottom_rois + n * 5;
+ int roi_batch_ind = offset_bottom_rois[0];
+ scalar_t roi_start_w = (scalar_t)(round(offset_bottom_rois[1])) * spatial_scale - 0.5;
+ scalar_t roi_start_h = (scalar_t)(round(offset_bottom_rois[2])) * spatial_scale - 0.5;
+ scalar_t roi_end_w = (scalar_t)(round(offset_bottom_rois[3]) + 1.) * spatial_scale - 0.5;
+ scalar_t roi_end_h = (scalar_t)(round(offset_bottom_rois[4]) + 1.) * spatial_scale - 0.5;
+
+ // Force too small ROIs to be 1x1
+ scalar_t roi_width = max(roi_end_w - roi_start_w, 0.1); //avoid 0
+ scalar_t roi_height = max(roi_end_h - roi_start_h, 0.1);
+
+ // Compute w and h at bottom
+ scalar_t bin_size_h = roi_height / (scalar_t)(pooled_height);
+ scalar_t bin_size_w = roi_width / (scalar_t)(pooled_width);
+
+ scalar_t sub_bin_size_h = bin_size_h / (scalar_t)(sample_per_part);
+ scalar_t sub_bin_size_w = bin_size_w / (scalar_t)(sample_per_part);
+
+ int part_h = floor((scalar_t)(ph) / pooled_height * part_size);
+ int part_w = floor((scalar_t)(pw) / pooled_width * part_size);
+ int class_id = ctop / channels_each_class;
+ scalar_t trans_x = no_trans ? (scalar_t)(0) : bottom_trans[(((n * num_classes + class_id) * 2) * part_size + part_h) * part_size + part_w] * (scalar_t)trans_std;
+ scalar_t trans_y = no_trans ? (scalar_t)(0) : bottom_trans[(((n * num_classes + class_id) * 2 + 1) * part_size + part_h) * part_size + part_w] * (scalar_t)trans_std;
+
+ scalar_t wstart = (scalar_t)(pw)*bin_size_w + roi_start_w;
+ wstart += trans_x * roi_width;
+ scalar_t hstart = (scalar_t)(ph)*bin_size_h + roi_start_h;
+ hstart += trans_y * roi_height;
+
+ scalar_t sum = 0;
+ int count = 0;
+ int gw = floor((scalar_t)(pw)*group_size / pooled_width);
+ int gh = floor((scalar_t)(ph)*group_size / pooled_height);
+ gw = min(max(gw, 0), group_size - 1);
+ gh = min(max(gh, 0), group_size - 1);
+
+ const scalar_t *offset_bottom_data = bottom_data + (roi_batch_ind * channels) * height * width;
+ for (int ih = 0; ih < sample_per_part; ih++)
+ {
+ for (int iw = 0; iw < sample_per_part; iw++)
+ {
+ scalar_t w = wstart + iw * sub_bin_size_w;
+ scalar_t h = hstart + ih * sub_bin_size_h;
+ // bilinear interpolation
+ if (w < -0.5 || w > width - 0.5 || h < -0.5 || h > height - 0.5)
+ {
+ continue;
+ }
+ w = min(max(w, 0.), width - 1.);
+ h = min(max(h, 0.), height - 1.);
+ int c = (ctop * group_size + gh) * group_size + gw;
+ scalar_t val = bilinear_interp(offset_bottom_data + c * height * width, w, h, width, height);
+ sum += val;
+ count++;
+ }
+ }
+ top_data[index] = count == 0 ? (scalar_t)(0) : sum / count;
+ top_count[index] = count;
+ }
+}
+
+template
+__global__ void DeformablePSROIPoolBackwardAccKernel(
+ const int count,
+ const scalar_t *top_diff,
+ const scalar_t *top_count,
+ const int num_rois,
+ const scalar_t spatial_scale,
+ const int channels,
+ const int height, const int width,
+ const int pooled_height, const int pooled_width,
+ const int output_dim,
+ scalar_t *bottom_data_diff, scalar_t *bottom_trans_diff,
+ const scalar_t *bottom_data,
+ const scalar_t *bottom_rois,
+ const scalar_t *bottom_trans,
+ const int no_trans,
+ const scalar_t trans_std,
+ const int sample_per_part,
+ const int group_size,
+ const int part_size,
+ const int num_classes,
+ const int channels_each_class)
+{
+ CUDA_KERNEL_LOOP(index, count)
+ {
+ // The output is in order (n, ctop, ph, pw)
+ int pw = index % pooled_width;
+ int ph = (index / pooled_width) % pooled_height;
+ int ctop = (index / pooled_width / pooled_height) % output_dim;
+ int n = index / pooled_width / pooled_height / output_dim;
+
+ // [start, end) interval for spatial sampling
+ const scalar_t *offset_bottom_rois = bottom_rois + n * 5;
+ int roi_batch_ind = offset_bottom_rois[0];
+ scalar_t roi_start_w = (scalar_t)(round(offset_bottom_rois[1])) * spatial_scale - 0.5;
+ scalar_t roi_start_h = (scalar_t)(round(offset_bottom_rois[2])) * spatial_scale - 0.5;
+ scalar_t roi_end_w = (scalar_t)(round(offset_bottom_rois[3]) + 1.) * spatial_scale - 0.5;
+ scalar_t roi_end_h = (scalar_t)(round(offset_bottom_rois[4]) + 1.) * spatial_scale - 0.5;
+
+ // Force too small ROIs to be 1x1
+ scalar_t roi_width = max(roi_end_w - roi_start_w, 0.1); //avoid 0
+ scalar_t roi_height = max(roi_end_h - roi_start_h, 0.1);
+
+ // Compute w and h at bottom
+ scalar_t bin_size_h = roi_height / (scalar_t)(pooled_height);
+ scalar_t bin_size_w = roi_width / (scalar_t)(pooled_width);
+
+ scalar_t sub_bin_size_h = bin_size_h / (scalar_t)(sample_per_part);
+ scalar_t sub_bin_size_w = bin_size_w / (scalar_t)(sample_per_part);
+
+ int part_h = floor((scalar_t)(ph) / pooled_height * part_size);
+ int part_w = floor((scalar_t)(pw) / pooled_width * part_size);
+ int class_id = ctop / channels_each_class;
+ scalar_t trans_x = no_trans ? (scalar_t)(0) : bottom_trans[(((n * num_classes + class_id) * 2) * part_size + part_h) * part_size + part_w] * (scalar_t)trans_std;
+ scalar_t trans_y = no_trans ? (scalar_t)(0) : bottom_trans[(((n * num_classes + class_id) * 2 + 1) * part_size + part_h) * part_size + part_w] * (scalar_t)trans_std;
+
+ scalar_t wstart = (scalar_t)(pw)*bin_size_w + roi_start_w;
+ wstart += trans_x * roi_width;
+ scalar_t hstart = (scalar_t)(ph)*bin_size_h + roi_start_h;
+ hstart += trans_y * roi_height;
+
+ if (top_count[index] <= 0)
+ {
+ continue;
+ }
+ scalar_t diff_val = top_diff[index] / top_count[index];
+ const scalar_t *offset_bottom_data = bottom_data + roi_batch_ind * channels * height * width;
+ scalar_t *offset_bottom_data_diff = bottom_data_diff + roi_batch_ind * channels * height * width;
+ int gw = floor((scalar_t)(pw)*group_size / pooled_width);
+ int gh = floor((scalar_t)(ph)*group_size / pooled_height);
+ gw = min(max(gw, 0), group_size - 1);
+ gh = min(max(gh, 0), group_size - 1);
+
+ for (int ih = 0; ih < sample_per_part; ih++)
+ {
+ for (int iw = 0; iw < sample_per_part; iw++)
+ {
+ scalar_t w = wstart + iw * sub_bin_size_w;
+ scalar_t h = hstart + ih * sub_bin_size_h;
+ // bilinear interpolation
+ if (w < -0.5 || w > width - 0.5 || h < -0.5 || h > height - 0.5)
+ {
+ continue;
+ }
+ w = min(max(w, 0.), width - 1.);
+ h = min(max(h, 0.), height - 1.);
+ int c = (ctop * group_size + gh) * group_size + gw;
+ // backward on feature
+ int x0 = floor(w);
+ int x1 = ceil(w);
+ int y0 = floor(h);
+ int y1 = ceil(h);
+ scalar_t dist_x = w - x0, dist_y = h - y0;
+ scalar_t q00 = (1 - dist_x) * (1 - dist_y);
+ scalar_t q01 = (1 - dist_x) * dist_y;
+ scalar_t q10 = dist_x * (1 - dist_y);
+ scalar_t q11 = dist_x * dist_y;
+ int bottom_index_base = c * height * width;
+ atomicAdd(offset_bottom_data_diff + bottom_index_base + y0 * width + x0, q00 * diff_val);
+ atomicAdd(offset_bottom_data_diff + bottom_index_base + y1 * width + x0, q01 * diff_val);
+ atomicAdd(offset_bottom_data_diff + bottom_index_base + y0 * width + x1, q10 * diff_val);
+ atomicAdd(offset_bottom_data_diff + bottom_index_base + y1 * width + x1, q11 * diff_val);
+
+ if (no_trans)
+ {
+ continue;
+ }
+ scalar_t U00 = offset_bottom_data[bottom_index_base + y0 * width + x0];
+ scalar_t U01 = offset_bottom_data[bottom_index_base + y1 * width + x0];
+ scalar_t U10 = offset_bottom_data[bottom_index_base + y0 * width + x1];
+ scalar_t U11 = offset_bottom_data[bottom_index_base + y1 * width + x1];
+ scalar_t diff_x = (U11 * dist_y + U10 * (1 - dist_y) - U01 * dist_y - U00 * (1 - dist_y)) * trans_std * diff_val;
+ diff_x *= roi_width;
+ scalar_t diff_y = (U11 * dist_x + U01 * (1 - dist_x) - U10 * dist_x - U00 * (1 - dist_x)) * trans_std * diff_val;
+ diff_y *= roi_height;
+
+ atomicAdd(bottom_trans_diff + (((n * num_classes + class_id) * 2) * part_size + part_h) * part_size + part_w, diff_x);
+ atomicAdd(bottom_trans_diff + (((n * num_classes + class_id) * 2 + 1) * part_size + part_h) * part_size + part_w, diff_y);
+ }
+ }
+ }
+}
+
+void DeformablePSROIPoolForward(const at::Tensor data,
+ const at::Tensor bbox,
+ const at::Tensor trans,
+ at::Tensor out,
+ at::Tensor top_count,
+ const int batch,
+ const int channels,
+ const int height,
+ const int width,
+ const int num_bbox,
+ const int channels_trans,
+ const int no_trans,
+ const float spatial_scale,
+ const int output_dim,
+ const int group_size,
+ const int pooled_size,
+ const int part_size,
+ const int sample_per_part,
+ const float trans_std)
+{
+ const int pooled_height = pooled_size;
+ const int pooled_width = pooled_size;
+ const int count = num_bbox * output_dim * pooled_height * pooled_width;
+ const int num_classes = no_trans ? 1 : channels_trans / 2;
+ const int channels_each_class = no_trans ? output_dim : output_dim / num_classes;
+
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+ data.type(), "deformable_psroi_pool_forward", ([&] {
+ const scalar_t *bottom_data = data.data();
+ const scalar_t *bottom_rois = bbox.data();
+ const scalar_t *bottom_trans = no_trans ? NULL : trans.data();
+ scalar_t *top_data = out.data();
+ scalar_t *top_count_data = top_count.data();
+
+ DeformablePSROIPoolForwardKernel<<>>(
+ count, bottom_data, (scalar_t)spatial_scale, channels, height, width, pooled_height, pooled_width,
+ bottom_rois, bottom_trans, no_trans, (scalar_t)trans_std, sample_per_part, output_dim,
+ group_size, part_size, num_classes, channels_each_class, top_data, top_count_data);
+ }));
+
+ cudaError_t err = cudaGetLastError();
+ if (err != cudaSuccess)
+ {
+ printf("error in DeformablePSROIPoolForward: %s\n", cudaGetErrorString(err));
+ }
+}
+
+void DeformablePSROIPoolBackwardAcc(const at::Tensor out_grad,
+ const at::Tensor data,
+ const at::Tensor bbox,
+ const at::Tensor trans,
+ const at::Tensor top_count,
+ at::Tensor in_grad,
+ at::Tensor trans_grad,
+ const int batch,
+ const int channels,
+ const int height,
+ const int width,
+ const int num_bbox,
+ const int channels_trans,
+ const int no_trans,
+ const float spatial_scale,
+ const int output_dim,
+ const int group_size,
+ const int pooled_size,
+ const int part_size,
+ const int sample_per_part,
+ const float trans_std)
+{
+ // LOG(INFO) << "DeformablePSROIPoolBackward";
+ const int num_rois = num_bbox;
+ const int pooled_height = pooled_size;
+ const int pooled_width = pooled_size;
+ const int count = num_bbox * output_dim * pooled_height * pooled_width;
+ const int num_classes = no_trans ? 1 : channels_trans / 2;
+ const int channels_each_class = no_trans ? output_dim : output_dim / num_classes;
+
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+ out_grad.type(), "deformable_psroi_pool_backward_acc", ([&] {
+ const scalar_t *top_diff = out_grad.data();
+ const scalar_t *bottom_data = data.data();
+ const scalar_t *bottom_rois = bbox.data();
+ const scalar_t *bottom_trans = no_trans ? NULL : trans.data();
+ scalar_t *bottom_data_diff = in_grad.data();
+ scalar_t *bottom_trans_diff = no_trans ? NULL : trans_grad.data();
+ const scalar_t *top_count_data = top_count.data();
+
+ DeformablePSROIPoolBackwardAccKernel<<>>(
+ count, top_diff, top_count_data, num_rois, (scalar_t)spatial_scale, channels, height, width,
+ pooled_height, pooled_width, output_dim, bottom_data_diff, bottom_trans_diff,
+ bottom_data, bottom_rois, bottom_trans, no_trans, (scalar_t)trans_std, sample_per_part,
+ group_size, part_size, num_classes, channels_each_class);
+ }));
+
+ cudaError_t err = cudaGetLastError();
+ if (err != cudaSuccess)
+ {
+ printf("error in DeformablePSROIPoolForward: %s\n", cudaGetErrorString(err));
+ }
+}
\ No newline at end of file
diff --git a/IndicPhotoOCR/detection/textbpn/network/backbone/resnet.py b/IndicPhotoOCR/detection/textbpn/network/backbone/resnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..bab4346d6115ace46a085496751291864a576bea
--- /dev/null
+++ b/IndicPhotoOCR/detection/textbpn/network/backbone/resnet.py
@@ -0,0 +1,336 @@
+import torch.nn as nn
+import math
+import torch.utils.model_zoo as model_zoo
+BatchNorm2d = nn.BatchNorm2d
+
+__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
+ 'resnet152']
+
+
+model_urls = {
+ 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
+ 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
+ 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
+ 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
+ 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
+}
+
+
+def constant_init(module, constant, bias=0):
+ nn.init.constant_(module.weight, constant)
+ if hasattr(module, 'bias'):
+ nn.init.constant_(module.bias, bias)
+
+
+def conv3x3(in_planes, out_planes, stride=1):
+ """3x3 convolution with padding"""
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
+ padding=1, bias=False)
+
+
+class BasicBlock(nn.Module):
+ expansion = 1
+
+ def __init__(self, inplanes, planes, stride=1, downsample=None, dcn=None):
+ super(BasicBlock, self).__init__()
+ self.with_dcn = dcn is not None
+ self.conv1 = conv3x3(inplanes, planes, stride)
+ self.bn1 = BatchNorm2d(planes)
+ self.relu = nn.ReLU(inplace=True)
+ self.with_modulated_dcn = False
+ if self.with_dcn:
+ fallback_on_stride = dcn.get('fallback_on_stride', False)
+ self.with_modulated_dcn = dcn.get('modulated', False)
+ # self.conv2 = conv3x3(planes, planes)
+ if not self.with_dcn or fallback_on_stride:
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
+ padding=1, bias=False)
+ else:
+ deformable_groups = dcn.get('deformable_groups', 1)
+ if not self.with_modulated_dcn:
+ from network.backbone.assets.dcn import DeformConv
+ conv_op = DeformConv
+ offset_channels = 18
+ else:
+ from network.backbone.assets.dcn import ModulatedDeformConv
+ conv_op = ModulatedDeformConv
+ offset_channels = 27
+ self.conv2_offset = nn.Conv2d(
+ planes,
+ deformable_groups * offset_channels,
+ kernel_size=3,
+ padding=1)
+ self.conv2 = conv_op(
+ planes,
+ planes,
+ kernel_size=3,
+ padding=1,
+ deformable_groups=deformable_groups,
+ bias=False)
+ self.bn2 = BatchNorm2d(planes)
+ self.downsample = downsample
+ self.stride = stride
+
+ def forward(self, x):
+ residual = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ # out = self.conv2(out)
+ if not self.with_dcn:
+ out = self.conv2(out)
+ elif self.with_modulated_dcn:
+ offset_mask = self.conv2_offset(out)
+ offset = offset_mask[:, :18, :, :]
+ mask = offset_mask[:, -9:, :, :].sigmoid()
+ out = self.conv2(out, offset, mask)
+ else:
+ offset = self.conv2_offset(out)
+ out = self.conv2(out, offset)
+ out = self.bn2(out)
+
+ if self.downsample is not None:
+ residual = self.downsample(x)
+
+ out += residual
+ out = self.relu(out)
+
+ return out
+
+
+class Bottleneck(nn.Module):
+ expansion = 4
+
+ def __init__(self, inplanes, planes, stride=1, downsample=None, dcn=None):
+ super(Bottleneck, self).__init__()
+ self.with_dcn = dcn is not None
+ self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
+ self.bn1 = BatchNorm2d(planes)
+ fallback_on_stride = False
+ self.with_modulated_dcn = False
+ if self.with_dcn:
+ fallback_on_stride = dcn.get('fallback_on_stride', False)
+ self.with_modulated_dcn = dcn.get('modulated', False)
+ if not self.with_dcn or fallback_on_stride:
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
+ stride=stride, padding=1, bias=False)
+ else:
+ deformable_groups = dcn.get('deformable_groups', 1)
+ if not self.with_modulated_dcn:
+ from network.backbone.assets.dcn import DeformConv
+ conv_op = DeformConv
+ offset_channels = 18
+ else:
+ from network.backbone.assets.dcn import ModulatedDeformConv
+ conv_op = ModulatedDeformConv
+ offset_channels = 27
+ self.conv2_offset = nn.Conv2d(
+ planes, deformable_groups * offset_channels,
+ kernel_size=3,
+ padding=1)
+ self.conv2 = conv_op(
+ planes, planes, kernel_size=3, padding=1, stride=stride,
+ deformable_groups=deformable_groups, bias=False)
+ self.bn2 = BatchNorm2d(planes)
+ self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
+ self.bn3 = BatchNorm2d(planes * 4)
+ self.relu = nn.ReLU(inplace=True)
+ self.downsample = downsample
+ self.stride = stride
+ self.dcn = dcn
+ self.with_dcn = dcn is not None
+
+ def forward(self, x):
+ residual = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ # out = self.conv2(out)
+ if not self.with_dcn:
+ out = self.conv2(out)
+ elif self.with_modulated_dcn:
+ offset_mask = self.conv2_offset(out)
+ offset = offset_mask[:, :18, :, :]
+ mask = offset_mask[:, -9:, :, :].sigmoid()
+ out = self.conv2(out, offset, mask)
+ else:
+ offset = self.conv2_offset(out)
+ out = self.conv2(out, offset)
+ out = self.bn2(out)
+ out = self.relu(out)
+
+ out = self.conv3(out)
+ out = self.bn3(out)
+
+ if self.downsample is not None:
+ residual = self.downsample(x)
+
+ out += residual
+ out = self.relu(out)
+
+ return out
+
+
+class ResNet(nn.Module):
+ def __init__(self, block, layers, num_classes=1000,
+ dcn=None, stage_with_dcn=(False, False, False, False)):
+ self.dcn = dcn
+ self.stage_with_dcn = stage_with_dcn
+ self.inplanes = 64
+ super(ResNet, self).__init__()
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
+ bias=False)
+ self.bn1 = BatchNorm2d(64)
+ self.relu = nn.ReLU(inplace=True)
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
+ self.layer1 = self._make_layer(block, 64, layers[0])
+ self.layer2 = self._make_layer(
+ block, 128, layers[1], stride=2, dcn=dcn)
+ self.layer3 = self._make_layer(
+ block, 256, layers[2], stride=2, dcn=dcn)
+ self.layer4 = self._make_layer(
+ block, 512, layers[3], stride=2, dcn=dcn)
+ self.avgpool = nn.AvgPool2d(7, stride=1)
+ self.fc = nn.Linear(512 * block.expansion, num_classes)
+
+ self.smooth = nn.Conv2d(2048, 256, kernel_size=1, stride=1, padding=1)
+
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
+ m.weight.data.normal_(0, math.sqrt(2. / n))
+ elif isinstance(m, BatchNorm2d):
+ m.weight.data.fill_(1)
+ m.bias.data.zero_()
+ if self.dcn is not None:
+ for m in self.modules():
+ if isinstance(m, Bottleneck) or isinstance(m, BasicBlock):
+ if hasattr(m, 'conv2_offset'):
+ constant_init(m.conv2_offset, 0)
+
+ def _make_layer(self, block, planes, blocks, stride=1, dcn=None):
+ downsample = None
+ if stride != 1 or self.inplanes != planes * block.expansion:
+ downsample = nn.Sequential(
+ nn.Conv2d(self.inplanes, planes * block.expansion,
+ kernel_size=1, stride=stride, bias=False),
+ BatchNorm2d(planes * block.expansion),
+ )
+
+ layers = []
+ layers.append(block(self.inplanes, planes,
+ stride, downsample, dcn=dcn))
+ self.inplanes = planes * block.expansion
+ for i in range(1, blocks):
+ layers.append(block(self.inplanes, planes, dcn=dcn))
+
+ return nn.Sequential(*layers)
+
+ def forward(self, x):
+ x = self.conv1(x)
+ x = self.bn1(x)
+ x = self.relu(x)
+ x1 = self.maxpool(x)
+
+ x2 = self.layer1(x1)
+ x3 = self.layer2(x2)
+ x4 = self.layer3(x3)
+ x5 = self.layer4(x4)
+
+ return x1, x2, x3, x4, x5
+
+
+def resnet18(pretrained=True, **kwargs):
+ """Constructs a ResNet-18 model.
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ """
+ model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
+ if pretrained:
+ model.load_state_dict(model_zoo.load_url(
+ model_urls['resnet18']), strict=False)
+ return model
+
+def deformable_resnet18(pretrained=True, **kwargs):
+ """Constructs a ResNet-18 model.
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ """
+ model = ResNet(BasicBlock, [2, 2, 2, 2],
+ dcn=dict(modulated=True,
+ deformable_groups=1,
+ fallback_on_stride=False),
+ stage_with_dcn=[False, True, True, True], **kwargs)
+ if pretrained:
+ model.load_state_dict(model_zoo.load_url(
+ model_urls['resnet18']), strict=False)
+ return model
+
+
+def resnet34(pretrained=True, **kwargs):
+ """Constructs a ResNet-34 model.
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ """
+ model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
+ if pretrained:
+ model.load_state_dict(model_zoo.load_url(
+ model_urls['resnet34']), strict=False)
+ return model
+
+
+def resnet50(pretrained=True, **kwargs):
+ """Constructs a ResNet-50 model.
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ """
+ model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
+ if pretrained:
+ model.load_state_dict(model_zoo.load_url(
+ model_urls['resnet50']), strict=False)
+ return model
+
+
+def deformable_resnet50(pretrained=True, **kwargs):
+ """Constructs a ResNet-50 model with deformable conv.
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ """
+ model = ResNet(Bottleneck, [3, 4, 6, 3],
+ dcn=dict(modulated=True,
+ deformable_groups=1,
+ fallback_on_stride=False),
+ stage_with_dcn=[False, True, True, True],
+ **kwargs)
+ if pretrained:
+ model.load_state_dict(model_zoo.load_url(
+ model_urls['resnet50']), strict=False)
+ return model
+
+
+def resnet101(pretrained=True, **kwargs):
+ """Constructs a ResNet-101 model.
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ """
+ model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
+ if pretrained:
+ model.load_state_dict(model_zoo.load_url(
+ model_urls['resnet101']), strict=False)
+ return model
+
+
+def resnet152(pretrained=True, **kwargs):
+ """Constructs a ResNet-152 model.
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ """
+ model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
+ if pretrained:
+ model.load_state_dict(model_zoo.load_url(
+ model_urls['resnet152']), strict=False)
+ return model
diff --git a/IndicPhotoOCR/detection/textbpn/network/backbone/vgg.py b/IndicPhotoOCR/detection/textbpn/network/backbone/vgg.py
new file mode 100644
index 0000000000000000000000000000000000000000..0932835b7a213614d826dcb832c7adf9d89f07d5
--- /dev/null
+++ b/IndicPhotoOCR/detection/textbpn/network/backbone/vgg.py
@@ -0,0 +1,60 @@
+import torch.nn as nn
+import torch.utils.model_zoo as model_zoo
+import torchvision.models as models
+
+model_urls = {
+ 'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth',
+ 'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth',
+ 'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth',
+ 'vgg11_bn': 'https://download.pytorch.org/models/vgg11_bn-6002323d.pth',
+ 'vgg16_bn': 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth',
+ 'vgg19_bn': 'https://download.pytorch.org/models/vgg19_bn-c79401a0.pth',
+}
+
+
+class VggNet(nn.Module):
+ def __init__(self, name="vgg16", pretrain=True):
+ super().__init__()
+ if name == "vgg16":
+ base_net = models.vgg16(pretrained=False)
+ elif name == "vgg16_bn":
+ base_net = models.vgg16_bn(pretrained=False)
+ else:
+ print(" base model is not support !")
+ if pretrain:
+ print("load the {} weight from ./cache".format(name))
+ base_net.load_state_dict(model_zoo.load_url(model_urls[name], model_dir="./cache"))
+
+ if name == "vgg16":
+ self.stage1 = nn.Sequential(*[base_net.features[layer] for layer in range(0, 5)])
+ self.stage2 = nn.Sequential(*[base_net.features[layer] for layer in range(5, 10)])
+ self.stage3 = nn.Sequential(*[base_net.features[layer] for layer in range(10, 17)])
+ self.stage4 = nn.Sequential(*[base_net.features[layer] for layer in range(17, 24)])
+ self.stage5 = nn.Sequential(*[base_net.features[layer] for layer in range(24, 31)])
+ elif name == "vgg16_bn":
+ self.stage1 = nn.Sequential(*[base_net.features[layer] for layer in range(0, 7)])
+ self.stage2 = nn.Sequential(*[base_net.features[layer] for layer in range(7, 14)])
+ self.stage3 = nn.Sequential(*[base_net.features[layer] for layer in range(14, 24)])
+ self.stage4 = nn.Sequential(*[base_net.features[layer] for layer in range(24, 34)])
+ self.stage5 = nn.Sequential(*[base_net.features[layer] for layer in range(34, 44)])
+
+ def forward(self, x):
+ C1 = self.stage1(x)
+ C2 = self.stage2(C1)
+ C3 = self.stage3(C2)
+ C4 = self.stage4(C3)
+ C5 = self.stage5(C4)
+
+ return C1, C2, C3, C4, C5
+
+
+if __name__ == '__main__':
+ import torch
+ input = torch.randn((4, 3, 512, 512))
+ net = VggNet()
+ C1, C2, C3, C4, C5 = net(input)
+ print(C1.size())
+ print(C2.size())
+ print(C3.size())
+ print(C4.size())
+ print(C5.size())
diff --git a/IndicPhotoOCR/detection/textbpn/network/layers/Adaptive_Deformation.py b/IndicPhotoOCR/detection/textbpn/network/layers/Adaptive_Deformation.py
new file mode 100644
index 0000000000000000000000000000000000000000..89557292a4cec6733f658d397fbe213b864e5685
--- /dev/null
+++ b/IndicPhotoOCR/detection/textbpn/network/layers/Adaptive_Deformation.py
@@ -0,0 +1,88 @@
+###################################################################
+# File Name: AdaptiveDeformation.py
+# Author: S.X.Zhang
+###################################################################
+
+from __future__ import print_function
+from __future__ import division
+from __future__ import absolute_import
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.nn import init
+
+
+class MeanAggregator(nn.Module):
+ def __init__(self):
+ super(MeanAggregator, self).__init__()
+
+ def forward(self, features, A):
+ x = torch.bmm(A, features)
+ return x
+
+
+class GraphConv(nn.Module):
+ def __init__(self, in_dim, out_dim, agg):
+ super(GraphConv, self).__init__()
+ self.in_dim = in_dim
+ self.out_dim = out_dim
+ self.weight = nn.Parameter(torch.FloatTensor(in_dim * 2, out_dim))
+ self.bias = nn.Parameter(torch.FloatTensor(out_dim))
+ init.xavier_uniform_(self.weight)
+ init.constant_(self.bias, 0)
+ self.agg = agg()
+
+ def forward(self, features, A):
+ b, n, d = features.shape
+ assert (d == self.in_dim)
+ agg_feats = self.agg(features, A)
+ cat_feats = torch.cat([features, agg_feats], dim=2)
+ out = torch.einsum('bnd,df->bnf', (cat_feats, self.weight))
+ out = F.relu(out + self.bias)
+ return out
+
+
+class AdaptiveDeformation(nn.Module):
+ def __init__(self, input, state_dim):
+ super(AdaptiveDeformation, self).__init__()
+ self.bn0 = nn.BatchNorm1d(input, affine=False)
+ self.conv1 = nn.Conv1d(input, state_dim, 1)
+ self.rnn = nn.LSTM(input, state_dim, 1, bidirectional=True)
+ self.gconv1 = GraphConv(input, 256, MeanAggregator)
+ self.gconv2 = GraphConv(256, 1024, MeanAggregator)
+ self.gconv3 = GraphConv(1024, 512, MeanAggregator)
+ self.gconv4 = GraphConv(512, state_dim, MeanAggregator)
+
+ self.prediction = nn.Sequential(
+ nn.Conv1d(4*state_dim, 128, 1),
+ nn.ReLU(inplace=True),
+ nn.Dropout(0.1),
+ nn.Conv1d(128, 64, 1),
+ nn.ReLU(inplace=True),
+ nn.Dropout(0.1),
+ nn.Conv1d(64, 2, 1))
+
+ def forward(self, x, A):
+ x = self.bn0(x)
+
+ # # rnn block
+ yl = x.permute(2, 0, 1)
+ yl, _ = self.rnn(yl)
+ yl = yl.permute(1, 2, 0)
+
+ # # gcn block
+ yg = x.permute(0, 2, 1)
+ b, n, c = yg.shape
+ A = A.expand(b, n, n)
+ yg = self.gconv1(yg, A)
+ yg = self.gconv2(yg, A)
+ yg = self.gconv3(yg, A)
+ yg = self.gconv4(yg, A)
+ yg = yg.permute(0, 2, 1)
+
+ # res block
+ x = torch.cat([yl, yg, self.conv1(x)], dim=1)
+ pred = self.prediction(x)
+
+ return pred
diff --git a/IndicPhotoOCR/detection/textbpn/network/layers/CircConv.py b/IndicPhotoOCR/detection/textbpn/network/layers/CircConv.py
new file mode 100644
index 0000000000000000000000000000000000000000..7a4d24d097ad1b9ceb5f92503eef4094e235f0e4
--- /dev/null
+++ b/IndicPhotoOCR/detection/textbpn/network/layers/CircConv.py
@@ -0,0 +1,91 @@
+import torch.nn as nn
+import torch
+
+
+class CircConv(nn.Module):
+ def __init__(self, state_dim, out_state_dim=None, n_adj=4):
+ super(CircConv, self).__init__()
+
+ self.n_adj = n_adj
+ out_state_dim = state_dim if out_state_dim is None else out_state_dim
+ self.fc = nn.Conv1d(state_dim, out_state_dim, kernel_size=self.n_adj*2+1)
+
+ def forward(self, input, adj):
+ input = torch.cat([input[..., -self.n_adj:], input, input[..., :self.n_adj]], dim=2)
+ return self.fc(input)
+
+
+class DilatedCircConv(nn.Module):
+ def __init__(self, state_dim, out_state_dim=None, n_adj=4, dilation=1):
+ super(DilatedCircConv, self).__init__()
+
+ self.n_adj = n_adj
+ self.dilation = dilation
+ out_state_dim = state_dim if out_state_dim is None else out_state_dim
+ self.fc = nn.Conv1d(state_dim, out_state_dim, kernel_size=self.n_adj*2+1, dilation=self.dilation)
+
+ def forward(self, input, adj):
+ if self.n_adj != 0:
+ input = torch.cat([input[..., -self.n_adj*self.dilation:], input, input[..., :self.n_adj*self.dilation]], dim=2)
+ return self.fc(input)
+
+
+_conv_factory = {
+ 'grid': CircConv,
+ 'dgrid': DilatedCircConv
+}
+
+
+class BasicBlock(nn.Module):
+ def __init__(self, state_dim, out_state_dim, conv_type, n_adj=4, dilation=1):
+ super(BasicBlock, self).__init__()
+
+ self.conv = _conv_factory[conv_type](state_dim, out_state_dim, n_adj, dilation)
+ self.relu = nn.ReLU(inplace=True)
+ self.norm = nn.BatchNorm1d(out_state_dim)
+
+ def forward(self, x, adj=None):
+ x = self.conv(x, adj)
+ x = self.relu(x)
+ x = self.norm(x)
+ return x
+
+
+class DeepSnake(nn.Module):
+ def __init__(self, state_dim, feature_dim, conv_type='dgrid'):
+ super(DeepSnake, self).__init__()
+
+ self.head = BasicBlock(feature_dim, state_dim, conv_type)
+
+ self.res_layer_num = 7
+ dilation = [1, 1, 1, 2, 2, 4, 4]
+ for i in range(self.res_layer_num):
+ conv = BasicBlock(state_dim, state_dim, conv_type, n_adj=4, dilation=dilation[i])
+ self.__setattr__('res'+str(i), conv)
+
+ fusion_state_dim = 256
+ self.fusion = nn.Conv1d(state_dim * (self.res_layer_num + 1), fusion_state_dim, 1)
+ self.prediction = nn.Sequential(
+ nn.Conv1d(state_dim * (self.res_layer_num + 1) + fusion_state_dim, 256, 1),
+ nn.ReLU(inplace=True),
+ nn.Conv1d(256, 64, 1),
+ nn.ReLU(inplace=True),
+ nn.Conv1d(64, 2, 1)
+ )
+
+ def forward(self, x, adj):
+ states = []
+
+ x = self.head(x, adj)
+ states.append(x)
+ for i in range(self.res_layer_num):
+ x = self.__getattr__('res'+str(i))(x, adj) + x
+ states.append(x)
+
+ state = torch.cat(states, dim=1)
+ global_state = torch.max(self.fusion(state), dim=2, keepdim=True)[0]
+ global_state = global_state.expand(global_state.size(0), global_state.size(1), state.size(2))
+ state = torch.cat([global_state, state], dim=1)
+ x = self.prediction(state)
+
+ return x
diff --git a/IndicPhotoOCR/detection/textbpn/network/layers/GCN.py b/IndicPhotoOCR/detection/textbpn/network/layers/GCN.py
new file mode 100644
index 0000000000000000000000000000000000000000..f2a5aa045b46e04c7393a1900b0e652de15de41d
--- /dev/null
+++ b/IndicPhotoOCR/detection/textbpn/network/layers/GCN.py
@@ -0,0 +1,77 @@
+###################################################################
+# File Name: GCN.py
+# Author: S.X.Zhang
+###################################################################
+
+from __future__ import print_function
+from __future__ import division
+from __future__ import absolute_import
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.nn import init
+
+
+class MeanAggregator(nn.Module):
+ def __init__(self):
+ super(MeanAggregator, self).__init__()
+
+ def forward(self, features, A):
+ x = torch.bmm(A, features)
+ return x
+
+
+class GraphConv(nn.Module):
+ def __init__(self, in_dim, out_dim, agg):
+ super(GraphConv, self).__init__()
+ self.in_dim = in_dim
+ self.out_dim = out_dim
+ self.weight = nn.Parameter(torch.FloatTensor(in_dim * 2, out_dim))
+ self.bias = nn.Parameter(torch.FloatTensor(out_dim))
+ init.xavier_uniform_(self.weight)
+ init.constant_(self.bias, 0)
+ self.agg = agg()
+
+ def forward(self, features, A):
+ b, n, d = features.shape
+ assert (d == self.in_dim)
+ agg_feats = self.agg(features, A)
+ cat_feats = torch.cat([features, agg_feats], dim=2)
+ out = torch.einsum('bnd,df->bnf', (cat_feats, self.weight))
+ out = F.relu(out + self.bias)
+ return out
+
+
+class GCN(nn.Module):
+ def __init__(self, in_dim, out_dim):
+ super(GCN, self).__init__()
+ self.bn0 = nn.BatchNorm1d(in_dim, affine=False)
+
+ self.conv1 = GraphConv(in_dim, 256, MeanAggregator)
+ self.conv2 = GraphConv(256, 1024, MeanAggregator)
+ self.conv3 = GraphConv(1024, 512, MeanAggregator)
+ self.conv4 = GraphConv(512, out_dim, MeanAggregator)
+
+ self.prediction = nn.Sequential(
+ nn.Conv1d(out_dim, 128, 1),
+ nn.ReLU(inplace=True),
+ nn.Conv1d(128, 64, 1),
+ nn.ReLU(inplace=True),
+ nn.Conv1d(64, 2, 1))
+
+ def forward(self, x, A):
+ x = self.bn0(x)
+ x = x.permute(0, 2, 1)
+ b, n, c = x.shape
+ A = A.expand(b, n, n)
+
+ x = self.conv1(x, A)
+ x = self.conv2(x, A)
+ x = self.conv3(x, A)
+ x = self.conv4(x, A)
+
+ x = x.permute(0, 2, 1)
+ pred = self.prediction(x)
+
+ return pred
diff --git a/IndicPhotoOCR/detection/textbpn/network/layers/GraphConv.py b/IndicPhotoOCR/detection/textbpn/network/layers/GraphConv.py
new file mode 100644
index 0000000000000000000000000000000000000000..5dd9f87320fbb5dbf14af519aff143b6c83a77e1
--- /dev/null
+++ b/IndicPhotoOCR/detection/textbpn/network/layers/GraphConv.py
@@ -0,0 +1,45 @@
+import math
+
+import torch
+from torch.nn.parameter import Parameter
+from torch.nn.modules.module import Module
+from torch.nn import init
+
+
+class GraphConvolution(Module):
+ """
+ Simple GCN layer, similar to https://arxiv.org/abs/1609.02907
+ """
+
+ def __init__(self, in_features, out_features, bias=True):
+ super(GraphConvolution, self).__init__()
+ self.in_features = in_features
+ self.out_features = out_features
+ self.weight = Parameter(torch.FloatTensor(in_features, out_features))
+ init.xavier_uniform_(self.weight)
+ if bias:
+ self.bias = Parameter(torch.FloatTensor(out_features))
+ init.constant_(self.bias, 0)
+ else:
+ self.register_parameter('bias', None)
+
+ self.reset_parameters()
+
+ def reset_parameters(self):
+ stdv = 1. / math.sqrt(self.weight.size(1))
+ self.weight.data.uniform_(-stdv, stdv)
+ if self.bias is not None:
+ self.bias.data.uniform_(-stdv, stdv)
+
+ def forward(self, input, adj):
+ support = torch.mm(input, self.weight)
+ output = torch.spmm(adj, support)
+ if self.bias is not None:
+ return output + self.bias
+ else:
+ return output
+
+ def __repr__(self):
+ return self.__class__.__name__ + ' (' \
+ + str(self.in_features) + ' -> ' \
+ + str(self.out_features) + ')'
diff --git a/IndicPhotoOCR/detection/textbpn/network/layers/RNN.py b/IndicPhotoOCR/detection/textbpn/network/layers/RNN.py
new file mode 100644
index 0000000000000000000000000000000000000000..8cafd28a0c920fe2bd799ec43868c73fc4e93c25
--- /dev/null
+++ b/IndicPhotoOCR/detection/textbpn/network/layers/RNN.py
@@ -0,0 +1,35 @@
+###################################################################
+# File Name: RNN.py
+# Author: S.X.Zhang
+###################################################################
+
+from __future__ import print_function
+from __future__ import division
+from __future__ import absolute_import
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.nn import init
+
+
+class RNN(nn.Module):
+ def __init__(self, input, state_dim):
+ super(RNN, self).__init__()
+ self.bn0 = nn.BatchNorm1d(input, affine=False)
+ self.rnn = nn.LSTM(input, state_dim, 1, dropout=0.1, bidirectional=True)
+ self.prediction = nn.Sequential(
+ nn.Conv1d(state_dim*2, 128, 1),
+ nn.ReLU(inplace=True),
+ nn.Conv1d(128, 64, 1),
+ nn.ReLU(inplace=True),
+ nn.Conv1d(64, 2, 1))
+
+ def forward(self, x, adj):
+ x = self.bn0(x)
+ x = x.permute(2, 0, 1)
+ x, _ = self.rnn(x)
+ x = x.permute(1, 2, 0)
+ pred = self.prediction(x)
+
+ return pred
diff --git a/IndicPhotoOCR/detection/textbpn/network/layers/Transformer.py b/IndicPhotoOCR/detection/textbpn/network/layers/Transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..9890b316bbea13537c8d3a839328028589225b41
--- /dev/null
+++ b/IndicPhotoOCR/detection/textbpn/network/layers/Transformer.py
@@ -0,0 +1,140 @@
+###################################################################
+# File Name: GCN.py
+# Author: S.X.Zhang
+###################################################################
+import torch
+from torch import nn, Tensor
+import numpy as np
+from IndicPhotoOCR.detection.textbpn.cfglib.config import config as cfg
+
+
+class Positional_encoding(nn.Module):
+ def __init__(self, PE_size, n_position=256):
+ super(Positional_encoding, self).__init__()
+ self.PE_size = PE_size
+ self.n_position = n_position
+ self.register_buffer('pos_table', self.get_encoding_table(n_position, PE_size))
+
+ def get_encoding_table(self, n_position, PE_size):
+ position_table = np.array(
+ [[pos / np.power(10000, 2. * i / self.PE_size) for i in range(self.PE_size)] for pos in range(n_position)])
+ position_table[:, 0::2] = np.sin(position_table[:, 0::2])
+ position_table[:, 1::2] = np.cos(position_table[:, 1::2])
+ return torch.FloatTensor(position_table).unsqueeze(0)
+
+ def forward(self, inputs):
+ return inputs + self.pos_table[:, :inputs.size(1), :].clone().detach()
+
+
+class MultiHeadAttention(nn.Module):
+ def __init__(self, num_heads, embed_dim, dropout=0.1, if_resi=True):
+ super(MultiHeadAttention, self).__init__()
+ self.layer_norm = nn.LayerNorm(embed_dim)
+ self.MultiheadAttention = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout)
+ self.Q_proj = nn.Sequential(nn.Linear(embed_dim, embed_dim), nn.ReLU())
+ self.K_proj = nn.Sequential(nn.Linear(embed_dim, embed_dim), nn.ReLU())
+ self.V_proj = nn.Sequential(nn.Linear(embed_dim, embed_dim), nn.ReLU())
+ self.if_resi = if_resi
+
+ def forward(self, inputs):
+ query = self.layer_norm(inputs)
+ q = self.Q_proj(query)
+ k = self.K_proj(query)
+ v = self.V_proj(query)
+ attn_output, attn_output_weights = self.MultiheadAttention(q, k, v)
+ if self.if_resi:
+ attn_output += inputs
+ else:
+ attn_output = attn_output
+
+ return attn_output
+
+
+class FeedForward(nn.Module):
+ def __init__(self, in_channel, FFN_channel, if_resi=True):
+ super(FeedForward, self).__init__()
+ """
+ 1024 2048
+ """
+ output_channel = (FFN_channel, in_channel)
+ self.fc1 = nn.Sequential(nn.Linear(in_channel, output_channel[0]), nn.ReLU())
+ self.fc2 = nn.Linear(output_channel[0], output_channel[1])
+ self.layer_norm = nn.LayerNorm(in_channel)
+ self.if_resi = if_resi
+
+ def forward(self, inputs):
+ outputs = self.layer_norm(inputs)
+ outputs = self.fc1(outputs)
+ outputs = self.fc2(outputs)
+ if self.if_resi:
+ outputs += inputs
+ else:
+ outputs = outputs
+ return outputs
+
+
+class TransformerLayer(nn.Module):
+ def __init__(self, out_dim, in_dim, num_heads, attention_size,
+ dim_feedforward=1024, drop_rate=0.1, if_resi=True, block_nums=3):
+ super(TransformerLayer, self).__init__()
+ self.block_nums = block_nums
+ self.if_resi = if_resi
+ self.linear = nn.Linear(in_dim, attention_size)
+ for i in range(self.block_nums):
+ self.__setattr__('MHA_self_%d' % i, MultiHeadAttention(num_heads, attention_size,
+ dropout=drop_rate, if_resi=if_resi))
+ self.__setattr__('FFN_%d' % i, FeedForward(out_dim, dim_feedforward, if_resi=if_resi))
+
+ def forward(self, query):
+ inputs = self.linear(query)
+ # outputs = inputs
+ for i in range(self.block_nums):
+ outputs = self.__getattr__('MHA_self_%d' % i)(inputs)
+ outputs = self.__getattr__('FFN_%d' % i)(outputs)
+ if self.if_resi:
+ inputs = inputs+outputs
+ else:
+ inputs = outputs
+ # outputs = inputs
+ return inputs
+
+
+class Transformer(nn.Module):
+
+ def __init__(self, in_dim, out_dim, num_heads=8,
+ dim_feedforward=1024, drop_rate=0.1, if_resi=False, block_nums=3):
+ super().__init__()
+
+ self.bn0 = nn.BatchNorm1d(in_dim, affine=False)
+ self.conv1 = nn.Conv1d(in_dim, out_dim, 1, dilation=1)
+
+ # self.pos_embedding = Positional_encoding(in_dim)
+ self.transformer = TransformerLayer(out_dim, in_dim, num_heads, attention_size=out_dim,
+ dim_feedforward=dim_feedforward, drop_rate=drop_rate,
+ if_resi=if_resi, block_nums=block_nums)
+
+ self.prediction = nn.Sequential(
+ nn.Conv1d(2*out_dim, 128, 1),
+ nn.ReLU(inplace=True),
+ nn.Dropout(0.1),
+ nn.Conv1d(128, 64, 1),
+ nn.ReLU(inplace=True),
+ # nn.Dropout(0.1),
+ nn.Conv1d(64, 2, 1))
+
+ def forward(self, x, adj):
+ x = self.bn0(x)
+
+ x1 = x.permute(0, 2, 1)
+ # x1 = self.pos_embedding(x1)
+ x1 = self.transformer(x1)
+ x1 = x1.permute(0, 2, 1)
+
+ x = torch.cat([x1, self.conv1(x)], dim=1)
+ # x = x1+self.conv1(x)
+ pred = self.prediction(x)
+
+ return pred
+
+
+
diff --git a/IndicPhotoOCR/detection/textbpn/network/layers/Transformer_old.py b/IndicPhotoOCR/detection/textbpn/network/layers/Transformer_old.py
new file mode 100644
index 0000000000000000000000000000000000000000..a2b7a7e485d7afb4b11048061391f6eb7ed2c274
--- /dev/null
+++ b/IndicPhotoOCR/detection/textbpn/network/layers/Transformer_old.py
@@ -0,0 +1,171 @@
+###################################################################
+# File Name: GCN.py
+# Author: S.X.Zhang
+###################################################################
+
+import torch
+import torch.nn.functional as F
+from torch import nn, Tensor
+from torch.autograd import Variable
+import numpy as np
+from cfglib.config import config as cfg
+
+
+class Positional_encoding(nn.Module):
+ def __init__(self, PE_size, n_position=200):
+ super(Positional_encoding, self).__init__()
+ self.PE_size = PE_size
+ self.n_position = n_position
+ self.register_buffer('pos_table', self.get_encoding_table(n_position, PE_size))
+
+ def get_encoding_table(self, n_position, PE_size):
+ position_table = np.array(
+ [[pos / np.power(10000, 2. * i / self.PE_size) for i in range(self.PE_size)] for pos in range(n_position)])
+ position_table[:, 0::2] = np.sin(position_table[:, 0::2])
+ position_table[:, 1::2] = np.cos(position_table[:, 1::2])
+ return torch.FloatTensor(position_table).unsqueeze(0)
+
+ def forward(self, inputs):
+ return inputs + self.pos_table[:, :inputs.size(1), :].clone().detach()
+
+
+class MultiHeadAttention(nn.Module):
+ def __init__(self, num_heads, embedding_size, attention_size,
+ drop_rate, future_blind=True, query_mask=False, if_resi=True):
+ super(MultiHeadAttention, self).__init__()
+ self.num_heads = num_heads
+ self.embedding_size = embedding_size
+ self.attention_size = attention_size
+ self.drop_rate = drop_rate
+ self.future_blind = future_blind
+
+ self.Q_proj = nn.Sequential(nn.Linear(self.embedding_size, self.attention_size), nn.ReLU())
+ self.K_proj = nn.Sequential(nn.Linear(self.embedding_size, self.attention_size), nn.ReLU())
+ self.V_proj = nn.Sequential(nn.Linear(self.embedding_size, self.attention_size), nn.ReLU())
+
+ self.drop_out = nn.Dropout(p=self.drop_rate)
+ self.layer_norm = nn.LayerNorm(self.attention_size)
+ self.if_resi = if_resi
+
+ def forward(self, query, key, value):
+ q = self.Q_proj(query)
+ k = self.K_proj(key)
+ v = self.V_proj(value)
+
+ q_ = torch.cat(torch.chunk(q, self.num_heads, dim=2), dim=0)
+ k_ = torch.cat(torch.chunk(k, self.num_heads, dim=2), dim=0)
+ v_ = torch.cat(torch.chunk(v, self.num_heads, dim=2), dim=0)
+
+ outputs = torch.bmm(q_, k_.permute(0, 2, 1))
+ outputs = outputs / (k_.size()[-1] ** 0.5)
+
+ # key mask
+
+ # future mask
+ if self.future_blind:
+ diag_vals = torch.ones_like(outputs[0, :, :]).to(cfg.device)
+ tril = torch.tril(diag_vals, diagonal=0)
+ masks = Variable(torch.unsqueeze(tril, 0).repeat(outputs.size()[0], 1, 1)) # (h*N,T_q,T_k)
+ padding = Variable(torch.ones_like(masks).to(cfg.device) * (-2 ** 32 + 1))
+ condition = masks.eq(0)
+ outputs = torch.where(condition, padding, outputs)
+
+ outputs = F.softmax(outputs, dim=-1)
+ # if self.future_blind==True:a
+ # print(outputs[0])
+ outputs = self.drop_out(outputs)
+
+ outputs = torch.bmm(outputs, v_)
+ outputs = torch.cat(torch.chunk(outputs, self.num_heads, dim=0), dim=2) # N,T_q,C
+
+ if self.if_resi:
+ # outputs += query
+ outputs += q
+ else:
+ outputs = outputs
+ outputs = self.layer_norm(outputs)
+
+ return outputs
+
+
+class FeedForward(nn.Module):
+ def __init__(self, in_channel, FFN_channel, if_resi=True):
+ super(FeedForward, self).__init__()
+ """
+ 1024 2048
+ """
+ output_channel = (FFN_channel, in_channel)
+ self.fc1 = nn.Sequential(nn.Linear(in_channel, output_channel[0]), nn.ReLU())
+ self.fc2 = nn.Linear(output_channel[0], output_channel[1])
+ self.layer_norm = nn.LayerNorm(in_channel)
+ self.if_resi = if_resi
+
+ def forward(self, inputs):
+ outputs = self.fc1(inputs)
+ outputs = self.fc2(outputs)
+ if self.if_resi:
+ outputs += inputs
+ else:
+ outputs = outputs
+ outputs = self.layer_norm(outputs)
+ return outputs
+
+
+class TransformerLayer(nn.Module):
+ def __init__(self, out_dim, num_heads, embedding_size, attention_size,
+ dim_feedforward=1024, drop_rate=0.1, if_resi=True, block_nums=3):
+ super(TransformerLayer, self).__init__()
+ self.block_nums = block_nums
+ self.if_resi = if_resi
+ for i in range(self.block_nums):
+ self.__setattr__('MHA_self_%d' % i, MultiHeadAttention(num_heads, embedding_size, attention_size,
+ drop_rate, future_blind=False, if_resi=if_resi))
+ self.__setattr__('FFN_%d' % i, FeedForward(out_dim, dim_feedforward, if_resi=if_resi))
+
+ def forward(self, query):
+ outputs = None
+ for i in range(self.block_nums):
+ outputs = self.__getattr__('MHA_self_%d' % i)(query, query, query)
+ outputs = self.__getattr__('FFN_%d' % i)(outputs)
+ return outputs
+
+
+class Transformer(nn.Module):
+
+ def __init__(self, in_dim, out_dim, num_heads=8,
+ dim_feedforward=1024, drop_rate=0.1, if_resi=False, block_nums=3):
+ super().__init__()
+
+ self.bn0 = nn.BatchNorm1d(in_dim, affine=False)
+ self.conv1 = nn.Conv1d(in_dim, out_dim, 1, dilation=1)
+
+ embed_dim = in_dim
+ # self.pos_embedding = Positional_encoding(embed_dim)
+ self.transformer = TransformerLayer(out_dim, num_heads, embedding_size=embed_dim,
+ attention_size=out_dim, dim_feedforward=dim_feedforward,
+ drop_rate=drop_rate, if_resi=if_resi, block_nums=block_nums)
+
+ self.prediction = nn.Sequential(
+ nn.Conv1d(out_dim*2, 128, 1),
+ nn.ReLU(inplace=True),
+ nn.Dropout(0.1),
+ nn.Conv1d(128, 64, 1),
+ nn.ReLU(inplace=True),
+ # nn.Dropout(0.1),
+ nn.Conv1d(64, 2, 1))
+
+ def forward(self, x, adj):
+ x = self.bn0(x)
+
+ x1 = x.permute(0, 2, 1)
+ x1 = self.transformer(x1)
+ x1 = x1.permute(0, 2, 1)
+
+ x = torch.cat([x1, self.conv1(x)], dim=1)
+ # x = x1+self.conv1(x)
+ pred = self.prediction(x)
+
+ return pred
+
+
+
diff --git a/IndicPhotoOCR/detection/textbpn/network/layers/__init__.py b/IndicPhotoOCR/detection/textbpn/network/layers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/IndicPhotoOCR/detection/textbpn/network/layers/gcn_utils.py b/IndicPhotoOCR/detection/textbpn/network/layers/gcn_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..6605ee61ae43097f1e6ef8848adbb3cff100829f
--- /dev/null
+++ b/IndicPhotoOCR/detection/textbpn/network/layers/gcn_utils.py
@@ -0,0 +1,150 @@
+# -*- coding: utf-8 -*-
+__author__ = "S.X.Zhang"
+import torch
+import numpy as np
+import cv2
+import torch.nn as nn
+from torch.autograd import Variable
+
+
+def normalize_adj(A, type="AD"):
+ if type == "DAD":
+ A = A + np.eye(A.shape[0]) # A=A+I
+ d = np.sum(A, axis=0)
+ d_inv = np.power(d, -0.5).flatten()
+ d_inv[np.isinf(d_inv)] = 0.0
+ d_inv = np.diag(d_inv)
+ G = A.dot(d_inv).transpose().dot(d_inv) # L = D^-1/2 A D^-1/2
+ G = torch.from_numpy(G)
+ elif type == "AD":
+ A = A + np.eye(A.shape[0]) # A=A+I
+ A = torch.from_numpy(A)
+ D = A.sum(1, keepdim=True)
+ G = A.div(D) # L= A/D
+ else:
+ A = A + np.eye(A.shape[0]) # A=A+I
+ D = A.sum(1, keepdim=True)
+ D = np.diag(D)
+ G = torch.from_numpy(D - A) # L = D-A
+ return G
+
+
+def np_to_variable(x, is_cuda=True, dtype=torch.FloatTensor):
+ v = Variable(torch.from_numpy(x).type(dtype))
+ if is_cuda:
+ v = v.cuda()
+ return v
+
+
+def set_trainable(model, requires_grad):
+ for param in model.parameters():
+ param.requires_grad = requires_grad
+
+
+def weights_normal_init(model, dev=0.01):
+ if isinstance(model, list):
+ for m in model:
+ weights_normal_init(m, dev)
+ else:
+ for m in model.modules():
+ if isinstance(m, nn.Conv2d):
+ m.weight.data.normal_(0.0, dev)
+ elif isinstance(m, nn.Linear):
+ m.weight.data.normal_(0.0, dev)
+
+
+def clip_gradient(model, clip_norm):
+ """Computes a gradient clipping coefficient based on gradient norm."""
+ totalnorm = 0
+ for p in model.parameters():
+ if p.requires_grad:
+ modulenorm = p.grad.data.norm()
+ totalnorm += modulenorm ** 2
+ totalnorm = np.sqrt(totalnorm)
+
+ norm = clip_norm / max(totalnorm, clip_norm)
+ for p in model.parameters():
+ if p.requires_grad:
+ p.grad.mul_(norm)
+
+
+def EuclideanDistances(A, B):
+ BT = B.transpose()
+ vecProd = np.dot(A,BT)
+ SqA = A**2
+ sumSqA = np.matrix(np.sum(SqA, axis=1))
+ sumSqAEx = np.tile(sumSqA.transpose(), (1, vecProd.shape[1]))
+
+ SqB = B**2
+ sumSqB = np.sum(SqB, axis=1)
+ sumSqBEx = np.tile(sumSqB, (vecProd.shape[0], 1))
+ SqED = sumSqBEx + sumSqAEx - 2*vecProd
+ SqED[SqED<0]=0.0
+ ED = np.sqrt(SqED)
+ return ED
+
+
+def get_center_feature(cnn_feature, img_poly, ind, h, w):
+ batch_size = cnn_feature.size(0)
+ for i in range(batch_size):
+ poly = img_poly[ind == i].cpu().numpy()
+ mask = np.zeros((h, w), dtype=np.uint8)
+ cv2.fillPoly(mask, poly.astype(np.int32), color=(1,))
+ return None
+
+
+def get_node_feature(cnn_feature, img_poly, ind, h, w):
+ img_poly = img_poly.clone().float()
+ img_poly[..., 0] = img_poly[..., 0] / (w / 2.) - 1
+ img_poly[..., 1] = img_poly[..., 1] / (h / 2.) - 1
+
+ batch_size = cnn_feature.size(0)
+ gcn_feature = torch.zeros([img_poly.size(0), cnn_feature.size(1), img_poly.size(1)]).to(img_poly.device)
+ for i in range(batch_size):
+ poly = img_poly[ind == i].unsqueeze(0)
+ gcn_feature[ind == i] = torch.nn.functional.grid_sample(cnn_feature[i:i + 1], poly)[0].permute(1, 0, 2)
+ return gcn_feature
+
+
+def get_adj_mat(n_adj, n_nodes):
+ a = np.zeros([n_nodes, n_nodes], dtype=np.float)
+
+ for i in range(n_nodes):
+ for j in range(-n_adj // 2, n_adj // 2 + 1):
+ if j != 0:
+ a[i][(i + j) % n_nodes] = 1
+ a[(i + j) % n_nodes][i] = 1
+ return a
+
+
+def get_adj_ind(n_adj, n_nodes, device):
+ ind = torch.tensor([i for i in range(-n_adj // 2, n_adj // 2 + 1) if i != 0]).long()
+ ind = (torch.arange(n_nodes)[:, None] + ind[None]) % n_nodes
+ return ind.to(device)
+
+
+def coord_embedding(b, w, h, device):
+ x_range = torch.linspace(0, 1, w, device=device)
+ y_range = torch.linspace(0, 1, h, device=device)
+ y, x = torch.meshgrid(y_range, x_range)
+ y = y.expand([b, 1, -1, -1])
+ x = x.expand([b, 1, -1, -1])
+ coord_map = torch.cat([x, y], 1)
+
+ return coord_map
+
+
+def img_poly_to_can_poly(img_poly):
+ if len(img_poly) == 0:
+ return torch.zeros_like(img_poly)
+ x_min = torch.min(img_poly[..., 0], dim=-1)[0]
+ y_min = torch.min(img_poly[..., 1], dim=-1)[0]
+ can_poly = img_poly.clone()
+ can_poly[..., 0] = can_poly[..., 0] - x_min[..., None]
+ can_poly[..., 1] = can_poly[..., 1] - y_min[..., None]
+ # x_max = torch.max(img_poly[..., 0], dim=-1)[0]
+ # y_max = torch.max(img_poly[..., 1], dim=-1)[0]
+ # h, w = y_max - y_min + 1, x_max - x_min + 1
+ # long_side = torch.max(h, w)
+ # can_poly = can_poly / long_side[..., None, None]
+ return can_poly
diff --git a/IndicPhotoOCR/detection/textbpn/network/layers/model_block.py b/IndicPhotoOCR/detection/textbpn/network/layers/model_block.py
new file mode 100644
index 0000000000000000000000000000000000000000..0a6b0bf9ba52f72cfa29c8f1ce389d90b5d04422
--- /dev/null
+++ b/IndicPhotoOCR/detection/textbpn/network/layers/model_block.py
@@ -0,0 +1,149 @@
+# -*- coding: utf-8 -*-
+__author__ = "S.X.Zhang"
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from IndicPhotoOCR.detection.textbpn.network.layers.vgg import VggNet
+from IndicPhotoOCR.detection.textbpn.network.layers.resnet import ResNet
+from IndicPhotoOCR.detection.textbpn.network.layers.resnet_dcn import ResNet_DCN
+from IndicPhotoOCR.detection.textbpn.cfglib.config import config as cfg
+
+
+class UpBlok(nn.Module):
+
+ def __init__(self, in_channels, out_channels):
+ super().__init__()
+ self.conv1x1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
+ self.conv3x3 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
+ self.deconv = nn.ConvTranspose2d(out_channels, out_channels, kernel_size=4, stride=2, padding=1)
+
+ def forward(self, upsampled, shortcut):
+ x = torch.cat([upsampled, shortcut], dim=1)
+ x = self.conv1x1(x)
+ x = F.relu(x)
+ x = self.conv3x3(x)
+ x = F.relu(x)
+ x = self.deconv(x)
+ return x
+
+
+class MergeBlok(nn.Module):
+ def __init__(self, in_channels, out_channels):
+ super().__init__()
+ self.conv1x1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
+ self.conv3x3 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
+
+ def forward(self, upsampled, shortcut):
+ x = torch.cat([upsampled, shortcut], dim=1)
+ x = self.conv1x1(x)
+ x = F.relu(x)
+ x = self.conv3x3(x)
+ return x
+
+
+class FPN(nn.Module):
+
+ def __init__(self, backbone='resnet50', is_training=True):
+ super().__init__()
+ self.is_training = is_training
+ self.backbone_name = backbone
+
+ if backbone in ['vgg_bn', 'vgg']:
+ self.backbone = VggNet(name=backbone, pretrain=is_training)
+ self.deconv5 = nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1)
+ self.merge4 = UpBlok(512 + 256, 128)
+ self.merge3 = UpBlok(256 + 128, 64)
+ if cfg.scale == 1:
+ self.merge2 = UpBlok(128 + 64, 32) # FPN 1/2
+ self.merge1 = UpBlok(64 + 32, 32) # FPN 1/1
+ elif cfg.scale == 2:
+ self.merge2 = UpBlok(128 + 64, 32) # FPN 1/2
+ self.merge1 = MergeBlok(64 + 32, 32) # FPN 1/2
+ elif cfg.scale == 4:
+ self.merge2 = MergeBlok(128 + 64, 32) # FPN 1/4
+
+ elif backbone in ['resnet50']:
+ self.backbone = ResNet(name=backbone, pretrain=is_training)
+ self.deconv5 = nn.ConvTranspose2d(2048, 256, kernel_size=4, stride=2, padding=1)
+ self.merge4 = UpBlok(1024 + 256, 128)
+ self.merge3 = UpBlok(512 + 128, 64)
+ if cfg.scale == 1:
+ self.merge2 = UpBlok(256 + 64, 32) # FPN 1/2
+ self.merge1 = UpBlok(64 + 32, 32) # FPN 1/1
+ elif cfg.scale == 2:
+ self.merge2 = UpBlok(256 + 64, 32) # FPN 1/2
+ self.merge1 = MergeBlok(64 + 32, 32) # FPN 1/2
+ elif cfg.scale == 4:
+ self.merge2 = MergeBlok(256 + 64, 32) # FPN 1/4
+ self.merge1 = MergeBlok(64 + 32, 32) # FPN 1/4
+
+ elif backbone in ['resnet18']:
+ self.backbone = ResNet(name=backbone, pretrain=is_training)
+ self.deconv5 = nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1)
+ self.merge4 = UpBlok(256 + 256, 128)
+ self.merge3 = UpBlok(128 + 128, 64)
+ if cfg.scale == 1:
+ self.merge2 = UpBlok(64 + 64, 32) # FPN 1/2
+ self.merge1 = UpBlok(64 + 32, 32) # FPN 1/1
+ elif cfg.scale == 2:
+ self.merge2 = UpBlok(64 + 64, 32) # FPN 1/2
+ self.merge1 = MergeBlok(64 + 32, 32) # FPN 1/2
+ elif cfg.scale == 4:
+ self.merge2 = MergeBlok(64 + 64, 32) # FPN 1/4
+ self.merge1 = MergeBlok(64 + 32, 32) # FPN 1/4
+
+ elif backbone in ["deformable_resnet18"]:
+ self.backbone = ResNet_DCN(name=backbone, pretrain=is_training)
+ self.deconv5 = nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1)
+ self.merge4 = UpBlok(256 + 256, 128)
+ self.merge3 = UpBlok(128 + 128, 64)
+ if cfg.scale == 1:
+ self.merge2 = UpBlok(64 + 64, 32) # FPN 1/2
+ self.merge1 = UpBlok(64 + 32, 32) # FPN 1/1
+ elif cfg.scale == 2:
+ self.merge2 = UpBlok(64 + 64, 32) # FPN 1/2
+ self.merge1 = MergeBlok(64 + 32, 32) # FPN 1/2
+ elif cfg.scale == 4:
+ self.merge2 = MergeBlok(64 + 64, 32) # FPN 1/4
+ self.merge1 = MergeBlok(64 + 32, 32) # FPN 1/4
+
+ elif backbone in ["deformable_resnet50"]:
+ self.backbone = ResNet_DCN(name=backbone, pretrain=is_training)
+ self.deconv5 = nn.ConvTranspose2d(2048, 256, kernel_size=4, stride=2, padding=1)
+ self.merge4 = UpBlok(1024 + 256, 128)
+ self.merge3 = UpBlok(512 + 128, 64)
+ if cfg.scale == 1:
+ self.merge2 = UpBlok(256 + 64, 32) # FPN 1/2
+ self.merge1 = UpBlok(64 + 32, 32) # FPN 1/1
+ elif cfg.scale == 2:
+ self.merge2 = UpBlok(256 + 64, 32) # FPN 1/2
+ self.merge1 = MergeBlok(64 + 32, 32) # FPN 1/2
+ elif cfg.scale == 4:
+ self.merge2 = MergeBlok(256 + 64, 32) # FPN 1/4
+ self.merge1 = MergeBlok(64 + 32, 32) # FPN 1/4
+ else:
+ print("backbone is not support !")
+
+ def forward(self, x):
+ C1, C2, C3, C4, C5 = self.backbone(x)
+ #print(C5.size())
+ #print(C4.size())
+ #print(C3.size())
+ #print(C2.size())
+ #print(C1.size())
+ up5 = self.deconv5(C5)
+ up5 = F.relu(up5)
+
+ up4 = self.merge4(C4, up5)
+ up4 = F.relu(up4)
+
+ up3 = self.merge3(C3, up4)
+ up3 = F.relu(up3)
+
+ up2 = self.merge2(C2, up3)
+ up2 = F.relu(up2)
+
+ up1 = self.merge1(C1, up2)
+ up1 = F.relu(up1)
+
+ return up1, up2, up3, up4, up5
diff --git a/IndicPhotoOCR/detection/textbpn/network/layers/position_encoding.py b/IndicPhotoOCR/detection/textbpn/network/layers/position_encoding.py
new file mode 100644
index 0000000000000000000000000000000000000000..73ae39edf24659e226dc6d96c7c5cbf8bef579ca
--- /dev/null
+++ b/IndicPhotoOCR/detection/textbpn/network/layers/position_encoding.py
@@ -0,0 +1,89 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
+"""
+Various positional encodings for the transformer.
+"""
+import math
+import torch
+from torch import nn
+
+from util.misc import NestedTensor
+
+
+class PositionEmbeddingSine(nn.Module):
+ """
+ This is a more standard version of the position embedding, very similar to the one
+ used by the Attention is all you need paper, generalized to work on images.
+ """
+ def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
+ super().__init__()
+ self.num_pos_feats = num_pos_feats
+ self.temperature = temperature
+ self.normalize = normalize
+ if scale is not None and normalize is False:
+ raise ValueError("normalize should be True if scale is passed")
+ if scale is None:
+ scale = 2 * math.pi
+ self.scale = scale
+
+ def forward(self, tensor_list: NestedTensor):
+ x = tensor_list.tensors
+ mask = tensor_list.mask
+ assert mask is not None
+ not_mask = ~mask
+ y_embed = not_mask.cumsum(1, dtype=torch.float32)
+ x_embed = not_mask.cumsum(2, dtype=torch.float32)
+ if self.normalize:
+ eps = 1e-6
+ y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
+ x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
+
+ dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
+ dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
+
+ pos_x = x_embed[:, :, :, None] / dim_t
+ pos_y = y_embed[:, :, :, None] / dim_t
+ pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
+ pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
+ pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
+ return pos
+
+
+class PositionEmbeddingLearned(nn.Module):
+ """
+ Absolute pos embedding, learned.
+ """
+ def __init__(self, num_pos_feats=256):
+ super().__init__()
+ self.row_embed = nn.Embedding(50, num_pos_feats)
+ self.col_embed = nn.Embedding(50, num_pos_feats)
+ self.reset_parameters()
+
+ def reset_parameters(self):
+ nn.init.uniform_(self.row_embed.weight)
+ nn.init.uniform_(self.col_embed.weight)
+
+ def forward(self, tensor_list: NestedTensor):
+ x = tensor_list.tensors
+ h, w = x.shape[-2:]
+ i = torch.arange(w, device=x.device)
+ j = torch.arange(h, device=x.device)
+ x_emb = self.col_embed(i)
+ y_emb = self.row_embed(j)
+ pos = torch.cat([
+ x_emb.unsqueeze(0).repeat(h, 1, 1),
+ y_emb.unsqueeze(1).repeat(1, w, 1),
+ ], dim=-1).permute(2, 0, 1).unsqueeze(0).repeat(x.shape[0], 1, 1, 1)
+ return pos
+
+
+def build_position_encoding(args):
+ N_steps = args.hidden_dim // 2
+ if args.position_embedding in ('v2', 'sine'):
+ # TODO find a better way of exposing other arguments
+ position_embedding = PositionEmbeddingSine(N_steps, normalize=True)
+ elif args.position_embedding in ('v3', 'learned'):
+ position_embedding = PositionEmbeddingLearned(N_steps)
+ else:
+ raise ValueError(f"not supported {args.position_embedding}")
+
+ return position_embedding
diff --git a/IndicPhotoOCR/detection/textbpn/network/layers/resnet.py b/IndicPhotoOCR/detection/textbpn/network/layers/resnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..b6da7208d9e30bbe6022c3ae5b04e9f0c08f8483
--- /dev/null
+++ b/IndicPhotoOCR/detection/textbpn/network/layers/resnet.py
@@ -0,0 +1,73 @@
+import torch
+import torch.nn as nn
+from torchvision.models import resnet
+import torch.utils.model_zoo as model_zoo
+from IndicPhotoOCR.detection.textbpn.cfglib.config import config as cfg
+
+model_urls = {
+ 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
+ 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
+ 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
+ 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
+ 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
+
+}
+
+
+class ResNet(nn.Module):
+ def __init__(self, name="resnet50", pretrain=True):
+ super().__init__()
+
+ if name == "resnet50":
+ base_net = resnet.resnet50(pretrained=False)
+ elif name == "resnet101":
+ base_net = resnet.resnet101(pretrained=False)
+ elif name == "resnet18":
+ base_net = resnet.resnet18(pretrained=False)
+ elif name == "resnet34":
+ base_net = resnet.resnet34(pretrained=False)
+
+ else:
+ print(" base model is not support !")
+
+ if pretrain:
+ print("load the {} weight from ./cache".format(name))
+ base_net.load_state_dict(model_zoo.load_url(model_urls[name], model_dir="./cache",
+ map_location=torch.device(cfg.device)), strict=False)
+ # print(base_net)
+ self.stage1 = nn.Sequential(
+ base_net.conv1,
+ base_net.bn1,
+ base_net.relu,
+ base_net.maxpool
+ )
+ self.stage2 = base_net.layer1
+ self.stage3 = base_net.layer2
+ self.stage4 = base_net.layer3
+ self.stage5 = base_net.layer4
+ self.up2 = nn.ConvTranspose2d(64, 64, kernel_size=4, stride=2, padding=1)
+
+ def forward(self, x):
+ C1 = self.stage1(x)
+ C2 = self.stage2(C1)
+ C3 = self.stage3(C2)
+ C4 = self.stage4(C3)
+ C5 = self.stage5(C4)
+
+ if cfg.scale == 2 or cfg.scale == 1:
+ # up2 --> 1/2
+ C1 = self.up2(C1)
+
+ return C1, C2, C3, C4, C5
+
+
+if __name__ == '__main__':
+ import torch
+ input = torch.randn((4, 3, 512, 512))
+ net = ResNet()
+ C1, C2, C3, C4, C5 = net(input)
+ print(C1.size())
+ print(C2.size())
+ print(C3.size())
+ print(C4.size())
+ print(C5.size())
diff --git a/IndicPhotoOCR/detection/textbpn/network/layers/resnet_dcn.py b/IndicPhotoOCR/detection/textbpn/network/layers/resnet_dcn.py
new file mode 100644
index 0000000000000000000000000000000000000000..918b5b84b8448adbc221c0ae93de3c7db71cfd1c
--- /dev/null
+++ b/IndicPhotoOCR/detection/textbpn/network/layers/resnet_dcn.py
@@ -0,0 +1,59 @@
+import torch
+import torch.nn as nn
+from IndicPhotoOCR.detection.textbpn.network.backbone.resnet import deformable_resnet18,deformable_resnet50
+import torch.utils.model_zoo as model_zoo
+from IndicPhotoOCR.detection.textbpn.cfglib.config import config as cfg
+
+model_urls = {
+ 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
+ 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
+ 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
+ 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
+ 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
+
+}
+
+
+class ResNet_DCN(nn.Module):
+ def __init__(self, name="deformable_resnet18", pretrain=False):
+ super().__init__()
+
+ if name == "deformable_resnet18":
+ self.base_net = deformable_resnet18(pretrained=False)
+ if pretrain:
+ print("load the {} weight from ./cache".format(name))
+ self.base_net.load_state_dict(
+ model_zoo.load_url(model_urls["resnet18"], model_dir="./cache",
+ map_location=torch.device(cfg.device)), strict=False)
+
+ elif name == "deformable_resnet50":
+ self.base_net = deformable_resnet50(pretrained=False)
+ if pretrain:
+ print("load the {} weight from ./cache".format(name))
+ self.base_net.load_state_dict(
+ model_zoo.load_url(model_urls["resnet50"], model_dir="./cache",
+ map_location=torch.device(cfg.device)), strict=False)
+ else:
+ print(" base model is not support !")
+
+ # print(base_net)
+ self.up2 = nn.ConvTranspose2d(64, 64, kernel_size=4, stride=2, padding=1)
+
+ def forward(self, x):
+ C1, C2, C3, C4, C5 = self.base_net(x)
+ # up2 --> 1/2
+ C1 = self.up2(C1)
+
+ return C1, C2, C3, C4, C5
+
+
+if __name__ == '__main__':
+ import torch
+ input = torch.randn((4, 3, 512, 512))
+ net = ResNet_DCN()
+ C1, C2, C3, C4, C5 = net(input)
+ print(C1.size())
+ print(C2.size())
+ print(C3.size())
+ print(C4.size())
+ print(C5.size())
diff --git a/IndicPhotoOCR/detection/textbpn/network/layers/vgg.py b/IndicPhotoOCR/detection/textbpn/network/layers/vgg.py
new file mode 100644
index 0000000000000000000000000000000000000000..c1796e65aeba98c320163b9aa1852f03bb25ef99
--- /dev/null
+++ b/IndicPhotoOCR/detection/textbpn/network/layers/vgg.py
@@ -0,0 +1,62 @@
+import torch.nn as nn
+import torch.utils.model_zoo as model_zoo
+import torchvision.models as models
+from IndicPhotoOCR.detection.textbpn.cfglib.config import config as cfg
+
+model_urls = {
+ 'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth',
+ 'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth',
+ 'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth',
+ 'vgg11_bn': 'https://download.pytorch.org/models/vgg11_bn-6002323d.pth',
+ 'vgg16_bn': 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth',
+ 'vgg19_bn': 'https://download.pytorch.org/models/vgg19_bn-c79401a0.pth',
+}
+
+
+class VggNet(nn.Module):
+ def __init__(self, name="vgg16", pretrain=True):
+ super().__init__()
+ if name == "vgg16":
+ base_net = models.vgg16(pretrained=False)
+ elif name == "vgg16_bn":
+ base_net = models.vgg16_bn(pretrained=False)
+ else:
+ print(" base model is not support !")
+ if pretrain:
+ print("load the {} weight from ./cache".format(name))
+ base_net.load_state_dict(model_zoo.load_url(model_urls[name],
+ model_dir="./cache",map_location=torch.device(cfg.device)))
+
+ if name == "vgg16":
+ self.stage1 = nn.Sequential(*[base_net.features[layer] for layer in range(0, 5)])
+ self.stage2 = nn.Sequential(*[base_net.features[layer] for layer in range(5, 10)])
+ self.stage3 = nn.Sequential(*[base_net.features[layer] for layer in range(10, 17)])
+ self.stage4 = nn.Sequential(*[base_net.features[layer] for layer in range(17, 24)])
+ self.stage5 = nn.Sequential(*[base_net.features[layer] for layer in range(24, 31)])
+ elif name == "vgg16_bn":
+ self.stage1 = nn.Sequential(*[base_net.features[layer] for layer in range(0, 7)])
+ self.stage2 = nn.Sequential(*[base_net.features[layer] for layer in range(7, 14)])
+ self.stage3 = nn.Sequential(*[base_net.features[layer] for layer in range(14, 24)])
+ self.stage4 = nn.Sequential(*[base_net.features[layer] for layer in range(24, 34)])
+ self.stage5 = nn.Sequential(*[base_net.features[layer] for layer in range(34, 44)])
+
+ def forward(self, x):
+ C1 = self.stage1(x)
+ C2 = self.stage2(C1)
+ C3 = self.stage3(C2)
+ C4 = self.stage4(C3)
+ C5 = self.stage5(C4)
+
+ return C1, C2, C3, C4, C5
+
+
+if __name__ == '__main__':
+ import torch
+ input = torch.randn((4, 3, 512, 512))
+ net = VggNet()
+ C1, C2, C3, C4, C5 = net(input)
+ print(C1.size())
+ print(C2.size())
+ print(C3.size())
+ print(C4.size())
+ print(C5.size())
diff --git a/IndicPhotoOCR/detection/textbpn/network/loss.py b/IndicPhotoOCR/detection/textbpn/network/loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..08392b9ba27e4bd9d49b87e8e90687e8e8249896
--- /dev/null
+++ b/IndicPhotoOCR/detection/textbpn/network/loss.py
@@ -0,0 +1,187 @@
+# -*- coding: utf-8 -*-
+# @Time : 10/1/21
+# @Author : GXYM
+import torch
+import torch.nn as nn
+from cfglib.config import config as cfg
+from network.Seg_loss import SegmentLoss
+from network.Reg_loss import PolyMatchingLoss
+import torch.nn.functional as F
+
+
+class TextLoss(nn.Module):
+
+ def __init__(self):
+ super().__init__()
+ self.MSE_loss = torch.nn.MSELoss(reduce=False, size_average=False)
+ self.BCE_loss = torch.nn.BCELoss(reduce=False, size_average=False)
+ self.PolyMatchingLoss = PolyMatchingLoss(cfg.num_points, cfg.device)
+ self.KL_loss = torch.nn.KLDivLoss(reduce=False, size_average=False)
+
+ @staticmethod
+ def single_image_loss(pre_loss, loss_label):
+ batch_size = pre_loss.shape[0]
+ sum_loss = torch.mean(pre_loss.view(-1)) * 0
+ pre_loss = pre_loss.view(batch_size, -1)
+ loss_label = loss_label.view(batch_size, -1)
+ eps = 0.001
+ for i in range(batch_size):
+ average_number = 0
+ positive_pixel = len(pre_loss[i][(loss_label[i] >= eps)])
+ average_number += positive_pixel
+ if positive_pixel != 0:
+ posi_loss = torch.mean(pre_loss[i][(loss_label[i] >= eps)])
+ sum_loss += posi_loss
+ if len(pre_loss[i][(loss_label[i] < eps)]) < 3 * positive_pixel:
+ nega_loss = torch.mean(pre_loss[i][(loss_label[i] < eps)])
+ average_number += len(pre_loss[i][(loss_label[i] < eps)])
+ else:
+ nega_loss = torch.mean(torch.topk(pre_loss[i][(loss_label[i] < eps)], 3 * positive_pixel)[0])
+ average_number += 3 * positive_pixel
+ sum_loss += nega_loss
+ else:
+ nega_loss = torch.mean(torch.topk(pre_loss[i], 100)[0])
+ average_number += 100
+ sum_loss += nega_loss
+ # sum_loss += loss/average_number
+
+ return sum_loss/batch_size
+
+ def cls_ohem(self, predict, target, train_mask, negative_ratio=3.):
+ pos = (target * train_mask).bool()
+ neg = ((1 - target) * train_mask).bool()
+
+ n_pos = pos.float().sum()
+ if n_pos.item() > 0:
+ loss_pos = self.BCE_loss(predict[pos], target[pos]).sum()
+ loss_neg = self.BCE_loss(predict[neg], target[neg])
+ n_neg = min(int(neg.float().sum().item()), int(negative_ratio * n_pos.float()))
+ else:
+ loss_pos = torch.tensor(0.)
+ loss_neg = self.BCE_loss(predict[neg], target[neg])
+ n_neg = 100
+ loss_neg, _ = torch.topk(loss_neg, n_neg)
+
+ return (loss_pos + loss_neg.sum()) / (n_pos + n_neg).float()
+
+ @staticmethod
+ def loss_calc_flux(pred_flux, gt_flux, weight_matrix, mask, train_mask):
+
+ # norm loss
+ gt_flux = 0.999999 * gt_flux / (gt_flux.norm(p=2, dim=1).unsqueeze(1) + 1e-3)
+ norm_loss = weight_matrix * torch.mean((pred_flux - gt_flux) ** 2, dim=1)*train_mask
+ norm_loss = norm_loss.sum(-1).mean()
+ # norm_loss = norm_loss.sum()
+
+ # angle loss
+ mask = train_mask * mask
+ pred_flux = 0.999999 * pred_flux / (pred_flux.norm(p=2, dim=1).unsqueeze(1) + 1e-3)
+ # angle_loss = weight_matrix * (torch.acos(torch.sum(pred_flux * gt_flux, dim=1))) ** 2
+ # angle_loss = angle_loss.sum(-1).mean()
+ angle_loss = (1 - torch.cosine_similarity(pred_flux, gt_flux, dim=1))
+ angle_loss = angle_loss[mask].mean()
+
+ return norm_loss, angle_loss
+
+ @staticmethod
+ def get_poly_energy(energy_field, img_poly, ind, h, w):
+ img_poly = img_poly.clone().float()
+ img_poly[..., 0] = img_poly[..., 0] / (w / 2.) - 1
+ img_poly[..., 1] = img_poly[..., 1] / (h / 2.) - 1
+
+ batch_size = energy_field.size(0)
+ gcn_feature = torch.zeros([img_poly.size(0), energy_field.size(1), img_poly.size(1)]).to(img_poly.device)
+ for i in range(batch_size):
+ poly = img_poly[ind == i].unsqueeze(0)
+ gcn_feature[ind == i] = torch.nn.functional.grid_sample(energy_field[i:i + 1], poly)[0].permute(1, 0, 2)
+ return gcn_feature
+
+ def loss_energy_regularization(self, energy_field, img_poly, inds, h, w):
+ energys = []
+ for i, py in enumerate(img_poly):
+ energy = self.get_poly_energy(energy_field.unsqueeze(1), py, inds, h, w)
+ energys.append(energy.squeeze(1).sum(-1))
+
+ regular_loss = torch.tensor(0.)
+ energy_loss = torch.tensor(0.)
+ for i, e in enumerate(energys[1:]):
+ regular_loss += torch.clamp(e - energys[i], min=0.0).mean()
+ energy_loss += torch.where(e <= 0.01, torch.tensor(0.), e).mean()
+
+ return (energy_loss+regular_loss)/len(energys[1:])
+
+ def forward(self, input_dict, output_dict, eps=None):
+ """
+ calculate boundary proposal network loss
+ """
+ # tr_mask = tr_mask.permute(0, 3, 1, 2).contiguous()
+
+ fy_preds = output_dict["fy_preds"]
+ py_preds = output_dict["py_preds"]
+ inds = output_dict["inds"]
+
+ train_mask = input_dict['train_mask']
+ tr_mask = input_dict['tr_mask'] > 0
+ distance_field = input_dict['distance_field']
+ direction_field = input_dict['direction_field']
+ weight_matrix = input_dict['weight_matrix']
+ gt_tags = input_dict['gt_points']
+
+ # # scale the prediction map
+ # fy_preds = F.interpolate(fy_preds, scale_factor=cfg.scale, mode='bilinear')
+
+ if cfg.scale > 1:
+ train_mask = F.interpolate(train_mask.float().unsqueeze(1),
+ scale_factor=1/cfg.scale, mode='bilinear').squeeze().bool()
+ tr_mask = F.interpolate(tr_mask.float().unsqueeze(1),
+ scale_factor=1/cfg.scale, mode='bilinear').squeeze().bool()
+
+ distance_field = F.interpolate(distance_field.unsqueeze(1),
+ scale_factor=1/cfg.scale, mode='bilinear').squeeze()
+ direction_field = F.interpolate(direction_field,
+ scale_factor=1 / cfg.scale, mode='bilinear')
+ weight_matrix = F.interpolate(weight_matrix.unsqueeze(1),
+ scale_factor=1/cfg.scale, mode='bilinear').squeeze()
+
+ # pixel class loss
+ # cls_loss = self.cls_ohem(fy_preds[:, 0, :, :], tr_mask.float(), train_mask)
+ cls_loss = self.BCE_loss(fy_preds[:, 0, :, :], tr_mask.float())
+ cls_loss = torch.mul(cls_loss, train_mask.float()).mean()
+
+ # distance field loss
+ dis_loss = self.MSE_loss(fy_preds[:, 1, :, :], distance_field)
+ dis_loss = torch.mul(dis_loss, train_mask.float())
+ dis_loss = self.single_image_loss(dis_loss, distance_field)
+
+ # # direction field loss
+ norm_loss, angle_loss = self.loss_calc_flux(fy_preds[:, 2:4, :, :], direction_field,
+ weight_matrix, tr_mask, train_mask)
+
+ # boundary point loss
+ point_loss = self.PolyMatchingLoss(py_preds[1:], gt_tags[inds])
+
+ # Minimum energy loss regularization
+ h, w = distance_field.size(1) * cfg.scale, distance_field.size(2) * cfg.scale
+ energy_loss = self.loss_energy_regularization(distance_field, py_preds, inds[0], h, w)
+
+ if eps is None:
+ alpha = 1.0; beta = 3.0; theta=0.5; gama = 0.05
+ else:
+ alpha = 1.0; beta = 3.0; theta=0.5;
+ gama = 0.1*torch.sigmoid(torch.tensor((eps - cfg.max_epoch)/cfg.max_epoch))
+ loss = alpha*cls_loss + beta*dis_loss + theta*(norm_loss + angle_loss) + gama*(point_loss + energy_loss)
+
+ loss_dict = {
+ 'total_loss': loss,
+ 'cls_loss': alpha*cls_loss,
+ 'distance loss': beta*dis_loss,
+ 'dir_loss': theta*(norm_loss + angle_loss),
+ 'norm_loss': theta*norm_loss,
+ 'angle_loss': theta*angle_loss,
+ 'point_loss': gama*point_loss,
+ 'energy_loss': gama*energy_loss,
+
+ }
+
+ return loss_dict
+
diff --git a/IndicPhotoOCR/detection/textbpn/network/loss_org.py b/IndicPhotoOCR/detection/textbpn/network/loss_org.py
new file mode 100644
index 0000000000000000000000000000000000000000..ee12ea7acee5984cf89c6921228eaef0be8bb1bb
--- /dev/null
+++ b/IndicPhotoOCR/detection/textbpn/network/loss_org.py
@@ -0,0 +1,136 @@
+# -*- coding: utf-8 -*-
+# @Time : 10/1/21
+# @Author : GXYM
+import torch
+import torch.nn as nn
+from cfglib.config import config as cfg
+from network.Seg_loss import SegmentLoss
+from network.Reg_loss import PolyMatchingLoss
+
+
+class TextLoss(nn.Module):
+
+ def __init__(self):
+ super().__init__()
+ self.MSE_loss = torch.nn.MSELoss(reduce=False, size_average=False)
+ self.BCE_loss = torch.nn.BCELoss(reduce=False, size_average=False)
+ self.PolyMatchingLoss = PolyMatchingLoss(cfg.num_points, cfg.device)
+ self.KL_loss = torch.nn.KLDivLoss(reduce=False, size_average=False)
+
+ @staticmethod
+ def single_image_loss(pre_loss, loss_label):
+ batch_size = pre_loss.shape[0]
+ sum_loss = torch.mean(pre_loss.view(-1)) * 0
+ pre_loss = pre_loss.view(batch_size, -1)
+ loss_label = loss_label.view(batch_size, -1)
+ eps = 0.001
+ for i in range(batch_size):
+ average_number = 0
+ positive_pixel = len(pre_loss[i][(loss_label[i] >= eps)])
+ average_number += positive_pixel
+ if positive_pixel != 0:
+ posi_loss = torch.mean(pre_loss[i][(loss_label[i] >= eps)])
+ sum_loss += posi_loss
+ if len(pre_loss[i][(loss_label[i] < eps)]) < 3 * positive_pixel:
+ nega_loss = torch.mean(pre_loss[i][(loss_label[i] < eps)])
+ average_number += len(pre_loss[i][(loss_label[i] < eps)])
+ else:
+ nega_loss = torch.mean(torch.topk(pre_loss[i][(loss_label[i] < eps)], 3 * positive_pixel)[0])
+ average_number += 3 * positive_pixel
+ sum_loss += nega_loss
+ else:
+ nega_loss = torch.mean(torch.topk(pre_loss[i], 100)[0])
+ average_number += 100
+ sum_loss += nega_loss
+ # sum_loss += loss/average_number
+
+ return sum_loss/batch_size
+
+ def cls_ohem(self, predict, target, train_mask, negative_ratio=3.):
+ pos = (target * train_mask).bool()
+ neg = ((1 - target) * train_mask).bool()
+
+ n_pos = pos.float().sum()
+
+ if n_pos.item() > 0:
+ loss_pos = self.BCE_loss(predict[pos], target[pos]).sum()
+ loss_neg = self.BCE_loss(predict[neg], target[neg])
+ n_neg = min(int(neg.float().sum().item()), int(negative_ratio * n_pos.float()))
+ else:
+ loss_pos = torch.tensor(0.)
+ loss_neg = self.BCE_loss(predict[neg], target[neg])
+ n_neg = 100
+ loss_neg, _ = torch.topk(loss_neg, n_neg)
+
+ return (loss_pos + loss_neg.sum()) / (n_pos + n_neg).float()
+
+ @staticmethod
+ def loss_calc_flux(pred_flux, gt_flux, weight_matrix, mask, train_mask):
+
+ # norm loss
+ gt_flux = 0.999999 * gt_flux / (gt_flux.norm(p=2, dim=1).unsqueeze(1) + 1e-9)
+ norm_loss = weight_matrix * torch.sum((pred_flux - gt_flux) ** 2, dim=1)*train_mask
+ norm_loss = norm_loss.sum(-1).mean()
+
+ # angle loss
+ mask = train_mask * mask
+ pred_flux = 0.999999 * pred_flux / (pred_flux.norm(p=2, dim=1).unsqueeze(1) + 1e-9)
+ # angle_loss = weight_matrix * (torch.acos(torch.sum(pred_flux * gt_flux, dim=1))) ** 2
+ # angle_loss = angle_loss.sum(-1).mean()
+ angle_loss = (1 - torch.cosine_similarity(pred_flux, gt_flux, dim=1))
+ angle_loss = angle_loss[mask].mean()
+
+ return norm_loss, angle_loss
+
+ def forward(self, input_dict, output_dict, eps=None):
+ """
+ calculate boundary proposal network loss
+ """
+ # tr_mask = tr_mask.permute(0, 3, 1, 2).contiguous()
+
+ fy_preds = output_dict["fy_preds"]
+ py_preds = output_dict["py_preds"]
+ inds = output_dict["inds"]
+
+ train_mask = input_dict['train_mask']
+ tr_mask = input_dict['tr_mask'] > 0
+ distance_field = input_dict['distance_field']
+ direction_field = input_dict['direction_field']
+ weight_matrix = input_dict['weight_matrix']
+ gt_tags = input_dict['gt_points']
+
+ # pixel class loss
+ cls_loss = self.cls_ohem(fy_preds[:, 0, :, :], tr_mask.float(), train_mask.bool())
+
+ # distance field loss
+ dis_loss = self.MSE_loss(fy_preds[:, 1, :, :], distance_field)
+ dis_loss = torch.mul(dis_loss, train_mask.float())
+ dis_loss = self.single_image_loss(dis_loss, distance_field)
+
+ # direction field loss
+ norm_loss, angle_loss = self.loss_calc_flux(fy_preds[:, 2:4, :, :],
+ direction_field, weight_matrix, tr_mask, train_mask)
+
+ # boundary point loss
+ point_loss = self.PolyMatchingLoss(py_preds, gt_tags[inds])
+
+ if eps is None:
+ loss_b = 0.05*point_loss
+ loss = cls_loss + 3.0*dis_loss + norm_loss + angle_loss + loss_b
+ else:
+ loss_b = 0.1*(torch.sigmoid(torch.tensor((eps - cfg.max_epoch)/cfg.max_epoch))) * point_loss
+ loss = cls_loss + 3.0*dis_loss + norm_loss + angle_loss + loss_b
+
+ loss_dict = {
+ 'total_loss': loss,
+ 'cls_loss': cls_loss,
+ 'distance loss': 3.0*dis_loss,
+ 'dir_loss': norm_loss + angle_loss,
+ 'point_loss': loss_b,
+ 'norm_loss': norm_loss,
+ 'angle_loss': angle_loss,
+
+ }
+
+ return loss_dict
+
diff --git a/IndicPhotoOCR/detection/textbpn/network/textnet.py b/IndicPhotoOCR/detection/textbpn/network/textnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..dae70d8453fc606354a66295f1896a78b881477f
--- /dev/null
+++ b/IndicPhotoOCR/detection/textbpn/network/textnet.py
@@ -0,0 +1,216 @@
+# -*- coding: utf-8 -*-
+# @Time : 10/1/21
+# @Author : GXYM
+import torch
+import torch.nn as nn
+from IndicPhotoOCR.detection.textbpn.network.layers.model_block import FPN
+from IndicPhotoOCR.detection.textbpn.cfglib.config import config as cfg
+import numpy as np
+from IndicPhotoOCR.detection.textbpn.network.layers.CircConv import DeepSnake
+from IndicPhotoOCR.detection.textbpn.network.layers.GCN import GCN
+from IndicPhotoOCR.detection.textbpn.network.layers.RNN import RNN
+from IndicPhotoOCR.detection.textbpn.network.layers.Adaptive_Deformation import AdaptiveDeformation
+# from IndicPhotoOCR.detection.textbpn.network.layers.Transformer_old import Transformer_old
+from IndicPhotoOCR.detection.textbpn.network.layers.Transformer import Transformer
+import cv2
+from IndicPhotoOCR.detection.textbpn.util.misc import get_sample_point, fill_hole
+from IndicPhotoOCR.detection.textbpn.network.layers.gcn_utils import get_node_feature, \
+ get_adj_mat, get_adj_ind, coord_embedding, normalize_adj
+import torch.nn.functional as F
+import time
+
+
+class Evolution(nn.Module):
+ def __init__(self, node_num, adj_num, is_training=True, device=None, model="snake"):
+ super(Evolution, self).__init__()
+ self.node_num = node_num
+ self.adj_num = adj_num
+ self.device = device
+ self.is_training = is_training
+ self.clip_dis = 16
+
+ self.iter = 3
+ if model == "gcn":
+ self.adj = get_adj_mat(self.adj_num, self.node_num)
+ self.adj = normalize_adj(self.adj, type="DAD").float().to(self.device)
+ for i in range(self.iter):
+ evolve_gcn = GCN(36, 128)
+ self.__setattr__('evolve_gcn' + str(i), evolve_gcn)
+ elif model == "rnn":
+ self.adj = None
+ for i in range(self.iter):
+ evolve_gcn = RNN(36, 128)
+ self.__setattr__('evolve_gcn' + str(i), evolve_gcn)
+ elif model == "AD":
+ self.adj = get_adj_mat(self.adj_num, self.node_num)
+ self.adj = normalize_adj(self.adj, type="DAD").float().to(self.device)
+ for i in range(self.iter):
+ evolve_gcn = AdaptiveDeformation(36, 128)
+ self.__setattr__('evolve_gcn' + str(i), evolve_gcn)
+ # elif model == "BT_old":
+ # self.adj = None
+ # for i in range(self.iter):
+ # evolve_gcn = Transformer_old(36, 512, num_heads=8,
+ # dim_feedforward=2048, drop_rate=0.0, if_resi=True, block_nums=4)
+ # self.__setattr__('evolve_gcn' + str(i), evolve_gcn)
+ elif model == "BT":
+ self.adj = None
+ for i in range(self.iter):
+ evolve_gcn = Transformer(36, 128, num_heads=8,
+ dim_feedforward=1024, drop_rate=0.0, if_resi=True, block_nums=3)
+ self.__setattr__('evolve_gcn' + str(i), evolve_gcn)
+ else:
+ self.adj = get_adj_ind(self.adj_num, self.node_num, self.device)
+ for i in range(self.iter):
+ evolve_gcn = DeepSnake(state_dim=128, feature_dim=36, conv_type='dgrid')
+ self.__setattr__('evolve_gcn' + str(i), evolve_gcn)
+
+ for m in self.modules():
+ if isinstance(m, nn.Conv1d) or isinstance(m, nn.Conv2d):
+ m.weight.data.normal_(0.0, 0.02)
+ # nn.init.kaiming_normal_(m.weight, mode='fan_in')
+ if m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+
+ @staticmethod
+ def get_boundary_proposal(input=None, seg_preds=None, switch="gt"):
+
+ if switch == "gt":
+ inds = torch.where(input['ignore_tags'] > 0)
+ # if len(inds[0]) > 320:
+ # inds = (inds[0][:320], inds[1][:320])
+ init_polys = input['proposal_points'][inds]
+ else:
+ tr_masks = input['tr_mask'].cpu().numpy()
+ tcl_masks = seg_preds[:, 0, :, :].detach().cpu().numpy() > cfg.threshold
+ inds = []
+ init_polys = []
+ for bid, tcl_mask in enumerate(tcl_masks):
+ ret, labels = cv2.connectedComponents(tcl_mask.astype(np.uint8), connectivity=8)
+ for idx in range(1, ret):
+ text_mask = labels == idx
+ ist_id = int(np.sum(text_mask*tr_masks[bid])/np.sum(text_mask))-1
+ inds.append([bid, ist_id])
+ poly = get_sample_point(text_mask, cfg.num_points, cfg.approx_factor)
+ init_polys.append(poly)
+ inds = torch.from_numpy(np.array(inds)).permute(1, 0).to(input["img"].device)
+ init_polys = torch.from_numpy(np.array(init_polys)).to(input["img"].device)
+
+ return init_polys, inds, None
+
+ def get_boundary_proposal_eval(self, input=None, seg_preds=None):
+
+ # if cfg.scale > 1:
+ # seg_preds = F.interpolate(seg_preds, scale_factor=cfg.scale, mode='bilinear')
+ cls_preds = seg_preds[:, 0, :, :].detach().cpu().numpy()
+ dis_preds = seg_preds[:, 1, :, ].detach().cpu().numpy()
+
+ inds = []
+ init_polys = []
+ confidences = []
+ for bid, dis_pred in enumerate(dis_preds):
+ # # dis_mask = (dis_pred / np.max(dis_pred)) > cfg.dis_threshold
+ dis_mask = dis_pred > cfg.dis_threshold
+ # dis_mask = fill_hole(dis_mask)
+ ret, labels = cv2.connectedComponents(dis_mask.astype(np.uint8), connectivity=8, ltype=cv2.CV_16U)
+ for idx in range(1, ret):
+ text_mask = labels == idx
+ confidence = round(cls_preds[bid][text_mask].mean(), 3)
+ # 50 for MLT2017 and ArT (or DCN is used in backone); else is all 150;
+ # just can set to 50, which has little effect on the performance
+ if np.sum(text_mask) < 50/(cfg.scale*cfg.scale) or confidence < cfg.cls_threshold:
+ continue
+ confidences.append(confidence)
+ inds.append([bid, 0])
+
+ poly = get_sample_point(text_mask, cfg.num_points,
+ cfg.approx_factor, scales=np.array([cfg.scale, cfg.scale]))
+ init_polys.append(poly)
+
+ if len(inds) > 0:
+ inds = torch.from_numpy(np.array(inds)).permute(1, 0).to(input["img"].device, non_blocking=True)
+ init_polys = torch.from_numpy(np.array(init_polys)).to(input["img"].device, non_blocking=True).float()
+ else:
+ init_polys = torch.from_numpy(np.array(init_polys)).to(input["img"].device, non_blocking=True).float()
+ inds = torch.from_numpy(np.array(inds)).to(input["img"].device, non_blocking=True)
+
+ return init_polys, inds, confidences
+
+ def evolve_poly(self, snake, cnn_feature, i_it_poly, ind):
+ if len(i_it_poly) == 0:
+ return torch.zeros_like(i_it_poly)
+ h, w = cnn_feature.size(2)*cfg.scale, cnn_feature.size(3)*cfg.scale
+ node_feats = get_node_feature(cnn_feature, i_it_poly, ind, h, w)
+ i_poly = i_it_poly + torch.clamp(snake(node_feats, self.adj).permute(0, 2, 1), -self.clip_dis, self.clip_dis)
+ if self.is_training:
+ i_poly = torch.clamp(i_poly, 0, w-1)
+ else:
+ i_poly[:, :, 0] = torch.clamp(i_poly[:, :, 0], 0, w - 1)
+ i_poly[:, :, 1] = torch.clamp(i_poly[:, :, 1], 0, h - 1)
+ return i_poly
+
+ def forward(self, embed_feature, input=None, seg_preds=None, switch="gt"):
+ if self.is_training:
+ init_polys, inds, confidences = self.get_boundary_proposal(input=input, seg_preds=seg_preds, switch=switch)
+ # TODO sample fix number
+ else:
+ init_polys, inds, confidences = self.get_boundary_proposal_eval(input=input, seg_preds=seg_preds)
+ if init_polys.shape[0] == 0:
+ return [init_polys for i in range(self.iter+1)], inds, confidences
+
+ py_preds = [init_polys, ]
+ for i in range(self.iter):
+ evolve_gcn = self.__getattr__('evolve_gcn' + str(i))
+ init_polys = self.evolve_poly(evolve_gcn, embed_feature, init_polys, inds[0])
+ py_preds.append(init_polys)
+
+ return py_preds, inds, confidences
+
+
+class TextNet(nn.Module):
+
+ def __init__(self, backbone='vgg', is_training=True):
+ super().__init__()
+ self.is_training = is_training
+ self.backbone_name = backbone
+ self.fpn = FPN(self.backbone_name, is_training=(not cfg.resume and is_training))
+
+ self.seg_head = nn.Sequential(
+ nn.Conv2d(32, 16, kernel_size=3, padding=2, dilation=2),
+ nn.PReLU(),
+ nn.Conv2d(16, 16, kernel_size=3, padding=4, dilation=4),
+ nn.PReLU(),
+ nn.Conv2d(16, 4, kernel_size=1, stride=1, padding=0),
+ )
+ self.BPN = Evolution(cfg.num_points, adj_num=4,
+ is_training=is_training, device=cfg.device, model="BT")
+
+ def load_model(self, model_path):
+ print('Loading from {}'.format(model_path))
+ state_dict = torch.load(model_path, map_location=torch.device(cfg.device))
+ self.load_state_dict(state_dict['model'], strict=(not self.is_training))
+
+ def forward(self, input_dict, test_speed=False):
+ output = {}
+ b, c, h, w = input_dict["img"].shape
+ if self.is_training or cfg.exp_name in ['ArT', 'MLT2017', "MLT2019"] or test_speed:
+ image = input_dict["img"]
+ else:
+ image = torch.zeros((b, c, cfg.test_size[1], cfg.test_size[1]), dtype=torch.float32).to(cfg.device)
+ image[:, :, :h, :w] = input_dict["img"][:, :, :, :]
+
+ up1, _, _, _, _ = self.fpn(image)
+ up1 = up1[:, :, :h // cfg.scale, :w // cfg.scale]
+
+ preds = self.seg_head(up1)
+ fy_preds = torch.cat([torch.sigmoid(preds[:, 0:2, :, :]), preds[:, 2:4, :, :]], dim=1)
+ cnn_feats = torch.cat([up1, fy_preds], dim=1)
+
+ py_preds, inds, confidences = self.BPN(cnn_feats, input=input_dict, seg_preds=fy_preds, switch="gt")
+
+ output["fy_preds"] = fy_preds
+ output["py_preds"] = py_preds
+ output["inds"] = inds
+ output["confidences"] = confidences
+
+ return output
diff --git a/IndicPhotoOCR/detection/textbpn/output.png b/IndicPhotoOCR/detection/textbpn/output.png
new file mode 100644
index 0000000000000000000000000000000000000000..865e3cd37791f717b44344b2e9971898f4004b30
--- /dev/null
+++ b/IndicPhotoOCR/detection/textbpn/output.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:44b8104e8e5d470e051d4b568214e3591e098f2677d1c0e1a0d6594e2d049636
+size 8695740
diff --git a/IndicPhotoOCR/detection/textbpn/textbpnpp_detector.py b/IndicPhotoOCR/detection/textbpn/textbpnpp_detector.py
new file mode 100644
index 0000000000000000000000000000000000000000..f360a23918e941e97ce52a3cf0695ad774f44f57
--- /dev/null
+++ b/IndicPhotoOCR/detection/textbpn/textbpnpp_detector.py
@@ -0,0 +1,197 @@
+import torch
+import cv2
+import numpy as np
+from IndicPhotoOCR.detection.textbpn.network.textnet import TextNet
+from IndicPhotoOCR.detection.textbpn.cfglib.config import config as cfg
+import warnings
+import os
+import requests
+from tqdm import tqdm
+
+# Suppress warnings
+warnings.filterwarnings("ignore")
+
+model_info = {
+ "textbpnpp": {
+ "path": "models/TextBPN_resnet50_300.pth",
+ "url" : "https://github.com/Bhashini-IITJ/SceneTextDetection/releases/download/TextBPN%2B%2B/TextBPN_resnet50_300.pth",
+ },
+ "textbpnpp_deformable": {
+ "path":"models/TextBPN_deformable_resnet50_300.pth",
+ "url": "https://github.com/Bhashini-IITJ/SceneTextDetection/releases/download/TextBPN%2B%2B/TextBPN_deformable_resnet50_300.pth",
+ },
+ "textbpn_resnet18" : {
+ "path":"models/TextBPN_resnet18_300.pth",
+ "url": "https://github.com/Bhashini-IITJ/SceneTextDetection/releases/download/TextBPN%2B%2B/TextBPN_resnet18_300.pth",
+
+ }
+}
+ # Ensure model file exists; download directly if not
+def ensure_model(model_name):
+ model_path = model_info[model_name]["path"]
+ url = model_info[model_name]["url"]
+ root_model_dir = "IndicPhotoOCR/detection/textbpn"
+ model_path = os.path.join(root_model_dir, model_path)
+
+ if not os.path.exists(model_path):
+ print(f"Model not found locally. Downloading {model_name} from {url}...")
+
+ # Start the download with a progress bar
+ response = requests.get(url, stream=True)
+ total_size = int(response.headers.get('content-length', 0))
+ os.makedirs(f"{root_model_dir}/models", exist_ok=True)
+
+ with open(model_path, "wb") as f, tqdm(
+ desc=model_name,
+ total=total_size,
+ unit='B',
+ unit_scale=True,
+ unit_divisor=1024,
+ ) as bar:
+ for data in response.iter_content(chunk_size=1024):
+ f.write(data)
+ bar.update(len(data))
+
+ print(f"Downloaded model for {model_name}.")
+
+ return model_path
+
+class TextBPNpp_detector:
+ def __init__(self, model_name="textbpnpp", backbone="resnet50", device="cpu"):
+ """
+ Initialize the TextBPN model.
+ :param model_path: Path to the pre-trained model.
+ :param backbone: Backbone architecture (default: "resnet50").
+ :param device: Device to run the model on (default: "cpu").
+ """
+ self.model_path = ensure_model(model_name)
+ self.device = torch.device(device)
+ self.model = TextNet(is_training=False, backbone=backbone)
+ self.model.load_model(self.model_path)
+ self.model.eval()
+ self.model.to(self.device)
+
+ @staticmethod
+ def to_device(tensor, device):
+ """
+ Move tensor to the specified device.
+ :param tensor: Tensor to move.
+ :param device: Target device.
+ :return: Tensor on the target device.
+ """
+ return tensor.to(device, non_blocking=True)
+
+ @staticmethod
+ def pad_image(image, stride=32):
+ """
+ Pad the image to make its dimensions divisible by the stride.
+ :param image: Input image.
+ :param stride: Stride size.
+ :return: Padded image and original dimensions.
+ """
+ h, w = image.shape[:2]
+ new_h = (h + stride - 1) // stride * stride
+ new_w = (w + stride - 1) // stride * stride
+ padded_image = cv2.copyMakeBorder(
+ image, 0, new_h - h, 0, new_w - w, cv2.BORDER_CONSTANT, value=(0, 0, 0)
+ )
+ return padded_image, (h, w)
+
+ @staticmethod
+ def rescale_result(image, bbox_contours, original_height, original_width):
+ """
+ Rescale the bounding box contours to the original image size.
+ :param image: Image after resizing.
+ :param bbox_contours: Bounding box contours.
+ :param original_height: Original image height.
+ :param original_width: Original image width.
+ :return: Original image and rescaled contours.
+ """
+ contours = []
+ for cont in bbox_contours:
+ cont[:, 0] = (cont[:, 0] * original_width / image.shape[1]).astype(int)
+ cont[:, 1] = (cont[:, 1] * original_height / image.shape[0]).astype(int)
+ contours.append(cont)
+ return contours
+
+ def detect(self, image_path):
+ """
+ Perform text detection on the given image.
+ :param image_path: Path to the input image.
+ :return: Dictionary with detection results.
+ """
+ image = cv2.imread(image_path)
+ if image is None:
+ raise ValueError(f"Failed to read the image at {image_path}")
+
+ padded_image, original_size = self.pad_image(image)
+ padded_tensor = (
+ torch.from_numpy(padded_image).permute(2, 0, 1).float() / 255.0
+ ).unsqueeze(0) # Convert to tensor and add batch dimension
+
+ cfg.test_size = [padded_image.shape[0], padded_image.shape[1]]
+
+ input_dict = {"img": self.to_device(padded_tensor, self.device)}
+ with torch.no_grad():
+ output_dict = self.model(input_dict, padded_image.shape)
+
+ contours = output_dict["py_preds"][-1].int().cpu().numpy()
+ contours = self.rescale_result(image, contours, *original_size)
+
+ bbox_result_dict = {"detections": []}
+ for contour in contours:
+ # x_min, y_min = np.min(contour, axis=0)
+ # x_max, y_max = np.max(contour, axis=0)
+ # bbox_result_dict["detections"].append([x_min, y_min, x_max, y_max])
+ bbox_result_dict["detections"].append(contour.tolist())
+
+ return bbox_result_dict
+
+ def visualize_detections(self, image_path, bbox_result_dict, output_path="output.png"):
+ """
+ Visualize detections on the image.
+ :param image_path: Path to the input image.
+ :param bbox_result_dict: Detection results in the format:
+ {'detections': [[[x1, y1], [x2, y2], [x3, y3], [x4, y4]], ...]}.
+ :param output_path: Path to save the visualized image. If None, the image is only displayed.
+ """
+ # Load the image
+ image = cv2.imread(image_path)
+ if image is None:
+ raise ValueError(f"Failed to read the image at {image_path}")
+
+ # Draw each detection
+ for bbox in bbox_result_dict.get("detections", []):
+ points = np.array(bbox, dtype=np.int32) # Convert to numpy array
+ cv2.polylines(image, [points], isClosed=True, color=(0, 255, 0), thickness=2)
+
+ # Display or save the visualized image
+ if output_path:
+ cv2.imwrite(output_path, image)
+ print(f"Visualization saved to {output_path}")
+ else:
+ cv2.imshow("Detections", image)
+ cv2.waitKey(0)
+ cv2.destroyAllWindows()
+
+if __name__ == "__main__":
+ import argparse
+ parser = argparse.ArgumentParser(description='Text detection using EAST model')
+ parser.add_argument('--image_path', type=str, required=True, help='Path to the input image')
+ parser.add_argument('--device', type=str, default='cpu', help='Device to run the model on, e.g., "cpu" or "cuda"')
+ parser.add_argument('--model_name', type=str, required=True, help='Path to the model checkpoint file')
+ args = parser.parse_args()
+
+
+
+ # model_path = "/DATA1/ocrteam/anik/git/IndicPhotoOCR/IndicPhotoOCR/detection/textbpn/models/TextBPN_resnet50_300.pth"
+ # image_path = "/DATA1/ocrteam/anik/splitonBSTD/detection/D/image_542.jpg"
+
+ detector = TextBPNpp_detector(args.model_name, device="cpu")
+ result = detector.detect(args.image_path)
+ print(result)
+ # detector.visualize_detections(image_path, result)
+
+ # python -m IndicPhotoOCR.detection.textbpn.textbpnpp_detector \
+ # --image_path /DATA1/ocrteam/anik/splitonBSTD/detection/D/image_542.jpg \
+ # --model_name textbpnpp
\ No newline at end of file
diff --git a/IndicPhotoOCR/detection/textbpn/util/__init__.py b/IndicPhotoOCR/detection/textbpn/util/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..4df2e390c73c946d5974e63c65456acc55b7a080
--- /dev/null
+++ b/IndicPhotoOCR/detection/textbpn/util/__init__.py
@@ -0,0 +1,2 @@
+from .visualize import *
+from .pbox import *
diff --git a/IndicPhotoOCR/detection/textbpn/util/augmentation.py b/IndicPhotoOCR/detection/textbpn/util/augmentation.py
new file mode 100644
index 0000000000000000000000000000000000000000..09fe7f348f10d44903c5e0914b65903d91280b08
--- /dev/null
+++ b/IndicPhotoOCR/detection/textbpn/util/augmentation.py
@@ -0,0 +1,794 @@
+# -*- coding: utf-8 -*-
+__author__ = "S.X.Zhang"
+import numpy as np
+import math
+import cv2
+import copy
+import numpy.random as random
+from shapely.geometry import Polygon
+import torchvision.transforms as transforms
+import torchvision.transforms.functional as F
+from PIL import ImageEnhance, Image
+
+
+###<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<###
+###<<<<<<<<< Function >>>>>>>>>>>>###
+###>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>###
+def crop_first(image, polygons, scale =10):
+ polygons_new = copy.deepcopy(polygons)
+ h, w, _ = image.shape
+ pad_h = h // scale
+ pad_w = w // scale
+ h_array = np.zeros((h + pad_h * 2), dtype=np.int32)
+ w_array = np.zeros((w + pad_w * 2), dtype=np.int32)
+
+ text_polys = []
+ pos_polys = []
+ for polygon in polygons_new:
+ rect = cv2.minAreaRect(polygon.points.astype(np.int32))
+ box = cv2.boxPoints(rect)
+ box = np.int0(box)
+ text_polys.append([box[0], box[1], box[2], box[3]])
+ if polygon.label != -1:
+ pos_polys.append([box[0], box[1], box[2], box[3]])
+
+ polys = np.array(text_polys, dtype=np.int32)
+ for poly in polys:
+ poly = np.round(poly, decimals=0).astype(np.int32) # 四舍五入
+ minx = np.min(poly[:, 0])
+ maxx = np.max(poly[:, 0])
+ w_array[minx + pad_w:maxx + pad_w] = 1
+ miny = np.min(poly[:, 1])
+ maxy = np.max(poly[:, 1])
+ h_array[miny + pad_h:maxy + pad_h] = 1
+ # ensure the cropped area not across a text 保证截取区域不会横穿文字
+ h_axis = np.where(h_array == 0)[0]
+ w_axis = np.where(w_array == 0)[0]
+ pp_polys = np.array(pos_polys, dtype=np.int32)
+
+ return h_axis, w_axis, pp_polys
+
+####<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<####
+####<<<<<<<<<<< Class >>>>>>>>>>>>>####
+####>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>####
+class Compose(object):
+ """Composes several augmentations together.
+ Args:
+ transforms (List[Transform]): list of transforms to compose.
+ Example:
+ >>> augmentations.Compose([
+ >>> transforms.CenterCrop(10),
+ >>> transforms.ToTensor(),
+ >>> ])
+ """
+
+ def __init__(self, transforms):
+ self.transforms = transforms
+
+ def __call__(self, img, pts=None):
+ for t in self.transforms:
+ img, pts = t(img, pts)
+ return img, pts
+
+
+class Normalize(object):
+ def __init__(self, mean, std):
+ self.mean = np.array(mean)
+ self.std = np.array(std)
+
+ def __call__(self, image, polygons=None):
+ image = image.astype(np.float32)
+ image /= 255.0
+ image -= self.mean
+ image /= self.std
+ return image, polygons
+
+
+class MinusMean(object):
+ def __init__(self, mean):
+ self.mean = np.array(mean)
+
+ def __call__(self, image, polygons=None):
+ image = image.astype(np.float32)
+ image -= self.mean
+ return image, polygons
+
+
+class RandomMirror(object):
+ # 镜像
+ def __init__(self):
+ pass
+
+ def __call__(self, image, polygons=None):
+ if polygons is None:
+ return image, polygons
+ if random.random()< 0.3:
+ image = np.ascontiguousarray(image[:, ::-1])
+ _, width, _ = image.shape
+ for polygon in polygons:
+ polygon.points[:, 0] = width - polygon.points[:, 0]
+ return image, polygons
+
+
+class AugmentColor(object):
+ # 颜色增强(添加噪声)
+ def __init__(self):
+ self.U = np.array([[-0.56543481, 0.71983482, 0.40240142],
+ [-0.5989477, -0.02304967, -0.80036049],
+ [-0.56694071, -0.6935729, 0.44423429]], dtype=np.float32)
+ self.EV = np.array([1.65513492, 0.48450358, 0.1565086], dtype=np.float32)
+ self.sigma = 0.1
+ self.color_vec = None
+
+ def __call__(self, img, polygons=None):
+ color_vec = self.color_vec
+ if self.color_vec is None:
+ if not self.sigma > 0.0:
+ color_vec = np.zeros(3, dtype=np.float32)
+ else:
+ color_vec = np.random.normal(0.0, self.sigma, 3)
+
+ alpha = color_vec.astype(np.float32) * self.EV
+ noise = np.dot(self.U, alpha.T) * 255
+ return np.clip(img + noise[np.newaxis, np.newaxis, :], 0, 255), polygons
+
+
+class RandomContrast(object):
+ def __init__(self, lower=0.5, upper=1.5):
+ self.lower = lower
+ self.upper = upper
+ assert self.upper >= self.lower, "contrast upper must be >= lower."
+ assert self.lower >= 0, "contrast lower must be non-negative."
+
+ # expects float image
+ def __call__(self, image, polygons=None):
+ if random.randint(2):
+ alpha = random.uniform(self.lower, self.upper)
+ image *= alpha
+ return np.clip(image, 0, 255), polygons
+
+
+class RandomBrightness(object):
+ def __init__(self, delta=32):
+ assert delta >= 0.0
+ assert delta <= 255.0
+ self.delta = delta
+
+ def __call__(self, image, polygons=None):
+ image = image.astype(np.float32)
+ if random.randint(2):
+ delta = random.uniform(-self.delta, self.delta)
+ image += delta
+ return np.clip(image, 0, 255), polygons
+
+
+class RandomErasing(object):
+ def __init__(self, sr=(0.0004, 0.01), scale=(0.5, 3), ratio=0.2, Type ="Erasing"):
+ """
+
+ :param area:
+ :param type: Erasing or Cutout
+ """
+ self.sr = sr
+ self.scale= scale
+ self.ratio=ratio
+ self.type=Type
+
+ def __call__(self, img, polygons=None):
+
+ if random.random()< self.ratio:
+ return img, polygons
+ area=img.shape[0]*img.shape[1]
+ target_area=random.randint(*self.sr)*area
+ aspect_ratio=random.uniform(*self.scale)
+ h = int(round(math.sqrt(target_area / aspect_ratio)))
+ w = int(round(math.sqrt(target_area * aspect_ratio)))
+
+ if w < img.shape[1] and h < img.shape[0]:
+ x1 = random.randint(0, img.shape[1] - w)
+ y1 = random.randint(0, img.shape[0] - h)
+ if self.type == "Erasing":
+ color=(random.randint(0, 255),random.randint(0, 255),random.randint(0, 255))
+ img[y1:y1+h, x1:x1+h,:]=color
+ else:
+ Gray_value=random.randint(0, 255)
+ color = (Gray_value, Gray_value ,Gray_value)
+ img[y1:y1 + h, x1:x1 + h, :] = color
+
+ return img, polygons
+
+
+class RandomMixUp(object):
+ def __init__(self, mixup_alpha=2):
+ self.mixup_alpha = mixup_alpha
+
+ def __call__(self, img1, img2, label1=[], label2=[]):
+ beta=np.random.beta(self.mixup_alpha,self.mixup_alpha)
+
+ #image = img1 * Gama + (1 - Gama) * img2
+ image=cv2.addWeighted(img1, beta, img2, (1-beta), 0)
+
+ if label1 is None or label2 is None:
+ return img1, label1
+ if isinstance(label1, list) and isinstance(label2, list):
+ label=[]
+ for id in range(len(label1)):
+ lab = beta*label1[id]+ (1-beta)*label2[id]
+ label.append(lab)
+ return image, label
+ else:
+ print("Error: label is not a list type")
+
+ return img1, label1
+
+
+class Rotate(object):
+ def __init__(self, up=30):
+ self.up = up
+
+ @staticmethod
+ def rotate(center, pt, theta): # 二维图形学的旋转
+ xr, yr = center
+ yr = -yr
+ x, y = pt[:, 0], pt[:, 1]
+ y = -y
+
+ theta = theta / 180 * math.pi
+ cos = math.cos(theta)
+ sin = math.sin(theta)
+
+ _x = xr + (x - xr) * cos - (y - yr) * sin
+ _y = yr + (x - xr) * sin + (y - yr) * cos
+
+ return _x, -_y
+
+ def __call__(self, img, polygons=None):
+ if np.random.randint(2):
+ return img, polygons
+ angle = np.random.normal(loc=0.0, scale=0.5) * self.up # angle 按照高斯分布
+ rows, cols = img.shape[0:2]
+ M = cv2.getRotationMatrix2D((cols / 2, rows / 2), angle, 1.0)
+ img = cv2.warpAffine(img, M, (cols, rows), borderValue=[0, 0, 0])
+ center = cols / 2.0, rows / 2.0
+ if polygons is not None:
+ for polygon in polygons:
+ x, y = self.rotate(center, polygon.points, angle)
+ pts = np.vstack([x, y]).T
+ polygon.points = pts
+ return img, polygons
+
+
+class RotatePadding(object):
+ def __init__(self, up=60,colors=True):
+ self.up = up
+ self.colors = colors
+ self.ratio = 0.5
+
+ @staticmethod
+ def rotate(center, pt, theta, movSize=[0, 0], scale=1): # 二维图形学的旋转
+ (xr, yr) = center
+ yr = -yr
+ x, y = pt[:, 0], pt[:, 1]
+ y = -y
+
+ theta = theta / 180 * math.pi
+ cos = math.cos(theta)
+ sin = math.sin(theta)
+
+ x = (x - xr) * scale
+ y = (y - yr) * scale
+
+ _x = xr + x * cos - y * sin + movSize[0]
+ _y = -(yr + x * sin + y * cos) + movSize[1]
+
+ return _x, _y
+
+ @staticmethod
+ def shift(size, degree):
+ angle = degree * math.pi / 180.0
+ width = size[0]
+ height = size[1]
+
+ alpha = math.cos(angle)
+ beta = math.sin(angle)
+ new_width = int(width * math.fabs(alpha) + height * math.fabs(beta))
+ new_height = int(width * math.fabs(beta) + height * math.fabs(alpha))
+
+ size = [new_width, new_height]
+ return size
+
+ def __call__(self, image, polygons=None, scale=1.0):
+ if np.random.random() <= self.ratio:
+ return image, polygons
+ angle = np.random.normal(loc=0.0, scale=0.5) * self.up # angle 按照高斯分布
+ rows, cols = image.shape[0:2]
+ center = (cols / 2.0, rows / 2.0)
+ newSize = self.shift([cols * scale, rows * scale], angle)
+ movSize = [int((newSize[0] - cols) / 2), int((newSize[1] - rows) / 2)]
+
+ M = cv2.getRotationMatrix2D(center, angle, scale)
+ M[0, 2] += int((newSize[0] - cols) / 2)
+ M[1, 2] += int((newSize[1] - rows) / 2)
+
+ if self.colors:
+ H, W, _ = image.shape
+ mask = np.zeros_like(image)
+ (h_index, w_index) = (np.random.randint(0, H * 7 // 8), np.random.randint(0, W * 7 // 8))
+ img_cut = image[h_index:(h_index + H // 9), w_index:(w_index + W // 9)]
+ img_cut = cv2.resize(img_cut, (newSize[0], newSize[1]))
+ mask = cv2.warpAffine(mask, M, (newSize[0], newSize[1]), borderValue=[1, 1, 1])
+ image = cv2.warpAffine(image, M, (newSize[0], newSize[1]), borderValue=[0,0,0])
+ image=image+img_cut*mask
+ else:
+ color = [0, 0, 0]
+ image = cv2.warpAffine(image, M, (newSize[0], newSize[1]), borderValue=color)
+
+ if polygons is not None:
+ for polygon in polygons:
+ x, y = self.rotate(center, polygon.points, angle,movSize,scale)
+ pts = np.vstack([x, y]).T
+ polygon.points = pts
+ return image, polygons
+
+
+class SquarePadding(object):
+
+ def __call__(self, image, polygons=None):
+
+ H, W, _ = image.shape
+
+ if H == W:
+ return image, polygons
+
+ padding_size = max(H, W)
+ (h_index, w_index) = (np.random.randint(0, H*7//8),np.random.randint(0, W*7//8))
+ img_cut = image[h_index:(h_index+H//9),w_index:(w_index+W//9)]
+ expand_image = cv2.resize(img_cut,(padding_size, padding_size))
+ #expand_image = np.zeros((padding_size, padding_size, 3), dtype=image.dtype)
+ #expand_image=img_cut[:,:,:]
+ if H > W:
+ y0, x0 = 0, (H - W) // 2
+ else:
+ y0, x0 = (W - H) // 2, 0
+ if polygons is not None:
+ for polygon in polygons:
+ polygon.points += np.array([x0, y0])
+ expand_image[y0:y0+H, x0:x0+W] = image
+ image = expand_image
+
+ return image, polygons
+
+
+class RandomImgCropPatch(object):
+ def __init__(self, up=30, beta=0.3):
+ self.up = up
+ self.beta=0.3
+ self.scale = 10
+
+ @staticmethod
+ def get_contour_min_area_box(contour):
+ rect = cv2.minAreaRect(contour)
+ box = cv2.boxPoints(rect)
+ box = np.int0(box)
+ return box
+
+ def CropWH(self, image, cut_w, cut_h, polygons=None):
+ h_axis, w_axis, polys = crop_first(image, polygons, scale=self.scale)
+ h, w, _ = image.shape
+ pad_h = h // self.scale
+ pad_w = w // self.scale
+ # TODO try Flip
+ xx = np.random.choice(w_axis, size=2)
+ xmin = np.min(xx) - pad_w
+ xmax = xmin + cut_w
+ yy = np.random.choice(h_axis, size=2)
+ ymin = np.min(yy) - pad_h
+ ymax = ymin + cut_h
+ if polys.shape[0] != 0:
+ poly_axis_in_area = (polys[:, :, 0] >= xmin) & (polys[:, :, 0] <= xmax) \
+ & (polys[:, :, 1] >= ymin) & (polys[:, :, 1] <= ymax)
+ selected_polys = np.where(np.sum(poly_axis_in_area, axis=1) == 4)[0]
+ else:
+ selected_polys = []
+
+ cropped = image[ymin:ymax + 1, xmin:xmax + 1, :]
+ polygons_new = []
+ for idx in selected_polys:
+ polygon = polygons[idx]
+ polygon.points -= np.array([xmin, ymin])
+ polygons_new.append(polygon)
+ image = cropped
+ polygon = polygons_new
+
+ return image, polygon
+
+ def __call__(self, images, polygons_list=None):
+ I_x, I_y = 1024,1024
+
+ w = int(round(I_x * random.beta(self.beta, self.beta)))
+ h = int(round(I_y * random.beta(self.beta, self.beta)))
+ w_ = [w, I_x - w, w, I_x - w]
+ h_ = [h, h, I_y - h, I_y - h]
+ new_img = np.zeros((I_x, I_y, 3), dtype=images[0].dtype)
+ imgs=[]
+ new_polygons=[]
+ for i, im in enumerate(images):
+ img, polygons = self.CropWH(im, w_[i], h_[i], polygons=polygons_list[i])
+ imgs.append(img)
+ new_polygons.append(polygons)
+ new_img[0:w, 0:h, :] = imgs[0]
+ new_img[w:I_x, 0:h, :] = imgs[1]
+ new_img[0:w, h:I_y, :] = imgs[2]
+ new_img[w:I_x, h:I_y, :] = imgs[3]
+ for polygon in new_polygons[1]:
+ polygon.points += np.array([w, 0])
+ for polygon in new_polygons[2]:
+ polygon.points += np.array([0, h])
+ for polygon in new_polygons[3]:
+ polygon.points += np.array([w, h])
+
+ polygons=new_polygons[0]+new_polygons[1]+new_polygons[2]+new_polygons[3]
+
+ return new_img, polygons
+
+
+class RandomCropFlip(object):
+
+ def __init__(self, min_crop_side_ratio=0.01):
+ self.scale = 10
+ self.ratio = 0.2
+ self.epsilon = 10.0
+ self.min_crop_side_ratio = min_crop_side_ratio
+
+ def __call__(self, image, polygons=None):
+
+ if polygons is None:
+ return image, polygons
+
+ if np.random.random() <= self.ratio:
+ return image, polygons
+
+ # 计算 有效的Crop区域, 方便选取有效的种子点
+ h_axis, w_axis, pp_polys = crop_first(image, polygons, scale =self.scale)
+ if len(h_axis) == 0 or len(w_axis) == 0:
+ return image, polygons
+
+ # TODO try crop
+ attempt = 0
+ h, w, _ = image.shape
+ area = h * w
+ pad_h = h // self.scale
+ pad_w = w // self.scale
+ while attempt < 10:
+ attempt += 1
+ polygons_new = []
+ xx = np.random.choice(w_axis, size=2)
+ xmin = np.min(xx) - pad_w
+ xmax = np.max(xx) - pad_w
+ xmin = np.clip(xmin, 0, w - 1)
+ xmax = np.clip(xmax, 0, w - 1)
+ yy = np.random.choice(h_axis, size=2)
+ ymin = np.min(yy) - pad_h
+ ymax = np.max(yy) - pad_h
+ ymin = np.clip(ymin, 0, h - 1)
+ ymax = np.clip(ymax, 0, h - 1)
+ if (xmax - xmin) * (ymax - ymin) < area * self.min_crop_side_ratio:
+ # area too small
+ continue
+
+ pts = np.stack([[xmin, xmax, xmax, xmin], [ymin, ymin, ymax, ymax]]).T.astype(np.int32)
+ pp = Polygon(pts).buffer(0)
+ Fail_flag = False
+ for polygon in polygons:
+ ppi = Polygon(polygon.points).buffer(0)
+ ppiou = float(ppi.intersection(pp).area)
+ if np.abs(ppiou - float(ppi.area)) > self.epsilon and np.abs(ppiou) > self.epsilon:
+ Fail_flag = True
+ break
+ if np.abs(ppiou - float(ppi.area)) < self.epsilon:
+ polygons_new.append(polygon)
+
+ if Fail_flag:
+ continue
+ else:
+ break
+
+ if len(polygons_new) == 0:
+ cropped = image[ymin:ymax, xmin:xmax, :]
+ select_type = random.randint(3)
+ if select_type == 0:
+ img = np.ascontiguousarray(cropped[:, ::-1])
+ elif select_type == 1:
+ img = np.ascontiguousarray(cropped[::-1, :])
+ else:
+ img = np.ascontiguousarray(cropped[::-1, ::-1])
+ image[ymin:ymax, xmin:xmax, :] = img
+ return image, polygons
+
+ else:
+ cropped = image[ymin:ymax, xmin:xmax, :]
+ height, width, _ = cropped.shape
+ select_type = random.randint(3)
+ if select_type == 0:
+ img = np.ascontiguousarray(cropped[:, ::-1])
+ for polygon in polygons_new:
+ polygon.points[:, 0] = width - polygon.points[:, 0] + 2 * xmin
+ elif select_type == 1:
+ img = np.ascontiguousarray(cropped[::-1, :])
+ for polygon in polygons_new:
+ polygon.points[:, 1] = height - polygon.points[:, 1] + 2 * ymin
+ else:
+ img = np.ascontiguousarray(cropped[::-1, ::-1])
+ for polygon in polygons_new:
+ polygon.points[:, 0] = width - polygon.points[:, 0] + 2 * xmin
+ polygon.points[:, 1] = height - polygon.points[:, 1] + 2 * ymin
+ image[ymin:ymax, xmin:xmax, :] = img
+
+ return image, polygons
+
+
+class RandomResizedCrop(object):
+ def __init__(self, min_crop_side_ratio=0.1):
+ self.scale = 10
+ self.epsilon = 1e-2
+ self.min_crop_side_ratio = min_crop_side_ratio
+
+ def __call__(self, image, polygons):
+
+ if polygons is None:
+ return image, polygons
+
+ # 计算 有效的Crop区域, 方便选取有效的种子点
+ h_axis, w_axis, pp_polys = crop_first(image, polygons, scale =self.scale)
+ if len(h_axis) == 0 or len(w_axis) == 0:
+ return image, polygons
+
+ # TODO try crop
+ attempt = 0
+ h, w, _ = image.shape
+ area = h * w
+ pad_h = h // self.scale
+ pad_w = w // self.scale
+ while attempt < 10:
+ attempt += 1
+ xx = np.random.choice(w_axis, size=2)
+ xmin = np.min(xx) - pad_w
+ xmax = np.max(xx) - pad_w
+ xmin = np.clip(xmin, 0, w - 1)
+ xmax = np.clip(xmax, 0, w - 1)
+ yy = np.random.choice(h_axis, size=2)
+ ymin = np.min(yy) - pad_h
+ ymax = np.max(yy) - pad_h
+ ymin = np.clip(ymin, 0, h - 1)
+ ymax = np.clip(ymax, 0, h - 1)
+ if (xmax - xmin)*(ymax - ymin) = xmin) & (pp_polys[:, :, 0] <= xmax) \
+ & (pp_polys[:, :, 1] >= ymin) & (pp_polys[:, :, 1] <= ymax)
+ selected_polys = np.where(np.sum(poly_axis_in_area, axis=1) == 4)[0]
+ else:
+ selected_polys = []
+
+ if len(selected_polys) == 0:
+ continue
+ else:
+ pts = np.stack([[xmin, xmax, xmax, xmin], [ymin, ymin, ymax, ymax]]).T.astype(np.int32)
+ pp = Polygon(pts).buffer(0)
+ polygons_new = []
+ Fail_flag = False
+ for polygon in copy.deepcopy(polygons):
+ ppi = Polygon(polygon.points).buffer(0)
+ ppiou = float(ppi.intersection(pp).area)
+ if np.abs(ppiou - float(ppi.area)) > self.epsilon and np.abs(ppiou) > self.epsilon:
+ Fail_flag = True
+ break
+ elif np.abs(ppiou - float(ppi.area)) < self.epsilon:
+ # polygon.points -= np.array([xmin, ymin])
+ polygons_new.append(polygon)
+
+ if Fail_flag:
+ continue
+ else:
+ cropped = image[ymin:ymax + 1, xmin:xmax + 1, :]
+ for polygon in polygons_new:
+ polygon.points -= np.array([xmin, ymin])
+
+ return cropped, polygons_new
+
+ return image, polygons
+
+
+class RandomResizeScale(object):
+ def __init__(self, size=512, ratio=(3./4, 5./2)):
+ self.size = size
+ self.ratio = ratio
+
+ def __call__(self, image, polygons=None):
+
+ aspect_ratio = np.random.uniform(self.ratio[0], self.ratio[1])
+ h, w, _ = image.shape
+ scales = self.size*1.0/max(h, w)
+ aspect_ratio = scales * aspect_ratio
+ aspect_ratio = int(w * aspect_ratio)*1.0/w
+ image = cv2.resize(image, (int(w * aspect_ratio), int(h*aspect_ratio)))
+ scales = np.array([aspect_ratio, aspect_ratio])
+ if polygons is not None:
+ for polygon in polygons:
+ polygon.points = polygon.points * scales
+
+ return image, polygons
+
+
+class Resize(object):
+ def __init__(self, size=1024):
+ self.size = size
+ self.SP = SquarePadding()
+
+ def __call__(self, image, polygons=None):
+ h, w, _ = image.shape
+ image = cv2.resize(image, (self.size,
+ self.size))
+ scales = np.array([self.size / w, self.size / h])
+
+ if polygons is not None:
+ for polygon in polygons:
+ polygon.points = polygon.points * scales
+
+ return image, polygons
+
+
+class ResizeSquare(object):
+ def __init__(self, size=(480, 1280)):
+ self.size = size
+
+ def __call__(self, image, polygons=None):
+ h, w, _ = image.shape
+ img_size_min = min(h, w)
+ img_size_max = max(h, w)
+
+ if img_size_min < self.size[0]:
+ im_scale = float(self.size[0]) / float(img_size_min) # expand min to size[0]
+ if np.ceil(im_scale * img_size_max) > self.size[1]: # expand max can't > size[1]
+ im_scale = float(self.size[1]) / float(img_size_max)
+ elif img_size_max > self.size[1]:
+ im_scale = float(self.size[1]) / float(img_size_max)
+ else:
+ im_scale = 1.0
+
+ new_h = int(int(h * im_scale/32)*32)
+ new_w = int(int(w * im_scale/32)*32)
+ # if new_h*new_w > 1600*1920:
+ # im_scale = 1600 / float(img_size_max)
+ # new_h = int(int(h * im_scale/32)*32)
+ # new_w = int(int(w * im_scale/32)*32)
+ image = cv2.resize(image, (new_w, new_h))
+ scales = np.array([new_w / w, new_h / h])
+ if polygons is not None:
+ for polygon in polygons:
+ polygon.points = polygon.points * scales
+
+ return image, polygons
+
+
+class ResizeLimitSquare(object):
+ def __init__(self, size=512, ratio=0.6):
+ self.size = size
+ self.ratio = ratio
+ self.SP = SquarePadding()
+
+ def __call__(self, image, polygons=None):
+ if np.random.random() <= self.ratio:
+ image, polygons = self.SP(image, polygons)
+ h, w, _ = image.shape
+ image = cv2.resize(image, (self.size,self.size))
+ scales = np.array([self.size*1.0/ w, self.size*1.0 / h])
+
+ if polygons is not None:
+ for polygon in polygons:
+ polygon.points = polygon.points * scales
+
+ return image, polygons
+
+
+class RandomResizePadding(object):
+ def __init__(self, size=512, random_scale=np.array([0.75, 1.0, 1.25,1.5,2.0]),stride=32, ratio=0.6667):
+ self.random_scale = random_scale
+ self.size = size
+ self.ratio=ratio
+ self.stride=stride
+ self.SP=SquarePadding()
+
+ ###########Random size for different eproches ########################
+ rd_scale = np.random.choice(self.random_scale)
+ step_num = round(np.random.normal(loc=0.0, scale=0.35) * 8) # step 按照高斯分布
+ self.input_size = np.clip(int(self.size * rd_scale + step_num * self.stride),
+ (int(self.size * self.random_scale[0] - self.stride)),
+ int(self.size * self.random_scale[-1] + self.stride))
+ ############################ end ########################
+
+ def __call__(self, image, polygons=None):
+
+ if np.random.random() <= self.ratio:
+ image, polygons = self.SP(image, polygons)
+ h, w, _ = image.shape
+ image = cv2.resize(image, (self.input_size,self.input_size))
+ scales = np.array([self.input_size*1.0/ w, self.input_size*1.0 / h])
+
+ if polygons is not None:
+ for polygon in polygons:
+ polygon.points = polygon.points * scales
+
+ return image, polygons
+
+transform_type_dict = dict(
+ brightness=ImageEnhance.Brightness, contrast=ImageEnhance.Contrast,
+ sharpness=ImageEnhance.Sharpness, color=ImageEnhance.Color
+)
+
+
+class RandomDistortion(object):
+ def __init__(self, transform_dict, prob=0.5):
+ self.transforms = [(transform_type_dict[k], transform_dict[k]) for k in transform_dict]
+ self.prob = prob
+
+ def __call__(self, img, target):
+ if random.random() > self.prob:
+ return img, target
+ out = Image.fromarray(img)
+ rand_num = np.random.uniform(0, 1, len(self.transforms))
+
+ for i, (transformer, alpha) in enumerate(self.transforms):
+ r = alpha * (rand_num[i] * 2.0 - 1.0) + 1 # r in [1-alpha, 1+alpha)
+ out = transformer(out).enhance(r)
+
+ return np.array(out), target
+
+
+class Augmentation(object):
+ def __init__(self, size, mean, std):
+ self.size = size
+ self.mean = mean
+ self.std = std
+ self._transform_dict = {'brightness': 0.5, 'contrast': 0.5, 'sharpness': 0.8386, 'color': 0.5}
+ self.augmentation = Compose([
+ RandomCropFlip(),
+ RandomResizeScale(size=self.size, ratio=(3. / 8, 5. / 2)),
+ RandomResizedCrop(),
+ RotatePadding(up=60, colors=True), # pretrain on Syn is "up=30", else is "up=60"
+ ResizeLimitSquare(size=self.size),
+ RandomMirror(),
+ RandomDistortion(self._transform_dict),
+ Normalize(mean=self.mean, std=self.std),
+ ])
+
+ def __call__(self, image, polygons=None):
+ return self.augmentation(image, polygons)
+
+
+class BaseTransform(object):
+ def __init__(self, size, mean, std):
+ self.size = size
+ self.mean = mean
+ self.std = std
+ self.augmentation = Compose([
+ # Resize(size=640),
+ ResizeSquare(size=self.size),
+ Normalize(mean, std)
+ ])
+
+ def __call__(self, image, polygons=None):
+ return self.augmentation(image, polygons)
+
+
+class BaseTransformNresize(object):
+ def __init__(self, mean, std):
+ self.mean = mean
+ self.std = std
+ self.augmentation = Compose([
+ Normalize(mean, std)
+ ])
+
+ def __call__(self, image, polygons=None):
+ return self.augmentation(image, polygons)
diff --git a/IndicPhotoOCR/detection/textbpn/util/canvas.py b/IndicPhotoOCR/detection/textbpn/util/canvas.py
new file mode 100644
index 0000000000000000000000000000000000000000..555da8fc3e47755371c41f8e0b160a901f25d68c
--- /dev/null
+++ b/IndicPhotoOCR/detection/textbpn/util/canvas.py
@@ -0,0 +1,55 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+__author__ = '古溪'
+
+import numpy as np
+import random
+import matplotlib.pyplot as plt
+
+
+def heatmap(im_gray):
+ cmap = plt.get_cmap('jet')
+ rgba_img = cmap(255 - im_gray)
+ Hmap = np.delete(rgba_img, 3, 2)
+ # print(Hmap.shape, Hmap.max(), Hmap.min())
+ # cv2.imshow("heat_img", Hmap)
+ # cv2.waitKey(0)
+ return Hmap
+
+
+def loss_ploy(loss_list, steps, period, name=""):
+ fig1, ax1 = plt.subplots(figsize=(16, 9))
+ ax1.plot(range(steps // period), loss_list)
+ ax1.set_title("Average loss vs step*{}".format(period))
+ ax1.set_xlabel("step*{}".format(period))
+ ax1.set_ylabel("Current loss")
+ plt.savefig('{}@loss_vs_step*{}.png'.format(name,period))
+ plt.clf()
+
+
+def plt_ploys(ploys, period, name=""):
+ fig1, ax1 = plt.subplots(figsize=(16, 9))
+ cnames = ['aliceblue','antiquewhite','aqua','aquamarine','azure',
+ 'blanchedalmond','blue','blueviolet','brown','burlywood',
+ 'coral','cornflowerblue','cornsilk','crimson','cyan',
+ 'darkblue','deeppink','deepskyblue','dodgerblue','forestgreen',
+ 'gold','goldenrod','green','greenyellow','honeydew','hotpink',
+ 'lawngreen','lightblue','lightgreen','lightpink','lightsalmon',
+ 'lightseagreen','lightsteelblue','lightyellow','lime','limegreen',
+ 'mediumseagreen','mediumspringgreen','midnightblue','orange','orangered',
+ 'pink','red','royalblue','seagreen','skyblue','springgreen','steelblue',
+ 'tan','teal','thistle','yellow','yellowgreen']
+
+ color = random.sample(cnames, len(ploys.keys()))
+ for ii, key in enumerate(ploys.keys()):
+ ax1.plot(range(1, len(ploys[key])+1), ploys[key],color=color[ii], label=key)
+ ax1.set_title("Loss Carve line")
+ ax1.set_xlabel("step*{}".format(period))
+ ax1.set_ylabel("Current loss")
+ plt.legend(ploys.keys())
+ plt.savefig('{}@loss_vs_step*{}.png'.format(name, period))
+ plt.clf()
+
+if __name__ == '__main__':
+ # TODO ADD CODE
+ pass
\ No newline at end of file
diff --git a/IndicPhotoOCR/detection/textbpn/util/detection.py b/IndicPhotoOCR/detection/textbpn/util/detection.py
new file mode 100644
index 0000000000000000000000000000000000000000..0f2dd4adb8ce8aa2d507e43c31386d9abd8e70c8
--- /dev/null
+++ b/IndicPhotoOCR/detection/textbpn/util/detection.py
@@ -0,0 +1,48 @@
+# c++ version pse based on opencv 3+
+from pse import decode as pse_decode
+from cfglib.config import config as cfg
+
+
+class TextDetector(object):
+
+ def __init__(self, model):
+ # evaluation mode
+ self.model = model
+ model.eval()
+ # parameter
+ self.scale = cfg.scale
+ self.threshold = cfg.threshold
+
+ def detect(self, image, img_show):
+ # get model output
+ preds = self.model.forward(image)
+ preds, boxes, contours = pse_decode(preds[0], self.scale, self.threshold)
+
+ output = {
+ 'image': image,
+ 'tr': preds,
+ 'bbox': boxes
+ }
+ return contours, output
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/IndicPhotoOCR/detection/textbpn/util/eval.py b/IndicPhotoOCR/detection/textbpn/util/eval.py
new file mode 100644
index 0000000000000000000000000000000000000000..7b36cabb7b4f17276ca9f1b5d740354fab4d827c
--- /dev/null
+++ b/IndicPhotoOCR/detection/textbpn/util/eval.py
@@ -0,0 +1,228 @@
+import os
+import cv2
+import numpy as np
+import subprocess
+from cfglib.config import config as cfg
+from util.misc import mkdirs
+
+
+def osmkdir(out_dir):
+ import shutil
+ if os.path.exists(out_dir):
+ shutil.rmtree(out_dir)
+ os.makedirs(out_dir)
+
+
+def analysize_result(source_dir, fid_path, outpt_dir, name):
+
+ bad_txt = open("{}/eval.txt".format(outpt_dir), 'w')
+ all_eval = open("{}/{}/{}_eval.txt".format(cfg.output_dir, "Analysis", name), 'a+')
+ sel_list = list()
+ with open(fid_path) as f:
+ lines = f.read().split("\n")
+ for line in lines:
+ line_items = line.split(" ")
+ id = line_items[0]
+ precision = float(line_items[2].split('=')[-1])
+ recall = float(line_items[4].split('=')[-1])
+ if id != "ALL" and (precision < 0.5 or recall < 0.5):
+ img_path = os.path.join(source_dir, line_items[0].replace(".txt", ".jpg"))
+ if os.path.exists(img_path):
+ os.system('cp {} {}'.format(img_path, outpt_dir))
+ sel_list.append((int(id.replace(".txt", "").replace("img", "").replace("_", "")), line))
+ if id == "ALL":
+ all_eval.write("{} {} {}\n".format(
+ outpt_dir.split('/')[-1],
+ "{}/{}".format(cfg.dis_threshold, cfg.cls_threshold),
+ line))
+ sel_list = sorted(sel_list, key=lambda its: its[0])
+ bad_txt.write('\n'.join([its[1] for its in sel_list]))
+ all_eval.close()
+ bad_txt.close()
+
+
+def deal_eval_total_text(debug=False):
+ # compute DetEval
+ eval_dir = os.path.join(cfg.output_dir, "Analysis", "output_eval")
+ if not os.path.exists(eval_dir):
+ os.makedirs(eval_dir)
+
+ print('Computing DetEval in {}/{}'.format(cfg.output_dir, cfg.exp_name))
+ subprocess.call(
+ ['python', 'dataset/total_text/Evaluation_Protocol/Python_scripts/Deteval.py', cfg.exp_name, '--tr', '0.7',
+ '--tp', '0.6'])
+ subprocess.call(
+ ['python', 'dataset/total_text/Evaluation_Protocol/Python_scripts/Deteval.py', cfg.exp_name, '--tr', '0.8',
+ '--tp', '0.4'])
+
+ if debug:
+ source_dir = os.path.join(cfg.vis_dir, '{}_test'.format(cfg.exp_name))
+ outpt_dir_base = os.path.join(cfg.output_dir, "Analysis", "eval_view", "total_text")
+ if not os.path.exists(outpt_dir_base):
+ mkdirs(outpt_dir_base)
+
+ outpt_dir1 = os.path.join(outpt_dir_base, "{}_{}_{}_{}_{}"
+ .format(cfg.test_size[0], cfg.test_size[1], cfg.checkepoch, 0.7, 0.6))
+ osmkdir(outpt_dir1)
+ fid_path1 = '{}/Eval_TotalText_{}_{}.txt'.format(eval_dir, 0.7, 0.6)
+
+ analysize_result(source_dir, fid_path1, outpt_dir1, "totalText")
+
+ outpt_dir2 = os.path.join(outpt_dir_base, "{}_{}_{}_{}_{}"
+ .format(cfg.test_size[0], cfg.test_size[1], cfg.checkepoch, 0.8, 0.4))
+ osmkdir(outpt_dir2)
+ fid_path2 = '{}/Eval_TotalText_{}_{}.txt'.format(eval_dir, 0.8, 0.4)
+
+ analysize_result(source_dir, fid_path2, outpt_dir2, "totalText")
+
+ print('End.')
+
+
+def deal_eval_ctw1500(debug=False):
+ # compute DetEval
+ eval_dir = os.path.join(cfg.output_dir, "Analysis", "output_eval")
+ if not os.path.exists(eval_dir):
+ os.makedirs(eval_dir)
+
+ print('Computing DetEval in {}/{}'.format(cfg.output_dir, cfg.exp_name))
+ subprocess.call(['python', 'dataset/ctw1500/Evaluation_Protocol/ctw1500_eval.py', cfg.exp_name])
+
+ if debug:
+ source_dir = os.path.join(cfg.vis_dir, '{}_test'.format(cfg.exp_name))
+ outpt_dir_base = os.path.join(cfg.output_dir, "Analysis", "eval_view", "ctw1500")
+ if not os.path.exists(outpt_dir_base):
+ mkdirs(outpt_dir_base)
+
+ outpt_dir = os.path.join(outpt_dir_base, "{}_{}_{}".format(cfg.test_size[0], cfg.test_size[1], cfg.checkepoch))
+ osmkdir(outpt_dir)
+ fid_path1 = '{}/Eval_ctw1500_{}.txt'.format(eval_dir, 0.5)
+
+ analysize_result(source_dir, fid_path1, outpt_dir, "ctw1500")
+
+ print('End.')
+
+
+def deal_eval_icdar15(debug=False):
+ # compute DetEval
+ eval_dir = os.path.join(cfg.output_dir, "Analysis", "output_eval")
+ if not os.path.exists(eval_dir):
+ os.makedirs(eval_dir)
+
+ input_dir = 'output/{}'.format(cfg.exp_name)
+ father_path = os.path.abspath(input_dir)
+ print(father_path)
+ print('Computing DetEval in {}/{}'.format(cfg.output_dir, cfg.exp_name))
+ subprocess.call(['sh', 'dataset/icdar15/eval.sh', father_path])
+
+ if debug:
+ source_dir = os.path.join(cfg.vis_dir, '{}_test'.format(cfg.exp_name))
+ outpt_dir_base = os.path.join(cfg.output_dir, "Analysis", "eval_view", "icdar15")
+ if not os.path.exists(outpt_dir_base):
+ mkdirs(outpt_dir_base)
+
+ outpt_dir = os.path.join(outpt_dir_base, "{}_{}_{}".format(cfg.test_size[0], cfg.test_size[1], cfg.checkepoch))
+ osmkdir(outpt_dir)
+ fid_path1 = '{}/Eval_icdar15.txt'.format(eval_dir)
+
+ analysize_result(source_dir, fid_path1, outpt_dir, "icdar15")
+
+ print('End.')
+
+ pass
+
+
+def deal_eval_TD500(debug=False):
+ # compute DetEval
+ eval_dir = os.path.join(cfg.output_dir, "Analysis", "output_eval")
+ if not os.path.exists(eval_dir):
+ os.makedirs(eval_dir)
+
+ input_dir = 'output/{}'.format(cfg.exp_name)
+ father_path = os.path.abspath(input_dir)
+ print(father_path)
+ print('Computing DetEval in {}/{}'.format(cfg.output_dir, cfg.exp_name))
+ subprocess.call(['sh', 'dataset/TD500/eval.sh', father_path])
+
+ if debug:
+ source_dir = os.path.join(cfg.vis_dir, '{}_test'.format(cfg.exp_name))
+ outpt_dir_base = os.path.join(cfg.output_dir, "Analysis", "eval_view", "TD500")
+ if not os.path.exists(outpt_dir_base):
+ mkdirs(outpt_dir_base)
+
+ outpt_dir = os.path.join(outpt_dir_base, "{}_{}_{}".format(cfg.test_size[0], cfg.test_size[1], cfg.checkepoch))
+ osmkdir(outpt_dir)
+ fid_path1 = '{}/Eval_TD500.txt'.format(eval_dir)
+
+ analysize_result(source_dir, fid_path1, outpt_dir, "TD500")
+
+ print('End.')
+
+
+def data_transfer_ICDAR(contours):
+ cnts = list()
+ for cont in contours:
+ rect = cv2.minAreaRect(cont)
+ if min(rect[1][0], rect[1][1]) <= 5:
+ continue
+ points = cv2.boxPoints(rect)
+ points = np.int0(points)
+ # print(points.shape)
+ # points = np.reshape(points, (4, 2))
+ cnts.append(points)
+ return cnts
+
+
+def data_transfer_TD500(contours, res_file, img=None):
+ with open(res_file, 'w') as f:
+ for cont in contours:
+ rect = cv2.minAreaRect(cont)
+ if min(rect[1][0], rect[1][1]) <= 5:
+ continue
+ points = cv2.boxPoints(rect)
+ box = np.int0(points)
+ cv2.drawContours(img, [box], 0, (0, 255, 0), 3)
+
+ cx, cy = rect[0]
+ w_, h_ = rect[1]
+ angle = rect[2]
+ mid_ = 0
+ if angle > 45:
+ angle = 90 - angle
+ mid_ = w_;
+ w_ = h_;
+ h_ = mid_
+ elif angle < -45:
+ angle = 90 + angle
+ mid_ = w_;
+ w_ = h_;
+ h_ = mid_
+ angle = angle / 180 * 3.141592653589
+
+ x_min = int(cx - w_ / 2)
+ x_max = int(cx + w_ / 2)
+ y_min = int(cy - h_ / 2)
+ y_max = int(cy + h_ / 2)
+ f.write('{},{},{},{},{}\r\n'.format(x_min, y_min, x_max, y_max, angle))
+
+ return img
+
+
+def data_transfer_MLT2017(contours, res_file):
+ with open(res_file, 'w') as f:
+ for cont in contours:
+ rect = cv2.minAreaRect(cont)
+ if min(rect[1][0], rect[1][1]) <= 5:
+ continue
+ ploy_area = cv2.contourArea(cont)
+ rect_area = rect[1][0]*rect[1][1]
+ solidity = ploy_area/rect_area
+ width = rect[1][0] - np.clip(rect[1][0] * (1-np.sqrt(solidity)), 0, 6)
+ height = rect[1][1] - np.clip(rect[1][1] * (1-np.sqrt(solidity)), 0, 4)
+ points = cv2.boxPoints((rect[0], (width, height), rect[2]))
+ points = np.int0(points)
+ p = np.reshape(points, -1)
+ f.write('{},{},{},{},{},{},{},{},{}\r\n'
+ .format(p[0], p[1], p[2], p[3], p[4], p[5], p[6], p[7], 1))
+
+
+
diff --git a/IndicPhotoOCR/detection/textbpn/util/graph.py b/IndicPhotoOCR/detection/textbpn/util/graph.py
new file mode 100644
index 0000000000000000000000000000000000000000..c0175ad72d59bf414a1b902bc8b6c3bcf918e4d5
--- /dev/null
+++ b/IndicPhotoOCR/detection/textbpn/util/graph.py
@@ -0,0 +1,309 @@
+from __future__ import print_function
+from __future__ import division
+from __future__ import absolute_import
+
+import numpy as np
+import time
+from util.misc import norm2
+
+class Data(object):
+ def __init__(self, name):
+ self.__name = name
+ self.__links = set()
+
+ @property
+ def name(self):
+ return self.__name
+
+ @property
+ def links(self):
+ return set(self.__links)
+
+ def add_link(self, other, score):
+ self.__links.add(other)
+ other.__links.add(self)
+
+
+def connected_components(nodes, score_dict, th):
+ '''
+ conventional connected components searching
+ '''
+ result = []
+ nodes = set(nodes)
+ while nodes:
+ n = nodes.pop()
+ group = {n}
+ queue = [n]
+ while queue:
+ n = queue.pop(0)
+ if th is not None:
+ neighbors = {l for l in n.links if score_dict[tuple(sorted([n.name, l.name]))] >= th}
+ else:
+ neighbors = n.links
+ neighbors.difference_update(group)
+ nodes.difference_update(neighbors)
+ group.update(neighbors)
+ queue.extend(neighbors)
+ result.append(group)
+ return result
+
+
+def connected_components_constraint(nodes, max_sz, score_dict=None, th=None):
+ '''
+ only use edges whose scores are above `th`
+ if a component is larger than `max_sz`, all the nodes in this component are added into `remain` and returned for next iteration.
+ '''
+ result = []
+ remain = set()
+ nodes = set(nodes)
+ while nodes:
+ n = nodes.pop()
+ group = {n}
+ queue = [n]
+ valid = True
+ while queue:
+ n = queue.pop(0)
+ if th is not None:
+ neighbors = {l for l in n.links if score_dict[tuple(sorted([n.name, l.name]))] >= th}
+ else:
+ neighbors = n.links
+ neighbors.difference_update(group)
+ nodes.difference_update(neighbors)
+ group.update(neighbors)
+ queue.extend(neighbors)
+ if len(group) > max_sz or len(remain.intersection(neighbors)) > 0:
+ # if this group is larger than `max_sz`, add the nodes into `remain`
+ valid = False
+ remain.update(group)
+ break
+ if valid: # if this group is smaller than or equal to `max_sz`, finalize it.
+ result.append(group)
+ return result, remain
+
+
+def graph_propagation_naive(edges, score, th, bboxs=None, dis_thresh=50, pool='avg'):
+
+ edges = np.sort(edges, axis=1)
+
+ score_dict = {} # score lookup table
+ if pool is None:
+ for i, e in enumerate(edges):
+ score_dict[e[0], e[1]] = score[i]
+ elif pool == 'avg':
+ for i, e in enumerate(edges):
+ if bboxs is not None:
+ box1 = bboxs[e[0]][:8].reshape(4, 2)
+ box2 = bboxs[e[1]][:8].reshape(4, 2)
+ c1 = np.mean(box1, 0); c2 = np.mean(box2, 0)
+ dst = norm2(c1 - c2)
+ if dst > dis_thresh:
+ score[i] = 0
+ if (e[0], e[1]) in score_dict:
+ score_dict[e[0], e[1]] = 0.5 * (score_dict[e[0], e[1]] + score[i])
+ else:
+ score_dict[e[0], e[1]] = score[i]
+
+ elif pool == 'max':
+ for i, e in enumerate(edges):
+ if (e[0], e[1]) in score_dict:
+ score_dict[e[0], e[1]] = max(score_dict[e[0], e[1]], score[i])
+ else:
+ score_dict[e[0], e[1]] = score[i]
+ else:
+ raise ValueError('Pooling operation not supported')
+
+ nodes = np.sort(np.unique(edges.flatten()))
+ mapping = -1 * np.ones((nodes.max()+1), dtype=np.int)
+ mapping[nodes] = np.arange(nodes.shape[0])
+ link_idx = mapping[edges]
+ vertex = [Data(n) for n in nodes]
+ for l, s in zip(link_idx, score):
+ vertex[l[0]].add_link(vertex[l[1]], s)
+
+ # first iteration
+ comps = connected_components(vertex, score_dict,th)
+
+ return comps
+
+
+def graph_search(edges, scores, edges_num, th=None):
+ # graph search
+ scores = scores.reshape((-1, edges_num))
+ select_index = np.argsort(scores, axis=1)[:, -2:]
+ edges = np.sort(edges, axis=1).reshape((-1, edges_num, 2))
+
+ score_dict = {}
+ for i, ips in enumerate(select_index):
+ edg = edges[i]
+ si = scores[i]
+ for j, idx in enumerate(ips):
+ e = edg[idx, :]
+ if (e[0], e[1]) in score_dict:
+ score_dict[e[0], e[1]] = 0.5 * (score_dict[e[0], e[1]] + si[j])
+ else:
+ score_dict[e[0], e[1]] = si[j]
+
+ nodes = np.sort(np.unique(edges.flatten()))
+ vertex = [Data(n) for n in nodes]
+ for (key, value) in score_dict.items():
+ vertex[key[0]].add_link(vertex[key[1]], value)
+
+ comps = connected_components(vertex, score_dict, th)
+
+ return comps
+
+
+def graph_propagation(edges, score, max_sz, step=0.1, beg_th=0.5, pool=None):
+
+ edges = np.sort(edges, axis=1)
+ th = score.min()
+ # th = beg_th
+ # construct graph
+ score_dict = {} # score lookup table
+ if pool is None:
+ for i,e in enumerate(edges):
+ score_dict[e[0], e[1]] = score[i]
+ elif pool == 'avg':
+ for i,e in enumerate(edges):
+ if (e[0], e[1]) in score_dict:
+ score_dict[e[0], e[1]] = 0.5*(score_dict[e[0], e[1]] + score[i])
+ else:
+ score_dict[e[0], e[1]] = score[i]
+
+ elif pool == 'max':
+ for i,e in enumerate(edges):
+ if (e[0],e[1]) in score_dict:
+ score_dict[e[0], e[1]] = max(score_dict[e[0], e[1]] , score[i])
+ else:
+ score_dict[e[0], e[1]] = score[i]
+ else:
+ raise ValueError('Pooling operation not supported')
+
+ nodes = np.sort(np.unique(edges.flatten()))
+ mapping = -1 * np.ones((nodes.max()+1), dtype=np.int)
+ mapping[nodes] = np.arange(nodes.shape[0])
+ link_idx = mapping[edges]
+ vertex = [Data(n) for n in nodes]
+ for l, s in zip(link_idx, score):
+ vertex[l[0]].add_link(vertex[l[1]], s)
+
+ # first iteration
+ comps, remain = connected_components_constraint(vertex, max_sz)
+
+ # iteration
+ components = comps[:]
+ while remain:
+ th = th + (1 - th) * step
+ comps, remain = connected_components_constraint(remain, max_sz, score_dict, th)
+ components.extend(comps)
+ return components
+
+
+def graph_propagation_soft(edges, score, max_sz, step=0.1, **kwargs):
+
+ edges = np.sort(edges, axis=1)
+ th = score.min()
+
+ # construct graph
+ score_dict = {} # score lookup table
+ for i,e in enumerate(edges):
+ score_dict[e[0], e[1]] = score[i]
+
+ nodes = np.sort(np.unique(edges.flatten()))
+ mapping = -1 * np.ones((nodes.max()+1), dtype=np.int)
+ mapping[nodes] = np.arange(nodes.shape[0])
+ link_idx = mapping[edges]
+ vertex = [Data(n) for n in nodes]
+ for l, s in zip(link_idx, score):
+ vertex[l[0]].add_link(vertex[l[1]], s)
+
+ # first iteration
+ comps, remain = connected_components_constraint(vertex, max_sz)
+ first_vertex_idx = np.array([mapping[n.name] for c in comps for n in c])
+ fusion_vertex_idx = np.setdiff1d(np.arange(nodes.shape[0]), first_vertex_idx, assume_unique=True)
+ # iteration
+ components = comps[:]
+ while remain:
+ th = th + (1 - th) * step
+ comps, remain = connected_components_constraint(remain, max_sz, score_dict, th)
+ components.extend(comps)
+ label_dict = {}
+ for i,c in enumerate(components):
+ for n in c:
+ label_dict[n.name] = i
+ print('Propagation ...')
+ prop_vertex = [vertex[idx] for idx in fusion_vertex_idx]
+ label, label_fusion = diffusion(prop_vertex, label_dict, score_dict, **kwargs)
+ return label, label_fusion
+
+
+def diffusion(vertex, label, score_dict, max_depth=5, weight_decay=0.6, normalize=True):
+ class BFSNode():
+ def __init__(self, node, depth, value):
+ self.node = node
+ self.depth = depth
+ self.value = value
+
+ label_fusion = {}
+ for name in label.keys():
+ label_fusion[name] = {label[name]: 1.0}
+ prog = 0
+ prog_step = len(vertex) // 20
+ start = time.time()
+ for root in vertex:
+ if prog % prog_step == 0:
+ print("progress: {} / {}, elapsed time: {}".format(prog, len(vertex), time.time() - start))
+ prog += 1
+ #queue = {[root, 0, 1.0]}
+ queue = {BFSNode(root, 0, 1.0)}
+ visited = [root.name]
+ root_label = label[root.name]
+ while queue:
+ curr = queue.pop()
+ if curr.depth >= max_depth: # pruning
+ continue
+ neighbors = curr.node.links
+ tmp_value = []
+ tmp_neighbor = []
+ for n in neighbors:
+ if n.name not in visited:
+ sub_value = score_dict[tuple(sorted([curr.node.name, n.name]))] * weight_decay * curr.value
+ tmp_value.append(sub_value)
+ tmp_neighbor.append(n)
+ if root_label not in label_fusion[n.name].keys():
+ label_fusion[n.name][root_label] = sub_value
+ else:
+ label_fusion[n.name][root_label] += sub_value
+ visited.append(n.name)
+ #queue.add([n, curr.depth+1, sub_value])
+ sortidx = np.argsort(tmp_value)[::-1]
+ for si in sortidx:
+ queue.add(BFSNode(tmp_neighbor[si], curr.depth+1, tmp_value[si]))
+ if normalize:
+ for name in label_fusion.keys():
+ summ = sum(label_fusion[name].values())
+ for k in label_fusion[name].keys():
+ label_fusion[name][k] /= summ
+ return label, label_fusion
+
+
+def clusters2labels(clusters, n_nodes):
+ labels = (-1)* np.ones((n_nodes,))
+ for ci, c in enumerate(clusters):
+ for xid in c:
+ labels[xid.name] = ci
+ assert np.sum(labels < 0) < 1
+ return labels
+
+
+def single_remove(bbox, pred):
+ single_idcs = np.zeros_like(pred)
+ pred_unique = np.unique(pred)
+ for u in pred_unique:
+ idcs = pred == u
+ if np.sum(idcs) == 1:
+ single_idcs[np.where(idcs)[0][0]] = 1
+ remain_idcs = [i for i in range(len(pred)) if not single_idcs[i]]
+ remain_idcs = np.asarray(remain_idcs)
+ return bbox[remain_idcs, :], pred[remain_idcs]
+
diff --git a/IndicPhotoOCR/detection/textbpn/util/io.py b/IndicPhotoOCR/detection/textbpn/util/io.py
new file mode 100644
index 0000000000000000000000000000000000000000..fdcdadc7b802e622d6079fc679d87e134541e0bf
--- /dev/null
+++ b/IndicPhotoOCR/detection/textbpn/util/io.py
@@ -0,0 +1,233 @@
+#coding=utf-8
+'''
+Created on 2016年9月27日
+
+@author: dengdan
+
+Tool functions for file system operation and I/O.
+In the style of linux shell commands
+'''
+import os
+import pickle as pkl
+import subprocess
+import logging
+from . import strs, io
+
+
+def mkdir(path):
+ """
+ If the target directory does not exists, it and its parent directories will created.
+ """
+ path = get_absolute_path(path)
+ if not exists(path):
+ os.makedirs(path)
+ return path
+
+def make_parent_dir(path):
+ """make the parent directories for a file."""
+ parent_dir = get_dir(path)
+ mkdir(parent_dir)
+
+
+def pwd():
+ return os.getcwd()
+
+def dump(path, obj):
+ path = get_absolute_path(path)
+ parent_path = get_dir(path)
+ mkdir(parent_path)
+ with open(path, 'w') as f:
+ logging.info('dumping file:' + path);
+ pkl.dump(obj, f)
+
+def load(path):
+ path = get_absolute_path(path)
+ with open(path, 'r') as f:
+ data = pkl.load(f)
+ return data
+
+def join_path(a, *p):
+ return os.path.join(a, *p)
+
+def is_dir(path):
+ path = get_absolute_path(path)
+ return os.path.isdir(path)
+
+is_directory = is_dir
+
+def is_path(path):
+ path = get_absolute_path(path)
+ return os.path.ispath(path)
+
+def get_dir(path):
+ '''
+ return the directory it belongs to.
+ if path is a directory itself, itself will be return
+ '''
+ path = get_absolute_path(path)
+ if is_dir(path):
+ return path;
+ return os.path.split(path)[0]
+
+def get_parent_dir(path):
+ current_dir = get_dir(path)
+ return get_absolute_path(join_path(current_dir, '..'))
+
+def get_filename(path):
+ return os.path.split(path)[1]
+
+def get_absolute_path(p):
+ if p.startswith('~'):
+ p = os.path.expanduser(p)
+ return os.path.abspath(p)
+
+def cd(p):
+ p = get_absolute_path(p)
+ os.chdir(p)
+
+def ls(path = '.', suffix = None):
+ """
+ list files in a directory.
+ return file names in a list
+ """
+ path = get_absolute_path(path)
+ files = os.listdir(path)
+
+ if suffix is None:
+ return files
+
+ filtered = []
+ for f in files:
+ if string.ends_with(f, suffix, ignore_case = True):
+ filtered.append(f)
+
+ return filtered
+
+def find_files(pattern):
+ import glob
+ return glob.glob(pattern)
+
+def read_lines(p):
+ """return the text in a file in lines as a list """
+ p = get_absolute_path(p)
+ f = open(p,'r')
+ return f.readlines()
+
+def write_lines(p, lines, append_break = False):
+ p = get_absolute_path(p)
+ make_parent_dir(p)
+ with open(p, 'w') as f:
+ for line in lines:
+ if append_break:
+ f.write(line + '\n')
+ else:
+ f.write(line)
+
+def cat(p):
+ """return the text in a file as a whole"""
+ cmd = 'cat ' + p
+ return subprocess.getoutput(cmd)
+
+def exists(path):
+ path = get_absolute_path(path)
+ return os.path.exists(path)
+
+def not_exists(path):
+ return not exists(path)
+
+def load_mat(path):
+ import scipy.io as sio # type: ignore
+ path = get_absolute_path(path)
+ return sio.loadmat(path)
+
+def dump_mat(path, dict_obj, append = True):
+ import scipy.io as sio # type: ignore
+ path = get_absolute_path(path)
+ make_parent_dir(path)
+ sio.savemat(file_name = path, mdict = dict_obj, appendmat = append)
+
+def dir_mat(path):
+ '''
+ list the variables in mat file.
+ return a list: [(name, shape, dtype), ...]
+ '''
+ import scipy.io as sio # type: ignore
+ path = get_absolute_path(path)
+ return sio.whosmat(path)
+
+SIZE_UNIT_K = 1024
+SIZE_UNIT_M = SIZE_UNIT_K ** 2
+SIZE_UNIT_G = SIZE_UNIT_K ** 3
+def get_file_size(path, unit = SIZE_UNIT_K):
+ size = os.path.getsize(get_absolute_path(path))
+ return size * 1.0 / unit
+
+
+def create_h5(path):
+ import h5py # type: ignore
+ path = get_absolute_path(path)
+ make_parent_dir(path)
+ return h5py.File(path, 'w');
+
+def open_h5(path, mode = 'r'):
+ import h5py
+ path = get_absolute_path(path)
+ return h5py.File(path, mode);
+
+def read_h5(h5, key):
+ return h5[key][:]
+def read_h5_attrs(h5, key, attrs):
+ return h5[key].attrs[attrs]
+
+def copy(src, dest):
+ io.make_parent_dir(dest)
+ import shutil
+ shutil.copy(get_absolute_path(src), get_absolute_path(dest))
+
+cp = copy
+
+def remove(p):
+ import os
+ os.remove(get_absolute_path(p))
+rm = remove
+
+def search(pattern, path, file_only = True):
+ """
+ Search files whose name matches the give pattern. The search scope
+ is the directory and sub-directories of 'path'.
+ """
+ path = get_absolute_path(path)
+ pattern_here = io.join_path(path, pattern)
+ targets = []
+
+ # find matchings in current directory
+ candidates = find_files(pattern_here)
+ for can in candidates:
+ if io.is_dir(can) and file_only:
+ continue
+ else:
+ targets.append(can)
+
+ # find matching in sub-dirs
+ files = ls(path)
+ for f in files:
+ fpath = io.join_path(path, f)
+ if is_dir(fpath):
+ targets_in_sub_dir = search(pattern, fpath, file_only)
+ targets.extend(targets_in_sub_dir)
+ return targets
+
+def dump_json(path, data):
+ import ujson as json
+ path = get_absolute_path(path)
+ make_parent_dir(path)
+
+ with open(path, 'w') as f:
+ json.dump(data, f)
+ return path
+
+def load_json(path):
+ import ujson as json
+ path = get_absolute_path(path)
+ with open(path, 'r') as f:
+ return json.load(f)
diff --git a/IndicPhotoOCR/detection/textbpn/util/logging.py b/IndicPhotoOCR/detection/textbpn/util/logging.py
new file mode 100644
index 0000000000000000000000000000000000000000..e928a4ac25de0531d0542e2de5b2eaf686a45261
--- /dev/null
+++ b/IndicPhotoOCR/detection/textbpn/util/logging.py
@@ -0,0 +1,129 @@
+from __future__ import absolute_import
+import os
+import sys
+import numpy as np
+import tensorflow as tf
+import scipy.misc
+try:
+ from StringIO import StringIO # Python 2.7
+except ImportError:
+ from io import BytesIO # Python 3.x
+
+from .osutils import mkdir_if_missing
+
+from config import get_args
+global_args = get_args(sys.argv[1:])
+
+if global_args.run_on_remote:
+ import moxing as mox
+ mox.file.shift("os", "mox")
+
+class Logger(object):
+ def __init__(self, fpath=None):
+ self.console = sys.stdout
+ self.file = None
+ if fpath is not None:
+ if global_args.run_on_remote:
+ dir_name = os.path.dirname(fpath)
+ if not mox.file.exists(dir_name):
+ mox.file.make_dirs(dir_name)
+ print('=> making dir ', dir_name)
+ self.file = mox.file.File(fpath, 'w')
+ # self.file = open(fpath, 'w')
+ else:
+ mkdir_if_missing(os.path.dirname(fpath))
+ self.file = open(fpath, 'w')
+
+ def __del__(self):
+ self.close()
+
+ def __enter__(self):
+ pass
+
+ def __exit__(self, *args):
+ self.close()
+
+ def write(self, msg):
+ self.console.write(msg)
+ if self.file is not None:
+ self.file.write(msg)
+
+ def flush(self):
+ self.console.flush()
+ if self.file is not None:
+ self.file.flush()
+ os.fsync(self.file.fileno())
+
+ def close(self):
+ self.console.close()
+ if self.file is not None:
+ self.file.close()
+
+
+class TFLogger(object):
+ def __init__(self, log_dir=None):
+ """Create a summary writer logging to log_dir."""
+ if log_dir is not None:
+ mkdir_if_missing(log_dir)
+ self.writer = tf.summary.FileWriter(log_dir)
+
+ def scalar_summary(self, tag, value, step):
+ """Log a scalar variable."""
+ summary = tf.Summary(value=[tf.Summary.Value(tag=tag, simple_value=value)])
+ self.writer.add_summary(summary, step)
+ self.writer.flush()
+
+ def image_summary(self, tag, images, step):
+ """Log a list of images."""
+
+ img_summaries = []
+ for i, img in enumerate(images):
+ # Write the image to a string
+ try:
+ s = StringIO()
+ except:
+ s = BytesIO()
+ scipy.misc.toimage(img).save(s, format="png")
+
+ # Create an Image object
+ img_sum = tf.Summary.Image(encoded_image_string=s.getvalue(),
+ height=img.shape[0],
+ width=img.shape[1])
+ # Create a Summary value
+ img_summaries.append(tf.Summary.Value(tag='%s/%d' % (tag, i), image=img_sum))
+
+ # Create and write Summary
+ summary = tf.Summary(value=img_summaries)
+ self.writer.add_summary(summary, step)
+ self.writer.flush()
+
+ def histo_summary(self, tag, values, step, bins=1000):
+ """Log a histogram of the tensor of values."""
+
+ # Create a histogram using numpy
+ counts, bin_edges = np.histogram(values, bins=bins)
+
+ # Fill the fields of the histogram proto
+ hist = tf.HistogramProto()
+ hist.min = float(np.min(values))
+ hist.max = float(np.max(values))
+ hist.num = int(np.prod(values.shape))
+ hist.sum = float(np.sum(values))
+ hist.sum_squares = float(np.sum(values**2))
+
+ # Drop the start of the first bin
+ bin_edges = bin_edges[1:]
+
+ # Add bin edges and counts
+ for edge in bin_edges:
+ hist.bucket_limit.append(edge)
+ for c in counts:
+ hist.bucket.append(c)
+
+ # Create and write Summary
+ summary = tf.Summary(value=[tf.Summary.Value(tag=tag, histo=hist)])
+ self.writer.add_summary(summary, step)
+ self.writer.flush()
+
+ def close(self):
+ self.writer.close()
\ No newline at end of file
diff --git a/IndicPhotoOCR/detection/textbpn/util/meters.py b/IndicPhotoOCR/detection/textbpn/util/meters.py
new file mode 100644
index 0000000000000000000000000000000000000000..7b98c6fd2d3260be44b1bd2fdda28d6e75979952
--- /dev/null
+++ b/IndicPhotoOCR/detection/textbpn/util/meters.py
@@ -0,0 +1,23 @@
+from __future__ import absolute_import
+
+
+class AverageMeter(object):
+ """Computes and stores the average and current value"""
+
+ def __init__(self):
+ self.val = 0
+ self.avg = 0
+ self.sum = 0
+ self.count = 0
+
+ def reset(self):
+ self.val = 0
+ self.avg = 0
+ self.sum = 0
+ self.count = 0
+
+ def update(self, val, n=1):
+ self.val = val
+ self.sum += val * n
+ self.count += n
+ self.avg = self.sum / self.count
\ No newline at end of file
diff --git a/IndicPhotoOCR/detection/textbpn/util/misc.py b/IndicPhotoOCR/detection/textbpn/util/misc.py
new file mode 100644
index 0000000000000000000000000000000000000000..4231e2b2bbb6f6214082f0a4ee8333588b268e06
--- /dev/null
+++ b/IndicPhotoOCR/detection/textbpn/util/misc.py
@@ -0,0 +1,408 @@
+import numpy as np
+import errno
+import os
+import cv2
+import math
+from shapely.geometry import Polygon
+from IndicPhotoOCR.detection.textbpn.cfglib.config import config as cfg
+from scipy import ndimage as ndimg
+
+def to_device(*tensors):
+ if len(tensors) < 2:
+ return tensors[0].to(cfg.device, non_blocking=True)
+ return (t.to(cfg.device, non_blocking=True) for t in tensors)
+
+
+def mkdirs(newdir):
+ """
+ make directory with parent path
+ :param newdir: target path
+ """
+ try:
+ if not os.path.exists(newdir):
+ os.makedirs(newdir)
+ except OSError as err:
+ # Reraise the error unless it's about an already existing directory
+ if err.errno != errno.EEXIST or not os.path.isdir(newdir):
+ raise
+
+
+def rescale_result(image, bbox_contours, H, W):
+ ori_H, ori_W = image.shape[:2]
+ image = cv2.resize(image, (W, H))
+ contours = list()
+ for cont in bbox_contours:
+ # if cv2.contourArea(cont) < 300:
+ # continue
+ cont[:, 0] = (cont[:, 0] * W / ori_W).astype(int)
+ cont[:, 1] = (cont[:, 1] * H / ori_H).astype(int)
+ contours.append(cont)
+ return image, contours
+
+
+def fill_hole(input_mask):
+ h, w = input_mask.shape
+ canvas = np.zeros((h + 2, w + 2), np.uint8)
+ canvas[1:h + 1, 1:w + 1] = input_mask.copy()
+
+ mask = np.zeros((h + 4, w + 4), np.uint8)
+
+ cv2.floodFill(canvas, mask, (0, 0), 1)
+ canvas = canvas[1:h + 1, 1:w + 1].astype(np.bool)
+
+ return (~canvas | input_mask.astype(np.uint8))
+
+
+def regularize_sin_cos(sin, cos):
+ # regularization
+ scale = np.sqrt(1.0 / (sin ** 2 + cos ** 2))
+ return sin * scale, cos * scale
+
+
+def gaussian2D(shape, sigma=1):
+ m, n = [(ss - 1.) / 2. for ss in shape]
+ y, x = np.ogrid[-m:m + 1, -n:n + 1]
+
+ h = np.exp(-(x * x + y * y) / (2 * sigma * sigma))
+ h[h < np.finfo(h.dtype).eps * h.max()] = 0
+ return h
+
+
+def draw_gaussian(heatmap, center, radius, k=1, delte=6):
+ diameter = 2 * radius + 1
+ gaussian = gaussian2D((diameter, diameter), sigma=diameter / delte)
+
+ x, y = center
+
+ height, width = heatmap.shape[0:2]
+
+ left, right = min(x, radius), min(width - x, radius + 1)
+ top, bottom = min(y, radius), min(height - y, radius + 1)
+
+ masked_heatmap = heatmap[y - top:y + bottom, x - left:x + right]
+ masked_gaussian = gaussian[radius - top:radius + bottom, radius - left:radius + right]
+ np.maximum(masked_heatmap, masked_gaussian * k, out=masked_heatmap)
+
+
+def gaussian_radius(det_size, min_overlap=0.7):
+ height, width = det_size
+
+ a1 = 1
+ b1 = (height + width)
+ c1 = width * height * (1 - min_overlap) / (1 + min_overlap)
+ sq1 = np.sqrt(b1 ** 2 - 4 * a1 * c1)
+ r1 = (b1 + sq1) / 2
+
+ a2 = 4
+ b2 = 2 * (height + width)
+ c2 = (1 - min_overlap) * width * height
+ sq2 = np.sqrt(b2 ** 2 - 4 * a2 * c2)
+ r2 = (b2 + sq2) / 2
+
+ a3 = 4 * min_overlap
+ b3 = -2 * min_overlap * (height + width)
+ c3 = (min_overlap - 1) * width * height
+ sq3 = np.sqrt(b3 ** 2 - 4 * a3 * c3)
+ r3 = (b3 + sq3) / 2
+ return min(r1, r2, r3)
+
+
+def point_dist_to_line(line, p3):
+ # 计算点到直线的距离
+ # line = (p1, p2)
+ # compute the distance from p3 to p1-p2 #cross(x,y)矩阵的叉积,norm()求范数
+ # np.linalg.norm(np.cross(p2 - p1, p1 - p3)) * 1.0 / np.linalg.norm(p2 - p1)
+ # compute the distance from p3 to p1-p2
+ p1, p2 = line
+ d = p2 - p1
+
+ def l2(p):
+ return math.sqrt(p[0] * p[0]+ p[1]*p[1])
+
+ if l2(d) > 0:
+ distance = abs(d[1] * p3[0] - d[0] * p3[1] + p2[0] * p1[1] - p2[1] * p1[0]) / l2(d)
+ else:
+ distance = math.sqrt((p3[0]-p2[0])**2 + (p3[1]-p2[1])**2)
+
+ return distance
+
+
+class AverageMeter(object):
+ """Computes and stores the average and current value"""
+ def __init__(self):
+ self.reset()
+
+ def reset(self):
+ self.val = 0
+ self.avg = 0
+ self.sum = 0
+ self.count = 0
+
+ def update(self, val, n=1):
+ self.val = val
+ self.sum += val * n
+ self.count += n
+ self.avg = self.sum / self.count
+
+
+def norm2(x, axis=None):
+ if axis:
+ return np.sqrt(np.sum(x ** 2, axis=axis))
+ return np.sqrt(np.sum(x ** 2))
+
+
+def cos(p1, p2):
+ return (p1 * p2).sum() / (norm2(p1) * norm2(p2))
+
+
+def vector_sin(v):
+ assert len(v) == 2
+ # sin = y / (sqrt(x^2 + y^2))
+ l = np.sqrt(v[0] ** 2 + v[1] ** 2) + 1e-5
+ return v[1] / l
+
+
+def vector_cos(v):
+ assert len(v) == 2
+ # cos = x / (sqrt(x^2 + y^2))
+ l = np.sqrt(v[0] ** 2 + v[1] ** 2) + 1e-5
+ return v[0] / l
+
+
+def find_bottom(pts):
+
+ if len(pts) > 4:
+ e = np.concatenate([pts, pts[:3]])
+ candidate = []
+ for i in range(1, len(pts) + 1):
+ v_prev = e[i] - e[i - 1]
+ v_next = e[i + 2] - e[i + 1]
+ if cos(v_prev, v_next) < -0.875:
+ candidate.append((i % len(pts), (i + 1) % len(pts), norm2(e[i] - e[i + 1])))
+
+ if len(candidate) != 2 or candidate[0][0] == candidate[1][1] or candidate[0][1] == candidate[1][0]:
+ # if candidate number < 2, or two bottom are joined, select 2 farthest edge
+ mid_list = []
+ dist_list = []
+ if len(candidate) > 2:
+
+ bottom_idx = np.argsort([angle for s1, s2, angle in candidate])[0:2]
+ bottoms = [candidate[bottom_idx[0]][:2], candidate[bottom_idx[1]][0:2]]
+ long_edge1, long_edge2 = find_long_edges(pts, bottoms)
+ edge_length1 = [norm2(pts[e1] - pts[e2]) for e1, e2 in long_edge1]
+ edge_length2 = [norm2(pts[e1] - pts[e2]) for e1, e2 in long_edge2]
+ l1 = sum(edge_length1)
+ l2 = sum(edge_length2)
+ len1 = len(edge_length1)
+ len2 = len(edge_length2)
+
+ if l1 > 2*l2 or l2 > 2*l1 or len1 == 0 or len2 == 0:
+ for i in range(len(pts)):
+ mid_point = (e[i] + e[(i + 1) % len(pts)]) / 2
+ mid_list.append((i, (i + 1) % len(pts), mid_point))
+
+ for i in range(len(pts)):
+ for j in range(len(pts)):
+ s1, e1, mid1 = mid_list[i]
+ s2, e2, mid2 = mid_list[j]
+ dist = norm2(mid1 - mid2)
+ dist_list.append((s1, e1, s2, e2, dist))
+ bottom_idx = np.argsort([dist for s1, e1, s2, e2, dist in dist_list])[-1]
+ bottoms = [dist_list[bottom_idx][:2], dist_list[bottom_idx][2:4]]
+ else:
+ mid_list = []
+ for i in range(len(pts)):
+ mid_point = (e[i] + e[(i + 1) % len(pts)]) / 2
+ mid_list.append((i, (i + 1) % len(pts), mid_point))
+
+ dist_list = []
+ for i in range(len(pts)):
+ for j in range(len(pts)):
+ s1, e1, mid1 = mid_list[i]
+ s2, e2, mid2 = mid_list[j]
+ dist = norm2(mid1 - mid2)
+ dist_list.append((s1, e1, s2, e2, dist))
+ bottom_idx = np.argsort([dist for s1, e1, s2, e2, dist in dist_list])[-2:]
+ bottoms = [dist_list[bottom_idx[0]][:2], dist_list[bottom_idx[1]][:2]]
+ else:
+ bottoms = [candidate[0][:2], candidate[1][:2]]
+ else:
+ d1 = norm2(pts[1] - pts[0]) + norm2(pts[2] - pts[3])
+ d2 = norm2(pts[2] - pts[1]) + norm2(pts[0] - pts[3])
+ bottoms = [(0, 1), (2, 3)] if d1 < d2 else [(1, 2), (3, 0)]
+ # bottoms = [(0, 1), (2, 3)] if 2 * d1 < d2 and d1 > 32 else [(1, 2), (3, 0)]
+ assert len(bottoms) == 2, 'fewer than 2 bottoms'
+ return bottoms
+
+
+def split_long_edges(points, bottoms):
+ """
+ Find two long edge sequence of and polygon
+ """
+ b1_start, b1_end = bottoms[0]
+ b2_start, b2_end = bottoms[1]
+ n_pts = len(points)
+
+ i = b1_end + 1
+ long_edge_1 = []
+ while i % n_pts != b2_end:
+ long_edge_1.append((i - 1, i))
+ i = (i + 1) % n_pts
+
+ i = b2_end + 1
+ long_edge_2 = []
+ while i % n_pts != b1_end:
+ long_edge_2.append((i - 1, i))
+ i = (i + 1) % n_pts
+ return long_edge_1, long_edge_2
+
+
+def find_long_edges(points, bottoms):
+ b1_start, b1_end = bottoms[0]
+ b2_start, b2_end = bottoms[1]
+ n_pts = len(points)
+ i = (b1_end + 1) % n_pts
+ long_edge_1 = []
+
+ while i % n_pts != b2_end:
+ start = (i - 1) % n_pts
+ end = i % n_pts
+ long_edge_1.append((start, end))
+ i = (i + 1) % n_pts
+
+ i = (b2_end + 1) % n_pts
+ long_edge_2 = []
+ while i % n_pts != b1_end:
+ start = (i - 1) % n_pts
+ end = i % n_pts
+ long_edge_2.append((start, end))
+ i = (i + 1) % n_pts
+ return long_edge_1, long_edge_2
+
+
+def split_edge_seqence(points, n_parts):
+ pts_num = points.shape[0]
+ long_edge = [(i, (i + 1) % pts_num) for i in range(pts_num)]
+ edge_length = [norm2(points[e1] - points[e2]) for e1, e2 in long_edge]
+ point_cumsum = np.cumsum([0] + edge_length)
+ total_length = sum(edge_length)
+ length_per_part = total_length / n_parts
+
+ cur_node = 0 # first point
+ splited_result = []
+
+ for i in range(1, n_parts):
+ cur_end = i * length_per_part
+
+ while cur_end > point_cumsum[cur_node + 1]:
+ cur_node += 1
+
+ e1, e2 = long_edge[cur_node]
+ e1, e2 = points[e1], points[e2]
+
+ # start_point = points[long_edge[cur_node]]
+ end_shift = cur_end - point_cumsum[cur_node]
+ ratio = end_shift / edge_length[cur_node]
+ new_point = e1 + ratio * (e2 - e1)
+ # print(cur_end, point_cumsum[cur_node], end_shift, edge_length[cur_node], '=', new_point)
+ splited_result.append(new_point)
+
+ # add first and last point
+ p_first = points[long_edge[0][0]]
+ p_last = points[long_edge[-1][1]]
+ splited_result = [p_first] + splited_result + [p_last]
+ return np.stack(splited_result)
+
+
+def split_edge_seqence_with_cell_division(points, n_parts):
+ points_seq = list(points)
+ pts_num = len(points_seq)
+
+ if pts_num <= n_parts:
+ long_edge = [(i, (i + 1) % pts_num) for i in range(pts_num)]
+ edge_length = [int(norm2(points[e1] - points[e2])) for e1, e2 in long_edge]
+ while pts_num < n_parts:
+ e = np.argmax(np.array(edge_length))
+ new_pts = (points_seq[e] + points_seq[(e+1) % pts_num])*0.5
+ points_seq.insert(e+1, new_pts)
+ d = int(0.5 * (edge_length[e]-1))
+ edge_length[e] = d
+ edge_length.insert(e+1, d)
+ pts_num = len(points_seq)
+ else:
+ pass
+
+ return np.stack(points_seq).astype(int)
+
+
+def split_edge_seqence_by_step(points, long_edge1, long_edge2, step=16.0):
+
+ edge_length1 = [norm2(points[e1] - points[e2]) for e1, e2 in long_edge1]
+ edge_length2 = [norm2(points[e1] - points[e2]) for e1, e2 in long_edge2]
+ # 取长边 计算bbox个数
+ total_length = (sum(edge_length1)+sum(edge_length2))/2
+ n_parts = math.ceil(float(total_length) / step)
+ try:
+ inner1 = split_edge_seqence(points, long_edge1, n_parts=n_parts)
+ inner2 = split_edge_seqence(points, long_edge2, n_parts=n_parts)
+ except:
+ print(edge_length1)
+ print(edge_length2)
+
+ return inner1, inner2
+
+
+def disjoint_find(x, F):
+ if F[x] == x:
+ return x
+ F[x] = disjoint_find(F[x], F)
+ return F[x]
+
+
+def disjoint_merge(x, y, F):
+ x = disjoint_find(x, F)
+ y = disjoint_find(y, F)
+ if x == y:
+ return False
+ F[y] = x
+ return True
+
+
+def merge_polygons(polygons, merge_map):
+
+ def merge_two_polygon(p1, p2):
+ p2 = Polygon(p2)
+ merged = p1.union(p2)
+ return merged
+
+ merge_map = [disjoint_find(x, merge_map) for x in range(len(merge_map))]
+ merge_map = np.array(merge_map)
+ final_polygons = []
+
+ for i in np.unique(merge_map):
+ merge_idx = np.where(merge_map == i)[0]
+ if len(merge_idx) > 0:
+ merged = Polygon(polygons[merge_idx[0]])
+ for j in range(1, len(merge_idx)):
+ merged = merge_two_polygon(merged, polygons[merge_idx[j]])
+ x, y = merged.exterior.coords.xy
+ final_polygons.append(np.stack([x, y], axis=1).astype(int))
+
+ return final_polygons
+
+
+def get_sample_point(text_mask, num_points, approx_factor, scales=None):
+ # get sample point in contours
+ contours, _ = cv2.findContours(text_mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
+ epsilon = approx_factor * cv2.arcLength(contours[0], True)
+ approx = cv2.approxPolyDP(contours[0], epsilon, True).reshape((-1, 2))
+ # approx = contours[0].reshape((-1, 2))
+ if scales is None:
+ ctrl_points = split_edge_seqence(approx, num_points)
+ else:
+ ctrl_points = split_edge_seqence(approx*scales, num_points)
+ ctrl_points = np.array(ctrl_points[:num_points, :]).astype(np.int32)
+
+ return ctrl_points
+
+
diff --git a/IndicPhotoOCR/detection/textbpn/util/pbox.py b/IndicPhotoOCR/detection/textbpn/util/pbox.py
new file mode 100644
index 0000000000000000000000000000000000000000..b43e1cbcef07f70086857e89fadb52bea957f036
--- /dev/null
+++ b/IndicPhotoOCR/detection/textbpn/util/pbox.py
@@ -0,0 +1,91 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+__author__ = '古溪'
+
+import numpy as np
+from typing import List
+
+
+def functools_reduce(a):
+ # 使用functools內建模块
+ import functools
+ import operator
+ return functools.reduce(operator.concat, a)
+
+
+def minConnectPath(list_all: List[list]):
+ list_nodo = list_all.copy()
+ res = []
+ ept = [0, 0]
+
+ def norm2(a, b):
+ """计算两点之间的距离"""
+ return ((a[0] - b[0]) ** 2 + (a[1] - b[1]) ** 2) ** 0.5
+
+ dict00 = {} # 格式 {距离,(起点坐标,终点坐标)}
+ dict11 = {} # 格式 {距离,(起点坐标,终点坐标)}
+ # 放入一个初始值
+ ept[0] = list_nodo[0] # left end point
+ ept[1] = list_nodo[0] # right end point
+ list_nodo.remove(list_nodo[0])
+ while list_nodo:
+ for i in list_nodo: # i 待处理的
+ length0 = norm2(i, ept[0]) # 端点0终点距离
+ dict00[length0] = [i, ept[0]]
+ length1 = norm2(ept[1], i) # 端点0终点距离
+ dict11[length1] = [ept[1], i]
+ key0 = min(dict00.keys())
+ key1 = min(dict11.keys())
+
+ if key0 <= key1:
+ ss = dict00[key0][0]
+ ee = dict00[key0][1]
+ res.insert(0, [list_all.index(ss), list_all.index(ee)])
+ list_nodo.remove(ss)
+ ept[0] = ss
+ else:
+ ss = dict11[key1][0]
+ ee = dict11[key1][1]
+ res.append([list_all.index(ss), list_all.index(ee)])
+ list_nodo.remove(ee)
+ ept[1] = ee
+
+ dict00 = {}
+ dict11 = {}
+
+ path = functools_reduce(res)
+ path = sorted(set(path), key=path.index) # 去重
+
+ return res, path
+
+
+def bbox_transfor_inv(radius_map, sin_map, cos_map, score_map, wclip=(2, 8), expend=1.0):
+ xy_text = np.argwhere(score_map > 0)
+ # sort the text boxes via the y axis
+ xy_text = xy_text[np.argsort(xy_text[:, 0])]
+ origin = xy_text
+ radius = radius_map[xy_text[:, 0], xy_text[:, 1], :]
+ sin = sin_map[xy_text[:, 0], xy_text[:, 1]]
+ cos = cos_map[xy_text[:, 0], xy_text[:, 1]]
+ dtx = radius[:, 0] * cos * expend
+ dty = radius[:, 0] * sin * expend
+ ddx = radius[:, 1] * cos * expend
+ ddy = radius[:, 1] * sin * expend
+ topp = origin + np.stack([dty, dtx], axis=-1)
+ botp = origin - np.stack([ddy, ddx], axis=-1)
+ width = (radius[:, 0] + radius[:, 1]) // 3
+ width = np.clip(width, wclip[0], wclip[1])
+
+ top1 = topp - np.stack([width * cos, -width * sin], axis=-1)
+ top2 = topp + np.stack([width * cos, -width * sin], axis=-1)
+ bot1 = botp - np.stack([width * cos, -width * sin], axis=-1)
+ bot2 = botp + np.stack([width * cos, -width * sin], axis=-1)
+
+ bbox = np.stack([top1, top2, bot2, bot1], axis=1)[:, :, ::-1]
+ bboxs = np.zeros((bbox.shape[0], 9), dtype=np.float32)
+ bboxs[:, :8] = bbox.reshape((-1, 8))
+ bboxs[:, 8] = score_map[xy_text[:, 0], xy_text[:, 1]]
+
+ return bboxs
+
+
diff --git a/IndicPhotoOCR/detection/textbpn/util/serialization.py b/IndicPhotoOCR/detection/textbpn/util/serialization.py
new file mode 100644
index 0000000000000000000000000000000000000000..8231f0139619b4fb9ef74382de46a3ac0663f59b
--- /dev/null
+++ b/IndicPhotoOCR/detection/textbpn/util/serialization.py
@@ -0,0 +1,89 @@
+from __future__ import print_function, absolute_import
+import json
+import os
+import sys
+# import moxing as mox
+import os.path as osp
+import shutil
+
+import torch
+from torch.nn import Parameter
+
+from .osutils import mkdir_if_missing
+
+from config import get_args
+global_args = get_args(sys.argv[1:])
+
+if global_args.run_on_remote:
+ import moxing as mox
+
+
+def read_json(fpath):
+ with open(fpath, 'r') as f:
+ obj = json.load(f)
+ return obj
+
+
+def write_json(obj, fpath):
+ mkdir_if_missing(osp.dirname(fpath))
+ with open(fpath, 'w') as f:
+ json.dump(obj, f, indent=4, separators=(',', ': '))
+
+
+def save_checkpoint(state, is_best, fpath='checkpoint.pth.tar'):
+ print('=> saving checkpoint ', fpath)
+ if global_args.run_on_remote:
+ dir_name = osp.dirname(fpath)
+ if not mox.file.exists(dir_name):
+ mox.file.make_dirs(dir_name)
+ print('=> makding dir ', dir_name)
+ local_path = "local_checkpoint.pth.tar"
+ torch.save(state, local_path)
+ mox.file.copy(local_path, fpath)
+ if is_best:
+ mox.file.copy(local_path, osp.join(dir_name, 'model_best.pth.tar'))
+ else:
+ mkdir_if_missing(osp.dirname(fpath))
+ torch.save(state, fpath)
+ if is_best:
+ shutil.copy(fpath, osp.join(osp.dirname(fpath), 'model_best.pth.tar'))
+
+
+def load_checkpoint(fpath):
+ if global_args.run_on_remote:
+ mox.file.shift('os', 'mox')
+ checkpoint = torch.load(fpath)
+ print("=> Loaded checkpoint '{}'".format(fpath))
+ return checkpoint
+ else:
+ load_path = fpath
+
+ if osp.isfile(load_path):
+ checkpoint = torch.load(load_path)
+ print("=> Loaded checkpoint '{}'".format(load_path))
+ return checkpoint
+ else:
+ raise ValueError("=> No checkpoint found at '{}'".format(load_path))
+
+
+def copy_state_dict(state_dict, model, strip=None):
+ tgt_state = model.state_dict()
+ copied_names = set()
+ for name, param in state_dict.items():
+ if strip is not None and name.startswith(strip):
+ name = name[len(strip):]
+ if name not in tgt_state:
+ continue
+ if isinstance(param, Parameter):
+ param = param.data
+ if param.size() != tgt_state[name].size():
+ print('mismatch:', name, param.size(), tgt_state[name].size())
+ continue
+ tgt_state[name].copy_(param)
+ copied_names.add(name)
+
+ missing = set(tgt_state.keys()) - copied_names
+ if len(missing) > 0:
+ print("missing keys in state_dict:", missing)
+
+ return model
\ No newline at end of file
diff --git a/IndicPhotoOCR/detection/textbpn/util/shedule.py b/IndicPhotoOCR/detection/textbpn/util/shedule.py
new file mode 100644
index 0000000000000000000000000000000000000000..338083533d0deb2970a36c60c5b0efb80d7064c9
--- /dev/null
+++ b/IndicPhotoOCR/detection/textbpn/util/shedule.py
@@ -0,0 +1,28 @@
+from torch.optim.lr_scheduler import _LRScheduler
+
+class FixLR(_LRScheduler):
+ """Sets the learning rate of each parameter group to the initial lr
+ decayed by gamma every step_size epochs. When last_epoch=-1, sets
+ initial lr as lr.
+
+ Args:
+ optimizer (Optimizer): Wrapped optimizer.
+ step_size (int): Period of learning rate decay.
+ gamma (float): Multiplicative factor of learning rate decay.
+ Default: 0.1.
+ last_epoch (int): The index of last epoch. Default: -1.
+
+ Example:
+ >>> # Fixed leraning rate
+ >>> scheduler = FixLR(optimizer, step_size=30, gamma=0.1)
+ >>> for epoch in range(100):
+ >>> scheduler.step()
+ >>> train(...)
+ >>> validate(...)
+ """
+
+ def __init__(self, optimizer, last_epoch=-1):
+ super().__init__(optimizer, last_epoch)
+
+ def get_lr(self):
+ return self.base_lrs
diff --git a/IndicPhotoOCR/detection/textbpn/util/strs.py b/IndicPhotoOCR/detection/textbpn/util/strs.py
new file mode 100644
index 0000000000000000000000000000000000000000..5009be2a0cb7f56ce603535c79485c39349d2fbc
--- /dev/null
+++ b/IndicPhotoOCR/detection/textbpn/util/strs.py
@@ -0,0 +1,128 @@
+# encoding = utf-8
+def int_array_to_str(arr):
+ """turn an int array to a str"""
+ return "".join(map(chr, arr))
+
+
+def join(arr, splitter=','):
+ temp = []
+ for e in arr:
+ temp.append(e)
+ temp.append(splitter)
+ temp.pop()
+ return "".join(temp)
+
+
+def is_str(s):
+ return type(s) == str
+
+
+def to_lowercase(s):
+ return str.lower(s)
+
+
+def to_uppercase(s):
+ return str.upper(s)
+
+
+def ends_with(s, suffix, ignore_case = False):
+ """
+ suffix: str, list, or tuple
+ """
+ if is_str(suffix):
+ suffix = [suffix]
+ suffix = list(suffix)
+ if ignore_case:
+ for idx, suf in enumerate(suffix):
+ suffix[idx] = to_lowercase(suf)
+ s = to_lowercase(s)
+ suffix = tuple(suffix)
+ return s.endswith(suffix)
+
+
+def starts_with(s, prefix, ignore_case = False):
+ """
+ prefix: str, list, or tuple
+ """
+ if is_str(prefix):
+ prefix = [prefix]
+ prefix = list(prefix)
+ if ignore_case:
+ for idx, pre in enumerate(prefix):
+ prefix[idx] = to_lowercase(pre)
+ s = to_lowercase(s)
+ prefix = tuple(prefix)
+ return s.startswith(prefix)
+
+
+def contains(s, target, ignore_case = False):
+ if ignore_case:
+ s = to_lowercase(s)
+ target = to_lowercase(target)
+ return s.find(target) >= 0
+
+
+def index_of(s, target):
+ return s.find(target)
+
+
+def replace_all(s, old, new, reg = False):
+ if reg:
+ import re
+ targets = re.findall(old, s)
+ for t in targets:
+ s = s.replace(t, new)
+ else:
+ s = s.replace(old, new)
+ return s
+
+
+def remove_all(s, sub):
+ return replace_all(s, sub, '')
+
+
+def split(s, splitter, reg = False):
+ if not reg:
+ return s.split(splitter)
+ import re
+ return re.split(splitter, s)
+
+
+def remove_invisible(s):
+ s = replace_all(s, ' ', '')
+ s = replace_all(s, '\n', '')
+ s = replace_all(s, '\t', '')
+ s = replace_all(s, '\r', '')
+ s = replace_all(s, '\xef\xbb\xbf', '')
+ return s
+
+
+def find_all(s, pattern):
+ import re
+ return re.findall(pattern, s)
+
+
+def is_none_or_empty(s):
+ if s is None:
+ return True
+ return len(s)==0;
+
+
+def to_json(obj):
+ import ujson
+ return ujson.dumps(obj)
+
+
+def to_list(obj):
+ items=obj.replace("(", '').replace(")","")
+ items=items.split(",")
+ lst=[float(i) for i in items]
+
+ return lst
+
+
+def to_tuple(obj):
+ items=obj.replace("(", '').replace(")","")
+ items=items.split(",")
+ tpl=tuple([float(i) for i in items])
+ return tpl
diff --git a/IndicPhotoOCR/detection/textbpn/util/summary.py b/IndicPhotoOCR/detection/textbpn/util/summary.py
new file mode 100644
index 0000000000000000000000000000000000000000..b6e318549da42a0fbca488ea57b72e03e7e5b402
--- /dev/null
+++ b/IndicPhotoOCR/detection/textbpn/util/summary.py
@@ -0,0 +1,26 @@
+from tensorboardX import SummaryWriter
+from util.misc import mkdirs
+
+
+class LogSummary(object):
+
+ def __init__(self, log_path):
+
+ mkdirs(log_path)
+ self.writer = SummaryWriter(log_path)
+
+ def write_scalars(self, scalar_dict, n_iter, tag=None):
+
+ for name, scalar in scalar_dict.items():
+ if tag is not None:
+ name = '/'.join([tag, name])
+ self.writer.add_scalar(name, scalar, n_iter)
+
+ def write_hist_parameters(self, net, n_iter):
+ for name, param in net.named_parameters():
+ self.writer.add_histogram(name, param.clone().cpu().numpy(), n_iter)
+
+
+
+
+
diff --git a/IndicPhotoOCR/detection/textbpn/util/vis_flux.py b/IndicPhotoOCR/detection/textbpn/util/vis_flux.py
new file mode 100644
index 0000000000000000000000000000000000000000..4a80e5dab97e8fdb1f810af4bb86f69486e04475
--- /dev/null
+++ b/IndicPhotoOCR/detection/textbpn/util/vis_flux.py
@@ -0,0 +1,108 @@
+import sys
+import scipy.io as sio
+import math
+import numpy as np
+import cv2
+import matplotlib
+matplotlib.use('agg')
+import pylab as plt
+from matplotlib import cm
+import os
+
+def label2color(label):
+
+ label = label.astype(np.uint16)
+
+ height, width = label.shape
+ color3u = np.zeros((height, width, 3), dtype=np.uint8)
+ unique_labels = np.unique(label)
+
+ if unique_labels[-1] >= 2**24:
+ raise RuntimeError('Error: label overflow!')
+
+ for i in range(len(unique_labels)):
+
+ binary = '{:024b}'.format(unique_labels[i])
+ # r g b 3*8 24
+ r = int(binary[::3][::-1], 2)
+ g = int(binary[1::3][::-1], 2)
+ b = int(binary[2::3][::-1], 2)
+
+ color3u[label == unique_labels[i]] = np.array([r, g, b])
+
+ return color3u
+
+
+def vis_direction_field(gt_flux):
+
+ norm_gt = np.sqrt(gt_flux[1, :, :] ** 2 + gt_flux[0, :, :] ** 2)
+ angle_gt = 180 / math.pi * np.arctan2(gt_flux[1, :, :], gt_flux[0, :, :])
+
+ fig = plt.figure(figsize=(10, 6))
+
+ ax1 = fig.add_subplot(121)
+ ax1.set_title('Norm_gt')
+ ax1.set_autoscale_on(True)
+ im1 = ax1.imshow(norm_gt, cmap=cm.jet)
+ plt.colorbar(im1, shrink=0.5)
+
+ ax2 = fig.add_subplot(122)
+ ax2.set_title('Angle_gt')
+ ax2.set_autoscale_on(True)
+ im2 = ax2.imshow(angle_gt, cmap=cm.jet)
+ plt.colorbar(im2, shrink=0.5)
+
+ plt.savefig('1.png')
+ plt.close(fig)
+
+
+def vis_flux(vis_image, pred_flux, gt_flux, gt_mask, image_name, save_dir):
+
+ vis_image = vis_image.data.cpu().numpy()[0, ...]
+ pred_flux = pred_flux.data.cpu().numpy()[0, ...]
+ gt_flux = gt_flux.data.cpu().numpy()[0, ...]
+ gt_mask = gt_mask.data.cpu().numpy()[0, ...]
+
+ image_name = image_name[0]
+
+ norm_pred = np.sqrt(pred_flux[1,:,:]**2 + pred_flux[0,:,:]**2)
+ angle_pred = 180/math.pi*np.arctan2(pred_flux[1,:,:], pred_flux[0,:,:])
+
+ norm_gt = np.sqrt(gt_flux[1,:,:]**2 + gt_flux[0,:,:]**2)
+ angle_gt = 180/math.pi*np.arctan2(gt_flux[1,:,:], gt_flux[0,:,:])
+
+ fig = plt.figure(figsize=(10,6))
+
+ ax0 = fig.add_subplot(231)
+ ax0.imshow(vis_image[:,:,::-1])
+
+ ax1 = fig.add_subplot(232)
+ ax1.set_title('Norm_gt')
+ ax1.set_autoscale_on(True)
+ im1 = ax1.imshow(norm_gt, cmap=cm.jet)
+ plt.colorbar(im1,shrink=0.5)
+
+ ax2 = fig.add_subplot(233)
+ ax2.set_title('Angle_gt')
+ ax2.set_autoscale_on(True)
+ im2 = ax2.imshow(angle_gt, cmap=cm.jet)
+ plt.colorbar(im2, shrink=0.5)
+
+ ax5 = fig.add_subplot(234)
+ color_mask = label2color(gt_mask)
+ ax5.imshow(color_mask)
+
+ ax4 = fig.add_subplot(235)
+ ax4.set_title('Norm_pred')
+ ax4.set_autoscale_on(True)
+ im4 = ax4.imshow(norm_pred, cmap=cm.jet)
+ plt.colorbar(im4,shrink=0.5)
+
+ ax5 = fig.add_subplot(236)
+ ax5.set_title('Angle_pred')
+ ax5.set_autoscale_on(True)
+ im5 = ax5.imshow(angle_pred, cmap=cm.jet)
+ plt.colorbar(im5, shrink=0.5)
+
+ plt.savefig(save_dir + image_name + '.png')
+ plt.close(fig)
diff --git a/IndicPhotoOCR/detection/textbpn/util/visualize.py b/IndicPhotoOCR/detection/textbpn/util/visualize.py
new file mode 100644
index 0000000000000000000000000000000000000000..cf6d441171f73842b97fc46ce926b2cea318ebde
--- /dev/null
+++ b/IndicPhotoOCR/detection/textbpn/util/visualize.py
@@ -0,0 +1,245 @@
+import torch
+import numpy as np
+import cv2
+import os
+import math
+from IndicPhotoOCR.detection.textbpn.cfglib.config import config as cfg
+from IndicPhotoOCR.detection.textbpn.util import canvas as cav
+import matplotlib
+matplotlib.use('agg')
+import pylab as plt
+from matplotlib import cm
+import torch.nn.functional as F
+
+
+def visualize_network_output(output_dict, input_dict, mode='train'):
+ vis_dir = os.path.join(cfg.vis_dir, cfg.exp_name + '_' + mode)
+ if not os.path.exists(vis_dir):
+ os.mkdir(vis_dir)
+
+ fy_preds = F.interpolate(output_dict["fy_preds"], scale_factor=cfg.scale, mode='bilinear')
+ fy_preds = fy_preds.data.cpu().numpy()
+
+ py_preds = output_dict["py_preds"][1:]
+ init_polys = output_dict["py_preds"][0]
+ inds = output_dict["inds"]
+
+ image = input_dict['img']
+ tr_mask = input_dict['tr_mask'].data.cpu().numpy() > 0
+ distance_field = input_dict['distance_field'].data.cpu().numpy()
+ direction_field = input_dict['direction_field']
+ weight_matrix = input_dict['weight_matrix']
+ gt_tags = input_dict['gt_points'].cpu().numpy()
+ ignore_tags = input_dict['ignore_tags'].cpu().numpy()
+
+ b, c, _, _ = fy_preds.shape
+ for i in range(b):
+
+ fig = plt.figure(figsize=(12, 9))
+
+ mask_pred = fy_preds[i, 0, :, :]
+ distance_pred = fy_preds[i, 1, :, :]
+ norm_pred = np.sqrt(fy_preds[i, 2, :, :] ** 2 + fy_preds[i, 3, :, :] ** 2)
+ angle_pred = 180 / math.pi * np.arctan2(fy_preds[i, 2, :, :], fy_preds[i, 3, :, :] + 0.00001)
+
+ ax1 = fig.add_subplot(341)
+ ax1.set_title('mask_pred')
+ # ax1.set_autoscale_on(True)
+ im1 = ax1.imshow(mask_pred, cmap=cm.jet)
+ # plt.colorbar(im1, shrink=0.5)
+
+ ax2 = fig.add_subplot(342)
+ ax2.set_title('distance_pred')
+ # ax2.set_autoscale_on(True)
+ im2 = ax2.imshow(distance_pred, cmap=cm.jet)
+ # plt.colorbar(im2, shrink=0.5)
+
+ ax3 = fig.add_subplot(343)
+ ax3.set_title('norm_pred')
+ # ax3.set_autoscale_on(True)
+ im3 = ax3.imshow(norm_pred, cmap=cm.jet)
+ # plt.colorbar(im3, shrink=0.5)
+
+ ax4 = fig.add_subplot(344)
+ ax4.set_title('angle_pred')
+ # ax4.set_autoscale_on(True)
+ im4 = ax4.imshow(angle_pred, cmap=cm.jet)
+ # plt.colorbar(im4, shrink=0.5)
+
+ mask_gt = tr_mask[i]
+ distance_gt = distance_field[i]
+ # gt_flux = 0.999999 * direction_field[i] / (direction_field[i].norm(p=2, dim=0) + 1e-9)
+ gt_flux = direction_field[i].cpu().numpy()
+ norm_gt = np.sqrt(gt_flux[0, :, :] ** 2 + gt_flux[1, :, :] ** 2)
+ angle_gt = 180 / math.pi * np.arctan2(gt_flux[0, :, :], gt_flux[1, :, :]+0.00001)
+
+ ax11 = fig.add_subplot(345)
+ # ax11.set_title('mask_gt')
+ # ax11.set_autoscale_on(True)
+ im11 = ax11.imshow(mask_gt, cmap=cm.jet)
+ # plt.colorbar(im11, shrink=0.5)
+
+ ax22 = fig.add_subplot(346)
+ # ax22.set_title('distance_gt')
+ # ax22.set_autoscale_on(True)
+ im22 = ax22.imshow(distance_gt, cmap=cm.jet)
+ # plt.colorbar(im22, shrink=0.5)
+
+ ax33 = fig.add_subplot(347)
+ # ax33.set_title('norm_gt')
+ # ax33.set_autoscale_on(True)
+ im33 = ax33.imshow(norm_gt, cmap=cm.jet)
+ # plt.colorbar(im33, shrink=0.5)
+
+ ax44 = fig.add_subplot(348)
+ # ax44.set_title('angle_gt')
+ # ax44.set_autoscale_on(True)
+ im44 = ax44.imshow(angle_gt, cmap=cm.jet)
+ # plt.colorbar(im44, shrink=0.5)
+
+ img_show = image[i].permute(1, 2, 0).cpu().numpy()
+ img_show = ((img_show * cfg.stds + cfg.means) * 255).astype(np.uint8)
+ img_show = np.ascontiguousarray(img_show[:, :, ::-1])
+ shows = []
+ gt = gt_tags[i]
+ gt_idx = np.where(ignore_tags[i] > 0)
+ gt_py = gt[gt_idx[0], :, :]
+ index = torch.where(inds[0] == i)[0]
+ init_py = init_polys[index].detach().cpu().numpy()
+
+ image_show = img_show.copy()
+ cv2.drawContours(image_show, init_py.astype(np.int32), -1, (255, 255, 0), 2)
+ cv2.drawContours(image_show, gt_py.astype(np.int32), -1, (0, 255, 0), 2)
+ shows.append(image_show)
+ for py in py_preds:
+ contours = py[index].detach().cpu().numpy()
+ image_show = img_show.copy()
+ cv2.drawContours(image_show, init_py.astype(np.int32), -1, (255, 255, 0), 2)
+ cv2.drawContours(image_show, gt_py.astype(np.int32), -1, (0, 255, 0), 2)
+ cv2.drawContours(image_show, contours.astype(np.int32), -1, (0, 0, 255), 2)
+ shows.append(image_show)
+
+ for idx, im_show in enumerate(shows):
+ axb = fig.add_subplot(3, 4, 9+idx)
+ # axb.set_title('boundary_{}'.format(idx))
+ # axb.set_autoscale_on(True)
+ im11 = axb.imshow(im_show, cmap=cm.jet)
+ # plt.colorbar(im11, shrink=0.5)
+
+ path = os.path.join(vis_dir, '{}.png'.format(i))
+ plt.savefig(path)
+ plt.close(fig)
+
+
+def visualize_gt(image, contours, label_tag):
+
+ image_show = image.copy()
+ image_show = np.ascontiguousarray(image_show[:, :, ::-1])
+
+ image_show = cv2.polylines(image_show,
+ [contours[i] for i, tag in enumerate(label_tag) if tag >0], True, (0, 0, 255), 3)
+ image_show = cv2.polylines(image_show,
+ [contours[i] for i, tag in enumerate(label_tag) if tag <0], True, (0, 255, 0), 3)
+
+ show_gt = cv2.resize(image_show, (320, 320))
+
+ return show_gt
+
+
+def visualize_detection(image, output_dict, meta=None):
+ image_show = image.copy()
+ image_show = np.ascontiguousarray(image_show[:, :, ::-1])
+
+ cls_preds = F.interpolate(output_dict["fy_preds"], scale_factor=cfg.scale, mode='bilinear')
+ cls_preds = cls_preds[0].data.cpu().numpy()
+
+ py_preds = output_dict["py_preds"][1:]
+ init_polys = output_dict["py_preds"][0]
+ shows = []
+
+ init_py = init_polys.data.cpu().numpy()
+ path = os.path.join(cfg.vis_dir, '{}_test'.format(cfg.exp_name),
+ meta['image_id'][0].split(".")[0] + "_init.png")
+
+ im_show0 = image_show.copy()
+ for i, bpts in enumerate(init_py.astype(np.int32)):
+ cv2.drawContours(im_show0, [bpts.astype(np.int32)], -1, (255, 255, 0), 2)
+ for j, pp in enumerate(bpts):
+ if j == 0:
+ cv2.circle(im_show0, (int(pp[0]), int(pp[1])), 3, (255, 0, 255), -1)
+ elif j == 1:
+ cv2.circle(im_show0, (int(pp[0]), int(pp[1])), 3, (0, 255, 255), -1)
+ else:
+ cv2.circle(im_show0, (int(pp[0]), int(pp[1])), 3, (0, 0, 255), -1)
+
+ cv2.imwrite(path, im_show0)
+
+ for idx, py in enumerate(py_preds):
+ im_show = im_show0.copy()
+ contours = py.data.cpu().numpy()
+ cv2.drawContours(im_show, contours.astype(np.int32), -1, (0, 0, 255), 2)
+ for ppts in contours:
+ for j, pp in enumerate(ppts):
+ if j == 0:
+ cv2.circle(im_show, (int(pp[0]), int(pp[1])), 3, (255, 0, 255), -1)
+ elif j == 1:
+ cv2.circle(im_show, (int(pp[0]), int(pp[1])), 3, (0, 255, 255), -1)
+ else:
+ cv2.circle(im_show, (int(pp[0]), int(pp[1])), 3, (0, 255, 0), -1)
+ path = os.path.join(cfg.vis_dir, '{}_test'.format(cfg.exp_name),
+ meta['image_id'][0].split(".")[0] + "_{}iter.png".format(idx))
+ cv2.imwrite(path, im_show)
+ shows.append(im_show)
+
+ # init_py = init_polys.data.cpu().numpy()
+ # im_show_score = image_show.copy()
+ # for in_py in init_py:
+ # mask = np.zeros_like(cls_preds[0], dtype=np.uint8)
+ # cv2.drawContours(mask, [in_py.astype(np.int32)], -1, (1,), -1)
+ # score = cls_preds[0][mask > 0].mean()
+ # if score > 0.9:
+ # cv2.drawContours(im_show_score, [in_py.astype(np.int32)], -1, (0, 255, 0), 2)
+ # else:
+ # cv2.drawContours(im_show_score, [in_py.astype(np.int32)], -1, (255, 0, 255), 2)
+ # cv2.putText(im_show_score, "{:.2f}".format(score),
+ # (int(np.mean(in_py[:, 0])), int(np.mean(in_py[:, 1]))), 1, 1, (0, 255, 255), 2)
+ # print(score)
+
+ # path = os.path.join(cfg.vis_dir, '{}_test'.format(cfg.exp_name),
+ # meta['image_id'][0].split(".")[0] + "init.png")
+ # cv2.imwrite(path, im_show_score)
+
+ show_img = np.concatenate(shows, axis=1)
+ show_boundary = cv2.resize(show_img, (320 * len(py_preds), 320))
+
+ # fig = plt.figure(figsize=(5, 4))
+ # ax1 = fig.add_subplot(111)
+ # # ax1.set_title('distance_field')
+ # ax1.set_autoscale_on(True)
+ # im1 = ax1.imshow(cls_preds[0], cmap=cm.jet)
+ # plt.colorbar(im1, shrink=0.75)
+ # plt.axis("off")
+ # path = os.path.join(cfg.vis_dir, '{}_test'.format(cfg.exp_name),
+ # meta['image_id'][0].split(".")[0] + "_cls.png")
+ # plt.savefig(path, dpi=300)
+ # plt.close(fig)
+ #
+ # fig = plt.figure(figsize=(5, 4))
+ # ax1 = fig.add_subplot(111)
+ # # ax1.set_title('distance_field')
+ # ax1.set_autoscale_on(True)
+ # im1 = ax1.imshow(np.array(cls_preds[1] / np.max(cls_preds[1])), cmap=cm.jet)
+ # plt.colorbar(im1, shrink=0.75)
+ # plt.axis("off")
+ # path = os.path.join(cfg.vis_dir, '{}_test'.format(cfg.exp_name),
+ # meta['image_id'][0].split(".")[0] + "_dis.png")
+ # plt.savefig(path, dpi=300)
+ # plt.close(fig)
+
+ cls_pred = cav.heatmap(np.array(cls_preds[0] * 255, dtype=np.uint8))
+ dis_pred = cav.heatmap(np.array(cls_preds[1] * 255, dtype=np.uint8))
+
+ heat_map = np.concatenate([cls_pred*255, dis_pred*255], axis=1)
+ heat_map = cv2.resize(heat_map, (320 * 2, 320))
+
+ return show_boundary, heat_map
\ No newline at end of file
diff --git a/IndicPhotoOCR/ocr.py b/IndicPhotoOCR/ocr.py
new file mode 100644
index 0000000000000000000000000000000000000000..8c462a20b1955f286e21a35fe5df0a813eebfa89
--- /dev/null
+++ b/IndicPhotoOCR/ocr.py
@@ -0,0 +1,186 @@
+import sys
+import os
+import torch
+from PIL import Image
+import cv2
+import numpy as np
+
+
+# from IndicPhotoOCR.detection.east_detector import EASTdetector
+# from IndicPhotoOCR.script_identification.CLIP_identifier import CLIPidentifier
+from IndicPhotoOCR.script_identification.vit.vit_infer import VIT_identifier
+from IndicPhotoOCR.recognition.parseq_recogniser import PARseqrecogniser
+import IndicPhotoOCR.detection.east_config as cfg
+from IndicPhotoOCR.detection.textbpn.textbpnpp_detector import TextBPNpp_detector
+
+from IndicPhotoOCR.utils.helper import detect_para
+
+
+class OCR:
+ def __init__(self, device='cuda:0', verbose=False):
+ # self.detect_model_checkpoint = detect_model_checkpoint
+ self.device = device
+ self.verbose = verbose
+ # self.image_path = image_path
+ # self.detector = EASTdetector()
+ self.detector = TextBPNpp_detector(device=self.device)
+ self.recogniser = PARseqrecogniser()
+ # self.identifier = CLIPidentifier()
+ self.identifier = VIT_identifier()
+
+ # def detect(self, image_path, detect_model_checkpoint=cfg.checkpoint):
+ # """Run the detection model to get bounding boxes of text areas."""
+
+ # if self.verbose:
+ # print("Running text detection...")
+ # detections = self.detector.detect(image_path, detect_model_checkpoint, self.device)
+ # # print(detections)
+ # return detections['detections']
+ def detect(self, image_path):
+ self.detections = self.detector.detect(image_path)
+ return self.detections['detections']
+
+ def visualize_detection(self, image_path, detections, save_path=None, show=False):
+ # Default save path if none is provided
+ default_save_path = "test.png"
+ path_to_save = save_path if save_path is not None else default_save_path
+
+ # Get the directory part of the path
+ directory = os.path.dirname(path_to_save)
+
+ # Check if the directory exists, and create it if it doesn’t
+ if directory and not os.path.exists(directory):
+ os.makedirs(directory)
+ print(f"Created directory: {directory}")
+
+ # Read the image and draw bounding boxes
+ image = cv2.imread(image_path)
+ for box in detections:
+ # Convert list of points to a numpy array with int type
+ points = np.array(box, np.int32)
+
+ # Compute the top-left and bottom-right corners of the bounding box
+ x_min = np.min(points[:, 0])
+ y_min = np.min(points[:, 1])
+ x_max = np.max(points[:, 0])
+ y_max = np.max(points[:, 1])
+
+ # Draw the rectangle
+ cv2.rectangle(image, (x_min, y_min), (x_max, y_max), color=(0, 255, 0), thickness=3)
+
+ # Show the image if 'show' is True
+ if show:
+ plt.figure(figsize=(10, 10))
+ plt.imshow(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
+ plt.axis("off")
+ plt.show()
+
+ # Save the annotated image
+ cv2.imwrite(path_to_save, image)
+ print(f"Image saved at: {path_to_save}")
+
+ def crop_and_identify_script(self, image, bbox):
+ """
+ Crop a text area from the image and identify its script language.
+
+ Args:
+ image (PIL.Image): The full image.
+ bbox (list): List of four corner points, each a [x, y] pair.
+
+ Returns:
+ str: Identified script language.
+ """
+ # Extract x and y coordinates from the four corner points
+ x_coords = [point[0] for point in bbox]
+ y_coords = [point[1] for point in bbox]
+
+ # Get the bounding box coordinates (min and max)
+ x_min, y_min = min(x_coords), min(y_coords)
+ x_max, y_max = max(x_coords), max(y_coords)
+
+ # Crop the image based on the bounding box
+ cropped_image = image.crop((x_min, y_min, x_max, y_max))
+ root_image_dir = "IndicPhotoOCR/script_identification"
+ os.makedirs(f"{root_image_dir}/images", exist_ok=True)
+ # Temporarily save the cropped image to pass to the script model
+ cropped_path = f'{root_image_dir}/images/temp_crop_{x_min}_{y_min}.jpg'
+ cropped_image.save(cropped_path)
+
+ # Predict script language, here we assume "hindi" as the model name
+ if self.verbose:
+ print("Identifying script for the cropped area...")
+ script_lang = self.identifier.identify(cropped_path, "hindi", self.device) # Use "hindi" as the model name
+ # print(script_lang)
+
+ # Clean up temporary file
+ # os.remove(cropped_path)
+
+ return script_lang, cropped_path
+
+ def recognise(self, cropped_image_path, script_lang):
+ """Recognize text in a cropped image area using the identified script."""
+ if self.verbose:
+ print("Recognizing text in detected area...")
+ recognized_text = self.recogniser.recognise(script_lang, cropped_image_path, script_lang, self.verbose, self.device)
+ # print(recognized_text)
+ return recognized_text
+
+ def ocr(self, image_path):
+ """Process the image by detecting text areas, identifying script, and recognizing text."""
+ recognized_texts = {}
+ recognized_words = []
+ image = Image.open(image_path)
+
+ # Run detection
+ detections = self.detect(image_path)
+
+ # Process each detected text area
+ # for bbox in detections:
+ # # Crop and identify script language
+ # script_lang, cropped_path = self.crop_and_identify_script(image, bbox)
+
+ # # Check if the script language is valid
+ # if script_lang:
+
+ # # Recognize text
+ # recognized_word = self.recognise(cropped_path, script_lang)
+ # recognized_words.append(recognized_word)
+
+ # if self.verbose:
+ # print(f"Recognized word: {recognized_word}")
+
+
+ for id, bbox in enumerate(detections):
+ # Identify the script and crop the image to this region
+ script_lang, cropped_path = self.crop_and_identify_script(image, bbox)
+
+ # Calculate bounding box coordinates
+ x1 = min([bbox[i][0] for i in range(len(bbox))])
+ y1 = min([bbox[i][1] for i in range(len(bbox))])
+ x2 = max([bbox[i][0] for i in range(len(bbox))])
+ y2 = max([bbox[i][1] for i in range(len(bbox))])
+
+ if script_lang:
+ recognized_text = self.recognise(cropped_path, script_lang)
+ recognized_texts[f"img_{id}"] = {"txt": recognized_text, "bbox": [x1, y1, x2, y2]}
+
+ return detect_para(recognized_texts)
+ # return recognized_words
+
+if __name__ == '__main__':
+ # detect_model_checkpoint = 'bharatSTR/East/tmp/epoch_990_checkpoint.pth.tar'
+ sample_image_path = 'test_images/image_88.jpg'
+ cropped_image_path = 'test_images/cropped_image/image_141_0.jpg'
+
+ ocr = OCR(device="cuda", verbose=False)
+
+ # detections = ocr.detect(sample_image_path)
+ # print(detections)
+
+ # ocr.visualize_detection(sample_image_path, detections)
+
+ # recognition = ocr.recognise(cropped_image_path, "hindi")
+ # print(recognition)
+
+ recognised_words = ocr.ocr(sample_image_path)
+ print(recognised_words)
\ No newline at end of file
diff --git a/IndicPhotoOCR/recognition/__init__.py b/IndicPhotoOCR/recognition/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/IndicPhotoOCR/recognition/parseq_recogniser.py b/IndicPhotoOCR/recognition/parseq_recogniser.py
new file mode 100644
index 0000000000000000000000000000000000000000..26f8a84cd399bc97dbde2ba26353f46a8b830137
--- /dev/null
+++ b/IndicPhotoOCR/recognition/parseq_recogniser.py
@@ -0,0 +1,215 @@
+import csv
+# import fire
+import json
+import numpy as np
+import os
+# import pandas as pd
+import sys
+import torch
+import requests
+
+from dataclasses import dataclass
+from PIL import Image
+from nltk import edit_distance
+from torchvision import transforms as T
+from typing import Optional, Callable, Sequence, Tuple
+from tqdm import tqdm
+
+
+from IndicPhotoOCR.utils.strhub.data.module import SceneTextDataModule
+from IndicPhotoOCR.utils.strhub.models.utils import load_from_checkpoint
+
+
+model_info = {
+ "assamese": {
+ "path": "models/assamese.ckpt",
+ "url" : "https://github.com/anikde/STocr/releases/download/V2.0.0/assamese.ckpt",
+ },
+ "bengali": {
+ "path": "models/bengali.ckpt",
+ "url" : "https://github.com/anikde/STocr/releases/download/V2.0.0/bengali.ckpt",
+ },
+ "hindi": {
+ "path": "models/hindi.ckpt",
+ "url" : "https://github.com/anikde/STocr/releases/download/V2.0.0/hindi.ckpt",
+ },
+ "gujarati": {
+ "path": "models/gujarati.ckpt",
+ "url" : "https://github.com/anikde/STocr/releases/download/V2.0.0/gujarati.ckpt",
+ },
+ "marathi": {
+ "path": "models/marathi.ckpt",
+ "url" : "https://github.com/anikde/STocr/releases/download/V2.0.0/marathi.ckpt",
+ },
+ "odia": {
+ "path": "models/odia.ckpt",
+ "url" : "https://github.com/anikde/STocr/releases/download/V2.0.0/odia.ckpt",
+ },
+ "punjabi": {
+ "path": "models/punjabi.ckpt",
+ "url" : "https://github.com/anikde/STocr/releases/download/V2.0.0/punjabi.ckpt",
+ },
+ "tamil": {
+ "path": "models/tamil.ckpt",
+ "url" : "https://github.com/anikde/STocr/releases/download/V2.0.0/tamil.ckpt",
+ },
+ "telugu": {
+ "path": "models/telugu.ckpt",
+ "url" : "https://github.com/anikde/STocr/releases/download/V2.0.0/telugu.ckpt",
+ }
+}
+
+class PARseqrecogniser:
+ def __init__(self):
+ pass
+
+ def get_transform(self, img_size: Tuple[int], augment: bool = False, rotation: int = 0):
+ transforms = []
+ if augment:
+ from .augment import rand_augment_transform
+ transforms.append(rand_augment_transform())
+ if rotation:
+ transforms.append(lambda img: img.rotate(rotation, expand=True))
+ transforms.extend([
+ T.Resize(img_size, T.InterpolationMode.BICUBIC),
+ T.ToTensor(),
+ T.Normalize(0.5, 0.5)
+ ])
+ return T.Compose(transforms)
+
+
+ def load_model(self, device, checkpoint):
+ model = load_from_checkpoint(checkpoint).eval().to(device)
+ return model
+
+ def get_model_output(self, device, model, image_path):
+ hp = model.hparams
+ transform = self.get_transform(hp.img_size, rotation=0)
+
+ image_name = image_path.split("/")[-1]
+ img = Image.open(image_path).convert('RGB')
+ img = transform(img)
+ logits = model(img.unsqueeze(0).to(device))
+ probs = logits.softmax(-1)
+ preds, probs = model.tokenizer.decode(probs)
+ text = model.charset_adapter(preds[0])
+ scores = probs[0].detach().cpu().numpy()
+
+ return text
+
+ # Ensure model file exists; download directly if not
+ def ensure_model(self, model_name):
+ model_path = model_info[model_name]["path"]
+ url = model_info[model_name]["url"]
+ root_model_dir = "IndicPhotoOCR/recognition/"
+ model_path = os.path.join(root_model_dir, model_path)
+
+ if not os.path.exists(model_path):
+ print(f"Model not found locally. Downloading {model_name} from {url}...")
+
+ # Start the download with a progress bar
+ response = requests.get(url, stream=True)
+ total_size = int(response.headers.get('content-length', 0))
+ os.makedirs(f"{root_model_dir}/models", exist_ok=True)
+
+ with open(model_path, "wb") as f, tqdm(
+ desc=model_name,
+ total=total_size,
+ unit='B',
+ unit_scale=True,
+ unit_divisor=1024,
+ ) as bar:
+ for data in response.iter_content(chunk_size=1024):
+ f.write(data)
+ bar.update(len(data))
+
+ print(f"Downloaded model for {model_name}.")
+
+ return model_path
+
+ def bstr(checkpoint, language, image_dir, save_dir):
+ """
+ Runs the OCR model to process images and save the output as a JSON file.
+
+ Args:
+ checkpoint (str): Path to the model checkpoint file.
+ language (str): Language code (e.g., 'hindi', 'english').
+ image_dir (str): Directory containing the images to process.
+ save_dir (str): Directory where the output JSON file will be saved.
+
+ Example usage:
+ python your_script.py --checkpoint /path/to/checkpoint.ckpt --language hindi --image_dir /path/to/images --save_dir /path/to/save
+ """
+ device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
+
+ if language != "english":
+ model = load_model(device, checkpoint)
+ else:
+ model = torch.hub.load('baudm/parseq', 'parseq', pretrained=True).eval().to(device)
+
+ parseq_dict = {}
+ for image_path in tqdm(os.listdir(image_dir)):
+ assert os.path.exists(os.path.join(image_dir, image_path)) == True, f"{image_path}"
+ text = get_model_output(device, model, os.path.join(image_dir, image_path), language=f"{language}")
+
+ filename = image_path.split('/')[-1]
+ parseq_dict[filename] = text
+
+ os.makedirs(save_dir, exist_ok=True)
+ with open(f"{save_dir}/{language}_test.json", 'w') as json_file:
+ json.dump(parseq_dict, json_file, indent=4, ensure_ascii=False)
+
+
+ def bstr_onImage(checkpoint, language, image_path):
+ """
+ Runs the OCR model to process images and save the output as a JSON file.
+
+ Args:
+ checkpoint (str): Path to the model checkpoint file.
+ language (str): Language code (e.g., 'hindi', 'english').
+ image_dir (str): Directory containing the images to process.
+ save_dir (str): Directory where the output JSON file will be saved.
+
+ Example usage:
+ python your_script.py --checkpoint /path/to/checkpoint.ckpt --language hindi --image_dir /path/to/images --save_dir /path/to/save
+ """
+ device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
+
+ if language != "english":
+ model = load_model(device, checkpoint)
+ else:
+ model = torch.hub.load('baudm/parseq', 'parseq', pretrained=True).eval().to(device)
+
+ # parseq_dict = {}
+ # for image_path in tqdm(os.listdir(image_dir)):
+ # assert os.path.exists(os.path.join(image_dir, image_path)) == True, f"{image_path}"
+ text = get_model_output(device, model, image_path, language=f"{language}")
+
+ return text
+
+
+ def recognise(self, checkpoint: str, image_path: str, language: str, verbose: bool, device: str) -> str:
+ """
+ Loads the desired model and returns the recognized word from the specified image.
+
+ Args:
+ checkpoint (str): Path to the model checkpoint file.
+ language (str): Language code (e.g., 'hindi', 'english').
+ image_path (str): Path to the image for which text recognition is needed.
+
+ Returns:
+ str: The recognized text from the image.
+ """
+ # device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
+
+ if language != "english":
+ model_path = self.ensure_model(checkpoint)
+ model = self.load_model(device, model_path)
+ else:
+ model = torch.hub.load('baudm/parseq', 'parseq', pretrained=True, verbose=verbose).eval().to(device)
+
+ recognized_text = self.get_model_output(device, model, image_path)
+
+ return recognized_text
+# if __name__ == '__main__':
+# fire.Fire(main)
\ No newline at end of file
diff --git a/IndicPhotoOCR/script_identification/CLIP_identifier.py b/IndicPhotoOCR/script_identification/CLIP_identifier.py
new file mode 100644
index 0000000000000000000000000000000000000000..c47531f88de065104461ef2f909ab80150e65e79
--- /dev/null
+++ b/IndicPhotoOCR/script_identification/CLIP_identifier.py
@@ -0,0 +1,201 @@
+
+import torch
+import clip
+from PIL import Image
+from io import BytesIO
+import os
+import requests
+
+# Model information dictionary containing model paths and language subcategories
+model_info = {
+ "hindi": {
+ "path": "models/clip_finetuned_hindienglish_real.pth",
+ "url" : "https://github.com/anikde/STscriptdetect/releases/download/V1/clip_finetuned_hindienglish_real.pth",
+ "subcategories": ["hindi", "english"]
+ },
+ "hinengasm": {
+ "path": "models/clip_finetuned_hindienglishassamese_real.pth",
+ "url": "https://github.com/anikde/STscriptdetect/releases/download/V1/clip_finetuned_hindienglishassamese_real.pth",
+ "subcategories": ["hindi", "english", "assamese"]
+ },
+ "hinengben": {
+ "path": "models/clip_finetuned_hindienglishbengali_real.pth",
+ "url" : "https://github.com/anikde/STscriptdetect/releases/download/V1/clip_finetuned_hindienglishbengali_real.pth",
+ "subcategories": ["hindi", "english", "bengali"]
+ },
+ "hinengguj": {
+ "path": "models/clip_finetuned_hindienglishgujarati_real.pth",
+ "url" : "https://github.com/anikde/STscriptdetect/releases/download/V1/clip_finetuned_hindienglishgujarati_real.pth",
+ "subcategories": ["hindi", "english", "gujarati"]
+ },
+ "hinengkan": {
+ "path": "models/clip_finetuned_hindienglishkannada_real.pth",
+ "url" : "https://github.com/anikde/STscriptdetect/releases/download/V1/clip_finetuned_hindienglishkannada_real.pth",
+ "subcategories": ["hindi", "english", "kannada"]
+ },
+ "hinengmal": {
+ "path": "models/clip_finetuned_hindienglishmalayalam_real.pth",
+ "url" : "https://github.com/anikde/STscriptdetect/releases/download/V1/clip_finetuned_hindienglishmalayalam_real.pth",
+ "subcategories": ["hindi", "english", "malayalam"]
+ },
+ "hinengmar": {
+ "path": "models/clip_finetuned_hindienglishmarathi_real.pth",
+ "url" : "https://github.com/anikde/STscriptdetect/releases/download/V1/clip_finetuned_hindienglishmarathi_real.pth",
+ "subcategories": ["hindi", "english", "marathi"]
+ },
+ "hinengmei": {
+ "path": "models/clip_finetuned_hindienglishmeitei_real.pth",
+ "url" : "https://github.com/anikde/STscriptdetect/releases/download/V1/clip_finetuned_hindienglishmeitei_real.pth",
+ "subcategories": ["hindi", "english", "meitei"]
+ },
+ "hinengodi": {
+ "path": "models/clip_finetuned_hindienglishodia_real.pth",
+ "url" : "https://github.com/anikde/STscriptdetect/releases/download/V1/clip_finetuned_hindienglishodia_real.pth",
+ "subcategories": ["hindi", "english", "odia"]
+ },
+ "hinengpun": {
+ "path": "models/clip_finetuned_hindienglishpunjabi_real.pth",
+ "url" : "https://github.com/anikde/STscriptdetect/releases/download/V1/clip_finetuned_hindienglishpunjabi_real.pth",
+ "subcategories": ["hindi", "english", "punjabi"]
+ },
+ "hinengtam": {
+ "path": "models/clip_finetuned_hindienglishtamil_real.pth",
+ "url" : "https://github.com/anikde/STscriptdetect/releases/download/V1/clip_finetuned_hindienglishtamil_real.pth",
+ "subcategories": ["hindi", "english", "tamil"]
+ },
+ "hinengtel": {
+ "path": "models/clip_finetuned_hindienglishtelugu_real.pth",
+ "url" : "https://github.com/anikde/STscriptdetect/releases/download/V1/clip_finetuned_hindienglishtelugu_real.pth",
+ "subcategories": ["hindi", "english", "telugu"]
+ },
+ "hinengurd": {
+ "path": "models/clip_finetuned_hindienglishurdu_real.pth",
+ "url" : "https://github.com/anikde/STscriptdetect/releases/download/V1/clip_finetuned_hindienglishurdu_real.pth",
+ "subcategories": ["hindi", "english", "urdu"]
+ },
+
+
+}
+
+
+# Set device to CUDA if available, otherwise use CPU
+device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+clip_model, preprocess = clip.load("ViT-B/32", device=device)
+
+class CLIPFineTuner(torch.nn.Module):
+ """
+ Fine-tuning class for the CLIP model to adapt to specific tasks.
+
+ Attributes:
+ model (torch.nn.Module): The CLIP model to be fine-tuned.
+ classifier (torch.nn.Linear): A linear classifier to map features to the desired number of classes.
+ """
+ def __init__(self, model, num_classes):
+ """
+ Initializes the fine-tuner with the CLIP model and classifier.
+
+ Args:
+ model (torch.nn.Module): The base CLIP model.
+ num_classes (int): The number of target classes for classification.
+ """
+ super(CLIPFineTuner, self).__init__()
+ self.model = model
+ self.classifier = torch.nn.Linear(model.visual.output_dim, num_classes)
+
+ def forward(self, x):
+ """
+ Forward pass for image classification.
+
+ Args:
+ x (torch.Tensor): Preprocessed input tensor for an image.
+
+ Returns:
+ torch.Tensor: Logits for each class.
+ """
+ with torch.no_grad():
+ features = self.model.encode_image(x).float() # Extract image features from CLIP model
+ return self.classifier(features) # Return class logits
+
+class CLIPidentifier:
+ def __init__(self):
+ pass
+
+ # Ensure model file exists; download directly if not
+ def ensure_model(self, model_name):
+ model_path = model_info[model_name]["path"]
+ url = model_info[model_name]["url"]
+ root_model_dir = "IndicPhotoOCR/script_identification/"
+ model_path = os.path.join(root_model_dir, model_path)
+
+ if not os.path.exists(model_path):
+ print(f"Model not found locally. Downloading {model_name} from {url}...")
+ response = requests.get(url, stream=True)
+ os.makedirs(f"{root_model_dir}/models", exist_ok=True)
+ with open(f"{model_path}", "wb") as f:
+ f.write(response.content)
+ print(f"Downloaded model for {model_name}.")
+
+ return model_path
+
+ # Prediction function to verify and load the model
+ def identify(self, image_path, model_name):
+ """
+ Predicts the class of an input image using a fine-tuned CLIP model.
+
+ Args:
+ image_path (str): Path to the input image file.
+ model_name (str): Name of the model (e.g., hineng, hinengpun, hinengguj) as specified in `model_info`.
+
+ Returns:
+ dict: Contains either `predicted_class` if successful or `error` if an exception occurs.
+
+ Example usage:
+ result = predict("sample_image.jpg", "hinengguj")
+ print(result) # Output might be {'predicted_class': 'hindi'}
+ """
+ try:
+ # Validate model name and retrieve associated subcategories
+ if model_name not in model_info:
+ return {"error": "Invalid model name"}
+
+ # Ensure the model file is downloaded and accessible
+ model_path = self.ensure_model(model_name)
+
+
+ subcategories = model_info[model_name]["subcategories"]
+ num_classes = len(subcategories)
+
+ # Load the fine-tuned model with the specified number of classes
+ model_ft = CLIPFineTuner(clip_model, num_classes)
+ model_ft.load_state_dict(torch.load(model_path, map_location=device))
+ model_ft = model_ft.to(device)
+ model_ft.eval()
+
+ # Load and preprocess the image
+ image = Image.open(image_path).convert("RGB")
+ input_tensor = preprocess(image).unsqueeze(0).to(device)
+
+ # Run the model and get the prediction
+ outputs = model_ft(input_tensor)
+ _, predicted_idx = torch.max(outputs, 1)
+ predicted_class = subcategories[predicted_idx.item()]
+
+ return predicted_class
+
+ except Exception as e:
+ return {"error": str(e)}
+
+
+# if __name__ == "__main__":
+# import argparse
+
+# # Argument parser for command line usage
+# parser = argparse.ArgumentParser(description="Image classification using CLIP fine-tuned model")
+# parser.add_argument("image_path", type=str, help="Path to the input image")
+# parser.add_argument("model_name", type=str, choices=model_info.keys(), help="Name of the model (e.g., hineng, hinengpun, hinengguj)")
+
+# args = parser.parse_args()
+
+# # Execute prediction with command line inputs
+# result = predict(args.image_path, args.model_name)
+# print(result)
\ No newline at end of file
diff --git a/IndicPhotoOCR/script_identification/__init__.py b/IndicPhotoOCR/script_identification/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/IndicPhotoOCR/script_identification/vit/__init__.py b/IndicPhotoOCR/script_identification/vit/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/IndicPhotoOCR/script_identification/vit/config.py b/IndicPhotoOCR/script_identification/vit/config.py
new file mode 100644
index 0000000000000000000000000000000000000000..9ea7af657551e0f7bec82f88c7e14a5d0e4a5f70
--- /dev/null
+++ b/IndicPhotoOCR/script_identification/vit/config.py
@@ -0,0 +1,58 @@
+common_config={
+ 'pretrained_vit_model': 'google/vit-base-patch16-224-in21k'
+}
+
+train_config = {
+ 'epochs': 20,
+ 'max_images_real':1900,
+ 'classes':12,
+ 'hindi_path_real': '',
+ 'english_path_real':'',
+ 'gujarati_path_real':'',
+ 'punjabi_path_real':'',
+ 'assamese_path_real':'',
+ 'bengali_path_real':'',
+ 'kannada_path_real':'',
+ 'malayalam_path_real':'',
+ 'marathi_path_real':'',
+ 'odia_path_real':'',
+ 'tamil_path_real':'',
+ 'telugu_path_real':'',
+ 'checkpoints_dir': ''
+
+}
+train_config.update(common_config)
+
+test_config = {
+ 'reload_model': '',
+ 'max_images':2000,
+ 'classes':12,
+ 'hindi_path_real': '',
+ 'english_path_real':'',
+ 'gujarati_path_real':'',
+ 'punjabi_path_real':'',
+ 'assamese_path_real':'',
+ 'bengali_path_real':'',
+ 'kannada_path_real':'',
+ 'malayalam_path_real':'',
+ 'marathi_path_real':'',
+ 'odia_path_real':'',
+ 'tamil_path_real':'',
+ 'telugu_path_real':'',
+
+
+}
+test_config.update(common_config)
+
+
+
+infer_config = {
+ 'model_path':'',
+ 'img_path': 'image_path',
+ 'folder_path':'',
+ 'csv_path':'',
+}
+
+
+infer_config.update(common_config)
+
diff --git a/IndicPhotoOCR/script_identification/vit/vit_infer.py b/IndicPhotoOCR/script_identification/vit/vit_infer.py
new file mode 100644
index 0000000000000000000000000000000000000000..0fed4612d2785065c582a607404eab150ca3c394
--- /dev/null
+++ b/IndicPhotoOCR/script_identification/vit/vit_infer.py
@@ -0,0 +1,213 @@
+from transformers import AutoImageProcessor,ViTForImageClassification,pipeline
+from PIL import Image
+from datasets import DatasetDict,Dataset,ClassLabel
+import torchvision.transforms as transforms
+import numpy as np
+import csv
+import os
+import argparse
+import requests
+from tqdm import tqdm
+import zipfile
+import time
+import glob
+from IndicPhotoOCR.script_identification.vit.config import infer_config as config
+
+model_info = {
+ "hindi": {
+ "path": "models/hindienglish",
+ "url" : "https://github.com/Bhashini-IITJ/ScriptIdentification/releases/download/Vit_Models/hindienglish.zip",
+ "subcategories": ["hindi", "english"]
+ },
+ "assamese": {
+ "path": "models/hindienglishassamese",
+ "url": "https://github.com/Bhashini-IITJ/ScriptIdentification/releases/download/Vit_Models/hindienglishassamese.zip",
+ "subcategories": ["hindi", "english", "assamese"]
+ },
+ "bengali": {
+ "path": "models/hindienglishbengali",
+ "url" : "https://github.com/Bhashini-IITJ/ScriptIdentification/releases/download/Vit_Models/hindienglishbengali.zip",
+ "subcategories": ["hindi", "english", "bengali"]
+ },
+ "gujarati": {
+ "path": "models/hindienglishgujarati",
+ "url" : "https://github.com/Bhashini-IITJ/ScriptIdentification/releases/download/Vit_Models/hindienglishgujarati.zip",
+ "subcategories": ["hindi", "english", "gujarati"]
+ },
+ "kannada": {
+ "path": "models/hindienglishkannada",
+ "url" : "https://github.com/Bhashini-IITJ/ScriptIdentification/releases/download/Vit_Models/hindienglishkannada.zip",
+ "subcategories": ["hindi", "english", "kannada"]
+ },
+ "malayalam": {
+ "path": "models/hindienglishmalayalam",
+ "url" : "https://github.com/Bhashini-IITJ/ScriptIdentification/releases/download/Vit_Models/hindienglishmalayalam.zip",
+ "subcategories": ["hindi", "english", "malayalam"]
+ },
+ "marathi": {
+ "path": "models/hindienglishmarathi",
+ "url" : "https://github.com/Bhashini-IITJ/ScriptIdentification/releases/download/Vit_Models/hindienglishmarathi.zip",
+ "subcategories": ["hindi", "english", "marathi"]
+ },
+ "meitei": {
+ "path": "models/hindienglishmeitei",
+ "url" : "https://github.com/Bhashini-IITJ/ScriptIdentification/releases/download/Vit_Models/hindienglishmeitei.zip",
+ "subcategories": ["hindi", "english", "meitei"]
+ },
+ "odia": {
+ "path": "models/hindienglishodia",
+ "url" : "https://github.com/Bhashini-IITJ/ScriptIdentification/releases/download/Vit_Models/hindienglishodia.zip",
+ "subcategories": ["hindi", "english", "odia"]
+ },
+ "punjabi": {
+ "path": "models/hindienglishpunjabi",
+ "url" : "https://github.com/Bhashini-IITJ/ScriptIdentification/releases/download/Vit_Models/hindienglishpunjabi.zip",
+ "subcategories": ["hindi", "english", "punjabi"]
+ },
+ "tamil": {
+ "path": "models/hindienglishtamil",
+ "url" : "https://github.com/Bhashini-IITJ/ScriptIdentification/releases/download/Vit_Models/hindienglishtamil.zip",
+ "subcategories": ["hindi", "english", "tamil"]
+ },
+ "telugu": {
+ "path": "models/hindienglishtelugu",
+ "url" : "https://github.com/Bhashini-IITJ/ScriptIdentification/releases/download/Vit_Models/hindienglishtelugu.zip",
+ "subcategories": ["hindi", "english", "telugu"]
+ },
+ "12C": {
+ "path": "models/12_classes",
+ "url" : "https://github.com/Bhashini-IITJ/ScriptIdentification/releases/download/Vit_Models/12_classes.zip",
+ "subcategories": ["hindi", "english", "assamese","bengali","gujarati","kannada","malayalam","marathi","odia","punjabi","tamil","telegu"]
+ },
+
+
+}
+
+pretrained_vit_model = config['pretrained_vit_model']
+processor = AutoImageProcessor.from_pretrained(pretrained_vit_model,use_fast=True)
+
+
+class VIT_identifier:
+ def __init__(self):
+ pass
+
+ def unzip_file(self, zip_path, extract_to):
+
+ with zipfile.ZipFile(zip_path, 'r') as zip_ref:
+ zip_ref.extractall(extract_to)
+ print(f"Extracted files to {extract_to}")
+
+
+
+
+ def ensure_model(self, model_name):
+ model_path = model_info[model_name]["path"]
+ url = model_info[model_name]["url"]
+ root_model_dir = "IndicPhotoOCR/script_identification/vit"
+ model_path = os.path.join(root_model_dir, model_path)
+
+ if not os.path.exists(model_path):
+ print(f"Model not found locally. Downloading {model_name} from {url}...")
+
+ response = requests.get(url, stream=True)
+ zip_path = os.path.join(model_path, "temp_download.zip")
+
+ os.makedirs(model_path, exist_ok=True)
+
+ with open(zip_path, "wb") as file:
+ for chunk in response.iter_content(chunk_size=8192):
+ file.write(chunk)
+
+ with zipfile.ZipFile(zip_path, 'r') as zip_ref:
+ zip_ref.extractall(model_path)
+
+ os.remove(zip_path)
+
+ print(f"Downloaded and extracted to {model_path}")
+
+ else:
+ # print(f"Model folder already exists: {model_path}")
+ pass
+
+ return model_path
+
+
+
+
+
+ def identify(self, image_path,model_name, device):
+ model_path = self.ensure_model(model_name)
+
+ vit = ViTForImageClassification.from_pretrained(model_path)
+ model= pipeline('image-classification', model=vit, feature_extractor=processor,device=device)
+
+ if image_path.endswith((".png", ".jpg", ".jpeg")):
+
+ image = Image.open(image_path)
+ output = model(image)
+ predicted_label = max(output, key=lambda x: x['score'])['label']
+
+ # print(f"image_path: {image_path}, predicted_label: {predicted_label}\n")
+
+ return predicted_label
+
+
+ def predict_batch(self, image_dir,model_name,time_show,output_csv="prediction.csv"):
+ model_path = self.ensure_model(model_name)
+ vit = ViTForImageClassification.from_pretrained(model_path)
+ model= pipeline('image-classification', model=vit, feature_extractor=processor,device=0)
+
+ start_time = time.time()
+ results=[]
+ image_count=0
+ for filename in os.listdir(image_dir):
+
+ if filename.endswith((".png", ".jpg", ".jpeg")):
+ img_path = os.path.join(image_dir, filename)
+ image = Image.open(img_path)
+
+
+ output = model(image)
+ predicted_label = max(output, key=lambda x: x['score'])['label'].capitalize()
+
+ results.append({"Filepath": filename, "Language": predicted_label})
+ image_count+=1
+
+ elapsed_time = time.time() - start_time
+
+ if time_show:
+ print(f"Time taken to process {image_count} images: {elapsed_time:.2f} seconds")
+
+ with open(output_csv, mode="w", newline="", encoding="utf-8") as csvfile:
+ writer = csv.DictWriter(csvfile, fieldnames=["Filepath", "Language"])
+ writer.writeheader()
+ writer.writerows(results)
+
+ return output_csv
+
+
+# if __name__ == "__main__":
+# # Argument parser for command line usage
+# parser = argparse.ArgumentParser(description="Image classification using CLIP fine-tuned model")
+# parser.add_argument("--image_path", type=str, help="Path to the input image")
+# parser.add_argument("--image_dir", type=str, help="Path to the input image directory")
+# parser.add_argument("--model_name", type=str, choices=model_info.keys(), help="Name of the model (e.g., hineng, hinengpun, hinengguj)")
+# parser.add_argument("--batch", action="store_true", help="Process images in batch mode if specified")
+# parser.add_argument("--time",type=bool, nargs="?", const=True, default=False, help="Prints the time required to process a batch of images")
+
+# args = parser.parse_args()
+
+
+# # Choose function based on the batch parameter
+# if args.batch:
+# if not args.image_dir:
+# print("Error: image_dir is required when batch is set to True.")
+# else:
+# result = predict_batch(args.image_dir, args.model_name, args.time)
+# print(result)
+# else:
+# if not args.image_path:
+# print("Error: image_path is required when batch is not set.")
+# else:
+# result = predict(args.image_path, args.model_name)
+# print(result)
\ No newline at end of file
diff --git a/IndicPhotoOCR/theme.py b/IndicPhotoOCR/theme.py
new file mode 100644
index 0000000000000000000000000000000000000000..ca812bf7d38568c0d831dc1476bffa3486970e33
--- /dev/null
+++ b/IndicPhotoOCR/theme.py
@@ -0,0 +1,43 @@
+from __future__ import annotations
+from typing import Iterable
+import gradio as gr
+from gradio.themes.base import Base
+from gradio.themes.utils import colors, fonts, sizes
+import time
+
+
+class Seafoam(Base):
+ def __init__(
+ self,
+ *,
+ primary_hue: colors.Color | str = colors.emerald,
+ secondary_hue: colors.Color | str = colors.blue,
+ neutral_hue: colors.Color | str = colors.gray,
+ spacing_size: sizes.Size | str = sizes.spacing_md,
+ radius_size: sizes.Size | str = sizes.radius_md,
+ text_size: sizes.Size | str = sizes.text_lg,
+ font: fonts.Font
+ | str
+ | Iterable[fonts.Font | str] = (
+ fonts.GoogleFont("Quicksand"),
+ "ui-sans-serif",
+ "sans-serif",
+ ),
+ font_mono: fonts.Font
+ | str
+ | Iterable[fonts.Font | str] = (
+ fonts.GoogleFont("IBM Plex Mono"),
+ "ui-monospace",
+ "monospace",
+ ),
+ ):
+ super().__init__(
+ primary_hue=primary_hue,
+ secondary_hue=secondary_hue,
+ neutral_hue=neutral_hue,
+ spacing_size=spacing_size,
+ radius_size=radius_size,
+ text_size=text_size,
+ font=font,
+ font_mono=font_mono,
+ )
\ No newline at end of file
diff --git a/IndicPhotoOCR/utils/__init__.py b/IndicPhotoOCR/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/IndicPhotoOCR/utils/helper.py b/IndicPhotoOCR/utils/helper.py
new file mode 100644
index 0000000000000000000000000000000000000000..16f023e50b81e13080f5a8b2f54ac0df1a90f50f
--- /dev/null
+++ b/IndicPhotoOCR/utils/helper.py
@@ -0,0 +1,220 @@
+import numpy as np
+
+# def detect_para(bbox_dict):
+# alpha1 = 0.2
+# alpha2 = 0.7
+# beta1 = 0.4
+# data = bbox_dict
+# word_crops = list(data.keys())
+# for i in word_crops:
+# data[i]["x1"], data[i]["y1"], data[i]["x2"], data[i]["y2"] = data[i]["bbox"]
+# data[i]["xc"] = (data[i]["x1"] + data[i]["x2"]) / 2
+# data[i]["yc"] = (data[i]["y1"] + data[i]["y2"]) / 2
+# data[i]["w"] = data[i]["x2"] - data[i]["x1"]
+# data[i]["h"] = data[i]["y2"] - data[i]["y1"]
+
+# patch_info = {}
+# while word_crops:
+# img_name = word_crops[0].split("_")[0]
+# word_crop_collection = [
+# word_crop for word_crop in word_crops if word_crop.startswith(img_name)
+# ]
+# centroids = {}
+# lines = []
+# img_word_crops = word_crop_collection.copy()
+# para = []
+# while img_word_crops:
+# clusters = []
+# para_words_group = [
+# img_word_crops[0],
+# ]
+# added = [
+# img_word_crops[0],
+# ]
+# img_word_crops.remove(img_word_crops[0])
+# ## determining the paragraph
+# while added:
+# word_crop = added.pop()
+# for i in range(len(img_word_crops)):
+# word_crop_ = img_word_crops[i]
+# if (
+# abs(data[word_crop_]["yc"] - data[word_crop]["yc"])
+# < data[word_crop]["h"] * alpha1
+# ):
+# if data[word_crop]["xc"] > data[word_crop_]["xc"]:
+# if (data[word_crop]["x1"] - data[word_crop_]["x2"]) < data[
+# word_crop
+# ]["h"] * alpha2:
+# para_words_group.append(word_crop_)
+# added.append(word_crop_)
+# else:
+# if (data[word_crop_]["x1"] - data[word_crop]["x2"]) < data[
+# word_crop
+# ]["h"] * alpha2:
+# para_words_group.append(word_crop_)
+# added.append(word_crop_)
+# else:
+# if data[word_crop]["yc"] > data[word_crop_]["yc"]:
+# if (data[word_crop]["y1"] - data[word_crop_]["y2"]) < data[
+# word_crop
+# ]["h"] * beta1 and (
+# (
+# (data[word_crop_]["x1"] < data[word_crop]["x2"])
+# and (data[word_crop_]["x1"] > data[word_crop]["x1"])
+# )
+# or (
+# (data[word_crop_]["x2"] < data[word_crop]["x2"])
+# and (data[word_crop_]["x2"] > data[word_crop]["x1"])
+# )
+# or (
+# (data[word_crop]["x1"] > data[word_crop_]["x1"])
+# and (data[word_crop]["x2"] < data[word_crop_]["x2"])
+# )
+# ):
+# para_words_group.append(word_crop_)
+# added.append(word_crop_)
+# else:
+# if (data[word_crop_]["y1"] - data[word_crop]["y2"]) < data[
+# word_crop
+# ]["h"] * beta1 and (
+# (
+# (data[word_crop_]["x1"] < data[word_crop]["x2"])
+# and (data[word_crop_]["x1"] > data[word_crop]["x1"])
+# )
+# or (
+# (data[word_crop_]["x2"] < data[word_crop]["x2"])
+# and (data[word_crop_]["x2"] > data[word_crop]["x1"])
+# )
+# or (
+# (data[word_crop]["x1"] > data[word_crop_]["x1"])
+# and (data[word_crop]["x2"] < data[word_crop_]["x2"])
+# )
+# ):
+# para_words_group.append(word_crop_)
+# added.append(word_crop_)
+# img_word_crops = [p for p in img_word_crops if p not in para_words_group]
+# ## processing for the line
+# while para_words_group:
+# line_words_group = [
+# para_words_group[0],
+# ]
+# added = [
+# para_words_group[0],
+# ]
+# para_words_group.remove(para_words_group[0])
+# ## determining the line
+# while added:
+# word_crop = added.pop()
+# for i in range(len(para_words_group)):
+# word_crop_ = para_words_group[i]
+# if (
+# abs(data[word_crop_]["yc"] - data[word_crop]["yc"])
+# < data[word_crop]["h"] * alpha1
+# ):
+# if data[word_crop]["xc"] > data[word_crop_]["xc"]:
+# if (data[word_crop]["x1"] - data[word_crop_]["x2"]) < data[
+# word_crop
+# ]["h"] * alpha2:
+# line_words_group.append(word_crop_)
+# added.append(word_crop_)
+# else:
+# if (data[word_crop_]["x1"] - data[word_crop]["x2"]) < data[
+# word_crop
+# ]["h"] * alpha2:
+# line_words_group.append(word_crop_)
+# added.append(word_crop_)
+# para_words_group = [
+# p for p in para_words_group if p not in line_words_group
+# ]
+# xc = [data[word_crop]["xc"] for word_crop in line_words_group]
+# idxs = np.argsort(xc)
+# patch_cluster_ = [line_words_group[i] for i in idxs]
+# line_words_group = patch_cluster_
+# x1 = [data[word_crop]["x1"] for word_crop in line_words_group]
+# x2 = [data[word_crop]["x2"] for word_crop in line_words_group]
+# y1 = [data[word_crop]["y1"] for word_crop in line_words_group]
+# y2 = [data[word_crop]["y2"] for word_crop in line_words_group]
+# txt_line = [data[word_crop]["txt"] for word_crop in line_words_group]
+# txt = " ".join(txt_line)
+# x = [x1[0]]
+# y1_ = [y1[0]]
+# y2_ = [y2[0]]
+# l = [len(txt_l) for txt_l in txt_line]
+# for i in range(1, len(x1)):
+# x.append((x1[i] + x2[i - 1]) / 2)
+# y1_.append((y1[i] + y1[i - 1]) / 2)
+# y2_.append((y2[i] + y2[i - 1]) / 2)
+# x.append(x2[-1])
+# y1_.append(y1[-1])
+# y2_.append(y2[-1])
+# line_info = {
+# "x": x,
+# "y1": y1_,
+# "y2": y2_,
+# "l": l,
+# "txt": txt,
+# "word_crops": line_words_group,
+# }
+# clusters.append(line_info)
+# y_ = [clusters[i]["y1"][0] for i in range(len(clusters))]
+# idxs = np.argsort(y_)
+# clusters_ = [clusters[i] for i in idxs]
+# txt = [clusters[i]["txt"] for i in idxs]
+# l = [len(t) for t in txt]
+# txt = " ".join(txt)
+# para_info = {"lines": clusters_, "l": l, "txt": txt}
+# para.append(para_info)
+
+# for word_crop in word_crop_collection:
+# word_crops.remove(word_crop)
+# return "\n".join([para[i]["txt"] for i in range(len(para))])
+
+
+def detect_para(recognized_texts):
+ """
+ Sort words into lines based on horizontal overlap of bounding boxes.
+
+ Args:
+ recognized_texts (dict): A dictionary with recognized texts as keys and bounding boxes as values.
+ Each bounding box is a list of points [x1, y1, x2, y2].
+
+ Returns:
+ list: A list of lists where each sublist contains words sorted by x-coordinate for a single line.
+ """
+ def calculate_overlap(bbox1, bbox2):
+ """Calculate the vertical overlap between two bounding boxes."""
+ # Extract bounding box coordinates
+ x1_1, y1_1, x2_1, y2_1 = bbox1
+ x1_2, y1_2, x2_2, y2_2 = bbox2
+
+ overlap = max(0, min(y2_1, y2_2) - max(y1_1, y1_2))
+ height = min(y2_1 - y1_1, y2_2 - y1_2)
+ return overlap / height if height > 0 else 0
+
+ # Convert recognized_texts dictionary to a list of tuples for processing
+ items = list(recognized_texts.items())
+ lines = []
+
+ while items:
+ current_image, current_data = items.pop(0)
+ current_text, current_bbox = current_data['txt'], current_data['bbox']
+ current_line = [(current_text, current_bbox)]
+
+ remaining_items = []
+ for image, data in items:
+ text, bbox = data['txt'], data['bbox']
+ if calculate_overlap(current_bbox, bbox) > 0.4:
+ current_line.append((text, bbox))
+ else:
+ remaining_items.append((image, data))
+
+ items = remaining_items
+ lines.append(current_line)
+
+ # Sort words within each line based on x1 (horizontal position)
+ sorted_lines = [
+ [text for text, bbox in sorted(line, key=lambda x: x[1][0])] for line in lines
+ ]
+ return sorted_lines
+
+
diff --git a/IndicPhotoOCR/utils/strhub/__init__.py b/IndicPhotoOCR/utils/strhub/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..5d740e413e1fb8d335027dd8ff6d3aa393d45e84
--- /dev/null
+++ b/IndicPhotoOCR/utils/strhub/__init__.py
@@ -0,0 +1,2 @@
+# from data.module import SceneTextDataModule
+# from model.utils import load_from_checkpoint
\ No newline at end of file
diff --git a/IndicPhotoOCR/utils/strhub/data/__init__.py b/IndicPhotoOCR/utils/strhub/data/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..36e3d6f0bae7320fb3ae022de41146597a5481b6
--- /dev/null
+++ b/IndicPhotoOCR/utils/strhub/data/__init__.py
@@ -0,0 +1 @@
+# from .module import SceneTextDataModule
\ No newline at end of file
diff --git a/IndicPhotoOCR/utils/strhub/data/aa_overrides.py b/IndicPhotoOCR/utils/strhub/data/aa_overrides.py
new file mode 100644
index 0000000000000000000000000000000000000000..ef374e2e4166c3847d80d30bab2b0eb6ba88d70c
--- /dev/null
+++ b/IndicPhotoOCR/utils/strhub/data/aa_overrides.py
@@ -0,0 +1,46 @@
+# Scene Text Recognition Model Hub
+# Copyright 2022 Darwin Bautista
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Extends default ops to accept optional parameters."""
+from functools import partial
+
+from timm.data.auto_augment import _LEVEL_DENOM, LEVEL_TO_ARG, NAME_TO_OP, _randomly_negate, rotate
+
+
+def rotate_expand(img, degrees, **kwargs):
+ """Rotate operation with expand=True to avoid cutting off the characters"""
+ kwargs['expand'] = True
+ return rotate(img, degrees, **kwargs)
+
+
+def _level_to_arg(level, hparams, key, default):
+ magnitude = hparams.get(key, default)
+ level = (level / _LEVEL_DENOM) * magnitude
+ level = _randomly_negate(level)
+ return (level,)
+
+
+def apply():
+ # Overrides
+ NAME_TO_OP.update({
+ 'Rotate': rotate_expand,
+ })
+ LEVEL_TO_ARG.update({
+ 'Rotate': partial(_level_to_arg, key='rotate_deg', default=30.0),
+ 'ShearX': partial(_level_to_arg, key='shear_x_pct', default=0.3),
+ 'ShearY': partial(_level_to_arg, key='shear_y_pct', default=0.3),
+ 'TranslateXRel': partial(_level_to_arg, key='translate_x_pct', default=0.45),
+ 'TranslateYRel': partial(_level_to_arg, key='translate_y_pct', default=0.45),
+ })
diff --git a/IndicPhotoOCR/utils/strhub/data/augment.py b/IndicPhotoOCR/utils/strhub/data/augment.py
new file mode 100644
index 0000000000000000000000000000000000000000..ed8832503693863907640b83d5771de90ed6e773
--- /dev/null
+++ b/IndicPhotoOCR/utils/strhub/data/augment.py
@@ -0,0 +1,112 @@
+# Scene Text Recognition Model Hub
+# Copyright 2022 Darwin Bautista
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from functools import partial
+
+import imgaug.augmenters as iaa
+import numpy as np
+from PIL import Image, ImageFilter
+
+from timm.data import auto_augment
+
+from strhub.data import aa_overrides
+
+aa_overrides.apply()
+
+_OP_CACHE = {}
+
+
+def _get_op(key, factory):
+ try:
+ op = _OP_CACHE[key]
+ except KeyError:
+ op = factory()
+ _OP_CACHE[key] = op
+ return op
+
+
+def _get_param(level, img, max_dim_factor, min_level=1):
+ max_level = max(min_level, max_dim_factor * max(img.size))
+ return round(min(level, max_level))
+
+
+def gaussian_blur(img, radius, **__):
+ radius = _get_param(radius, img, 0.02)
+ key = 'gaussian_blur_' + str(radius)
+ op = _get_op(key, lambda: ImageFilter.GaussianBlur(radius))
+ return img.filter(op)
+
+
+def motion_blur(img, k, **__):
+ k = _get_param(k, img, 0.08, 3) | 1 # bin to odd values
+ key = 'motion_blur_' + str(k)
+ op = _get_op(key, lambda: iaa.MotionBlur(k))
+ return Image.fromarray(op(image=np.asarray(img)))
+
+
+def gaussian_noise(img, scale, **_):
+ scale = _get_param(scale, img, 0.25) | 1 # bin to odd values
+ key = 'gaussian_noise_' + str(scale)
+ op = _get_op(key, lambda: iaa.AdditiveGaussianNoise(scale=scale))
+ return Image.fromarray(op(image=np.asarray(img)))
+
+
+def poisson_noise(img, lam, **_):
+ lam = _get_param(lam, img, 0.2) | 1 # bin to odd values
+ key = 'poisson_noise_' + str(lam)
+ op = _get_op(key, lambda: iaa.AdditivePoissonNoise(lam))
+ return Image.fromarray(op(image=np.asarray(img)))
+
+
+def _level_to_arg(level, _hparams, max):
+ level = max * level / auto_augment._LEVEL_DENOM
+ return (level,)
+
+
+_RAND_TRANSFORMS = auto_augment._RAND_INCREASING_TRANSFORMS.copy()
+_RAND_TRANSFORMS.remove('SharpnessIncreasing') # remove, interferes with *blur ops
+_RAND_TRANSFORMS.extend([
+ 'GaussianBlur',
+ # 'MotionBlur',
+ # 'GaussianNoise',
+ 'PoissonNoise',
+])
+auto_augment.LEVEL_TO_ARG.update({
+ 'GaussianBlur': partial(_level_to_arg, max=4),
+ 'MotionBlur': partial(_level_to_arg, max=20),
+ 'GaussianNoise': partial(_level_to_arg, max=0.1 * 255),
+ 'PoissonNoise': partial(_level_to_arg, max=40),
+})
+auto_augment.NAME_TO_OP.update({
+ 'GaussianBlur': gaussian_blur,
+ 'MotionBlur': motion_blur,
+ 'GaussianNoise': gaussian_noise,
+ 'PoissonNoise': poisson_noise,
+})
+
+
+def rand_augment_transform(magnitude=5, num_layers=3):
+ # These are tuned for magnitude=5, which means that effective magnitudes are half of these values.
+ hparams = {
+ 'rotate_deg': 30,
+ 'shear_x_pct': 0.9,
+ 'shear_y_pct': 0.2,
+ 'translate_x_pct': 0.10,
+ 'translate_y_pct': 0.30,
+ }
+ ra_ops = auto_augment.rand_augment_ops(magnitude, hparams=hparams, transforms=_RAND_TRANSFORMS)
+ # Supply weights to disable replacement in random selection (i.e. avoid applying the same op twice)
+ choice_weights = [1.0 / len(ra_ops) for _ in range(len(ra_ops))]
+ return auto_augment.RandAugment(ra_ops, num_layers, choice_weights)
diff --git a/IndicPhotoOCR/utils/strhub/data/dataset.py b/IndicPhotoOCR/utils/strhub/data/dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..65954c23127f02d1393179ddbc9fb175c88b8de9
--- /dev/null
+++ b/IndicPhotoOCR/utils/strhub/data/dataset.py
@@ -0,0 +1,148 @@
+# Scene Text Recognition Model Hub
+# Copyright 2022 Darwin Bautista
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import glob
+import io
+import logging
+import unicodedata
+from pathlib import Path, PurePath
+from typing import Callable, Optional, Union
+
+import lmdb
+from PIL import Image
+
+from torch.utils.data import ConcatDataset, Dataset
+
+from IndicPhotoOCR.utils.strhub.data.utils import CharsetAdapter
+
+log = logging.getLogger(__name__)
+
+
+def build_tree_dataset(root: Union[PurePath, str], *args, **kwargs):
+ try:
+ kwargs.pop('root') # prevent 'root' from being passed via kwargs
+ except KeyError:
+ pass
+ root = Path(root).absolute()
+ log.info(f'dataset root:\t{root}')
+ datasets = []
+ for mdb in glob.glob(str(root / '**/data.mdb'), recursive=True):
+ mdb = Path(mdb)
+ ds_name = str(mdb.parent.relative_to(root))
+ ds_root = str(mdb.parent.absolute())
+ dataset = LmdbDataset(ds_root, *args, **kwargs)
+ log.info(f'\tlmdb:\t{ds_name}\tnum samples: {len(dataset)}')
+ datasets.append(dataset)
+ return ConcatDataset(datasets)
+
+
+class LmdbDataset(Dataset):
+ """Dataset interface to an LMDB database.
+
+ It supports both labelled and unlabelled datasets. For unlabelled datasets, the image index itself is returned
+ as the label. Unicode characters are normalized by default. Case-sensitivity is inferred from the charset.
+ Labels are transformed according to the charset.
+ """
+
+ def __init__(
+ self,
+ root: str,
+ charset: str,
+ max_label_len: int,
+ min_image_dim: int = 0,
+ remove_whitespace: bool = True,
+ normalize_unicode: bool = True,
+ unlabelled: bool = False,
+ transform: Optional[Callable] = None,
+ ):
+ self._env = None
+ self.root = root
+ self.unlabelled = unlabelled
+ self.transform = transform
+ self.labels = []
+ self.filtered_index_list = []
+ self.num_samples = self._preprocess_labels(
+ charset, remove_whitespace, normalize_unicode, max_label_len, min_image_dim
+ )
+
+ def __del__(self):
+ if self._env is not None:
+ self._env.close()
+ self._env = None
+
+ def _create_env(self):
+ return lmdb.open(
+ self.root, max_readers=1, readonly=True, create=False, readahead=False, meminit=False, lock=False
+ )
+
+ @property
+ def env(self):
+ if self._env is None:
+ self._env = self._create_env()
+ return self._env
+
+ def _preprocess_labels(self, charset, remove_whitespace, normalize_unicode, max_label_len, min_image_dim):
+ charset_adapter = CharsetAdapter(charset)
+ with self._create_env() as env, env.begin() as txn:
+ num_samples = int(txn.get('num-samples'.encode()))
+ if self.unlabelled:
+ return num_samples
+ for index in range(num_samples):
+ index += 1 # lmdb starts with 1
+ label_key = f'label-{index:09d}'.encode()
+ label = txn.get(label_key).decode()
+ # Normally, whitespace is removed from the labels.
+ if remove_whitespace:
+ label = ''.join(label.split())
+ # Normalize unicode composites (if any) and convert to compatible ASCII characters
+ if normalize_unicode:
+ label = unicodedata.normalize('NFKD', label).encode('ascii', 'ignore').decode()
+ # Filter by length before removing unsupported characters. The original label might be too long.
+ if len(label) > max_label_len:
+ continue
+ label = charset_adapter(label)
+ # We filter out samples which don't contain any supported characters
+ if not label:
+ continue
+ # Filter images that are too small.
+ if min_image_dim > 0:
+ img_key = f'image-{index:09d}'.encode()
+ buf = io.BytesIO(txn.get(img_key))
+ w, h = Image.open(buf).size
+ if w < self.min_image_dim or h < self.min_image_dim:
+ continue
+ self.labels.append(label)
+ self.filtered_index_list.append(index)
+ return len(self.labels)
+
+ def __len__(self):
+ return self.num_samples
+
+ def __getitem__(self, index):
+ if self.unlabelled:
+ label = index
+ else:
+ label = self.labels[index]
+ index = self.filtered_index_list[index]
+
+ img_key = f'image-{index:09d}'.encode()
+ with self.env.begin() as txn:
+ imgbuf = txn.get(img_key)
+ buf = io.BytesIO(imgbuf)
+ img = Image.open(buf).convert('RGB')
+
+ if self.transform is not None:
+ img = self.transform(img)
+
+ return img, label
diff --git a/IndicPhotoOCR/utils/strhub/data/module.py b/IndicPhotoOCR/utils/strhub/data/module.py
new file mode 100644
index 0000000000000000000000000000000000000000..336a4ba467dd59d9b146e5eebf38e7225e348b12
--- /dev/null
+++ b/IndicPhotoOCR/utils/strhub/data/module.py
@@ -0,0 +1,157 @@
+# Scene Text Recognition Model Hub
+# Copyright 2022 Darwin Bautista
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from pathlib import PurePath
+from typing import Callable, Optional, Sequence
+
+from torch.utils.data import DataLoader
+from torchvision import transforms as T
+
+import pytorch_lightning as pl
+
+from IndicPhotoOCR.utils.strhub.data.dataset import LmdbDataset, build_tree_dataset
+
+
+class SceneTextDataModule(pl.LightningDataModule):
+ TEST_BENCHMARK_SUB = ('IIIT5k', 'SVT', 'IC13_857', 'IC15_1811', 'SVTP', 'CUTE80')
+ TEST_BENCHMARK = ('IIIT5k', 'SVT', 'IC13_1015', 'IC15_2077', 'SVTP', 'CUTE80')
+ TEST_NEW = ('ArT', 'COCOv1.4', 'Uber')
+ TEST_ALL = tuple(set(TEST_BENCHMARK_SUB + TEST_BENCHMARK + TEST_NEW))
+
+ def __init__(
+ self,
+ root_dir: str,
+ train_dir: str,
+ img_size: Sequence[int],
+ max_label_length: int,
+ charset_train: str,
+ charset_test: str,
+ batch_size: int,
+ num_workers: int,
+ augment: bool,
+ remove_whitespace: bool = True,
+ normalize_unicode: bool = True,
+ min_image_dim: int = 0,
+ rotation: int = 0,
+ collate_fn: Optional[Callable] = None,
+ ):
+ super().__init__()
+ self.root_dir = root_dir
+ self.train_dir = train_dir
+ self.img_size = tuple(img_size)
+ self.max_label_length = max_label_length
+ self.charset_train = charset_train
+ self.charset_test = charset_test
+ self.batch_size = batch_size
+ self.num_workers = num_workers
+ self.augment = augment
+ self.remove_whitespace = remove_whitespace
+ self.normalize_unicode = normalize_unicode
+ self.min_image_dim = min_image_dim
+ self.rotation = rotation
+ self.collate_fn = collate_fn
+ self._train_dataset = None
+ self._val_dataset = None
+
+ @staticmethod
+ def get_transform(img_size: tuple[int], augment: bool = False, rotation: int = 0):
+ transforms = []
+ if augment:
+ from .augment import rand_augment_transform
+
+ transforms.append(rand_augment_transform())
+ if rotation:
+ transforms.append(lambda img: img.rotate(rotation, expand=True))
+ transforms.extend([
+ T.Resize(img_size, T.InterpolationMode.BICUBIC),
+ T.ToTensor(),
+ T.Normalize(0.5, 0.5),
+ ])
+ return T.Compose(transforms)
+
+ @property
+ def train_dataset(self):
+ if self._train_dataset is None:
+ transform = self.get_transform(self.img_size, self.augment)
+ root = PurePath(self.root_dir, 'train', self.train_dir)
+ self._train_dataset = build_tree_dataset(
+ root,
+ self.charset_train,
+ self.max_label_length,
+ self.min_image_dim,
+ self.remove_whitespace,
+ self.normalize_unicode,
+ transform=transform,
+ )
+ return self._train_dataset
+
+ @property
+ def val_dataset(self):
+ if self._val_dataset is None:
+ transform = self.get_transform(self.img_size)
+ root = PurePath(self.root_dir, 'val')
+ self._val_dataset = build_tree_dataset(
+ root,
+ self.charset_test,
+ self.max_label_length,
+ self.min_image_dim,
+ self.remove_whitespace,
+ self.normalize_unicode,
+ transform=transform,
+ )
+ return self._val_dataset
+
+ def train_dataloader(self):
+ return DataLoader(
+ self.train_dataset,
+ batch_size=self.batch_size,
+ shuffle=True,
+ num_workers=self.num_workers,
+ persistent_workers=self.num_workers > 0,
+ pin_memory=True,
+ collate_fn=self.collate_fn,
+ )
+
+ def val_dataloader(self):
+ return DataLoader(
+ self.val_dataset,
+ batch_size=self.batch_size,
+ num_workers=self.num_workers,
+ persistent_workers=self.num_workers > 0,
+ pin_memory=True,
+ collate_fn=self.collate_fn,
+ )
+
+ def test_dataloaders(self, subset):
+ transform = self.get_transform(self.img_size, rotation=self.rotation)
+ root = PurePath(self.root_dir, 'test')
+ datasets = {
+ s: LmdbDataset(
+ str(root / s),
+ self.charset_test,
+ self.max_label_length,
+ self.min_image_dim,
+ self.remove_whitespace,
+ self.normalize_unicode,
+ transform=transform,
+ )
+ for s in subset
+ }
+ return {
+ k: DataLoader(
+ v, batch_size=self.batch_size, num_workers=self.num_workers, pin_memory=True, collate_fn=self.collate_fn
+ )
+ for k, v in datasets.items()
+ }
diff --git a/IndicPhotoOCR/utils/strhub/data/utils.py b/IndicPhotoOCR/utils/strhub/data/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..16fd30d0bba424361730b9d1c33016746de23f68
--- /dev/null
+++ b/IndicPhotoOCR/utils/strhub/data/utils.py
@@ -0,0 +1,150 @@
+# Scene Text Recognition Model Hub
+# Copyright 2022 Darwin Bautista
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import re
+from abc import ABC, abstractmethod
+from itertools import groupby
+from typing import Optional
+
+import torch
+from torch import Tensor
+from torch.nn.utils.rnn import pad_sequence
+
+
+class CharsetAdapter:
+ """Transforms labels according to the target charset."""
+
+ def __init__(self, target_charset) -> None:
+ super().__init__()
+ self.lowercase_only = target_charset == target_charset.lower()
+ self.uppercase_only = target_charset == target_charset.upper()
+ self.unsupported = re.compile(f'[^{re.escape(target_charset)}]')
+
+ def __call__(self, label):
+ if self.lowercase_only:
+ label = label.lower()
+ elif self.uppercase_only:
+ label = label.upper()
+ # Remove unsupported characters
+ label = self.unsupported.sub('', label)
+ return label
+
+
+class BaseTokenizer(ABC):
+
+ def __init__(self, charset: str, specials_first: tuple = (), specials_last: tuple = ()) -> None:
+ self._itos = specials_first + tuple(charset) + specials_last
+ self._stoi = {s: i for i, s in enumerate(self._itos)}
+
+ def __len__(self):
+ return len(self._itos)
+
+ def _tok2ids(self, tokens: str) -> list[int]:
+ return [self._stoi[s] for s in tokens]
+
+ def _ids2tok(self, token_ids: list[int], join: bool = True) -> str:
+ tokens = [self._itos[i] for i in token_ids]
+ return ''.join(tokens) if join else tokens
+
+ @abstractmethod
+ def encode(self, labels: list[str], device: Optional[torch.device] = None) -> Tensor:
+ """Encode a batch of labels to a representation suitable for the model.
+
+ Args:
+ labels: List of labels. Each can be of arbitrary length.
+ device: Create tensor on this device.
+
+ Returns:
+ Batched tensor representation padded to the max label length. Shape: N, L
+ """
+ raise NotImplementedError
+
+ @abstractmethod
+ def _filter(self, probs: Tensor, ids: Tensor) -> tuple[Tensor, list[int]]:
+ """Internal method which performs the necessary filtering prior to decoding."""
+ raise NotImplementedError
+
+ def decode(self, token_dists: Tensor, raw: bool = False) -> tuple[list[str], list[Tensor]]:
+ """Decode a batch of token distributions.
+
+ Args:
+ token_dists: softmax probabilities over the token distribution. Shape: N, L, C
+ raw: return unprocessed labels (will return list of list of strings)
+
+ Returns:
+ list of string labels (arbitrary length) and
+ their corresponding sequence probabilities as a list of Tensors
+ """
+ batch_tokens = []
+ batch_probs = []
+ for dist in token_dists:
+ probs, ids = dist.max(-1) # greedy selection
+ if not raw:
+ probs, ids = self._filter(probs, ids)
+ tokens = self._ids2tok(ids, not raw)
+ batch_tokens.append(tokens)
+ batch_probs.append(probs)
+ return batch_tokens, batch_probs
+
+
+class Tokenizer(BaseTokenizer):
+ BOS = '[B]'
+ EOS = '[E]'
+ PAD = '[P]'
+
+ def __init__(self, charset: str) -> None:
+ specials_first = (self.EOS,)
+ specials_last = (self.BOS, self.PAD)
+ super().__init__(charset, specials_first, specials_last)
+ self.eos_id, self.bos_id, self.pad_id = [self._stoi[s] for s in specials_first + specials_last]
+
+ def encode(self, labels: list[str], device: Optional[torch.device] = None) -> Tensor:
+ batch = [
+ torch.as_tensor([self.bos_id] + self._tok2ids(y) + [self.eos_id], dtype=torch.long, device=device)
+ for y in labels
+ ]
+ return pad_sequence(batch, batch_first=True, padding_value=self.pad_id)
+
+ def _filter(self, probs: Tensor, ids: Tensor) -> tuple[Tensor, list[int]]:
+ ids = ids.tolist()
+ try:
+ eos_idx = ids.index(self.eos_id)
+ except ValueError:
+ eos_idx = len(ids) # Nothing to truncate.
+ # Truncate after EOS
+ ids = ids[:eos_idx]
+ probs = probs[: eos_idx + 1] # but include prob. for EOS (if it exists)
+ return probs, ids
+
+
+class CTCTokenizer(BaseTokenizer):
+ BLANK = '[B]'
+
+ def __init__(self, charset: str) -> None:
+ # BLANK uses index == 0 by default
+ super().__init__(charset, specials_first=(self.BLANK,))
+ self.blank_id = self._stoi[self.BLANK]
+
+ def encode(self, labels: list[str], device: Optional[torch.device] = None) -> Tensor:
+ # We use a padded representation since we don't want to use CUDNN's CTC implementation
+ batch = [torch.as_tensor(self._tok2ids(y), dtype=torch.long, device=device) for y in labels]
+ return pad_sequence(batch, batch_first=True, padding_value=self.blank_id)
+
+ def _filter(self, probs: Tensor, ids: Tensor) -> tuple[Tensor, list[int]]:
+ # Best path decoding:
+ ids = list(zip(*groupby(ids.tolist())))[0] # Remove duplicate tokens
+ ids = [x for x in ids if x != self.blank_id] # Remove BLANKs
+ # `probs` is just pass-through since all positions are considered part of the path
+ return probs, ids
diff --git a/IndicPhotoOCR/utils/strhub/models/__init__.py b/IndicPhotoOCR/utils/strhub/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..5570dcb032819c28de6b73c4de5bba4109e08c0f
--- /dev/null
+++ b/IndicPhotoOCR/utils/strhub/models/__init__.py
@@ -0,0 +1 @@
+# from .utils import load_from_checkpoint
\ No newline at end of file
diff --git a/IndicPhotoOCR/utils/strhub/models/abinet/LICENSE b/IndicPhotoOCR/utils/strhub/models/abinet/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..2f1d4adb4889b2719f13ed6edf56aed10246a516
--- /dev/null
+++ b/IndicPhotoOCR/utils/strhub/models/abinet/LICENSE
@@ -0,0 +1,25 @@
+ABINet for non-commercial purposes
+
+Copyright (c) 2021, USTC
+All rights reserved.
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are met:
+
+1. Redistributions of source code must retain the above copyright notice, this
+ list of conditions and the following disclaimer.
+
+2. Redistributions in binary form must reproduce the above copyright notice,
+ this list of conditions and the following disclaimer in the documentation
+ and/or other materials provided with the distribution.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
diff --git a/IndicPhotoOCR/utils/strhub/models/abinet/__init__.py b/IndicPhotoOCR/utils/strhub/models/abinet/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..604811036fda52d8485eecfebd4ffeb7f7176042
--- /dev/null
+++ b/IndicPhotoOCR/utils/strhub/models/abinet/__init__.py
@@ -0,0 +1,13 @@
+r"""
+Fang, Shancheng, Hongtao, Xie, Yuxin, Wang, Zhendong, Mao, and Yongdong, Zhang.
+"Read Like Humans: Autonomous, Bidirectional and Iterative Language Modeling for Scene Text Recognition." .
+In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR) (pp. 7098-7107).2021.
+
+https://arxiv.org/abs/2103.06495
+
+All source files, except `system.py`, are based on the implementation listed below,
+and hence are released under the license of the original.
+
+Source: https://github.com/FangShancheng/ABINet
+License: 2-clause BSD License (see included LICENSE file)
+"""
diff --git a/IndicPhotoOCR/utils/strhub/models/abinet/attention.py b/IndicPhotoOCR/utils/strhub/models/abinet/attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..cc8fba0638e7444fdffe964f72d0566c1a5bb818
--- /dev/null
+++ b/IndicPhotoOCR/utils/strhub/models/abinet/attention.py
@@ -0,0 +1,100 @@
+import torch
+import torch.nn as nn
+
+from .transformer import PositionalEncoding
+
+
+class Attention(nn.Module):
+ def __init__(self, in_channels=512, max_length=25, n_feature=256):
+ super().__init__()
+ self.max_length = max_length
+
+ self.f0_embedding = nn.Embedding(max_length, in_channels)
+ self.w0 = nn.Linear(max_length, n_feature)
+ self.wv = nn.Linear(in_channels, in_channels)
+ self.we = nn.Linear(in_channels, max_length)
+
+ self.active = nn.Tanh()
+ self.softmax = nn.Softmax(dim=2)
+
+ def forward(self, enc_output):
+ enc_output = enc_output.permute(0, 2, 3, 1).flatten(1, 2)
+ reading_order = torch.arange(self.max_length, dtype=torch.long, device=enc_output.device)
+ reading_order = reading_order.unsqueeze(0).expand(enc_output.size(0), -1) # (S,) -> (B, S)
+ reading_order_embed = self.f0_embedding(reading_order) # b,25,512
+
+ t = self.w0(reading_order_embed.permute(0, 2, 1)) # b,512,256
+ t = self.active(t.permute(0, 2, 1) + self.wv(enc_output)) # b,256,512
+
+ attn = self.we(t) # b,256,25
+ attn = self.softmax(attn.permute(0, 2, 1)) # b,25,256
+ g_output = torch.bmm(attn, enc_output) # b,25,512
+ return g_output, attn.view(*attn.shape[:2], 8, 32)
+
+
+def encoder_layer(in_c, out_c, k=3, s=2, p=1):
+ return nn.Sequential(nn.Conv2d(in_c, out_c, k, s, p),
+ nn.BatchNorm2d(out_c),
+ nn.ReLU(True))
+
+
+def decoder_layer(in_c, out_c, k=3, s=1, p=1, mode='nearest', scale_factor=None, size=None):
+ align_corners = None if mode == 'nearest' else True
+ return nn.Sequential(nn.Upsample(size=size, scale_factor=scale_factor,
+ mode=mode, align_corners=align_corners),
+ nn.Conv2d(in_c, out_c, k, s, p),
+ nn.BatchNorm2d(out_c),
+ nn.ReLU(True))
+
+
+class PositionAttention(nn.Module):
+ def __init__(self, max_length, in_channels=512, num_channels=64,
+ h=8, w=32, mode='nearest', **kwargs):
+ super().__init__()
+ self.max_length = max_length
+ self.k_encoder = nn.Sequential(
+ encoder_layer(in_channels, num_channels, s=(1, 2)),
+ encoder_layer(num_channels, num_channels, s=(2, 2)),
+ encoder_layer(num_channels, num_channels, s=(2, 2)),
+ encoder_layer(num_channels, num_channels, s=(2, 2))
+ )
+ self.k_decoder = nn.Sequential(
+ decoder_layer(num_channels, num_channels, scale_factor=2, mode=mode),
+ decoder_layer(num_channels, num_channels, scale_factor=2, mode=mode),
+ decoder_layer(num_channels, num_channels, scale_factor=2, mode=mode),
+ decoder_layer(num_channels, in_channels, size=(h, w), mode=mode)
+ )
+
+ self.pos_encoder = PositionalEncoding(in_channels, dropout=0., max_len=max_length)
+ self.project = nn.Linear(in_channels, in_channels)
+
+ def forward(self, x):
+ N, E, H, W = x.size()
+ k, v = x, x # (N, E, H, W)
+
+ # calculate key vector
+ features = []
+ for i in range(0, len(self.k_encoder)):
+ k = self.k_encoder[i](k)
+ features.append(k)
+ for i in range(0, len(self.k_decoder) - 1):
+ k = self.k_decoder[i](k)
+ k = k + features[len(self.k_decoder) - 2 - i]
+ k = self.k_decoder[-1](k)
+
+ # calculate query vector
+ # TODO q=f(q,k)
+ zeros = x.new_zeros((self.max_length, N, E)) # (T, N, E)
+ q = self.pos_encoder(zeros) # (T, N, E)
+ q = q.permute(1, 0, 2) # (N, T, E)
+ q = self.project(q) # (N, T, E)
+
+ # calculate attention
+ attn_scores = torch.bmm(q, k.flatten(2, 3)) # (N, T, (H*W))
+ attn_scores = attn_scores / (E ** 0.5)
+ attn_scores = torch.softmax(attn_scores, dim=-1)
+
+ v = v.permute(0, 2, 3, 1).view(N, -1, E) # (N, (H*W), E)
+ attn_vecs = torch.bmm(attn_scores, v) # (N, T, E)
+
+ return attn_vecs, attn_scores.view(N, -1, H, W)
diff --git a/IndicPhotoOCR/utils/strhub/models/abinet/backbone.py b/IndicPhotoOCR/utils/strhub/models/abinet/backbone.py
new file mode 100644
index 0000000000000000000000000000000000000000..debcabd7f115db0e698a55175a01a0ff0131e10f
--- /dev/null
+++ b/IndicPhotoOCR/utils/strhub/models/abinet/backbone.py
@@ -0,0 +1,24 @@
+import torch.nn as nn
+from torch.nn import TransformerEncoderLayer, TransformerEncoder
+
+from .resnet import resnet45
+from .transformer import PositionalEncoding
+
+
+class ResTranformer(nn.Module):
+ def __init__(self, d_model=512, nhead=8, d_inner=2048, dropout=0.1, activation='relu', backbone_ln=2):
+ super().__init__()
+ self.resnet = resnet45()
+ self.pos_encoder = PositionalEncoding(d_model, max_len=8 * 32)
+ encoder_layer = TransformerEncoderLayer(d_model=d_model, nhead=nhead,
+ dim_feedforward=d_inner, dropout=dropout, activation=activation)
+ self.transformer = TransformerEncoder(encoder_layer, backbone_ln)
+
+ def forward(self, images):
+ feature = self.resnet(images)
+ n, c, h, w = feature.shape
+ feature = feature.view(n, c, -1).permute(2, 0, 1)
+ feature = self.pos_encoder(feature)
+ feature = self.transformer(feature)
+ feature = feature.permute(1, 2, 0).view(n, c, h, w)
+ return feature
diff --git a/IndicPhotoOCR/utils/strhub/models/abinet/model.py b/IndicPhotoOCR/utils/strhub/models/abinet/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..cc0cd143d324822c57b897b6e5749024d857fd30
--- /dev/null
+++ b/IndicPhotoOCR/utils/strhub/models/abinet/model.py
@@ -0,0 +1,31 @@
+import torch
+import torch.nn as nn
+
+
+class Model(nn.Module):
+
+ def __init__(self, dataset_max_length: int, null_label: int):
+ super().__init__()
+ self.max_length = dataset_max_length + 1 # additional stop token
+ self.null_label = null_label
+
+ def _get_length(self, logit, dim=-1):
+ """ Greed decoder to obtain length from logit"""
+ out = (logit.argmax(dim=-1) == self.null_label)
+ abn = out.any(dim)
+ out = ((out.cumsum(dim) == 1) & out).max(dim)[1]
+ out = out + 1 # additional end token
+ out = torch.where(abn, out, out.new_tensor(logit.shape[1], device=out.device))
+ return out
+
+ @staticmethod
+ def _get_padding_mask(length, max_length):
+ length = length.unsqueeze(-1)
+ grid = torch.arange(0, max_length, device=length.device).unsqueeze(0)
+ return grid >= length
+
+ @staticmethod
+ def _get_location_mask(sz, device=None):
+ mask = torch.eye(sz, device=device)
+ mask = mask.float().masked_fill(mask == 1, float('-inf'))
+ return mask
diff --git a/IndicPhotoOCR/utils/strhub/models/abinet/model_abinet_iter.py b/IndicPhotoOCR/utils/strhub/models/abinet/model_abinet_iter.py
new file mode 100644
index 0000000000000000000000000000000000000000..1a8523ff6431f991037d56dc8dd72ae67c7bf242
--- /dev/null
+++ b/IndicPhotoOCR/utils/strhub/models/abinet/model_abinet_iter.py
@@ -0,0 +1,39 @@
+import torch
+from torch import nn
+
+from .model_alignment import BaseAlignment
+from .model_language import BCNLanguage
+from .model_vision import BaseVision
+
+
+class ABINetIterModel(nn.Module):
+ def __init__(self, dataset_max_length, null_label, num_classes, iter_size=1,
+ d_model=512, nhead=8, d_inner=2048, dropout=0.1, activation='relu',
+ v_loss_weight=1., v_attention='position', v_attention_mode='nearest',
+ v_backbone='transformer', v_num_layers=2,
+ l_loss_weight=1., l_num_layers=4, l_detach=True, l_use_self_attn=False,
+ a_loss_weight=1.):
+ super().__init__()
+ self.iter_size = iter_size
+ self.vision = BaseVision(dataset_max_length, null_label, num_classes, v_attention, v_attention_mode,
+ v_loss_weight, d_model, nhead, d_inner, dropout, activation, v_backbone, v_num_layers)
+ self.language = BCNLanguage(dataset_max_length, null_label, num_classes, d_model, nhead, d_inner, dropout,
+ activation, l_num_layers, l_detach, l_use_self_attn, l_loss_weight)
+ self.alignment = BaseAlignment(dataset_max_length, null_label, num_classes, d_model, a_loss_weight)
+
+ def forward(self, images):
+ v_res = self.vision(images)
+ a_res = v_res
+ all_l_res, all_a_res = [], []
+ for _ in range(self.iter_size):
+ tokens = torch.softmax(a_res['logits'], dim=-1)
+ lengths = a_res['pt_lengths']
+ lengths.clamp_(2, self.language.max_length) # TODO:move to langauge model
+ l_res = self.language(tokens, lengths)
+ all_l_res.append(l_res)
+ a_res = self.alignment(l_res['feature'], v_res['feature'])
+ all_a_res.append(a_res)
+ if self.training:
+ return all_a_res, all_l_res, v_res
+ else:
+ return a_res, all_l_res[-1], v_res
diff --git a/IndicPhotoOCR/utils/strhub/models/abinet/model_alignment.py b/IndicPhotoOCR/utils/strhub/models/abinet/model_alignment.py
new file mode 100644
index 0000000000000000000000000000000000000000..9ccfa95e65dbd7176c8bcee693bb0bcb8ad13c69
--- /dev/null
+++ b/IndicPhotoOCR/utils/strhub/models/abinet/model_alignment.py
@@ -0,0 +1,28 @@
+import torch
+import torch.nn as nn
+
+from .model import Model
+
+
+class BaseAlignment(Model):
+ def __init__(self, dataset_max_length, null_label, num_classes, d_model=512, loss_weight=1.0):
+ super().__init__(dataset_max_length, null_label)
+ self.loss_weight = loss_weight
+ self.w_att = nn.Linear(2 * d_model, d_model)
+ self.cls = nn.Linear(d_model, num_classes)
+
+ def forward(self, l_feature, v_feature):
+ """
+ Args:
+ l_feature: (N, T, E) where T is length, N is batch size and d is dim of model
+ v_feature: (N, T, E) shape the same as l_feature
+ """
+ f = torch.cat((l_feature, v_feature), dim=2)
+ f_att = torch.sigmoid(self.w_att(f))
+ output = f_att * v_feature + (1 - f_att) * l_feature
+
+ logits = self.cls(output) # (N, T, C)
+ pt_lengths = self._get_length(logits)
+
+ return {'logits': logits, 'pt_lengths': pt_lengths, 'loss_weight': self.loss_weight,
+ 'name': 'alignment'}
diff --git a/IndicPhotoOCR/utils/strhub/models/abinet/model_language.py b/IndicPhotoOCR/utils/strhub/models/abinet/model_language.py
new file mode 100644
index 0000000000000000000000000000000000000000..aa8bb8f60b61ad96dca3c54f7db94e19ffd42b83
--- /dev/null
+++ b/IndicPhotoOCR/utils/strhub/models/abinet/model_language.py
@@ -0,0 +1,49 @@
+import torch.nn as nn
+
+from .model import Model
+from .transformer import PositionalEncoding, TransformerDecoderLayer, TransformerDecoder
+
+
+class BCNLanguage(Model):
+ def __init__(self, dataset_max_length, null_label, num_classes, d_model=512, nhead=8, d_inner=2048, dropout=0.1,
+ activation='relu', num_layers=4, detach=True, use_self_attn=False, loss_weight=1.0,
+ global_debug=False):
+ super().__init__(dataset_max_length, null_label)
+ self.detach = detach
+ self.loss_weight = loss_weight
+ self.proj = nn.Linear(num_classes, d_model, False)
+ self.token_encoder = PositionalEncoding(d_model, max_len=self.max_length)
+ self.pos_encoder = PositionalEncoding(d_model, dropout=0, max_len=self.max_length)
+ decoder_layer = TransformerDecoderLayer(d_model, nhead, d_inner, dropout,
+ activation, self_attn=use_self_attn, debug=global_debug)
+ self.model = TransformerDecoder(decoder_layer, num_layers)
+ self.cls = nn.Linear(d_model, num_classes)
+
+ def forward(self, tokens, lengths):
+ """
+ Args:
+ tokens: (N, T, C) where T is length, N is batch size and C is classes number
+ lengths: (N,)
+ """
+ if self.detach:
+ tokens = tokens.detach()
+ embed = self.proj(tokens) # (N, T, E)
+ embed = embed.permute(1, 0, 2) # (T, N, E)
+ embed = self.token_encoder(embed) # (T, N, E)
+ padding_mask = self._get_padding_mask(lengths, self.max_length)
+
+ zeros = embed.new_zeros(*embed.shape)
+ qeury = self.pos_encoder(zeros)
+ location_mask = self._get_location_mask(self.max_length, tokens.device)
+ output = self.model(qeury, embed,
+ tgt_key_padding_mask=padding_mask,
+ memory_mask=location_mask,
+ memory_key_padding_mask=padding_mask) # (T, N, E)
+ output = output.permute(1, 0, 2) # (N, T, E)
+
+ logits = self.cls(output) # (N, T, C)
+ pt_lengths = self._get_length(logits)
+
+ res = {'feature': output, 'logits': logits, 'pt_lengths': pt_lengths,
+ 'loss_weight': self.loss_weight, 'name': 'language'}
+ return res
diff --git a/IndicPhotoOCR/utils/strhub/models/abinet/model_vision.py b/IndicPhotoOCR/utils/strhub/models/abinet/model_vision.py
new file mode 100644
index 0000000000000000000000000000000000000000..bddb7d5f237854b81c388090e2e20fc26632c431
--- /dev/null
+++ b/IndicPhotoOCR/utils/strhub/models/abinet/model_vision.py
@@ -0,0 +1,45 @@
+from torch import nn
+
+from .attention import PositionAttention, Attention
+from .backbone import ResTranformer
+from .model import Model
+from .resnet import resnet45
+
+
+class BaseVision(Model):
+ def __init__(self, dataset_max_length, null_label, num_classes,
+ attention='position', attention_mode='nearest', loss_weight=1.0,
+ d_model=512, nhead=8, d_inner=2048, dropout=0.1, activation='relu',
+ backbone='transformer', backbone_ln=2):
+ super().__init__(dataset_max_length, null_label)
+ self.loss_weight = loss_weight
+ self.out_channels = d_model
+
+ if backbone == 'transformer':
+ self.backbone = ResTranformer(d_model, nhead, d_inner, dropout, activation, backbone_ln)
+ else:
+ self.backbone = resnet45()
+
+ if attention == 'position':
+ self.attention = PositionAttention(
+ max_length=self.max_length,
+ mode=attention_mode
+ )
+ elif attention == 'attention':
+ self.attention = Attention(
+ max_length=self.max_length,
+ n_feature=8 * 32,
+ )
+ else:
+ raise ValueError(f'invalid attention: {attention}')
+
+ self.cls = nn.Linear(self.out_channels, num_classes)
+
+ def forward(self, images):
+ features = self.backbone(images) # (N, E, H, W)
+ attn_vecs, attn_scores = self.attention(features) # (N, T, E), (N, T, H, W)
+ logits = self.cls(attn_vecs) # (N, T, C)
+ pt_lengths = self._get_length(logits)
+
+ return {'feature': attn_vecs, 'logits': logits, 'pt_lengths': pt_lengths,
+ 'attn_scores': attn_scores, 'loss_weight': self.loss_weight, 'name': 'vision'}
diff --git a/IndicPhotoOCR/utils/strhub/models/abinet/resnet.py b/IndicPhotoOCR/utils/strhub/models/abinet/resnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..59bf38896987b3560e254e8037426d29bcdd5844
--- /dev/null
+++ b/IndicPhotoOCR/utils/strhub/models/abinet/resnet.py
@@ -0,0 +1,72 @@
+import math
+from typing import Optional, Callable
+
+import torch.nn as nn
+from torchvision.models import resnet
+
+
+class BasicBlock(resnet.BasicBlock):
+
+ def __init__(self, inplanes: int, planes: int, stride: int = 1, downsample: Optional[nn.Module] = None,
+ groups: int = 1, base_width: int = 64, dilation: int = 1,
+ norm_layer: Optional[Callable[..., nn.Module]] = None) -> None:
+ super().__init__(inplanes, planes, stride, downsample, groups, base_width, dilation, norm_layer)
+ self.conv1 = resnet.conv1x1(inplanes, planes)
+ self.conv2 = resnet.conv3x3(planes, planes, stride)
+
+
+class ResNet(nn.Module):
+
+ def __init__(self, block, layers):
+ super().__init__()
+ self.inplanes = 32
+ self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1,
+ bias=False)
+ self.bn1 = nn.BatchNorm2d(32)
+ self.relu = nn.ReLU(inplace=True)
+
+ self.layer1 = self._make_layer(block, 32, layers[0], stride=2)
+ self.layer2 = self._make_layer(block, 64, layers[1], stride=1)
+ self.layer3 = self._make_layer(block, 128, layers[2], stride=2)
+ self.layer4 = self._make_layer(block, 256, layers[3], stride=1)
+ self.layer5 = self._make_layer(block, 512, layers[4], stride=1)
+
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
+ m.weight.data.normal_(0, math.sqrt(2. / n))
+ elif isinstance(m, nn.BatchNorm2d):
+ m.weight.data.fill_(1)
+ m.bias.data.zero_()
+
+ def _make_layer(self, block, planes, blocks, stride=1):
+ downsample = None
+ if stride != 1 or self.inplanes != planes * block.expansion:
+ downsample = nn.Sequential(
+ nn.Conv2d(self.inplanes, planes * block.expansion,
+ kernel_size=1, stride=stride, bias=False),
+ nn.BatchNorm2d(planes * block.expansion),
+ )
+
+ layers = []
+ layers.append(block(self.inplanes, planes, stride, downsample))
+ self.inplanes = planes * block.expansion
+ for i in range(1, blocks):
+ layers.append(block(self.inplanes, planes))
+
+ return nn.Sequential(*layers)
+
+ def forward(self, x):
+ x = self.conv1(x)
+ x = self.bn1(x)
+ x = self.relu(x)
+ x = self.layer1(x)
+ x = self.layer2(x)
+ x = self.layer3(x)
+ x = self.layer4(x)
+ x = self.layer5(x)
+ return x
+
+
+def resnet45():
+ return ResNet(BasicBlock, [3, 4, 6, 6, 3])
diff --git a/IndicPhotoOCR/utils/strhub/models/abinet/system.py b/IndicPhotoOCR/utils/strhub/models/abinet/system.py
new file mode 100644
index 0000000000000000000000000000000000000000..f56e9d1dff021318095d28fb5eb99cace5371ecb
--- /dev/null
+++ b/IndicPhotoOCR/utils/strhub/models/abinet/system.py
@@ -0,0 +1,215 @@
+# Scene Text Recognition Model Hub
+# Copyright 2022 Darwin Bautista
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+import math
+from typing import Any, Optional
+
+import torch
+import torch.nn.functional as F
+from torch import Tensor, nn
+from torch.optim import AdamW
+from torch.optim.lr_scheduler import OneCycleLR
+
+from pytorch_lightning.utilities.types import STEP_OUTPUT
+from timm.optim.optim_factory import param_groups_weight_decay
+
+from strhub.models.base import CrossEntropySystem
+from strhub.models.utils import init_weights
+
+from .model_abinet_iter import ABINetIterModel as Model
+
+log = logging.getLogger(__name__)
+
+
+class ABINet(CrossEntropySystem):
+
+ def __init__(
+ self,
+ charset_train: str,
+ charset_test: str,
+ max_label_length: int,
+ batch_size: int,
+ lr: float,
+ warmup_pct: float,
+ weight_decay: float,
+ iter_size: int,
+ d_model: int,
+ nhead: int,
+ d_inner: int,
+ dropout: float,
+ activation: str,
+ v_loss_weight: float,
+ v_attention: str,
+ v_attention_mode: str,
+ v_backbone: str,
+ v_num_layers: int,
+ l_loss_weight: float,
+ l_num_layers: int,
+ l_detach: bool,
+ l_use_self_attn: bool,
+ l_lr: float,
+ a_loss_weight: float,
+ lm_only: bool = False,
+ **kwargs,
+ ) -> None:
+ super().__init__(charset_train, charset_test, batch_size, lr, warmup_pct, weight_decay)
+ self.scheduler = None
+ self.save_hyperparameters()
+ self.max_label_length = max_label_length
+ self.num_classes = len(self.tokenizer) - 2 # We don't predict nor
+ self.model = Model(
+ max_label_length,
+ self.eos_id,
+ self.num_classes,
+ iter_size,
+ d_model,
+ nhead,
+ d_inner,
+ dropout,
+ activation,
+ v_loss_weight,
+ v_attention,
+ v_attention_mode,
+ v_backbone,
+ v_num_layers,
+ l_loss_weight,
+ l_num_layers,
+ l_detach,
+ l_use_self_attn,
+ a_loss_weight,
+ )
+ self.model.apply(init_weights)
+ # FIXME: doesn't support resumption from checkpoint yet
+ self._reset_alignment = True
+ self._reset_optimizers = True
+ self.l_lr = l_lr
+ self.lm_only = lm_only
+ # Train LM only. Freeze other submodels.
+ if lm_only:
+ self.l_lr = lr # for tuning
+ self.model.vision.requires_grad_(False)
+ self.model.alignment.requires_grad_(False)
+
+ @property
+ def _pretraining(self):
+ # In the original work, VM was pretrained for 8 epochs while full model was trained for an additional 10 epochs.
+ total_steps = self.trainer.estimated_stepping_batches * self.trainer.accumulate_grad_batches
+ return self.global_step < (8 / (8 + 10)) * total_steps
+
+ @torch.jit.ignore
+ def no_weight_decay(self):
+ return {'model.language.proj.weight'}
+
+ def _add_weight_decay(self, model: nn.Module, skip_list=()):
+ if self.weight_decay:
+ return param_groups_weight_decay(model, self.weight_decay, skip_list)
+ else:
+ return [{'params': model.parameters()}]
+
+ def configure_optimizers(self):
+ agb = self.trainer.accumulate_grad_batches
+ # Linear scaling so that the effective learning rate is constant regardless of the number of GPUs used with DDP.
+ lr_scale = agb * math.sqrt(self.trainer.num_devices) * self.batch_size / 256.0
+ lr = lr_scale * self.lr
+ l_lr = lr_scale * self.l_lr
+ params = []
+ params.extend(self._add_weight_decay(self.model.vision))
+ params.extend(self._add_weight_decay(self.model.alignment))
+ # We use a different learning rate for the LM.
+ for p in self._add_weight_decay(self.model.language, ('proj.weight',)):
+ p['lr'] = l_lr
+ params.append(p)
+ max_lr = [p.get('lr', lr) for p in params]
+ optim = AdamW(params, lr)
+ self.scheduler = OneCycleLR(
+ optim, max_lr, self.trainer.estimated_stepping_batches, pct_start=self.warmup_pct, cycle_momentum=False
+ )
+ return {'optimizer': optim, 'lr_scheduler': {'scheduler': self.scheduler, 'interval': 'step'}}
+
+ def forward(self, images: Tensor, max_length: Optional[int] = None) -> Tensor:
+ max_length = self.max_label_length if max_length is None else min(max_length, self.max_label_length)
+ logits = self.model.forward(images)[0]['logits']
+ return logits[:, : max_length + 1] # truncate
+
+ def calc_loss(self, targets, *res_lists) -> Tensor:
+ total_loss = 0
+ for res_list in res_lists:
+ loss = 0
+ if isinstance(res_list, dict):
+ res_list = [res_list]
+ for res in res_list:
+ logits = res['logits'].flatten(end_dim=1)
+ loss += F.cross_entropy(logits, targets.flatten(), ignore_index=self.pad_id)
+ loss /= len(res_list)
+ self.log('loss_' + res_list[0]['name'], loss)
+ total_loss += res_list[0]['loss_weight'] * loss
+ return total_loss
+
+ def on_train_batch_start(self, batch: Any, batch_idx: int) -> None:
+ if not self._pretraining and self._reset_optimizers:
+ log.info('Pretraining ends. Updating base LRs.')
+ self._reset_optimizers = False
+ # Make base_lr the same for all groups
+ base_lr = self.scheduler.base_lrs[0] # base_lr of group 0 - VM
+ self.scheduler.base_lrs = [base_lr] * len(self.scheduler.base_lrs)
+
+ def _prepare_inputs_and_targets(self, labels):
+ # Use dummy label to ensure sequence length is constant.
+ dummy = ['0' * self.max_label_length]
+ targets = self.tokenizer.encode(dummy + list(labels), self.device)[1:]
+ targets = targets[:, 1:] # remove . Unused here.
+ # Inputs are padded with eos_id
+ inputs = torch.where(targets == self.pad_id, self.eos_id, targets)
+ inputs = F.one_hot(inputs, self.num_classes).float()
+ lengths = torch.as_tensor(list(map(len, labels)), device=self.device) + 1 # +1 for eos
+ return inputs, lengths, targets
+
+ def training_step(self, batch, batch_idx) -> STEP_OUTPUT:
+ images, labels = batch
+ inputs, lengths, targets = self._prepare_inputs_and_targets(labels)
+ if self.lm_only:
+ l_res = self.model.language(inputs, lengths)
+ loss = self.calc_loss(targets, l_res)
+ # Pretrain submodels independently first
+ elif self._pretraining:
+ # Vision
+ v_res = self.model.vision(images)
+ # Language
+ l_res = self.model.language(inputs, lengths)
+ # We also train the alignment model to 'satisfy' DDP requirements (all parameters should be used).
+ # We'll reset its parameters prior to joint training.
+ a_res = self.model.alignment(l_res['feature'].detach(), v_res['feature'].detach())
+ loss = self.calc_loss(targets, v_res, l_res, a_res)
+ else:
+ # Reset alignment model's parameters once prior to full model training.
+ if self._reset_alignment:
+ log.info('Pretraining ends. Resetting alignment model.')
+ self._reset_alignment = False
+ self.model.alignment.apply(init_weights)
+ all_a_res, all_l_res, v_res = self.model.forward(images)
+ loss = self.calc_loss(targets, v_res, all_l_res, all_a_res)
+ self.log('loss', loss)
+ return loss
+
+ def forward_logits_loss(self, images: Tensor, labels: list[str]) -> tuple[Tensor, Tensor, int]:
+ if self.lm_only:
+ inputs, lengths, targets = self._prepare_inputs_and_targets(labels)
+ l_res = self.model.language(inputs, lengths)
+ loss = self.calc_loss(targets, l_res)
+ loss_numel = (targets != self.pad_id).sum()
+ return l_res['logits'], loss, loss_numel
+ else:
+ return super().forward_logits_loss(images, labels)
diff --git a/IndicPhotoOCR/utils/strhub/models/abinet/transformer.py b/IndicPhotoOCR/utils/strhub/models/abinet/transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..03ae4b13976ddc67dfb2e2bfd83885a823cf9ecb
--- /dev/null
+++ b/IndicPhotoOCR/utils/strhub/models/abinet/transformer.py
@@ -0,0 +1,198 @@
+import math
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+from torch.nn.modules.transformer import _get_activation_fn, _get_clones
+
+
+class TransformerDecoder(nn.Module):
+ r"""TransformerDecoder is a stack of N decoder layers
+
+ Args:
+ decoder_layer: an instance of the TransformerDecoderLayer() class (required).
+ num_layers: the number of sub-decoder-layers in the decoder (required).
+ norm: the layer normalization component (optional).
+
+ Examples::
+ >>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8)
+ >>> transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=6)
+ >>> memory = torch.rand(10, 32, 512)
+ >>> tgt = torch.rand(20, 32, 512)
+ >>> out = transformer_decoder(tgt, memory)
+ """
+ __constants__ = ['norm']
+
+ def __init__(self, decoder_layer, num_layers, norm=None):
+ super(TransformerDecoder, self).__init__()
+ self.layers = _get_clones(decoder_layer, num_layers)
+ self.num_layers = num_layers
+ self.norm = norm
+
+ def forward(self, tgt, memory, memory2=None, tgt_mask=None,
+ memory_mask=None, memory_mask2=None, tgt_key_padding_mask=None,
+ memory_key_padding_mask=None, memory_key_padding_mask2=None):
+ # type: (Tensor, Tensor, Optional[Tensor], Optional[Tensor], Optional[Tensor], Optional[Tensor]) -> Tensor
+ r"""Pass the inputs (and mask) through the decoder layer in turn.
+
+ Args:
+ tgt: the sequence to the decoder (required).
+ memory: the sequence from the last layer of the encoder (required).
+ tgt_mask: the mask for the tgt sequence (optional).
+ memory_mask: the mask for the memory sequence (optional).
+ tgt_key_padding_mask: the mask for the tgt keys per batch (optional).
+ memory_key_padding_mask: the mask for the memory keys per batch (optional).
+
+ Shape:
+ see the docs in Transformer class.
+ """
+ output = tgt
+
+ for mod in self.layers:
+ output = mod(output, memory, memory2=memory2, tgt_mask=tgt_mask,
+ memory_mask=memory_mask, memory_mask2=memory_mask2,
+ tgt_key_padding_mask=tgt_key_padding_mask,
+ memory_key_padding_mask=memory_key_padding_mask,
+ memory_key_padding_mask2=memory_key_padding_mask2)
+
+ if self.norm is not None:
+ output = self.norm(output)
+
+ return output
+
+
+class TransformerDecoderLayer(nn.Module):
+ r"""TransformerDecoderLayer is made up of self-attn, multi-head-attn and feedforward network.
+ This standard decoder layer is based on the paper "Attention Is All You Need".
+ Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
+ Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in
+ Neural Information Processing Systems, pages 6000-6010. Users may modify or implement
+ in a different way during application.
+
+ Args:
+ d_model: the number of expected features in the input (required).
+ nhead: the number of heads in the multiheadattention models (required).
+ dim_feedforward: the dimension of the feedforward network model (default=2048).
+ dropout: the dropout value (default=0.1).
+ activation: the activation function of intermediate layer, relu or gelu (default=relu).
+
+ Examples::
+ >>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8)
+ >>> memory = torch.rand(10, 32, 512)
+ >>> tgt = torch.rand(20, 32, 512)
+ >>> out = decoder_layer(tgt, memory)
+ """
+
+ def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
+ activation="relu", self_attn=True, siamese=False, debug=False):
+ super().__init__()
+ self.has_self_attn, self.siamese = self_attn, siamese
+ self.debug = debug
+ if self.has_self_attn:
+ self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
+ self.norm1 = nn.LayerNorm(d_model)
+ self.dropout1 = nn.Dropout(dropout)
+ self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
+ # Implementation of Feedforward model
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
+ self.dropout = nn.Dropout(dropout)
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
+
+ self.norm2 = nn.LayerNorm(d_model)
+ self.norm3 = nn.LayerNorm(d_model)
+ self.dropout2 = nn.Dropout(dropout)
+ self.dropout3 = nn.Dropout(dropout)
+ if self.siamese:
+ self.multihead_attn2 = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
+
+ self.activation = _get_activation_fn(activation)
+
+ def __setstate__(self, state):
+ if 'activation' not in state:
+ state['activation'] = F.relu
+ super().__setstate__(state)
+
+ def forward(self, tgt, memory, tgt_mask=None, memory_mask=None,
+ tgt_key_padding_mask=None, memory_key_padding_mask=None,
+ memory2=None, memory_mask2=None, memory_key_padding_mask2=None):
+ # type: (Tensor, Tensor, Optional[Tensor], Optional[Tensor], Optional[Tensor], Optional[Tensor]) -> Tensor
+ r"""Pass the inputs (and mask) through the decoder layer.
+
+ Args:
+ tgt: the sequence to the decoder layer (required).
+ memory: the sequence from the last layer of the encoder (required).
+ tgt_mask: the mask for the tgt sequence (optional).
+ memory_mask: the mask for the memory sequence (optional).
+ tgt_key_padding_mask: the mask for the tgt keys per batch (optional).
+ memory_key_padding_mask: the mask for the memory keys per batch (optional).
+
+ Shape:
+ see the docs in Transformer class.
+ """
+ if self.has_self_attn:
+ tgt2, attn = self.self_attn(tgt, tgt, tgt, attn_mask=tgt_mask,
+ key_padding_mask=tgt_key_padding_mask)
+ tgt = tgt + self.dropout1(tgt2)
+ tgt = self.norm1(tgt)
+ if self.debug: self.attn = attn
+ tgt2, attn2 = self.multihead_attn(tgt, memory, memory, attn_mask=memory_mask,
+ key_padding_mask=memory_key_padding_mask)
+ if self.debug: self.attn2 = attn2
+
+ if self.siamese:
+ tgt3, attn3 = self.multihead_attn2(tgt, memory2, memory2, attn_mask=memory_mask2,
+ key_padding_mask=memory_key_padding_mask2)
+ tgt = tgt + self.dropout2(tgt3)
+ if self.debug: self.attn3 = attn3
+
+ tgt = tgt + self.dropout2(tgt2)
+ tgt = self.norm2(tgt)
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
+ tgt = tgt + self.dropout3(tgt2)
+ tgt = self.norm3(tgt)
+
+ return tgt
+
+
+class PositionalEncoding(nn.Module):
+ r"""Inject some information about the relative or absolute position of the tokens
+ in the sequence. The positional encodings have the same dimension as
+ the embeddings, so that the two can be summed. Here, we use sine and cosine
+ functions of different frequencies.
+ .. math::
+ \text{PosEncoder}(pos, 2i) = sin(pos/10000^(2i/d_model))
+ \text{PosEncoder}(pos, 2i+1) = cos(pos/10000^(2i/d_model))
+ \text{where pos is the word position and i is the embed idx)
+ Args:
+ d_model: the embed dim (required).
+ dropout: the dropout value (default=0.1).
+ max_len: the max. length of the incoming sequence (default=5000).
+ Examples:
+ >>> pos_encoder = PositionalEncoding(d_model)
+ """
+
+ def __init__(self, d_model, dropout=0.1, max_len=5000):
+ super().__init__()
+ self.dropout = nn.Dropout(p=dropout)
+
+ pe = torch.zeros(max_len, d_model)
+ position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
+ div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
+ pe[:, 0::2] = torch.sin(position * div_term)
+ pe[:, 1::2] = torch.cos(position * div_term)
+ pe = pe.unsqueeze(0).transpose(0, 1)
+ self.register_buffer('pe', pe)
+
+ def forward(self, x):
+ r"""Inputs of forward function
+ Args:
+ x: the sequence fed to the positional encoder model (required).
+ Shape:
+ x: [sequence length, batch size, embed dim]
+ output: [sequence length, batch size, embed dim]
+ Examples:
+ >>> output = pos_encoder(x)
+ """
+
+ x = x + self.pe[:x.size(0), :]
+ return self.dropout(x)
diff --git a/IndicPhotoOCR/utils/strhub/models/base.py b/IndicPhotoOCR/utils/strhub/models/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..42d61efd6311445c604446ddae208bdef3d7b115
--- /dev/null
+++ b/IndicPhotoOCR/utils/strhub/models/base.py
@@ -0,0 +1,221 @@
+# Scene Text Recognition Model Hub
+# Copyright 2022 Darwin Bautista
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+from abc import ABC, abstractmethod
+from dataclasses import dataclass
+from typing import Optional
+
+from nltk import edit_distance
+
+import torch
+import torch.nn.functional as F
+from torch import Tensor
+from torch.optim import Optimizer
+from torch.optim.lr_scheduler import OneCycleLR
+
+import pytorch_lightning as pl
+from pytorch_lightning.utilities.types import STEP_OUTPUT
+from timm.optim import create_optimizer_v2
+
+from IndicPhotoOCR.utils.strhub.data.utils import BaseTokenizer, CharsetAdapter, CTCTokenizer, Tokenizer
+
+
+@dataclass
+class BatchResult:
+ num_samples: int
+ correct: int
+ ned: float
+ confidence: float
+ label_length: int
+ loss: Tensor
+ loss_numel: int
+
+
+EPOCH_OUTPUT = list[dict[str, BatchResult]]
+
+
+class BaseSystem(pl.LightningModule, ABC):
+
+ def __init__(
+ self,
+ tokenizer: BaseTokenizer,
+ charset_test: str,
+ batch_size: int,
+ lr: float,
+ warmup_pct: float,
+ weight_decay: float,
+ ) -> None:
+ super().__init__()
+ self.tokenizer = tokenizer
+ self.charset_adapter = CharsetAdapter(charset_test)
+ self.batch_size = batch_size
+ self.lr = lr
+ self.warmup_pct = warmup_pct
+ self.weight_decay = weight_decay
+ self.outputs: EPOCH_OUTPUT = []
+
+ @abstractmethod
+ def forward(self, images: Tensor, max_length: Optional[int] = None) -> Tensor:
+ """Inference
+
+ Args:
+ images: Batch of images. Shape: N, Ch, H, W
+ max_length: Max sequence length of the output. If None, will use default.
+
+ Returns:
+ logits: N, L, C (L = sequence length, C = number of classes, typically len(charset_train) + num specials)
+ """
+ raise NotImplementedError
+
+ @abstractmethod
+ def forward_logits_loss(self, images: Tensor, labels: list[str]) -> tuple[Tensor, Tensor, int]:
+ """Like forward(), but also computes the loss (calls forward() internally).
+
+ Args:
+ images: Batch of images. Shape: N, Ch, H, W
+ labels: Text labels of the images
+
+ Returns:
+ logits: N, L, C (L = sequence length, C = number of classes, typically len(charset_train) + num specials)
+ loss: mean loss for the batch
+ loss_numel: number of elements the loss was calculated from
+ """
+ raise NotImplementedError
+
+ def configure_optimizers(self):
+ agb = self.trainer.accumulate_grad_batches
+ # Linear scaling so that the effective learning rate is constant regardless of the number of GPUs used with DDP.
+ lr_scale = agb * math.sqrt(self.trainer.num_devices) * self.batch_size / 256.0
+ lr = lr_scale * self.lr
+ optim = create_optimizer_v2(self, 'adamw', lr, self.weight_decay)
+ sched = OneCycleLR(
+ optim, lr, self.trainer.estimated_stepping_batches, pct_start=self.warmup_pct, cycle_momentum=False
+ )
+ return {'optimizer': optim, 'lr_scheduler': {'scheduler': sched, 'interval': 'step'}}
+
+ def optimizer_zero_grad(self, epoch: int, batch_idx: int, optimizer: Optimizer) -> None:
+ optimizer.zero_grad(set_to_none=True)
+
+ def _eval_step(self, batch, validation: bool) -> Optional[STEP_OUTPUT]:
+ images, labels = batch
+
+ correct = 0
+ total = 0
+ ned = 0
+ confidence = 0
+ label_length = 0
+ if validation:
+ logits, loss, loss_numel = self.forward_logits_loss(images, labels)
+ else:
+ # At test-time, we shouldn't specify a max_label_length because the test-time charset used
+ # might be different from the train-time charset. max_label_length in eval_logits_loss() is computed
+ # based on the transformed label, which could be wrong if the actual gt label contains characters existing
+ # in the train-time charset but not in the test-time charset. For example, "aishahaleyes.blogspot.com"
+ # is exactly 25 characters, but if processed by CharsetAdapter for the 36-char set, it becomes 23 characters
+ # long only, which sets max_label_length = 23. This will cause the model prediction to be truncated.
+ logits = self.forward(images)
+ loss = loss_numel = None # Only used for validation; not needed at test-time.
+
+ probs = logits.softmax(-1)
+ preds, probs = self.tokenizer.decode(probs)
+ for pred, prob, gt in zip(preds, probs, labels):
+ confidence += prob.prod().item()
+ pred = self.charset_adapter(pred)
+ # Follow ICDAR 2019 definition of N.E.D.
+ ned += edit_distance(pred, gt) / max(len(pred), len(gt))
+ if pred == gt:
+ correct += 1
+ total += 1
+ label_length += len(pred)
+ return dict(output=BatchResult(total, correct, ned, confidence, label_length, loss, loss_numel))
+
+ @staticmethod
+ def _aggregate_results(outputs: EPOCH_OUTPUT) -> tuple[float, float, float]:
+ if not outputs:
+ return 0.0, 0.0, 0.0
+ total_loss = 0
+ total_loss_numel = 0
+ total_n_correct = 0
+ total_norm_ED = 0
+ total_size = 0
+ for result in outputs:
+ result = result['output']
+ total_loss += result.loss_numel * result.loss
+ total_loss_numel += result.loss_numel
+ total_n_correct += result.correct
+ total_norm_ED += result.ned
+ total_size += result.num_samples
+ acc = total_n_correct / total_size
+ ned = 1 - total_norm_ED / total_size
+ loss = total_loss / total_loss_numel
+ return acc, ned, loss
+
+ def validation_step(self, batch, batch_idx) -> Optional[STEP_OUTPUT]:
+ result = self._eval_step(batch, True)
+ self.outputs.append(result)
+ return result
+
+ def on_validation_epoch_end(self) -> None:
+ acc, ned, loss = self._aggregate_results(self.outputs)
+ self.outputs.clear()
+ self.log('val_accuracy', 100 * acc, sync_dist=True)
+ self.log('val_NED', 100 * ned, sync_dist=True)
+ self.log('val_loss', loss, sync_dist=True)
+ self.log('hp_metric', acc, sync_dist=True)
+
+ def test_step(self, batch, batch_idx) -> Optional[STEP_OUTPUT]:
+ return self._eval_step(batch, False)
+
+
+class CrossEntropySystem(BaseSystem):
+
+ def __init__(
+ self, charset_train: str, charset_test: str, batch_size: int, lr: float, warmup_pct: float, weight_decay: float
+ ) -> None:
+ tokenizer = Tokenizer(charset_train)
+ super().__init__(tokenizer, charset_test, batch_size, lr, warmup_pct, weight_decay)
+ self.bos_id = tokenizer.bos_id
+ self.eos_id = tokenizer.eos_id
+ self.pad_id = tokenizer.pad_id
+
+ def forward_logits_loss(self, images: Tensor, labels: list[str]) -> tuple[Tensor, Tensor, int]:
+ targets = self.tokenizer.encode(labels, self.device)
+ targets = targets[:, 1:] # Discard
+ max_len = targets.shape[1] - 1 # exclude from count
+ logits = self.forward(images, max_len)
+ loss = F.cross_entropy(logits.flatten(end_dim=1), targets.flatten(), ignore_index=self.pad_id)
+ loss_numel = (targets != self.pad_id).sum()
+ return logits, loss, loss_numel
+
+
+class CTCSystem(BaseSystem):
+
+ def __init__(
+ self, charset_train: str, charset_test: str, batch_size: int, lr: float, warmup_pct: float, weight_decay: float
+ ) -> None:
+ tokenizer = CTCTokenizer(charset_train)
+ super().__init__(tokenizer, charset_test, batch_size, lr, warmup_pct, weight_decay)
+ self.blank_id = tokenizer.blank_id
+
+ def forward_logits_loss(self, images: Tensor, labels: list[str]) -> tuple[Tensor, Tensor, int]:
+ targets = self.tokenizer.encode(labels, self.device)
+ logits = self.forward(images)
+ log_probs = logits.log_softmax(-1).transpose(0, 1) # swap batch and seq. dims
+ T, N, _ = log_probs.shape
+ input_lengths = torch.full(size=(N,), fill_value=T, dtype=torch.long, device=self.device)
+ target_lengths = torch.as_tensor(list(map(len, labels)), dtype=torch.long, device=self.device)
+ loss = F.ctc_loss(log_probs, targets, input_lengths, target_lengths, blank=self.blank_id, zero_infinity=True)
+ return logits, loss, N
diff --git a/IndicPhotoOCR/utils/strhub/models/crnn/LICENSE b/IndicPhotoOCR/utils/strhub/models/crnn/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..f98687be392fdce266708e79885aadaa4991b67f
--- /dev/null
+++ b/IndicPhotoOCR/utils/strhub/models/crnn/LICENSE
@@ -0,0 +1,21 @@
+The MIT License (MIT)
+
+Copyright (c) 2017 Jieru Mei
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
diff --git a/IndicPhotoOCR/utils/strhub/models/crnn/__init__.py b/IndicPhotoOCR/utils/strhub/models/crnn/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a4535947d9233c8fb0a85e9c22b151697d37f410
--- /dev/null
+++ b/IndicPhotoOCR/utils/strhub/models/crnn/__init__.py
@@ -0,0 +1,13 @@
+r"""
+Shi, Baoguang, Xiang Bai, and Cong Yao.
+"An end-to-end trainable neural network for image-based sequence recognition and its application to scene text recognition."
+IEEE transactions on pattern analysis and machine intelligence 39, no. 11 (2016): 2298-2304.
+
+https://arxiv.org/abs/1507.05717
+
+All source files, except `system.py`, are based on the implementation listed below,
+and hence are released under the license of the original.
+
+Source: https://github.com/meijieru/crnn.pytorch
+License: MIT License (see included LICENSE file)
+"""
diff --git a/IndicPhotoOCR/utils/strhub/models/crnn/model.py b/IndicPhotoOCR/utils/strhub/models/crnn/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..4d5c9e8e6a1a2f3d4ed32c976f47a8cbdff22946
--- /dev/null
+++ b/IndicPhotoOCR/utils/strhub/models/crnn/model.py
@@ -0,0 +1,62 @@
+import torch.nn as nn
+
+from strhub.models.modules import BidirectionalLSTM
+
+
+class CRNN(nn.Module):
+
+ def __init__(self, img_h, nc, nclass, nh, leaky_relu=False):
+ super().__init__()
+ assert img_h % 16 == 0, 'img_h has to be a multiple of 16'
+
+ ks = [3, 3, 3, 3, 3, 3, 2]
+ ps = [1, 1, 1, 1, 1, 1, 0]
+ ss = [1, 1, 1, 1, 1, 1, 1]
+ nm = [64, 128, 256, 256, 512, 512, 512]
+
+ cnn = nn.Sequential()
+
+ def convRelu(i, batchNormalization=False):
+ nIn = nc if i == 0 else nm[i - 1]
+ nOut = nm[i]
+ cnn.add_module(f'conv{i}',
+ nn.Conv2d(nIn, nOut, ks[i], ss[i], ps[i], bias=not batchNormalization))
+ if batchNormalization:
+ cnn.add_module(f'batchnorm{i}', nn.BatchNorm2d(nOut))
+ if leaky_relu:
+ cnn.add_module(f'relu{i}',
+ nn.LeakyReLU(0.2, inplace=True))
+ else:
+ cnn.add_module(f'relu{i}', nn.ReLU(True))
+
+ convRelu(0)
+ cnn.add_module('pooling0', nn.MaxPool2d(2, 2)) # 64x16x64
+ convRelu(1)
+ cnn.add_module('pooling1', nn.MaxPool2d(2, 2)) # 128x8x32
+ convRelu(2, True)
+ convRelu(3)
+ cnn.add_module('pooling2',
+ nn.MaxPool2d((2, 2), (2, 1), (0, 1))) # 256x4x16
+ convRelu(4, True)
+ convRelu(5)
+ cnn.add_module('pooling3',
+ nn.MaxPool2d((2, 2), (2, 1), (0, 1))) # 512x2x16
+ convRelu(6, True) # 512x1x16
+
+ self.cnn = cnn
+ self.rnn = nn.Sequential(
+ BidirectionalLSTM(512, nh, nh),
+ BidirectionalLSTM(nh, nh, nclass))
+
+ def forward(self, input):
+ # conv features
+ conv = self.cnn(input)
+ b, c, h, w = conv.size()
+ assert h == 1, 'the height of conv must be 1'
+ conv = conv.squeeze(2)
+ conv = conv.transpose(1, 2) # [b, w, c]
+
+ # rnn features
+ output = self.rnn(conv)
+
+ return output
diff --git a/IndicPhotoOCR/utils/strhub/models/crnn/system.py b/IndicPhotoOCR/utils/strhub/models/crnn/system.py
new file mode 100644
index 0000000000000000000000000000000000000000..a69dfdd131cc3895bfc7b0aaa0832681dbfaab25
--- /dev/null
+++ b/IndicPhotoOCR/utils/strhub/models/crnn/system.py
@@ -0,0 +1,56 @@
+# Scene Text Recognition Model Hub
+# Copyright 2022 Darwin Bautista
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Optional, Sequence
+
+from torch import Tensor
+
+from pytorch_lightning.utilities.types import STEP_OUTPUT
+
+from strhub.models.base import CTCSystem
+from strhub.models.utils import init_weights
+
+from .model import CRNN as Model
+
+
+class CRNN(CTCSystem):
+
+ def __init__(
+ self,
+ charset_train: str,
+ charset_test: str,
+ max_label_length: int,
+ batch_size: int,
+ lr: float,
+ warmup_pct: float,
+ weight_decay: float,
+ img_size: Sequence[int],
+ hidden_size: int,
+ leaky_relu: bool,
+ **kwargs,
+ ) -> None:
+ super().__init__(charset_train, charset_test, batch_size, lr, warmup_pct, weight_decay)
+ self.save_hyperparameters()
+ self.model = Model(img_size[0], 3, len(self.tokenizer), hidden_size, leaky_relu)
+ self.model.apply(init_weights)
+
+ def forward(self, images: Tensor, max_length: Optional[int] = None) -> Tensor:
+ return self.model.forward(images)
+
+ def training_step(self, batch, batch_idx) -> STEP_OUTPUT:
+ images, labels = batch
+ loss = self.forward_logits_loss(images, labels)[1]
+ self.log('loss', loss)
+ return loss
diff --git a/IndicPhotoOCR/utils/strhub/models/modules.py b/IndicPhotoOCR/utils/strhub/models/modules.py
new file mode 100644
index 0000000000000000000000000000000000000000..a89d05f6afd67437f3cfa8aff6d2d8b12df3fafa
--- /dev/null
+++ b/IndicPhotoOCR/utils/strhub/models/modules.py
@@ -0,0 +1,20 @@
+r"""Shared modules used by CRNN and TRBA"""
+from torch import nn
+
+
+class BidirectionalLSTM(nn.Module):
+ """Ref: https://github.com/clovaai/deep-text-recognition-benchmark/blob/master/modules/sequence_modeling.py"""
+
+ def __init__(self, input_size, hidden_size, output_size):
+ super().__init__()
+ self.rnn = nn.LSTM(input_size, hidden_size, bidirectional=True, batch_first=True)
+ self.linear = nn.Linear(hidden_size * 2, output_size)
+
+ def forward(self, input):
+ """
+ input : visual feature [batch_size x T x input_size], T = num_steps.
+ output : contextual feature [batch_size x T x output_size]
+ """
+ recurrent, _ = self.rnn(input) # batch_size x T x input_size -> batch_size x T x (2*hidden_size)
+ output = self.linear(recurrent) # batch_size x T x output_size
+ return output
diff --git a/IndicPhotoOCR/utils/strhub/models/parseq/__init__.py b/IndicPhotoOCR/utils/strhub/models/parseq/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/IndicPhotoOCR/utils/strhub/models/parseq/model.py b/IndicPhotoOCR/utils/strhub/models/parseq/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..28f4f3d06d7ccc005cb40e47a5e86626a54aa04d
--- /dev/null
+++ b/IndicPhotoOCR/utils/strhub/models/parseq/model.py
@@ -0,0 +1,169 @@
+# Scene Text Recognition Model Hub
+# Copyright 2022 Darwin Bautista
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from functools import partial
+from typing import Optional, Sequence
+
+import torch
+import torch.nn as nn
+from torch import Tensor
+
+from timm.models.helpers import named_apply
+
+from IndicPhotoOCR.utils.strhub.data.utils import Tokenizer
+from IndicPhotoOCR.utils.strhub.models.utils import init_weights
+
+from .modules import Decoder, DecoderLayer, Encoder, TokenEmbedding
+
+
+class PARSeq(nn.Module):
+
+ def __init__(
+ self,
+ num_tokens: int,
+ max_label_length: int,
+ img_size: Sequence[int],
+ patch_size: Sequence[int],
+ embed_dim: int,
+ enc_num_heads: int,
+ enc_mlp_ratio: int,
+ enc_depth: int,
+ dec_num_heads: int,
+ dec_mlp_ratio: int,
+ dec_depth: int,
+ decode_ar: bool,
+ refine_iters: int,
+ dropout: float,
+ ) -> None:
+ super().__init__()
+
+ self.max_label_length = max_label_length
+ self.decode_ar = decode_ar
+ self.refine_iters = refine_iters
+
+ self.encoder = Encoder(
+ img_size, patch_size, embed_dim=embed_dim, depth=enc_depth, num_heads=enc_num_heads, mlp_ratio=enc_mlp_ratio
+ )
+ decoder_layer = DecoderLayer(embed_dim, dec_num_heads, embed_dim * dec_mlp_ratio, dropout)
+ self.decoder = Decoder(decoder_layer, num_layers=dec_depth, norm=nn.LayerNorm(embed_dim))
+
+ # We don't predict nor
+ self.head = nn.Linear(embed_dim, num_tokens - 2)
+ self.text_embed = TokenEmbedding(num_tokens, embed_dim)
+
+ # +1 for
+ self.pos_queries = nn.Parameter(torch.Tensor(1, max_label_length + 1, embed_dim))
+ self.dropout = nn.Dropout(p=dropout)
+ # Encoder has its own init.
+ named_apply(partial(init_weights, exclude=['encoder']), self)
+ nn.init.trunc_normal_(self.pos_queries, std=0.02)
+
+ @property
+ def _device(self) -> torch.device:
+ return next(self.head.parameters(recurse=False)).device
+
+ @torch.jit.ignore
+ def no_weight_decay(self):
+ param_names = {'text_embed.embedding.weight', 'pos_queries'}
+ enc_param_names = {'encoder.' + n for n in self.encoder.no_weight_decay()}
+ return param_names.union(enc_param_names)
+
+ def encode(self, img: torch.Tensor):
+ return self.encoder(img)
+
+ def decode(
+ self,
+ tgt: torch.Tensor,
+ memory: torch.Tensor,
+ tgt_mask: Optional[Tensor] = None,
+ tgt_padding_mask: Optional[Tensor] = None,
+ tgt_query: Optional[Tensor] = None,
+ tgt_query_mask: Optional[Tensor] = None,
+ ):
+ N, L = tgt.shape
+ # stands for the null context. We only supply position information for characters after .
+ null_ctx = self.text_embed(tgt[:, :1])
+ tgt_emb = self.pos_queries[:, : L - 1] + self.text_embed(tgt[:, 1:])
+ tgt_emb = self.dropout(torch.cat([null_ctx, tgt_emb], dim=1))
+ if tgt_query is None:
+ tgt_query = self.pos_queries[:, :L].expand(N, -1, -1)
+ tgt_query = self.dropout(tgt_query)
+ return self.decoder(tgt_query, tgt_emb, memory, tgt_query_mask, tgt_mask, tgt_padding_mask)
+
+ def forward(self, tokenizer: Tokenizer, images: Tensor, max_length: Optional[int] = None) -> Tensor:
+ testing = max_length is None
+ max_length = self.max_label_length if max_length is None else min(max_length, self.max_label_length)
+ bs = images.shape[0]
+ # +1 for at end of sequence.
+ num_steps = max_length + 1
+ memory = self.encode(images)
+
+ # Query positions up to `num_steps`
+ pos_queries = self.pos_queries[:, :num_steps].expand(bs, -1, -1)
+
+ # Special case for the forward permutation. Faster than using `generate_attn_masks()`
+ tgt_mask = query_mask = torch.triu(torch.ones((num_steps, num_steps), dtype=torch.bool, device=self._device), 1)
+
+ if self.decode_ar:
+ tgt_in = torch.full((bs, num_steps), tokenizer.pad_id, dtype=torch.long, device=self._device)
+ tgt_in[:, 0] = tokenizer.bos_id
+
+ logits = []
+ for i in range(num_steps):
+ j = i + 1 # next token index
+ # Efficient decoding:
+ # Input the context up to the ith token. We use only one query (at position = i) at a time.
+ # This works because of the lookahead masking effect of the canonical (forward) AR context.
+ # Past tokens have no access to future tokens, hence are fixed once computed.
+ tgt_out = self.decode(
+ tgt_in[:, :j],
+ memory,
+ tgt_mask[:j, :j],
+ tgt_query=pos_queries[:, i:j],
+ tgt_query_mask=query_mask[i:j, :j],
+ )
+ # the next token probability is in the output's ith token position
+ p_i = self.head(tgt_out)
+ logits.append(p_i)
+ if j < num_steps:
+ # greedy decode. add the next token index to the target input
+ tgt_in[:, j] = p_i.squeeze().argmax(-1)
+ # Efficient batch decoding: If all output words have at least one EOS token, end decoding.
+ if testing and (tgt_in == tokenizer.eos_id).any(dim=-1).all():
+ break
+
+ logits = torch.cat(logits, dim=1)
+ else:
+ # No prior context, so input is just . We query all positions.
+ tgt_in = torch.full((bs, 1), tokenizer.bos_id, dtype=torch.long, device=self._device)
+ tgt_out = self.decode(tgt_in, memory, tgt_query=pos_queries)
+ logits = self.head(tgt_out)
+
+ if self.refine_iters:
+ # For iterative refinement, we always use a 'cloze' mask.
+ # We can derive it from the AR forward mask by unmasking the token context to the right.
+ query_mask[torch.triu(torch.ones(num_steps, num_steps, dtype=torch.bool, device=self._device), 2)] = 0
+ bos = torch.full((bs, 1), tokenizer.bos_id, dtype=torch.long, device=self._device)
+ for i in range(self.refine_iters):
+ # Prior context is the previous output.
+ tgt_in = torch.cat([bos, logits[:, :-1].argmax(-1)], dim=1)
+ # Mask tokens beyond the first EOS token.
+ tgt_padding_mask = (tgt_in == tokenizer.eos_id).int().cumsum(-1) > 0
+ tgt_out = self.decode(
+ tgt_in, memory, tgt_mask, tgt_padding_mask, pos_queries, query_mask[:, : tgt_in.shape[1]]
+ )
+ logits = self.head(tgt_out)
+
+ return logits
diff --git a/IndicPhotoOCR/utils/strhub/models/parseq/modules.py b/IndicPhotoOCR/utils/strhub/models/parseq/modules.py
new file mode 100644
index 0000000000000000000000000000000000000000..e3e3f23f6ee52c9de8b21df63efe7299100eb44d
--- /dev/null
+++ b/IndicPhotoOCR/utils/strhub/models/parseq/modules.py
@@ -0,0 +1,176 @@
+# Scene Text Recognition Model Hub
+# Copyright 2022 Darwin Bautista
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+from typing import Optional
+
+import torch
+from torch import Tensor, nn as nn
+from torch.nn import functional as F
+from torch.nn.modules import transformer
+
+from timm.models.vision_transformer import PatchEmbed, VisionTransformer
+
+
+class DecoderLayer(nn.Module):
+ """A Transformer decoder layer supporting two-stream attention (XLNet)
+ This implements a pre-LN decoder, as opposed to the post-LN default in PyTorch."""
+
+ def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation='gelu', layer_norm_eps=1e-5):
+ super().__init__()
+ self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=True)
+ self.cross_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=True)
+ # Implementation of Feedforward model
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
+ self.dropout = nn.Dropout(dropout)
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
+
+ self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps)
+ self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps)
+ self.norm_q = nn.LayerNorm(d_model, eps=layer_norm_eps)
+ self.norm_c = nn.LayerNorm(d_model, eps=layer_norm_eps)
+ self.dropout1 = nn.Dropout(dropout)
+ self.dropout2 = nn.Dropout(dropout)
+ self.dropout3 = nn.Dropout(dropout)
+
+ self.activation = transformer._get_activation_fn(activation)
+
+ def __setstate__(self, state):
+ if 'activation' not in state:
+ state['activation'] = F.gelu
+ super().__setstate__(state)
+
+ def forward_stream(
+ self,
+ tgt: Tensor,
+ tgt_norm: Tensor,
+ tgt_kv: Tensor,
+ memory: Tensor,
+ tgt_mask: Optional[Tensor],
+ tgt_key_padding_mask: Optional[Tensor],
+ ):
+ """Forward pass for a single stream (i.e. content or query)
+ tgt_norm is just a LayerNorm'd tgt. Added as a separate parameter for efficiency.
+ Both tgt_kv and memory are expected to be LayerNorm'd too.
+ memory is LayerNorm'd by ViT.
+ """
+ tgt2, sa_weights = self.self_attn(
+ tgt_norm, tgt_kv, tgt_kv, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask
+ )
+ tgt = tgt + self.dropout1(tgt2)
+
+ tgt2, ca_weights = self.cross_attn(self.norm1(tgt), memory, memory)
+ tgt = tgt + self.dropout2(tgt2)
+
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(self.norm2(tgt)))))
+ tgt = tgt + self.dropout3(tgt2)
+ return tgt, sa_weights, ca_weights
+
+ def forward(
+ self,
+ query,
+ content,
+ memory,
+ query_mask: Optional[Tensor] = None,
+ content_mask: Optional[Tensor] = None,
+ content_key_padding_mask: Optional[Tensor] = None,
+ update_content: bool = True,
+ ):
+ query_norm = self.norm_q(query)
+ content_norm = self.norm_c(content)
+ query = self.forward_stream(query, query_norm, content_norm, memory, query_mask, content_key_padding_mask)[0]
+ if update_content:
+ content = self.forward_stream(
+ content, content_norm, content_norm, memory, content_mask, content_key_padding_mask
+ )[0]
+ return query, content
+
+
+class Decoder(nn.Module):
+ __constants__ = ['norm']
+
+ def __init__(self, decoder_layer, num_layers, norm):
+ super().__init__()
+ self.layers = transformer._get_clones(decoder_layer, num_layers)
+ self.num_layers = num_layers
+ self.norm = norm
+
+ def forward(
+ self,
+ query,
+ content,
+ memory,
+ query_mask: Optional[Tensor] = None,
+ content_mask: Optional[Tensor] = None,
+ content_key_padding_mask: Optional[Tensor] = None,
+ ):
+ for i, mod in enumerate(self.layers):
+ last = i == len(self.layers) - 1
+ query, content = mod(
+ query, content, memory, query_mask, content_mask, content_key_padding_mask, update_content=not last
+ )
+ query = self.norm(query)
+ return query
+
+
+class Encoder(VisionTransformer):
+
+ def __init__(
+ self,
+ img_size=224,
+ patch_size=16,
+ in_chans=3,
+ embed_dim=768,
+ depth=12,
+ num_heads=12,
+ mlp_ratio=4.0,
+ qkv_bias=True,
+ drop_rate=0.0,
+ attn_drop_rate=0.0,
+ drop_path_rate=0.0,
+ embed_layer=PatchEmbed,
+ ):
+ super().__init__(
+ img_size,
+ patch_size,
+ in_chans,
+ embed_dim=embed_dim,
+ depth=depth,
+ num_heads=num_heads,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ drop_rate=drop_rate,
+ attn_drop_rate=attn_drop_rate,
+ drop_path_rate=drop_path_rate,
+ embed_layer=embed_layer,
+ num_classes=0, # These
+ global_pool='', # disable the
+ class_token=False, # classifier head.
+ )
+
+ def forward(self, x):
+ # Return all tokens
+ return self.forward_features(x)
+
+
+class TokenEmbedding(nn.Module):
+
+ def __init__(self, charset_size: int, embed_dim: int):
+ super().__init__()
+ self.embedding = nn.Embedding(charset_size, embed_dim)
+ self.embed_dim = embed_dim
+
+ def forward(self, tokens: torch.Tensor):
+ return math.sqrt(self.embed_dim) * self.embedding(tokens)
diff --git a/IndicPhotoOCR/utils/strhub/models/parseq/system.py b/IndicPhotoOCR/utils/strhub/models/parseq/system.py
new file mode 100644
index 0000000000000000000000000000000000000000..217275f8345f84ce88be97fb8932a30e51baf6ca
--- /dev/null
+++ b/IndicPhotoOCR/utils/strhub/models/parseq/system.py
@@ -0,0 +1,200 @@
+# Scene Text Recognition Model Hub
+# Copyright 2022 Darwin Bautista
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+from itertools import permutations
+from typing import Any, Optional, Sequence
+
+import numpy as np
+
+import torch
+import torch.nn.functional as F
+from torch import Tensor
+
+from pytorch_lightning.utilities.types import STEP_OUTPUT
+
+from IndicPhotoOCR.utils.strhub.models.base import CrossEntropySystem
+
+from .model import PARSeq as Model
+
+
+class PARSeq(CrossEntropySystem):
+
+ def __init__(
+ self,
+ charset_train: str,
+ charset_test: str,
+ max_label_length: int,
+ batch_size: int,
+ lr: float,
+ warmup_pct: float,
+ weight_decay: float,
+ img_size: Sequence[int],
+ patch_size: Sequence[int],
+ embed_dim: int,
+ enc_num_heads: int,
+ enc_mlp_ratio: int,
+ enc_depth: int,
+ dec_num_heads: int,
+ dec_mlp_ratio: int,
+ dec_depth: int,
+ perm_num: int,
+ perm_forward: bool,
+ perm_mirrored: bool,
+ decode_ar: bool,
+ refine_iters: int,
+ dropout: float,
+ **kwargs: Any,
+ ) -> None:
+ super().__init__(charset_train, charset_test, batch_size, lr, warmup_pct, weight_decay)
+ self.save_hyperparameters()
+
+ self.model = Model(
+ len(self.tokenizer),
+ max_label_length,
+ img_size,
+ patch_size,
+ embed_dim,
+ enc_num_heads,
+ enc_mlp_ratio,
+ enc_depth,
+ dec_num_heads,
+ dec_mlp_ratio,
+ dec_depth,
+ decode_ar,
+ refine_iters,
+ dropout,
+ )
+
+ # Perm/attn mask stuff
+ self.rng = np.random.default_rng()
+ self.max_gen_perms = perm_num // 2 if perm_mirrored else perm_num
+ self.perm_forward = perm_forward
+ self.perm_mirrored = perm_mirrored
+
+ def forward(self, images: Tensor, max_length: Optional[int] = None) -> Tensor:
+ return self.model.forward(self.tokenizer, images, max_length)
+
+ def gen_tgt_perms(self, tgt):
+ """Generate shared permutations for the whole batch.
+ This works because the same attention mask can be used for the shorter sequences
+ because of the padding mask.
+ """
+ # We don't permute the position of BOS, we permute EOS separately
+ max_num_chars = tgt.shape[1] - 2
+ # Special handling for 1-character sequences
+ if max_num_chars == 1:
+ return torch.arange(3, device=self._device).unsqueeze(0)
+ perms = [torch.arange(max_num_chars, device=self._device)] if self.perm_forward else []
+ # Additional permutations if needed
+ max_perms = math.factorial(max_num_chars)
+ if self.perm_mirrored:
+ max_perms //= 2
+ num_gen_perms = min(self.max_gen_perms, max_perms)
+ # For 4-char sequences and shorter, we generate all permutations and sample from the pool to avoid collisions
+ # Note that this code path might NEVER get executed since the labels in a mini-batch typically exceed 4 chars.
+ if max_num_chars < 5:
+ # Pool of permutations to sample from. We only need the first half (if complementary option is selected)
+ # Special handling for max_num_chars == 4 which correctly divides the pool into the flipped halves
+ if max_num_chars == 4 and self.perm_mirrored:
+ selector = [0, 3, 4, 6, 9, 10, 12, 16, 17, 18, 19, 21]
+ else:
+ selector = list(range(max_perms))
+ perm_pool = torch.as_tensor(
+ list(permutations(range(max_num_chars), max_num_chars)),
+ device=self._device,
+ )[selector]
+ # If the forward permutation is always selected, no need to add it to the pool for sampling
+ if self.perm_forward:
+ perm_pool = perm_pool[1:]
+ perms = torch.stack(perms)
+ if len(perm_pool):
+ i = self.rng.choice(len(perm_pool), size=num_gen_perms - len(perms), replace=False)
+ perms = torch.cat([perms, perm_pool[i]])
+ else:
+ perms.extend(
+ [torch.randperm(max_num_chars, device=self._device) for _ in range(num_gen_perms - len(perms))]
+ )
+ perms = torch.stack(perms)
+ if self.perm_mirrored:
+ # Add complementary pairs
+ comp = perms.flip(-1)
+ # Stack in such a way that the pairs are next to each other.
+ perms = torch.stack([perms, comp]).transpose(0, 1).reshape(-1, max_num_chars)
+ # NOTE:
+ # The only meaningful way of permuting the EOS position is by moving it one character position at a time.
+ # However, since the number of permutations = T! and number of EOS positions = T + 1, the number of possible EOS
+ # positions will always be much less than the number of permutations (unless a low perm_num is set).
+ # Thus, it would be simpler to just train EOS using the full and null contexts rather than trying to evenly
+ # distribute it across the chosen number of permutations.
+ # Add position indices of BOS and EOS
+ bos_idx = perms.new_zeros((len(perms), 1))
+ eos_idx = perms.new_full((len(perms), 1), max_num_chars + 1)
+ perms = torch.cat([bos_idx, perms + 1, eos_idx], dim=1)
+ # Special handling for the reverse direction. This does two things:
+ # 1. Reverse context for the characters
+ # 2. Null context for [EOS] (required for learning to predict [EOS] in NAR mode)
+ if len(perms) > 1:
+ perms[1, 1:] = max_num_chars + 1 - torch.arange(max_num_chars + 1, device=self._device)
+ return perms
+
+ def generate_attn_masks(self, perm):
+ """Generate attention masks given a sequence permutation (includes pos. for bos and eos tokens)
+ :param perm: the permutation sequence. i = 0 is always the BOS
+ :return: lookahead attention masks
+ """
+ sz = perm.shape[0]
+ mask = torch.zeros((sz, sz), dtype=torch.bool, device=self._device)
+ for i in range(sz):
+ query_idx = perm[i]
+ masked_keys = perm[i + 1 :]
+ mask[query_idx, masked_keys] = True
+ content_mask = mask[:-1, :-1].clone()
+ mask[torch.eye(sz, dtype=torch.bool, device=self._device)] = True # mask "self"
+ query_mask = mask[1:, :-1]
+ return content_mask, query_mask
+
+ def training_step(self, batch, batch_idx) -> STEP_OUTPUT:
+ images, labels = batch
+ tgt = self.tokenizer.encode(labels, self._device)
+
+ # Encode the source sequence (i.e. the image codes)
+ memory = self.model.encode(images)
+
+ # Prepare the target sequences (input and output)
+ tgt_perms = self.gen_tgt_perms(tgt)
+ tgt_in = tgt[:, :-1]
+ tgt_out = tgt[:, 1:]
+ # The [EOS] token is not depended upon by any other token in any permutation ordering
+ tgt_padding_mask = (tgt_in == self.pad_id) | (tgt_in == self.eos_id)
+
+ loss = 0
+ loss_numel = 0
+ n = (tgt_out != self.pad_id).sum().item()
+ for i, perm in enumerate(tgt_perms):
+ tgt_mask, query_mask = self.generate_attn_masks(perm)
+ out = self.model.decode(tgt_in, memory, tgt_mask, tgt_padding_mask, tgt_query_mask=query_mask)
+ logits = self.model.head(out).flatten(end_dim=1)
+ loss += n * F.cross_entropy(logits, tgt_out.flatten(), ignore_index=self.pad_id)
+ loss_numel += n
+ # After the second iteration (i.e. done with canonical and reverse orderings),
+ # remove the [EOS] tokens for the succeeding perms
+ if i == 1:
+ tgt_out = torch.where(tgt_out == self.eos_id, self.pad_id, tgt_out)
+ n = (tgt_out != self.pad_id).sum().item()
+ loss /= loss_numel
+
+ self.log('loss', loss)
+ return loss
diff --git a/IndicPhotoOCR/utils/strhub/models/trba/__init__.py b/IndicPhotoOCR/utils/strhub/models/trba/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a574a8af95e7f1ffaa05c45b4cd22f4a3cc0a5c0
--- /dev/null
+++ b/IndicPhotoOCR/utils/strhub/models/trba/__init__.py
@@ -0,0 +1,13 @@
+r"""
+Baek, Jeonghun, Geewook Kim, Junyeop Lee, Sungrae Park, Dongyoon Han, Sangdoo Yun, Seong Joon Oh, and Hwalsuk Lee.
+"What is wrong with scene text recognition model comparisons? dataset and model analysis."
+In Proceedings of the IEEE/CVF International Conference on Computer Vision, pp. 4715-4723. 2019.
+
+https://arxiv.org/abs/1904.01906
+
+All source files, except `system.py`, are based on the implementation listed below,
+and hence are released under the license of the original.
+
+Source: https://github.com/clovaai/deep-text-recognition-benchmark
+License: Apache License 2.0 (see LICENSE file in project root)
+"""
diff --git a/IndicPhotoOCR/utils/strhub/models/trba/feature_extraction.py b/IndicPhotoOCR/utils/strhub/models/trba/feature_extraction.py
new file mode 100644
index 0000000000000000000000000000000000000000..17646e3ff83ad28c1021237824a838e38c3b6345
--- /dev/null
+++ b/IndicPhotoOCR/utils/strhub/models/trba/feature_extraction.py
@@ -0,0 +1,110 @@
+import torch.nn as nn
+
+from torchvision.models.resnet import BasicBlock
+
+
+class ResNet_FeatureExtractor(nn.Module):
+ """ FeatureExtractor of FAN (http://openaccess.thecvf.com/content_ICCV_2017/papers/Cheng_Focusing_Attention_Towards_ICCV_2017_paper.pdf) """
+
+ def __init__(self, input_channel, output_channel=512):
+ super().__init__()
+ self.ConvNet = ResNet(input_channel, output_channel, BasicBlock, [1, 2, 5, 3])
+
+ def forward(self, input):
+ return self.ConvNet(input)
+
+
+class ResNet(nn.Module):
+
+ def __init__(self, input_channel, output_channel, block, layers):
+ super().__init__()
+
+ self.output_channel_block = [int(output_channel / 4), int(output_channel / 2), output_channel, output_channel]
+
+ self.inplanes = int(output_channel / 8)
+ self.conv0_1 = nn.Conv2d(input_channel, int(output_channel / 16),
+ kernel_size=3, stride=1, padding=1, bias=False)
+ self.bn0_1 = nn.BatchNorm2d(int(output_channel / 16))
+ self.conv0_2 = nn.Conv2d(int(output_channel / 16), self.inplanes,
+ kernel_size=3, stride=1, padding=1, bias=False)
+ self.bn0_2 = nn.BatchNorm2d(self.inplanes)
+ self.relu = nn.ReLU(inplace=True)
+
+ self.maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
+ self.layer1 = self._make_layer(block, self.output_channel_block[0], layers[0])
+ self.conv1 = nn.Conv2d(self.output_channel_block[0], self.output_channel_block[
+ 0], kernel_size=3, stride=1, padding=1, bias=False)
+ self.bn1 = nn.BatchNorm2d(self.output_channel_block[0])
+
+ self.maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
+ self.layer2 = self._make_layer(block, self.output_channel_block[1], layers[1], stride=1)
+ self.conv2 = nn.Conv2d(self.output_channel_block[1], self.output_channel_block[
+ 1], kernel_size=3, stride=1, padding=1, bias=False)
+ self.bn2 = nn.BatchNorm2d(self.output_channel_block[1])
+
+ self.maxpool3 = nn.MaxPool2d(kernel_size=2, stride=(2, 1), padding=(0, 1))
+ self.layer3 = self._make_layer(block, self.output_channel_block[2], layers[2], stride=1)
+ self.conv3 = nn.Conv2d(self.output_channel_block[2], self.output_channel_block[
+ 2], kernel_size=3, stride=1, padding=1, bias=False)
+ self.bn3 = nn.BatchNorm2d(self.output_channel_block[2])
+
+ self.layer4 = self._make_layer(block, self.output_channel_block[3], layers[3], stride=1)
+ self.conv4_1 = nn.Conv2d(self.output_channel_block[3], self.output_channel_block[
+ 3], kernel_size=2, stride=(2, 1), padding=(0, 1), bias=False)
+ self.bn4_1 = nn.BatchNorm2d(self.output_channel_block[3])
+ self.conv4_2 = nn.Conv2d(self.output_channel_block[3], self.output_channel_block[
+ 3], kernel_size=2, stride=1, padding=0, bias=False)
+ self.bn4_2 = nn.BatchNorm2d(self.output_channel_block[3])
+
+ def _make_layer(self, block, planes, blocks, stride=1):
+ downsample = None
+ if stride != 1 or self.inplanes != planes * block.expansion:
+ downsample = nn.Sequential(
+ nn.Conv2d(self.inplanes, planes * block.expansion,
+ kernel_size=1, stride=stride, bias=False),
+ nn.BatchNorm2d(planes * block.expansion),
+ )
+
+ layers = []
+ layers.append(block(self.inplanes, planes, stride, downsample))
+ self.inplanes = planes * block.expansion
+ for i in range(1, blocks):
+ layers.append(block(self.inplanes, planes))
+
+ return nn.Sequential(*layers)
+
+ def forward(self, x):
+ x = self.conv0_1(x)
+ x = self.bn0_1(x)
+ x = self.relu(x)
+ x = self.conv0_2(x)
+ x = self.bn0_2(x)
+ x = self.relu(x)
+
+ x = self.maxpool1(x)
+ x = self.layer1(x)
+ x = self.conv1(x)
+ x = self.bn1(x)
+ x = self.relu(x)
+
+ x = self.maxpool2(x)
+ x = self.layer2(x)
+ x = self.conv2(x)
+ x = self.bn2(x)
+ x = self.relu(x)
+
+ x = self.maxpool3(x)
+ x = self.layer3(x)
+ x = self.conv3(x)
+ x = self.bn3(x)
+ x = self.relu(x)
+
+ x = self.layer4(x)
+ x = self.conv4_1(x)
+ x = self.bn4_1(x)
+ x = self.relu(x)
+ x = self.conv4_2(x)
+ x = self.bn4_2(x)
+ x = self.relu(x)
+
+ return x
diff --git a/IndicPhotoOCR/utils/strhub/models/trba/model.py b/IndicPhotoOCR/utils/strhub/models/trba/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..41161a4df4e2ff368bfe1c62f681c6964510a0c0
--- /dev/null
+++ b/IndicPhotoOCR/utils/strhub/models/trba/model.py
@@ -0,0 +1,55 @@
+import torch.nn as nn
+
+from strhub.models.modules import BidirectionalLSTM
+from .feature_extraction import ResNet_FeatureExtractor
+from .prediction import Attention
+from .transformation import TPS_SpatialTransformerNetwork
+
+
+class TRBA(nn.Module):
+
+ def __init__(self, img_h, img_w, num_class, num_fiducial=20, input_channel=3, output_channel=512, hidden_size=256,
+ use_ctc=False):
+ super().__init__()
+ """ Transformation """
+ self.Transformation = TPS_SpatialTransformerNetwork(
+ F=num_fiducial, I_size=(img_h, img_w), I_r_size=(img_h, img_w),
+ I_channel_num=input_channel)
+
+ """ FeatureExtraction """
+ self.FeatureExtraction = ResNet_FeatureExtractor(input_channel, output_channel)
+ self.FeatureExtraction_output = output_channel
+ self.AdaptiveAvgPool = nn.AdaptiveAvgPool2d((None, 1)) # Transform final (imgH/16-1) -> 1
+
+ """ Sequence modeling"""
+ self.SequenceModeling = nn.Sequential(
+ BidirectionalLSTM(self.FeatureExtraction_output, hidden_size, hidden_size),
+ BidirectionalLSTM(hidden_size, hidden_size, hidden_size))
+ self.SequenceModeling_output = hidden_size
+
+ """ Prediction """
+ if use_ctc:
+ self.Prediction = nn.Linear(self.SequenceModeling_output, num_class)
+ else:
+ self.Prediction = Attention(self.SequenceModeling_output, hidden_size, num_class)
+
+ def forward(self, image, max_label_length, text=None):
+ """ Transformation stage """
+ image = self.Transformation(image)
+
+ """ Feature extraction stage """
+ visual_feature = self.FeatureExtraction(image)
+ visual_feature = visual_feature.permute(0, 3, 1, 2) # [b, c, h, w] -> [b, w, c, h]
+ visual_feature = self.AdaptiveAvgPool(visual_feature) # [b, w, c, h] -> [b, w, c, 1]
+ visual_feature = visual_feature.squeeze(3) # [b, w, c, 1] -> [b, w, c]
+
+ """ Sequence modeling stage """
+ contextual_feature = self.SequenceModeling(visual_feature) # [b, num_steps, hidden_size]
+
+ """ Prediction stage """
+ if isinstance(self.Prediction, Attention):
+ prediction = self.Prediction(contextual_feature.contiguous(), text, max_label_length)
+ else:
+ prediction = self.Prediction(contextual_feature.contiguous()) # CTC
+
+ return prediction # [b, num_steps, num_class]
diff --git a/IndicPhotoOCR/utils/strhub/models/trba/prediction.py b/IndicPhotoOCR/utils/strhub/models/trba/prediction.py
new file mode 100644
index 0000000000000000000000000000000000000000..5609398a28ef5288d3f3971786c2cebc2e574336
--- /dev/null
+++ b/IndicPhotoOCR/utils/strhub/models/trba/prediction.py
@@ -0,0 +1,73 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class Attention(nn.Module):
+
+ def __init__(self, input_size, hidden_size, num_class, num_char_embeddings=256):
+ super().__init__()
+ self.attention_cell = AttentionCell(input_size, hidden_size, num_char_embeddings)
+ self.hidden_size = hidden_size
+ self.num_class = num_class
+ self.generator = nn.Linear(hidden_size, num_class)
+ self.char_embeddings = nn.Embedding(num_class, num_char_embeddings)
+
+ def forward(self, batch_H, text, max_label_length=25):
+ """
+ input:
+ batch_H : contextual_feature H = hidden state of encoder. [batch_size x num_steps x num_class]
+ text : the text-index of each image. [batch_size x (max_length+1)]. +1 for [SOS] token. text[:, 0] = [SOS].
+ output: probability distribution at each step [batch_size x num_steps x num_class]
+ """
+ batch_size = batch_H.size(0)
+ num_steps = max_label_length + 1 # +1 for [EOS] at end of sentence.
+
+ output_hiddens = batch_H.new_zeros((batch_size, num_steps, self.hidden_size), dtype=torch.float)
+ hidden = (batch_H.new_zeros((batch_size, self.hidden_size), dtype=torch.float),
+ batch_H.new_zeros((batch_size, self.hidden_size), dtype=torch.float))
+
+ if self.training:
+ for i in range(num_steps):
+ char_embeddings = self.char_embeddings(text[:, i])
+ # hidden : decoder's hidden s_{t-1}, batch_H : encoder's hidden H, char_embeddings : f(y_{t-1})
+ hidden, alpha = self.attention_cell(hidden, batch_H, char_embeddings)
+ output_hiddens[:, i, :] = hidden[0] # LSTM hidden index (0: hidden, 1: Cell)
+ probs = self.generator(output_hiddens)
+
+ else:
+ targets = text[0].expand(batch_size) # should be fill with [SOS] token
+ probs = batch_H.new_zeros((batch_size, num_steps, self.num_class), dtype=torch.float)
+
+ for i in range(num_steps):
+ char_embeddings = self.char_embeddings(targets)
+ hidden, alpha = self.attention_cell(hidden, batch_H, char_embeddings)
+ probs_step = self.generator(hidden[0])
+ probs[:, i, :] = probs_step
+ _, next_input = probs_step.max(1)
+ targets = next_input
+
+ return probs # batch_size x num_steps x num_class
+
+
+class AttentionCell(nn.Module):
+
+ def __init__(self, input_size, hidden_size, num_embeddings):
+ super().__init__()
+ self.i2h = nn.Linear(input_size, hidden_size, bias=False)
+ self.h2h = nn.Linear(hidden_size, hidden_size) # either i2i or h2h should have bias
+ self.score = nn.Linear(hidden_size, 1, bias=False)
+ self.rnn = nn.LSTMCell(input_size + num_embeddings, hidden_size)
+ self.hidden_size = hidden_size
+
+ def forward(self, prev_hidden, batch_H, char_embeddings):
+ # [batch_size x num_encoder_step x num_channel] -> [batch_size x num_encoder_step x hidden_size]
+ batch_H_proj = self.i2h(batch_H)
+ prev_hidden_proj = self.h2h(prev_hidden[0]).unsqueeze(1)
+ e = self.score(torch.tanh(batch_H_proj + prev_hidden_proj)) # batch_size x num_encoder_step * 1
+
+ alpha = F.softmax(e, dim=1)
+ context = torch.bmm(alpha.permute(0, 2, 1), batch_H).squeeze(1) # batch_size x num_channel
+ concat_context = torch.cat([context, char_embeddings], 1) # batch_size x (num_channel + num_embedding)
+ cur_hidden = self.rnn(concat_context, prev_hidden)
+ return cur_hidden, alpha
diff --git a/IndicPhotoOCR/utils/strhub/models/trba/system.py b/IndicPhotoOCR/utils/strhub/models/trba/system.py
new file mode 100644
index 0000000000000000000000000000000000000000..eabc5bacf3ef0a61d6b50cb1707c7da9eb2cb930
--- /dev/null
+++ b/IndicPhotoOCR/utils/strhub/models/trba/system.py
@@ -0,0 +1,125 @@
+# Scene Text Recognition Model Hub
+# Copyright 2022 Darwin Bautista
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from functools import partial
+from typing import Any, Optional, Sequence
+
+import torch
+import torch.nn.functional as F
+from torch import Tensor
+
+from pytorch_lightning.utilities.types import STEP_OUTPUT
+from timm.models.helpers import named_apply
+
+from strhub.models.base import CrossEntropySystem, CTCSystem
+from strhub.models.utils import init_weights
+
+from .model import TRBA as Model
+
+
+class TRBA(CrossEntropySystem):
+
+ def __init__(
+ self,
+ charset_train: str,
+ charset_test: str,
+ max_label_length: int,
+ batch_size: int,
+ lr: float,
+ warmup_pct: float,
+ weight_decay: float,
+ img_size: Sequence[int],
+ num_fiducial: int,
+ output_channel: int,
+ hidden_size: int,
+ **kwargs: Any,
+ ) -> None:
+ super().__init__(charset_train, charset_test, batch_size, lr, warmup_pct, weight_decay)
+ self.save_hyperparameters()
+ self.max_label_length = max_label_length
+ img_h, img_w = img_size
+ self.model = Model(
+ img_h,
+ img_w,
+ len(self.tokenizer),
+ num_fiducial,
+ output_channel=output_channel,
+ hidden_size=hidden_size,
+ use_ctc=False,
+ )
+ named_apply(partial(init_weights, exclude=['Transformation.LocalizationNetwork.localization_fc2']), self.model)
+
+ @torch.jit.ignore
+ def no_weight_decay(self):
+ return {'model.Prediction.char_embeddings.weight'}
+
+ def forward(self, images: Tensor, max_length: Optional[int] = None) -> Tensor:
+ max_length = self.max_label_length if max_length is None else min(max_length, self.max_label_length)
+ text = images.new_full([1], self.bos_id, dtype=torch.long)
+ return self.model.forward(images, max_length, text)
+
+ def training_step(self, batch, batch_idx) -> STEP_OUTPUT:
+ images, labels = batch
+ encoded = self.tokenizer.encode(labels, self.device)
+ inputs = encoded[:, :-1] # remove
+ targets = encoded[:, 1:] # remove
+ max_length = encoded.shape[1] - 2 # exclude and from count
+ logits = self.model.forward(images, max_length, inputs)
+ loss = F.cross_entropy(logits.flatten(end_dim=1), targets.flatten(), ignore_index=self.pad_id)
+ self.log('loss', loss)
+ return loss
+
+
+class TRBC(CTCSystem):
+
+ def __init__(
+ self,
+ charset_train: str,
+ charset_test: str,
+ max_label_length: int,
+ batch_size: int,
+ lr: float,
+ warmup_pct: float,
+ weight_decay: float,
+ img_size: Sequence[int],
+ num_fiducial: int,
+ output_channel: int,
+ hidden_size: int,
+ **kwargs: Any,
+ ) -> None:
+ super().__init__(charset_train, charset_test, batch_size, lr, warmup_pct, weight_decay)
+ self.save_hyperparameters()
+ self.max_label_length = max_label_length
+ img_h, img_w = img_size
+ self.model = Model(
+ img_h,
+ img_w,
+ len(self.tokenizer),
+ num_fiducial,
+ output_channel=output_channel,
+ hidden_size=hidden_size,
+ use_ctc=True,
+ )
+ named_apply(partial(init_weights, exclude=['Transformation.LocalizationNetwork.localization_fc2']), self.model)
+
+ def forward(self, images: Tensor, max_length: Optional[int] = None) -> Tensor:
+ # max_label_length is unused in CTC prediction
+ return self.model.forward(images, None)
+
+ def training_step(self, batch, batch_idx) -> STEP_OUTPUT:
+ images, labels = batch
+ loss = self.forward_logits_loss(images, labels)[1]
+ self.log('loss', loss)
+ return loss
diff --git a/IndicPhotoOCR/utils/strhub/models/trba/transformation.py b/IndicPhotoOCR/utils/strhub/models/trba/transformation.py
new file mode 100644
index 0000000000000000000000000000000000000000..960419d135ec878aaaa3297c3ff5c22e998ef6be
--- /dev/null
+++ b/IndicPhotoOCR/utils/strhub/models/trba/transformation.py
@@ -0,0 +1,169 @@
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class TPS_SpatialTransformerNetwork(nn.Module):
+ """ Rectification Network of RARE, namely TPS based STN """
+
+ def __init__(self, F, I_size, I_r_size, I_channel_num=1):
+ """ Based on RARE TPS
+ input:
+ batch_I: Batch Input Image [batch_size x I_channel_num x I_height x I_width]
+ I_size : (height, width) of the input image I
+ I_r_size : (height, width) of the rectified image I_r
+ I_channel_num : the number of channels of the input image I
+ output:
+ batch_I_r: rectified image [batch_size x I_channel_num x I_r_height x I_r_width]
+ """
+ super().__init__()
+ self.F = F
+ self.I_size = I_size
+ self.I_r_size = I_r_size # = (I_r_height, I_r_width)
+ self.I_channel_num = I_channel_num
+ self.LocalizationNetwork = LocalizationNetwork(self.F, self.I_channel_num)
+ self.GridGenerator = GridGenerator(self.F, self.I_r_size)
+
+ def forward(self, batch_I):
+ batch_C_prime = self.LocalizationNetwork(batch_I) # batch_size x K x 2
+ # batch_size x n (= I_r_width x I_r_height) x 2
+ build_P_prime = self.GridGenerator.build_P_prime(batch_C_prime)
+ build_P_prime_reshape = build_P_prime.reshape([build_P_prime.size(0), self.I_r_size[0], self.I_r_size[1], 2])
+
+ if torch.__version__ > "1.2.0":
+ batch_I_r = F.grid_sample(batch_I, build_P_prime_reshape, padding_mode='border', align_corners=True)
+ else:
+ batch_I_r = F.grid_sample(batch_I, build_P_prime_reshape, padding_mode='border')
+
+ return batch_I_r
+
+
+class LocalizationNetwork(nn.Module):
+ """ Localization Network of RARE, which predicts C' (K x 2) from I (I_width x I_height) """
+
+ def __init__(self, F, I_channel_num):
+ super().__init__()
+ self.F = F
+ self.I_channel_num = I_channel_num
+ self.conv = nn.Sequential(
+ nn.Conv2d(in_channels=self.I_channel_num, out_channels=64, kernel_size=3, stride=1, padding=1,
+ bias=False), nn.BatchNorm2d(64), nn.ReLU(True),
+ nn.MaxPool2d(2, 2), # batch_size x 64 x I_height/2 x I_width/2
+ nn.Conv2d(64, 128, 3, 1, 1, bias=False), nn.BatchNorm2d(128), nn.ReLU(True),
+ nn.MaxPool2d(2, 2), # batch_size x 128 x I_height/4 x I_width/4
+ nn.Conv2d(128, 256, 3, 1, 1, bias=False), nn.BatchNorm2d(256), nn.ReLU(True),
+ nn.MaxPool2d(2, 2), # batch_size x 256 x I_height/8 x I_width/8
+ nn.Conv2d(256, 512, 3, 1, 1, bias=False), nn.BatchNorm2d(512), nn.ReLU(True),
+ nn.AdaptiveAvgPool2d(1) # batch_size x 512
+ )
+
+ self.localization_fc1 = nn.Sequential(nn.Linear(512, 256), nn.ReLU(True))
+ self.localization_fc2 = nn.Linear(256, self.F * 2)
+
+ # Init fc2 in LocalizationNetwork
+ self.localization_fc2.weight.data.fill_(0)
+ """ see RARE paper Fig. 6 (a) """
+ ctrl_pts_x = np.linspace(-1.0, 1.0, int(F / 2))
+ ctrl_pts_y_top = np.linspace(0.0, -1.0, num=int(F / 2))
+ ctrl_pts_y_bottom = np.linspace(1.0, 0.0, num=int(F / 2))
+ ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1)
+ ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1)
+ initial_bias = np.concatenate([ctrl_pts_top, ctrl_pts_bottom], axis=0)
+ self.localization_fc2.bias.data = torch.from_numpy(initial_bias).float().view(-1)
+
+ def forward(self, batch_I):
+ """
+ input: batch_I : Batch Input Image [batch_size x I_channel_num x I_height x I_width]
+ output: batch_C_prime : Predicted coordinates of fiducial points for input batch [batch_size x F x 2]
+ """
+ batch_size = batch_I.size(0)
+ features = self.conv(batch_I).view(batch_size, -1)
+ batch_C_prime = self.localization_fc2(self.localization_fc1(features)).view(batch_size, self.F, 2)
+ return batch_C_prime
+
+
+class GridGenerator(nn.Module):
+ """ Grid Generator of RARE, which produces P_prime by multipling T with P """
+
+ def __init__(self, F, I_r_size):
+ """ Generate P_hat and inv_delta_C for later """
+ super().__init__()
+ self.eps = 1e-6
+ self.I_r_height, self.I_r_width = I_r_size
+ self.F = F
+ self.C = self._build_C(self.F) # F x 2
+ self.P = self._build_P(self.I_r_width, self.I_r_height)
+
+ # num_gpu = torch.cuda.device_count()
+ # if num_gpu > 1:
+ # for multi-gpu, you may need register buffer
+ self.register_buffer("inv_delta_C", torch.tensor(
+ self._build_inv_delta_C(self.F, self.C)).float()) # F+3 x F+3
+ self.register_buffer("P_hat", torch.tensor(self._build_P_hat(self.F, self.C, self.P)).float()) # n x F+3
+ # else:
+ # # for fine-tuning with different image width, you may use below instead of self.register_buffer
+ # self.inv_delta_C = torch.tensor(self._build_inv_delta_C(self.F, self.C)).float() # F+3 x F+3
+ # self.P_hat = torch.tensor(self._build_P_hat(self.F, self.C, self.P)).float() # n x F+3
+
+ def _build_C(self, F):
+ """ Return coordinates of fiducial points in I_r; C """
+ ctrl_pts_x = np.linspace(-1.0, 1.0, int(F / 2))
+ ctrl_pts_y_top = -1 * np.ones(int(F / 2))
+ ctrl_pts_y_bottom = np.ones(int(F / 2))
+ ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1)
+ ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1)
+ C = np.concatenate([ctrl_pts_top, ctrl_pts_bottom], axis=0)
+ return C # F x 2
+
+ def _build_inv_delta_C(self, F, C):
+ """ Return inv_delta_C which is needed to calculate T """
+ hat_C = np.zeros((F, F), dtype=float) # F x F
+ for i in range(0, F):
+ for j in range(i, F):
+ r = np.linalg.norm(C[i] - C[j])
+ hat_C[i, j] = r
+ hat_C[j, i] = r
+ np.fill_diagonal(hat_C, 1)
+ hat_C = (hat_C ** 2) * np.log(hat_C)
+ # print(C.shape, hat_C.shape)
+ delta_C = np.concatenate( # F+3 x F+3
+ [
+ np.concatenate([np.ones((F, 1)), C, hat_C], axis=1), # F x F+3
+ np.concatenate([np.zeros((2, 3)), np.transpose(C)], axis=1), # 2 x F+3
+ np.concatenate([np.zeros((1, 3)), np.ones((1, F))], axis=1) # 1 x F+3
+ ],
+ axis=0
+ )
+ inv_delta_C = np.linalg.inv(delta_C)
+ return inv_delta_C # F+3 x F+3
+
+ def _build_P(self, I_r_width, I_r_height):
+ I_r_grid_x = (np.arange(-I_r_width, I_r_width, 2) + 1.0) / I_r_width # self.I_r_width
+ I_r_grid_y = (np.arange(-I_r_height, I_r_height, 2) + 1.0) / I_r_height # self.I_r_height
+ P = np.stack( # self.I_r_width x self.I_r_height x 2
+ np.meshgrid(I_r_grid_x, I_r_grid_y),
+ axis=2
+ )
+ return P.reshape([-1, 2]) # n (= self.I_r_width x self.I_r_height) x 2
+
+ def _build_P_hat(self, F, C, P):
+ n = P.shape[0] # n (= self.I_r_width x self.I_r_height)
+ P_tile = np.tile(np.expand_dims(P, axis=1), (1, F, 1)) # n x 2 -> n x 1 x 2 -> n x F x 2
+ C_tile = np.expand_dims(C, axis=0) # 1 x F x 2
+ P_diff = P_tile - C_tile # n x F x 2
+ rbf_norm = np.linalg.norm(P_diff, ord=2, axis=2, keepdims=False) # n x F
+ rbf = np.multiply(np.square(rbf_norm), np.log(rbf_norm + self.eps)) # n x F
+ P_hat = np.concatenate([np.ones((n, 1)), P, rbf], axis=1)
+ return P_hat # n x F+3
+
+ def build_P_prime(self, batch_C_prime):
+ """ Generate Grid from batch_C_prime [batch_size x F x 2] """
+ batch_size = batch_C_prime.size(0)
+ batch_inv_delta_C = self.inv_delta_C.repeat(batch_size, 1, 1)
+ batch_P_hat = self.P_hat.repeat(batch_size, 1, 1)
+ batch_C_prime_with_zeros = torch.cat((batch_C_prime, batch_C_prime.new_zeros(
+ (batch_size, 3, 2), dtype=torch.float)), dim=1) # batch_size x F+3 x 2
+ batch_T = torch.bmm(batch_inv_delta_C, batch_C_prime_with_zeros) # batch_size x F+3 x 2
+ batch_P_prime = torch.bmm(batch_P_hat, batch_T) # batch_size x n x 2
+ return batch_P_prime # batch_size x n x 2
diff --git a/IndicPhotoOCR/utils/strhub/models/utils.py b/IndicPhotoOCR/utils/strhub/models/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..4debeb25999c2c10e105bb3e32d5ae6f8d04ed9d
--- /dev/null
+++ b/IndicPhotoOCR/utils/strhub/models/utils.py
@@ -0,0 +1,125 @@
+from pathlib import PurePath
+from typing import Sequence
+
+import yaml
+
+import torch
+from torch import nn
+
+
+class InvalidModelError(RuntimeError):
+ """Exception raised for any model-related error (creation, loading)"""
+
+
+_WEIGHTS_URL = {
+ 'parseq-tiny': 'https://github.com/baudm/parseq/releases/download/v1.0.0/parseq_tiny-e7a21b54.pt',
+ 'parseq-patch16-224': 'https://github.com/baudm/parseq/releases/download/v1.0.0/parseq_small_patch16_224-fcf06f5a.pt',
+ 'parseq': 'https://github.com/baudm/parseq/releases/download/v1.0.0/parseq-bb5792a6.pt',
+ 'abinet': 'https://github.com/baudm/parseq/releases/download/v1.0.0/abinet-1d1e373e.pt',
+ 'trba': 'https://github.com/baudm/parseq/releases/download/v1.0.0/trba-cfaed284.pt',
+ 'vitstr': 'https://github.com/baudm/parseq/releases/download/v1.0.0/vitstr-26d0fcf4.pt',
+ 'crnn': 'https://github.com/baudm/parseq/releases/download/v1.0.0/crnn-679d0e31.pt',
+}
+
+
+def _get_config(experiment: str, **kwargs):
+ """Emulates hydra config resolution"""
+ root = PurePath(__file__).parents[2]
+ with open(root / 'configs/main.yaml', 'r') as f:
+ config = yaml.load(f, yaml.Loader)['model']
+ with open(root / 'configs/charset/94_full.yaml', 'r') as f:
+ config.update(yaml.load(f, yaml.Loader)['model'])
+ with open(root / f'configs/experiment/{experiment}.yaml', 'r') as f:
+ exp = yaml.load(f, yaml.Loader)
+ # Apply base model config
+ model = exp['defaults'][0]['override /model']
+ with open(root / f'configs/model/{model}.yaml', 'r') as f:
+ config.update(yaml.load(f, yaml.Loader))
+ # Apply experiment config
+ if 'model' in exp:
+ config.update(exp['model'])
+ config.update(kwargs)
+ # Workaround for now: manually cast the lr to the correct type.
+ config['lr'] = float(config['lr'])
+ return config
+
+
+def _get_model_class(key):
+ if 'abinet' in key:
+ from .abinet.system import ABINet as ModelClass
+ elif 'crnn' in key:
+ from .crnn.system import CRNN as ModelClass
+ elif 'parseq' in key:
+ from .parseq.system import PARSeq as ModelClass
+ elif 'trba' in key:
+ from .trba.system import TRBA as ModelClass
+ elif 'trbc' in key:
+ from .trba.system import TRBC as ModelClass
+ elif 'vitstr' in key:
+ from .vitstr.system import ViTSTR as ModelClass
+ else:
+ from .parseq.system import PARSeq as ModelClass
+ return ModelClass
+
+
+def get_pretrained_weights(experiment):
+ try:
+ url = _WEIGHTS_URL[experiment]
+ except KeyError:
+ raise InvalidModelError(f"No pretrained weights found for '{experiment}'") from None
+ return torch.hub.load_state_dict_from_url(url=url, map_location='cpu', check_hash=True)
+
+
+def create_model(experiment: str, pretrained: bool = False, **kwargs):
+ try:
+ config = _get_config(experiment, **kwargs)
+ except FileNotFoundError:
+ raise InvalidModelError(f"No configuration found for '{experiment}'") from None
+ ModelClass = _get_model_class(experiment)
+ model = ModelClass(**config)
+ if pretrained:
+ m = model.model if 'parseq' in experiment else model
+ m.load_state_dict(get_pretrained_weights(experiment))
+ return model
+
+
+def load_from_checkpoint(checkpoint_path: str, **kwargs):
+ if checkpoint_path.startswith('pretrained='):
+ model_id = checkpoint_path.split('=', maxsplit=1)[1]
+ model = create_model(model_id, True, **kwargs)
+ else:
+ ModelClass = _get_model_class(checkpoint_path)
+ model = ModelClass.load_from_checkpoint(checkpoint_path, **kwargs)
+ return model
+
+
+def parse_model_args(args):
+ kwargs = {}
+ arg_types = {t.__name__: t for t in [int, float, str]}
+ arg_types['bool'] = lambda v: v.lower() == 'true' # special handling for bool
+ for arg in args:
+ name, value = arg.split('=', maxsplit=1)
+ name, arg_type = name.split(':', maxsplit=1)
+ kwargs[name] = arg_types[arg_type](value)
+ return kwargs
+
+
+def init_weights(module: nn.Module, name: str = '', exclude: Sequence[str] = ()):
+ """Initialize the weights using the typical initialization schemes used in SOTA models."""
+ if any(map(name.startswith, exclude)):
+ return
+ if isinstance(module, nn.Linear):
+ nn.init.trunc_normal_(module.weight, std=0.02)
+ if module.bias is not None:
+ nn.init.zeros_(module.bias)
+ elif isinstance(module, nn.Embedding):
+ nn.init.trunc_normal_(module.weight, std=0.02)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+ elif isinstance(module, nn.Conv2d):
+ nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu')
+ if module.bias is not None:
+ nn.init.zeros_(module.bias)
+ elif isinstance(module, (nn.LayerNorm, nn.BatchNorm2d, nn.GroupNorm)):
+ nn.init.ones_(module.weight)
+ nn.init.zeros_(module.bias)
diff --git a/IndicPhotoOCR/utils/strhub/models/vitstr/__init__.py b/IndicPhotoOCR/utils/strhub/models/vitstr/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..19e985679da1fcaa6deb306697993fd601892d6c
--- /dev/null
+++ b/IndicPhotoOCR/utils/strhub/models/vitstr/__init__.py
@@ -0,0 +1,12 @@
+r"""
+Atienza, Rowel. "Vision Transformer for Fast and Efficient Scene Text Recognition."
+In International Conference on Document Analysis and Recognition (ICDAR). 2021.
+
+https://arxiv.org/abs/2105.08582
+
+All source files, except `system.py`, are based on the implementation listed below,
+and hence are released under the license of the original.
+
+Source: https://github.com/roatienza/deep-text-recognition-benchmark
+License: Apache License 2.0 (see LICENSE file in project root)
+"""
diff --git a/IndicPhotoOCR/utils/strhub/models/vitstr/model.py b/IndicPhotoOCR/utils/strhub/models/vitstr/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..62c5d551626c325243a4f0d055869384a59b3910
--- /dev/null
+++ b/IndicPhotoOCR/utils/strhub/models/vitstr/model.py
@@ -0,0 +1,28 @@
+"""
+Implementation of ViTSTR based on timm VisionTransformer.
+
+TODO:
+1) distilled deit backbone
+2) base deit backbone
+
+Copyright 2021 Rowel Atienza
+"""
+
+from timm.models.vision_transformer import VisionTransformer
+
+
+class ViTSTR(VisionTransformer):
+ """
+ ViTSTR is basically a ViT that uses DeiT weights.
+ Modified head to support a sequence of characters prediction for STR.
+ """
+
+ def forward(self, x, seqlen: int = 25):
+ x = self.forward_features(x)
+ x = x[:, :seqlen]
+
+ # batch, seqlen, embsize
+ b, s, e = x.size()
+ x = x.reshape(b * s, e)
+ x = self.head(x).view(b, s, self.num_classes)
+ return x
diff --git a/IndicPhotoOCR/utils/strhub/models/vitstr/system.py b/IndicPhotoOCR/utils/strhub/models/vitstr/system.py
new file mode 100644
index 0000000000000000000000000000000000000000..37b762e1a074055873413655e35ef2605ffa8238
--- /dev/null
+++ b/IndicPhotoOCR/utils/strhub/models/vitstr/system.py
@@ -0,0 +1,79 @@
+# Scene Text Recognition Model Hub
+# Copyright 2022 Darwin Bautista
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Any, Optional, Sequence
+
+import torch
+from torch import Tensor
+
+from pytorch_lightning.utilities.types import STEP_OUTPUT
+
+from strhub.models.base import CrossEntropySystem
+from strhub.models.utils import init_weights
+
+from .model import ViTSTR as Model
+
+
+class ViTSTR(CrossEntropySystem):
+
+ def __init__(
+ self,
+ charset_train: str,
+ charset_test: str,
+ max_label_length: int,
+ batch_size: int,
+ lr: float,
+ warmup_pct: float,
+ weight_decay: float,
+ img_size: Sequence[int],
+ patch_size: Sequence[int],
+ embed_dim: int,
+ num_heads: int,
+ **kwargs: Any,
+ ) -> None:
+ super().__init__(charset_train, charset_test, batch_size, lr, warmup_pct, weight_decay)
+ self.save_hyperparameters()
+ self.max_label_length = max_label_length
+ # We don't predict nor
+ self.model = Model(
+ img_size=img_size,
+ patch_size=patch_size,
+ depth=12,
+ mlp_ratio=4,
+ qkv_bias=True,
+ embed_dim=embed_dim,
+ num_heads=num_heads,
+ num_classes=len(self.tokenizer) - 2,
+ )
+ # Non-zero weight init for the head
+ self.model.head.apply(init_weights)
+
+ @torch.jit.ignore
+ def no_weight_decay(self):
+ return {'model.' + n for n in self.model.no_weight_decay()}
+
+ def forward(self, images: Tensor, max_length: Optional[int] = None) -> Tensor:
+ max_length = self.max_label_length if max_length is None else min(max_length, self.max_label_length)
+ logits = self.model.forward(images, max_length + 2) # +2 tokens for [GO] and [s]
+ # Truncate to conform to other models. [GO] in ViTSTR is actually used as the padding (therefore, ignored).
+ # First position corresponds to the class token, which is unused and ignored in the original work.
+ logits = logits[:, 1:]
+ return logits
+
+ def training_step(self, batch, batch_idx) -> STEP_OUTPUT:
+ images, labels = batch
+ loss = self.forward_logits_loss(images, labels)[1]
+ self.log('loss', loss)
+ return loss
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..e36a8980c85818dc36895426b19ca07d36b357ad
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,21 @@
+MIT License
+
+Copyright (c) 2024 Bhashini Team@IIT Jodhpur
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
diff --git a/README.md b/README.md
index bae7d168cd0929a190590466ad01469fb0bb3d3f..a93025300a749b0561c5b0bbe04260914aaf9696 100644
--- a/README.md
+++ b/README.md
@@ -1,12 +1,206 @@
---
-title: Scene Text Translator
-emoji: 🏃
-colorFrom: gray
-colorTo: yellow
-sdk: gradio
-sdk_version: 5.12.0
+title: "Scene-Text-Translator"
+colorFrom: "purple"
+colorTo: "pink"
+sdk: "gradio"
+python_version: "3.9"
+sdk_version: "4.44.0"
app_file: app.py
-pinned: false
+pinned: True
+CPU: "cpu-basic"
+suggested_storage : "small"
+app_port: 7865
---
-Check out the configuration reference at https://huggingface.co./docs/hub/spaces-config-reference
+
+
+
+IndicPhotoOCR - Comprehensive Scene Text Recognition Toolkit across 13 Indian Languages
+
+
+
+
+![Open Source](https://img.shields.io/badge/Open%20Source-Bhashini-FF6C00)
+[![Hits](https://hits.seeyoufarm.com/api/count/incr/badge.svg?url=https%3A%2F%2Fgithub.com%2FBhashini-IITJ%2FBharatOCR&count_bg=%233D48C8&title_bg=%23555555&icon=&icon_color=%0C0983&title=hits&edge_flat=false)](https://hits.seeyoufarm.com)
+[![GitHub stars](https://img.shields.io/github/stars/Bhashini-IITJ/BharatOCR.svg?style=social&label=Star&color=orange)](https://github.com/Bhashini-IITJ/BharatOCR/stargazers)
+![GitHub forks](https://img.shields.io/github/forks/Bhashini-IITJ/BharatOCR?style=social)
+[![Hugging Face](https://img.shields.io/badge/Hugging_Face-Demo-FF6C00?logo=Huggingface&logoColor=white)](https://huggingface.co./spaces/anikde/BharatOCR)
+
+
+
+
+
+
+
+
+IndicPhotoOCR is an advanced OCR toolkit designed for detecting, identifying, and recognizing text across 13 Indian languages, including Assamese, Bengali, Gujarati, Hindi, Kannada, Malayalam, Marathi, Meitei Odia, Punjabi, Tamil, Telugu, Urdu, and English. Built to handle the unique scripts and complex structures of Indian languages, IndicPhotoOCR provides robust detection and recognition capabilities, making it a valuable tool for processing multilingual documents and enhancing document analysis in these diverse scripts.
+
+![](static/pics/visualizeIndicPhotoOCR.png)
+
+
+## Table of Content
+[Updates](https://github.com/Bhashini-IITJ/BharatOCR/blob/main/README.md#updates)
+[Installation](https://github.com/Bhashini-IITJ/BharatOCR/blob/main/README.md#installation)
+[How to use](https://github.com/Bhashini-IITJ/BharatOCR/blob/main/README.md#how-to-use)
+[Acknowledgement](https://github.com/Bhashini-IITJ/BharatOCR/blob/main/README.md#acknowledgement)
+[Contact us](https://github.com/Bhashini-IITJ/BharatOCR/blob/main/README.md#contact-us)
+
+
+
+
+## Updates
+[November 2024]: Try demo in [huggingface space](https://huggingface.co./spaces/anikde/BharatOCR).\
+[November 2024]: Use this package in [Google Colab](https://colab.research.google.com/drive/1BILXjUF2kKKrzUJ_evubgLHl2busPiH2?usp=sharing).\
+[November 2024]: Added support for [10 languages](#config) in the recognition module.
+[September 2024]: Private repository created.
+
+
+## Installation
+Currently we need to manually create virtual environemnt.
+```python
+conda create -n indicphotoocr python=3.9 -y
+conda activate indicphotoocr
+
+
+git clone https://github.com/Bhashini-IITJ/IndicPhotoOCR.git
+cd IndicPhotoOCR
+```
+
+ CPU Installation
+
+ ```bash
+ python setup.py sdist bdist_wheel
+ pip install dist/IndicPhotoOCR-1.1.0-py3-none-any.whl[cpu]
+ ```
+
+
+
+ CUDA 11.8 Installation
+
+ ```bash
+ python setup.py sdist bdist_wheel
+ pip install ./dist/IndicPhotoOCR-1.1.0-py3-none-any.whl[cu118] --extra-index-url https://download.pytorch.org/whl/cu118
+ ```
+
+
+
+ CUDA 12.1 Installation
+
+ ```bash
+ python setup.py sdist bdist_wheel
+ pip install ./dist/IndicPhotoOCR-1.1.0-py3-none-any.whl[cu121] --extra-index-url https://download.pytorch.org/whl/cu121
+ ```
+
+
+
+If you find any trouble with the above installation use the ```setup.sh``` script.
+```bash
+chmod +x setup.sh
+./setup.sh
+```
+
+## Config
+Currently this model works for hindi v/s english script identification and thereby hindi and english recognition.
+
+Detection Model: EAST\
+ScripIndetification Model: Hindi v/s English\
+Recognition Model: Hindi, English, Assamese, Bengali, Gujarati, Marathi, Odia, Punjabi, Tamil, Telugu.
+
+## How to use
+### Detection
+
+```python
+>>> from IndicPhotoOCR.ocr import OCR
+# Create an object of OCR
+>>> ocr_system = OCR(verbose=True) # for CPU --> OCR(device="cpu")
+
+# Get detections
+>>> detections = ocr_system.detect("test_images/image_141.jpg")
+
+# Running text detection...
+# 4334 text boxes before nms
+# 1.027989387512207
+
+# Save and visualize the detection results
+>>> ocr_system.visualize_detection("test_images/image_141.jpg", detections)
+# Image saved at: test.png
+```
+
+## Cropped Word Recognition
+```python
+>>> from IndicPhotoOCR.ocr import OCR
+# Create an object of OCR
+>>> ocr_system = OCR(verbose=True) # for CPU --> OCR(device="cpu")
+# Get recognitions
+>>> ocr_system.recognise("test_images/cropped_image/image_141_0.jpg", "hindi")
+# Recognizing text in detected area...
+# 'मण्डी'
+```
+
+## End-to-end Scene Text Recognition
+```python
+>>> from IndicPhotoOCR.ocr import OCR
+# Create an object of OCR
+>>> ocr_system = OCR(verbose=True) # for CPU --> OCR(device="cpu")
+# Complete pipeline
+>>> ocr_system.ocr("test_images/image_141.jpg")
+# Running text detection...
+# 4334 text boxes before nms
+# 0.9715704917907715
+# Identifying script for the cropped area...
+# Recognizing text in detected area...
+# Recognized word: रोड
+# Identifying script for the cropped area...
+# Recognizing text in detected area...
+# Recognized word: बाराखम्ब
+# Identifying script for the cropped area...
+# Recognizing text in detected area...
+# Using cache found in /DATA1/ocrteam/.cache/torch/hub/baudm_parseq_main
+# Recognized word: barakhaml
+# Identifying script for the cropped area...
+# Recognizing text in detected area...
+# Recognized word: हाऊस
+# Identifying script for the cropped area...
+# Recognizing text in detected area...
+# Using cache found in /DATA1/ocrteam/.cache/torch/hub/baudm_parseq_main
+# Recognized word: mandi
+# Identifying script for the cropped area...
+# Recognizing text in detected area...
+# Using cache found in /DATA1/ocrteam/.cache/torch/hub/baudm_parseq_main
+# Recognized word: chowk
+# Identifying script for the cropped area...
+# Recognizing text in detected area...
+# Recognized word: मण्डी
+# Identifying script for the cropped area...
+# Recognizing text in detected area...
+# Using cache found in /DATA1/ocrteam/.cache/torch/hub/baudm_parseq_main
+# Recognized word: road
+# Identifying script for the cropped area...
+# Recognizing text in detected area...
+# Using cache found in /DATA1/ocrteam/.cache/torch/hub/baudm_parseq_main
+# Recognized word: house
+# Identifying script for the cropped area...
+# Recognizing text in detected area...
+# Using cache found in /DATA1/ocrteam/.cache/torch/hub/baudm_parseq_main
+# Recognized word: rajiv
+# Identifying script for the cropped area...
+# Recognizing text in detected area...
+# Recognized word: राजीव
+# Identifying script for the cropped area...
+# Recognizing text in detected area...
+# Recognized word: चौक
+
+
+```
+
+
+
+## Acknowledgement
+
+Text Recognition - [PARseq](https://github.com/baudm/parseq)\
+EAST re-implemenation [repository](https://github.com/foamliu/EAST).
+National Language Translation Mission [Bhashini](https://bhashini.gov.in/).
+## Contact us
+For any queries, please contact us at:
+- [Anik De](mailto:anekde@gmail.com)
+
diff --git a/app.py b/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..7d86bf6f13d22d756cc593b14c8fe79e2bd46f81
--- /dev/null
+++ b/app.py
@@ -0,0 +1,195 @@
+import gradio as gr
+from PIL import Image
+import os
+from IndicPhotoOCR.ocr import OCR # Ensure OCR class is saved in a file named ocr.py
+from IndicPhotoOCR.theme import Seafoam
+from IndicPhotoOCR.utils.helper import detect_para
+from transformers import (
+ AutoModelForSeq2SeqLM,
+ AutoTokenizer,
+
+)
+import numpy as np
+import torch
+
+from IndicTransToolkit import IndicProcessor
+
+# Initialize the OCR object for text detection and recognition
+ocr = OCR(device='cpu',verbose=False)
+
+def translate(given_str,lang='hindi'):
+ DEVICE = 'cpu'
+ model_name = "ai4bharat/indictrans2-en-indic-1B" if lang=="english" else "ai4bharat/indictrans2-indic-en-1B"
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
+
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name, trust_remote_code=True)
+
+ ip = IndicProcessor(inference=True)
+
+ model = model.to(DEVICE)
+ model.eval()
+ src_lang, tgt_lang = ("eng_Latn", "hin_Deva") if lang=="english" else ("hin_Deva", "eng_Latn" )
+
+ batch = ip.preprocess_batch(
+ [given_str],
+ src_lang=src_lang,
+ tgt_lang=tgt_lang,
+ )
+ inputs = tokenizer(
+ batch,
+ truncation=True,
+ padding="longest",
+ return_tensors="pt",
+ return_attention_mask=True,
+ ).to(DEVICE)
+ with torch.no_grad():
+ generated_tokens = model.generate(
+ **inputs,
+ use_cache=True,
+ min_length=0,
+ max_length=256,
+ num_beams=5,
+ num_return_sequences=1,
+ )
+
+ # Decode the generated tokens into text
+ with tokenizer.as_target_tokenizer():
+ generated_tokens = tokenizer.batch_decode(
+ generated_tokens.detach().cpu().tolist(),
+ skip_special_tokens=True,
+ clean_up_tokenization_spaces=True,
+ )
+ translation = ip.postprocess_batch(generated_tokens, lang=tgt_lang)[0]
+ return translation
+
+
+
+
+
+
+def process_image(image):
+ """
+ Processes the uploaded image for text detection and recognition.
+ - Detects bounding boxes in the image
+ - Draws bounding boxes on the image and identifies script in each detected area
+ - Recognizes text in each cropped region and returns the annotated image and recognized text
+
+ Parameters:
+ image (PIL.Image): The input image to be processed.
+
+ Returns:
+ tuple: A PIL.Image with bounding boxes and a string of recognized text.
+ """
+
+ # Save the input image temporarily
+ image_path = "input_image.jpg"
+ image.save(image_path)
+
+ # Detect bounding boxes on the image using OCR
+ detections = ocr.detect(image_path)
+
+ # Draw bounding boxes on the image and save it as output
+ ocr.visualize_detection(image_path, detections, save_path="output_image.png")
+
+ # Load the annotated image with bounding boxes drawn
+ output_image = Image.open("output_image.png")
+
+ # Initialize list to hold recognized text from each detected area
+ recognized_texts = {}
+ pil_image = Image.open(image_path)
+
+ # # Process each detected bounding box for script identification and text recognition
+ # for bbox in detections:
+ # # Identify the script and crop the image to this region
+ # script_lang, cropped_path = ocr.crop_and_identify_script(pil_image, bbox)
+
+ # if script_lang: # Only proceed if a script language is identified
+ # # Recognize text in the cropped area
+ # recognized_text = ocr.recognise(cropped_path, script_lang)
+ # recognized_texts.append(recognized_text)
+ for id, bbox in enumerate(detections):
+ # Identify the script and crop the image to this region
+ script_lang, cropped_path = ocr.crop_and_identify_script(pil_image, bbox)
+
+ # Calculate bounding box coordinates
+ x1 = min([bbox[i][0] for i in range(len(bbox))])
+ y1 = min([bbox[i][1] for i in range(len(bbox))])
+ x2 = max([bbox[i][0] for i in range(len(bbox))])
+ y2 = max([bbox[i][1] for i in range(len(bbox))])
+
+ if script_lang:
+ recognized_text = ocr.recognise(cropped_path, script_lang)
+ recognized_texts[f"img_{id}"] = {"txt": recognized_text, "bbox": [x1, y1, x2, y2]}
+
+ # Combine recognized texts into a single string for display
+ # recognized_texts_combined = " ".join(recognized_texts)
+ string = detect_para(recognized_texts)
+
+ recognized_texts_combined = '\n'.join([' '.join(line) for line in string])
+ recognized_texts_combined = translate(recognized_texts_combined,script_lang)
+
+ return output_image, recognized_texts_combined
+
+# Custom HTML for interface header with logos and alignment
+interface_html = """
+
+
+
![IITJ Logo](https://iitj.ac.in/images/logo/Design-of-New-Logo-of-IITJ-2.png)
+
+
![Bhashini Logo](https://play-lh.googleusercontent.com/_FXSr4xmhPfBykmNJvKvC0GIAVJmOLhFl6RA5fobCjV-8zVSypxX8yb8ka6zu6-4TEft=w240-h480-rw)
+
+"""
+
+
+
+# Links to GitHub and Dataset repositories with GitHub icon
+links_html = """
+
+"""
+
+# Custom CSS to style the text box font size
+custom_css = """
+.custom-textbox textarea {
+ font-size: 20px !important;
+}
+"""
+
+# Create an instance of the Seafoam theme for a consistent visual style
+seafoam = Seafoam()
+
+# Define examples for users to try out
+examples = [
+ ["test_images/208.jpg"],
+ ["test_images/1310.jpg"]
+]
+title = "Developed by IITJ
"
+
+# Set up the Gradio Interface with the defined function and customizations
+demo = gr.Interface(
+ fn=process_image,
+ inputs=gr.Image(type="pil", image_mode="RGB"),
+ outputs=[
+ gr.Image(type="pil", label="Detected Bounding Boxes"),
+ gr.Textbox(label="Translated Text", elem_classes="custom-textbox")
+ ],
+ title="Scene Text Translator",
+ description=title+interface_html+links_html,
+ theme=seafoam,
+ css=custom_css,
+ examples=examples
+)
+
+# # Server setup and launch configuration
+if __name__ == "__main__":
+ server = "0.0.0.0" # IP address for server
+ port = 7867 # Port to run the server on
+ demo.launch(server_name=server, server_port=port)
+
+# demo.launch()
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..1faea294ff63f6516fe29b8f0197b4d2c7170a75
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,49 @@
+IndicTransToolkit @ git+https://github.com/VarunGumma/IndicTransToolkit
+
+aiohappyeyeballs==2.4.3
+aiohttp==3.10.10
+aiosignal==1.3.1
+async-timeout==4.0.3
+attrs==24.2.0
+certifi==2024.8.30
+charset-normalizer==3.4.0
+click==8.1.7
+filelock==3.16.1
+frozenlist==1.5.0
+fsspec
+huggingface-hub==0.26.1
+idna==3.10
+jinja2==3.1.4
+joblib==1.4.2
+lightning-utilities==0.11.8
+markupsafe==3.0.2
+mpmath==1.3.0
+multidict==6.1.0
+networkx==3.2.1
+nltk==3.9.1
+numpy==1.26.4
+packaging==24.1
+pillow==11.0.0
+propcache==0.2.0
+pytorch-lightning==2.4.0
+pyyaml==6.0.2
+regex==2024.9.11
+requests==2.32.3
+safetensors==0.4.5
+sympy==1.13.1
+timm==1.0.11
+torchmetrics==1.5.1
+tqdm==4.66.5
+typing-extensions==4.12.2
+urllib3==2.2.3
+yarl==1.16.0
+opencv-python==4.10.0.84
+shapely==2.0.6
+openai-clip==1.0.1
+lmdb==1.5.1
+torch==2.5.0
+torchvision==0.20.0
+easydict==1.13
+scipy==1.13.1
+transformers==4.45.1
+datasets==3.1.0
diff --git a/setup.py b/setup.py
new file mode 100644
index 0000000000000000000000000000000000000000..f075eec24e314deb35a57251dddde0816a7fcb03
--- /dev/null
+++ b/setup.py
@@ -0,0 +1,79 @@
+from setuptools import setup, find_packages
+
+setup(
+ name="IndicPhotoOCR",
+ version="1.1.0",
+ description="Scene Text Recognition Toolkit across 13 Indian Languages which contains detection, script identification, and text recognition modules",
+ long_description=open("README.md").read() + "\n\n" + open("CHANGELOG.md").read(),
+ long_description_content_type="text/markdown",
+ author="Anik De",
+ author_email="anekde@gmail.com",
+ url="https://github.com/Bhashini-IITJ/IndicPhotoOCR",
+ packages=find_packages(),
+ python_requires='>=3.9',
+ install_requires=[
+ # Your mandatory dependencies here
+ 'aiohappyeyeballs==2.4.3',
+ 'aiohttp==3.10.10',
+ 'aiosignal==1.3.1',
+ 'async-timeout==4.0.3',
+ 'attrs==24.2.0',
+ 'certifi==2024.8.30',
+ 'charset-normalizer==3.4.0',
+ 'click==8.1.7',
+ 'filelock==3.16.1',
+ 'frozenlist==1.5.0',
+ 'fsspec==2024.10.0',
+ 'huggingface-hub==0.26.1',
+ 'idna==3.10',
+ 'jinja2==3.1.4',
+ 'joblib==1.4.2',
+ 'lightning-utilities==0.11.8',
+ 'markupsafe==3.0.2',
+ 'mpmath==1.3.0',
+ 'multidict==6.1.0',
+ 'networkx==3.2.1',
+ 'nltk==3.9.1',
+ 'numpy==1.26.4',
+ 'packaging==24.1',
+ 'pillow==11.0.0',
+ 'propcache==0.2.0',
+ 'pytorch-lightning==2.4.0',
+ 'pyyaml==6.0.2',
+ 'regex==2024.9.11',
+ 'requests==2.32.3',
+ 'safetensors==0.4.5',
+ 'sympy==1.13.1',
+ 'timm==1.0.11',
+ 'torchmetrics==1.5.1',
+ 'tqdm==4.66.5',
+ 'typing-extensions==4.12.2',
+ 'urllib3==2.2.3',
+ 'yarl==1.16.0',
+ 'opencv-python==4.10.0.84',
+ 'shapely==2.0.6',
+ 'openai-clip==1.0.1',
+ 'lmdb==1.5.1'
+
+ ],
+ extras_require={
+ 'cu118': [
+ 'torch==2.5.0+cu118',
+ 'torchvision==0.20.0+cu118',
+ # Any additional packages specific to cu118
+ ],
+ 'cu121': [
+ 'torch==2.5.0+cu121',
+ 'torchvision==0.20.0+cu121',
+ # Any additional packages specific to cu121
+ ],
+ 'cpu': [
+ 'torch==2.5.0',
+ 'torchvision==0.20.0',
+ # Any additional packages specific to CPU
+ ],
+ 'extra': [
+ 'six==1.16.0', # Your other extra requirements
+ ],
+ },
+)
diff --git a/static/pics/README.md b/static/pics/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc
--- /dev/null
+++ b/static/pics/README.md
@@ -0,0 +1 @@
+
diff --git a/static/pics/bharatOCR.png b/static/pics/bharatOCR.png
new file mode 100644
index 0000000000000000000000000000000000000000..92a37fdd463ba03e581f02e3e9860a57b040b072
Binary files /dev/null and b/static/pics/bharatOCR.png differ
diff --git a/static/pics/visualizeIndicPhotoOCR.png b/static/pics/visualizeIndicPhotoOCR.png
new file mode 100644
index 0000000000000000000000000000000000000000..11e8ca765f8ac5354c726b526ec7d3f0c06dad23
Binary files /dev/null and b/static/pics/visualizeIndicPhotoOCR.png differ
diff --git a/test_images/1310.jpg b/test_images/1310.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..6a04e1e1a7ba9757ce4e5d45bbd0b83eb97507fb
Binary files /dev/null and b/test_images/1310.jpg differ
diff --git a/test_images/208.jpg b/test_images/208.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..098f7a7a491fd100fc8b93d232454700a9a97e97
Binary files /dev/null and b/test_images/208.jpg differ