summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/util.py
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/networks/util.py')
-rw-r--r--text_recognizer/networks/util.py77
1 files changed, 38 insertions, 39 deletions
diff --git a/text_recognizer/networks/util.py b/text_recognizer/networks/util.py
index d292680..9c6b151 100644
--- a/text_recognizer/networks/util.py
+++ b/text_recognizer/networks/util.py
@@ -1,7 +1,7 @@
"""Miscellaneous neural network functionality."""
import importlib
from pathlib import Path
-from typing import Dict, Type
+from typing import Dict, NamedTuple, Union, Type
from loguru import logger
import torch
@@ -24,41 +24,40 @@ def activation_function(activation: str) -> Type[nn.Module]:
return activation_fns[activation.lower()]
-def configure_backbone(backbone: str, backbone_args: Dict) -> Type[nn.Module]:
- """Loads a backbone network."""
- network_module = importlib.import_module("text_recognizer.networks")
- backbone_ = getattr(network_module, backbone)
-
- if "pretrained" in backbone_args:
- logger.info("Loading pretrained backbone.")
- checkpoint_file = Path(__file__).resolve().parents[2] / backbone_args.pop(
- "pretrained"
- )
-
- # Loading state directory.
- state_dict = torch.load(checkpoint_file)
- network_args = state_dict["network_args"]
- weights = state_dict["model_state"]
-
- freeze = False
- if "freeze" in backbone_args and backbone_args["freeze"] is True:
- backbone_args.pop("freeze")
- freeze = True
- network_args = backbone_args
-
- # Initializes the network with trained weights.
- backbone = backbone_(**network_args)
- backbone.load_state_dict(weights)
- if freeze:
- for params in backbone.parameters():
- params.requires_grad = False
- else:
- backbone_ = getattr(network_module, backbone)
- backbone = backbone_(**backbone_args)
-
- if "remove_layers" in backbone_args and backbone_args["remove_layers"] is not None:
- backbone = nn.Sequential(
- *list(backbone.children())[:][: -backbone_args["remove_layers"]]
- )
-
- return backbone
+# def configure_backbone(backbone: Union[OmegaConf, NamedTuple]) -> Type[nn.Module]:
+# """Loads a backbone network."""
+# network_module = importlib.import_module("text_recognizer.networks")
+# backbone_class = getattr(network_module, backbone.type)
+#
+# if "pretrained" in backbone.args:
+# logger.info("Loading pretrained backbone.")
+# checkpoint_file = Path(__file__).resolve().parents[2] / backbone.args.pop(
+# "pretrained"
+# )
+#
+# # Loading state directory.
+# state_dict = torch.load(checkpoint_file)
+# network_args = state_dict["network_args"]
+# weights = state_dict["model_state"]
+#
+# freeze = False
+# if "freeze" in backbone.args and backbone.args["freeze"] is True:
+# backbone.args.pop("freeze")
+# freeze = True
+#
+# # Initializes the network with trained weights.
+# backbone_ = backbone_(**backbone.args)
+# backbone_.load_state_dict(weights)
+# if freeze:
+# for params in backbone_.parameters():
+# params.requires_grad = False
+# else:
+# backbone_ = getattr(network_module, backbone.type)
+# backbone_ = backbone_(**backbone.args)
+#
+# if "remove_layers" in backbone_args and backbone_args["remove_layers"] is not None:
+# backbone = nn.Sequential(
+# *list(backbone.children())[:][: -backbone_args["remove_layers"]]
+# )
+#
+# return backbone