Spaces:
Runtime error
Runtime error
import sys | |
import os | |
import torch | |
from PIL import Image | |
import cv2 | |
import numpy as np | |
# from IndicPhotoOCR.detection.east_detector import EASTdetector | |
from IndicPhotoOCR.script_identification.CLIP_identifier import CLIPidentifier | |
from IndicPhotoOCR.recognition.parseq_recogniser import PARseqrecogniser | |
import IndicPhotoOCR.detection.east_config as cfg | |
from IndicPhotoOCR.detection.textbpn.textbpnpp_detector import TextBPNpp_detector | |
class OCR: | |
def __init__(self, device='cuda:0', verbose=False): | |
# self.detect_model_checkpoint = detect_model_checkpoint | |
self.device = device | |
self.verbose = verbose | |
# self.image_path = image_path | |
# self.detector = EASTdetector() | |
self.detector = TextBPNpp_detector(device=self.device) | |
self.recogniser = PARseqrecogniser() | |
self.identifier = CLIPidentifier() | |
# def detect(self, image_path, detect_model_checkpoint=cfg.checkpoint): | |
# """Run the detection model to get bounding boxes of text areas.""" | |
# if self.verbose: | |
# print("Running text detection...") | |
# detections = self.detector.detect(image_path, detect_model_checkpoint, self.device) | |
# # print(detections) | |
# return detections['detections'] | |
def detect(self, image_path): | |
self.detections = self.detector.detect(image_path) | |
return self.detections['detections'] | |
def visualize_detection(self, image_path, detections, save_path=None, show=False): | |
# Default save path if none is provided | |
default_save_path = "test.png" | |
path_to_save = save_path if save_path is not None else default_save_path | |
# Get the directory part of the path | |
directory = os.path.dirname(path_to_save) | |
# Check if the directory exists, and create it if it doesn’t | |
if directory and not os.path.exists(directory): | |
os.makedirs(directory) | |
print(f"Created directory: {directory}") | |
# Read the image and draw bounding boxes | |
image = cv2.imread(image_path) | |
for box in detections: | |
# Convert list of points to a numpy array with int type | |
points = np.array(box, np.int32) | |
# Compute the top-left and bottom-right corners of the bounding box | |
x_min = np.min(points[:, 0]) | |
y_min = np.min(points[:, 1]) | |
x_max = np.max(points[:, 0]) | |
y_max = np.max(points[:, 1]) | |
# Draw the rectangle | |
cv2.rectangle(image, (x_min, y_min), (x_max, y_max), color=(0, 255, 0), thickness=3) | |
# Show the image if 'show' is True | |
if show: | |
plt.figure(figsize=(10, 10)) | |
plt.imshow(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) | |
plt.axis("off") | |
plt.show() | |
# Save the annotated image | |
cv2.imwrite(path_to_save, image) | |
print(f"Image saved at: {path_to_save}") | |
def crop_and_identify_script(self, image, bbox): | |
""" | |
Crop a text area from the image and identify its script language. | |
Args: | |
image (PIL.Image): The full image. | |
bbox (list): List of four corner points, each a [x, y] pair. | |
Returns: | |
str: Identified script language. | |
""" | |
# Extract x and y coordinates from the four corner points | |
x_coords = [point[0] for point in bbox] | |
y_coords = [point[1] for point in bbox] | |
# Get the bounding box coordinates (min and max) | |
x_min, y_min = min(x_coords), min(y_coords) | |
x_max, y_max = max(x_coords), max(y_coords) | |
# Crop the image based on the bounding box | |
cropped_image = image.crop((x_min, y_min, x_max, y_max)) | |
root_image_dir = "IndicPhotoOCR/script_identification" | |
os.makedirs(f"{root_image_dir}/images", exist_ok=True) | |
# Temporarily save the cropped image to pass to the script model | |
cropped_path = f'{root_image_dir}/images/temp_crop_{x_min}_{y_min}.jpg' | |
cropped_image.save(cropped_path) | |
# Predict script language, here we assume "hindi" as the model name | |
if self.verbose: | |
print("Identifying script for the cropped area...") | |
script_lang = self.identifier.identify(cropped_path, "hindi") # Use "hindi" as the model name | |
# print(script_lang) | |
# Clean up temporary file | |
# os.remove(cropped_path) | |
return script_lang, cropped_path | |
def recognise(self, cropped_image_path, script_lang): | |
"""Recognize text in a cropped image area using the identified script.""" | |
if self.verbose: | |
print("Recognizing text in detected area...") | |
recognized_text = self.recogniser.recognise(script_lang, cropped_image_path, script_lang, self.verbose) | |
# print(recognized_text) | |
return recognized_text | |
def ocr(self, image_path): | |
"""Process the image by detecting text areas, identifying script, and recognizing text.""" | |
recognized_words = [] | |
image = Image.open(image_path) | |
# Run detection | |
detections = self.detect(image_path) | |
# Process each detected text area | |
for bbox in detections: | |
# Crop and identify script language | |
script_lang, cropped_path = self.crop_and_identify_script(image, bbox) | |
# Check if the script language is valid | |
if script_lang: | |
# Recognize text | |
recognized_word = self.recognise(cropped_path, script_lang) | |
recognized_words.append(recognized_word) | |
if self.verbose: | |
print(f"Recognized word: {recognized_word}") | |
return recognized_words | |
if __name__ == '__main__': | |
# detect_model_checkpoint = 'bharatSTR/East/tmp/epoch_990_checkpoint.pth.tar' | |
sample_image_path = 'test_images/image_141.jpg' | |
cropped_image_path = 'test_images/cropped_image/image_141_0.jpg' | |
ocr = OCR(device="cuda", verbose=False) | |
# detections = ocr.detect(sample_image_path) | |
# print(detections) | |
# ocr.visualize_detection(sample_image_path, detections) | |
# recognition = ocr.recognise(cropped_image_path, "hindi") | |
# print(recognition) | |
recognised_words = ocr.ocr(sample_image_path) | |
print(recognised_words) |