summaryrefslogtreecommitdiff
path: root/text_recognizer/models/base.py
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-07-30 23:15:03 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-07-30 23:15:03 +0200
commit7268035fb9e57342612a8cc50a1fe04e8841ca2f (patch)
tree8d4cf3743975bd25f2c04d6a56ff3d4608a7e8d9 /text_recognizer/models/base.py
parent92fc1c7ed2f9f64552be8f71d9b8ab0d5a0a88d4 (diff)
attr bug fix, properly loading network
Diffstat (limited to 'text_recognizer/models/base.py')
-rw-r--r--text_recognizer/models/base.py4
1 files changed, 1 insertions, 3 deletions
diff --git a/text_recognizer/models/base.py b/text_recognizer/models/base.py
index 3e02261..dfb4ca4 100644
--- a/text_recognizer/models/base.py
+++ b/text_recognizer/models/base.py
@@ -11,8 +11,6 @@ from torch import nn
from torch import Tensor
import torchmetrics
-from text_recognizer.networks.base import BaseNetwork
-
@attr.s
class BaseLitModel(LightningModule):
@@ -21,7 +19,7 @@ class BaseLitModel(LightningModule):
def __attrs_pre_init__(self) -> None:
super().__init__()
- network: Type[BaseNetwork] = attr.ib()
+ network: Type[nn.Module] = attr.ib()
criterion_config: DictConfig = attr.ib(converter=DictConfig)
optimizer_config: DictConfig = attr.ib(converter=DictConfig)
lr_scheduler_config: DictConfig = attr.ib(converter=DictConfig)