From db86cef2d308f58325278061c6aa177a535e7e03 Mon Sep 17 00:00:00 2001
From: Gustaf Rydholm <gustaf.rydholm@gmail.com>
Date: Wed, 1 Jun 2022 23:10:12 +0200
Subject: Replace attr with attrs

---
 text_recognizer/models/base.py        | 24 ++++++++++--------------
 text_recognizer/models/metrics.py     | 10 +++++-----
 text_recognizer/models/transformer.py | 24 ++++++++++++------------
 3 files changed, 27 insertions(+), 31 deletions(-)

(limited to 'text_recognizer/models')

diff --git a/text_recognizer/models/base.py b/text_recognizer/models/base.py
index 821cb69..bf3bc08 100644
--- a/text_recognizer/models/base.py
+++ b/text_recognizer/models/base.py
@@ -1,7 +1,7 @@
 """Base PyTorch Lightning model."""
 from typing import Any, Dict, List, Optional, Tuple, Type
 
-import attr
+from attrs import define, field
 import hydra
 from loguru import logger as log
 from omegaconf import DictConfig
@@ -14,7 +14,7 @@ import torchmetrics
 from text_recognizer.data.mappings.base import AbstractMapping
 
 
-@attr.s(eq=False)
+@define(eq=False)
 class BaseLitModel(LightningModule):
     """Abstract PyTorch Lightning class."""
 
@@ -22,22 +22,18 @@ class BaseLitModel(LightningModule):
         """Pre init constructor."""
         super().__init__()
 
-    network: Type[nn.Module] = attr.ib()
-    loss_fn: Type[nn.Module] = attr.ib()
-    optimizer_configs: DictConfig = attr.ib()
-    lr_scheduler_configs: Optional[DictConfig] = attr.ib()
-    mapping: Type[AbstractMapping] = attr.ib()
+    network: Type[nn.Module] = field()
+    loss_fn: Type[nn.Module] = field()
+    optimizer_configs: DictConfig = field()
+    lr_scheduler_configs: Optional[DictConfig] = field()
+    mapping: Type[AbstractMapping] = field()
 
     # Placeholders
-    train_acc: torchmetrics.Accuracy = attr.ib(
-        init=False, default=torchmetrics.Accuracy()
-    )
-    val_acc: torchmetrics.Accuracy = attr.ib(
-        init=False, default=torchmetrics.Accuracy()
-    )
-    test_acc: torchmetrics.Accuracy = attr.ib(
+    train_acc: torchmetrics.Accuracy = field(
         init=False, default=torchmetrics.Accuracy()
     )
+    val_acc: torchmetrics.Accuracy = field(init=False, default=torchmetrics.Accuracy())
+    test_acc: torchmetrics.Accuracy = field(init=False, default=torchmetrics.Accuracy())
 
     def optimizer_zero_grad(
         self,
diff --git a/text_recognizer/models/metrics.py b/text_recognizer/models/metrics.py
index f83c9e4..e59a830 100644
--- a/text_recognizer/models/metrics.py
+++ b/text_recognizer/models/metrics.py
@@ -1,20 +1,20 @@
 """Character Error Rate (CER)."""
 from typing import Set
 
-import attr
+from attrs import define, field
 import editdistance
 import torch
 from torch import Tensor
 from torchmetrics import Metric
 
 
-@attr.s(eq=False)
+@define(eq=False)
 class CharacterErrorRate(Metric):
     """Character error rate metric, computed using Levenshtein distance."""
 
-    ignore_indices: Set[Tensor] = attr.ib(converter=set)
-    error: Tensor = attr.ib(init=False)
-    total: Tensor = attr.ib(init=False)
+    ignore_indices: Set[Tensor] = field(converter=set)
+    error: Tensor = field(init=False)
+    total: Tensor = field(init=False)
 
     def __attrs_post_init__(self) -> None:
         super().__init__()
diff --git a/text_recognizer/models/transformer.py b/text_recognizer/models/transformer.py
index 7272f46..c5120fe 100644
--- a/text_recognizer/models/transformer.py
+++ b/text_recognizer/models/transformer.py
@@ -1,7 +1,7 @@
 """PyTorch Lightning model for base Transformers."""
 from typing import Set, Tuple
 
-import attr
+from attrs import define, field
 import torch
 from torch import Tensor
 
@@ -9,22 +9,22 @@ from text_recognizer.models.base import BaseLitModel
 from text_recognizer.models.metrics import CharacterErrorRate
 
 
-@attr.s(auto_attribs=True, eq=False)
+@define(auto_attribs=True, eq=False)
 class TransformerLitModel(BaseLitModel):
     """A PyTorch Lightning model for transformer networks."""
 
-    max_output_len: int = attr.ib(default=451)
-    start_token: str = attr.ib(default="<s>")
-    end_token: str = attr.ib(default="<e>")
-    pad_token: str = attr.ib(default="<p>")
+    max_output_len: int = field(default=451)
+    start_token: str = field(default="<s>")
+    end_token: str = field(default="<e>")
+    pad_token: str = field(default="<p>")
 
-    start_index: int = attr.ib(init=False)
-    end_index: int = attr.ib(init=False)
-    pad_index: int = attr.ib(init=False)
+    start_index: int = field(init=False)
+    end_index: int = field(init=False)
+    pad_index: int = field(init=False)
 
-    ignore_indices: Set[Tensor] = attr.ib(init=False)
-    val_cer: CharacterErrorRate = attr.ib(init=False)
-    test_cer: CharacterErrorRate = attr.ib(init=False)
+    ignore_indices: Set[Tensor] = field(init=False)
+    val_cer: CharacterErrorRate = field(init=False)
+    test_cer: CharacterErrorRate = field(init=False)
 
     def __attrs_post_init__(self) -> None:
         """Post init configuration."""
-- 
cgit v1.2.3-70-g09d2