summaryrefslogtreecommitdiff
path: root/src/notebooks/02b-emnist-lines-dataset.ipynb
diff options
context:
space:
mode:
Diffstat (limited to 'src/notebooks/02b-emnist-lines-dataset.ipynb')
-rw-r--r--src/notebooks/02b-emnist-lines-dataset.ipynb123
1 files changed, 116 insertions, 7 deletions
diff --git a/src/notebooks/02b-emnist-lines-dataset.ipynb b/src/notebooks/02b-emnist-lines-dataset.ipynb
index 2ef7da7..84d853b 100644
--- a/src/notebooks/02b-emnist-lines-dataset.ipynb
+++ b/src/notebooks/02b-emnist-lines-dataset.ipynb
@@ -22,7 +22,7 @@
},
{
"cell_type": "code",
- "execution_count": 7,
+ "execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
@@ -31,7 +31,7 @@
},
{
"cell_type": "code",
- "execution_count": 8,
+ "execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
@@ -40,14 +40,14 @@
},
{
"cell_type": "code",
- "execution_count": 10,
+ "execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
- "2020-09-09 23:07:57.716 | DEBUG | text_recognizer.datasets.emnist_lines_dataset:_load_data:134 - EmnistLinesDataset loading data from HDF5...\n"
+ "2020-09-10 20:11:30.358 | DEBUG | text_recognizer.datasets.emnist_lines_dataset:_load_data:134 - EmnistLinesDataset loading data from HDF5...\n"
]
}
],
@@ -57,7 +57,7 @@
},
{
"cell_type": "code",
- "execution_count": 15,
+ "execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
@@ -67,7 +67,7 @@
},
{
"cell_type": "code",
- "execution_count": 16,
+ "execution_count": 6,
"metadata": {
"scrolled": false
},
@@ -212,7 +212,7 @@
},
{
"cell_type": "code",
- "execution_count": 17,
+ "execution_count": 7,
"metadata": {},
"outputs": [
{
@@ -237,6 +237,115 @@
},
{
"cell_type": "code",
+ "execution_count": 8,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "tensor([58, 50, 53, 46, 40, 53, 54, 62, 58, 44, 41, 40, 62, 53, 40, 41, 56, 54,\n",
+ " 40, 39, 62, 55, 50, 62, 43, 36, 57, 40, 79, 79, 79, 79, 79, 79],\n",
+ " dtype=torch.uint8)"
+ ]
+ },
+ "execution_count": 8,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "target"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from text_recognizer.networks import LineRecurrentNetwork"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "crnn = LineRecurrentNetwork(encoder=\"ResidualNetworkEncoder\",\n",
+ " \n",
+ " encoder_args={\n",
+ " \"in_channels\": 1,\n",
+ " \"num_classes\": 80,\n",
+ " \"depths\": [2, 2],\n",
+ " \"block_sizes\": [64, 128],\n",
+ " \"activation\": \"leaky_relu\",\n",
+ " \"stn\": False,})"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "output = crnn(data)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "targets = target.unsqueeze(0).type(torch.long)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 13,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "input_lengths = torch.full(\n",
+ " size=(output.shape[1],), fill_value=output.shape[0], dtype=torch.long,\n",
+ ")\n",
+ "target_lengths = torch.full(\n",
+ " size=(output.shape[1],), fill_value=targets.shape[1], dtype=torch.long,\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 15,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "ctc = torch.nn.CTCLoss(blank=79)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 16,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "tensor(6.9917, grad_fn=<MeanBackward0>)"
+ ]
+ },
+ "execution_count": 16,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "ctc(output, targets, input_lengths, target_lengths)"
+ ]
+ },
+ {
+ "cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],