diff options
-rw-r--r-- | notebooks/04-efficientnet-transformer.ipynb | 136 |
1 files changed, 135 insertions, 1 deletions
diff --git a/notebooks/04-efficientnet-transformer.ipynb b/notebooks/04-efficientnet-transformer.ipynb index 4a6fd64..0977487 100644 --- a/notebooks/04-efficientnet-transformer.ipynb +++ b/notebooks/04-efficientnet-transformer.ipynb @@ -284,10 +284,144 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, "id": "b13ac47c-322d-47d4-bcee-43e5341f74a7", "metadata": {}, "outputs": [], + "source": [ + "start_tokens = torch.ones(1, 1).long()" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "55a16f5d-2b27-4a12-b5bb-eb079784b0ae", + "metadata": {}, + "outputs": [], + "source": [ + "num_dims = len(start_tokens.shape)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "46c65400-fa47-4c10-9edd-8416e6a1185a", + "metadata": {}, + "outputs": [], + "source": [ + "if num_dims == 1:\n", + " start_tokens = start_tokens[None, :]" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "1dfa0b95-a075-4121-b2bf-f1a8100b10fd", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([1, 1])" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "start_tokens.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "c85a357c-f6af-42c5-b714-89df024c29e3", + "metadata": {}, + "outputs": [], + "source": [ + "b, t = start_tokens.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "0ba293f4-e08d-4aaa-94d5-da4899f9b592", + "metadata": {}, + "outputs": [], + "source": [ + "out = start_tokens" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "a8225f98-c5e9-4da2-b756-75599fa8e044", + "metadata": {}, + "outputs": [], + "source": [ + "input_mask = torch.full_like(out, True, dtype=torch.bool, device=out.device)" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "63dfcfd7-6b93-49ac-a0ab-59be53fa0853", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[True]])" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "input_mask" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "a752bfcf-f323-43bc-a910-fec4695150e0", + "metadata": {}, + "outputs": [], + "source": [ + "x = out[:, -200:]" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "4b1b1989-930a-48c5-a7b3-746289107b97", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[1]])" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "x" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "513f27bb-2ae1-42a0-8de9-9ae39fdfff32", + "metadata": {}, + "outputs": [], "source": [] } ], |