summaryrefslogtreecommitdiff
path: root/text_recognizer/paragraph_text_recognizer.py
blob: aa39662863678c35c9d8add325a2ae68b1132503 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
"""Full model.

Takes an image and returns the text in the image, by first segmenting the image with a LineDetector, then extracting the
each crop of the image corresponding to line regions, and feeding them to a LinePredictor model that outputs the text
in each region.
"""
from typing import Dict, List, Tuple, Union

import cv2
import numpy as np
import torch

from text_recognizer.models import SegmentationModel, TransformerModel
from text_recognizer.util import read_image


class ParagraphTextRecognizor:
    """Given an image of a single handwritten character, recognizes it."""

    def __init__(self, line_predictor_args: Dict, line_detector_args: Dict) -> None:
        self._line_predictor = TransformerModel(**line_predictor_args)
        self._line_detector = SegmentationModel(**line_detector_args)
        self._line_detector.eval()
        self._line_predictor.eval()

    def predict(self, image_or_filename: Union[str, np.ndarray]) -> Tuple:
        """Takes an image and returns all text within it."""
        image = (
            read_image(image_or_filename)
            if isinstance(image_or_filename, str)
            else image_or_filename
        )

        line_region_crops = self._get_line_region_crops(image)
        processed_line_region_crops = [
            self._process_image_for_line_predictor(image=crop)
            for crop in line_region_crops
        ]
        line_region_strings = [
            self.line_predictor_model.predict_on_image(crop)[0]
            for crop in processed_line_region_crops
        ]

        return " ".join(line_region_strings), line_region_crops

    def _get_line_region_crops(
        self, image: np.ndarray, min_crop_len_factor: float = 0.02
    ) -> List[np.ndarray]:
        """Returns all the crops of text lines in a square image."""
        processed_image, scale_down_factor = self._process_image_for_line_detector(
            image
        )
        line_segmentation = self._line_detector.predict_on_image(processed_image)
        bounding_boxes = _find_line_bounding_boxes(line_segmentation)

        bounding_boxes = (bounding_boxes * scale_down_factor).astype(int)

        min_crop_len = int(min_crop_len_factor * min(image.shape[0], image.shape[1]))
        line_region_crops = [
            image[y : y + h, x : x + w]
            for x, y, w, h in bounding_boxes
            if w >= min_crop_len and h >= min_crop_len
        ]
        return line_region_crops

    def _process_image_for_line_detector(
        self, image: np.ndarray
    ) -> Tuple[np.ndarray, float]:
        """Convert uint8 image to float image with black background with shape self._line_detector.image_shape."""
        resized_image, scale_down_factor = _resize_image_for_line_detector(
            image=image, max_shape=self._line_detector.image_shape
        )
        resized_image = (1.0 - resized_image / 255).astype("float32")
        return resized_image, scale_down_factor

    def _process_image_for_line_predictor(self, image: np.ndarray) -> np.ndarray:
        """Preprocessing of image before feeding it to the LinePrediction model.

        Convert uint8 image to float image with black background with shape
        self._line_predictor.image_shape while maintaining the image aspect ratio.

        Args:
            image (np.ndarray): Crop of text line.

        Returns:
            np.ndarray: Processed crop for feeding line predictor.
        """
        expected_shape = self._line_detector.image_shape
        scale_factor = (np.array(expected_shape) / np.array(image.shape)).min()
        scaled_image = cv2.resize(
            image,
            dsize=None,
            fx=scale_factor,
            fy=scale_factor,
            interpolation=cv2.INTER_AREA,
        )

        pad_with = (
            (0, expected_shape[0] - scaled_image.shape[0]),
            (0, expected_shape[1] - scaled_image.shape[1]),
        )

        padded_image = np.pad(
            scaled_image, pad_with=pad_with, mode="constant", constant_values=255
        )
        return 1 - padded_image / 255


def _find_line_bounding_boxes(line_segmentation: np.ndarray) -> np.ndarray:
    """Given a line segmentation, find bounding boxes for connected-component regions corresponding to non-0 labels."""

    def _find_line_bounding_boxes_in_channel(
        line_segmentation_channel: np.ndarray,
    ) -> np.ndarray:
        line_segmentation_image = cv2.dilate(
            line_segmentation_channel, kernel=np.ones((3, 3)), iterations=1
        )
        line_activation_image = (line_segmentation_image * 255).astype("uint8")
        line_activation_image = cv2.threshold(
            line_activation_image, 0.5, 1, cv2.THRESH_BINARY | cv2.THRESH_OTSU
        )[1]

        bounding_cnts, _ = cv2.findContours(
            line_segmentation_image, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
        )
        return np.array([cv2.boundingRect(cnt) for cnt in bounding_cnts])

    bounding_boxes = np.concatenate(
        [
            _find_line_bounding_boxes_in_channel(line_segmentation[:, :, i])
            for i in [1, 2]
        ],
        axis=0,
    )

    return bounding_boxes[np.argsort(bounding_boxes[:, 1])]


def _resize_image_for_line_detector(
    image: np.ndarray, max_shape: Tuple[int, int]
) -> Tuple[np.ndarray, float]:
    """Resize the image to less than the max_shape while maintaining the aspect ratio."""
    scale_down_factor = max(np.ndarray(image.shape) / np.ndarray(max_shape))
    if scale_down_factor == 1:
        return image.copy(), scale_down_factor
    resize_image = cv2.resize(
        image,
        dsize=None,
        fx=1 / scale_down_factor,
        fy=1 / scale_down_factor,
        interpolation=cv2.INTER_AREA,
    )
    return resize_image, scale_down_factor