Building a GAN from Scratch using TensorFlow & Keras
Generative Adversarial Networks (GANs) are one of the most exciting innovations in modern deep learning. In this tutorial, we’ll guide you through building a GAN from scratch that can generate digits similar to MNIST.
🧠What is a GAN?
A GAN consists of two neural networks:
- Generator: Generates fake data
- Discriminator: Distinguishes fake from real data
These networks train against each other in a game-like process.
🔧 Step-by-Step Implementation
Step 1: Import Libraries
import tensorflow as tf
from tensorflow.keras.layers import Dense, Flatten, Reshape, LeakyReLU
from tensorflow.keras.models import Sequential
import numpy as np
Step 2: Load and Normalize Data
(x_train, _), _ = tf.keras.datasets.mnist.load_data()
x_train = x_train / 127.5 - 1.0
x_train = x_train.reshape(-1, 784)
Step 3: Build Generator
def build_generator():
model = Sequential()
model.add(Dense(128, input_dim=100))
model.add(LeakyReLU(alpha=0.2))
model.add(Dense(784, activation='tanh'))
model.add(Reshape((28, 28)))
return model
Step 4: Build Discriminator
def build_discriminator():
model = Sequential()
model.add(Flatten(input_shape=(28, 28)))
model.add(Dense(128))
model.add(LeakyReLU(alpha=0.2))
model.add(Dense(1, activation='sigmoid'))
return model
Step 5: Compile Models
discriminator = build_discriminator()
discriminator.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
generator = build_generator()
z = tf.keras.Input(shape=(100,))
img = generator(z)
discriminator.trainable = False
valid = discriminator(img)
combined = tf.keras.Model(z, valid)
combined.compile(optimizer='adam', loss='binary_crossentropy')
Step 6: Training the GAN
epochs = 10000
batch_size = 64
for epoch in range(epochs):
# Sample real and fake data
idx = np.random.randint(0, x_train.shape[0], batch_size)
real_imgs = x_train[idx]
noise = np.random.normal(0, 1, (batch_size, 100))
gen_imgs = generator.predict(noise)
# Train discriminator
d_loss_real = discriminator.train_on_batch(real_imgs, np.ones((batch_size, 1)))
d_loss_fake = discriminator.train_on_batch(gen_imgs, np.zeros((batch_size, 1)))
# Train generator
g_loss = combined.train_on_batch(noise, np.ones((batch_size, 1)))
if epoch % 1000 == 0:
print(f"Epoch {epoch} | D loss: {0.5 * (d_loss_real[0] + d_loss_fake[0])} | G loss: {g_loss}")
📌 Tips for Better GANs
- Use
LeakyReLU
instead ofReLU
- Avoid using
Dropout
in the generator - Normalize input data between -1 and 1
- Start with low resolution and scale up
✅ Conclusion
GANs are incredibly powerful for image generation, and this basic example is just the start. With the right tuning, you can create faces, artwork, and even synthetic datasets.
📣 What's Next?
Next in our series: Transformers and NLP with Hugging Face. Stay tuned, and follow Darchums Tech for weekly expert AI content!
Comments
Post a Comment