"""
Library of data augmentations for genomic sequence data.
This module provides evolution-inspired data augmentation techniques for genomic sequences,
ensuring that all augmentations preserve the input sequence length L.
To contribute a custom augmentation, use the following syntax:
.. code-block:: python
class CustomAugmentation(AugmentBase):
def __init__(self, param1, param2):
self.param1 = param1
self.param2 = param2
def __call__(self, x: torch.Tensor) -> torch.Tensor:
# Perform augmentation
return x_aug
"""
import torch
[docs]
class AugmentBase:
"""
Base class for EvoAug augmentations for genomic sequences.
All augmentation classes should inherit from this base class and implement
the :meth:`__call__` method to ensure consistent interface.
"""
[docs]
def __call__(self, x):
"""Return an augmented version of `x`.
Parameters
----------
x : torch.Tensor
Batch of one-hot sequences with shape (N, A, L) where:
- N is the batch size
- A is the number of nucleotides (4 for DNA)
- L is the sequence length
Returns
-------
torch.Tensor
Batch of one-hot sequences with random augmentation applied.
Output shape must be (N, A, L) to maintain sequence length consistency.
Raises
------
NotImplementedError
If the augmentation class does not implement this method.
"""
raise NotImplementedError()
[docs]
class RandomDeletion(AugmentBase):
"""
Randomly deletes contiguous stretches of nucleotides from sequences.
This augmentation randomly selects deletion lengths and positions for each sequence
in a batch, then pads the deleted regions with random DNA to maintain the original
sequence length L.
Parameters
----------
delete_min : int, optional
Minimum size for random deletion. Defaults to 0.
delete_max : int, optional
Maximum size for random deletion. Defaults to 20.
Notes
-----
- Deletion positions are constrained to ensure the deletion window fits within
the sequence boundaries
- Random DNA padding is added equally to both ends of the deletion to maintain
sequence length L
- Each sequence in the batch receives a different random deletion
"""
[docs]
def __init__(self, delete_min=0, delete_max=20):
self.delete_min = delete_min
self.delete_max = delete_max
[docs]
def __call__(self, x):
"""Randomly delete segments in a set of one-hot DNA sequences.
Parameters
----------
x : torch.Tensor
Batch of one-hot sequences with shape (N, A, L).
Returns
-------
torch.Tensor
Sequences with randomly deleted segments, padded with random DNA
to maintain shape (N, A, L).
"""
N, A, L = x.shape
# sample random DNA
a = torch.eye(A)
p = torch.tensor([1/A for _ in range(A)])
padding = torch.stack([a[p.multinomial(self.delete_max, replacement=True)].transpose(0,1) for _ in range(N)]).to(x.device)
# sample deletion length for each sequence
delete_lens = torch.randint(self.delete_min, self.delete_max + 1, (N,))
# sample locations to delete for each sequence
delete_inds = torch.randint(L - self.delete_max + 1, (N,)) # deletion must be in boundaries of seq.
# loop over each sequence
x_aug = []
for seq, pad, delete_len, delete_ind in zip(x, padding, delete_lens, delete_inds):
# get index of half delete_len (to pad random DNA at beginning of sequence)
pad_begin_index = torch.div(delete_len, 2, rounding_mode='floor').item()
# index for other half (to pad random DNA at end of sequence)
pad_end_index = delete_len - pad_begin_index
# removes deletion and pads beginning and end of sequence with random DNA to ensure same length
x_aug.append( torch.cat([pad[:,:pad_begin_index], # random dna padding
seq[:,:delete_ind], # sequence up to deletion start index
seq[:,delete_ind+delete_len:], # sequence after deletion end index
pad[:,self.delete_max-pad_end_index:]], # random dna padding
-1)) # concatenation axis
return torch.stack(x_aug)
[docs]
class RandomInsertion(AugmentBase):
"""
Randomly inserts contiguous stretches of random DNA into sequences.
This augmentation randomly selects insertion lengths and positions for each sequence
in a batch, then trims the resulting sequences equally from both ends to maintain
the original sequence length L.
Parameters
----------
insert_min : int, optional
Minimum size for random insertion. Defaults to 0.
insert_max : int, optional
Maximum size for random insertion. Defaults to 20.
Notes
-----
- Insertion positions are randomly selected across the sequence length
- Random DNA is generated using uniform nucleotide distribution
- After insertion, sequences are trimmed equally from both ends to maintain
sequence length L
- Each sequence in the batch receives a different random insertion
"""
[docs]
def __init__(self, insert_min=0, insert_max=20):
self.insert_min = insert_min
self.insert_max = insert_max
[docs]
def __call__(self, x):
"""Randomly insert segments of random DNA into DNA sequences.
Parameters
----------
x : torch.Tensor
Batch of one-hot sequences with shape (N, A, L).
Returns
-------
torch.Tensor
Sequences with randomly inserted segments of random DNA, trimmed
to maintain shape (N, A, L).
"""
N, A, L = x.shape
# If insert_max is 0, return original sequences without modification
if self.insert_max <= 0:
return x
# sample random DNA
a = torch.eye(A)
p = torch.tensor([1/A for _ in range(A)])
insertions = torch.stack([a[p.multinomial(self.insert_max, replacement=True)].transpose(0,1) for _ in range(N)]).to(x.device)
# sample insertion length for each sequence
insert_lens = torch.randint(self.insert_min, self.insert_max + 1, (N,))
# sample locations to insertion for each sequence
insert_inds = torch.randint(L, (N,))
# loop over each sequence
x_aug = []
for seq, insertion, insert_len, insert_ind in zip(x, insertions, insert_lens, insert_inds):
# Convert to Python integers for safe indexing
il = insert_len.item()
ii = insert_ind.item()
# Insert the random DNA
inserted = torch.cat([seq[:, :ii], insertion[:, :il], seq[:, ii:]], -1)
# Calculate how much to trim to get back to length L
current_len = inserted.shape[-1]
excess = current_len - L
if excess > 0:
# Trim equally from both ends
trim_left = excess // 2
trim_right = excess - trim_left
final_seq = inserted[:, trim_left:current_len - trim_right]
else:
# No trimming needed
final_seq = inserted
# Ensure the final sequence has exactly length L
if final_seq.shape[-1] != L:
# If still wrong length, pad or trim to exactly L
if final_seq.shape[-1] > L:
final_seq = final_seq[:, :L]
else:
# Pad with random DNA to reach length L
padding_needed = L - final_seq.shape[-1]
padding = a[p.multinomial(padding_needed, replacement=True)].transpose(0,1).to(x.device)
final_seq = torch.cat([final_seq, padding], -1)
x_aug.append(final_seq)
# Stack all sequences and ensure they all have the same shape
stacked = torch.stack(x_aug)
# Final safety check - ensure all sequences have exactly length L
if stacked.shape[-1] != L:
print(f"Warning: RandomInsertion output shape {stacked.shape} doesn't match expected length {L}")
# Force all sequences to length L by trimming or padding
if stacked.shape[-1] > L:
stacked = stacked[:, :, :L]
else:
# Pad all sequences to length L
padding_needed = L - stacked.shape[-1]
padding = torch.stack([a[p.multinomial(padding_needed, replacement=True)].transpose(0,1) for _ in range(N)]).to(x.device)
stacked = torch.cat([stacked, padding], -1)
return stacked
[docs]
class RandomTranslocation(AugmentBase):
"""
Randomly shifts sequences using circular roll transformations.
This augmentation applies random positive or negative shifts to each sequence
in a batch, effectively cutting the sequence and reordering the pieces
while maintaining the original sequence length L.
Parameters
----------
shift_min : int, optional
Minimum size for random shift. Defaults to 0.
shift_max : int, optional
Maximum size for random shift. Defaults to 20.
Notes
-----
- Shifts are randomly chosen between shift_min and shift_max
- Approximately half of the shifts are made negative to create
both left and right circular shifts
- Uses torch.roll for efficient implementation
- Each sequence in the batch receives a different random shift
"""
[docs]
def __init__(self, shift_min=0, shift_max=20):
self.shift_min = shift_min
self.shift_max = shift_max
[docs]
def __call__(self, x):
"""Randomly shift sequences in a batch using circular roll.
Parameters
----------
x : torch.Tensor
Batch of one-hot sequences with shape (N, A, L).
Returns
-------
torch.Tensor
Sequences with random circular shifts applied, maintaining
shape (N, A, L).
"""
N = x.shape[0]
# determine size of shifts for each sequence
shifts = torch.randint(self.shift_min, self.shift_max + 1, (N,))
# make some of the shifts negative
ind_neg = torch.rand(N) < 0.5
shifts[ind_neg] = -1 * shifts[ind_neg]
# apply random shift to each sequence
x_rolled = []
for i, shift in enumerate(shifts):
x_rolled.append( torch.roll(x[i], shift.item(), -1) )
x_rolled = torch.stack(x_rolled).to(x.device)
return x_rolled
[docs]
class RandomInversion(AugmentBase):
"""
Randomly inverts contiguous stretches of nucleotides in sequences.
This augmentation randomly selects inversion lengths and positions for each sequence
in a batch, then applies a reverse-complement transformation to the selected region
while maintaining the original sequence length L.
Parameters
----------
invert_min : int, optional
Minimum size for random inversion. Defaults to 0.
invert_max : int, optional
Maximum size for random inversion. Defaults to 20.
Notes
-----
- Inversion positions are constrained to ensure the inversion window fits within
the sequence boundaries
- Applies reverse-complement transformation (flip both sequence and nucleotide dimensions)
- Each sequence in the batch receives a different random inversion
"""
[docs]
def __init__(self, invert_min=0, invert_max=20):
self.invert_min = invert_min
self.invert_max = invert_max
[docs]
def __call__(self, x):
"""Randomly invert segments of DNA sequences.
Parameters
----------
x : torch.Tensor
Batch of one-hot sequences with shape (N, A, L).
Returns
-------
torch.Tensor
Sequences with randomly inverted segments, maintaining
shape (N, A, L).
"""
N, A, L = x.shape
# set random inversion size for each sequence
inversion_lens = torch.randint(self.invert_min, self.invert_max + 1, (N,))
# randomly select start location for each inversion
inversion_inds = torch.randint(L - self.invert_max + 1, (N,)) # inversion must be in boundaries of seq.
# apply random inversion to each sequence
x_aug = []
for seq, inversion_len, inversion_ind in zip(x, inversion_lens, inversion_inds):
x_aug.append( torch.cat([seq[:,:inversion_ind], # sequence up to inversion start index
torch.flip(seq[:,inversion_ind:inversion_ind+inversion_len], dims=[0,1]), # reverse-complement transformation
seq[:,inversion_ind+inversion_len:]], # sequence after inversion
-1)) # concatenation axis
return torch.stack(x_aug)
[docs]
class RandomMutation(AugmentBase):
"""
Randomly mutates nucleotides in sequences according to a mutation fraction.
This augmentation randomly selects positions in each sequence and replaces
the nucleotides with random DNA, effectively introducing point mutations
while maintaining the original sequence length L.
Parameters
----------
mutate_frac : float, optional
Probability of mutation for each nucleotide. Defaults to 0.05.
Notes
-----
- The actual number of mutations is calculated as: round(mutate_frac / 0.75 * L)
- The division by 0.75 accounts for silent mutations (nucleotides that don't change)
- Random DNA is generated using uniform nucleotide distribution
- Each sequence in the batch receives a different set of random mutations
"""
[docs]
def __init__(self, mut_frac=0.05):
self.mutate_frac = mut_frac
[docs]
def __call__(self, x):
"""Randomly introduce mutations to a set of one-hot DNA sequences.
Parameters
----------
x : torch.Tensor
Batch of one-hot sequences with shape (N, A, L).
Returns
-------
torch.Tensor
Sequences with randomly mutated DNA, maintaining
shape (N, A, L).
"""
N, A, L = x.shape
# determine the number of mutations per sequence
num_mutations = round(self.mutate_frac / 0.75 * L) # num. mutations per sequence (accounting for silent mutations)
# If no mutations, return original sequences
if num_mutations <= 0:
return x
# randomly determine the indices to apply mutations
mutation_inds = torch.argsort(torch.rand(N,L))[:, :num_mutations] # see <https://discuss.pytorch.org/t/torch-equivalent-of-numpy-random-choice/16146>0
# create random DNA (to serve as random mutations)
a = torch.eye(A)
p = torch.tensor([1/A for _ in range(A)])
mutations = torch.stack([a[p.multinomial(num_mutations, replacement=True)].transpose(0,1) for _ in range(N)]).to(x.device)
# make a copy of the batch of sequences
x_aug = torch.clone(x)
# loop over sequences and apply mutations
for i in range(N):
x_aug[i,:,mutation_inds[i]] = mutations[i]
return x_aug
[docs]
class RandomRC(AugmentBase):
"""
Randomly applies reverse-complement transformations to sequences.
This augmentation randomly selects sequences in a batch and applies
a reverse-complement transformation with a specified probability.
The transformation reverses both the sequence order and nucleotide
identity while maintaining the original sequence length L.
Parameters
----------
rc_prob : float, optional
Probability to apply a reverse-complement transformation. Defaults to 0.5.
Notes
-----
- Each sequence is independently selected for transformation
- Uses torch.flip with dims=[1,2] to reverse both sequence and nucleotide dimensions
- Maintains original sequence length L
- Useful for learning strand-invariant representations
"""
[docs]
def __init__(self, rc_prob=0.5):
"""Create random reverse-complement augmentation object.
Parameters
----------
rc_prob : float
Probability to apply reverse-complement transformation.
"""
self.rc_prob = rc_prob
[docs]
def __call__(self, x):
"""Randomly transform sequences with reverse-complement transformations.
Parameters
----------
x : torch.Tensor
Batch of one-hot sequences with shape (N, A, L).
Returns
-------
torch.Tensor
Sequences with random reverse-complements applied, maintaining
shape (N, A, L).
"""
# make a copy of the sequence
x_aug = torch.clone(x)
# randomly select sequences to apply rc transformation
ind_rc = torch.rand(x_aug.shape[0]) < self.rc_prob
# apply reverse-complement transformation
x_aug[ind_rc] = torch.flip(x_aug[ind_rc], dims=[1,2])
return x_aug
[docs]
class RandomNoise(AugmentBase):
"""
Randomly adds Gaussian noise to sequences.
This augmentation adds random Gaussian noise to each sequence in a batch,
effectively introducing small perturbations to the one-hot encodings
while maintaining the original sequence length L.
Parameters
----------
noise_mean : float, optional
Mean of the Gaussian noise. Defaults to 0.0.
noise_std : float, optional
Standard deviation of the Gaussian noise. Defaults to 0.2.
Notes
-----
- Noise is sampled from a normal distribution with specified mean and standard deviation
- Noise is added element-wise to the input tensor
- Useful for improving model robustness to small perturbations
- Each sequence in the batch receives different random noise
"""
[docs]
def __init__(self, noise_mean=0.0, noise_std=0.2):
self.noise_mean = noise_mean
self.noise_std = noise_std
[docs]
def __call__(self, x):
"""Randomly add Gaussian noise to a set of one-hot DNA sequences.
Parameters
----------
x : torch.Tensor
Batch of one-hot sequences with shape (N, A, L).
Returns
-------
torch.Tensor
Sequences with random noise added, maintaining
shape (N, A, L).
"""
return x + torch.normal(self.noise_mean, self.noise_std, x.shape).to(x.device)