{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "# Copyright (c) 2020 Urbain Vaes. All rights reserved.\n", "#\n", "# This work is licensed under the terms of the MIT license.\n", "# For a copy, see .\n", "# import time\n", "import numpy as np\n", "import scipy.stats\n", "import networkx as nx\n", "import matplotlib as mpl\n", "import matplotlib.pyplot as plt\n", "import matplotlib.animation as animation" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "lines_to_end_of_cell_marker": 0, "lines_to_next_cell": 1 }, "outputs": [], "source": [ "mpl.rc('font', size=20)\n", "mpl.rc('font', family='serif')\n", "mpl.rc('figure', figsize=(16, 11))\n", "mpl.rc('lines', linewidth=2)\n", "mpl.rc('lines', markersize=12)\n", "mpl.rc('figure.subplot', hspace=.3)\n", "mpl.rc('figure.subplot', wspace=.1)\n", "mpl.rc('animation', html='html5')\n", "np.random.seed(0)\n", "\n", "# T is the transition matrix\n", "def run_tests(T, action='plot_evolution'):\n", "\n", " G = nx.DiGraph()\n", " for i, v in enumerate(T):\n", " for j, n in enumerate(v):\n", " if n != 0:\n", " G.add_edges_from([(i, j)], weight=n)\n", "\n", " pos = {0: (0, 0), 1: (0, 2), 2: (1, 1), 3: (2, 0), 4: (2, 2)}\n", "\n", " def add_edges_labels(ax):\n", " kwargs = {\n", " 'fontsize': 18,\n", " 'horizontalalignment': 'center',\n", " 'verticalalignment': 'center',\n", " 'transform': ax.transAxes,\n", " }\n", "\n", " if T[1][2] != 0:\n", " text = ax.text(.3, .62, \"{}\".format(T[1][2]), **kwargs)\n", " text = ax.text(.05, .5, \"{}\".format(T[1][0]), **kwargs)\n", "\n", " if T[3][2] != 0:\n", " text = ax.text(.7, .38, \"{}\".format(T[3][2]), **kwargs)\n", " text = ax.text(.95, .5, \"{}\".format(T[3][4]), **kwargs)\n", "\n", " text = ax.text(.3, .79, \"0.5\", **kwargs)\n", " text = ax.text(.3, .28, \"1\", **kwargs)\n", " text = ax.text(.7, .79, \"1\", **kwargs)\n", " text = ax.text(.7, .20, \"0.5\", **kwargs)\n", "\n", " # Number of \"particles\"\n", " N = 10**4\n", "\n", " # Number of iterations\n", " n = 100\n", "\n", " # Number of nodes\n", " K = len(T)\n", "\n", " # values[i] contains the number of particles at the nodes at iteration i\n", " values = np.zeros((n + 1, K), dtype=int)\n", " exact = np.zeros((n + 1, K))\n", " values[0] = [N, 0, 0, 0, 0]\n", " exact[0] = [1, 0, 0, 0, 0]\n", " tr = np.array(T)\n", "\n", " # Generalized Bernoulli distribution for each node\n", " gen_bernoulli = scipy.stats.rv_discrete\n", " draw_next = [gen_bernoulli(values=(range(K), v)) for v in T]\n", "\n", " # Simulation of the Markov chain\n", " for i in range(n):\n", " for j, v in enumerate(T):\n", " next_step = draw_next[j].rvs(size=values[i][j])\n", " for k in next_step:\n", " values[i+1][k] += 1\n", " exact[i+1] = tr.T.dot(exact[i])\n", "\n", " def plot_evolution(i):\n", " ax.clear()\n", " add_edges_labels(ax)\n", " labels = {j: v for j, v in enumerate(values[i])}\n", " nx.draw_networkx_labels(G, pos, labels=labels, font_size=16, ax=ax)\n", " cmap = mpl.cm.get_cmap('viridis')\n", " nx.draw(G, pos, node_color=values[i], alpha=.5, node_size=3000,\n", " connectionstyle='arc3, rad=0.1', ax=ax, cmap=cmap)\n", " ax.set_title(\"Discrete time: ${}$\".format(i))\n", "\n", " def plot_pmf(i):\n", " ax.clear()\n", " ax.set_title(\"Probability mass function at iteration ${}$\".format(i))\n", " ax.set_xlabel(\"Node index\")\n", " ax.stem(np.arange(K) - .05, values[i]/N, use_line_collection=True,\n", " label=\"MC approximation\", linefmt='C0-', markerfmt='C0o')\n", " ax.stem(np.arange(K) + .05, exact[i], use_line_collection=True,\n", " label=\"Exact\", linefmt='C1-', markerfmt='C1o')\n", " ax.set_ylim(0, 1.1)\n", " ax.legend()\n", "\n", " # Create animation\n", " mpl.rc('figure', figsize=(12, 8))\n", " fig, ax = plt.subplots()\n", " fig.subplots_adjust(left=.1, bottom=.1, right=.98, top=.95)\n", " iterate = plot_evolution if action == 'plot_evolution' else plot_pmf\n", " anim = animation.FuncAnimation(fig, iterate, np.arange(n),\n", " init_func=lambda: None, repeat=True)\n", " # For Python\n", " # plt.show()\n", "\n", " # For notebook\n", " plt.close(fig)\n", " return anim\n" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/html": [ "" ], "text/plain": [ "" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "T = [[0, 0, 1, 0, 0], [1, 0, 0, 0, 0],\n", " [0, .5, 0, .5, 0], [0, 0, 0, 0, 1], [0, 0, 1, 0, 0]]\n", "run_tests(T, action='plot_evolution')" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/html": [ "" ], "text/plain": [ "" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "run_tests(T, action='plot_pmf')" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/html": [ "" ], "text/plain": [ "" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "T = [[0, 0, 1, 0, 0], [.5, 0, .5, 0, 0],\n", " [0, .5, 0, .5, 0], [0, 0, .5, 0, .5], [0, 0, 1, 0, 0]]\n", "run_tests(T, action='plot_evolution')" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/html": [ "" ], "text/plain": [ "" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "run_tests(T, action='plot_pmf')" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/html": [ "" ], "text/plain": [ "" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "T = [[0, 0, 1, 0, 0], [1, 0, 0, 0, 0],\n", " [0, .5, 0, .5, 0], [0, 0, .5, 0, .5], [0, 0, 1, 0, 0]]\n", "run_tests(T, action='plot_evolution')" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/html": [ "" ], "text/plain": [ "" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "run_tests(T, action='plot_pmf')" ] } ], "metadata": { "jupytext": { "cell_metadata_filter": "-all", "main_language": "python", "notebook_metadata_filter": "-all" } }, "nbformat": 4, "nbformat_minor": 4 }