Skip to content

Models Module

This module defines the neural network architectures for the WGAN-GP implementation.

Architecture Overview

The models are designed for tabular data generation with residual connections to improve training stability and gradient flow.

Classes

Residual

Building block layer that applies linear transformation, batch normalization, and ReLU activation with residual connections.

Generator

Residual network-based generator that transforms random noise into synthetic tabular data samples.

Discriminator

Multi-layer discriminator network that evaluates the quality of generated samples using Wasserstein distance.

Key Features

  • Residual Connections: Improved gradient flow and training stability
  • Configurable Architecture: Flexible layer dimensions for different dataset sizes
  • Batch Normalization: Stabilized training dynamics
  • Leaky ReLU Activations: Better gradient propagation in discriminator

wgan_gp.models

Discriminator

Bases: Module

Discriminator.

Source code in wgan_gp/models.py
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
class Discriminator(nn.Module):
    """
    Discriminator.
    """

    def __init__(self, data_dim, discriminator_dim):
        """
        Initializes the Discriminator network.

        The discriminator is a sequential neural network designed to differentiate
        between real and synthetic data samples. It is built using linear layers
        followed by LeakyReLU activations to introduce non-linearity, enabling it
        to learn complex data distributions.

        Args:
            data_dim (int): The dimension of the input data. This defines the
                number of features in each data sample that the discriminator will
                evaluate.
            discriminator_dim (list): A list of integers specifying the number of
                neurons in each hidden layer of the discriminator network. This
                determines the capacity and complexity of the discriminator.

        Attributes:
            data_dim (int): Stores the dimension of the input data for use in the
                network's forward pass.
            seq (nn.Sequential): An nn.Sequential container that holds the
                discriminator's layers. This allows for easy forward propagation
                through the entire network.
        """
        super(Discriminator, self).__init__()
        seq = []
        self.data_dim = data_dim

        dim = data_dim
        for item in list(discriminator_dim):
            # seq += [nn.Linear(dim, item), nn.LeakyReLU(0.2), nn.Dropout(0.3)]
            seq += [nn.Linear(dim, item), nn.LeakyReLU(0.2)]
            dim = item

        seq += [nn.Linear(dim, 1)]
        self.seq = nn.Sequential(*seq)

    def forward(self, input_):
        """
        Applies the discriminator network to the input data.

        The discriminator aims to distinguish between real and synthetic data samples. This method
        reshapes the input and processes it through a sequential model to output a classification
        score, indicating the likelihood of the input being real. This is a crucial step in
        training the GAN, allowing the discriminator to learn the characteristics of real data
        and guide the generator's learning process.

        Args:
            input_ (torch.Tensor): The input data to be evaluated by the discriminator.

        Returns:
            Output (torch.Tensor): The discriminator's output, representing the classification score  for the input data.
        """
        return self.seq(input_.view(-1, self.data_dim))

__init__(data_dim, discriminator_dim)

Initializes the Discriminator network.

The discriminator is a sequential neural network designed to differentiate between real and synthetic data samples. It is built using linear layers followed by LeakyReLU activations to introduce non-linearity, enabling it to learn complex data distributions.

Parameters:

Name Type Description Default
data_dim int

The dimension of the input data. This defines the number of features in each data sample that the discriminator will evaluate.

required
discriminator_dim list

A list of integers specifying the number of neurons in each hidden layer of the discriminator network. This determines the capacity and complexity of the discriminator.

required

Attributes:

Name Type Description
data_dim int

Stores the dimension of the input data for use in the network's forward pass.

seq Sequential

An nn.Sequential container that holds the discriminator's layers. This allows for easy forward propagation through the entire network.

Source code in wgan_gp/models.py
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
def __init__(self, data_dim, discriminator_dim):
    """
    Initializes the Discriminator network.

    The discriminator is a sequential neural network designed to differentiate
    between real and synthetic data samples. It is built using linear layers
    followed by LeakyReLU activations to introduce non-linearity, enabling it
    to learn complex data distributions.

    Args:
        data_dim (int): The dimension of the input data. This defines the
            number of features in each data sample that the discriminator will
            evaluate.
        discriminator_dim (list): A list of integers specifying the number of
            neurons in each hidden layer of the discriminator network. This
            determines the capacity and complexity of the discriminator.

    Attributes:
        data_dim (int): Stores the dimension of the input data for use in the
            network's forward pass.
        seq (nn.Sequential): An nn.Sequential container that holds the
            discriminator's layers. This allows for easy forward propagation
            through the entire network.
    """
    super(Discriminator, self).__init__()
    seq = []
    self.data_dim = data_dim

    dim = data_dim
    for item in list(discriminator_dim):
        # seq += [nn.Linear(dim, item), nn.LeakyReLU(0.2), nn.Dropout(0.3)]
        seq += [nn.Linear(dim, item), nn.LeakyReLU(0.2)]
        dim = item

    seq += [nn.Linear(dim, 1)]
    self.seq = nn.Sequential(*seq)

forward(input_)

Applies the discriminator network to the input data.

The discriminator aims to distinguish between real and synthetic data samples. This method reshapes the input and processes it through a sequential model to output a classification score, indicating the likelihood of the input being real. This is a crucial step in training the GAN, allowing the discriminator to learn the characteristics of real data and guide the generator's learning process.

Parameters:

Name Type Description Default
input_ Tensor

The input data to be evaluated by the discriminator.

required

Returns:

Name Type Description
Output Tensor

The discriminator's output, representing the classification score for the input data.

Source code in wgan_gp/models.py
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
def forward(self, input_):
    """
    Applies the discriminator network to the input data.

    The discriminator aims to distinguish between real and synthetic data samples. This method
    reshapes the input and processes it through a sequential model to output a classification
    score, indicating the likelihood of the input being real. This is a crucial step in
    training the GAN, allowing the discriminator to learn the characteristics of real data
    and guide the generator's learning process.

    Args:
        input_ (torch.Tensor): The input data to be evaluated by the discriminator.

    Returns:
        Output (torch.Tensor): The discriminator's output, representing the classification score  for the input data.
    """
    return self.seq(input_.view(-1, self.data_dim))

Generator

Bases: Module

Generator.

Source code in wgan_gp/models.py
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
class Generator(nn.Module):
    """
    Generator.
    """

    def __init__(self, embedding_dim, generator_dim, data_dim):
        """
        Initializes a Generator object.

        The generator network consists of a series of residual blocks followed by a linear layer.
        It takes a latent vector as input and outputs a data sample that mimics the real data
        distribution. This architecture allows the generator to learn complex data patterns and
        generate synthetic samples with high fidelity.

        Args:
            embedding_dim (int): Dimension of the latent embedding. This determines the size
                of the input noise vector.
            generator_dim (list): A list of dimensions for the residual blocks. Each value
                specifies the output dimension of a residual block, progressively increasing
                the feature space.
            data_dim (int): Dimension of the output data. This should match the dimensionality
                of the real data samples that the generator is trying to emulate.

        Attributes:
            latent_dim (int): Dimension of the latent embedding.
            seq (nn.Sequential): A sequential container holding the generator network layers.
        """
        super(Generator, self).__init__()
        self.latent_dim = embedding_dim
        dim = embedding_dim
        seq = []
        for item in list(generator_dim):
            seq += [Residual(dim, item)]
            dim += item
        seq.append(nn.Linear(dim, data_dim))
        self.seq = nn.Sequential(*seq)

    def forward(self, input_):
        """
        Transforms the input data using a sequence of layers.

        This process aims to generate synthetic data that mirrors the characteristics of real data,
        making it suitable for tasks requiring realistic datasets.

        Args:
            input_: The input data to be transformed. This could be real data or random noise.

        Returns:
            The transformed data after passing through the generator's layers. This represents
            the synthetic data generated by the network.
        """
        data = self.seq(input_)
        return data

    def sample_latent(self, num_samples):
        """
        Samples latent vectors from a standard normal distribution.

        These vectors serve as the initial input to the generator, guiding the creation of
        synthetic samples. Sampling from a standard normal distribution ensures diversity in
        the generated output, allowing the GAN to explore different regions of the data space.

        Args:
            num_samples (int): The number of latent vectors to sample.

        Returns:
             Output (torch.Tensor): A tensor of shape (num_samples, latent_dim) containing the sampled
                latent vectors.
        """
        return torch.randn((num_samples, self.latent_dim))

__init__(embedding_dim, generator_dim, data_dim)

Initializes a Generator object.

The generator network consists of a series of residual blocks followed by a linear layer. It takes a latent vector as input and outputs a data sample that mimics the real data distribution. This architecture allows the generator to learn complex data patterns and generate synthetic samples with high fidelity.

Parameters:

Name Type Description Default
embedding_dim int

Dimension of the latent embedding. This determines the size of the input noise vector.

required
generator_dim list

A list of dimensions for the residual blocks. Each value specifies the output dimension of a residual block, progressively increasing the feature space.

required
data_dim int

Dimension of the output data. This should match the dimensionality of the real data samples that the generator is trying to emulate.

required

Attributes:

Name Type Description
latent_dim int

Dimension of the latent embedding.

seq Sequential

A sequential container holding the generator network layers.

Source code in wgan_gp/models.py
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
def __init__(self, embedding_dim, generator_dim, data_dim):
    """
    Initializes a Generator object.

    The generator network consists of a series of residual blocks followed by a linear layer.
    It takes a latent vector as input and outputs a data sample that mimics the real data
    distribution. This architecture allows the generator to learn complex data patterns and
    generate synthetic samples with high fidelity.

    Args:
        embedding_dim (int): Dimension of the latent embedding. This determines the size
            of the input noise vector.
        generator_dim (list): A list of dimensions for the residual blocks. Each value
            specifies the output dimension of a residual block, progressively increasing
            the feature space.
        data_dim (int): Dimension of the output data. This should match the dimensionality
            of the real data samples that the generator is trying to emulate.

    Attributes:
        latent_dim (int): Dimension of the latent embedding.
        seq (nn.Sequential): A sequential container holding the generator network layers.
    """
    super(Generator, self).__init__()
    self.latent_dim = embedding_dim
    dim = embedding_dim
    seq = []
    for item in list(generator_dim):
        seq += [Residual(dim, item)]
        dim += item
    seq.append(nn.Linear(dim, data_dim))
    self.seq = nn.Sequential(*seq)

forward(input_)

Transforms the input data using a sequence of layers.

This process aims to generate synthetic data that mirrors the characteristics of real data, making it suitable for tasks requiring realistic datasets.

Parameters:

Name Type Description Default
input_

The input data to be transformed. This could be real data or random noise.

required

Returns:

Type Description

The transformed data after passing through the generator's layers. This represents

the synthetic data generated by the network.

Source code in wgan_gp/models.py
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
def forward(self, input_):
    """
    Transforms the input data using a sequence of layers.

    This process aims to generate synthetic data that mirrors the characteristics of real data,
    making it suitable for tasks requiring realistic datasets.

    Args:
        input_: The input data to be transformed. This could be real data or random noise.

    Returns:
        The transformed data after passing through the generator's layers. This represents
        the synthetic data generated by the network.
    """
    data = self.seq(input_)
    return data

sample_latent(num_samples)

Samples latent vectors from a standard normal distribution.

These vectors serve as the initial input to the generator, guiding the creation of synthetic samples. Sampling from a standard normal distribution ensures diversity in the generated output, allowing the GAN to explore different regions of the data space.

Parameters:

Name Type Description Default
num_samples int

The number of latent vectors to sample.

required

Returns:

Name Type Description
Output Tensor

A tensor of shape (num_samples, latent_dim) containing the sampled latent vectors.

Source code in wgan_gp/models.py
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
def sample_latent(self, num_samples):
    """
    Samples latent vectors from a standard normal distribution.

    These vectors serve as the initial input to the generator, guiding the creation of
    synthetic samples. Sampling from a standard normal distribution ensures diversity in
    the generated output, allowing the GAN to explore different regions of the data space.

    Args:
        num_samples (int): The number of latent vectors to sample.

    Returns:
         Output (torch.Tensor): A tensor of shape (num_samples, latent_dim) containing the sampled
            latent vectors.
    """
    return torch.randn((num_samples, self.latent_dim))

Residual

Bases: Module

Residual layer.

Source code in wgan_gp/models.py
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
class Residual(nn.Module):
    """
    Residual layer.
    """

    def __init__(self, i, o):
        """
        Initializes a Residual block for feature transformation within the GAN.

        This block applies a linear transformation, batch normalization, and ReLU activation
        to refine features during the generation or discrimination process. It helps the network
        learn more complex representations by introducing non-linearities and stabilizing the
        learning process.

        Args:
            i (int): The input feature size.
            o (int): The output feature size.

        Attributes:
            fc (nn.Linear): A linear layer that transforms the input from size 'i' to size 'o'.
            bn (nn.BatchNorm1d): A batch normalization layer applied to the output of the
                linear layer, stabilizing the activations.
            relu (nn.ReLU): A ReLU activation function applied after batch normalization,
                introducing non-linearity.
        """
        super(Residual, self).__init__()
        self.fc = nn.Linear(i, o)
        self.bn = nn.BatchNorm1d(o)
        self.relu = nn.ReLU()

    def forward(self, input_):
        """
        Applies a residual connection to the input after passing it through a fully connected
        layer, batch normalization (commented out), and ReLU activation.

        This process enhances the model's ability to capture intricate data patterns by
        concatenating the processed input with the original input, thus preserving original
        data characteristics while introducing non-linear transformations.

        Args:
            input_ (torch.Tensor): The input tensor to the residual layer.

        Returns:
            Output (torch.Tensor): The concatenated tensor of the processed input and the original input.
        """
        out = self.fc(input_)
        # out = self.bn(out)
        out = self.relu(out)
        return torch.cat([out, input_], dim=1)

__init__(i, o)

Initializes a Residual block for feature transformation within the GAN.

This block applies a linear transformation, batch normalization, and ReLU activation to refine features during the generation or discrimination process. It helps the network learn more complex representations by introducing non-linearities and stabilizing the learning process.

Parameters:

Name Type Description Default
i int

The input feature size.

required
o int

The output feature size.

required

Attributes:

Name Type Description
fc Linear

A linear layer that transforms the input from size 'i' to size 'o'.

bn BatchNorm1d

A batch normalization layer applied to the output of the linear layer, stabilizing the activations.

relu ReLU

A ReLU activation function applied after batch normalization, introducing non-linearity.

Source code in wgan_gp/models.py
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
def __init__(self, i, o):
    """
    Initializes a Residual block for feature transformation within the GAN.

    This block applies a linear transformation, batch normalization, and ReLU activation
    to refine features during the generation or discrimination process. It helps the network
    learn more complex representations by introducing non-linearities and stabilizing the
    learning process.

    Args:
        i (int): The input feature size.
        o (int): The output feature size.

    Attributes:
        fc (nn.Linear): A linear layer that transforms the input from size 'i' to size 'o'.
        bn (nn.BatchNorm1d): A batch normalization layer applied to the output of the
            linear layer, stabilizing the activations.
        relu (nn.ReLU): A ReLU activation function applied after batch normalization,
            introducing non-linearity.
    """
    super(Residual, self).__init__()
    self.fc = nn.Linear(i, o)
    self.bn = nn.BatchNorm1d(o)
    self.relu = nn.ReLU()

forward(input_)

Applies a residual connection to the input after passing it through a fully connected layer, batch normalization (commented out), and ReLU activation.

This process enhances the model's ability to capture intricate data patterns by concatenating the processed input with the original input, thus preserving original data characteristics while introducing non-linear transformations.

Parameters:

Name Type Description Default
input_ Tensor

The input tensor to the residual layer.

required

Returns:

Name Type Description
Output Tensor

The concatenated tensor of the processed input and the original input.

Source code in wgan_gp/models.py
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
def forward(self, input_):
    """
    Applies a residual connection to the input after passing it through a fully connected
    layer, batch normalization (commented out), and ReLU activation.

    This process enhances the model's ability to capture intricate data patterns by
    concatenating the processed input with the original input, thus preserving original
    data characteristics while introducing non-linear transformations.

    Args:
        input_ (torch.Tensor): The input tensor to the residual layer.

    Returns:
        Output (torch.Tensor): The concatenated tensor of the processed input and the original input.
    """
    out = self.fc(input_)
    # out = self.bn(out)
    out = self.relu(out)
    return torch.cat([out, input_], dim=1)