Expectation-maximization for a Markov chain mixture model
Taken from coursework for ECE 7751: Graphical Models in Machine Learning, taught by Faramarz Fekri at Georgia Tech, Spring 2023
Assume that a sequence $v_1,\ldots,v_T \in {1,\dots,V}$ is generated by a Markov chain. For a single chain of length $T$, we have $$ p(v_1,\dots,v_T) = p(v_1)\prod_{t=1}^{T-1} p(v_{t+1}|v_t) \newcommand{\EE}{\mathbb{E}} \newcommand{\ind}{\mathbb{1}} \newcommand{\answertext}[1]{\textcolor{Green}{\fbox{#1}}} \newcommand{\answer}[1]{\answertext{$#1$}} \newcommand{\argmax}[1]{\underset{#1}{\operatorname{argmax}}} \newcommand{\argmin}[1]{\underset{#1}{\operatorname{argmin}}} \newcommand{\comment}[1]{\textcolor{gray}{\textrm{#1}}} \newcommand{\vec}[1]{\mathbf{#1}} \newcommand{\inv}[1]{\frac{1}{#1}} \newcommand{\abs}[1]{\lvert{#1}\rvert} \newcommand{\norm}[1]{\lVert{#1}\rVert} \newcommand{\lr}[1]{\left(#1\right)} \newcommand{\lrb}[1]{\left[#1\right]} \newcommand{\lrbr}[1]{\lbrace#1\rbrace} \newcommand{\Bx}[0]{\mathbf{x}} $$
For simplicity, we denote the sequence of visible variables as $$ \vec v = \lr{v_1,\dots,v_T} $$
For a single Markov chain labelled by $h$, $$ p(\vec v|h) = p(v_1|h)\prod_{t=1}^{T-1}p(v_{t+1}|v_t,h) $$
In total there are a set of $H$ such Markov chains $h=1,\dots,H$. The distribution on the visible variables is therefore $$ p(\vec v) = \sum_{h=1}^H p(\vec v|h)p(h) $$
Deriving EM update equations
There are a set of training sequences, $\vec v^n,n=1,\dots,N$. Assuming that each sequence $\vec v^n$ is independently and identically drawn from a Markov chain mixture model with $H$ components, derive the Expectation Maximization algorithm for training this model.
Solution
E step
We need to compute the distribution of the hidden variables $h$ given the current parameters, which is equivalent to the expected “count” of each chain $h$ and each sample $n$. These counts are the sufficient statistics for the categorical distribution $p(h|v^n,\theta)$. We’ll later pretend this distribution doesn’t depend on $v^n$ during the M step, and thus represent this as $p(h|v^n,\theta)=q^n(h)=\tau_h^n$: $$ \begin{aligned} \tau_h^n &= \EE_{p(h | v^n,\theta)} [\ind(h^n=h)] \\ &= p(h | v^n,\theta) \\ &= \frac{p(h, v^n | \theta)}{p(v^n|\theta)} \\ &= \frac{p(h, v^n | \theta)}{\sum_h p(h, v^n|\theta)} \\ &= \answer{\frac{p(h)p(v_1|h)\prod_{t=2}^{T}p(v_t|v_{t-1,h})} {\sum_h p(h)p(v_1|h)\prod_{t=2}^{T}p(v_t|v_{t-1,h})}} \\ \end{aligned} $$
At each iteration, these counts are computed and used to take expectations over $q^n(h)$. This is needed to derive the complete data likelihood, which is maximized in the M step.
M step
In the M step, we maximize the expected complete data log likelihood $f(\theta)$ with respect to $\theta$: $$ \begin{aligned} f(\theta) &= \EE_{h \sim q} \log p(v^1,\dots,v^n,h^1,\dots,h^n|\theta) \\ &= \EE_{h \sim q} \sum_n \log p(v^n,h|\theta) \\ &= \sum_n \EE_{h \sim q^n} \log p(v^n,h|\theta) \\ &= \sum_n \sum_h q^n(h) \log p(v^n,h|\theta) \\ &= \sum_{n,h} \tau_h^n \log p(v^n,h|\theta) \\ &= \sum_{n,h} \tau_h^n \log p(h) + \log p(v_1^n|h) + \sum_{t=2}^T \log p(v_t^n|v_{t-1}^n,h) \\ &= \sum_{n,h} \tau_h^n \log \theta_h + \log \pi_{v_1^n|h} + \sum_{t=2}^T \log a_{v_t^n|v_{t-1}^n,h} \\ \end{aligned} $$
$\theta_h, \pi_{k|h}, a_{j|i,h}$ represent the chain prior, initial, and transition probabilities, respectively. We now we maximize $f(\theta)$ to derive the update step for each.
$$ \newcommand{\pdd}[2]{\frac{\partial #1}{\partial #2}} \newcommand{\pdfd}[1]{\pdd{f}{#1}} \newcommand{\pdld}[1]{\pdd{\mathcal{L}}{#1}} $$
$\theta_h$ update
We need to introduce a Lagrange multiplier to enforce the constraint that $\sum_h \theta_h = 1$, and set $\pdld{\theta_h}=\pdld{\lambda}=0$: $$ \hat{\theta}_h = \argmax{\theta_h} \mathcal{L}(\theta,\lambda) $$ $$ \mathcal{L}(\theta,\lambda) = f(\theta) - \lambda \lr{\sum_h \theta_h - 1} $$
Then we set $\pdld{\theta_h}=0$: $$ \pdld{\theta_h} = \frac{\partial}{\partial \theta_h} \sum_{n} \tau_h^n \log \theta_h - \lambda \theta_h = 0 $$ $$ \lambda = \frac{\sum_n \tau_h^n}{\hat{\theta}_h} \qquad \hat{\theta}_h = \frac{\sum_n \tau_h^n}{\lambda}$$
How do we get $\lambda$? We know that $\sum_h \theta_h = 1$, so we sum top and bottom over $h$: $$ \lambda = \frac{\sum_{n,h} \tau_h^n}{\sum_h \hat{\theta}_h} = \sum_{n,h} \tau_h^n = N $$
Hence, $$ \answer{\hat{\theta}_h = \frac{\sum_n \tau_h^n}{N}} $$
$\pi_{k|h}$ update
We follow the same pattern as before, using a Lagrange multiplier:
$$ \hat{\pi}_{k|h} = \argmax{\pi_{k|h}} $$ $$ \hat{\pi}_{k|h} = \argmax{\pi_{k|h}} \mathcal{L}(\theta,\lambda) $$ $$ \mathcal{L}(\theta,\lambda) = f(\theta) - \lambda \lr{\sum_{k} \pi_{k|h} - 1} $$ $$ \pdld{\pi_{k|h}} = \pdd{}{\pi_{k|h}} \sum_{n} \tau_h^n \log \pi_{k|h} \ind(v_1^n=k) - \lambda \pi_{k|h} = 0 $$ $$ \lambda = \frac{\sum_n \tau_h^n \ind(v_1^n=k) }{\hat{\pi}_{k|h}} \qquad \hat{\pi}_{k|h} = \frac{\sum_n \tau_h^n\ind(v_1^n=k) }{\lambda}$$Similar to before, taking our constraint into account, we sum over $k$ on top and bottom to solve for $\lambda$. $$ \lambda = \frac{\sum_{n,k} \tau_h^n \ind(v_1^n=k) }{\sum_k \hat{\pi}_{k|h}} = \sum_{n,k} \tau_h^n \ind(v_1^n=k) = \sum_n \tau_h^n $$ $$ \answer{\hat{\pi}_{k|h} = \frac{\sum_{n} \tau_h^n \ind(v_1^n=k)}{\sum_{n} \tau_h^n}} $$
$a_{j|i,h}$ update
The process is again similar to what we did before:
$$ \newcommand{\transprob}{a_{j|i,h}} \newcommand{\transprobhat}{\hat{a}_{j|i,h}} $$ $$ \transprob = \argmax{\transprob} \mathcal{L}(\theta,\lambda) $$ $$ \mathcal{L}(\theta,\lambda) = f(\theta) - \lambda \lr{\sum_{j} \transprob - 1} $$ $$ \pdld{\transprob} = \pdd{}{\transprob} \sum_{n} \tau_h^n \sum_{t=2}^T \ind(v_t^n=j,v_{t-1}^n=i) \log\transprob - \lambda \transprob = 0 $$ $$ \begin{aligned} \lambda &= \frac{\sum_n \tau_h^n \sum_{t=2}^T \ind(v_t^n=j,v_{t-1}^n=i)}{\transprobhat} \\\\ &= \sum_n \tau_h^n \sum_{t=2}^T \sum_j \ind(v_t^n=j,v_{t-1}^n=i) \\\\ &= \sum_n \tau_h^n \sum_{t=2}^T \ind(v_{t-1}^n=i) \\\\ \end{aligned} $$ $$ \answer{ \transprobhat = \frac{\sum_n \tau_h^n \sum_{t=2}^T \ind(v_t^n=j, v_{t-1}^n=i)} {\sum_n \tau_h^n \sum_{t=2}^T \ind(v_{t-1}^n=i)} } $$Python code and application to biological sequences
The file sequences.mat
contains a set of fictitious bio-sequences in a cell array sequences{n}(t)
.
Thus sequences{3}(:)
is the third sequence, GTCTCCTGCCCTCTCTGAAC
, which consists of 20 timesteps.
There are 20 such sequences in total.
Your task is to cluster these sequences into two clusters, assuming that each cluster is modelled by a Markov chain.
State which of the sequences belong together by assigning a sequence $\mathbf{v}^n$ to that state for which $p(h|\mathbf{v}^n)$ is highest.
Your solution must print and report the two clusters members (show which of the 20 sequences from the file sequences.mat
are assigned to Cluster 1 and Cluster 2).
Solution
from scipy.io import loadmat
import numpy as np
seqs = loadmat('sequences.mat')['sequences'][0]
seqs = np.array([list(seq[0]) for seq in seqs])
seqs[seqs == 'A'] = 0
seqs[seqs == 'C'] = 1
seqs[seqs == 'G'] = 2
seqs[seqs == 'T'] = 3
seqs = seqs.astype(int)
N, T = seqs.shape
seqs
array([[1, 0, 3, 0, 2, 2, 1, 0, 3, 3, 1, 3, 0, 3, 2, 3, 2, 1, 3, 2],
[1, 1, 0, 2, 3, 3, 0, 1, 2, 2, 0, 1, 2, 1, 1, 2, 0, 0, 0, 2],
[3, 2, 2, 0, 0, 1, 1, 3, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[2, 3, 1, 3, 1, 1, 3, 2, 1, 1, 1, 3, 1, 3, 1, 3, 2, 0, 0, 1],
[2, 3, 2, 1, 1, 3, 2, 2, 0, 1, 1, 3, 2, 0, 0, 0, 0, 2, 1, 1],
[1, 2, 2, 1, 1, 2, 1, 2, 1, 1, 3, 1, 1, 2, 2, 2, 0, 0, 1, 2],
[0, 0, 0, 2, 3, 2, 1, 3, 1, 3, 2, 0, 0, 0, 0, 1, 3, 1, 0, 1],
[0, 1, 0, 3, 2, 0, 0, 1, 3, 0, 1, 0, 3, 0, 2, 3, 0, 3, 0, 0],
[2, 3, 3, 2, 2, 3, 1, 0, 2, 1, 0, 1, 0, 1, 2, 2, 0, 1, 3, 2],
[1, 1, 3, 1, 1, 1, 1, 3, 1, 1, 1, 1, 3, 3, 3, 1, 1, 3, 2, 1],
[1, 0, 1, 3, 0, 1, 2, 2, 1, 3, 0, 1, 1, 3, 2, 2, 2, 1, 0, 0],
[1, 2, 2, 3, 1, 1, 2, 3, 1, 1, 2, 0, 2, 2, 1, 0, 1, 3, 1, 2],
[3, 0, 0, 2, 3, 2, 3, 1, 1, 3, 1, 3, 2, 1, 3, 1, 1, 3, 0, 0],
[1, 0, 1, 1, 0, 3, 1, 0, 1, 1, 1, 3, 3, 2, 1, 3, 0, 0, 2, 2],
[0, 0, 0, 2, 0, 0, 1, 3, 1, 1, 1, 1, 3, 1, 1, 1, 3, 2, 1, 1],
[1, 0, 0, 0, 3, 2, 1, 1, 3, 1, 0, 1, 2, 1, 2, 3, 1, 3, 1, 0],
[2, 1, 1, 0, 0, 2, 1, 0, 2, 2, 2, 3, 1, 3, 1, 0, 0, 1, 3, 3],
[1, 0, 3, 2, 2, 0, 1, 3, 2, 1, 3, 1, 1, 0, 1, 0, 0, 0, 2, 2],
[0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0, 0, 0, 0, 1, 1, 3, 0, 0, 2],
[2, 1, 2, 3, 0, 0, 0, 0, 0, 0, 0, 2, 3, 1, 1, 3, 2, 2, 2, 3]])
H = 2
K = 4
def em_cluster():
theta = np.random.rand(H)
theta /= theta.sum()
pi = np.random.rand(K, H)
pi /= pi.sum(axis=0)
A = np.random.rand(K, K, H) # i, j, h
A /= A.sum(axis=1, keepdims=True)
assert np.allclose(A.sum(axis=1), 1)
iters_no_change = 0
h_hat_prev = np.empty(N)
tau_prev = np.empty((N, H))
tau = np.ones_like(tau_prev)
iters = 0
# while iters_no_change < 10:
while not np.allclose(tau_prev, tau):
tau_prev = tau.copy()
iters += 1
# E-step
# h n x h n x h
p_hv = theta * pi[seqs[:, 0], :] * np.prod(A[seqs[:, :-1], seqs[:, 1:], :], axis=1)
assert p_hv.shape == (N, H)
tau = p_hv / p_hv.sum(axis=1, keepdims=True)
h_hat = np.argmax(tau, axis=1)
# M-step
theta = tau.sum(axis=0) / N
for k in range(K):
pi[k, :] = np.sum(tau[seqs[:, 0] == k, :], axis=0) / np.sum(tau, axis=0)
A = np.zeros_like(A)
for n in range(N):
for t in range(1, T):
A[seqs[n, t-1], seqs[n, t], :] += tau[n, :]
A /= np.sum(A, axis=1, keepdims=True)
# stopping condition
if not np.all(h_hat_prev == h_hat):
iters_no_change = 0
else:
iters_no_change += 1
h_hat_prev = h_hat.copy()
print(f'converged in {iters} iterations')
likelihood = theta[h_hat] * pi[seqs[:, 0], h_hat] * np.prod(A[seqs[:, :-1], seqs[:, 1:], h_hat[:, np.newaxis]], axis=1)
logl = np.log(likelihood).sum()
return h_hat, logl
h_hat = em_cluster()
h_hat
converged in 11 iterations
(array([0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0]),
-497.82024107176716)
# repeat and keep highest likelihood model
best_logl = -np.inf
for i in range(20):
h_hat_i, logl_i = em_cluster()
if logl_i > best_logl:
h_hat = h_hat_i
best_logl = logl_i
print(f'best log likelihood: {best_logl}')
clusters = {f'Cluster {i+1}': np.where(h_hat == i)[0]+1 for i in [0, 1]}
for k, v in clusters.items():
print(f'{k}:', *v, sep=' ')
converged in 12 iterations
converged in 26 iterations
converged in 63 iterations
converged in 21 iterations
converged in 8 iterations
converged in 11 iterations
converged in 9 iterations
converged in 12 iterations
converged in 17 iterations
converged in 27 iterations
converged in 20 iterations
converged in 11 iterations
converged in 12 iterations
converged in 19 iterations
converged in 10 iterations
converged in 25 iterations
converged in 14 iterations
converged in 25 iterations
converged in 24 iterations
converged in 24 iterations
best log likelihood: -483.6486774197766
Cluster 1: 1 2 6 8 9 11 12 14 16 17 18
Cluster 2: 3 4 5 7 10 13 15 19 20
# print sequences
for k, v in clusters.items():
print(f'{k}:')
for i in v:
print(''.join(['ACGT'[i] for i in seqs[i-1]]))
print('')
Cluster 1:
CATAGGCATTCTATGTGCTG
CCAGTTACGGACGCCGAAAG
CGGCCGCGCCTCCGGGAACG
ACATGAACTACATAGTATAA
GTTGGTCAGCACACGGACTG
CACTACGGCTACCTGGGCAA
CGGTCCGTCCGAGGCACTCG
CACCATCACCCTTGCTAAGG
CAAATGCCTCACGCGTCTCA
GCCAAGCAGGGTCTCAACTT
CATGGACTGCTCCACAAAGG
Cluster 2:
TGGAACCTTAAAAAAAAAAA
GTCTCCTGCCCTCTCTGAAC
GTGCCTGGACCTGAAAAGCC
AAAGTGCTCTGAAAACTCAC
CCTCCCCTCCCCTTTCCTGC
TAAGTGTCCTCTGCTCCTAA
AAAGAACTCCCCTCCCTGCC
AAAAAAACGAAAAACCTAAG
GCGTAAAAAAAGTCCTGGGT