Approximate inference via Gibbs sampling

Taken from coursework for ECE 7751: Graphical Models in Machine Learning, taught by Faramarz Fekri at Georgia Tech, Spring 2023

Consider a setting in which there are $D$ diseases and a patient either has ($d_i=1$) or does not have ($d_i=0$) each disease. The hospital can measure $S$ symptoms, where $s_j=1$ when the patient has the symptom and $s_j=0$ otherwise. A simple Bayesian network for this setting is given by:

$$ p\left(s_1, \ldots, s_S, d_1, \ldots, d_D\right)=\prod_{j=1}^D p\left(s_j \mid \mathbf{d}\right) \prod_{i=1}^D p\left(d_i\right) $$

where $\mathbf{d}=\left(d_1, \ldots, d_D\right)^T$ and

$$ p\left(s_j=1 \mid \mathbf{d}\right)=\sigma\left(\mathbf{w}_j^T \mathbf{d}+b_j\right) $$

where $\sigma(x) = 1/(1+e^{-x})$.

In the above $\mathbf{w}_j$ is a vector of parameters relating symptom $j$ to the diseases and $b_j$ is related to the prevalence of the symptom. The hospital provides the collection of parameters $W$ and $b$, the prior disease probabilities $p$ (with $p(d_i = 1) = p_i$) and a vector $s$ of symptoms for the patient; see SymptomDisease.mat

Use Gibbs sampling to estimate (using a reasonable amount of burn-in and sub-sampling) to estimate the vector

$$ \bigl[p(d_1=1|s),\ldots,p(d_D=1|s)\bigr] $$

Solution

from scipy.io import loadmat
import numpy as np

data = loadmat('SymptomDisease.mat')
for var in ['W', 'b', 'p', 's']:
    globals()[var] = data[var].squeeze()
print(f'W.shape = {W.shape}')
S, D = W.shape
assert S == s.size == b.size
assert D == p.size
p.shape
W.shape = (200, 50)

(50,)

To perform Gibbs sampling, we need $p(d_i|d_{-i},s)$, where $d_{-i}$ refers to ${d_1,\ldots,d_{i-1},d_{i+1},\ldots,d_D}$. We can compute this using Bayes’ rule:

$$ p(d_i|d_{-i},s) = \frac{p(s,d)}{p(s,d_{-i})} = \frac{p(d)p(s|d)}{p(d_{-i})p(s|d_{-i})} = \frac{p(d)p(s|d)}{\sum_{d_i} p(d)p(s|d)}$$

We can compute the denominator by summing over the numerator when $d_i=1$ and $d_i=0$.

def σ(x):
    return 1 / (1 + np.exp(-x))

def likelihood(d) -> float:
    # multiply symptom probs p(s_j|d)
    ps = σ(W@d + b)
    L = np.prod(ps**s * (1 - ps)**(1 - s))
    # and disease probs p(d_i)
    L *= np.prod(p**d * (1 - p)**(1 - d))
    return L

def sample_d(d, rng):
    for i in range(D):
        d[i] = 1
        p_d_s_di1 = likelihood(d)
        d[i] = 0
        p_d_s_di0 = likelihood(d)
        condp_di = p_d_s_di1 / (p_d_s_di1 + p_d_s_di0)
        d[i] = rng.binomial(1, condp_di)
    return d
sample_d((np.random.uniform(size=D) < p).astype(int), np.random.default_rng())
array([1, 1, 0, 1, 0, 0, 0, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1,
       1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1, 1, 1,
       0, 0, 1, 1, 0, 1])

We can use the complete data likelihood $p(s, d)$ not only to sample, but also for a principled approach to decide when burn-in is complete.

And we’ll continue sampling until the disease posterior probabilities converge.

def gibbs_sample(samples, rng):
    d = (rng.uniform(size=D) < p).astype(int)
    iters = 0

    # burn-in period
    iters_without_improvement = 0
    curr_lklhd = likelihood(d)
    lklhd_max = -np.inf
    while iters_without_improvement < 10:
        iters += 1
        d = sample_d(d, rng)
        curr_lklhd = likelihood(d)
        if curr_lklhd > lklhd_max:
            iters_without_improvement = 0
            lklhd_max = curr_lklhd
        else:
            iters_without_improvement += 1
    print(f"burn-in period complete after {iters} iterations")
    burn_in_iters = iters

    print('Using burn-in time to determine sampling frequency')
    print(f'Taking every {burn_in_iters}th sample')

    iters = 0
    dtot = np.zeros(D)
    while iters < samples:
        iters += 1
        for _ in range(burn_in_iters):
            d = sample_d(d, rng)
        dtot += d
    print(f"Converged in {iters} samples (taken out of {burn_in_iters*iters} total samples)")
    print(dtot)
    return dtot / iters

posteriors = gibbs_sample(500, np.random.default_rng())
posteriors
burn-in period complete after 19 iterations
Using burn-in time to determine sampling frequency
Taking every 19th sample
Converged in 500 samples (taken out of 9500 total samples)
[  3. 498.   9. 500. 333.   4.   8.   0.   3. 500.   1. 500. 500. 500.
 492. 481. 491. 453. 500. 496.  46. 359. 500. 499.   5. 500.  12.   0.
 500.   0.   0.  47.   6. 500.   0.   1. 495.   0.   0.   0. 497. 499.
 500. 500.   0.   5. 499. 499.   0. 500.]

array([0.006, 0.996, 0.018, 1.   , 0.666, 0.008, 0.016, 0.   , 0.006,
       1.   , 0.002, 1.   , 1.   , 1.   , 0.984, 0.962, 0.982, 0.906,
       1.   , 0.992, 0.092, 0.718, 1.   , 0.998, 0.01 , 1.   , 0.024,
       0.   , 1.   , 0.   , 0.   , 0.094, 0.012, 1.   , 0.   , 0.002,
       0.99 , 0.   , 0.   , 0.   , 0.994, 0.998, 1.   , 1.   , 0.   ,
       0.01 , 0.998, 0.998, 0.   , 1.   ])
    for i, p in enumerate(posteriors):
        print(f'p(d_{i+1}|s)\t= {p:.3f}')
p(d_1|s)    = 0.006
p(d_2|s)    = 0.996
p(d_3|s)    = 0.018
p(d_4|s)    = 1.000
p(d_5|s)    = 0.666
p(d_6|s)    = 0.008
p(d_7|s)    = 0.016
p(d_8|s)    = 0.000
p(d_9|s)    = 0.006
p(d_10|s)   = 1.000
p(d_11|s)   = 0.002
p(d_12|s)   = 1.000
p(d_13|s)   = 1.000
p(d_14|s)   = 1.000
p(d_15|s)   = 0.984
p(d_16|s)   = 0.962
p(d_17|s)   = 0.982
p(d_18|s)   = 0.906
p(d_19|s)   = 1.000
p(d_20|s)   = 0.992
p(d_21|s)   = 0.092
p(d_22|s)   = 0.718
p(d_23|s)   = 1.000
p(d_24|s)   = 0.998
p(d_25|s)   = 0.010
p(d_26|s)   = 1.000
p(d_27|s)   = 0.024
p(d_28|s)   = 0.000
p(d_29|s)   = 1.000
p(d_30|s)   = 0.000
p(d_31|s)   = 0.000
p(d_32|s)   = 0.094
p(d_33|s)   = 0.012
p(d_34|s)   = 1.000
p(d_35|s)   = 0.000
p(d_36|s)   = 0.002
p(d_37|s)   = 0.990
p(d_38|s)   = 0.000
p(d_39|s)   = 0.000
p(d_40|s)   = 0.000
p(d_41|s)   = 0.994
p(d_42|s)   = 0.998
p(d_43|s)   = 1.000
p(d_44|s)   = 1.000
p(d_45|s)   = 0.000
p(d_46|s)   = 0.010
p(d_47|s)   = 0.998
p(d_48|s)   = 0.998
p(d_49|s)   = 0.000
p(d_50|s)   = 1.000
Kyle Johnsen
Kyle Johnsen
PhD Candidate
Biomedical Engineering

My research interests include applying principles from the brain to improve machine learning.