Sri model¶
View a running version of this notebook. | Download this project.
For a recent talk in my department I talked a little bit about agent based modeling and in the process I came across the simple but quite interesting SIR model in epidemiology. The inspiration for this notebook was Simon Dobson's post on Epidemic spreading processes, which will provide a much more detailed scientific background and take you through some of the code step by step. However as a brief introduction
I've made some minor tweaks to the model by adding vaccinated and dead states. I've also unified the function based approach into a single Parameterized class, which takes care of initializing, running and visualizing the network.
In this notebook I'll primarily look at how we can quickly create complex visualization about this model using HoloViews. In the process I'll look at some predictions this model can make about herd immunity but won't be giving it any rigorous scientific treatment.
The Code¶
Here's the code for the model relying only on numpy, networkx, holoviews and matplotlib in the background.
import collections
import itertools
import math
import numpy as np
np.seterr(divide='ignore')
import numpy.random as rnd
import networkx as nx
import param
import holoviews as hv
class SRI_Model(param.Parameterized):
"""
Implementation of the SRI epidemiology model
using NetworkX and HoloViews for visualization.
This code has been adapted from Simon Dobson's
code here:
http://www.simondobson.org/complex-networks-complex-processes/epidemic-spreading.html
In addition to his basic parameters I've added
additional states to the model, a node may be
in one of the following states:
* Susceptible: Can catch the disease from a connected node.
* Vaccinated: Immune to infection.
* Infected: Has the disease and may pass it on to any connected node.
* Recovered: Immune to infection.
* Dead: Edges are removed from graph.
"""
network = param.ClassSelector(class_=nx.Graph, default=None, doc="""
A custom NetworkX graph, instead of the default Erdos-Renyi graph.""")
visualize = param.Boolean(default=True, doc="""
Whether to compute layout of network for visualization.""")
N = param.Integer(default=1000, doc="""
Number of nodes to simulate.""")
mean_connections = param.Number(default=10, doc="""
Mean number of connections to make to other nodes.""")
pSick = param.Number(default=0.01, doc="""
Probability of a node to be initialized in sick state.""", bounds=(0, 1))
pVaccinated = param.Number(default=0.1, bounds=(0, 1), doc="""
Probability of a node to be initialized in vaccinated state.""")
pInfect = param.Number(default=0.3, doc="""
Probability of infection on each time step.""", bounds=(0, 1))
pRecover = param.Number(default=0.05, doc="""
Probability of recovering if infected on each timestep.""", bounds=(0, 1))
pDeath = param.Number(default=0.1, doc="""
Probability of death if infected on each timestep.""", bounds=(0, 1))
SPREADING_SUSCEPTIBLE = 'S'
SPREADING_VACCINATED = 'V'
SPREADING_INFECTED = 'I'
SPREADING_RECOVERED = 'R'
DEAD = 'D'
def __init__(self, **params):
super(SRI_Model, self).__init__(**params)
if not self.network:
self.g = nx.erdos_renyi_graph(self.N, float(self.mean_connections)/self.N)
else:
self.g = self.network
self.vaccinated, self.infected = self.spreading_init()
self.model = self.spreading_make_sir_model()
self.color_mapping = [self.SPREADING_SUSCEPTIBLE,
self.SPREADING_VACCINATED,
self.SPREADING_INFECTED,
self.SPREADING_RECOVERED, self.DEAD]
if self.visualize:
k = 2/(math.sqrt(self.g.order()))
self.pos = hv.Graph.from_networkx(self.g, nx.spring_layout, iterations=50, k=k)
def spreading_init(self):
"""Initialise the network with vaccinated, susceptible and infected states."""
vaccinated, infected = 0, []
for i in self.g.node.keys():
self.g.node[i]['transmissions'] = 0
if(rnd.random() <= self.pVaccinated):
self.g.node[i]['state'] = self.SPREADING_VACCINATED
vaccinated += 1
elif(rnd.random() <= self.pSick):
self.g.node[i]['state'] = self.SPREADING_INFECTED
infected.append(i)
else:
self.g.node[i]['state'] = self.SPREADING_SUSCEPTIBLE
return vaccinated, infected
def spreading_make_sir_model(self):
"""Return an SIR model function for given infection and recovery probabilities."""
# model (local rule) function
def model( g, i ):
if g.node[i]['state'] == self.SPREADING_INFECTED:
# infect susceptible neighbours with probability pInfect
for m in g.neighbors(i):
if g.node[m]['state'] == self.SPREADING_SUSCEPTIBLE:
if rnd.random() <= self.pInfect:
g.node[m]['state'] = self.SPREADING_INFECTED
self.infected.append(m)
g.node[i]['transmissions'] += 1
# recover with probability pRecover
if rnd.random() <= self.pRecover:
g.node[i]['state'] = self.SPREADING_RECOVERED
elif rnd.random() <= self.pDeath:
edges = [edge for edge in self.g.edges() if i in edge]
g.node[i]['state'] = self.DEAD
g.remove_edges_from(edges)
return model
def step(self):
"""Run a single step of the model over the graph."""
for i in self.g.node.keys():
self.model(self.g, i)
def run(self, steps):
"""
Run the network for the specified number of time steps
"""
for i in range(steps):
self.step()
def stats(self):
"""
Return an ItemTable with statistics on the network data.
"""
state_labels = hv.OrderedDict([('S', 'Susceptible'), ('V', 'Vaccinated'), ('I', 'Infected'),
('R', 'Recovered'), ('D', 'Dead')])
counts = collections.Counter()
transmissions = []
for n in self.g.nodes():
state = state_labels[self.g.node[n]['state']]
counts[state] += 1
if n in self.infected:
transmissions.append(self.g.node[n]['transmissions'])
data = hv.OrderedDict([(l, counts[l])
for l in state_labels.values()])
infected = len(set(self.infected))
unvaccinated = float(self.N-self.vaccinated)
data['$R_0$'] = np.mean(transmissions) if transmissions else 0
data['Death rate DR'] = np.divide(float(data['Dead']),self.N)
data['Infection rate IR'] = np.divide(float(infected), self.N)
if unvaccinated:
unvaccinated_dr = data['Dead']/unvaccinated
unvaccinated_ir = infected/unvaccinated
else:
unvaccinated_dr = 0
unvaccinated_ir = 0
data['Unvaccinated DR'] = unvaccinated_dr
data['Unvaccinated IR'] = unvaccinated_ir
return hv.ItemTable(data)
def animate(self, steps):
"""
Run the network for the specified number of steps accumulating animations
of the network nodes and edges changing states and curves tracking the
spread of the disease.
"""
if not self.visualize:
raise Exception("Enable visualize option to get compute network visulizations.")
# Declare HoloMap for network animation and counts array
network_hmap = hv.HoloMap(kdims='Time')
sird = np.zeros((steps, 5))
# Declare labels
state_labels = ['Susceptible', 'Vaccinated', 'Infected', 'Recovered', 'Dead']
# Text annotation
nlabel = hv.Text(0.9, 0.05, 'N=%d' % self.N)
for i in range(steps):
# Get path, point, states and count data
states = [self.g.node[n]['state'] for n in self.g.nodes()]
state_ints = [self.color_mapping.index(v) for v in states]
state_array = np.array(state_ints, ndmin=2).T
(sird[i, :], _) = np.histogram(state_array, bins=list(range(6)))
# Create network path and node Elements
nodes = self.pos.nodes.clone(datatype=['dictionary'])
nodes = nodes.add_dimension('State', 0, states, True)
graph = self.pos.clone((self.pos.data.copy(), nodes))
# Create overlay and accumulate in network HoloMap
network_hmap[i] = (graph * nlabel).relabel(group='Network', label='SRI')
self.step()
# Create Overlay of Curves
#extents = (-1, -1, steps, np.max(sird)+2)
curves = hv.NdOverlay({label: hv.Curve(zip(range(steps), sird[:, i]),
'Time', 'Count')
for i, label in enumerate(state_labels)},
kdims=[hv.Dimension('State', values=state_labels)])
# Animate VLine on top of Curves
distribution = hv.HoloMap({i: (curves * hv.VLine(i)).relabel(group='Counts', label='SRI')
for i in range(steps)}, kdims='Time')
return network_hmap + distribution
The style¶
HoloViews allows use to define various style options in advance on the Store.options object.
hv.extension('bokeh', 'matplotlib')
# Set colors and style options for the Element types
from holoviews import Store, Options
opts = Store.options()
colormap = {k: v for k, v in zip('SVIRD', hv.Cycle().values)}
opts.Graph = Options('plot', color_index='State')
opts.Graph = Options('style', cmap=colormap, node_size=6, edge_line_width=1)
opts.Histogram = Options('plot', show_grid=False)
opts.Overlay = Options('plot', show_frame=False)
opts.HeatMap = Options('plot', xrotation=90)
opts.ItemTable = Options('plot', width=900, height=50)
opts.Overlay.Network = Options('plot', xaxis=None, yaxis=None)
opts.Overlay.Counts = Options('plot', show_grid=True)
opts.VLine = {'style': Options(color='black', line_width=1),
'plot': Options(show_grid=True)}