Approximate inference via Gibbs sampling
Taken from coursework for ECE 7751: Graphical Models in Machine Learning, taught by Faramarz Fekri at Georgia Tech, Spring 2023
Consider a setting in which there are $D$ diseases and a patient either has ($d_i=1$) or does not have ($d_i=0$) each disease. The hospital can measure $S$ symptoms, where $s_j=1$ when the patient has the symptom and $s_j=0$ otherwise. A simple Bayesian network for this setting is given by:
$$ p\left(s_1, \ldots, s_S, d_1, \ldots, d_D\right)=\prod_{j=1}^D p\left(s_j \mid \mathbf{d}\right) \prod_{i=1}^D p\left(d_i\right) $$
where $\mathbf{d}=\left(d_1, \ldots, d_D\right)^T$ and
$$ p\left(s_j=1 \mid \mathbf{d}\right)=\sigma\left(\mathbf{w}_j^T \mathbf{d}+b_j\right) $$
where $\sigma(x) = 1/(1+e^{-x})$.
In the above $\mathbf{w}_j$ is a vector of parameters relating symptom $j$ to the diseases and $b_j$ is related to the prevalence of the symptom.
The hospital provides the collection of parameters $W$ and $b$, the prior disease probabilities $p$ (with $p(d_i = 1) = p_i$) and a vector $s$ of symptoms for the patient; see SymptomDisease.mat
Use Gibbs sampling to estimate (using a reasonable amount of burn-in and sub-sampling) to estimate the vector
$$ \bigl[p(d_1=1|s),\ldots,p(d_D=1|s)\bigr] $$
from import loadmat
import numpy as np
data = loadmat('SymptomDisease.mat')
for var in ['W', 'b', 'p', 's']:
globals()[var] = data[var].squeeze()
print(f'W.shape = {W.shape}')
S, D = W.shape
assert S == s.size == b.size
assert D == p.size
W.shape = (200, 50)
To perform Gibbs sampling, we need $p(d_i|d_{-i},s)$, where $d_{-i}$ refers to ${d_1,\ldots,d_{i-1},d_{i+1},\ldots,d_D}$. We can compute this using Bayes’ rule:
$$ p(d_i|d_{-i},s) = \frac{p(s,d)}{p(s,d_{-i})} = \frac{p(d)p(s|d)}{p(d_{-i})p(s|d_{-i})} = \frac{p(d)p(s|d)}{\sum_{d_i} p(d)p(s|d)}$$We can compute the denominator by summing over the numerator when $d_i=1$ and $d_i=0$.
def σ(x):
return 1 / (1 + np.exp(-x))
def likelihood(d) -> float:
# multiply symptom probs p(s_j|d)
ps = σ(W@d + b)
L =**s * (1 - ps)**(1 - s))
# and disease probs p(d_i)
L *=**d * (1 - p)**(1 - d))
return L
def sample_d(d, rng):
for i in range(D):
d[i] = 1
p_d_s_di1 = likelihood(d)
d[i] = 0
p_d_s_di0 = likelihood(d)
condp_di = p_d_s_di1 / (p_d_s_di1 + p_d_s_di0)
d[i] = rng.binomial(1, condp_di)
return d
sample_d((np.random.uniform(size=D) < p).astype(int), np.random.default_rng())
array([1, 1, 0, 1, 0, 0, 0, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1,
1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1, 1, 1,
0, 0, 1, 1, 0, 1])
We can use the complete data likelihood $p(s, d)$ not only to sample, but also for a principled approach to decide when burn-in is complete.
And we’ll continue sampling until the disease posterior probabilities converge.
def gibbs_sample(samples, rng):
d = (rng.uniform(size=D) < p).astype(int)
iters = 0
# burn-in period
iters_without_improvement = 0
curr_lklhd = likelihood(d)
lklhd_max = -np.inf
while iters_without_improvement < 10:
iters += 1
d = sample_d(d, rng)
curr_lklhd = likelihood(d)
if curr_lklhd > lklhd_max:
iters_without_improvement = 0
lklhd_max = curr_lklhd
iters_without_improvement += 1
print(f"burn-in period complete after {iters} iterations")
burn_in_iters = iters
print('Using burn-in time to determine sampling frequency')
print(f'Taking every {burn_in_iters}th sample')
iters = 0
dtot = np.zeros(D)
while iters < samples:
iters += 1
for _ in range(burn_in_iters):
d = sample_d(d, rng)
dtot += d
print(f"Converged in {iters} samples (taken out of {burn_in_iters*iters} total samples)")
return dtot / iters
posteriors = gibbs_sample(500, np.random.default_rng())
burn-in period complete after 19 iterations
Using burn-in time to determine sampling frequency
Taking every 19th sample
Converged in 500 samples (taken out of 9500 total samples)
[ 3. 498. 9. 500. 333. 4. 8. 0. 3. 500. 1. 500. 500. 500.
492. 481. 491. 453. 500. 496. 46. 359. 500. 499. 5. 500. 12. 0.
500. 0. 0. 47. 6. 500. 0. 1. 495. 0. 0. 0. 497. 499.
500. 500. 0. 5. 499. 499. 0. 500.]
array([0.006, 0.996, 0.018, 1. , 0.666, 0.008, 0.016, 0. , 0.006,
1. , 0.002, 1. , 1. , 1. , 0.984, 0.962, 0.982, 0.906,
1. , 0.992, 0.092, 0.718, 1. , 0.998, 0.01 , 1. , 0.024,
0. , 1. , 0. , 0. , 0.094, 0.012, 1. , 0. , 0.002,
0.99 , 0. , 0. , 0. , 0.994, 0.998, 1. , 1. , 0. ,
0.01 , 0.998, 0.998, 0. , 1. ])
for i, p in enumerate(posteriors):
print(f'p(d_{i+1}|s)\t= {p:.3f}')
p(d_1|s) = 0.006
p(d_2|s) = 0.996
p(d_3|s) = 0.018
p(d_4|s) = 1.000
p(d_5|s) = 0.666
p(d_6|s) = 0.008
p(d_7|s) = 0.016
p(d_8|s) = 0.000
p(d_9|s) = 0.006
p(d_10|s) = 1.000
p(d_11|s) = 0.002
p(d_12|s) = 1.000
p(d_13|s) = 1.000
p(d_14|s) = 1.000
p(d_15|s) = 0.984
p(d_16|s) = 0.962
p(d_17|s) = 0.982
p(d_18|s) = 0.906
p(d_19|s) = 1.000
p(d_20|s) = 0.992
p(d_21|s) = 0.092
p(d_22|s) = 0.718
p(d_23|s) = 1.000
p(d_24|s) = 0.998
p(d_25|s) = 0.010
p(d_26|s) = 1.000
p(d_27|s) = 0.024
p(d_28|s) = 0.000
p(d_29|s) = 1.000
p(d_30|s) = 0.000
p(d_31|s) = 0.000
p(d_32|s) = 0.094
p(d_33|s) = 0.012
p(d_34|s) = 1.000
p(d_35|s) = 0.000
p(d_36|s) = 0.002
p(d_37|s) = 0.990
p(d_38|s) = 0.000
p(d_39|s) = 0.000
p(d_40|s) = 0.000
p(d_41|s) = 0.994
p(d_42|s) = 0.998
p(d_43|s) = 1.000
p(d_44|s) = 1.000
p(d_45|s) = 0.000
p(d_46|s) = 0.010
p(d_47|s) = 0.998
p(d_48|s) = 0.998
p(d_49|s) = 0.000
p(d_50|s) = 1.000