Expectation Maximization (EM)

The Expectation Maximization (EM) is a well-known unsupervised learning algorithm, which is often used to cluster unlabelled data in a Gaussian Mixture Model (GMM). In this example, EM is employed in a multi-variate fashion to identify 2 groups of pixel intensities in an image.



last update: 22/09/2021

Expectation Optimization Model

Likelihood and its logarithm

Let any observed data be denoted by $\mathbf{X}\in\mathrm{R}^{N\times M}$ with $N$ samples and $M$ features so that $\mathbf{X}=\left[x_1^{(m)}, x_2^{(m)}, x_n^{(m)}, \dots, x_N^{(m)}\right]^\intercal$, which can be modelled as a joint probability density (or likelihood) function $\mathcal{N}\left(\cdot\right)$ by

$$ \mathcal{N}\left(\mathbf{X}|\boldsymbol\mu, \boldsymbol\Sigma\right) = \frac{\exp\left(-\frac{1}{2} ({\mathbf X}-{\boldsymbol\mu}){\boldsymbol\Sigma}^{-1}({\mathbf X}-{\boldsymbol\mu})^\intercal\right)}{\sqrt{(2\pi)^k |\boldsymbol\Sigma|}} $$

where $\boldsymbol\mu \in \mathrm{R}^{1\times M}$ and $\boldsymbol\Sigma \in \mathrm{R}^{M\times M}$ represent the mean and covariance across features, respectively.

Note: The terms around $\boldsymbol\Sigma^{-1}$ have a swapped transpose with regards to the conventional notation in order to comply with the above matrix and vector dimensions, which are governed by the multivariate function implementation.

The logarithm comes in handy as it allows to write

$$ \log \mathcal{N}\left(\mathbf{X}|\boldsymbol\mu, \boldsymbol\Sigma\right) = -\frac{1}{2}\log|\boldsymbol\Sigma| -\frac{1}{2} ({\mathbf X}-{\boldsymbol\mu}){\boldsymbol\Sigma}^{-1}({\mathbf X}-{\boldsymbol\mu})^\intercal $$

for the sake of convenience.

Multi-variate Gaussian Mixture Model (GMM)

To account for multiple modes in $\mathbf{X}$, our joint probability density (or likelihood) function is extended to a Gaussian Mixture Model (GMM), which reads

$$ p(\mathbf{X}, \boldsymbol{\theta}) = \sum_k^K \pi_k \mathcal{N}\left(\mathbf{X}|\boldsymbol{\mu}_k, \boldsymbol{\Sigma}_k\right) $$

where $\boldsymbol{\theta}:=\{\pi_k, \boldsymbol{\mu}_k, \boldsymbol{\Sigma}_k\}$ represents the unknown mixture parameters with $k \in \{1, 2, \dots, K\}$ as the Gaussian component index and $K$ as its total number. Note that the mixing coefficient is in accordance with $0\le \pi_k \le 1$ and $\sum_{k=1}^K \pi_k=1$.

Latent variable

A key concept of the likelihood function $p(\mathbf{X}, \boldsymbol{\theta})$ can be represented as a conditional distribution $p(\mathbf{X} | \mathbf{z}, \boldsymbol{\theta})$ containing latent variables $\mathbf{z}:=z_n$ which indicate the probability at which observed points $\mathbf{x}_n$ are assigned to a cluster $k$, i.e. describing how probable it is that $\mathbf{x}_n$ is drawn from the real distribution component $k$.

$$ p(\mathbf{X}, \boldsymbol{\theta}) = \int p(\mathbf{X} | \mathbf{z}, \boldsymbol{\theta})p(\mathbf{z}) \,\mathrm{d}\mathbf{z} $$

and marginalizing out the latent variables $\mathbf{z}$ via integration.

Optimization

According to the above statements, the GMM-based log-likelihood writes

$$ L_{\mathbf{X}}(\boldsymbol{\theta}) = \sum_n^N \log p(\mathbf{x}_n, \boldsymbol{\theta}) = \sum_n^N \log \sum_k^K \pi_k \mathcal{N}\left(\mathbf{x}_n|\boldsymbol{\mu}_k, \boldsymbol{\Sigma}_k\right) $$

where $L_{\mathbf{X}}(\boldsymbol{\theta})$ represents the Maximum Likelihood Estimate (MLE) with $\mathbf{X}$ in the subscript signifying that the MLE is bound to $\mathbf{X}$. In the context of optimization, it is our goal to minimize the negative logarithm of the MLE, i.e. $-L_{\mathbf{x}}(\boldsymbol{\theta})$. The objective function we aim to solve can thus be formulated as

$$ \hat{\boldsymbol{\theta}} = \underset{\boldsymbol{\theta}}{\operatorname{arg max}}L_{\mathbf{X}}(\boldsymbol{\theta}) = \underset{\boldsymbol{\theta}}{\operatorname{arg max}}\left(\sum_n^N \log p\left(\mathbf{x}_n,\boldsymbol{\theta}\right)\right) = \underset{\boldsymbol{\theta}}{\operatorname{arg min}}\left(-\sum_n^N \log p\left(\mathbf{x}_n,\boldsymbol{\theta}\right)\right) $$

which for simplification can be expressed as logarithmic likelihood.

Expectation Maximization (EM) Algorithm

Let $Q(\boldsymbol{\theta}, \boldsymbol{\theta}^{(t)})$ represent the expected value of log-likelihood function given by

$$ Q(\boldsymbol{\theta}, \boldsymbol{\theta}^{(t)}) = \mathbb{E}_{\mathbf{z}|\mathbf{X}, \boldsymbol{\theta}^{(t)}}\left[\log L(\boldsymbol{\theta}; \mathbf{X}, \mathbf{z})\right] $$

at each iteration step $t$.

1. Expectation step

For the expectation update procedure, the likelihood weights $w_n^{(k)}$ are computed by

$$ w_n^{(k)} \leftarrow \frac{\pi_k\mathcal{N}(\mathbf{x}_n | \boldsymbol{\mu}_k, \boldsymbol{\Sigma}_k)}{\sum_{k=1}^K \pi_k\mathcal{N}(\mathbf{x}_n | \boldsymbol{\mu}_k, \boldsymbol{\Sigma}_k)}\, , \quad \forall k, \, \forall n $$

which indicate the probability of a sample $\mathbf{x}_n$ belonging to a cluster component $k$. A Python implementation of the expectation is provided hereafter.

In [1]:
from scipy.stats import multivariate_normal

def expectation_step(data, pi, mu, cov):
    
    # likelihood weights (responsibilities)
    w = np.zeros([len(mu), len(data)])

    for k in range(len(mu)):
        # compute probability density function for each component k; KxN
        w[k, :] = pi[k] * multivariate_normal.pdf(data, mean=mu[k], cov=cov[k])

    # normalize all probability density function components
    w = np.divide(w, w.sum(axis=0), out=np.zeros_like(w), where=w.sum(0) != 0)
    
    return w

2. Maximization step

In the second part, the maximization of $\boldsymbol{\theta}$ parameters can be written as

$$ \boldsymbol{\theta}^{(t+1)} = \underset{\boldsymbol{\theta}}{\operatorname{arg max}} Q(\boldsymbol{\theta}, \boldsymbol{\theta}^{(t)}) $$

for iteration $t$. This general update procedure breaks down to updating the mean $\boldsymbol{\mu}_k$, covariance $\boldsymbol{\Sigma}_k$ and mixing coefficients $\pi_k$ as follows

$$ \boldsymbol{\mu}_k \leftarrow \frac{\sum_{n=1}^{N} w_n^{(k)} \mathbf{x}_n}{\sum_{n=1}^{N} w_n^{(k)}}\, , \quad \forall k $$$$ \boldsymbol{\Sigma}_k \leftarrow \frac{\sum_{n=1}^{N} w_n^{(k)} \left(\mathbf{x}_n-\boldsymbol{\mu}_k\right)^\intercal\left(\mathbf{x}_n-\boldsymbol{\mu}_k\right)}{\sum_{n=1}^{N} w_n^{(k)}} \, , \quad \forall k $$$$ \pi_k \leftarrow \frac{1}{N} \sum_{n=1}^{N} w_n^{(k)} \, , \quad \forall k $$

which mainly relies on the pre-computed $w_n^{(k)}$. At the end of this notebook proofs can be found for the above procedures which are implemented below as part of the maximization step.

In [2]:
def maximization_step(data, w):
    
    # power per component; KxN -> K
    p_comp = np.sum(w, axis=1)
    # update mixing coefficients
    pi = p_comp / len(data)
    
    mu, cov = np.zeros([len(p_comp), data.shape[-1]]), np.zeros([len(p_comp), data.shape[-1], data.shape[-1]])
    for k in range(len(p_comp)):
        # update means
        mu[k] = np.divide(np.dot(w[k, :], data), p_comp[k], out=np.zeros_like(mu[k]), where=p_comp[k] != 0)
        # update standard deviations
        cov[k] = np.dot(w[np.newaxis, k, :] * (data - mu[k]).T, data - mu[k]) / p_comp[k]
        #cov[k] = np.dot((data - mu[k]).T, w[k, :, np.newaxis] * (data - mu[k])) / p_comp[k]
        
    return pi, mu, cov
In [3]:
import numpy as np

def em_algorithm(data, n_comp=None, s_mean=None, s_covs=None, pi=None, max_iter=50, tol=1e-15):
    """
    Expectation Maximization (EM)
    (i) initialize the likelihood
    (ii) estimate the log-likelihood
    (iii) find the partial derivative of the log-likelihood

    :param data: input data
    :param n_comp: number of mixture components
    :param s_mean: list of mean values
    :param s_covs: covariance matrix
    :param pi: mixing coefficients
    :param max_iter: maximum number of iterations
    :param tol: tolerance of log-likelihood as stop condition
    :return: aggregated lists of intermediate means, covariances, mixtures, latent variables and likelihoods
    """

    # dimensions init
    n_comp = len(s_mean) if n_comp is None and len(s_mean) > 1 else 2  # number of mixture components
    n_data = len(data)  # number of data points
    n_dims = data.shape[-1]  # number of multi-variate dimensions
    
    # parameter init
    gen_randcomps = lambda vmax, vmin, K, C: np.array([np.random.uniform(vmin, vmax, size=C) for _ in range(K)])
    pi = (1. / n_comp) * np.ones(n_comp) if pi is None else pi  # mixing coefficients
    mu = gen_randcomps(data.min(), data.max(), n_comp, n_dims)  # means of components
    cov = [np.diag(v) for v in gen_randcomps(data.max() / 50, data.max() / 25, n_comp, n_dims)]
    
    # lists init for intermediate results (log-likelihoods (lls), responsibilities, pi, mu, covariance)
    ll_list, w_list, pi_list, mu_list, cov_list = ([] for _ in range(5))

    for i in range(max_iter):

        w = expectation_step(data, pi, mu, cov)
        
        pi, mu, cov = maximization_step(data, w)
        
        ll_hood = np.sum(np.log(w.sum(axis=0)))
        
        ll_list.append(ll_hood)
        pi_list.append(pi.tolist())
        mu_list.append(mu)
        cov_list.append(cov)
        w_list.append(w)

        if len(ll_list) > 1 and 0 < ll_list[-1] - ll_list[-2] < tol:
            print('EM converged after %s iterations\n' % i)
            break

    return mu_list, cov_list, pi_list, w_list, ll_list

Unlike the k-means algorithm, EM provides a numerical likelihood for each sample indicating its cluster group membership. After convergence of the above two-step-iteration scheme, one can infer the per sample cluster correspondence from respective likelihood weights $w_n^{(k)}$ via

$$ \mathbf{z} = z_n = \underset{k}{\operatorname{arg max}} w_n^{(k)} $$

Image Segmentation using EM-based GMM

In [4]:
from PIL import Image
import matplotlib.pyplot as plt
import requests
import urllib
import numpy as np

def get_image(url: str = None):

    # load image
    try:
        img = np.array(Image.open(requests.get(url, stream=True).raw))
    except (requests.exceptions.ConnectionError, requests.exceptions.MissingSchema) as e:
        fp_alt = os.path.join('.', 'img', url.split('/')[-1])
        if os.path.exists(fp_alt):
            img = np.array(Image.open(fp_alt))
        else:
            raise e

    # downsample
    img = img[::2, ::2, ...]

    # color channel treatment
    if img.shape[-1] != 3:
        if img.shape[-1] == 4:
            img = img[..., :3]
        else:
            img = np.repeat(img[..., np.newaxis], 3, axis=-1)

    # pre-processing (normalization)
    th_min = 0
    th_max = 100 - th_min
    norm = lambda ch: (ch-np.percentile(ch, th_min))/(np.percentile(ch, th_max)-np.percentile(ch, th_min))
    # white balance
    if th_min != 0:
        img = np.round(norm(img)*255)
        img[img > 255] = 255
        img[img < 0] = 0
    else:
        img = np.dstack([norm(ch).T for ch in np.swapaxes(img, 0, -1)]) if True else img

    return img

# load image
url = 'https://pbblogassets.s3.amazonaws.com/uploads/2016/02/Green-Screen-Lighting.jpg'
img = get_image(url)
In [5]:
# run expectation maximization
data = img.reshape(-1, img.shape[-1])
n_comp = 2
s_mean = [[0, 255, 0], [255, 0, 255]]
mu_list, cov_list, pi_list, w_list, ll_hoods = em_algorithm(data=data, n_comp=n_comp, s_mean=s_mean)

# print results
print("means: %s" % mu_list[-1])
print("stds: %s" % cov_list[-1])
print("coeffs: %s" % pi_list[-1])
EM converged after 6 iterations

means: [[0.39780785 0.33046741 0.27309934]
 [0.26046594 0.64587764 0.2444385 ]]
stds: [[[0.08072217 0.05963549 0.04823085]
  [0.05963549 0.04705295 0.03966599]
  [0.04823085 0.03966599 0.03593154]]

 [[0.01840589 0.01372042 0.00741219]
  [0.01372042 0.01120826 0.00577302]
  [0.00741219 0.00577302 0.00318006]]]
coeffs: [0.17223814211694427, 0.8277618578830556]

Convergence Graph

In [6]:
if ll_hoods:
    plt.figure(figsize=(15, 5))
    plt.title('Convergence plot', fontsize=16)
    plt.plot(np.array(ll_hoods), color='gray', linestyle='--')
    plt.plot(np.array(ll_hoods), color='red', linestyle='', marker='.', markersize=9)
    plt.xlabel('Iteration $t$', fontsize=14)
    plt.xticks(np.arange(len(ll_hoods)-1), np.arange(1, len(ll_hoods), 1))
    plt.ylabel('Log-Likelihood', fontsize=14)

Animated Results

In [7]:
def plot_3d_distribution(data, mu, cov, latent_vars, ax=None):

    if ax is None:
        fig = plt.figure()

    # remove previous plot data
    ax.clear()
    ax.set_xlabel('Red', fontsize=18)
    ax.set_ylabel('Green', fontsize=18)
    ax.set_zlabel('Blue', fontsize=18)

    colors = ['green', 'orange']
    
    # iterate through clusters
    for i, (c, m) in enumerate(zip(np.unique(latent_vars), mu)):

        # plot cluster points
        group = data[latent_vars == c]
        ax.scatter(group[..., 0], group[..., 1], zs=group[..., 2], zdir='z', s=5,
                   c=colors[i % len(mu)], alpha=.025, label='Cluster #'+str(i+1))

        # plot mean value
        ax.scatter(m[0], m[1], m[2], s=100, c='red', marker='+')

        # decompose covariance in eigenvalues (variance) and eigenvectors (component directions)
        eig_vals, eig_vecs = np.linalg.eig(cov[i])

        # compute radii (standard deviation) from variance
        r = eig_vals**.5

        # create ellipsoid coordinates
        csamples = complex(0, 31)
        phi, theta = np.mgrid[0:np.pi:csamples, 0:2*np.pi:csamples]
        x = r[0] * np.sin(phi) * np.cos(theta)
        y = r[1] * np.sin(phi) * np.sin(theta)
        z = r[2] * np.cos(phi)

        # comply with right-hand rule and world frame convention
        eig_vecs[:, 2] = np.cross(eig_vecs[:, 0], eig_vecs[:, 1])

        # rotation matrix
        rot_mat = np.array([eig_vecs[:, 0], eig_vecs[:, 1], eig_vecs[:, 2]]).transpose()

        # rotate and translate ellipsoid coordinates
        x, y, z = np.dot(rot_mat, np.array([x.flatten(), y.flatten(), z.flatten()]))
        x, y, z = x + mu[i][0], y + mu[i][1], z + mu[i][2]
        xx, yy, zz = np.array([x.reshape(phi.shape), y.reshape(phi.shape), z.reshape(phi.shape)])

        ax.plot_wireframe(xx, yy, zz, color='purple', alpha=.15, label='Covariance #'+str(i+1))


def plot_2d_image_segmentation(img, r, axs=None):

    fig, axs = plt.subplots(3, 1) if axs is None else (None, axs)
    axs[0].imshow(img)
    axs[1].imshow(img * np.repeat(r.argmax(axis=0), 3, axis=-1).reshape(img.shape))
    axs[2].imshow(img * np.repeat(r.argmin(axis=0), 3, axis=-1).reshape(img.shape))


def em_iterations_anim(data, mu_list, cov_list, r_list, shape, save_opt=False, style_opt=False):

    nrows = 3
    ncols = 3
    from matplotlib import gridspec
    gs = gridspec.GridSpec(nrows=nrows, ncols=ncols)

    fig = plt.figure(figsize=(15, 8))
    axs = []
    titles = ['original', 'pixel cluster #1', 'pixel cluster #2']
    for i in range(nrows):
        axs.append(fig.add_subplot(gs[i, 0]))
        axs[i].tick_params(top=False, bottom=False, left=False, right=False, labelleft=False, labelbottom=False)
        axs[i].set_title(titles[i], y=-0.2)

    axs.append(fig.add_subplot(gs[:, 1:ncols], projection='3d'))
    axs[-1].view_init(elev=30, azim=160)
    axs[-1].set_xlabel('Red')
    axs[-1].set_ylabel('Green')
    axs[-1].set_zlabel('Blue')

    def update(i):

        latent_vars = r_list[i].argmax(axis=0)
        plot_2d_image_segmentation(data.reshape(shape), r_list[i], axs=axs[:-1])
        plot_3d_distribution(data, mu_list[i], cov_list[i], latent_vars, ax=axs[-1])

        return axs,
    
    import matplotlib as mpl
    if style_opt:
        mpl.rcParams['savefig.facecolor'] = '#148ec8'
        fig.set_facecolor('#148ec8')
        for ax in axs:
            ax.set_facecolor('#148ec8')
            ax.set_title(label=ax.get_title(), fontdict={'color': 'white', 'size': 24}, y=1.0)
            ax.spines['bottom'].set_color('white')
            ax.spines['top'].set_color('white')
            ax.spines['left'].set_color('white')
            ax.spines['right'].set_color('white')
            ax.xaxis.label.set_color('white')
            ax.yaxis.label.set_color('white')
            ax.tick_params(colors='white')
            try:
                ax.zaxis.label.set_color('white')
                ax.w_xaxis.line.set_color("white")
                ax.w_yaxis.line.set_color("white")
                ax.w_zaxis.line.set_color("white")
            except:
                pass
    else:
        mpl.rcParams['savefig.facecolor'] = '#ffffff'
        fig.set_facecolor('#ffffff')
        for ax in axs:
            ax.set_title(label=ax.get_title(), fontdict={'color': 'black', 'size': 24}, y=1.0)
    
    from matplotlib import animation
    anim = animation.FuncAnimation(fig, update, frames=len(mu_list), interval=500)
    plt.tight_layout()
    plt.close()
            
    if save_opt:
        anim.save('./em-fit_anim.gif', dpi=50, writer='imagemagick')

    return anim
In [8]:
from IPython.display import HTML
anim = em_iterations_anim(data, mu_list, cov_list, w_list, img.shape)
HTML(anim.to_jshtml())
Out[8]: