SRI Model#

For a recent talk in my department I talked a little bit about agent based modeling and in the process I came across the simple but quite interesting SIR model in epidemiology. The inspiration for this notebook was an older version of Simon Dobson’s post on Epidemic spreading processes, which will provide a much more detailed scientific background and take you through some of the code step by step. However as a brief introduction

I’ve made some minor tweaks to the model by adding vaccinated and dead states. I’ve also unified the function based approach into a single Parameterized class, which takes care of initializing, running and visualizing the network.

In this notebook I’ll primarily look at how we can quickly create complex visualization about this model using HoloViews. In the process I’ll look at some predictions this model can make about herd immunity but won’t be giving it any rigorous scientific treatment.

The Code#

Here’s the code for the model relying only on numpy, networkx, holoviews and matplotlib in the background.

import collections
import itertools
import math

import numpy as np
import numpy.random as rnd
import networkx as nx

import param
import holoviews as hv

class SRI_Model(param.Parameterized):
    Implementation of the SRI epidemiology model
    using NetworkX and HoloViews for visualization.
    This code has been adapted from Simon Dobson's
    code here:
    In addition to his basic parameters I've added
    additional states to the model, a node may be
    in one of the following states:
      * Susceptible: Can catch the disease from a connected node.
      * Vaccinated: Immune to infection.
      * Infected: Has the disease and may pass it on to any connected node.
      * Recovered: Immune to infection.
      * Dead: Edges are removed from graph.

    network = param.ClassSelector(class_=nx.Graph, default=None, doc="""
        A custom NetworkX graph, instead of the default Erdos-Renyi graph.""")
    visualize = param.Boolean(default=True, doc="""
        Whether to compute layout of network for visualization.""")
    N = param.Integer(default=1000, doc="""
        Number of nodes to simulate.""")
    mean_connections = param.Number(default=10, doc="""
        Mean number of connections to make to other nodes.""")
    pSick = param.Number(default=0.01, doc="""
        Probability of a node to be initialized in sick state.""", bounds=(0, 1))

    pVaccinated = param.Number(default=0.1, bounds=(0, 1), doc="""
        Probability of a node to be initialized in vaccinated state.""")
    pInfect = param.Number(default=0.3, doc="""
        Probability of infection on each time step.""", bounds=(0, 1))
    pRecover = param.Number(default=0.05, doc="""
        Probability of recovering if infected on each timestep.""", bounds=(0, 1))
    pDeath = param.Number(default=0.1, doc="""
        Probability of death if infected on each timestep.""", bounds=(0, 1))
    DEAD = 'D'

    def __init__(self, **params):
        super(SRI_Model, self).__init__(**params)
        if not
            self.g = nx.erdos_renyi_graph(self.N, float(self.mean_connections)/self.N)
            self.g =
        self.vaccinated, self.infected = self.spreading_init()
        self.model = self.spreading_make_sir_model()
        self.color_mapping = [self.SPREADING_SUSCEPTIBLE,
                              self.SPREADING_RECOVERED, self.DEAD]
        if self.visualize:
            k = 2/(math.sqrt(self.g.order()))
            self.pos = hv.Graph.from_networkx(self.g, nx.spring_layout, iterations=50, k=k)

    def spreading_init(self):
        """Initialise the network with vaccinated, susceptible and infected states."""
        vaccinated, infected = 0, []
        for i in self.g.nodes.keys():
            self.g.nodes[i]['transmissions'] = 0
            if(rnd.random() <= self.pVaccinated): 
                self.g.nodes[i]['state'] = self.SPREADING_VACCINATED
                vaccinated += 1
            elif(rnd.random() <= self.pSick):
                self.g.nodes[i]['state'] = self.SPREADING_INFECTED
                self.g.nodes[i]['state'] = self.SPREADING_SUSCEPTIBLE
        return vaccinated, infected

    def spreading_make_sir_model(self):
        """Return an SIR model function for given infection and recovery probabilities."""
        # model (local rule) function
        def model( g, i ):
            if g.nodes[i]['state'] == self.SPREADING_INFECTED:
                # infect susceptible neighbours with probability pInfect
                for m in g.neighbors(i):
                    if g.nodes[m]['state'] == self.SPREADING_SUSCEPTIBLE:
                        if rnd.random() <= self.pInfect:
                            g.nodes[m]['state'] = self.SPREADING_INFECTED
                            g.nodes[i]['transmissions'] += 1

                # recover with probability pRecover
                if rnd.random() <= self.pRecover:
                    g.nodes[i]['state'] = self.SPREADING_RECOVERED
                elif rnd.random() <= self.pDeath:
                    edges = [edge for edge in self.g.edges() if i in edge] 
                    g.nodes[i]['state'] = self.DEAD

        return model

    def step(self):
        """Run a single step of the model over the graph."""
        for i in self.g.nodes.keys():
            self.model(self.g, i)

    def run(self, steps):
        Run the network for the specified number of time steps
        for i in range(steps):

    def stats(self):
        Return an ItemTable with statistics on the network data.
        state_labels = {'S': 'Susceptible', 'V': 'Vaccinated', 'I': 'Infected',
                        'R': 'Recovered', 'D': 'Dead'}
        counts = collections.Counter()
        transmissions = []
        for n in self.g.nodes():
            state = state_labels[self.g.nodes[n]['state']]
            counts[state] += 1
            if n in self.infected:
        data = {l: counts[l] for l in state_labels.values()}
        infected = len(set(self.infected))
        unvaccinated = float(self.N-self.vaccinated)
        data['$R_0$'] = np.mean(transmissions) if transmissions else 0
        data['Death rate DR'] = np.divide(float(data['Dead']),self.N)
        data['Infection rate IR'] = np.divide(float(infected), self.N)
        if unvaccinated:
            unvaccinated_dr = data['Dead']/unvaccinated
            unvaccinated_ir = infected/unvaccinated
            unvaccinated_dr = 0
            unvaccinated_ir = 0
        data['Unvaccinated DR'] = unvaccinated_dr
        data['Unvaccinated IR'] = unvaccinated_ir
        return hv.ItemTable(data)

    def animate(self, steps):
        Run the network for the specified number of steps accumulating animations
        of the network nodes and edges changing states and curves tracking the
        spread of the disease.
        if not self.visualize:
            raise Exception("Enable visualize option to get compute network visulizations.")

        # Declare HoloMap for network animation and counts array
        network_hmap = hv.HoloMap(kdims='Time')
        sird = np.zeros((steps, 5))
        # Declare labels
        state_labels = ['Susceptible', 'Vaccinated', 'Infected', 'Recovered', 'Dead']

        # Text annotation
        nlabel = hv.Text(0.9, 0.05, 'N=%d' % self.N)

        for i in range(steps):
            # Get path, point, states and count data
            states = [self.g.nodes[n]['state'] for n in self.g.nodes()]
            state_ints = [self.color_mapping.index(v) for v in states]
            state_array = np.array(state_ints, ndmin=2).T
            (sird[i, :], _) = np.histogram(state_array, bins=list(range(6)))

            # Create network path and node Elements
            nodes = self.pos.nodes.clone(datatype=['dictionary'])
            nodes = nodes.add_dimension('State', 0, states, True)
            graph = self.pos.clone((, nodes))
            # Create overlay and accumulate in network HoloMap
            network_hmap[i] = (graph * nlabel).relabel(group='Network', label='SRI')

        # Create Overlay of Curves
        #extents = (-1, -1, steps, np.max(sird)+2)
        curves = hv.NdOverlay({label: hv.Curve(zip(range(steps), sird[:, i]),
                                               'Time', 'Count')
                              for i, label in enumerate(state_labels)},
                              kdims=[hv.Dimension('State', values=state_labels)])
        # Animate VLine on top of Curves
        distribution = hv.HoloMap({i: (curves * hv.VLine(i)).relabel(group='Counts', label='SRI')
                                   for i in range(steps)}, kdims='Time')
        return network_hmap + distribution

The style#

HoloViews allows use to define various style options in advance on the Store.options object.

hv.extension('bokeh', 'matplotlib')

# Set colors and style options for the Element types
from holoviews import Store, Options
opts = Store.options()

colormap = {k: v for k, v in zip('SVIRD', hv.Cycle().values)}

opts.Graph     = Options('plot', color_index='State')
opts.Graph     = Options('style', cmap=colormap, node_size=6, edge_line_width=1)
opts.Histogram = Options('plot', show_grid=False)
opts.Overlay   = Options('plot', show_frame=False)
opts.HeatMap   = Options('plot', xrotation=90)
opts.ItemTable = Options('plot', width=900, height=50)

opts.Overlay.Network = Options('plot', xaxis=None, yaxis=None)
opts.Overlay.Counts  = Options('plot', show_grid=True)

opts.VLine     = {'style': Options(color='black', line_width=1),
                  'plot':  Options(show_grid=True)}

Herd Immunity#

Experiment 1: Evaluating the effects of a highly infectious and deadly disease in a small population with varying levels of vaccination#

Having defined the model and defined the model we can run some real experiments. In particular we can investigate the effect of vaccination on our model.

We’ll initialize our model with only 50 inviduals, who will on average make 10 connections to other individuals. Then we will infect a small population (\(p=0.1\)) so we can track how the disease spreads through the population. To really drive the point home we’ll use a very infectious and deadly disease.

experiment1_params = dict(pInfect=0.08, pRecover=0.08, pSick=0.15,
                          N=50, mean_connections=10, pDeath=0.1)

Low vaccination population (10%)#

Here we’ll investigate the spread of the disease in population with a 10% vaccination rate:

sri_model = SRI_Model(pVaccinated=0.1, **experiment1_params)
sri_model.animate(21).redim.range(x=(-1.2, 1.2), y=(-1.2, 1.2))
WARNING:param.GraphPlot: The `color_index` parameter is deprecated in favor of color style mapping, e.g. `color=dim('color')` or `line_color=dim('color')`

In figure A we can observe how the disease quickly spreads across almost the entire unvaccinated population. Additionally we can track the number of individuals in a particular state in B. As the disease spreads unimpeded the most individuals either die or recover and therefore gain immunity. Individuals that die are obviously no longer part of the network so their connections to other individuals get deleted, this way we can see the network thin out as the disease wreaks havok among the population.

Next we can view a breakdown of the final state of the simulation including infection and death rates:

sri_model.stats().opts(hv.opts.ItemTable(width=900, height=50))

As you can see both the infection and death rates are very high in this population. The disease reached a large percentage all individuals causing death in a large fraction of them. Among the unvaccinated population they are of course even higher with almost >90% infected and >40% dead. The disease spread through our network completely unimpeded. Now let’s see what happens if a large fraction of the population is vaccinated.

High vaccination population (65%)#

If we increase the initial probability of being vaccinated to \(p=0.65\) we’ll be able to observe how this affects the spread of the disease through the network:

sri_model = SRI_Model(pVaccinated=0.65, **experiment1_params)
WARNING:param.GraphPlot: The `color_index` parameter is deprecated in favor of color style mapping, e.g. `color=dim('color')` or `line_color=dim('color')`

Even though we can still see the disease spreading among non-vaccinated individuals we can also observe how the vaccinated individuals stop the spread. If an infected individual is connected with a majority of vaccinated indivuals the probability of the disease spreading is strongly impeded. Unlike in low vaccinated population the disease stops its spread not because too many individuals have died off, rather it quickly runs out of steam, such that a majority of the initial, susceptible but healthy population remains completely unaffected.

This is what’s known as herd immunity and its very important. This is because a small percentage of any population cannot be vaccinated, usually because they are immuno-compromised. However when a larger percentage of people decide that they do not want to get vaccinated (for various and invariably stupid reasons), they place the rest of the population in danger, particularly those that cannot get vaccinated for health reasons.

Let’s look what higher vaccination rates did to our experimental population:


The precipetous drop in the whole populations infection rate and death rate are obviously easily explained by the fact that a smaller fraction of the population was susceptible to the disease in the first place, however as herd immunity would predict, a smaller fraction of the unvaccinated population contracted and died of the disease as well. I hope this toy example once again emphasizes how important vaccination and herd immunity is.

Large networks#

Before we have a more systematic look at herd immunity we’ll increase the population size to 1000 individuals and have a look at what our virulent disease does to this population, if nothing else it’ll produce a pretty plot.

hv.output(holomap='scrubber', size=150)
sri_model_lv = SRI_Model(pVaccinated=0.1, **dict(experiment1_params, N=1000))
sri_layout = sri_model_lv.animate(31)
WARNING:param.GraphPlot: The `color_index` parameter is deprecated in favor of color style mapping, e.g. `color=dim('color')` or `line_color=dim('color')`