|
import torch |
|
import cv2 |
|
import numpy as np |
|
from IndicPhotoOCR.detection.textbpn.network.textnet import TextNet |
|
from IndicPhotoOCR.detection.textbpn.cfglib.config import config as cfg |
|
import warnings |
|
import os |
|
import requests |
|
from tqdm import tqdm |
|
|
|
|
|
warnings.filterwarnings("ignore") |
|
|
|
model_info = { |
|
"textbpnpp": { |
|
"path": "models/TextBPN_resnet50_300.pth", |
|
"url" : "https://github.com/Bhashini-IITJ/SceneTextDetection/releases/download/TextBPN%2B%2B/TextBPN_resnet50_300.pth", |
|
}, |
|
"textbpnpp_deformable": { |
|
"path":"models/TextBPN_deformable_resnet50_300.pth", |
|
"url": "https://github.com/Bhashini-IITJ/SceneTextDetection/releases/download/TextBPN%2B%2B/TextBPN_deformable_resnet50_300.pth", |
|
}, |
|
"textbpn_resnet18" : { |
|
"path":"models/TextBPN_resnet18_300.pth", |
|
"url": "https://github.com/Bhashini-IITJ/SceneTextDetection/releases/download/TextBPN%2B%2B/TextBPN_resnet18_300.pth", |
|
|
|
} |
|
} |
|
|
|
def ensure_model(model_name): |
|
model_path = model_info[model_name]["path"] |
|
url = model_info[model_name]["url"] |
|
root_model_dir = "IndicPhotoOCR/detection/textbpn" |
|
model_path = os.path.join(root_model_dir, model_path) |
|
|
|
if not os.path.exists(model_path): |
|
print(f"Model not found locally. Downloading {model_name} from {url}...") |
|
|
|
|
|
response = requests.get(url, stream=True) |
|
total_size = int(response.headers.get('content-length', 0)) |
|
os.makedirs(f"{root_model_dir}/models", exist_ok=True) |
|
|
|
with open(model_path, "wb") as f, tqdm( |
|
desc=model_name, |
|
total=total_size, |
|
unit='B', |
|
unit_scale=True, |
|
unit_divisor=1024, |
|
) as bar: |
|
for data in response.iter_content(chunk_size=1024): |
|
f.write(data) |
|
bar.update(len(data)) |
|
|
|
print(f"Downloaded model for {model_name}.") |
|
|
|
return model_path |
|
|
|
class TextBPNpp_detector: |
|
def __init__(self, model_name="textbpnpp", backbone="resnet50", device="cpu"): |
|
""" |
|
Initialize the TextBPN model. |
|
:param model_path: Path to the pre-trained model. |
|
:param backbone: Backbone architecture (default: "resnet50"). |
|
:param device: Device to run the model on (default: "cpu"). |
|
""" |
|
self.model_path = ensure_model(model_name) |
|
self.device = torch.device(device) |
|
self.model = TextNet(is_training=False, backbone=backbone) |
|
self.model.load_model(self.model_path) |
|
self.model.eval() |
|
self.model.to(self.device) |
|
|
|
@staticmethod |
|
def to_device(tensor, device): |
|
""" |
|
Move tensor to the specified device. |
|
:param tensor: Tensor to move. |
|
:param device: Target device. |
|
:return: Tensor on the target device. |
|
""" |
|
return tensor.to(device, non_blocking=True) |
|
|
|
@staticmethod |
|
def pad_image(image, stride=32): |
|
""" |
|
Pad the image to make its dimensions divisible by the stride. |
|
:param image: Input image. |
|
:param stride: Stride size. |
|
:return: Padded image and original dimensions. |
|
""" |
|
h, w = image.shape[:2] |
|
new_h = (h + stride - 1) // stride * stride |
|
new_w = (w + stride - 1) // stride * stride |
|
padded_image = cv2.copyMakeBorder( |
|
image, 0, new_h - h, 0, new_w - w, cv2.BORDER_CONSTANT, value=(0, 0, 0) |
|
) |
|
return padded_image, (h, w) |
|
|
|
@staticmethod |
|
def rescale_result(image, bbox_contours, original_height, original_width): |
|
""" |
|
Rescale the bounding box contours to the original image size. |
|
:param image: Image after resizing. |
|
:param bbox_contours: Bounding box contours. |
|
:param original_height: Original image height. |
|
:param original_width: Original image width. |
|
:return: Original image and rescaled contours. |
|
""" |
|
contours = [] |
|
for cont in bbox_contours: |
|
cont[:, 0] = (cont[:, 0] * original_width / image.shape[1]).astype(int) |
|
cont[:, 1] = (cont[:, 1] * original_height / image.shape[0]).astype(int) |
|
contours.append(cont) |
|
return contours |
|
|
|
def detect(self, image_path): |
|
""" |
|
Perform text detection on the given image. |
|
:param image_path: Path to the input image. |
|
:return: Dictionary with detection results. |
|
""" |
|
image = cv2.imread(image_path) |
|
if image is None: |
|
raise ValueError(f"Failed to read the image at {image_path}") |
|
|
|
padded_image, original_size = self.pad_image(image) |
|
padded_tensor = ( |
|
torch.from_numpy(padded_image).permute(2, 0, 1).float() / 255.0 |
|
).unsqueeze(0) |
|
|
|
cfg.test_size = [padded_image.shape[0], padded_image.shape[1]] |
|
|
|
input_dict = {"img": self.to_device(padded_tensor, self.device)} |
|
with torch.no_grad(): |
|
output_dict = self.model(input_dict, padded_image.shape) |
|
|
|
contours = output_dict["py_preds"][-1].int().cpu().numpy() |
|
contours = self.rescale_result(image, contours, *original_size) |
|
|
|
bbox_result_dict = {"detections": []} |
|
for contour in contours: |
|
|
|
|
|
|
|
bbox_result_dict["detections"].append(contour.tolist()) |
|
|
|
return bbox_result_dict |
|
|
|
def visualize_detections(self, image_path, bbox_result_dict, output_path="output.png"): |
|
""" |
|
Visualize detections on the image. |
|
:param image_path: Path to the input image. |
|
:param bbox_result_dict: Detection results in the format: |
|
{'detections': [[[x1, y1], [x2, y2], [x3, y3], [x4, y4]], ...]}. |
|
:param output_path: Path to save the visualized image. If None, the image is only displayed. |
|
""" |
|
|
|
image = cv2.imread(image_path) |
|
if image is None: |
|
raise ValueError(f"Failed to read the image at {image_path}") |
|
|
|
|
|
for bbox in bbox_result_dict.get("detections", []): |
|
points = np.array(bbox, dtype=np.int32) |
|
cv2.polylines(image, [points], isClosed=True, color=(0, 255, 0), thickness=2) |
|
|
|
|
|
if output_path: |
|
cv2.imwrite(output_path, image) |
|
print(f"Visualization saved to {output_path}") |
|
else: |
|
cv2.imshow("Detections", image) |
|
cv2.waitKey(0) |
|
cv2.destroyAllWindows() |
|
|
|
if __name__ == "__main__": |
|
import argparse |
|
parser = argparse.ArgumentParser(description='Text detection using EAST model') |
|
parser.add_argument('--image_path', type=str, required=True, help='Path to the input image') |
|
parser.add_argument('--device', type=str, default='cpu', help='Device to run the model on, e.g., "cpu" or "cuda"') |
|
parser.add_argument('--model_name', type=str, required=True, help='Path to the model checkpoint file') |
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
|
|
|
|
|
detector = TextBPNpp_detector(args.model_name, device="cpu") |
|
result = detector.detect(args.image_path) |
|
print(result) |
|
|
|
|
|
|
|
|
|
|