diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-11-03 22:12:09 +0100 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-11-03 22:12:09 +0100 |
commit | 113153671b5ab9ff613a03dbbfcf4266e269bd9f (patch) | |
tree | 8b5f8e1dfcbe13dd55c168d569d6e7816148cc12 /notebooks | |
parent | 86b75f8439cef178cea2cae891afb112ccba1411 (diff) |
Update notebook
Diffstat (limited to 'notebooks')
-rw-r--r-- | notebooks/04-efficientnet-transformer.ipynb | 136 |
1 files changed, 135 insertions, 1 deletions
diff --git a/notebooks/04-efficientnet-transformer.ipynb b/notebooks/04-efficientnet-transformer.ipynb index 4a6fd64..0977487 100644 --- a/notebooks/04-efficientnet-transformer.ipynb +++ b/notebooks/04-efficientnet-transformer.ipynb @@ -284,10 +284,144 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, "id": "b13ac47c-322d-47d4-bcee-43e5341f74a7", "metadata": {}, "outputs": [], + "source": [ + "start_tokens = torch.ones(1, 1).long()" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "55a16f5d-2b27-4a12-b5bb-eb079784b0ae", + "metadata": {}, + "outputs": [], + "source": [ + "num_dims = len(start_tokens.shape)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "46c65400-fa47-4c10-9edd-8416e6a1185a", + "metadata": {}, + "outputs": [], + "source": [ + "if num_dims == 1:\n", + " start_tokens = start_tokens[None, :]" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "1dfa0b95-a075-4121-b2bf-f1a8100b10fd", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([1, 1])" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "start_tokens.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "c85a357c-f6af-42c5-b714-89df024c29e3", + "metadata": {}, + "outputs": [], + "source": [ + "b, t = start_tokens.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "0ba293f4-e08d-4aaa-94d5-da4899f9b592", + "metadata": {}, + "outputs": [], + "source": [ + "out = start_tokens" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "a8225f98-c5e9-4da2-b756-75599fa8e044", + "metadata": {}, + "outputs": [], + "source": [ + "input_mask = torch.full_like(out, True, dtype=torch.bool, device=out.device)" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "63dfcfd7-6b93-49ac-a0ab-59be53fa0853", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[True]])" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "input_mask" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "a752bfcf-f323-43bc-a910-fec4695150e0", + "metadata": {}, + "outputs": [], + "source": [ + "x = out[:, -200:]" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "4b1b1989-930a-48c5-a7b3-746289107b97", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[1]])" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "x" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "513f27bb-2ae1-42a0-8de9-9ae39fdfff32", + "metadata": {}, + "outputs": [], "source": [] } ], |