"""
EvoAug2: PyTorch DataLoader implementation of EvoAug functionality.
This module provides the same augmentation capabilities as RobustModel but
as a standalone PyTorch DataLoader that can be used with any model.
The RobustLoader inherits from DataLoader and can be used directly in
PyTorch Lightning DataModules or vanilla PyTorch training loops.
Classes
-------
AugmentedGenomicDataset
Dataset wrapper that applies EvoAug augmentations on-the-fly.
RobustLoader
DataLoader with built-in EvoAug augmentations.
"""
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
from typing import List, Optional, Tuple, Union
from evoaug.augment import AugmentBase
[docs]
class AugmentedGenomicDataset(Dataset):
"""
PyTorch Dataset that applies EvoAug-style augmentations to genomic sequences.
This dataset wraps an existing dataset and applies augmentations on-the-fly
during training, while optionally disabling them for validation/finetuning.
Parameters
----------
base_dataset : torch.utils.data.Dataset
The underlying dataset that provides (sequence, target) pairs.
augment_list : List[AugmentBase], optional
List of data augmentations to apply. Defaults to empty list.
max_augs_per_seq : int, optional
Maximum number of augmentations to apply per sequence. Defaults to 0.
hard_aug : bool, optional
If True, always apply exactly max_augs_per_seq augmentations.
If False, randomly sample 1 to max_augs_per_seq augmentations. Defaults to True.
apply_augmentations : bool, optional
Whether to apply augmentations. Can be toggled for finetuning. Defaults to True.
Notes
-----
- The dataset automatically detects the maximum insertion length from augmentations
- Augmentations can be enabled/disabled at runtime using enable_augmentations()
and disable_augmentations() methods
- Each sequence receives a different random combination of augmentations
"""
[docs]
def __init__(self,
base_dataset: Dataset,
augment_list: List[AugmentBase] = [],
max_augs_per_seq: int = 0,
hard_aug: bool = True,
apply_augmentations: bool = True):
self.base_dataset = base_dataset
self.augment_list = augment_list
self.max_augs_per_seq = min(max_augs_per_seq, len(augment_list))
self.hard_aug = hard_aug
self.apply_augmentations = apply_augmentations
self.max_num_aug = len(augment_list)
self.insert_max = self._get_insert_max()
[docs]
def __len__(self):
"""Return the number of samples in the dataset.
Returns
-------
int
Number of samples in the base dataset.
"""
return len(self.base_dataset)
[docs]
def __getitem__(self, idx):
"""Get a single sample from the dataset.
Parameters
----------
idx : int
Index of the sample to retrieve.
Returns
-------
torch.Tensor or tuple
If target exists: (augmented_sequence, target)
If no target: augmented_sequence
Notes
-----
- Sequences are augmented on-the-fly if augmentations are enabled
- Augmentations preserve the original sequence length L
- Each call may produce different augmentations due to randomness
"""
# Get the original data
data = self.base_dataset[idx]
# Handle different data formats
if isinstance(data, (tuple, list)) and len(data) >= 2:
sequence, target = data[0], data[1]
else:
sequence = data
target = None
# Apply augmentations if enabled
if self.apply_augmentations and self.augment_list:
sequence = self._apply_augmentations(sequence)
# elif self.insert_max > 0:
# # If no augmentations but we need padding for consistency
# sequence = self._pad_end(sequence)
if target is not None:
return sequence, target
else:
return sequence
def _get_insert_max(self) -> int:
"""Get the maximum insertion length from augmentations.
Returns
-------
int
Maximum insertion length found in augment_list, or 0 if none found.
Notes
-----
This method scans through all augmentations to find the maximum
insertion length, which is used for consistent padding when needed.
"""
insert_max = 0
for augment in self.augment_list:
if hasattr(augment, 'insert_max'):
insert_max = augment.insert_max
return insert_max
def _sample_aug_combos(self) -> List[List[int]]:
"""Sample augmentation combinations for a single sequence.
Returns
-------
List[List[int]]
List containing a single list of augmentation indices to apply.
Notes
-----
- If hard_aug is True, exactly max_augs_per_seq augmentations are selected
- If hard_aug is False, 1 to max_augs_per_seq augmentations are randomly selected
- Augmentations are selected without replacement from the available list
"""
if self.hard_aug:
num_augs = self.max_augs_per_seq
else:
num_augs = np.random.randint(1, self.max_augs_per_seq + 1)
if num_augs == 0:
return []
# Randomly choose augmentations
aug_indices = list(sorted(np.random.choice(self.max_num_aug, num_augs, replace=False)))
return [aug_indices]
def _apply_augmentations(self, sequence: torch.Tensor) -> torch.Tensor:
"""Apply augmentations to a single sequence.
Parameters
----------
sequence : torch.Tensor
Input sequence with shape (A, L) where A is number of nucleotides
and L is sequence length.
Returns
-------
torch.Tensor
Augmented sequence with shape (A, L).
Notes
-----
- The sequence is temporarily converted to batch format (1, A, L) for processing
- Augmentations are applied sequentially in the sampled order
- The final sequence maintains the original length L
"""
if not self.augment_list:
return sequence
# Sample augmentation combination
aug_combos = self._sample_aug_combos()
if not aug_combos:
return sequence
aug_indices = aug_combos[0]
sequence = sequence.unsqueeze(0) # Add batch dimension
# Apply augmentations
insert_status = True
for aug_index in aug_indices:
sequence = self.augment_list[aug_index](sequence)
if hasattr(self.augment_list[aug_index], 'insert_max'):
insert_status = False
# # Add padding only if no insertion augmentations were applied AND augmentations are enabled
# if insert_status and self.insert_max > 0 and self.apply_augmentations:
# sequence = self._pad_end(sequence)
return sequence.squeeze(0) # Remove batch dimension
def _pad_end(self, sequence: torch.Tensor) -> torch.Tensor:
"""Add random DNA padding to the end of a sequence.
Parameters
----------
sequence : torch.Tensor
Input sequence with shape (A, L) or (N, A, L).
Returns
-------
torch.Tensor
Sequence with random DNA padding added to the end.
Notes
-----
- Padding length is determined by self.insert_max
- Random DNA is generated using uniform nucleotide distribution
- This method handles both single sequences and batches
"""
if self.insert_max <= 0:
return sequence
# Handle both single sequences and batches
if sequence.dim() == 3: # Batch of sequences
N, A, L = sequence.shape
a = torch.eye(A)
p = torch.tensor([1/A for _ in range(A)])
padding = torch.stack([a[p.multinomial(self.insert_max, replacement=True)].transpose(0,1)
for _ in range(N)]).to(sequence.device)
return torch.cat([sequence, padding], dim=2)
else: # Single sequence
A, L = sequence.shape
a = torch.eye(A)
p = torch.tensor([1/A for _ in range(A)])
padding = a[p.multinomial(self.insert_max, replacement=True)].transpose(0,1).to(sequence.device)
return torch.cat([sequence, padding], dim=1)
[docs]
def enable_augmentations(self):
"""Enable augmentations for training.
Notes
-----
This method allows augmentations to be applied during training
while keeping them disabled for validation/finetuning.
"""
self.apply_augmentations = True
[docs]
def disable_augmentations(self):
"""Disable augmentations for finetuning/validation.
Notes
-----
This method prevents augmentations from being applied, useful
for validation, testing, or finetuning on original data.
"""
self.apply_augmentations = False
[docs]
class RobustLoader(DataLoader):
"""
EvoAug2 DataLoader that inherits from PyTorch DataLoader.
This class provides a DataLoader with built-in EvoAug augmentations that can be
used with pl.DataModule or directly into vanilla PyTorch.
Parameters
----------
base_dataset : torch.utils.data.Dataset
The underlying dataset that provides (sequence, target) pairs.
augment_list : List[AugmentBase], optional
List of augmentations to apply. Defaults to empty list.
max_augs_per_seq : int, optional
Maximum augmentations per sequence. Defaults to 0.
hard_aug : bool, optional
Whether to use hard augmentation count. Defaults to True.
batch_size : int, optional
Batch size for the DataLoader. Defaults to 32.
shuffle : bool, optional
Whether to shuffle the data. Defaults to True.
num_workers : int, optional
Number of worker processes. Defaults to 4.
**kwargs
Additional arguments passed to DataLoader.
Notes
-----
- The RobustLoader automatically creates an AugmentedGenomicDataset wrapper
- Augmentations can be enabled/disabled at runtime using enable_augmentations()
and disable_augmentations() methods
- All augmentations preserve sequence length L for consistent batch shapes
"""
[docs]
def __init__(self,
base_dataset: Dataset,
augment_list: List[AugmentBase] = [],
max_augs_per_seq: int = 0,
hard_aug: bool = True,
batch_size: int = 32,
shuffle: bool = True,
num_workers: int = 4,
**kwargs):
# Create the augmented dataset
self.augmented_dataset = AugmentedGenomicDataset(
base_dataset=base_dataset,
augment_list=augment_list,
max_augs_per_seq=max_augs_per_seq,
hard_aug=hard_aug,
apply_augmentations=True
)
# Initialize the parent DataLoader with the augmented dataset
super().__init__(
dataset=self.augmented_dataset,
batch_size=batch_size,
shuffle=shuffle,
num_workers=num_workers,
**kwargs
)
[docs]
def enable_augmentations(self):
"""Enable augmentations for training.
Notes
-----
This method enables augmentations on the underlying dataset,
allowing them to be applied during training.
"""
self.augmented_dataset.enable_augmentations()
[docs]
def disable_augmentations(self):
"""Disable augmentations for finetuning/validation.
Notes
-----
This method disables augmentations on the underlying dataset,
useful for validation, testing, or finetuning on original data.
"""
self.augmented_dataset.disable_augmentations()
[docs]
def set_augmentations(self, augment_list: List[AugmentBase], max_augs_per_seq: int = 0, hard_aug: bool = True):
"""Update the augmentation settings.
Parameters
----------
augment_list : List[AugmentBase]
New list of augmentations to apply.
max_augs_per_seq : int, optional
New maximum augmentations per sequence. Defaults to 0.
hard_aug : bool, optional
New hard augmentation setting. Defaults to True.
Notes
-----
This method allows dynamic updating of augmentation parameters
without recreating the entire DataLoader.
"""
self.augmented_dataset.augment_list = augment_list
self.augmented_dataset.max_augs_per_seq = min(max_augs_per_seq, len(augment_list))
self.augmented_dataset.hard_aug = hard_aug
self.augmented_dataset.max_num_aug = len(augment_list)
self.augmented_dataset.insert_max = self.augmented_dataset._get_insert_max()