summaryrefslogtreecommitdiff
path: root/src/notebooks/07-try-gtn.ipynb
diff options
context:
space:
mode:
Diffstat (limited to 'src/notebooks/07-try-gtn.ipynb')
-rw-r--r--src/notebooks/07-try-gtn.ipynb49
1 files changed, 48 insertions, 1 deletions
diff --git a/src/notebooks/07-try-gtn.ipynb b/src/notebooks/07-try-gtn.ipynb
index d366dec..4ef444b 100644
--- a/src/notebooks/07-try-gtn.ipynb
+++ b/src/notebooks/07-try-gtn.ipynb
@@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "code",
- "execution_count": 1,
+ "execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
@@ -125,6 +125,53 @@
},
{
"cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[1.0, 0.0, 0.5, 0.5]\n"
+ ]
+ }
+ ],
+ "source": [
+ "import gtn\n",
+ "\n",
+ "# Make some graphs:\n",
+ "g1 = gtn.Graph()\n",
+ "g1.add_node(True) # Add a start node\n",
+ "g1.add_node() # Add an internal node\n",
+ "g1.add_node(False, True) # Add an accepting node\n",
+ "\n",
+ "# Add arcs with (src node, dst node, label):\n",
+ "g1.add_arc(0, 1, 1)\n",
+ "g1.add_arc(0, 1, 2)\n",
+ "g1.add_arc(1, 2, 1)\n",
+ "g1.add_arc(1, 2, 0)\n",
+ "\n",
+ "g2 = gtn.Graph()\n",
+ "g2.add_node(True, True)\n",
+ "g2.add_arc(0, 0, 1)\n",
+ "g2.add_arc(0, 0, 0)\n",
+ "\n",
+ "# Compute a function of the graphs:\n",
+ "intersection = gtn.intersect(g1, g2)\n",
+ "score = gtn.forward_score(intersection)\n",
+ "\n",
+ "# Visualize the intersected graph:\n",
+ "gtn.draw(intersection, \"intersection.pdf\")\n",
+ "\n",
+ "# Backprop:\n",
+ "gtn.backward(score)\n",
+ "\n",
+ "# Print gradients of arc weights \n",
+ "print(g1.grad().weights_to_list()) # [1.0, 0.0, 1.0, 0.0]"
+ ]
+ },
+ {
+ "cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],