diff options
Diffstat (limited to 'src/notebooks/07-try-gtn.ipynb')
-rw-r--r-- | src/notebooks/07-try-gtn.ipynb | 49 |
1 files changed, 48 insertions, 1 deletions
diff --git a/src/notebooks/07-try-gtn.ipynb b/src/notebooks/07-try-gtn.ipynb index d366dec..4ef444b 100644 --- a/src/notebooks/07-try-gtn.ipynb +++ b/src/notebooks/07-try-gtn.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 1, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ @@ -125,6 +125,53 @@ }, { "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[1.0, 0.0, 0.5, 0.5]\n" + ] + } + ], + "source": [ + "import gtn\n", + "\n", + "# Make some graphs:\n", + "g1 = gtn.Graph()\n", + "g1.add_node(True) # Add a start node\n", + "g1.add_node() # Add an internal node\n", + "g1.add_node(False, True) # Add an accepting node\n", + "\n", + "# Add arcs with (src node, dst node, label):\n", + "g1.add_arc(0, 1, 1)\n", + "g1.add_arc(0, 1, 2)\n", + "g1.add_arc(1, 2, 1)\n", + "g1.add_arc(1, 2, 0)\n", + "\n", + "g2 = gtn.Graph()\n", + "g2.add_node(True, True)\n", + "g2.add_arc(0, 0, 1)\n", + "g2.add_arc(0, 0, 0)\n", + "\n", + "# Compute a function of the graphs:\n", + "intersection = gtn.intersect(g1, g2)\n", + "score = gtn.forward_score(intersection)\n", + "\n", + "# Visualize the intersected graph:\n", + "gtn.draw(intersection, \"intersection.pdf\")\n", + "\n", + "# Backprop:\n", + "gtn.backward(score)\n", + "\n", + "# Print gradients of arc weights \n", + "print(g1.grad().weights_to_list()) # [1.0, 0.0, 1.0, 0.0]" + ] + }, + { + "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], |