summaryrefslogtreecommitdiff
path: root/text_recognizer/networks
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-04-05 23:13:34 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-04-05 23:13:34 +0200
commit370c3eb35f47e64eab7926a5f995947c63c6b208 (patch)
tree9640c06a06b20ca1a0f17aa6893e9b83d36e450f /text_recognizer/networks
parent38202e9c6c1155d96ee0f6e9f337022ee4eeb7e3 (diff)
Configure backbone commented out, transformer init change
Diffstat (limited to 'text_recognizer/networks')
-rw-r--r--text_recognizer/networks/transformer/__init__.py2
-rw-r--r--text_recognizer/networks/util.py77
2 files changed, 39 insertions, 40 deletions
diff --git a/text_recognizer/networks/transformer/__init__.py b/text_recognizer/networks/transformer/__init__.py
index 652e82e..627fa7b 100644
--- a/text_recognizer/networks/transformer/__init__.py
+++ b/text_recognizer/networks/transformer/__init__.py
@@ -4,4 +4,4 @@ from .positional_encoding import (
PositionalEncoding2D,
target_padding_mask,
)
-from .transformer import Decoder, Encoder, EncoderLayer, Transformer
+from .transformer import Decoder, DecoderLayer, Encoder, EncoderLayer, Transformer
diff --git a/text_recognizer/networks/util.py b/text_recognizer/networks/util.py
index d292680..9c6b151 100644
--- a/text_recognizer/networks/util.py
+++ b/text_recognizer/networks/util.py
@@ -1,7 +1,7 @@
"""Miscellaneous neural network functionality."""
import importlib
from pathlib import Path
-from typing import Dict, Type
+from typing import Dict, NamedTuple, Union, Type
from loguru import logger
import torch
@@ -24,41 +24,40 @@ def activation_function(activation: str) -> Type[nn.Module]:
return activation_fns[activation.lower()]
-def configure_backbone(backbone: str, backbone_args: Dict) -> Type[nn.Module]:
- """Loads a backbone network."""
- network_module = importlib.import_module("text_recognizer.networks")
- backbone_ = getattr(network_module, backbone)
-
- if "pretrained" in backbone_args:
- logger.info("Loading pretrained backbone.")
- checkpoint_file = Path(__file__).resolve().parents[2] / backbone_args.pop(
- "pretrained"
- )
-
- # Loading state directory.
- state_dict = torch.load(checkpoint_file)
- network_args = state_dict["network_args"]
- weights = state_dict["model_state"]
-
- freeze = False
- if "freeze" in backbone_args and backbone_args["freeze"] is True:
- backbone_args.pop("freeze")
- freeze = True
- network_args = backbone_args
-
- # Initializes the network with trained weights.
- backbone = backbone_(**network_args)
- backbone.load_state_dict(weights)
- if freeze:
- for params in backbone.parameters():
- params.requires_grad = False
- else:
- backbone_ = getattr(network_module, backbone)
- backbone = backbone_(**backbone_args)
-
- if "remove_layers" in backbone_args and backbone_args["remove_layers"] is not None:
- backbone = nn.Sequential(
- *list(backbone.children())[:][: -backbone_args["remove_layers"]]
- )
-
- return backbone
+# def configure_backbone(backbone: Union[OmegaConf, NamedTuple]) -> Type[nn.Module]:
+# """Loads a backbone network."""
+# network_module = importlib.import_module("text_recognizer.networks")
+# backbone_class = getattr(network_module, backbone.type)
+#
+# if "pretrained" in backbone.args:
+# logger.info("Loading pretrained backbone.")
+# checkpoint_file = Path(__file__).resolve().parents[2] / backbone.args.pop(
+# "pretrained"
+# )
+#
+# # Loading state directory.
+# state_dict = torch.load(checkpoint_file)
+# network_args = state_dict["network_args"]
+# weights = state_dict["model_state"]
+#
+# freeze = False
+# if "freeze" in backbone.args and backbone.args["freeze"] is True:
+# backbone.args.pop("freeze")
+# freeze = True
+#
+# # Initializes the network with trained weights.
+# backbone_ = backbone_(**backbone.args)
+# backbone_.load_state_dict(weights)
+# if freeze:
+# for params in backbone_.parameters():
+# params.requires_grad = False
+# else:
+# backbone_ = getattr(network_module, backbone.type)
+# backbone_ = backbone_(**backbone.args)
+#
+# if "remove_layers" in backbone_args and backbone_args["remove_layers"] is not None:
+# backbone = nn.Sequential(
+# *list(backbone.children())[:][: -backbone_args["remove_layers"]]
+# )
+#
+# return backbone