summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/base.py
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-07-29 23:59:52 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-07-29 23:59:52 +0200
commit34098ccbbbf6379c0bd29a987440b8479c743746 (patch)
treea8c68e3036503049fc7034c677ec855465f7a8e0 /text_recognizer/networks/base.py
parentc032ffb05a7ed86f8fe5d596f94e8997c558cae8 (diff)
Configs, refactor with attrs, fix attr bug in iam
Diffstat (limited to 'text_recognizer/networks/base.py')
-rw-r--r--text_recognizer/networks/base.py18
1 files changed, 18 insertions, 0 deletions
diff --git a/text_recognizer/networks/base.py b/text_recognizer/networks/base.py
new file mode 100644
index 0000000..07b6a32
--- /dev/null
+++ b/text_recognizer/networks/base.py
@@ -0,0 +1,18 @@
+"""Base network with required methods."""
+from abc import abstractmethod
+
+import attr
+from torch import nn, Tensor
+
+
+@attr.s
+class BaseNetwork(nn.Module):
+ """Base network."""
+
+ def __attrs_pre_init__(self) -> None:
+ super().__init__()
+
+ @abstractmethod
+ def predict(self, x: Tensor) -> Tensor:
+ """Return token indices for predictions."""
+ ...