summaryrefslogtreecommitdiff
path: root/src/text_recognizer/models/character_model.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/text_recognizer/models/character_model.py')
-rw-r--r--src/text_recognizer/models/character_model.py15
1 files changed, 1 insertions, 14 deletions
diff --git a/src/text_recognizer/models/character_model.py b/src/text_recognizer/models/character_model.py
index f1dabb7..0a0ab2d 100644
--- a/src/text_recognizer/models/character_model.py
+++ b/src/text_recognizer/models/character_model.py
@@ -6,10 +6,6 @@ import torch
from torch import nn
from torchvision.transforms import ToTensor
-from text_recognizer.datasets.emnist_dataset import (
- _augment_emnist_mapping,
- _load_emnist_essentials,
-)
from text_recognizer.models.base import Model
@@ -20,7 +16,6 @@ class CharacterModel(Model):
self,
network_fn: Type[nn.Module],
network_args: Optional[Dict] = None,
- data_loader: Optional[Callable] = None,
data_loader_args: Optional[Dict] = None,
metrics: Optional[Dict] = None,
criterion: Optional[Callable] = None,
@@ -36,7 +31,6 @@ class CharacterModel(Model):
super().__init__(
network_fn,
network_args,
- data_loader,
data_loader_args,
metrics,
criterion,
@@ -47,16 +41,9 @@ class CharacterModel(Model):
lr_scheduler_args,
device,
)
- if self.mapping is None:
- self.load_mapping()
self.tensor_transform = ToTensor()
self.softmax = nn.Softmax(dim=0)
- def load_mapping(self) -> None:
- """Mapping between integers and classes."""
- essentials = _load_emnist_essentials()
- self._mapping = _augment_emnist_mapping(dict(essentials["mapping"]))
-
def predict_on_image(
self, image: Union[np.ndarray, torch.Tensor]
) -> Tuple[str, float]:
@@ -86,6 +73,6 @@ class CharacterModel(Model):
index = int(torch.argmax(prediction, dim=0))
confidence_of_prediction = prediction[index]
- predicted_character = self._mapping[index]
+ predicted_character = self.mapper(index)
return predicted_character, confidence_of_prediction