From 75909723fa2b1f6245d5c5422e4f2e88b8a26052 Mon Sep 17 00:00:00 2001 From: aktersnurra Date: Sun, 15 Nov 2020 17:40:44 +0100 Subject: Able to generate support files for lines datasets. --- src/text_recognizer/models/base.py | 51 ++++++++++---------------------------- 1 file changed, 13 insertions(+), 38 deletions(-) (limited to 'src/text_recognizer/models') diff --git a/src/text_recognizer/models/base.py b/src/text_recognizer/models/base.py index a945b41..d394b4c 100644 --- a/src/text_recognizer/models/base.py +++ b/src/text_recognizer/models/base.py @@ -15,8 +15,9 @@ from torch import Tensor from torch.optim.swa_utils import AveragedModel, SWALR from torch.utils.data import DataLoader, Dataset, random_split from torchsummary import summary -from torchvision.transforms import Compose +from text_recognizer import datasets +from text_recognizer import networks from text_recognizer.datasets import EmnistMapper WEIGHT_DIRNAME = Path(__file__).parents[1].resolve() / "weights" @@ -27,8 +28,8 @@ class Model(ABC): def __init__( self, - network_fn: Type[nn.Module], - dataset: Type[Dataset], + network_fn: str, + dataset: str, network_args: Optional[Dict] = None, dataset_args: Optional[Dict] = None, metrics: Optional[Dict] = None, @@ -44,8 +45,8 @@ class Model(ABC): """Base class, to be inherited by model for specific type of data. Args: - network_fn (Type[nn.Module]): The PyTorch network. - dataset (Type[Dataset]): A dataset class. + network_fn (str): The name of network. + dataset (str): The name dataset class. network_args (Optional[Dict]): Arguments for the network. Defaults to None. dataset_args (Optional[Dict]): Arguments for the dataset. metrics (Optional[Dict]): Metrics to evaluate the performance with. Defaults to None. @@ -62,13 +63,15 @@ class Model(ABC): device (Optional[str]): Name of the device to train on. Defaults to None. """ + self._name = f"{self.__class__.__name__}_{dataset}_{network_fn}" # Has to be set in subclass. self._mapper = None # Placeholder. self._input_shape = None - self.dataset = dataset + self.dataset_name = dataset + self.dataset = None self.dataset_args = dataset_args # Placeholders for datasets. @@ -92,10 +95,6 @@ class Model(ABC): # Flag for stopping training. self.stop_training = False - self._name = ( - f"{self.__class__.__name__}_{dataset.__name__}_{network_fn.__name__}" - ) - self._metrics = metrics if metrics is not None else None # Set the device. @@ -132,38 +131,12 @@ class Model(ABC): # Set this flag to true to prevent the model from configuring again. self.is_configured = True - def _configure_transforms(self) -> None: - # Load transforms. - transforms_module = importlib.import_module( - "text_recognizer.datasets.transforms" - ) - if ( - "transform" in self.dataset_args["args"] - and self.dataset_args["args"]["transform"] is not None - ): - transform_ = [] - for t in self.dataset_args["args"]["transform"]: - args = t["args"] or {} - transform_.append(getattr(transforms_module, t["type"])(**args)) - self.dataset_args["args"]["transform"] = Compose(transform_) - - if ( - "target_transform" in self.dataset_args["args"] - and self.dataset_args["args"]["target_transform"] is not None - ): - target_transform_ = [ - torch.tensor, - ] - for t in self.dataset_args["args"]["target_transform"]: - args = t["args"] or {} - target_transform_.append(getattr(transforms_module, t["type"])(**args)) - self.dataset_args["args"]["target_transform"] = Compose(target_transform_) - def prepare_data(self) -> None: """Prepare data for training.""" # TODO add downloading. if not self.data_prepared: - self._configure_transforms() + # Load dataset module. + self.dataset = getattr(datasets, self.dataset_name) # Load train dataset. train_dataset = self.dataset(train=True, **self.dataset_args["args"]) @@ -222,6 +195,8 @@ class Model(ABC): def _configure_network(self, network_fn: Type[nn.Module]) -> None: """Loads the network.""" # If no network arguments are given, load pretrained weights if they exist. + # Load network module. + network_fn = getattr(networks, network_fn) if self._network_args is None: self.load_weights(network_fn) else: -- cgit v1.2.3-70-g09d2