From 905eeeb4c3c0ba54b5414eb8f435e2e9870b7307 Mon Sep 17 00:00:00 2001 From: aktersnurra Date: Wed, 24 Feb 2021 22:00:29 +0100 Subject: updates --- src/notebooks/07-try-gtn.ipynb | 49 +++++++++++++++++++++++++++++++++++++++++- 1 file changed, 48 insertions(+), 1 deletion(-) (limited to 'src/notebooks/07-try-gtn.ipynb') diff --git a/src/notebooks/07-try-gtn.ipynb b/src/notebooks/07-try-gtn.ipynb index d366dec..4ef444b 100644 --- a/src/notebooks/07-try-gtn.ipynb +++ b/src/notebooks/07-try-gtn.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 1, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ @@ -123,6 +123,53 @@ "print(g1.grad().weights_to_list()) " ] }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[1.0, 0.0, 0.5, 0.5]\n" + ] + } + ], + "source": [ + "import gtn\n", + "\n", + "# Make some graphs:\n", + "g1 = gtn.Graph()\n", + "g1.add_node(True) # Add a start node\n", + "g1.add_node() # Add an internal node\n", + "g1.add_node(False, True) # Add an accepting node\n", + "\n", + "# Add arcs with (src node, dst node, label):\n", + "g1.add_arc(0, 1, 1)\n", + "g1.add_arc(0, 1, 2)\n", + "g1.add_arc(1, 2, 1)\n", + "g1.add_arc(1, 2, 0)\n", + "\n", + "g2 = gtn.Graph()\n", + "g2.add_node(True, True)\n", + "g2.add_arc(0, 0, 1)\n", + "g2.add_arc(0, 0, 0)\n", + "\n", + "# Compute a function of the graphs:\n", + "intersection = gtn.intersect(g1, g2)\n", + "score = gtn.forward_score(intersection)\n", + "\n", + "# Visualize the intersected graph:\n", + "gtn.draw(intersection, \"intersection.pdf\")\n", + "\n", + "# Backprop:\n", + "gtn.backward(score)\n", + "\n", + "# Print gradients of arc weights \n", + "print(g1.grad().weights_to_list()) # [1.0, 0.0, 1.0, 0.0]" + ] + }, { "cell_type": "code", "execution_count": null, -- cgit v1.2.3-70-g09d2