summaryrefslogtreecommitdiff
path: root/src/text_recognizer/models/character_model.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/text_recognizer/models/character_model.py')
-rw-r--r--src/text_recognizer/models/character_model.py20
1 files changed, 16 insertions, 4 deletions
diff --git a/src/text_recognizer/models/character_model.py b/src/text_recognizer/models/character_model.py
index 0fd7afd..64ba693 100644
--- a/src/text_recognizer/models/character_model.py
+++ b/src/text_recognizer/models/character_model.py
@@ -4,8 +4,10 @@ from typing import Callable, Dict, Optional, Tuple, Type, Union
import numpy as np
import torch
from torch import nn
+from torch.utils.data import Dataset
from torchvision.transforms import ToTensor
+from text_recognizer.datasets import EmnistMapper
from text_recognizer.models.base import Model
@@ -15,8 +17,9 @@ class CharacterModel(Model):
def __init__(
self,
network_fn: Type[nn.Module],
+ dataset: Type[Dataset],
network_args: Optional[Dict] = None,
- data_loader_args: Optional[Dict] = None,
+ dataset_args: Optional[Dict] = None,
metrics: Optional[Dict] = None,
criterion: Optional[Callable] = None,
criterion_args: Optional[Dict] = None,
@@ -24,14 +27,16 @@ class CharacterModel(Model):
optimizer_args: Optional[Dict] = None,
lr_scheduler: Optional[Callable] = None,
lr_scheduler_args: Optional[Dict] = None,
+ swa_args: Optional[Dict] = None,
device: Optional[str] = None,
) -> None:
"""Initializes the CharacterModel."""
super().__init__(
network_fn,
+ dataset,
network_args,
- data_loader_args,
+ dataset_args,
metrics,
criterion,
criterion_args,
@@ -39,8 +44,11 @@ class CharacterModel(Model):
optimizer_args,
lr_scheduler,
lr_scheduler_args,
+ swa_args,
device,
)
+ if self._mapper is None:
+ self._mapper = EmnistMapper()
self.tensor_transform = ToTensor()
self.softmax = nn.Softmax(dim=0)
@@ -67,9 +75,13 @@ class CharacterModel(Model):
# Put the image tensor on the device the model weights are on.
image = image.to(self.device)
- logits = self.network(image)
+ logits = (
+ self.swa_network(image)
+ if self.swa_network is not None
+ else self.network(image)
+ )
- prediction = self.softmax(logits.data.squeeze())
+ prediction = self.softmax(logits.squeeze(0))
index = int(torch.argmax(prediction, dim=0))
confidence_of_prediction = prediction[index]