diff options
Diffstat (limited to 'src/notebooks/02b-emnist-lines-dataset.ipynb')
-rw-r--r-- | src/notebooks/02b-emnist-lines-dataset.ipynb | 123 |
1 files changed, 116 insertions, 7 deletions
diff --git a/src/notebooks/02b-emnist-lines-dataset.ipynb b/src/notebooks/02b-emnist-lines-dataset.ipynb index 2ef7da7..84d853b 100644 --- a/src/notebooks/02b-emnist-lines-dataset.ipynb +++ b/src/notebooks/02b-emnist-lines-dataset.ipynb @@ -22,7 +22,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ @@ -31,7 +31,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ @@ -40,14 +40,14 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "2020-09-09 23:07:57.716 | DEBUG | text_recognizer.datasets.emnist_lines_dataset:_load_data:134 - EmnistLinesDataset loading data from HDF5...\n" + "2020-09-10 20:11:30.358 | DEBUG | text_recognizer.datasets.emnist_lines_dataset:_load_data:134 - EmnistLinesDataset loading data from HDF5...\n" ] } ], @@ -57,7 +57,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 5, "metadata": {}, "outputs": [], "source": [ @@ -67,7 +67,7 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 6, "metadata": { "scrolled": false }, @@ -212,7 +212,7 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 7, "metadata": {}, "outputs": [ { @@ -237,6 +237,115 @@ }, { "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([58, 50, 53, 46, 40, 53, 54, 62, 58, 44, 41, 40, 62, 53, 40, 41, 56, 54,\n", + " 40, 39, 62, 55, 50, 62, 43, 36, 57, 40, 79, 79, 79, 79, 79, 79],\n", + " dtype=torch.uint8)" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "target" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "from text_recognizer.networks import LineRecurrentNetwork" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "crnn = LineRecurrentNetwork(encoder=\"ResidualNetworkEncoder\",\n", + " \n", + " encoder_args={\n", + " \"in_channels\": 1,\n", + " \"num_classes\": 80,\n", + " \"depths\": [2, 2],\n", + " \"block_sizes\": [64, 128],\n", + " \"activation\": \"leaky_relu\",\n", + " \"stn\": False,})" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "output = crnn(data)" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "targets = target.unsqueeze(0).type(torch.long)" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [], + "source": [ + "input_lengths = torch.full(\n", + " size=(output.shape[1],), fill_value=output.shape[0], dtype=torch.long,\n", + ")\n", + "target_lengths = torch.full(\n", + " size=(output.shape[1],), fill_value=targets.shape[1], dtype=torch.long,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [], + "source": [ + "ctc = torch.nn.CTCLoss(blank=79)" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor(6.9917, grad_fn=<MeanBackward0>)" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ctc(output, targets, input_lengths, target_lengths)" + ] + }, + { + "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], |