summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/text_recognizer/models/base.py4
-rw-r--r--src/text_recognizer/networks/line_lstm_ctc.py1
-rw-r--r--src/text_recognizer/weights/CharacterModel_EmnistDataset_ResidualNetwork_weights.ptbin28654593 -> 4562821 bytes
-rw-r--r--src/training/experiments/sample_experiment.yml4
4 files changed, 5 insertions, 4 deletions
diff --git a/src/text_recognizer/models/base.py b/src/text_recognizer/models/base.py
index 74fd223..3a84a11 100644
--- a/src/text_recognizer/models/base.py
+++ b/src/text_recognizer/models/base.py
@@ -139,10 +139,10 @@ class Model(ABC):
else:
_optimizer = None
- if self._optimizer and lr_scheduler is not None:
+ if _optimizer and lr_scheduler is not None:
if "OneCycleLR" in str(lr_scheduler):
lr_scheduler_args["steps_per_epoch"] = len(self._data_loaders["train"])
- _lr_scheduler = lr_scheduler(self._optimizer, **lr_scheduler_args)
+ _lr_scheduler = lr_scheduler(_optimizer, **lr_scheduler_args)
else:
_lr_scheduler = None
diff --git a/src/text_recognizer/networks/line_lstm_ctc.py b/src/text_recognizer/networks/line_lstm_ctc.py
index d704139..2e2c3a5 100644
--- a/src/text_recognizer/networks/line_lstm_ctc.py
+++ b/src/text_recognizer/networks/line_lstm_ctc.py
@@ -2,3 +2,4 @@
import torch
from torch import nn
+from torch import Tensor
diff --git a/src/text_recognizer/weights/CharacterModel_EmnistDataset_ResidualNetwork_weights.pt b/src/text_recognizer/weights/CharacterModel_EmnistDataset_ResidualNetwork_weights.pt
index 008beb2..a5c6aaf 100644
--- a/src/text_recognizer/weights/CharacterModel_EmnistDataset_ResidualNetwork_weights.pt
+++ b/src/text_recognizer/weights/CharacterModel_EmnistDataset_ResidualNetwork_weights.pt
Binary files differ
diff --git a/src/training/experiments/sample_experiment.yml b/src/training/experiments/sample_experiment.yml
index bae02ac..b00bd5a 100644
--- a/src/training/experiments/sample_experiment.yml
+++ b/src/training/experiments/sample_experiment.yml
@@ -26,8 +26,8 @@ experiments:
network_args:
in_channels: 1
num_classes: 80
- depths: [1, 1]
- block_sizes: [128, 256]
+ depths: [2, 1]
+ block_sizes: [96, 32]
# network: LeNet
# network_args:
# output_size: 62