summaryrefslogtreecommitdiff
path: root/src/text_recognizer/datasets/emnist_lines_dataset.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/text_recognizer/datasets/emnist_lines_dataset.py')
-rw-r--r--src/text_recognizer/datasets/emnist_lines_dataset.py9
1 files changed, 4 insertions, 5 deletions
diff --git a/src/text_recognizer/datasets/emnist_lines_dataset.py b/src/text_recognizer/datasets/emnist_lines_dataset.py
index d64a991..b0617f5 100644
--- a/src/text_recognizer/datasets/emnist_lines_dataset.py
+++ b/src/text_recognizer/datasets/emnist_lines_dataset.py
@@ -8,6 +8,7 @@ import h5py
from loguru import logger
import numpy as np
import torch
+from torch import Tensor
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import Compose, Normalize, ToTensor
@@ -87,16 +88,14 @@ class EmnistLinesDataset(Dataset):
"""Returns the length of the dataset."""
return len(self.data)
- def __getitem__(
- self, index: Union[int, torch.Tensor]
- ) -> Tuple[torch.Tensor, torch.Tensor]:
+ def __getitem__(self, index: Union[int, Tensor]) -> Tuple[Tensor, Tensor]:
"""Fetches data, target pair of the dataset for a given and index or indices.
Args:
- index (Union[int, torch.Tensor]): Either a list or int of indices/index.
+ index (Union[int, Tensor]): Either a list or int of indices/index.
Returns:
- Tuple[torch.Tensor, torch.Tensor]: Data target pair.
+ Tuple[Tensor, Tensor]: Data target pair.
"""
if torch.is_tensor(index):