Vanilla PyTorch Integration Example
This example demonstrates basic EvoAug2 functionality using vanilla PyTorch, providing a simple and direct approach to integrating evolution-inspired augmentations into your training pipeline.
Overview
The vanilla PyTorch example (example_vanilla_pytorch.py) showcases:
Direct PyTorch integration without Lightning abstractions
Basic augmentation application to genomic sequences
Simple training loop implementation
Core EvoAug2 functionality demonstration
Minimal dependencies for quick prototyping
Key Features
Simple Implementation: Straightforward PyTorch code without external abstractions
Core Augmentations: Demonstrates basic mutation and deletion augmentations
Easy Customization: Simple to modify for different use cases
Minimal Setup: Requires only basic PyTorch knowledge
Quick Prototyping: Ideal for research and experimentation
File Structure
example_vanilla_pytorch.py
├── Imports and setup
├── Augmentation definition
├── Simple model definition
├── Training loop
└── Basic evaluation
Usage
Basic Execution:
python example_vanilla_pytorch.py
Prerequisites:
# Install core dependencies
pip install evoaug2
# Or install from source
git clone https://github.com/aduranu/evoaug.git
cd evoaug
pip install -e .
Dependencies:
PyTorch >= 1.9.0
NumPy >= 1.20.0
EvoAug2 core package
Code Walkthrough
1. Imports and Setup:
import torch
import torch.nn as nn
from evoaug.augment import RandomMutation, RandomDeletion
from evoaug.evoaug import RobustLoader
2. Augmentation Definition:
# Define augmentations
augmentations = [
RandomMutation(mut_frac=0.1), # 10% mutation rate
RandomDeletion(delete_min=0, delete_max=20) # 0-20 deletions
]
3. Simple Model:
class SimpleModel(nn.Module):
def __init__(self, input_size=200, num_classes=2):
super().__init__()
self.conv1 = nn.Conv1d(4, 32, kernel_size=3, padding=1)
self.pool = nn.AdaptiveAvgPool1d(1)
self.fc = nn.Linear(32, num_classes)
def forward(self, x):
x = x.transpose(1, 2) # [batch, channels, length]
x = torch.relu(self.conv1(x))
x = self.pool(x).squeeze(-1)
x = self.fc(x)
return x
4. Training Loop:
# Training setup
model = SimpleModel()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
# Training loop
model.train()
for epoch in range(num_epochs):
for batch_seqs, batch_labels in dataloader:
# Apply augmentations
for aug in augmentations:
batch_seqs = aug(batch_seqs)
# Forward pass
outputs = model(batch_seqs)
loss = criterion(outputs, batch_labels)
# Backward pass
optimizer.zero_grad()
loss.backward()
optimizer.step()
Augmentation Application
Direct Application:
# Apply single augmentation
augmented_seq = mutation(sequence)
# Apply multiple augmentations sequentially
for aug in augmentations:
sequence = aug(sequence)
Batch Processing:
# Apply to entire batch
batch_size, seq_length, channels = batch_seqs.shape
# Apply augmentations to each sequence in batch
for i in range(batch_size):
for aug in augmentations:
batch_seqs[i] = aug(batch_seqs[i:i+1]).squeeze(0)
RobustLoader Integration:
# Use RobustLoader for efficient batch processing
dataloader = RobustLoader(
base_dataset=dataset,
augment_list=augmentations,
max_augs_per_seq=1,
hard_aug=False,
batch_size=32,
shuffle=True
)
Customization Examples
Modify Augmentation Parameters:
# Adjust mutation rate
mutation = RandomMutation(mut_frac=0.05) # 5% mutations
# Change deletion range
deletion = RandomDeletion(delete_min=5, delete_max=50) # 5-50 deletions
# Add new augmentation types
from evoaug.augment import RandomTranslocation, RandomNoise
augmentations = [
RandomMutation(mut_frac=0.1),
RandomDeletion(delete_min=0, delete_max=20),
RandomTranslocation(shift_min=0, shift_max=15),
RandomNoise(noise_mean=0, noise_std=0.2)
]
Custom Training Loop:
# Add learning rate scheduling
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)
# Add validation loop
model.eval()
with torch.no_grad():
val_loss = 0
for val_seqs, val_labels in val_dataloader:
outputs = model(val_seqs)
val_loss += criterion(outputs, val_labels).item()
print(f"Validation Loss: {val_loss/len(val_dataloader):.4f}")
scheduler.step()
Custom Model Architecture:
class CustomModel(nn.Module):
def __init__(self, input_size=200, num_classes=2):
super().__init__()
self.conv1 = nn.Conv1d(4, 64, kernel_size=5, padding=2)
self.conv2 = nn.Conv1d(64, 128, kernel_size=3, padding=1)
self.dropout = nn.Dropout(0.3)
self.pool = nn.AdaptiveAvgPool1d(1)
self.fc1 = nn.Linear(128, 64)
self.fc2 = nn.Linear(64, num_classes)
def forward(self, x):
x = x.transpose(1, 2)
x = torch.relu(self.conv1(x))
x = torch.relu(self.conv2(x))
x = self.dropout(x)
x = self.pool(x).squeeze(-1)
x = torch.relu(self.fc1(x))
x = self.fc2(x)
return x
Data Handling
Basic Dataset:
class SimpleDataset(torch.utils.data.Dataset):
def __init__(self, sequences, labels):
self.sequences = sequences
self.labels = labels
def __len__(self):
return len(self.sequences)
def __getitem__(self, idx):
return self.sequences[idx], self.labels[idx]
Data Loading:
# Create dataset
dataset = SimpleDataset(sequences, labels)
# Create DataLoader
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=32,
shuffle=True
)
Data Preprocessing:
# Normalize sequences
sequences = (sequences - sequences.mean()) / sequences.std()
# Convert to float32
sequences = sequences.float()
labels = labels.long()
Training Configuration
Basic Training Parameters:
# Training configuration
num_epochs = 50
batch_size = 32
learning_rate = 0.001
weight_decay = 1e-6
# Model parameters
input_size = 200
num_classes = 2
hidden_size = 64
Advanced Training Options:
# Mixed precision training
scaler = torch.cuda.amp.GradScaler()
# Gradient clipping
max_grad_norm = 1.0
# Early stopping
patience = 10
best_loss = float('inf')
patience_counter = 0
Device Configuration:
# Check for GPU availability
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
# Move model and data to device
model = model.to(device)
batch_seqs = batch_seqs.to(device)
batch_labels = batch_labels.to(device)
Evaluation and Metrics
Basic Evaluation:
model.eval()
correct = 0
total = 0
with torch.no_grad():
for batch_seqs, batch_labels in dataloader:
outputs = model(batch_seqs)
_, predicted = torch.max(outputs.data, 1)
total += batch_labels.size(0)
correct += (predicted == batch_labels).sum().item()
accuracy = 100 * correct / total
print(f"Accuracy: {accuracy:.2f}%")
Loss Tracking:
# Track training loss
train_losses = []
for epoch in range(num_epochs):
epoch_loss = 0
for batch_seqs, batch_labels in dataloader:
# ... training code ...
epoch_loss += loss.item()
avg_loss = epoch_loss / len(dataloader)
train_losses.append(avg_loss)
print(f"Epoch {epoch+1}, Loss: {avg_loss:.4f}")
Model Saving:
# Save best model
if avg_loss < best_loss:
best_loss = avg_loss
torch.save(model.state_dict(), 'best_model.pth')
print("Saved best model!")
Comparison with Lightning Example
Advantages of Vanilla PyTorch:
Direct Control: Full control over training loop and augmentation application
Simple Debugging: Easier to debug and understand
Minimal Dependencies: Fewer external dependencies
Customization: Easy to modify for specific research needs
Learning: Better for understanding PyTorch fundamentals
Advantages of Lightning Example:
Production Ready: Professional training workflows
Built-in Features: Logging, checkpointing, distributed training
Less Code: More concise implementation
Best Practices: Follows PyTorch Lightning conventions
Scalability: Better for large-scale experiments
When to Use Each:
Use Vanilla PyTorch: For research, prototyping, learning, simple workflows
Use Lightning: For production, complex experiments, team collaboration
Troubleshooting
Common Issues:
Memory Errors: - Reduce batch size - Use gradient accumulation - Clear cache: torch.cuda.empty_cache()
Training Instability: - Reduce learning rate - Add gradient clipping - Check data normalization
Augmentation Problems: - Verify input tensor shapes - Check augmentation parameters - Ensure data types are correct
Debugging Tips:
# Print tensor shapes
print(f"Input shape: {batch_seqs.shape}")
print(f"Label shape: {batch_labels.shape}")
# Check data types
print(f"Input dtype: {batch_seqs.dtype}")
print(f"Label dtype: {batch_labels.dtype}")
# Verify augmentation output
print(f"Original shape: {sequence.shape}")
augmented = mutation(sequence)
print(f"Augmented shape: {augmented.shape}")
Next Steps
After running this example:
Experiment: Try different augmentation combinations
Customize: Modify the model architecture
Scale Up: Apply to larger datasets
Compare: Run the Lightning example for comparison
Further Learning:
Read the user_guide/augmentations for all augmentation types
Explore the api/evoaug for detailed API reference
Check the examples/lightning_module for advanced workflows
Review the user_guide/training for training strategies
This example provides a solid foundation for understanding EvoAug2’s core functionality and can be easily extended for your specific research needs.