summaryrefslogtreecommitdiff
path: root/src/text_recognizer/models/base.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/text_recognizer/models/base.py')
-rw-r--r--src/text_recognizer/models/base.py51
1 files changed, 13 insertions, 38 deletions
diff --git a/src/text_recognizer/models/base.py b/src/text_recognizer/models/base.py
index a945b41..d394b4c 100644
--- a/src/text_recognizer/models/base.py
+++ b/src/text_recognizer/models/base.py
@@ -15,8 +15,9 @@ from torch import Tensor
from torch.optim.swa_utils import AveragedModel, SWALR
from torch.utils.data import DataLoader, Dataset, random_split
from torchsummary import summary
-from torchvision.transforms import Compose
+from text_recognizer import datasets
+from text_recognizer import networks
from text_recognizer.datasets import EmnistMapper
WEIGHT_DIRNAME = Path(__file__).parents[1].resolve() / "weights"
@@ -27,8 +28,8 @@ class Model(ABC):
def __init__(
self,
- network_fn: Type[nn.Module],
- dataset: Type[Dataset],
+ network_fn: str,
+ dataset: str,
network_args: Optional[Dict] = None,
dataset_args: Optional[Dict] = None,
metrics: Optional[Dict] = None,
@@ -44,8 +45,8 @@ class Model(ABC):
"""Base class, to be inherited by model for specific type of data.
Args:
- network_fn (Type[nn.Module]): The PyTorch network.
- dataset (Type[Dataset]): A dataset class.
+ network_fn (str): The name of network.
+ dataset (str): The name dataset class.
network_args (Optional[Dict]): Arguments for the network. Defaults to None.
dataset_args (Optional[Dict]): Arguments for the dataset.
metrics (Optional[Dict]): Metrics to evaluate the performance with. Defaults to None.
@@ -62,13 +63,15 @@ class Model(ABC):
device (Optional[str]): Name of the device to train on. Defaults to None.
"""
+ self._name = f"{self.__class__.__name__}_{dataset}_{network_fn}"
# Has to be set in subclass.
self._mapper = None
# Placeholder.
self._input_shape = None
- self.dataset = dataset
+ self.dataset_name = dataset
+ self.dataset = None
self.dataset_args = dataset_args
# Placeholders for datasets.
@@ -92,10 +95,6 @@ class Model(ABC):
# Flag for stopping training.
self.stop_training = False
- self._name = (
- f"{self.__class__.__name__}_{dataset.__name__}_{network_fn.__name__}"
- )
-
self._metrics = metrics if metrics is not None else None
# Set the device.
@@ -132,38 +131,12 @@ class Model(ABC):
# Set this flag to true to prevent the model from configuring again.
self.is_configured = True
- def _configure_transforms(self) -> None:
- # Load transforms.
- transforms_module = importlib.import_module(
- "text_recognizer.datasets.transforms"
- )
- if (
- "transform" in self.dataset_args["args"]
- and self.dataset_args["args"]["transform"] is not None
- ):
- transform_ = []
- for t in self.dataset_args["args"]["transform"]:
- args = t["args"] or {}
- transform_.append(getattr(transforms_module, t["type"])(**args))
- self.dataset_args["args"]["transform"] = Compose(transform_)
-
- if (
- "target_transform" in self.dataset_args["args"]
- and self.dataset_args["args"]["target_transform"] is not None
- ):
- target_transform_ = [
- torch.tensor,
- ]
- for t in self.dataset_args["args"]["target_transform"]:
- args = t["args"] or {}
- target_transform_.append(getattr(transforms_module, t["type"])(**args))
- self.dataset_args["args"]["target_transform"] = Compose(target_transform_)
-
def prepare_data(self) -> None:
"""Prepare data for training."""
# TODO add downloading.
if not self.data_prepared:
- self._configure_transforms()
+ # Load dataset module.
+ self.dataset = getattr(datasets, self.dataset_name)
# Load train dataset.
train_dataset = self.dataset(train=True, **self.dataset_args["args"])
@@ -222,6 +195,8 @@ class Model(ABC):
def _configure_network(self, network_fn: Type[nn.Module]) -> None:
"""Loads the network."""
# If no network arguments are given, load pretrained weights if they exist.
+ # Load network module.
+ network_fn = getattr(networks, network_fn)
if self._network_args is None:
self.load_weights(network_fn)
else: