Intro¶
WGAN-GP with Meta-Feature Statistics (MFS)
A PyTorch implementation of Wasserstein GAN with Gradient Penalty (WGAN-GP) enhanced with meta-feature statistics preservation for synthetic tabular data generation.
Features¶
- WGAN-GP Implementation: Stable GAN training with Wasserstein distance and gradient penalty
- Meta-Feature Statistics (MFS) Integration: Preserves statistical properties of the original data
- Comprehensive Evaluation: Multiple utility and fidelity metrics for synthetic data assessment
- Experiment Tracking: Built-in Aim tracking for monitoring training progress
- Flexible Architecture: Configurable generator and discriminator architectures
- Multi-Dataset Support: Tested on tabular datasets including Abalone and California Housing
Architecture¶
The project implements two main training approaches:
- Vanilla WGAN-GP: Standard WGAN-GP implementation
- MFS-Enhanced WGAN-GP: WGAN-GP with additional meta-feature statistics preservation loss
Key Components¶
- Generator: Residual network-based generator with configurable dimensions
- Discriminator: Multi-layer discriminator with LeakyReLU activations
- MFS Manager: PyTorch implementation of meta-feature extraction (correlation, covariance, eigenvalues, etc.)
- Wasserstein Distance: Topological distance computation for MFS alignment
Installation¶
-
Clone the repository:
git clone https://github.com/Roman223/GAN_MFS cd GAN_MFS
-
Install dependencies:
pip install -r req.txt
Usage¶
Basic Training¶
from wgan_gp.models import Generator, Discriminator
from wgan_gp.training import Trainer, TrainerModified
# Configure hyperparameters
learning_params = {
'epochs': 1000,
'learning_rate_D': 0.0001,
'learning_rate_G': 0.0001,
'batch_size': 64,
'gp_weight': 10,
'generator_dim': (64, 128, 64),
'discriminator_dim': (64, 128, 64),
'emb_dim': 32,
}
# For MFS-enhanced training
learning_params.update({
'mfs_lambda': [0.1, 2.0], # or single float value
'subset_mfs': ['mean', 'var', 'eigenvalues'],
'sample_number': 10,
'sample_frac': 0.5,
})
# Initialize and train model
trainer = TrainerModified(...) # or Trainer() for vanilla WGAN-GP
trainer.train(data_loader, epochs, plot_freq=100)
Data Preparation¶
The project expects tabular data with the target variable. Example preprocessing:
import pandas as pd
from sklearn.model_selection import train_test_split
# Load your data
data = pd.read_csv('your_data.csv')
y = data.pop('target').values
X = data.values
# Split data
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3)
# Combine features and target for training
training_data = np.hstack([X_train, y_train.reshape(-1, 1)])
Meta-Feature Statistics (MFS)¶
The MFS component preserves important statistical properties of the original data:
Supported Meta-Features¶
- Statistical: mean, variance, standard deviation, range, min, max, median
- Distributional: skewness, kurtosis, interquartile range
- Relational: correlation matrix, covariance matrix, eigenvalues
- Advanced: sparsity, mad (median absolute deviation)
MFS Loss Function¶
The generator loss combines adversarial loss with MFS preservation:
g_loss = d_generated.mean() + λ * wasserstein_distance(mfs_real, mfs_synthetic)
Evaluation Metrics¶
The project includes comprehensive evaluation metrics:
Utility Metrics¶
- R² Score: Regression performance preservation
- RMSE: Root mean squared error
- MAPE: Mean absolute percentage error
Fidelity Metrics¶
- Correlation Matrix Distance: Cosine distance between correlation matrices
- Jensen-Shannon Divergence: Marginal distribution similarity
Experiment Tracking¶
Aim tracking is implemented and strongly advised for monitoring:
- Training losses (Generator, Discriminator, MFS)
- Gradient norms and flow visualization
- Sample quality progression
- Comprehensive metrics logging
To wake up server invoke:
aim up
Project Structure¶
wgan_gp/
__init__.py
main.py # Main training script
models.py # Generator and Discriminator definitions
training.py # Training loops and MFS integration
pymfe_to_torch.py # PyTorch meta-feature extraction
utils.py # Utility functions and metrics
req.txt # Dependencies
gitignore # Git ignore file
README.md # This file
Key Dependencies¶
- PyTorch: Deep learning framework
- Aim: Experiment tracking
- scikit-learn: Machine learning utilities
- pandas/numpy: Data manipulation
- matplotlib/seaborn: Visualization
- torch-topological: Topological data analysis
- POT: Optimal transport (Wasserstein distance)
Configuration¶
Key hyperparameters can be configured in main.py
:
learning_params = dict(
epochs=1000, # Training epochs
learning_rate_D=0.0001, # Discriminator learning rate
learning_rate_G=0.0001, # Generator learning rate
batch_size=64, # Batch size
gp_weight=10, # Gradient penalty weight
generator_dim=(64,128,64), # Generator architecture
discriminator_dim=(64,128,64), # Discriminator architecture
mfs_lambda=0.1, # MFS loss weights
subset_mfs=['mean','var'], # Selected meta-features
sample_number=10, # Number of variates for MFS
sample_frac=0.5, # Fraction for variate sampling
)
Results¶
The model generates high-quality synthetic tabular data that preserves:
- Statistical distributions of original features
- Correlational structure between variables
- Utility for downstream machine learning tasks
Contributing¶
Contributions are welcome! Please feel free to submit issues, feature requests, or pull requests.
License¶
This project is licensed under the MIT License - see the LICENSE file for details.
Citation¶
If you use this code in your research, please cite:
@misc{wgan_gp_mfs,
title={WGAN-GP with Meta-Feature Statistics for Synthetic Tabular Data Generation},
author={[Your Name]},
year={2024},
url={https://github.com/your-username/your-repo}
}
Acknowledgments¶
- Original WGAN-GP paper: Improved Training of Wasserstein GANs
- PyMFE library for meta-feature extraction inspiration
- Aim team for excellent experiment tracking tools