"""
Model definitions and PyTorch Lightning modules for EvoAug2.
This module provides the DeepSTARR model architecture and PyTorch Lightning
wrappers for training with EvoAug2 augmentations.
"""
import torch
import torch.nn as nn
import pytorch_lightning as pl
import numpy as np
[docs]
class DeepSTARR(nn.Module):
"""
DeepSTARR model from de Almeida et al., 2022.
This is the original DeepSTARR model architecture as described in the paper.
See https://www.nature.com/articles/s41588-022-01048-5 for details.
Parameters
----------
output_dim : int
Number of output classes for prediction.
d : int, optional
Number of first-layer convolutional filters. Defaults to 256.
conv1_filters : torch.Tensor, optional
Initial filters for the first convolutional layer. If None, random filters are initialized.
learn_conv1_filters : bool, optional
Whether to learn the first convolutional filters. Defaults to True.
conv2_filters : torch.Tensor, optional
Initial filters for the second convolutional layer. If None, random filters are initialized.
learn_conv2_filters : bool, optional
Whether to learn the second convolutional filters. Defaults to True.
conv3_filters : torch.Tensor, optional
Initial filters for the third convolutional layer. If None, random filters are initialized.
learn_conv3_filters : bool, optional
Whether to learn the third convolutional filters. Defaults to True.
conv4_filters : torch.Tensor, optional
Initial filters for the fourth convolutional layer. If None, random filters are initialized.
learn_conv4_filters : bool, optional
Whether to learn the fourth convolutional filters. Defaults to True.
Notes
-----
- The original DeepSTARR model uses 256 first-layer convolutional filters
- Supports transfer learning by initializing with pre-trained filters
- Uses batch normalization and max pooling throughout
- Final layers use LazyLinear for automatic input size inference
"""
[docs]
def __init__(self, output_dim, d=256,
conv1_filters=None, learn_conv1_filters=True,
conv2_filters=None, learn_conv2_filters=True,
conv3_filters=None, learn_conv3_filters=True,
conv4_filters=None, learn_conv4_filters=True):
super().__init__()
if d != 256:
print("NB: number of first-layer convolutional filters in original DeepSTARR model is 256; current number of first-layer convolutional filters is not set to 256")
self.activation = nn.ReLU()
self.dropout4 = nn.Dropout(0.4)
self.flatten = nn.Flatten()
self.init_conv1_filters = conv1_filters
self.init_conv2_filters = conv2_filters
self.init_conv3_filters = conv3_filters
self.init_conv4_filters = conv4_filters
assert (not (conv1_filters is None and not learn_conv1_filters)), "initial conv1_filters cannot be set to None while learn_conv1_filters is set to False"
assert (not (conv2_filters is None and not learn_conv2_filters)), "initial conv2_filters cannot be set to None while learn_conv2_filters is set to False"
assert (not (conv3_filters is None and not learn_conv3_filters)), "initial conv3_filters cannot be set to None while learn_conv3_filters is set to False"
assert (not (conv4_filters is None and not learn_conv4_filters)), "initial conv4_filters cannot be set to None while learn_conv4_filters is set to False"
# Layer 1 (convolutional), constituent parts
if conv1_filters is not None:
if learn_conv1_filters: # continue modifying existing conv1_filters through learning
self.conv1_filters = nn.Parameter( torch.Tensor(conv1_filters) )
else:
self.register_buffer("conv1_filters", torch.Tensor(conv1_filters))
else:
self.conv1_filters = nn.Parameter(torch.zeros(d, 4, 7))
nn.init.kaiming_normal_(self.conv1_filters)
self.batchnorm1 = nn.BatchNorm1d(d)
self.activation1 = nn.ReLU() # name the first-layer activation function for hook purposes
self.maxpool1 = nn.MaxPool1d(2)
# Layer 2 (convolutional), constituent parts
if conv2_filters is not None:
if learn_conv2_filters: # continue modifying existing conv2_filters through learning
self.conv2_filters = nn.Parameter( torch.Tensor(conv2_filters) )
else:
self.register_buffer("conv2_filters", torch.Tensor(conv2_filters))
else:
self.conv2_filters = nn.Parameter(torch.zeros(60, d, 3))
nn.init.kaiming_normal_(self.conv2_filters)
self.batchnorm2 = nn.BatchNorm1d(60)
self.maxpool2 = nn.MaxPool1d(2)
# Layer 3 (convolutional), constituent parts
if conv3_filters is not None:
if learn_conv3_filters: # continue modifying existing conv3_filters through learning
self.conv3_filters = nn.Parameter( torch.Tensor(conv3_filters) )
else:
self.register_buffer("conv3_filters", torch.Tensor(conv3_filters))
else:
self.conv3_filters = nn.Parameter(torch.zeros(60, 60, 5))
nn.init.kaiming_normal_(self.conv3_filters)
self.batchnorm3 = nn.BatchNorm1d(60)
self.maxpool3 = nn.MaxPool1d(2)
# Layer 4 (convolutional), constituent parts
if conv4_filters is not None:
if learn_conv4_filters: # continue modifying existing conv4_filters through learning
self.conv4_filters = nn.Parameter( torch.Tensor(conv4_filters) )
else:
self.register_buffer("conv4_filters", torch.Tensor(conv4_filters))
else:
self.conv4_filters = nn.Parameter(torch.zeros(120, 60, 3))
nn.init.kaiming_normal_(self.conv4_filters)
self.batchnorm4 = nn.BatchNorm1d(120)
self.maxpool4 = nn.MaxPool1d(2)
# Layer 5 (fully connected), constituent parts
self.fc5 = nn.LazyLinear(256, bias=True)
self.batchnorm5 = nn.BatchNorm1d(256)
# Layer 6 (fully connected), constituent parts
self.fc6 = nn.Linear(256, 256, bias=True)
self.batchnorm6 = nn.BatchNorm1d(256)
# Output layer (fully connected), constituent parts
self.fc7 = nn.Linear(256, output_dim)
[docs]
def get_which_conv_layers_transferred(self):
"""
Get list of convolutional layers that were initialized with pre-trained filters.
Returns
-------
list
List of layer indices (1-4) that were initialized with pre-trained filters.
Notes
-----
This method is useful for understanding which layers were transferred
from a pre-trained model during initialization.
"""
layers = []
if self.init_conv1_filters is not None:
layers.append(1)
if self.init_conv2_filters is not None:
layers.append(2)
if self.init_conv3_filters is not None:
layers.append(3)
if self.init_conv4_filters is not None:
layers.append(4)
return layers
[docs]
def forward(self, x):
"""
Forward pass through the DeepSTARR model.
Parameters
----------
x : torch.Tensor
Input tensor with shape (batch_size, 4, sequence_length).
Returns
-------
torch.Tensor
Output predictions with shape (batch_size, output_dim).
Notes
-----
The forward pass applies:
1. Four sequential 1D convolutions with batch normalization and max pooling
2. Flattening of convolutional features
3. Two fully connected layers with batch normalization and dropout
4. Final output layer for predictions
"""
# Layer 1
cnn = torch.conv1d(x, self.conv1_filters, stride=1, padding="same")
cnn = self.batchnorm1(cnn)
cnn = self.activation1(cnn)
cnn = self.maxpool1(cnn)
# Layer 2
cnn = torch.conv1d(cnn, self.conv2_filters, stride=1, padding="same")
cnn = self.batchnorm2(cnn)
cnn = self.activation(cnn)
cnn = self.maxpool2(cnn)
# Layer 3
cnn = torch.conv1d(cnn, self.conv3_filters, stride=1, padding="same")
cnn = self.batchnorm3(cnn)
cnn = self.activation(cnn)
cnn = self.maxpool3(cnn)
# Layer 4
cnn = torch.conv1d(cnn, self.conv4_filters, stride=1, padding="same")
cnn = self.batchnorm4(cnn)
cnn = self.activation(cnn)
cnn = self.maxpool4(cnn)
# Layer 5
cnn = self.flatten(cnn)
cnn = self.fc5(cnn)
cnn = self.batchnorm5(cnn)
cnn = self.activation(cnn)
cnn = self.dropout4(cnn)
# Layer 6
cnn = self.fc6(cnn)
cnn = self.batchnorm6(cnn)
cnn = self.activation(cnn)
cnn = self.dropout4(cnn)
# Output layer
y_pred = self.fc7(cnn)
return y_pred
# Note: We'll use H5DataModule directly and create separate data modules for different purposes
[docs]
class DeepSTARRModel(pl.LightningModule):
"""
PyTorch Lightning module for DeepSTARR training.
This class wraps the DeepSTARR model in a PyTorch Lightning module,
providing training, validation, and testing functionality with
automatic logging and checkpointing.
Parameters
----------
model : DeepSTARR
The DeepSTARR model instance.
learning_rate : float, optional
Learning rate for training. Defaults to 0.001.
weight_decay : float, optional
Weight decay (L2 regularization). Defaults to 1e-6.
Notes
-----
- Uses MSE loss for regression tasks
- Adam optimizer with ReduceLROnPlateau scheduler
- Automatic logging of training, validation, and test losses
"""
[docs]
def __init__(self, model, learning_rate=0.001, weight_decay=1e-6):
super().__init__()
self.model = model
self.learning_rate = learning_rate
self.weight_decay = weight_decay
self.criterion = nn.MSELoss()
[docs]
def forward(self, x):
"""
Forward pass through the model.
Parameters
----------
x : torch.Tensor
Input tensor.
Returns
-------
torch.Tensor
Model predictions.
"""
return self.model(x)
[docs]
def training_step(self, batch, batch_idx):
"""
Training step for a single batch.
Parameters
----------
batch : tuple
Tuple of (x, y) where x is input and y is target.
batch_idx : int
Index of the current batch.
Returns
-------
torch.Tensor
Training loss for the batch.
"""
x, y = batch
y_hat = self(x)
loss = self.criterion(y_hat, y)
self.log('train_loss', loss, prog_bar=True)
return loss
[docs]
def validation_step(self, batch, batch_idx):
"""
Validation step for a single batch.
Parameters
----------
batch : tuple
Tuple of (x, y) where x is input and y is target.
batch_idx : int
Index of the current batch.
Returns
-------
torch.Tensor
Validation loss for the batch.
"""
x, y = batch
y_hat = self(x)
loss = self.criterion(y_hat, y)
self.log('val_loss', loss, prog_bar=True)
return loss
[docs]
def test_step(self, batch, batch_idx):
"""
Test step for a single batch.
Parameters
----------
batch : tuple
Tuple of (x, y) where x is input and y is target.
batch_idx : int
Index of the current batch.
Returns
-------
torch.Tensor
Test loss for the batch.
"""
x, y = batch
y_hat = self(x)
loss = self.criterion(y_hat, y)
self.log('test_loss', loss, prog_bar=True)
return loss
[docs]
class Basset(nn.Module):
"""
Basset model from Kelley et al., 2016.
This is the Basset model architecture as described in the original paper.
See https://genome.cshlp.org/content/early/2016/05/03/gr.200535.115.abstract
and https://github.com/davek44/Basset/blob/master/data/models/pretrained_params.txt
Parameters
----------
output_dim : int
Number of output classes for prediction.
d : int, optional
Number of first-layer convolutional filters. Defaults to 300.
conv1_filters : torch.Tensor, optional
Initial filters for the first convolutional layer. If None, random filters are initialized.
learn_conv1_filters : bool, optional
Whether to learn the first convolutional filters. Defaults to True.
conv2_filters : torch.Tensor, optional
Initial filters for the second convolutional layer. If None, random filters are initialized.
learn_conv2_filters : bool, optional
Whether to learn the second convolutional filters. Defaults to True.
conv3_filters : torch.Tensor, optional
Initial filters for the third convolutional layer. If None, random filters are initialized.
learn_conv3_filters : bool, optional
Whether to learn the third convolutional filters. Defaults to True.
Notes
-----
- The original Basset model uses 300 first-layer convolutional filters
- Supports transfer learning by initializing with pre-trained filters
- Uses batch normalization and max pooling throughout
- Final layers use LazyLinear for automatic input size inference
- Output uses sigmoid activation for binary classification
"""
[docs]
def __init__(self, output_dim, d=300,
conv1_filters=None, learn_conv1_filters=True,
conv2_filters=None, learn_conv2_filters=True,
conv3_filters=None, learn_conv3_filters=True):
super().__init__()
if d != 300:
print("NB: number of first-layer convolutional filters in original Basset model is 300; current number of first-layer convolutional filters is not set to 300")
self.activation = nn.ReLU()
self.dropout3 = nn.Dropout(0.3)
self.flatten = nn.Flatten()
self.init_conv1_filters = conv1_filters
self.init_conv2_filters = conv2_filters
self.init_conv3_filters = conv3_filters
assert (not (conv1_filters is None and not learn_conv1_filters)), "initial conv1_filters cannot be set to None while learn_conv1_filters is set to False"
assert (not (conv2_filters is None and not learn_conv2_filters)), "initial conv2_filters cannot be set to None while learn_conv2_filters is set to False"
assert (not (conv3_filters is None and not learn_conv3_filters)), "initial conv3_filters cannot be set to None while learn_conv3_filters is set to False"
# Layer 1 (convolutional), constituent parts
if conv1_filters is not None:
if learn_conv1_filters: # continue modifying existing conv1_filters through learning
self.conv1_filters = nn.Parameter( torch.Tensor(conv1_filters) )
else:
self.register_buffer("conv1_filters", torch.Tensor(conv1_filters))
else:
self.conv1_filters = nn.Parameter(torch.zeros(d, 4, 19))
nn.init.kaiming_normal_(self.conv1_filters)
self.batchnorm1 = nn.BatchNorm1d(d)
self.activation1 = nn.ReLU() # name the first-layer activation function for hook purposes
self.maxpool1 = nn.MaxPool1d(3)
# Layer 2 (convolutional), constituent parts
if conv2_filters is not None:
if learn_conv2_filters: # continue modifying existing conv2_filters through learning
self.conv2_filters = nn.Parameter( torch.Tensor(conv2_filters) )
else:
self.register_buffer("conv2_filters", torch.Tensor(conv2_filters))
else:
self.conv2_filters = nn.Parameter(torch.zeros(200, d, 11))
nn.init.kaiming_normal_(self.conv2_filters)
self.batchnorm2 = nn.BatchNorm1d(200)
self.maxpool2 = nn.MaxPool1d(4)
# Layer 3 (convolutional), constituent parts
if conv3_filters is not None:
if learn_conv3_filters: # continue modifying existing conv3_filters through learning
self.conv3_filters = nn.Parameter( torch.Tensor(conv3_filters) )
else:
self.register_buffer("conv3_filters", torch.Tensor(conv3_filters))
else:
self.conv3_filters = nn.Parameter(torch.zeros(200, 200, 7))
nn.init.kaiming_normal_(self.conv3_filters)
self.batchnorm3 = nn.BatchNorm1d(200)
self.maxpool3 = nn.MaxPool1d(4)
# Layer 4 (fully connected), constituent parts
self.fc4 = nn.LazyLinear(1000, bias=False)
self.batchnorm4 = nn.BatchNorm1d(1000)
# Layer 5 (fully connected), constituent parts
self.fc5 = nn.Linear(1000, 1000, bias=False)
self.batchnorm5 = nn.BatchNorm1d(1000)
# Output layer (fully connected), constituent parts
self.fc6 = nn.Linear(1000, output_dim)
self.sigmoid = nn.Sigmoid()
[docs]
def get_which_conv_layers_transferred(self):
"""
Get list of convolutional layers that were initialized with pre-trained filters.
Returns
-------
list
List of layer indices (1-3) that were initialized with pre-trained filters.
Notes
-----
This method is useful for understanding which layers were transferred
from a pre-trained model during initialization.
"""
layers = []
if self.init_conv1_filters is not None:
layers.append(1)
if self.init_conv2_filters is not None:
layers.append(2)
if self.init_conv3_filters is not None:
layers.append(3)
return layers
[docs]
def forward(self, x):
"""
Forward pass through the Basset model.
Parameters
----------
x : torch.Tensor
Input tensor with shape (batch_size, 4, sequence_length).
Returns
-------
torch.Tensor
Output predictions with shape (batch_size, output_dim).
Notes
-----
The forward pass applies:
1. Three sequential 1D convolutions with batch normalization and max pooling
2. Flattening of convolutional features
3. Two fully connected layers with batch normalization and dropout
4. Final output layer with sigmoid activation
"""
# Layer 1
cnn = torch.conv1d(x, self.conv1_filters, stride=1, padding=(self.conv1_filters.shape[-1]//2))
cnn = self.batchnorm1(cnn)
cnn = self.activation1(cnn)
cnn = self.maxpool1(cnn)
# Layer 2
cnn = torch.conv1d(cnn, self.conv2_filters, stride=1, padding=(self.conv2_filters.shape[-1]//2))
cnn = self.batchnorm2(cnn)
cnn = self.activation(cnn)
cnn = self.maxpool2(cnn)
# Layer 3
cnn = torch.conv1d(cnn, self.conv3_filters, stride=1, padding=(self.conv3_filters.shape[-1]//2))
cnn = self.batchnorm3(cnn)
cnn = self.activation(cnn)
cnn = self.maxpool3(cnn)
# Layer 4
cnn = self.flatten(cnn)
cnn = self.fc4(cnn)
cnn = self.batchnorm4(cnn)
cnn = self.activation(cnn)
cnn = self.dropout3(cnn)
# Layer 5
cnn = self.fc5(cnn)
cnn = self.batchnorm5(cnn)
cnn = self.activation(cnn)
cnn = self.dropout3(cnn)
# Output layer
cnn = self.fc6(cnn)
y_pred = self.sigmoid(cnn)
return y_pred
[docs]
class CNN(nn.Module):
"""
Generic CNN model for genomic sequence classification.
This is a flexible CNN architecture that can be used for various
genomic sequence classification tasks.
Parameters
----------
output_dim : int
Number of output classes for prediction.
Notes
-----
- Uses three convolutional layers with batch normalization and max pooling
- Includes dropout for regularization
- Final layers use LazyLinear for automatic input size inference
- Output uses sigmoid activation for binary classification
"""
[docs]
def __init__(self, output_dim):
super().__init__()
self.activation1 = nn.ReLU()
self.activation = nn.ReLU()
self.dropout1 = nn.Dropout(0.2)
self.dropout2 = nn.Dropout(0.2)
self.dropout3 = nn.Dropout(0.2)
self.dropout4 = nn.Dropout(0.5)
self.flatten = nn.Flatten()
self.output_activation = nn.Sigmoid()
# Layer 1 (convolutional), constituent parts
self.conv1_filters = torch.nn.Parameter(torch.zeros(64, 4, 7))
torch.nn.init.kaiming_uniform_(self.conv1_filters)
self.batchnorm1 = nn.BatchNorm1d(64)
self.maxpool1 = nn.MaxPool1d(4)
# Layer 3 (convolutional), constituent parts
self.conv2_filters = torch.nn.Parameter(torch.zeros(96, 64, 5))
torch.nn.init.kaiming_uniform_(self.conv2_filters)
self.batchnorm2 = nn.BatchNorm1d(96)
self.maxpool2 = nn.MaxPool1d(4)
# Layer 4 (convolutional), constituent parts
self.conv3_filters = torch.nn.Parameter(torch.zeros(128, 96, 5))
torch.nn.init.kaiming_uniform_(self.conv3_filters)
self.batchnorm3 = nn.BatchNorm1d(128)
self.maxpool3 = nn.MaxPool1d(2)
# Layer 5 (fully connected), constituent parts
self.fc4 = nn.LazyLinear(256, bias=True)
self.batchnorm4 = nn.BatchNorm1d(256)
# Output layer (fully connected), constituent parts
self.fc5 = nn.LazyLinear(output_dim, bias=True)
[docs]
def forward(self, x):
"""
Forward pass through the CNN model.
Parameters
----------
x : torch.Tensor
Input tensor with shape (batch_size, 4, sequence_length).
Returns
-------
torch.Tensor
Output predictions with shape (batch_size, output_dim).
Notes
-----
The forward pass applies:
1. Three sequential 1D convolutions with batch normalization and max pooling
2. Dropout after each convolutional layer
3. Flattening of convolutional features
4. Fully connected layer with batch normalization and dropout
5. Final output layer with sigmoid activation
"""
# Layer 1
cnn = torch.conv1d(x, self.conv1_filters, stride=1, padding="same")
cnn = self.batchnorm1(cnn)
cnn = self.activation1(cnn)
cnn = self.maxpool1(cnn)
cnn = self.dropout1(cnn)
# Layer 2
cnn = torch.conv1d(cnn, self.conv2_filters, stride=1, padding="same")
cnn = self.batchnorm2(cnn)
cnn = self.activation(cnn)
cnn = self.maxpool2(cnn)
cnn = self.dropout2(cnn)
# Layer 3
cnn = torch.conv1d(cnn, self.conv3_filters, stride=1, padding="same")
cnn = self.batchnorm3(cnn)
cnn = self.activation(cnn)
cnn = self.maxpool3(cnn)
cnn = self.dropout3(cnn)
# Layer 4
cnn = self.flatten(cnn)
cnn = self.fc4(cnn)
cnn = self.batchnorm4(cnn)
cnn = self.activation(cnn)
cnn = self.dropout4(cnn)
# Output layer
logits = self.fc5(cnn)
y_pred = self.output_activation(logits)
return y_pred