summaryrefslogtreecommitdiff
path: root/notebooks
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-11-03 22:12:09 +0100
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-11-03 22:12:09 +0100
commit113153671b5ab9ff613a03dbbfcf4266e269bd9f (patch)
tree8b5f8e1dfcbe13dd55c168d569d6e7816148cc12 /notebooks
parent86b75f8439cef178cea2cae891afb112ccba1411 (diff)
Update notebook
Diffstat (limited to 'notebooks')
-rw-r--r--notebooks/04-efficientnet-transformer.ipynb136
1 files changed, 135 insertions, 1 deletions
diff --git a/notebooks/04-efficientnet-transformer.ipynb b/notebooks/04-efficientnet-transformer.ipynb
index 4a6fd64..0977487 100644
--- a/notebooks/04-efficientnet-transformer.ipynb
+++ b/notebooks/04-efficientnet-transformer.ipynb
@@ -284,10 +284,144 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 2,
"id": "b13ac47c-322d-47d4-bcee-43e5341f74a7",
"metadata": {},
"outputs": [],
+ "source": [
+ "start_tokens = torch.ones(1, 1).long()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "id": "55a16f5d-2b27-4a12-b5bb-eb079784b0ae",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "num_dims = len(start_tokens.shape)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "id": "46c65400-fa47-4c10-9edd-8416e6a1185a",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "if num_dims == 1:\n",
+ " start_tokens = start_tokens[None, :]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "id": "1dfa0b95-a075-4121-b2bf-f1a8100b10fd",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "torch.Size([1, 1])"
+ ]
+ },
+ "execution_count": 7,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "start_tokens.shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "id": "c85a357c-f6af-42c5-b714-89df024c29e3",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "b, t = start_tokens.shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "id": "0ba293f4-e08d-4aaa-94d5-da4899f9b592",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "out = start_tokens"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "id": "a8225f98-c5e9-4da2-b756-75599fa8e044",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "input_mask = torch.full_like(out, True, dtype=torch.bool, device=out.device)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "id": "63dfcfd7-6b93-49ac-a0ab-59be53fa0853",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "tensor([[True]])"
+ ]
+ },
+ "execution_count": 12,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "input_mask"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 13,
+ "id": "a752bfcf-f323-43bc-a910-fec4695150e0",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "x = out[:, -200:]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 14,
+ "id": "4b1b1989-930a-48c5-a7b3-746289107b97",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "tensor([[1]])"
+ ]
+ },
+ "execution_count": 14,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "x"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "513f27bb-2ae1-42a0-8de9-9ae39fdfff32",
+ "metadata": {},
+ "outputs": [],
"source": []
}
],