Models Module¶
This module defines the neural network architectures for the WGAN-GP implementation.
Architecture Overview¶
The models are designed for tabular data generation with residual connections to improve training stability and gradient flow.
Classes¶
Residual¶
Building block layer that applies linear transformation, batch normalization, and ReLU activation with residual connections.
Generator¶
Residual network-based generator that transforms random noise into synthetic tabular data samples.
Discriminator¶
Multi-layer discriminator network that evaluates the quality of generated samples using Wasserstein distance.
Key Features¶
- Residual Connections: Improved gradient flow and training stability
- Configurable Architecture: Flexible layer dimensions for different dataset sizes
- Batch Normalization: Stabilized training dynamics
- Leaky ReLU Activations: Better gradient propagation in discriminator
wgan_gp.models
¶
Discriminator
¶
Bases: Module
Discriminator.
Source code in wgan_gp/models.py
128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 |
|
__init__(data_dim, discriminator_dim)
¶
Initializes the Discriminator network.
The discriminator is a sequential neural network designed to differentiate between real and synthetic data samples. It is built using linear layers followed by LeakyReLU activations to introduce non-linearity, enabling it to learn complex data distributions.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
data_dim
|
int
|
The dimension of the input data. This defines the number of features in each data sample that the discriminator will evaluate. |
required |
discriminator_dim
|
list
|
A list of integers specifying the number of neurons in each hidden layer of the discriminator network. This determines the capacity and complexity of the discriminator. |
required |
Attributes:
Name | Type | Description |
---|---|---|
data_dim |
int
|
Stores the dimension of the input data for use in the network's forward pass. |
seq |
Sequential
|
An nn.Sequential container that holds the discriminator's layers. This allows for easy forward propagation through the entire network. |
Source code in wgan_gp/models.py
133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 |
|
forward(input_)
¶
Applies the discriminator network to the input data.
The discriminator aims to distinguish between real and synthetic data samples. This method reshapes the input and processes it through a sequential model to output a classification score, indicating the likelihood of the input being real. This is a crucial step in training the GAN, allowing the discriminator to learn the characteristics of real data and guide the generator's learning process.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
input_
|
Tensor
|
The input data to be evaluated by the discriminator. |
required |
Returns:
Name | Type | Description |
---|---|---|
Output |
Tensor
|
The discriminator's output, representing the classification score for the input data. |
Source code in wgan_gp/models.py
170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 |
|
Generator
¶
Bases: Module
Generator.
Source code in wgan_gp/models.py
56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 |
|
__init__(embedding_dim, generator_dim, data_dim)
¶
Initializes a Generator object.
The generator network consists of a series of residual blocks followed by a linear layer. It takes a latent vector as input and outputs a data sample that mimics the real data distribution. This architecture allows the generator to learn complex data patterns and generate synthetic samples with high fidelity.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
embedding_dim
|
int
|
Dimension of the latent embedding. This determines the size of the input noise vector. |
required |
generator_dim
|
list
|
A list of dimensions for the residual blocks. Each value specifies the output dimension of a residual block, progressively increasing the feature space. |
required |
data_dim
|
int
|
Dimension of the output data. This should match the dimensionality of the real data samples that the generator is trying to emulate. |
required |
Attributes:
Name | Type | Description |
---|---|---|
latent_dim |
int
|
Dimension of the latent embedding. |
seq |
Sequential
|
A sequential container holding the generator network layers. |
Source code in wgan_gp/models.py
61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 |
|
forward(input_)
¶
Transforms the input data using a sequence of layers.
This process aims to generate synthetic data that mirrors the characteristics of real data, making it suitable for tasks requiring realistic datasets.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
input_
|
The input data to be transformed. This could be real data or random noise. |
required |
Returns:
Type | Description |
---|---|
The transformed data after passing through the generator's layers. This represents |
|
the synthetic data generated by the network. |
Source code in wgan_gp/models.py
93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 |
|
sample_latent(num_samples)
¶
Samples latent vectors from a standard normal distribution.
These vectors serve as the initial input to the generator, guiding the creation of synthetic samples. Sampling from a standard normal distribution ensures diversity in the generated output, allowing the GAN to explore different regions of the data space.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
num_samples
|
int
|
The number of latent vectors to sample. |
required |
Returns:
Name | Type | Description |
---|---|---|
Output |
Tensor
|
A tensor of shape (num_samples, latent_dim) containing the sampled latent vectors. |
Source code in wgan_gp/models.py
110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 |
|
Residual
¶
Bases: Module
Residual layer.
Source code in wgan_gp/models.py
5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 |
|
__init__(i, o)
¶
Initializes a Residual block for feature transformation within the GAN.
This block applies a linear transformation, batch normalization, and ReLU activation to refine features during the generation or discrimination process. It helps the network learn more complex representations by introducing non-linearities and stabilizing the learning process.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
i
|
int
|
The input feature size. |
required |
o
|
int
|
The output feature size. |
required |
Attributes:
Name | Type | Description |
---|---|---|
fc |
Linear
|
A linear layer that transforms the input from size 'i' to size 'o'. |
bn |
BatchNorm1d
|
A batch normalization layer applied to the output of the linear layer, stabilizing the activations. |
relu |
ReLU
|
A ReLU activation function applied after batch normalization, introducing non-linearity. |
Source code in wgan_gp/models.py
10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 |
|
forward(input_)
¶
Applies a residual connection to the input after passing it through a fully connected layer, batch normalization (commented out), and ReLU activation.
This process enhances the model's ability to capture intricate data patterns by concatenating the processed input with the original input, thus preserving original data characteristics while introducing non-linear transformations.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
input_
|
Tensor
|
The input tensor to the residual layer. |
required |
Returns:
Name | Type | Description |
---|---|---|
Output |
Tensor
|
The concatenated tensor of the processed input and the original input. |
Source code in wgan_gp/models.py
35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 |
|