diff options
author | aktersnurra <gustaf.rydholm@gmail.com> | 2020-08-09 23:24:02 +0200 |
---|---|---|
committer | aktersnurra <gustaf.rydholm@gmail.com> | 2020-08-09 23:24:02 +0200 |
commit | 53677be4ec14854ea4881b0d78730e0414c8dedd (patch) | |
tree | 56eaace5e9906c7d408b6a251ca100b5c8b4e991 /src/text_recognizer/models/character_model.py | |
parent | 125d5da5fb845d03bda91426e172bca7f537584a (diff) |
Working bash scripts etc.
Diffstat (limited to 'src/text_recognizer/models/character_model.py')
-rw-r--r-- | src/text_recognizer/models/character_model.py | 15 |
1 files changed, 1 insertions, 14 deletions
diff --git a/src/text_recognizer/models/character_model.py b/src/text_recognizer/models/character_model.py index f1dabb7..0a0ab2d 100644 --- a/src/text_recognizer/models/character_model.py +++ b/src/text_recognizer/models/character_model.py @@ -6,10 +6,6 @@ import torch from torch import nn from torchvision.transforms import ToTensor -from text_recognizer.datasets.emnist_dataset import ( - _augment_emnist_mapping, - _load_emnist_essentials, -) from text_recognizer.models.base import Model @@ -20,7 +16,6 @@ class CharacterModel(Model): self, network_fn: Type[nn.Module], network_args: Optional[Dict] = None, - data_loader: Optional[Callable] = None, data_loader_args: Optional[Dict] = None, metrics: Optional[Dict] = None, criterion: Optional[Callable] = None, @@ -36,7 +31,6 @@ class CharacterModel(Model): super().__init__( network_fn, network_args, - data_loader, data_loader_args, metrics, criterion, @@ -47,16 +41,9 @@ class CharacterModel(Model): lr_scheduler_args, device, ) - if self.mapping is None: - self.load_mapping() self.tensor_transform = ToTensor() self.softmax = nn.Softmax(dim=0) - def load_mapping(self) -> None: - """Mapping between integers and classes.""" - essentials = _load_emnist_essentials() - self._mapping = _augment_emnist_mapping(dict(essentials["mapping"])) - def predict_on_image( self, image: Union[np.ndarray, torch.Tensor] ) -> Tuple[str, float]: @@ -86,6 +73,6 @@ class CharacterModel(Model): index = int(torch.argmax(prediction, dim=0)) confidence_of_prediction = prediction[index] - predicted_character = self._mapping[index] + predicted_character = self.mapper(index) return predicted_character, confidence_of_prediction |