diff options
author | aktersnurra <gustaf.rydholm@gmail.com> | 2020-08-20 22:18:35 +0200 |
---|---|---|
committer | aktersnurra <gustaf.rydholm@gmail.com> | 2020-08-20 22:18:35 +0200 |
commit | 1f459ba19422593de325983040e176f97cf4ffc0 (patch) | |
tree | 89fef442d5dbe0c83253e9566d1762f0704f64e2 /src/text_recognizer/datasets/emnist_lines_dataset.py | |
parent | 95cbdf5bc1cc9639febda23c28d8f464c998b214 (diff) |
A lot of stuff working :D. ResNet implemented!
Diffstat (limited to 'src/text_recognizer/datasets/emnist_lines_dataset.py')
-rw-r--r-- | src/text_recognizer/datasets/emnist_lines_dataset.py | 9 |
1 files changed, 4 insertions, 5 deletions
diff --git a/src/text_recognizer/datasets/emnist_lines_dataset.py b/src/text_recognizer/datasets/emnist_lines_dataset.py index d64a991..b0617f5 100644 --- a/src/text_recognizer/datasets/emnist_lines_dataset.py +++ b/src/text_recognizer/datasets/emnist_lines_dataset.py @@ -8,6 +8,7 @@ import h5py from loguru import logger import numpy as np import torch +from torch import Tensor from torch.utils.data import DataLoader, Dataset from torchvision.transforms import Compose, Normalize, ToTensor @@ -87,16 +88,14 @@ class EmnistLinesDataset(Dataset): """Returns the length of the dataset.""" return len(self.data) - def __getitem__( - self, index: Union[int, torch.Tensor] - ) -> Tuple[torch.Tensor, torch.Tensor]: + def __getitem__(self, index: Union[int, Tensor]) -> Tuple[Tensor, Tensor]: """Fetches data, target pair of the dataset for a given and index or indices. Args: - index (Union[int, torch.Tensor]): Either a list or int of indices/index. + index (Union[int, Tensor]): Either a list or int of indices/index. Returns: - Tuple[torch.Tensor, torch.Tensor]: Data target pair. + Tuple[Tensor, Tensor]: Data target pair. """ if torch.is_tensor(index): |