summaryrefslogtreecommitdiff
path: root/text_recognizer/models/base.py
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/models/base.py')
-rw-r--r--text_recognizer/models/base.py4
1 files changed, 1 insertions, 3 deletions
diff --git a/text_recognizer/models/base.py b/text_recognizer/models/base.py
index 3e02261..dfb4ca4 100644
--- a/text_recognizer/models/base.py
+++ b/text_recognizer/models/base.py
@@ -11,8 +11,6 @@ from torch import nn
from torch import Tensor
import torchmetrics
-from text_recognizer.networks.base import BaseNetwork
-
@attr.s
class BaseLitModel(LightningModule):
@@ -21,7 +19,7 @@ class BaseLitModel(LightningModule):
def __attrs_pre_init__(self) -> None:
super().__init__()
- network: Type[BaseNetwork] = attr.ib()
+ network: Type[nn.Module] = attr.ib()
criterion_config: DictConfig = attr.ib(converter=DictConfig)
optimizer_config: DictConfig = attr.ib(converter=DictConfig)
lr_scheduler_config: DictConfig = attr.ib(converter=DictConfig)