diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-04-05 23:13:34 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-04-05 23:13:34 +0200 |
commit | 370c3eb35f47e64eab7926a5f995947c63c6b208 (patch) | |
tree | 9640c06a06b20ca1a0f17aa6893e9b83d36e450f /text_recognizer/networks | |
parent | 38202e9c6c1155d96ee0f6e9f337022ee4eeb7e3 (diff) |
Configure backbone commented out, transformer init change
Diffstat (limited to 'text_recognizer/networks')
-rw-r--r-- | text_recognizer/networks/transformer/__init__.py | 2 | ||||
-rw-r--r-- | text_recognizer/networks/util.py | 77 |
2 files changed, 39 insertions, 40 deletions
diff --git a/text_recognizer/networks/transformer/__init__.py b/text_recognizer/networks/transformer/__init__.py index 652e82e..627fa7b 100644 --- a/text_recognizer/networks/transformer/__init__.py +++ b/text_recognizer/networks/transformer/__init__.py @@ -4,4 +4,4 @@ from .positional_encoding import ( PositionalEncoding2D, target_padding_mask, ) -from .transformer import Decoder, Encoder, EncoderLayer, Transformer +from .transformer import Decoder, DecoderLayer, Encoder, EncoderLayer, Transformer diff --git a/text_recognizer/networks/util.py b/text_recognizer/networks/util.py index d292680..9c6b151 100644 --- a/text_recognizer/networks/util.py +++ b/text_recognizer/networks/util.py @@ -1,7 +1,7 @@ """Miscellaneous neural network functionality.""" import importlib from pathlib import Path -from typing import Dict, Type +from typing import Dict, NamedTuple, Union, Type from loguru import logger import torch @@ -24,41 +24,40 @@ def activation_function(activation: str) -> Type[nn.Module]: return activation_fns[activation.lower()] -def configure_backbone(backbone: str, backbone_args: Dict) -> Type[nn.Module]: - """Loads a backbone network.""" - network_module = importlib.import_module("text_recognizer.networks") - backbone_ = getattr(network_module, backbone) - - if "pretrained" in backbone_args: - logger.info("Loading pretrained backbone.") - checkpoint_file = Path(__file__).resolve().parents[2] / backbone_args.pop( - "pretrained" - ) - - # Loading state directory. - state_dict = torch.load(checkpoint_file) - network_args = state_dict["network_args"] - weights = state_dict["model_state"] - - freeze = False - if "freeze" in backbone_args and backbone_args["freeze"] is True: - backbone_args.pop("freeze") - freeze = True - network_args = backbone_args - - # Initializes the network with trained weights. - backbone = backbone_(**network_args) - backbone.load_state_dict(weights) - if freeze: - for params in backbone.parameters(): - params.requires_grad = False - else: - backbone_ = getattr(network_module, backbone) - backbone = backbone_(**backbone_args) - - if "remove_layers" in backbone_args and backbone_args["remove_layers"] is not None: - backbone = nn.Sequential( - *list(backbone.children())[:][: -backbone_args["remove_layers"]] - ) - - return backbone +# def configure_backbone(backbone: Union[OmegaConf, NamedTuple]) -> Type[nn.Module]: +# """Loads a backbone network.""" +# network_module = importlib.import_module("text_recognizer.networks") +# backbone_class = getattr(network_module, backbone.type) +# +# if "pretrained" in backbone.args: +# logger.info("Loading pretrained backbone.") +# checkpoint_file = Path(__file__).resolve().parents[2] / backbone.args.pop( +# "pretrained" +# ) +# +# # Loading state directory. +# state_dict = torch.load(checkpoint_file) +# network_args = state_dict["network_args"] +# weights = state_dict["model_state"] +# +# freeze = False +# if "freeze" in backbone.args and backbone.args["freeze"] is True: +# backbone.args.pop("freeze") +# freeze = True +# +# # Initializes the network with trained weights. +# backbone_ = backbone_(**backbone.args) +# backbone_.load_state_dict(weights) +# if freeze: +# for params in backbone_.parameters(): +# params.requires_grad = False +# else: +# backbone_ = getattr(network_module, backbone.type) +# backbone_ = backbone_(**backbone.args) +# +# if "remove_layers" in backbone_args and backbone_args["remove_layers"] is not None: +# backbone = nn.Sequential( +# *list(backbone.children())[:][: -backbone_args["remove_layers"]] +# ) +# +# return backbone |