summaryrefslogtreecommitdiff
path: root/text_recognizer/paragraph_text_recognizer.py
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/paragraph_text_recognizer.py')
-rw-r--r--text_recognizer/paragraph_text_recognizer.py153
1 files changed, 0 insertions, 153 deletions
diff --git a/text_recognizer/paragraph_text_recognizer.py b/text_recognizer/paragraph_text_recognizer.py
deleted file mode 100644
index aa39662..0000000
--- a/text_recognizer/paragraph_text_recognizer.py
+++ /dev/null
@@ -1,153 +0,0 @@
-"""Full model.
-
-Takes an image and returns the text in the image, by first segmenting the image with a LineDetector, then extracting the
-each crop of the image corresponding to line regions, and feeding them to a LinePredictor model that outputs the text
-in each region.
-"""
-from typing import Dict, List, Tuple, Union
-
-import cv2
-import numpy as np
-import torch
-
-from text_recognizer.models import SegmentationModel, TransformerModel
-from text_recognizer.util import read_image
-
-
-class ParagraphTextRecognizor:
- """Given an image of a single handwritten character, recognizes it."""
-
- def __init__(self, line_predictor_args: Dict, line_detector_args: Dict) -> None:
- self._line_predictor = TransformerModel(**line_predictor_args)
- self._line_detector = SegmentationModel(**line_detector_args)
- self._line_detector.eval()
- self._line_predictor.eval()
-
- def predict(self, image_or_filename: Union[str, np.ndarray]) -> Tuple:
- """Takes an image and returns all text within it."""
- image = (
- read_image(image_or_filename)
- if isinstance(image_or_filename, str)
- else image_or_filename
- )
-
- line_region_crops = self._get_line_region_crops(image)
- processed_line_region_crops = [
- self._process_image_for_line_predictor(image=crop)
- for crop in line_region_crops
- ]
- line_region_strings = [
- self.line_predictor_model.predict_on_image(crop)[0]
- for crop in processed_line_region_crops
- ]
-
- return " ".join(line_region_strings), line_region_crops
-
- def _get_line_region_crops(
- self, image: np.ndarray, min_crop_len_factor: float = 0.02
- ) -> List[np.ndarray]:
- """Returns all the crops of text lines in a square image."""
- processed_image, scale_down_factor = self._process_image_for_line_detector(
- image
- )
- line_segmentation = self._line_detector.predict_on_image(processed_image)
- bounding_boxes = _find_line_bounding_boxes(line_segmentation)
-
- bounding_boxes = (bounding_boxes * scale_down_factor).astype(int)
-
- min_crop_len = int(min_crop_len_factor * min(image.shape[0], image.shape[1]))
- line_region_crops = [
- image[y : y + h, x : x + w]
- for x, y, w, h in bounding_boxes
- if w >= min_crop_len and h >= min_crop_len
- ]
- return line_region_crops
-
- def _process_image_for_line_detector(
- self, image: np.ndarray
- ) -> Tuple[np.ndarray, float]:
- """Convert uint8 image to float image with black background with shape self._line_detector.image_shape."""
- resized_image, scale_down_factor = _resize_image_for_line_detector(
- image=image, max_shape=self._line_detector.image_shape
- )
- resized_image = (1.0 - resized_image / 255).astype("float32")
- return resized_image, scale_down_factor
-
- def _process_image_for_line_predictor(self, image: np.ndarray) -> np.ndarray:
- """Preprocessing of image before feeding it to the LinePrediction model.
-
- Convert uint8 image to float image with black background with shape
- self._line_predictor.image_shape while maintaining the image aspect ratio.
-
- Args:
- image (np.ndarray): Crop of text line.
-
- Returns:
- np.ndarray: Processed crop for feeding line predictor.
- """
- expected_shape = self._line_detector.image_shape
- scale_factor = (np.array(expected_shape) / np.array(image.shape)).min()
- scaled_image = cv2.resize(
- image,
- dsize=None,
- fx=scale_factor,
- fy=scale_factor,
- interpolation=cv2.INTER_AREA,
- )
-
- pad_with = (
- (0, expected_shape[0] - scaled_image.shape[0]),
- (0, expected_shape[1] - scaled_image.shape[1]),
- )
-
- padded_image = np.pad(
- scaled_image, pad_with=pad_with, mode="constant", constant_values=255
- )
- return 1 - padded_image / 255
-
-
-def _find_line_bounding_boxes(line_segmentation: np.ndarray) -> np.ndarray:
- """Given a line segmentation, find bounding boxes for connected-component regions corresponding to non-0 labels."""
-
- def _find_line_bounding_boxes_in_channel(
- line_segmentation_channel: np.ndarray,
- ) -> np.ndarray:
- line_segmentation_image = cv2.dilate(
- line_segmentation_channel, kernel=np.ones((3, 3)), iterations=1
- )
- line_activation_image = (line_segmentation_image * 255).astype("uint8")
- line_activation_image = cv2.threshold(
- line_activation_image, 0.5, 1, cv2.THRESH_BINARY | cv2.THRESH_OTSU
- )[1]
-
- bounding_cnts, _ = cv2.findContours(
- line_segmentation_image, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
- )
- return np.array([cv2.boundingRect(cnt) for cnt in bounding_cnts])
-
- bounding_boxes = np.concatenate(
- [
- _find_line_bounding_boxes_in_channel(line_segmentation[:, :, i])
- for i in [1, 2]
- ],
- axis=0,
- )
-
- return bounding_boxes[np.argsort(bounding_boxes[:, 1])]
-
-
-def _resize_image_for_line_detector(
- image: np.ndarray, max_shape: Tuple[int, int]
-) -> Tuple[np.ndarray, float]:
- """Resize the image to less than the max_shape while maintaining the aspect ratio."""
- scale_down_factor = max(np.ndarray(image.shape) / np.ndarray(max_shape))
- if scale_down_factor == 1:
- return image.copy(), scale_down_factor
- resize_image = cv2.resize(
- image,
- dsize=None,
- fx=1 / scale_down_factor,
- fy=1 / scale_down_factor,
- interpolation=cv2.INTER_AREA,
- )
- return resize_image, scale_down_factor