summaryrefslogtreecommitdiff
path: root/src/text_recognizer/models
diff options
context:
space:
mode:
authoraktersnurra <gustaf.rydholm@gmail.com>2020-07-05 22:27:08 +0200
committeraktersnurra <gustaf.rydholm@gmail.com>2020-07-05 22:27:08 +0200
commit5a78fc2e33c28968a69d033cb10d638f4f63fed1 (patch)
treee11cea8366c848e5500f85968ee5369ff8d96b00 /src/text_recognizer/models
parent7c4de6d88664d2ea1b084f316a11896dde3e1150 (diff)
Working on getting experiment loop.
Diffstat (limited to 'src/text_recognizer/models')
-rw-r--r--src/text_recognizer/models/base.py12
-rw-r--r--src/text_recognizer/models/character_model.py8
-rw-r--r--src/text_recognizer/models/metrics.py (renamed from src/text_recognizer/models/util.py)2
3 files changed, 13 insertions, 9 deletions
diff --git a/src/text_recognizer/models/base.py b/src/text_recognizer/models/base.py
index 736af7b..0cc531a 100644
--- a/src/text_recognizer/models/base.py
+++ b/src/text_recognizer/models/base.py
@@ -10,7 +10,6 @@ import torch
from torch import nn
from torchsummary import summary
-from text_recognizer.dataset.data_loader import fetch_data_loader
WEIGHT_DIRNAME = Path(__file__).parents[1].resolve() / "weights"
@@ -22,6 +21,7 @@ class Model(ABC):
self,
network_fn: Callable,
network_args: Dict,
+ data_loader: Optional[Callable] = None,
data_loader_args: Optional[Dict] = None,
metrics: Optional[Dict] = None,
criterion: Optional[Callable] = None,
@@ -32,12 +32,13 @@ class Model(ABC):
lr_scheduler_args: Optional[Dict] = None,
device: Optional[str] = None,
) -> None:
- """Base class, to be inherited by predictors for specific type of data.
+ """Base class, to be inherited by model for specific type of data.
Args:
network_fn (Callable): The PyTorch network.
network_args (Dict): Arguments for the network.
- data_loader_args (Optional[Dict]): Arguments for the data loader.
+ data_loader (Optional[Callable]): A function that fetches train and val DataLoader.
+ data_loader_args (Optional[Dict]): Arguments for the DataLoader.
metrics (Optional[Dict]): Metrics to evaluate the performance with. Defaults to None.
criterion (Optional[Callable]): The criterion to evaulate the preformance of the network.
Defaults to None.
@@ -53,8 +54,8 @@ class Model(ABC):
# Fetch data loaders.
if data_loader_args is not None:
- self._data_loaders = fetch_data_loader(**data_loader_args)
- dataset_name = self._data_loaders.items()[0].dataset.__name__
+ self._data_loaders = data_loader(**data_loader_args)
+ dataset_name = self._data_loaders.__name__
else:
dataset_name = ""
self._data_loaders = None
@@ -210,7 +211,6 @@ class Model(ABC):
logger.debug(
f"Found a new best {val_metric}. Saving best checkpoint and weights."
)
- self.save_weights()
shutil.copyfile(filepath, str(path / "best.pt"))
def load_weights(self) -> None:
diff --git a/src/text_recognizer/models/character_model.py b/src/text_recognizer/models/character_model.py
index 1570344..fd69bf2 100644
--- a/src/text_recognizer/models/character_model.py
+++ b/src/text_recognizer/models/character_model.py
@@ -32,17 +32,21 @@ class CharacterModel(Model):
super().__init__(
network_fn,
- data_loader_args,
network_args,
+ data_loader_args,
metrics,
criterion,
+ criterion_args,
optimizer,
+ optimizer_args,
+ lr_scheduler,
+ lr_scheduler_args,
device,
)
self.emnist_mapping = self.mapping()
self.eval()
- def mapping(self) -> Dict:
+ def mapping(self) -> Dict[int, str]:
"""Mapping between integers and classes."""
mapping = load_emnist_mapping()
return mapping
diff --git a/src/text_recognizer/models/util.py b/src/text_recognizer/models/metrics.py
index 905fe7b..e2a30a9 100644
--- a/src/text_recognizer/models/util.py
+++ b/src/text_recognizer/models/metrics.py
@@ -4,7 +4,7 @@ import torch
def accuracy(outputs: torch.Tensor, labels: torch.Tensro) -> float:
- """Short summary.
+ """Computes the accuracy.
Args:
outputs (torch.Tensor): The output from the network.