From 75801019981492eedf9280cb352eea3d8e99b65f Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Mon, 2 Aug 2021 21:13:48 +0200 Subject: Fix log import, fix mapping in datamodules, fix nn modules can be hashed --- text_recognizer/data/mappings.py | 24 +++++++++++++++++------- 1 file changed, 17 insertions(+), 7 deletions(-) (limited to 'text_recognizer/data/mappings.py') diff --git a/text_recognizer/data/mappings.py b/text_recognizer/data/mappings.py index b69e888..d1c64dd 100644 --- a/text_recognizer/data/mappings.py +++ b/text_recognizer/data/mappings.py @@ -1,18 +1,30 @@ """Mapping to and from word pieces.""" from abc import ABC, abstractmethod from pathlib import Path -from typing import Dict, List, Optional, Union, Set, Sequence +from typing import Dict, List, Optional, Union, Set import attr -import loguru.logger as log import torch +from loguru import logger as log from torch import Tensor from text_recognizer.data.emnist import emnist_mapping from text_recognizer.data.iam_preprocessor import Preprocessor +@attr.s class AbstractMapping(ABC): + input_size: List[int] = attr.ib(init=False) + mapping: List[str] = attr.ib(init=False) + inverse_mapping: Dict[str, int] = attr.ib(init=False) + + def __len__(self) -> int: + return len(self.mapping) + + @property + def num_classes(self) -> int: + return self.__len__() + @abstractmethod def get_token(self, *args, **kwargs) -> str: ... @@ -30,15 +42,13 @@ class AbstractMapping(ABC): ... -@attr.s +@attr.s(auto_attribs=True) class EmnistMapping(AbstractMapping): - extra_symbols: Optional[Set[str]] = attr.ib(default=None, converter=set) - mapping: Sequence[str] = attr.ib(init=False) - inverse_mapping: Dict[str, int] = attr.ib(init=False) - input_size: List[int] = attr.ib(init=False) + extra_symbols: Optional[Set[str]] = attr.ib(default=None) def __attrs_post_init__(self) -> None: """Post init configuration.""" + self.extra_symbols = set(self.extra_symbols) if self.extra_symbols is not None else None self.mapping, self.inverse_mapping, self.input_size = emnist_mapping( self.extra_symbols ) -- cgit v1.2.3-70-g09d2