Build a GAN with TensorFlow & Keras

Build a GAN with TensorFlow & Keras

Building a GAN from Scratch using TensorFlow & Keras

Generative Adversarial Networks (GANs) are one of the most exciting innovations in modern deep learning. In this tutorial, we’ll guide you through building a GAN from scratch that can generate digits similar to MNIST.

🧠 What is a GAN?

A GAN consists of two neural networks:

  • Generator: Generates fake data
  • Discriminator: Distinguishes fake from real data

These networks train against each other in a game-like process.

🔧 Step-by-Step Implementation

Step 1: Import Libraries

import tensorflow as tf
from tensorflow.keras.layers import Dense, Flatten, Reshape, LeakyReLU
from tensorflow.keras.models import Sequential
import numpy as np

Step 2: Load and Normalize Data

(x_train, _), _ = tf.keras.datasets.mnist.load_data()
x_train = x_train / 127.5 - 1.0
x_train = x_train.reshape(-1, 784)

Step 3: Build Generator

def build_generator():
  model = Sequential()
  model.add(Dense(128, input_dim=100))
  model.add(LeakyReLU(alpha=0.2))
  model.add(Dense(784, activation='tanh'))
  model.add(Reshape((28, 28)))
  return model

Step 4: Build Discriminator

def build_discriminator():
  model = Sequential()
  model.add(Flatten(input_shape=(28, 28)))
  model.add(Dense(128))
  model.add(LeakyReLU(alpha=0.2))
  model.add(Dense(1, activation='sigmoid'))
  return model

Step 5: Compile Models

discriminator = build_discriminator()
discriminator.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])

generator = build_generator()
z = tf.keras.Input(shape=(100,))
img = generator(z)
discriminator.trainable = False
valid = discriminator(img)
combined = tf.keras.Model(z, valid)
combined.compile(optimizer='adam', loss='binary_crossentropy')

Step 6: Training the GAN

epochs = 10000
batch_size = 64
for epoch in range(epochs):
  # Sample real and fake data
  idx = np.random.randint(0, x_train.shape[0], batch_size)
  real_imgs = x_train[idx]
  noise = np.random.normal(0, 1, (batch_size, 100))
  gen_imgs = generator.predict(noise)

  # Train discriminator
  d_loss_real = discriminator.train_on_batch(real_imgs, np.ones((batch_size, 1)))
  d_loss_fake = discriminator.train_on_batch(gen_imgs, np.zeros((batch_size, 1)))

  # Train generator
  g_loss = combined.train_on_batch(noise, np.ones((batch_size, 1)))

  if epoch % 1000 == 0:
    print(f"Epoch {epoch} | D loss: {0.5 * (d_loss_real[0] + d_loss_fake[0])} | G loss: {g_loss}")

📌 Tips for Better GANs

  • Use LeakyReLU instead of ReLU
  • Avoid using Dropout in the generator
  • Normalize input data between -1 and 1
  • Start with low resolution and scale up

✅ Conclusion

GANs are incredibly powerful for image generation, and this basic example is just the start. With the right tuning, you can create faces, artwork, and even synthetic datasets.

📣 What's Next?

Next in our series: Transformers and NLP with Hugging Face. Stay tuned, and follow Darchums Tech for weekly expert AI content!

Comments