Simulated one-dimensional latent space¶
[1]:
import string
import pandas as pd
from itertools import combinations
from gpytorch.kernels import RQKernel
import torch
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
Simulate data¶
[2]:
def sim(seed, p=5):
N = 2 ** p
torch.random.manual_seed(seed)
W = (torch.randn(p, 1) - 0.2) * 2
X = torch.zeros(N, p)
ind = 1
# for all # of mutations
for mutations in range(1, p + 1):
# for selected combination of mutations for a variant
for variant in combinations(range(p), mutations):
# for each selected
for s in variant:
X[ind, s] = 1
# update after variant
ind += 1
z = torch.mm(X, W)
Z = torch.linspace(z.min(), z.max(), 100)[:, None]
z_samp = torch.cat((z, Z), 0)
kernel = RQKernel()
with torch.no_grad():
K = kernel(z_samp).evaluate()
f = torch.distributions.MultivariateNormal(
torch.zeros(N + 100), 0.0025 * K + torch.eye(N + 100) * 1e-7
).rsample() + torch.sigmoid(0.9+ z_samp[:, 0])
y = f[:N] + torch.randn(N) * 0.05
return W, X, z, y, Z, f[N:]
p = 5
W, X, z, y, Z, f = sim(100, p=p)
plt.figure(figsize=(4, 3), dpi=300)
plt.plot(Z, f)
plt.scatter(z, y, c="C2", alpha=0.8)
plt.axvline(0, c="k", ls="--")
for i in range(p):
plt.arrow(0, -.05*i, W[i].item(), 0, color=f"C{3+i}", width=0.01)
plt.ylabel("phenotype")
plt.xlabel("$z_1$")
None
Convert to dataframe¶
We convert the generated dataset to a pandas DataFrame for use with LANTERN.
[3]:
df = pd.DataFrame(
{
"substitutions": [
":".join(
[
# encode each mutation as one of +a, +b, ...
"+{}".format(string.ascii_lowercase[i])
for i in np.where(X[j, :].numpy())[0]
]
)
for j in range(X.shape[0])
],
"phenotype": ((y-y.mean())/y.std()).numpy(),
},
)
df.head()
[3]:
substitutions | phenotype | |
---|---|---|
0 | 1.711029 | |
1 | +a | 2.069441 |
2 | +b | 0.693510 |
3 | +c | 0.269302 |
4 | +d | 1.718302 |
Build LANTERN dataset¶
[4]:
from lantern.dataset import Dataset
ds = Dataset(df)
ds
[4]:
Dataset(substitutions='substitutions', phenotypes=['phenotype'], errors=None)
[5]:
# 32 observations
len(ds)
[5]:
32
[6]:
# get the first element (a tuple of x_0, y_0)
ds[0]
[6]:
(tensor([0., 0., 0., 0., 0.]), tensor([1.7110]))
Build model¶
[7]:
K = 4
[8]:
from lantern.model.basis import VariationalBasis
basis = VariationalBasis.fromDataset(ds, K=K)
[9]:
from lantern.model.surface import Phenotype
surface = Phenotype.fromDataset(ds, K=K, Ni=200, inducScale=1.0)
[10]:
from lantern.model import Model
from lantern.model.likelihood import GaussianLikelihood
model = Model(basis, surface, GaussianLikelihood())
Train model¶
[11]:
from torch.optim import Adam
loss = model.loss(N=len(ds))
Xtrain, ytrain = ds[: len(ds)]
E = 3000
optimizer = Adam(loss.parameters(), lr=0.01)
hist = []
halpha = np.zeros((E, K))
for i in range(E):
optimizer.zero_grad()
yhat = model(Xtrain)
lss = loss(yhat, ytrain)
total = sum(lss.values())
total.backward()
optimizer.step()
hist.append(total.item())
halpha[i, :] = basis.qalpha(detach=True).mean.numpy()
plt.figure(figsize=(4, 3), dpi=300)
plt.plot(hist)
plt.xlabel("epoch")
plt.ylabel("loss")
[11]:
Text(0, 0.5, 'loss')
We also review the learned variance for each dimension at each iteration:
[12]:
plt.figure(figsize=(4, 3), dpi=300)
plt.plot(1/halpha)
plt.xlabel("epoch")
plt.ylabel("variance")
plt.semilogy()
None
Analyze model¶
Model dimensionality¶
There is a built-in utility for computing the dimensionality learned by LANTERN:
[13]:
from lantern.model import dimensionality
The dimensionality calculations require the trained model and dataset. The number displayed as output is the total number of dimensions found in the model (this is also available as the attribute K
of the returned Dimensionality
object
[14]:
dim = dimensionality(model, ds)
dim
[14]:
Dimensionality(1)
To view the statistics used to determine the dimensionality (see LANTERN’s associated manuscript for more details), there is a diagnostic plot available:
[15]:
dim.plotStatistics(nrow=1)
Finally, to see the variance learned for each dimension (with circles representing dimensions included according to the determined dimensionality), run:
[16]:
dim.plotVariance(model.basis)
Compared learned latent effects¶
The order of dimensions by relevance (e.g. their variance) is stochastic, and can change on different re-runs. We provide an order
attribute on the basis
component to address this problem. The order
is the indexes of the latent dimensions sorted by their variance. We store the index of the highest relevance dimension here for further analysis.
[17]:
z1 = basis.order[0]
When comparing the latent mutational effects learned by LANTERN to the true effects, there is strong correspondence between the two. In this case, LANTERN has learned a reflection of the true effects (i.e. \(z_1\) learned by LANTERN corresponds to \(-z_1\) used to simulate the data). So, the learned effects are strongly negatively correlated with their true value.
[18]:
Wapprox = basis.W_mu[:, z1].detach().numpy()
plt.figure(figsize=(4, 3), dpi=300)
plt.scatter(Wapprox, W)
plt.xlabel("$z_1$ (from LANTERN)")
plt.ylabel("$z_1$ (actual)")
None
Plot learned surface¶
We compare the learned surface to the true, underlying surface. Since LANTERN is unaware of the underlying scale of the latent mutational effect dimension, we rescale the true \(z_1\) to match that of the \(z_1\) learned by LANTERN.
[19]:
with torch.no_grad():
Zapprox = basis(X)
Zpred = torch.zeros(100, K)
Zpred[:, z1] = torch.linspace(Zapprox[:, z1].min(), Zapprox[:, z1].max(), 100)
fpred = surface(Zpred)
lo, hi = fpred.confidence_region()
plt.figure(figsize=(4, 3), dpi=300)
plt.plot(Zpred[:, z1].numpy(), fpred.mean.numpy(), label="surface (learned)")
plt.fill_between(Zpred[:, z1].numpy(), lo.numpy(), hi.numpy(), alpha=0.6)
# this scales the actual mutational effect space to the learned one
scale = Zapprox[:, z1].mean() / z.mean()
plt.plot(Z*scale, (f - y.mean())/y.std(), label="surface (actual)")
plt.scatter(z*scale, (y - y.mean())/y.std(), c="C2", alpha=0.8, label="data")
plt.axvline(0, c="k", ls="--", label="wild-type")
plt.xlabel("$z_1$")
plt.ylabel("phenotype")
plt.legend()
None