summaryrefslogtreecommitdiff
path: root/src/text_recognizer/models/base.py
diff options
context:
space:
mode:
authoraktersnurra <gustaf.rydholm@gmail.com>2020-12-07 22:54:04 +0100
committeraktersnurra <gustaf.rydholm@gmail.com>2020-12-07 22:54:04 +0100
commit25b5d6983d51e0e791b96a76beb7e49f392cd9a8 (patch)
tree526ba739714b3d040f7810c1a6be3ff0ba37fdb1 /src/text_recognizer/models/base.py
parent5529e0fc9ca39e81fe0f08a54f257d32f0afe120 (diff)
Segmentation working!
Diffstat (limited to 'src/text_recognizer/models/base.py')
-rw-r--r--src/text_recognizer/models/base.py11
1 files changed, 6 insertions, 5 deletions
diff --git a/src/text_recognizer/models/base.py b/src/text_recognizer/models/base.py
index d394b4c..f2cd4b8 100644
--- a/src/text_recognizer/models/base.py
+++ b/src/text_recognizer/models/base.py
@@ -159,7 +159,7 @@ class Model(ABC):
self.test_dataset = self.dataset(train=False, **self.dataset_args["args"])
self.test_dataset.load_or_generate_data()
- # Set the flag to true to disable ability to load data agian.
+ # Set the flag to true to disable ability to load data again.
self.data_prepared = True
def train_dataloader(self) -> DataLoader:
@@ -260,7 +260,7 @@ class Model(ABC):
@property
def mapping(self) -> Dict:
"""Returns the mapping between network output and Emnist character."""
- return self._mapper.mapping
+ return self._mapper.mapping if self._mapper is not None else None
def eval(self) -> None:
"""Sets the network to evaluation mode."""
@@ -341,7 +341,7 @@ class Model(ABC):
if input_shape is not None:
summary(self.network, input_shape, depth=depth, device=device)
elif self._input_shape is not None:
- input_shape = (1,) + tuple(self._input_shape)
+ input_shape = tuple(self._input_shape)
summary(self.network, input_shape, depth=depth, device=device)
else:
logger.warning("Could not print summary as input shape is not set.")
@@ -427,7 +427,7 @@ class Model(ABC):
)
shutil.copyfile(filepath, str(checkpoint_path / "best.pt"))
- def load_weights(self, network_fn: Type[nn.Module]) -> None:
+ def load_weights(self, network_fn: Optional[Type[nn.Module]] = None) -> None:
"""Load the network weights."""
logger.debug("Loading network with pretrained weights.")
filename = glob(self.weights_filename)[0]
@@ -441,7 +441,8 @@ class Model(ABC):
weights = state_dict["model_state"]
# Initializes the network with trained weights.
- self._network = network_fn(**self._network_args)
+ if network_fn is not None:
+ self._network = network_fn(**self._network_args)
self._network.load_state_dict(weights)
if "swa_network" in state_dict: