Analyzing the avGFP landscape

WARNING: this example notebook uses an outdated version of the library API. See the documentation for more details. This example walks through the generation of analysis figures for the avGFP brightness landscape, similar to how they were generate ``Interpretable modeling of genotype-phenotype landscapes with state-of-the-art predictive power’’.

[1]:
import matplotlib.pyplot as plt

%matplotlib inline

Prepare the model

First we load the dataset and pre-trained model file. The dataset for avGFP can be generated from the manuscript pipeline found at https://github.com/usnistgov/lantern/tree/master/manuscript. The pre-trained model parameters are also provided there.

First, we load the csv dataset and prepare it as a LANTERN dataset.

[2]:
import pandas as pd

from lantern.dataset import Dataset

df = pd.read_csv("../manuscript/data/processed/gfp.csv", index_col=0)
df.head()
[2]:
aaMutations uniqueBarcodes medianBrightness std substitutions phenotype cv
0 NaN 3645 3.719212 0.106992 NaN 1.030462 1.0
1 SA108D 1 1.301030 NaN SA108D -1.250280 7.0
2 SA108D:SN144D:SI186V:SM231T:SL234P 1 1.301031 NaN SA108D:SN144D:SI186V:SM231T:SL234P -1.250279 9.0
3 SA108D:SY141N:SH146L:SE220G 1 1.301189 NaN SA108D:SY141N:SH146L:SE220G -1.250130 8.0
4 SA108G 2 3.708478 0.020613 SA108G 1.020338 8.0
[3]:
ds = Dataset(df)
ds
[3]:
Dataset(substitutions='substitutions', phenotypes=['phenotype'], errors=None)

Then we setup the LANTERN model, and load the pre-trained model parameters.

[4]:
from lantern.model.basis import VariationalBasis
from lantern.model.surface import Phenotype
from lantern.model import Model

import torch

K=8
model = Model(
    VariationalBasis.fromDataset(ds, K=K),
    Phenotype.fromDataset(ds, K=K)
)

model.load_state_dict(
    torch.load(
        "../manuscript/experiments/gfp-brightness/lantern/full/model.pt",
        "cpu"
    )
)

model.eval();

Plot the avGFP brightness surface

We first plot the non-linear surface across \(z_1\) and \(z_2\) by: 1. Finding the range of values observed for \(z_1\) and \(z_2\) across the dataset. 2. Building a dense mesh of points covering the range of values in these two dimensions. 3. Predicting the surface at these points 4. Plotting the surface contours

[5]:
import numpy as np

# find z1, z2
z1, z2 = model.basis.order[:2]

# get mutations and brightness for all observations
X, y = ds[:len(ds)]

# get the embedding for all data points
with torch.no_grad():
    Z = model.basis(X)

# to filter outliers, only plot the points within [q/2, 1-q/2] quantile of each latent dimension
q = 0.01

# number of surface points
N = 100

# the meshgrid is used for surface plotting
Z1, Z2 = np.meshgrid(
    np.linspace(np.quantile(Z[:, z1], q/2), np.quantile(Z[:, z1], 1 - q/2), N),
    np.linspace(np.quantile(Z[:, z2], q/2), np.quantile(Z[:, z2], 1 - q/2), N )
)

# predict the surface at each meshgrid point
Zpred = torch.zeros(N**2, model.basis.K)
Zpred[:, z1] = torch.from_numpy(Z1.ravel())
Zpred[:, z2] = torch.from_numpy(Z2.ravel())

# predict the surface
with torch.no_grad():
    fpred = model.surface(Zpred)

# scale to original brightness values and reshape for plotting
f = fpred.mean * df["medianBrightness"].std() + df["medianBrightness"].mean()
f = f.reshape(Z1.shape)

# also scale the data for plotting
y = y * df["medianBrightness"].std() + df["medianBrightness"].mean()
[6]:
plt.figure(figsize=(3,2), dpi=300)

im = plt.contourf(Z1, Z2, f, levels=8)
plt.xlabel("$z_1$")
plt.ylabel("$z_2$")
plt.colorbar()
[6]:
<matplotlib.colorbar.Colorbar at 0x7fed123b11d0>
../_images/examples_analyzing-gfp_12_1.png

Next, we add the scatter of measured datapoints for comparison, coloring them by their measured value.

[7]:
plt.figure(figsize=(3,2), dpi=300)

im = plt.contour(Z1, Z2, f, levels=8)
plt.scatter(Z[:, z1].numpy(), Z[:, z2].numpy(), c = y, alpha=0.8, s=0.3, rasterized=True)

# re-apply the limits
plt.xlim(np.quantile(Z[:, z1], q/2), np.quantile(Z[:, z1], 1 - q/2))
plt.ylim(np.quantile(Z[:, z2], q/2), np.quantile(Z[:, z2], 1 - q/2))

plt.xlabel("$z_1$")
plt.ylabel("$z_2$")
plt.colorbar()
[7]:
<matplotlib.colorbar.Colorbar at 0x7fed123edc10>
../_images/examples_analyzing-gfp_14_1.png

Highlighting mutations in the surface

To highlight specific mutations in the surface, we find the mutation effect vectors in the matrix \(W\) and plot them over the surface contour. We re-calculate the contour predictions on the more narrow focused region around the origin. The mutations shown here are previously identified stabilizing substitutions for the blue fluorescent derivative of avGFP (BFP).

[8]:
# BFP stabilizing mutations
subs = ["SY143F", "SS63T", "SH229L",'SY37N', 'SN103T', 'SY143F', 'SI169V', 'SN196S', 'SA204V']
labels = [s[1:] for s in subs]
ind = [ds.tokenizer.tokens.index(s) for s in subs]
[9]:
# Get average mutation effects
W = model.basis.W_mu.detach().numpy()
[10]:
# rebuid the surface in a focused region
Z1, Z2 = np.meshgrid(
    np.linspace(-1.7, 2.6, N),
    np.linspace(-0.6, 0.3, N )
)

# predict the surface at each meshgrid point
Zpred = torch.zeros(N**2, model.basis.K)
Zpred[:, z1] = torch.from_numpy(Z1.ravel())
Zpred[:, z2] = torch.from_numpy(Z2.ravel())

# predict the surface
with torch.no_grad():
    fpred = model.surface(Zpred)

# scale to original brightness values and reshape for plotting
f = fpred.mean * df["medianBrightness"].std() + df["medianBrightness"].mean()
f = f.reshape(Z1.shape)
[11]:
fig = plt.figure(figsize=(3, 2), dpi=300)

plt.contour(Z1, Z2, f, levels=8)

for i in ind:
    _z = W[i, [z1, z2]]
    plt.arrow(
        0,
        0,
        _z[0],
        _z[1],
        color="C{}".format(ind.index(i)),
        label="+{}".format(labels[ind.index(i)]),
        length_includes_head=True,
        width=0.01,
    )

fig.legend(
    ncol=1,
    bbox_to_anchor=(1.01, 0.9),
    loc="upper left",
    borderaxespad=0.0,
)

plt.xlabel("$z_1$")
plt.ylabel("$z_2$")
[11]:
Text(0, 0.5, '$z_2$')
../_images/examples_analyzing-gfp_20_1.png

Mutational effect distribution

To plot the distribution of effects learned by LANTERN in the two highest relevant dimension (\(z_1\) and \(z_2\)), we compute the angle of each mutational effect vector and plot their histogram.

[12]:
plt.figure(figsize=(2, 2), dpi=300)
ax = plt.subplot(111, polar=True)

# "y" is the first argument to arctan2
theta = np.arctan2(W[:, z2], W[:, z1])

H, edges = np.histogram(theta, bins=80, density=True)

ax = plt.subplot(111, polar="true")

# offset the bottom of the histogram
bottom = H.max() * 0.5

bars = ax.bar(
    edges[1:],
    H,
    width=edges[1:] - edges[:-1],
    bottom=bottom,
    zorder=100,
)
ax.set_yticklabels([])

for h, b in zip(H, bars):
    b.set_edgecolor("k")
../_images/examples_analyzing-gfp_23_0.png

Model dimensions

To determine the dimensionality of the model, we find the highest \(i\) such that \(\log_{10} \sigma^2_k / \sigma^2_{k+1} > 1\). In other words, the lowest variance dimension that has at least a ten-fold increase in variance over the previous. This heuristic was sufficient to determine the correct landscape dimension for simulated data (see manuscript for more details). We show how to plot the fold change of dimensions, highlighting those that are included with this procedure.

[13]:
# get basis dimension variances (inverse precision)
qalpha = model.basis.qalpha(detach=True)
sigmak = 1 / qalpha.mean[model.basis.order].numpy() # sort by the determined z1, z2, ...
[14]:
fold = (sigmak[:-1] / sigmak[1:])[::-1]
thresh = 1

# find the first dimension where at least a ten-fold change occurs
select = (np.where(np.log10(fold) > thresh)[0]).min()


plt.figure(figsize=(3, 2), dpi=300)
plt.plot(fold, marker="o")
plt.scatter(range(select, K-1), fold[select:], facecolors="none", s=100, edgecolors="C0")
plt.xticks(range(K-1), [f"z{K-k-1}" for k in range(K-1)])
plt.ylabel("fold-change")
plt.axhline(10**thresh, c="r", ls="--")

plt.semilogy()
[14]:
[]
../_images/examples_analyzing-gfp_27_1.png

We can also review the variance of each dimension directly. First looking at the total variance for each dimension:

[15]:
plt.figure(figsize=(3, 2), dpi=300)
plt.plot(np.arange(1, model.basis.K+1), sigmak, marker="o")
plt.semilogy()

plt.xlabel("dimension")
plt.ylabel("$\sigma^2_k$")
[15]:
Text(0, 0.5, '$\\sigma^2_k$')
../_images/examples_analyzing-gfp_29_1.png

And also as a percentage of total variance for the mutational effect space:

[16]:
plt.figure(figsize=(3, 2), dpi=300)
plt.plot(np.arange(1, model.basis.K+1), sigmak / sigmak.sum(), marker="o")
plt.semilogy()

plt.xlabel("dimension")
plt.ylabel("% total variance")
[16]:
Text(0, 0.5, '% total variance')
../_images/examples_analyzing-gfp_31_1.png