summaryrefslogtreecommitdiff
path: root/text_recognizer/util.py
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-03-20 18:09:06 +0100
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-03-20 18:09:06 +0100
commit7e8e54e84c63171e748bbf09516fd517e6821ace (patch)
tree996093f75a5d488dddf7ea1f159ed343a561ef89 /text_recognizer/util.py
parentb0719d84138b6bbe5f04a4982dfca673aea1a368 (diff)
Inital commit for refactoring to lightning
Diffstat (limited to 'text_recognizer/util.py')
-rw-r--r--text_recognizer/util.py52
1 files changed, 52 insertions, 0 deletions
diff --git a/text_recognizer/util.py b/text_recognizer/util.py
new file mode 100644
index 0000000..b431e22
--- /dev/null
+++ b/text_recognizer/util.py
@@ -0,0 +1,52 @@
+"""Utility functions for text_recognizer module."""
+import os
+from pathlib import Path
+from typing import Union
+from urllib.request import urlopen
+
+import cv2
+import numpy as np
+
+
+def read_image(image_uri: Union[Path, str], grayscale: bool = False) -> np.ndarray:
+ """Read image_uri."""
+
+ def read_image_from_filename(image_filename: str, imread_flag: int) -> np.ndarray:
+ return cv2.imread(str(image_filename), imread_flag)
+
+ def read_image_from_url(image_url: str, imread_flag: int) -> np.ndarray:
+ if image_url.lower().startswith("http"):
+ url_response = urlopen(str(image_url))
+ image_array = np.array(bytearray(url_response.read()), dtype=np.uint8)
+ return cv2.imdecode(image_array, imread_flag)
+ else:
+ raise ValueError(
+ "Url does not start with http, therefore not safe to open..."
+ ) from None
+
+ imread_flag = cv2.IMREAD_GRAYSCALE if grayscale else cv2.IMREAD_COLOR
+ local_file = os.path.exists(image_uri)
+ image = None
+
+ if local_file:
+ image = read_image_from_filename(image_uri, imread_flag)
+ else:
+ image = read_image_from_url(image_uri, imread_flag)
+
+ if image is None:
+ raise ValueError(f"Could not load image at {image_uri}")
+
+ return image
+
+
+def rescale_image(image: np.ndarray) -> np.ndarray:
+ """Rescale image from [0, 1] to [0, 255]."""
+ if image.max() <= 1.0:
+ image = 255 * (image - image.min()) / (image.max() - image.min())
+ return image
+
+
+def write_image(image: np.ndarray, filename: Union[Path, str]) -> None:
+ """Write image to file."""
+ image = rescale_image(image)
+ cv2.imwrite(str(filename), image)