From e1b504bca41a9793ed7e88ef14f2e2cbd85724f2 Mon Sep 17 00:00:00 2001 From: aktersnurra Date: Tue, 8 Sep 2020 23:14:23 +0200 Subject: IAM datasets implemented. --- src/text_recognizer/models/character_model.py | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) (limited to 'src/text_recognizer/models/character_model.py') diff --git a/src/text_recognizer/models/character_model.py b/src/text_recognizer/models/character_model.py index 0fd7afd..64ba693 100644 --- a/src/text_recognizer/models/character_model.py +++ b/src/text_recognizer/models/character_model.py @@ -4,8 +4,10 @@ from typing import Callable, Dict, Optional, Tuple, Type, Union import numpy as np import torch from torch import nn +from torch.utils.data import Dataset from torchvision.transforms import ToTensor +from text_recognizer.datasets import EmnistMapper from text_recognizer.models.base import Model @@ -15,8 +17,9 @@ class CharacterModel(Model): def __init__( self, network_fn: Type[nn.Module], + dataset: Type[Dataset], network_args: Optional[Dict] = None, - data_loader_args: Optional[Dict] = None, + dataset_args: Optional[Dict] = None, metrics: Optional[Dict] = None, criterion: Optional[Callable] = None, criterion_args: Optional[Dict] = None, @@ -24,14 +27,16 @@ class CharacterModel(Model): optimizer_args: Optional[Dict] = None, lr_scheduler: Optional[Callable] = None, lr_scheduler_args: Optional[Dict] = None, + swa_args: Optional[Dict] = None, device: Optional[str] = None, ) -> None: """Initializes the CharacterModel.""" super().__init__( network_fn, + dataset, network_args, - data_loader_args, + dataset_args, metrics, criterion, criterion_args, @@ -39,8 +44,11 @@ class CharacterModel(Model): optimizer_args, lr_scheduler, lr_scheduler_args, + swa_args, device, ) + if self._mapper is None: + self._mapper = EmnistMapper() self.tensor_transform = ToTensor() self.softmax = nn.Softmax(dim=0) @@ -67,9 +75,13 @@ class CharacterModel(Model): # Put the image tensor on the device the model weights are on. image = image.to(self.device) - logits = self.network(image) + logits = ( + self.swa_network(image) + if self.swa_network is not None + else self.network(image) + ) - prediction = self.softmax(logits.data.squeeze()) + prediction = self.softmax(logits.squeeze(0)) index = int(torch.argmax(prediction, dim=0)) confidence_of_prediction = prediction[index] -- cgit v1.2.3-70-g09d2