快捷方式

DCGAN 教程

建立日期:2018 年 7 月 31 日 | 最後更新:2024 年 1 月 19 日 | 最後驗證:2024 年 11 月 5 日

作者: Nathan Inkawhich

引言

本教程將透過一個示例介紹 DCGAN。我們將訓練一個生成對抗網路 (GAN),在展示了許多真實名人照片後,使其生成新的名人照片。這裡的絕大多數程式碼都來自 pytorch/examples 中的 DCGAN 實現,本文件將詳細解釋實現過程,並闡明該模型的工作原理和原因。但請放心,無需 GAN 的先驗知識,不過初學者可能需要花一些時間思考其底層實際發生的事情。此外,為了節省時間,有一兩個 GPU 會很有幫助。讓我們從頭開始。

生成對抗網路

什麼是 GAN?

GAN 是一種框架,用於教導深度學習模型捕捉訓練資料分佈,以便我們可以從該相同分佈中生成新資料。GAN 由 Ian Goodfellow 於 2014 年發明,並首次在論文 Generative Adversarial Nets 中描述。它們由兩個不同的模型組成:一個*生成器*(generator) 和一個*判別器*(discriminator)。生成器的任務是生成看起來像訓練影像的“偽造”影像。判別器的任務是檢視一張影像並輸出它是否是真實的訓練影像或來自生成器的偽造影像。在訓練過程中,生成器不斷嘗試透過生成越來越好的偽造品來勝過判別器,而判別器則致力於成為更好的偵探,正確分類真實影像和偽造影像。這場博弈的均衡點是生成器生成看起來像是直接來自訓練資料的完美偽造品,而判別器則總是以 50% 的置信度猜測生成器的輸出是真實還是偽造的。

現在,讓我們定義一些將在整個教程中使用的符號,首先是判別器。令 \(x\) 表示影像資料。 \(D(x)\) 是判別器網路,它輸出 \(x\) 來自訓練資料而非生成器的(標量)機率。在這裡,由於我們處理的是影像,\(D(x)\) 的輸入是 CHW 大小為 3x64x64 的影像。直觀上,當 \(x\) 來自訓練資料時,\(D(x)\) 應該很高;當 \(x\) 來自生成器時,\(D(x)\) 應該很低。 \(D(x)\) 也可以被視為傳統的二分類器。

對於生成器的符號,令 \(z\) 為從標準正態分佈中取樣的潛在空間向量。 \(G(z)\) 表示生成器函式,它將潛在向量 \(z\) 對映到資料空間。 \(G\) 的目標是估計訓練資料來源的分佈 (\(p_{data}\)),以便它可以從該估計分佈 (\(p_g\)) 中生成偽造樣本。

因此, \(D(G(z))\) 是生成器 \(G\) 的輸出是真實影像的機率(標量)。如 Goodfellow 的論文中所述,\(D\)\(G\) 進行一場最小最大博弈,其中 \(D\) 試圖最大化其正確分類真實和偽造樣本的機率 (\(logD(x)\)),而 \(G\) 試圖最小化 \(D\) 將其輸出預測為偽造的機率 (\(log(1-D(G(z)))\))。根據論文,GAN 的損失函式是

\[\underset{G}{\text{min}} \underset{D}{\text{max}}V(D,G) = \mathbb{E}_{x\sim p_{data}(x)}\big[logD(x)\big] + \mathbb{E}_{z\sim p_{z}(z)}\big[log(1-D(G(z)))\big] \]

理論上,這場最小最大博弈的解是 \(p_g = p_{data}\),判別器隨機猜測輸入是真實還是偽造的。然而,GAN 的收斂理論仍在積極研究中,在現實中模型並不總是訓練到這一點。

什麼是 DCGAN?

DCGAN 是上述 GAN 的直接擴充套件,不同之處在於它在判別器和生成器中分別明確使用了卷積層和轉置卷積層。它最初由 Radford 等人在論文 Unsupervised Representation Learning With Deep Convolutional Generative Adversarial Networks 中描述。判別器由帶步長的 卷積層、批標準化層和 LeakyReLU 啟用函式組成。輸入是 3x64x64 的輸入影像,輸出是輸入來自真實資料分佈的標量機率。生成器由 轉置卷積層、批標準化層和 ReLU 啟用函式組成。輸入是潛在向量 \(z\),該向量從標準正態分佈中抽取,輸出是 3x64x64 的 RGB 影像。帶步長的轉置卷積層允許將潛在向量轉換為與影像具有相同形狀的體積。在論文中,作者還給出了一些關於如何設定最佳化器、如何計算損失函式以及如何初始化模型權重的技巧,所有這些將在接下來的部分中解釋。

#%matplotlib inline
import argparse
import os
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML

# Set random seed for reproducibility
manualSeed = 999
#manualSeed = random.randint(1, 10000) # use if you want new results
print("Random Seed: ", manualSeed)
random.seed(manualSeed)
torch.manual_seed(manualSeed)
torch.use_deterministic_algorithms(True) # Needed for reproducible results
Random Seed:  999

輸入引數

讓我們為執行定義一些輸入引數

  • dataroot - 資料集資料夾根目錄的路徑。我們將在下一節詳細討論資料集。

  • workers - 使用 DataLoader 載入資料的工作執行緒數量。

  • batch_size - 訓練中使用的批次大小。DCGAN 論文使用 128 的批次大小。

  • image_size - 用於訓練的影像的空間尺寸。本實現預設為 64x64。如果需要其他尺寸,必須更改 D 和 G 的結構。更多詳情請參見此處

  • nc - 輸入影像中的顏色通道數量。對於彩色影像,此值為 3。

  • nz - 潛在向量的長度。

  • ngf - 與透過生成器傳播的特徵圖深度有關。

  • ndf - 設定透過判別器傳播的特徵圖深度。

  • num_epochs - 要執行的訓練輪次數量。訓練時間越長可能導致結果越好,但也會花費更長時間。

  • lr - 訓練的學習率。如 DCGAN 論文所述,此值應為 0.0002。

  • beta1 - Adam 最佳化器的 beta1 超引數。如論文所述,此值應為 0.5。

  • ngpu - 可用 GPU 數量。如果此值為 0,程式碼將在 CPU 模式下執行。如果此值大於 0,將在相應數量的 GPU 上執行。

# Root directory for dataset
dataroot = "data/celeba"

# Number of workers for dataloader
workers = 2

# Batch size during training
batch_size = 128

# Spatial size of training images. All images will be resized to this
#   size using a transformer.
image_size = 64

# Number of channels in the training images. For color images this is 3
nc = 3

# Size of z latent vector (i.e. size of generator input)
nz = 100

# Size of feature maps in generator
ngf = 64

# Size of feature maps in discriminator
ndf = 64

# Number of training epochs
num_epochs = 5

# Learning rate for optimizers
lr = 0.0002

# Beta1 hyperparameter for Adam optimizers
beta1 = 0.5

# Number of GPUs available. Use 0 for CPU mode.
ngpu = 1

資料

在本教程中,我們將使用 Celeb-A 人臉資料集,可以在連結的網站下載,或在 Google Drive 下載。資料集將下載為一個名為 img_align_celeba.zip 的檔案。下載後,建立一個名為 celeba 的目錄,並將 zip 檔案解壓到該目錄中。然後,將本 notebook 的 dataroot 輸入設定為您剛剛建立的 celeba 目錄。結果目錄結構應為

/path/to/celeba
    -> img_align_celeba
        -> 188242.jpg
        -> 173822.jpg
        -> 284702.jpg
        -> 537394.jpg
           ...

這是重要的一步,因為我們將使用 ImageFolder 資料集類,該類要求資料集根目錄中必須有子目錄。現在,我們可以建立資料集,建立資料載入器,設定執行裝置,最後視覺化一些訓練資料。

# We can use an image folder dataset the way we have it setup.
# Create the dataset
dataset = dset.ImageFolder(root=dataroot,
                           transform=transforms.Compose([
                               transforms.Resize(image_size),
                               transforms.CenterCrop(image_size),
                               transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                           ]))
# Create the dataloader
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
                                         shuffle=True, num_workers=workers)

# Decide which device we want to run on
device = torch.device("cuda:0" if (torch.cuda.is_available() and ngpu > 0) else "cpu")

# Plot some training images
real_batch = next(iter(dataloader))
plt.figure(figsize=(8,8))
plt.axis("off")
plt.title("Training Images")
plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=2, normalize=True).cpu(),(1,2,0)))
plt.show()
Training Images

實現

設定好輸入引數並準備好資料集後,我們現在可以進入實現環節。我們將從權重初始化策略開始,然後詳細討論生成器、判別器、損失函式和訓練迴圈。

權重初始化

根據 DCGAN 論文,作者指定所有模型權重應從均值 mean=0、標準差 stdev=0.02 的正態分佈中隨機初始化。weights_init 函式接受一個已初始化的模型作為輸入,並重新初始化所有卷積層、轉置卷積層和批標準化層以滿足此標準。此函式在模型初始化後立即應用於模型。

# custom weights initialization called on ``netG`` and ``netD``
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

生成器

生成器 \(G\) 旨在將潛在空間向量 (\(z\)) 對映到資料空間。由於我們的資料是影像,將 \(z\) 轉換為資料空間最終意味著建立一個與訓練影像具有相同尺寸 (即 3x64x64) 的 RGB 影像。在實踐中,這透過一系列帶步長的二維轉置卷積層來實現,每個層都與一個二維批標準化層和一個 relu 啟用函式配對。生成器的輸出透過 tanh 函式饋送,將其返回到 \([-1,1]\) 的輸入資料範圍。值得注意的是,轉置卷積層之後存在批標準化函式,這是 DCGAN 論文的關鍵貢獻之一。這些層有助於訓練期間的梯度流。DCGAN 論文中的生成器影像如下所示。

dcgan_generator

請注意,我們在輸入引數部分設定的輸入 (nzngfnc) 如何影響程式碼中的生成器架構。nz 是 z 輸入向量的長度,ngf 與透過生成器傳播的特徵圖大小有關,nc 是輸出影像中的通道數(對於 RGB 影像設定為 3)。下面是生成器的程式碼。

# Generator Code

class Generator(nn.Module):
    def __init__(self, ngpu):
        super(Generator, self).__init__()
        self.ngpu = ngpu
        self.main = nn.Sequential(
            # input is Z, going into a convolution
            nn.ConvTranspose2d( nz, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),
            # state size. ``(ngf*8) x 4 x 4``
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            # state size. ``(ngf*4) x 8 x 8``
            nn.ConvTranspose2d( ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            # state size. ``(ngf*2) x 16 x 16``
            nn.ConvTranspose2d( ngf * 2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            # state size. ``(ngf) x 32 x 32``
            nn.ConvTranspose2d( ngf, nc, 4, 2, 1, bias=False),
            nn.Tanh()
            # state size. ``(nc) x 64 x 64``
        )

    def forward(self, input):
        return self.main(input)

現在,我們可以例項化生成器並應用 weights_init 函式。檢視打印出的模型,瞭解生成器物件的結構。

# Create the generator
netG = Generator(ngpu).to(device)

# Handle multi-GPU if desired
if (device.type == 'cuda') and (ngpu > 1):
    netG = nn.DataParallel(netG, list(range(ngpu)))

# Apply the ``weights_init`` function to randomly initialize all weights
#  to ``mean=0``, ``stdev=0.02``.
netG.apply(weights_init)

# Print the model
print(netG)
Generator(
  (main): Sequential(
    (0): ConvTranspose2d(100, 512, kernel_size=(4, 4), stride=(1, 1), bias=False)
    (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (7): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): ReLU(inplace=True)
    (9): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (10): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (11): ReLU(inplace=True)
    (12): ConvTranspose2d(64, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (13): Tanh()
  )
)

判別器

如前所述,判別器 \(D\) 是一個二分類網路,它將影像作為輸入,並輸出輸入影像是真實影像(而不是偽造影像)的標量機率。在這裡,\(D\) 接收一個 3x64x64 的輸入影像,透過一系列 Conv2d、BatchNorm2d 和 LeakyReLU 層進行處理,並透過 Sigmoid 啟用函式輸出最終機率。如果問題需要,此架構可以擴充套件更多層,但使用帶步長的卷積、批標準化和 LeakyReLU 具有重要意義。DCGAN 論文提到,使用帶步長的卷積而不是池化進行下采樣是一個好做法,因為它讓網路學習自己的池化函式。此外,批標準化和 LeakyReLU 函式促進了健康的梯度流,這對於 \(G\)\(D\) 的學習過程至關重要。

判別器程式碼

class Discriminator(nn.Module):
    def __init__(self, ngpu):
        super(Discriminator, self).__init__()
        self.ngpu = ngpu
        self.main = nn.Sequential(
            # input is ``(nc) x 64 x 64``
            nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. ``(ndf) x 32 x 32``
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. ``(ndf*2) x 16 x 16``
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. ``(ndf*4) x 8 x 8``
            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. ``(ndf*8) x 4 x 4``
            nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, input):
        return self.main(input)

現在,與生成器一樣,我們可以建立判別器,應用 weights_init 函式,並列印模型的結構。

# Create the Discriminator
netD = Discriminator(ngpu).to(device)

# Handle multi-GPU if desired
if (device.type == 'cuda') and (ngpu > 1):
    netD = nn.DataParallel(netD, list(range(ngpu)))

# Apply the ``weights_init`` function to randomly initialize all weights
# like this: ``to mean=0, stdev=0.2``.
netD.apply(weights_init)

# Print the model
print(netD)
Discriminator(
  (main): Sequential(
    (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
    (2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (4): LeakyReLU(negative_slope=0.2, inplace=True)
    (5): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (6): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (7): LeakyReLU(negative_slope=0.2, inplace=True)
    (8): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (9): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (10): LeakyReLU(negative_slope=0.2, inplace=True)
    (11): Conv2d(512, 1, kernel_size=(4, 4), stride=(1, 1), bias=False)
    (12): Sigmoid()
  )
)

損失函式和最佳化器

設定好 \(D\)\(G\) 後,我們可以透過損失函式和最佳化器指定它們的學習方式。我們將使用二元交叉熵損失 (BCELoss) 函式,該函式在 PyTorch 中定義為

\[\ell(x, y) = L = \{l_1,\dots,l_N\}^\top, \quad l_n = - \left[ y_n \cdot \log x_n + (1 - y_n) \cdot \log (1 - x_n) \right] \]

請注意,此函式提供了目標函式中兩個對數分量(即 \(log(D(x))\)\(log(1-D(G(z)))\))的計算。我們可以透過 \(y\) 輸入指定使用 BCE 方程的哪個部分。這將在即將到來的訓練迴圈中實現,但重要的是要理解我們如何只需更改 \(y\)(即真實標籤)即可選擇我們希望計算哪個分量。

接下來,我們將真實標籤定義為 1,將偽造標籤定義為 0。這些標籤將在計算 \(D\)\(G\) 的損失時使用,這也是原始 GAN 論文中使用的約定。最後,我們設定了兩個獨立的最佳化器,一個用於 \(D\),一個用於 \(G\)。如 DCGAN 論文所述,兩者都是 Adam 最佳化器,學習率為 0.0002,Beta1 = 0.5。為了跟蹤生成器的學習進展,我們將生成一個固定的批次潛在向量,這些向量從高斯分佈中抽取(即 fixed_noise)。在訓練迴圈中,我們將定期將此 fixed_noise 輸入到 \(G\) 中,並在迭代過程中看到影像從噪聲中形成。

# Initialize the ``BCELoss`` function
criterion = nn.BCELoss()

# Create batch of latent vectors that we will use to visualize
#  the progression of the generator
fixed_noise = torch.randn(64, nz, 1, 1, device=device)

# Establish convention for real and fake labels during training
real_label = 1.
fake_label = 0.

# Setup Adam optimizers for both G and D
optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))

訓練

最後,既然我們已經定義了 GAN 框架的所有組成部分,我們就可以訓練它了。請注意,訓練 GAN 在某種程度上是一門藝術,因為不正確的超引數設定會導致模式崩潰,且很難解釋哪裡出了問題。在這裡,我們將密切遵循 Goodfellow 的論文中的演算法 1,同時遵守 ganhacks 中展示的一些最佳實踐。具體來說,我們將“為真實和偽造影像構建不同的小批次”,並調整 G 的目標函式以最大化 \(log(D(G(z)))\)。訓練分為兩個主要部分。第一部分更新判別器,第二部分更新生成器。

第一部分 - 訓練判別器

回顧一下,訓練判別器的目標是最大化正確分類給定輸入是真實還是偽造的機率。用 Goodfellow 的話說,我們希望“透過提升其隨機梯度來更新判別器”。實際上,我們希望最大化 \(log(D(x)) + log(1-D(G(z)))\)。由於 ganhacks 中建議使用獨立的小批次,我們將分兩步計算它。首先,我們將從訓練集中構建一個真實樣本的批次,透過 \(D\) 進行前向傳播,計算損失 (\(log(D(x))\)),然後在反向傳播中計算梯度。其次,我們將使用當前的生成器構建一個偽造樣本的批次,透過 \(D\) 對此批次進行前向傳播,計算損失 (\(log(1-D(G(z)))\)),並透過反向傳播*累積*梯度。現在,在累積了來自所有真實批次和所有偽造批次的梯度後,我們呼叫判別器最佳化器的一個步長。

第二部分 - 訓練生成器

如原始論文所述,我們希望透過最小化 \(log(1-D(G(z)))\) 來訓練生成器,以努力生成更好的偽造品。如前所述,Goodfellow 表明這種方式無法提供足夠的梯度,特別是在學習過程的早期。作為修正,我們轉而希望最大化 \(log(D(G(z)))\)。在程式碼中,我們透過以下方式實現這一點:使用判別器對第一部分中的生成器輸出進行分類,*使用真實標籤作為真實值*計算 G 的損失,在反向傳播中計算 G 的梯度,最後使用最佳化器步長更新 G 的引數。使用真實標籤作為損失函式的真實值可能看起來有悖常理,但這允許我們使用 BCELoss 中的 \(log(x)\) 部分(而不是 \(log(1-x)\) 部分),這正是我們想要的。

最後,我們將進行一些統計報告,並在每個輪次結束時將我們的 fixed_noise 批次透過生成器,以便直觀地跟蹤 G 訓練的進展。報告的訓練統計資訊如下

  • Loss_D - 判別器損失,計算為所有真實和所有偽造批次損失的總和 (\(log(D(x)) + log(1 - D(G(z)))\))。

  • Loss_G - 生成器損失,計算為 \(log(D(G(z)))\)

  • D(x) - 判別器對所有真實樣本批次的平均輸出(跨批次)。這個值理論上應該從接近 1 開始,然後隨著生成器 (G) 的改進而收斂到 0.5。思考一下這是為什麼。

  • D(G(z)) - 判別器對所有生成(假)樣本批次的平均輸出。第一個數字是判別器 (D) 更新之前的值,第二個數字是 D 更新之後的值。這些數字應該從接近 0 開始,然後隨著 G 的改進而收斂到 0.5。思考一下這是為什麼。

注意:這一步可能需要一段時間,具體取決於你執行的 epoch 數量以及你是否從資料集中移除了部分資料。

# Training Loop

# Lists to keep track of progress
img_list = []
G_losses = []
D_losses = []
iters = 0

print("Starting Training Loop...")
# For each epoch
for epoch in range(num_epochs):
    # For each batch in the dataloader
    for i, data in enumerate(dataloader, 0):

        ############################
        # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
        ###########################
        ## Train with all-real batch
        netD.zero_grad()
        # Format batch
        real_cpu = data[0].to(device)
        b_size = real_cpu.size(0)
        label = torch.full((b_size,), real_label, dtype=torch.float, device=device)
        # Forward pass real batch through D
        output = netD(real_cpu).view(-1)
        # Calculate loss on all-real batch
        errD_real = criterion(output, label)
        # Calculate gradients for D in backward pass
        errD_real.backward()
        D_x = output.mean().item()

        ## Train with all-fake batch
        # Generate batch of latent vectors
        noise = torch.randn(b_size, nz, 1, 1, device=device)
        # Generate fake image batch with G
        fake = netG(noise)
        label.fill_(fake_label)
        # Classify all fake batch with D
        output = netD(fake.detach()).view(-1)
        # Calculate D's loss on the all-fake batch
        errD_fake = criterion(output, label)
        # Calculate the gradients for this batch, accumulated (summed) with previous gradients
        errD_fake.backward()
        D_G_z1 = output.mean().item()
        # Compute error of D as sum over the fake and the real batches
        errD = errD_real + errD_fake
        # Update D
        optimizerD.step()

        ############################
        # (2) Update G network: maximize log(D(G(z)))
        ###########################
        netG.zero_grad()
        label.fill_(real_label)  # fake labels are real for generator cost
        # Since we just updated D, perform another forward pass of all-fake batch through D
        output = netD(fake).view(-1)
        # Calculate G's loss based on this output
        errG = criterion(output, label)
        # Calculate gradients for G
        errG.backward()
        D_G_z2 = output.mean().item()
        # Update G
        optimizerG.step()

        # Output training stats
        if i % 50 == 0:
            print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
                  % (epoch, num_epochs, i, len(dataloader),
                     errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))

        # Save Losses for plotting later
        G_losses.append(errG.item())
        D_losses.append(errD.item())

        # Check how the generator is doing by saving G's output on fixed_noise
        if (iters % 500 == 0) or ((epoch == num_epochs-1) and (i == len(dataloader)-1)):
            with torch.no_grad():
                fake = netG(fixed_noise).detach().cpu()
            img_list.append(vutils.make_grid(fake, padding=2, normalize=True))

        iters += 1
Starting Training Loop...
[0/5][0/1583]   Loss_D: 1.4640  Loss_G: 6.9366  D(x): 0.7143    D(G(z)): 0.5877 / 0.0017
[0/5][50/1583]  Loss_D: 1.3473  Loss_G: 26.0342 D(x): 0.9521    D(G(z)): 0.6274 / 0.0000
[0/5][100/1583] Loss_D: 0.4592  Loss_G: 5.4997  D(x): 0.8385    D(G(z)): 0.1582 / 0.0096
[0/5][150/1583] Loss_D: 0.4003  Loss_G: 4.9369  D(x): 0.9549    D(G(z)): 0.2696 / 0.0112
[0/5][200/1583] Loss_D: 1.0697  Loss_G: 7.1552  D(x): 0.9390    D(G(z)): 0.5277 / 0.0016
[0/5][250/1583] Loss_D: 0.2654  Loss_G: 3.7994  D(x): 0.8486    D(G(z)): 0.0386 / 0.0498
[0/5][300/1583] Loss_D: 0.8607  Loss_G: 1.7948  D(x): 0.5738    D(G(z)): 0.0922 / 0.2307
[0/5][350/1583] Loss_D: 0.4578  Loss_G: 3.7834  D(x): 0.8936    D(G(z)): 0.2463 / 0.0383
[0/5][400/1583] Loss_D: 0.4972  Loss_G: 5.1323  D(x): 0.9523    D(G(z)): 0.3053 / 0.0126
[0/5][450/1583] Loss_D: 1.3933  Loss_G: 1.9958  D(x): 0.4034    D(G(z)): 0.0379 / 0.1982
[0/5][500/1583] Loss_D: 0.9834  Loss_G: 2.8194  D(x): 0.5936    D(G(z)): 0.1847 / 0.1077
[0/5][550/1583] Loss_D: 0.5884  Loss_G: 7.6739  D(x): 0.8827    D(G(z)): 0.3119 / 0.0013
[0/5][600/1583] Loss_D: 0.5217  Loss_G: 5.6103  D(x): 0.9342    D(G(z)): 0.3160 / 0.0085
[0/5][650/1583] Loss_D: 0.4254  Loss_G: 3.8598  D(x): 0.8154    D(G(z)): 0.1438 / 0.0317
[0/5][700/1583] Loss_D: 0.3483  Loss_G: 4.2089  D(x): 0.7952    D(G(z)): 0.0355 / 0.0240
[0/5][750/1583] Loss_D: 0.5566  Loss_G: 6.2280  D(x): 0.9148    D(G(z)): 0.3040 / 0.0042
[0/5][800/1583] Loss_D: 0.2617  Loss_G: 5.5604  D(x): 0.8322    D(G(z)): 0.0383 / 0.0080
[0/5][850/1583] Loss_D: 1.6397  Loss_G: 10.7162 D(x): 0.9620    D(G(z)): 0.6981 / 0.0002
[0/5][900/1583] Loss_D: 1.0194  Loss_G: 5.4787  D(x): 0.8427    D(G(z)): 0.4678 / 0.0094
[0/5][950/1583] Loss_D: 0.4182  Loss_G: 4.7106  D(x): 0.8578    D(G(z)): 0.1802 / 0.0222
[0/5][1000/1583]        Loss_D: 0.4757  Loss_G: 3.8595  D(x): 0.8051    D(G(z)): 0.1514 / 0.0416
[0/5][1050/1583]        Loss_D: 0.6044  Loss_G: 2.9149  D(x): 0.7372    D(G(z)): 0.1696 / 0.0809
[0/5][1100/1583]        Loss_D: 0.7655  Loss_G: 2.3512  D(x): 0.6174    D(G(z)): 0.0593 / 0.1484
[0/5][1150/1583]        Loss_D: 0.7374  Loss_G: 3.1968  D(x): 0.6097    D(G(z)): 0.0709 / 0.0777
[0/5][1200/1583]        Loss_D: 0.6484  Loss_G: 4.1837  D(x): 0.8723    D(G(z)): 0.3046 / 0.0323
[0/5][1250/1583]        Loss_D: 0.6404  Loss_G: 4.9987  D(x): 0.8959    D(G(z)): 0.3395 / 0.0124
[0/5][1300/1583]        Loss_D: 0.7700  Loss_G: 7.7520  D(x): 0.9699    D(G(z)): 0.4454 / 0.0011
[0/5][1350/1583]        Loss_D: 0.4115  Loss_G: 3.8996  D(x): 0.8038    D(G(z)): 0.1153 / 0.0301
[0/5][1400/1583]        Loss_D: 0.5865  Loss_G: 3.3128  D(x): 0.8186    D(G(z)): 0.2586 / 0.0521
[0/5][1450/1583]        Loss_D: 0.7625  Loss_G: 2.5499  D(x): 0.6857    D(G(z)): 0.2169 / 0.1131
[0/5][1500/1583]        Loss_D: 1.3006  Loss_G: 3.9234  D(x): 0.4019    D(G(z)): 0.0053 / 0.0425
[0/5][1550/1583]        Loss_D: 1.0234  Loss_G: 2.1976  D(x): 0.4556    D(G(z)): 0.0291 / 0.1659
[1/5][0/1583]   Loss_D: 0.3606  Loss_G: 3.7421  D(x): 0.8785    D(G(z)): 0.1770 / 0.0377
[1/5][50/1583]  Loss_D: 0.6186  Loss_G: 2.6328  D(x): 0.6461    D(G(z)): 0.0559 / 0.1141
[1/5][100/1583] Loss_D: 0.6551  Loss_G: 3.9456  D(x): 0.6392    D(G(z)): 0.0641 / 0.0429
[1/5][150/1583] Loss_D: 0.7882  Loss_G: 6.6105  D(x): 0.9553    D(G(z)): 0.4592 / 0.0031
[1/5][200/1583] Loss_D: 0.5069  Loss_G: 2.1326  D(x): 0.7197    D(G(z)): 0.0957 / 0.1621
[1/5][250/1583] Loss_D: 0.4229  Loss_G: 2.8329  D(x): 0.7680    D(G(z)): 0.0920 / 0.0915
[1/5][300/1583] Loss_D: 0.3388  Loss_G: 3.2621  D(x): 0.8501    D(G(z)): 0.1096 / 0.0758
[1/5][350/1583] Loss_D: 0.2864  Loss_G: 4.5487  D(x): 0.9182    D(G(z)): 0.1608 / 0.0184
[1/5][400/1583] Loss_D: 0.3158  Loss_G: 3.3892  D(x): 0.8432    D(G(z)): 0.1100 / 0.0554
[1/5][450/1583] Loss_D: 1.2332  Loss_G: 8.1937  D(x): 0.9940    D(G(z)): 0.6184 / 0.0008
[1/5][500/1583] Loss_D: 0.4001  Loss_G: 3.4084  D(x): 0.8584    D(G(z)): 0.1890 / 0.0472
[1/5][550/1583] Loss_D: 1.5110  Loss_G: 2.5652  D(x): 0.3283    D(G(z)): 0.0121 / 0.1440
[1/5][600/1583] Loss_D: 0.5324  Loss_G: 2.1393  D(x): 0.6765    D(G(z)): 0.0592 / 0.1596
[1/5][650/1583] Loss_D: 0.5493  Loss_G: 1.9572  D(x): 0.6725    D(G(z)): 0.0439 / 0.1998
[1/5][700/1583] Loss_D: 0.6842  Loss_G: 3.5358  D(x): 0.7578    D(G(z)): 0.2744 / 0.0450
[1/5][750/1583] Loss_D: 1.5829  Loss_G: 0.7034  D(x): 0.3024    D(G(z)): 0.0307 / 0.5605
[1/5][800/1583] Loss_D: 0.6566  Loss_G: 1.7996  D(x): 0.6073    D(G(z)): 0.0531 / 0.2245
[1/5][850/1583] Loss_D: 0.4141  Loss_G: 2.7758  D(x): 0.8372    D(G(z)): 0.1650 / 0.0919
[1/5][900/1583] Loss_D: 0.7488  Loss_G: 4.1499  D(x): 0.8385    D(G(z)): 0.3698 / 0.0261
[1/5][950/1583] Loss_D: 1.0031  Loss_G: 1.6256  D(x): 0.4876    D(G(z)): 0.0742 / 0.2805
[1/5][1000/1583]        Loss_D: 0.3197  Loss_G: 3.5365  D(x): 0.9197    D(G(z)): 0.1881 / 0.0426
[1/5][1050/1583]        Loss_D: 0.4852  Loss_G: 2.6088  D(x): 0.7459    D(G(z)): 0.1400 / 0.0992
[1/5][1100/1583]        Loss_D: 1.4441  Loss_G: 4.7499  D(x): 0.9102    D(G(z)): 0.6548 / 0.0160
[1/5][1150/1583]        Loss_D: 0.8372  Loss_G: 3.3722  D(x): 0.8911    D(G(z)): 0.4280 / 0.0614
[1/5][1200/1583]        Loss_D: 0.3625  Loss_G: 3.4286  D(x): 0.7971    D(G(z)): 0.0875 / 0.0520
[1/5][1250/1583]        Loss_D: 1.7122  Loss_G: 1.5450  D(x): 0.2588    D(G(z)): 0.0382 / 0.3596
[1/5][1300/1583]        Loss_D: 0.3812  Loss_G: 2.9381  D(x): 0.9070    D(G(z)): 0.2145 / 0.0863
[1/5][1350/1583]        Loss_D: 0.8282  Loss_G: 2.3004  D(x): 0.7097    D(G(z)): 0.2855 / 0.1390
[1/5][1400/1583]        Loss_D: 0.6341  Loss_G: 2.8587  D(x): 0.8392    D(G(z)): 0.3181 / 0.0790
[1/5][1450/1583]        Loss_D: 0.6178  Loss_G: 1.4617  D(x): 0.6149    D(G(z)): 0.0315 / 0.2863
[1/5][1500/1583]        Loss_D: 0.5564  Loss_G: 2.6619  D(x): 0.7793    D(G(z)): 0.2217 / 0.0940
[1/5][1550/1583]        Loss_D: 0.6675  Loss_G: 1.9683  D(x): 0.6435    D(G(z)): 0.1213 / 0.1833
[2/5][0/1583]   Loss_D: 0.5963  Loss_G: 2.2106  D(x): 0.6437    D(G(z)): 0.0701 / 0.1452
[2/5][50/1583]  Loss_D: 1.5170  Loss_G: 4.4082  D(x): 0.9217    D(G(z)): 0.7074 / 0.0197
[2/5][100/1583] Loss_D: 0.4643  Loss_G: 2.4072  D(x): 0.7568    D(G(z)): 0.1341 / 0.1100
[2/5][150/1583] Loss_D: 0.6131  Loss_G: 3.7405  D(x): 0.9036    D(G(z)): 0.3549 / 0.0355
[2/5][200/1583] Loss_D: 0.5679  Loss_G: 2.1571  D(x): 0.6892    D(G(z)): 0.1197 / 0.1484
[2/5][250/1583] Loss_D: 0.6073  Loss_G: 1.5544  D(x): 0.6573    D(G(z)): 0.1079 / 0.2677
[2/5][300/1583] Loss_D: 0.6738  Loss_G: 3.6060  D(x): 0.8955    D(G(z)): 0.3995 / 0.0362
[2/5][350/1583] Loss_D: 0.5477  Loss_G: 2.9593  D(x): 0.8822    D(G(z)): 0.3133 / 0.0710
[2/5][400/1583] Loss_D: 0.4689  Loss_G: 2.2539  D(x): 0.7419    D(G(z)): 0.1151 / 0.1397
[2/5][450/1583] Loss_D: 0.4517  Loss_G: 2.5200  D(x): 0.7845    D(G(z)): 0.1592 / 0.1018
[2/5][500/1583] Loss_D: 0.5757  Loss_G: 2.5563  D(x): 0.7272    D(G(z)): 0.1838 / 0.1009
[2/5][550/1583] Loss_D: 0.5867  Loss_G: 3.2838  D(x): 0.8595    D(G(z)): 0.3113 / 0.0504
[2/5][600/1583] Loss_D: 0.8449  Loss_G: 3.9811  D(x): 0.9381    D(G(z)): 0.4823 / 0.0275
[2/5][650/1583] Loss_D: 0.5224  Loss_G: 1.6869  D(x): 0.7184    D(G(z)): 0.1308 / 0.2290
[2/5][700/1583] Loss_D: 0.7586  Loss_G: 1.4822  D(x): 0.5316    D(G(z)): 0.0298 / 0.2875
[2/5][750/1583] Loss_D: 0.9340  Loss_G: 3.6577  D(x): 0.8168    D(G(z)): 0.4633 / 0.0371
[2/5][800/1583] Loss_D: 0.9857  Loss_G: 4.4083  D(x): 0.9287    D(G(z)): 0.5446 / 0.0168
[2/5][850/1583] Loss_D: 0.5434  Loss_G: 3.0087  D(x): 0.8789    D(G(z)): 0.3096 / 0.0642
[2/5][900/1583] Loss_D: 0.9124  Loss_G: 3.5092  D(x): 0.7808    D(G(z)): 0.4230 / 0.0443
[2/5][950/1583] Loss_D: 0.7267  Loss_G: 4.0834  D(x): 0.8874    D(G(z)): 0.4076 / 0.0246
[2/5][1000/1583]        Loss_D: 0.6258  Loss_G: 1.8459  D(x): 0.7038    D(G(z)): 0.1866 / 0.1969
[2/5][1050/1583]        Loss_D: 0.9129  Loss_G: 1.5444  D(x): 0.5195    D(G(z)): 0.1200 / 0.2599
[2/5][1100/1583]        Loss_D: 0.6557  Loss_G: 3.6323  D(x): 0.9009    D(G(z)): 0.3815 / 0.0375
[2/5][1150/1583]        Loss_D: 0.7832  Loss_G: 0.9305  D(x): 0.5585    D(G(z)): 0.0914 / 0.4358
[2/5][1200/1583]        Loss_D: 0.4719  Loss_G: 3.0284  D(x): 0.8842    D(G(z)): 0.2650 / 0.0649
[2/5][1250/1583]        Loss_D: 0.4804  Loss_G: 2.1393  D(x): 0.7515    D(G(z)): 0.1566 / 0.1427
[2/5][1300/1583]        Loss_D: 0.9386  Loss_G: 3.4696  D(x): 0.7883    D(G(z)): 0.4487 / 0.0442
[2/5][1350/1583]        Loss_D: 0.4987  Loss_G: 1.7055  D(x): 0.7410    D(G(z)): 0.1531 / 0.2218
[2/5][1400/1583]        Loss_D: 0.9054  Loss_G: 4.0416  D(x): 0.9176    D(G(z)): 0.4947 / 0.0277
[2/5][1450/1583]        Loss_D: 0.5133  Loss_G: 2.5319  D(x): 0.7986    D(G(z)): 0.2195 / 0.1015
[2/5][1500/1583]        Loss_D: 0.7425  Loss_G: 3.2894  D(x): 0.8523    D(G(z)): 0.3979 / 0.0532
[2/5][1550/1583]        Loss_D: 0.9294  Loss_G: 1.2275  D(x): 0.4648    D(G(z)): 0.0324 / 0.3599
[3/5][0/1583]   Loss_D: 0.9583  Loss_G: 1.1608  D(x): 0.4547    D(G(z)): 0.0335 / 0.3771
[3/5][50/1583]  Loss_D: 0.8272  Loss_G: 1.7047  D(x): 0.5949    D(G(z)): 0.1800 / 0.2317
[3/5][100/1583] Loss_D: 0.5761  Loss_G: 3.6231  D(x): 0.8937    D(G(z)): 0.3400 / 0.0367
[3/5][150/1583] Loss_D: 0.6144  Loss_G: 1.0569  D(x): 0.6247    D(G(z)): 0.0867 / 0.3969
[3/5][200/1583] Loss_D: 0.6703  Loss_G: 1.9168  D(x): 0.7176    D(G(z)): 0.2415 / 0.1718
[3/5][250/1583] Loss_D: 0.4968  Loss_G: 2.6420  D(x): 0.7417    D(G(z)): 0.1436 / 0.1051
[3/5][300/1583] Loss_D: 0.7349  Loss_G: 0.8902  D(x): 0.5783    D(G(z)): 0.0955 / 0.4540
[3/5][350/1583] Loss_D: 0.7369  Loss_G: 2.7404  D(x): 0.7916    D(G(z)): 0.3492 / 0.0855
[3/5][400/1583] Loss_D: 0.6515  Loss_G: 2.8947  D(x): 0.7512    D(G(z)): 0.2633 / 0.0773
[3/5][450/1583] Loss_D: 0.6572  Loss_G: 1.6984  D(x): 0.6819    D(G(z)): 0.1973 / 0.2194
[3/5][500/1583] Loss_D: 0.6705  Loss_G: 1.9898  D(x): 0.6495    D(G(z)): 0.1540 / 0.1725
[3/5][550/1583] Loss_D: 0.5451  Loss_G: 2.4617  D(x): 0.8146    D(G(z)): 0.2534 / 0.1119
[3/5][600/1583] Loss_D: 0.5778  Loss_G: 2.8757  D(x): 0.7501    D(G(z)): 0.2017 / 0.0799
[3/5][650/1583] Loss_D: 0.5724  Loss_G: 2.1972  D(x): 0.7264    D(G(z)): 0.1839 / 0.1486
[3/5][700/1583] Loss_D: 1.2302  Loss_G: 4.5527  D(x): 0.9450    D(G(z)): 0.6299 / 0.0161
[3/5][750/1583] Loss_D: 0.6716  Loss_G: 2.0258  D(x): 0.6407    D(G(z)): 0.1369 / 0.1712
[3/5][800/1583] Loss_D: 0.5515  Loss_G: 2.1855  D(x): 0.7735    D(G(z)): 0.2209 / 0.1395
[3/5][850/1583] Loss_D: 1.6550  Loss_G: 5.3041  D(x): 0.9557    D(G(z)): 0.7417 / 0.0082
[3/5][900/1583] Loss_D: 1.5012  Loss_G: 6.1913  D(x): 0.9689    D(G(z)): 0.6948 / 0.0041
[3/5][950/1583] Loss_D: 0.4969  Loss_G: 2.7285  D(x): 0.8293    D(G(z)): 0.2401 / 0.0846
[3/5][1000/1583]        Loss_D: 0.6695  Loss_G: 1.8164  D(x): 0.6038    D(G(z)): 0.0651 / 0.2048
[3/5][1050/1583]        Loss_D: 0.5644  Loss_G: 1.7400  D(x): 0.7405    D(G(z)): 0.1959 / 0.2097
[3/5][1100/1583]        Loss_D: 0.8853  Loss_G: 1.6351  D(x): 0.5643    D(G(z)): 0.1673 / 0.2550
[3/5][1150/1583]        Loss_D: 1.6414  Loss_G: 0.4946  D(x): 0.2512    D(G(z)): 0.0278 / 0.6601
[3/5][1200/1583]        Loss_D: 0.9217  Loss_G: 0.7732  D(x): 0.4728    D(G(z)): 0.0525 / 0.5116
[3/5][1250/1583]        Loss_D: 0.8338  Loss_G: 1.5767  D(x): 0.5083    D(G(z)): 0.0630 / 0.2551
[3/5][1300/1583]        Loss_D: 0.7982  Loss_G: 3.7209  D(x): 0.8877    D(G(z)): 0.4442 / 0.0361
[3/5][1350/1583]        Loss_D: 0.4342  Loss_G: 2.7570  D(x): 0.8195    D(G(z)): 0.1871 / 0.0820
[3/5][1400/1583]        Loss_D: 0.5983  Loss_G: 3.2100  D(x): 0.8487    D(G(z)): 0.3273 / 0.0523
[3/5][1450/1583]        Loss_D: 0.6556  Loss_G: 2.2088  D(x): 0.6753    D(G(z)): 0.1843 / 0.1396
[3/5][1500/1583]        Loss_D: 1.4272  Loss_G: 4.3660  D(x): 0.9378    D(G(z)): 0.6743 / 0.0210
[3/5][1550/1583]        Loss_D: 0.6038  Loss_G: 2.4530  D(x): 0.7970    D(G(z)): 0.2745 / 0.1143
[4/5][0/1583]   Loss_D: 1.0254  Loss_G: 3.7756  D(x): 0.8369    D(G(z)): 0.5216 / 0.0385
[4/5][50/1583]  Loss_D: 0.6841  Loss_G: 2.9326  D(x): 0.8038    D(G(z)): 0.3241 / 0.0689
[4/5][100/1583] Loss_D: 0.6353  Loss_G: 1.5868  D(x): 0.6100    D(G(z)): 0.0740 / 0.2480
[4/5][150/1583] Loss_D: 2.2435  Loss_G: 3.7620  D(x): 0.9507    D(G(z)): 0.8368 / 0.0387
[4/5][200/1583] Loss_D: 0.6184  Loss_G: 1.8196  D(x): 0.6856    D(G(z)): 0.1562 / 0.1994
[4/5][250/1583] Loss_D: 0.5574  Loss_G: 1.8185  D(x): 0.6915    D(G(z)): 0.1294 / 0.1960
[4/5][300/1583] Loss_D: 0.5771  Loss_G: 3.4464  D(x): 0.9116    D(G(z)): 0.3473 / 0.0430
[4/5][350/1583] Loss_D: 0.5368  Loss_G: 3.0320  D(x): 0.8551    D(G(z)): 0.2862 / 0.0643
[4/5][400/1583] Loss_D: 0.7641  Loss_G: 1.4842  D(x): 0.5538    D(G(z)): 0.0720 / 0.2773
[4/5][450/1583] Loss_D: 0.8868  Loss_G: 4.3501  D(x): 0.9490    D(G(z)): 0.5257 / 0.0173
[4/5][500/1583] Loss_D: 1.0951  Loss_G: 1.1540  D(x): 0.4149    D(G(z)): 0.0316 / 0.3755
[4/5][550/1583] Loss_D: 0.5921  Loss_G: 3.2704  D(x): 0.8644    D(G(z)): 0.3268 / 0.0504
[4/5][600/1583] Loss_D: 1.9290  Loss_G: 0.0810  D(x): 0.2260    D(G(z)): 0.0389 / 0.9277
[4/5][650/1583] Loss_D: 0.5085  Loss_G: 2.6994  D(x): 0.8242    D(G(z)): 0.2472 / 0.0845
[4/5][700/1583] Loss_D: 0.7072  Loss_G: 1.5190  D(x): 0.5953    D(G(z)): 0.0826 / 0.2650
[4/5][750/1583] Loss_D: 0.5817  Loss_G: 2.7395  D(x): 0.8310    D(G(z)): 0.2830 / 0.0853
[4/5][800/1583] Loss_D: 0.4707  Loss_G: 2.3596  D(x): 0.7818    D(G(z)): 0.1635 / 0.1262
[4/5][850/1583] Loss_D: 1.6073  Loss_G: 0.4274  D(x): 0.2876    D(G(z)): 0.0989 / 0.6886
[4/5][900/1583] Loss_D: 0.5918  Loss_G: 2.6160  D(x): 0.7312    D(G(z)): 0.1983 / 0.0984
[4/5][950/1583] Loss_D: 0.7132  Loss_G: 2.7998  D(x): 0.8739    D(G(z)): 0.3872 / 0.0858
[4/5][1000/1583]        Loss_D: 0.8327  Loss_G: 3.9972  D(x): 0.9455    D(G(z)): 0.4914 / 0.0265
[4/5][1050/1583]        Loss_D: 0.4837  Loss_G: 2.4716  D(x): 0.7829    D(G(z)): 0.1792 / 0.1073
[4/5][1100/1583]        Loss_D: 0.7168  Loss_G: 1.8686  D(x): 0.6250    D(G(z)): 0.1307 / 0.1945
[4/5][1150/1583]        Loss_D: 0.5136  Loss_G: 2.0851  D(x): 0.7486    D(G(z)): 0.1614 / 0.1606
[4/5][1200/1583]        Loss_D: 0.4791  Loss_G: 2.0791  D(x): 0.7381    D(G(z)): 0.1236 / 0.1586
[4/5][1250/1583]        Loss_D: 0.5550  Loss_G: 2.5631  D(x): 0.8379    D(G(z)): 0.2759 / 0.1006
[4/5][1300/1583]        Loss_D: 0.3853  Loss_G: 3.4606  D(x): 0.9458    D(G(z)): 0.2601 / 0.0419
[4/5][1350/1583]        Loss_D: 0.6888  Loss_G: 3.2058  D(x): 0.8515    D(G(z)): 0.3644 / 0.0533
[4/5][1400/1583]        Loss_D: 0.8042  Loss_G: 4.1665  D(x): 0.9471    D(G(z)): 0.4778 / 0.0235
[4/5][1450/1583]        Loss_D: 0.4398  Loss_G: 1.8515  D(x): 0.7708    D(G(z)): 0.1293 / 0.1916
[4/5][1500/1583]        Loss_D: 2.1083  Loss_G: 0.3365  D(x): 0.1914    D(G(z)): 0.0699 / 0.7397
[4/5][1550/1583]        Loss_D: 0.6472  Loss_G: 1.5645  D(x): 0.6363    D(G(z)): 0.1143 / 0.2488

結果

最後,讓我們看看我們的訓練結果如何。在這裡,我們將檢視三種不同的結果。首先,我們將看到 D 和 G 的損失在訓練過程中如何變化。其次,我們將視覺化每個 epoch 中 G 在 fixed_noise 批次上的輸出。第三,我們將並排檢視一批真實資料和一批 G 生成的假資料。

損失與訓練迭代次數的關係

下方是 D & G 的損失與訓練迭代次數的關係圖。

plt.figure(figsize=(10,5))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(G_losses,label="G")
plt.plot(D_losses,label="D")
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.show()
Generator and Discriminator Loss During Training

G 進展的視覺化

還記得我們在每個訓練 epoch 結束後都儲存了生成器在 fixed_noise 批次上的輸出嗎?現在,我們可以透過動畫來視覺化 G 的訓練進展。按下播放按鈕開始動畫。

fig = plt.figure(figsize=(8,8))
plt.axis("off")
ims = [[plt.imshow(np.transpose(i,(1,2,0)), animated=True)] for i in img_list]
ani = animation.ArtistAnimation(fig, ims, interval=1000, repeat_delay=1000, blit=True)

HTML(ani.to_jshtml())
dcgan faces tutorial


真實影像 vs. 生成影像

最後,讓我們並排檢視一些真實影像和生成影像。

# Grab a batch of real images from the dataloader
real_batch = next(iter(dataloader))

# Plot the real images
plt.figure(figsize=(15,15))
plt.subplot(1,2,1)
plt.axis("off")
plt.title("Real Images")
plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=5, normalize=True).cpu(),(1,2,0)))

# Plot the fake images from the last epoch
plt.subplot(1,2,2)
plt.axis("off")
plt.title("Fake Images")
plt.imshow(np.transpose(img_list[-1],(1,2,0)))
plt.show()
Real Images, Fake Images

下一步方向

我們已經完成了這次探索之旅,但你可以從這裡開始探索幾個方向。你可以選擇:

  • 訓練更長時間,看看結果能有多好

  • 修改此模型以使用不同的資料集,並可能更改影像的大小和模型架構

  • 在此處檢視一些其他很棒的 GAN 專案 here

  • 建立生成音樂的 GAN

指令碼總執行時間: ( 6 分鐘 35.098 秒)

由 Sphinx-Gallery 生成的相簿

文件

訪問 PyTorch 全面的開發者文件

檢視文件

教程

獲取面向初學者和高階開發者的深度教程

檢視教程

資源

查詢開發資源並獲得問題解答

檢視資源