diff options
author | aktersnurra <gustaf.rydholm@gmail.com> | 2020-09-08 23:14:23 +0200 |
---|---|---|
committer | aktersnurra <gustaf.rydholm@gmail.com> | 2020-09-08 23:14:23 +0200 |
commit | e1b504bca41a9793ed7e88ef14f2e2cbd85724f2 (patch) | |
tree | 70b482f890c9ad2be104f0bff8f2172e8411a2be /src/text_recognizer/models/character_model.py | |
parent | fe23001b6588e6e6e9e2c5a99b72f3445cf5206f (diff) |
IAM datasets implemented.
Diffstat (limited to 'src/text_recognizer/models/character_model.py')
-rw-r--r-- | src/text_recognizer/models/character_model.py | 20 |
1 files changed, 16 insertions, 4 deletions
diff --git a/src/text_recognizer/models/character_model.py b/src/text_recognizer/models/character_model.py index 0fd7afd..64ba693 100644 --- a/src/text_recognizer/models/character_model.py +++ b/src/text_recognizer/models/character_model.py @@ -4,8 +4,10 @@ from typing import Callable, Dict, Optional, Tuple, Type, Union import numpy as np import torch from torch import nn +from torch.utils.data import Dataset from torchvision.transforms import ToTensor +from text_recognizer.datasets import EmnistMapper from text_recognizer.models.base import Model @@ -15,8 +17,9 @@ class CharacterModel(Model): def __init__( self, network_fn: Type[nn.Module], + dataset: Type[Dataset], network_args: Optional[Dict] = None, - data_loader_args: Optional[Dict] = None, + dataset_args: Optional[Dict] = None, metrics: Optional[Dict] = None, criterion: Optional[Callable] = None, criterion_args: Optional[Dict] = None, @@ -24,14 +27,16 @@ class CharacterModel(Model): optimizer_args: Optional[Dict] = None, lr_scheduler: Optional[Callable] = None, lr_scheduler_args: Optional[Dict] = None, + swa_args: Optional[Dict] = None, device: Optional[str] = None, ) -> None: """Initializes the CharacterModel.""" super().__init__( network_fn, + dataset, network_args, - data_loader_args, + dataset_args, metrics, criterion, criterion_args, @@ -39,8 +44,11 @@ class CharacterModel(Model): optimizer_args, lr_scheduler, lr_scheduler_args, + swa_args, device, ) + if self._mapper is None: + self._mapper = EmnistMapper() self.tensor_transform = ToTensor() self.softmax = nn.Softmax(dim=0) @@ -67,9 +75,13 @@ class CharacterModel(Model): # Put the image tensor on the device the model weights are on. image = image.to(self.device) - logits = self.network(image) + logits = ( + self.swa_network(image) + if self.swa_network is not None + else self.network(image) + ) - prediction = self.softmax(logits.data.squeeze()) + prediction = self.softmax(logits.squeeze(0)) index = int(torch.argmax(prediction, dim=0)) confidence_of_prediction = prediction[index] |