import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
# 디바이스 설정
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 데이터 로더 설정
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)
# 생성자 클래스
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.model = nn.Sequential(
nn.Linear(100, 256),
nn.LeakyReLU(0.2),
nn.Linear(256, 512),
nn.LeakyReLU(0.2),
nn.Linear(512, 1024),
nn.LeakyReLU(0.2),
nn.Linear(1024, 784),
nn.Tanh()
)
def forward(self, x):
return self.model(x).view(x.size(0), 1, 28, 28)
# 판별자 클래스
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.model = nn.Sequential(
nn.Linear(784, 1024),
nn.LeakyReLU(0.2),
nn.Dropout(0.3),
nn.Linear(1024, 512),
nn.LeakyReLU(0.2),
nn.Dropout(0.3),
nn.Linear(512, 256),
nn.LeakyReLU(0.2),
nn.Dropout(0.3),
nn.Linear(256, 1),
nn.Sigmoid()
)
def forward(self, x):
x = x.view(x.size(0), -1)
return self.model(x)
# 모델 및 디바이스 설정
generator = Generator().to(device)
discriminator = Discriminator().to(device)
# 손실 함수 및 옵티마이저
loss_function = nn.BCELoss()
g_optimizer = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
d_optimizer = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))
# 시각화 함수
def show_generated_images(epoch, generator, show=False, save=False, path='output.png'):
with torch.no_grad():
noise = torch.randn(16, 100, device=device)
fake_images = generator(noise)
fake_images = fake_images.cpu().view(-1, 28, 28) # reshape for plotting
fig, axs = plt.subplots(4, 4, figsize=(4, 4), sharey=True, sharex=True)
for i, ax in enumerate(axs.flatten()):
ax.imshow(fake_images[i], cmap='gray')
ax.axis('off')
plt.tight_layout()
if save:
plt.savefig(path.format(epoch=epoch))
if show:
plt.show()
plt.close()
# 학습 시작
epochs = 1
for epoch in range(epochs):
for i, (images, _) in enumerate(train_loader):
images = images.to(device)
batch_size = images.size(0)
real_labels = torch.ones(batch_size, 1, device=device)
fake_labels = torch.zeros(batch_size, 1, device=device)
# 판별자 훈련
outputs = discriminator(images)
d_loss_real = loss_function(outputs, real_labels)
real_score = outputs
noise = torch.randn(batch_size, 100, device=device)
fake_images = generator(noise)
outputs = discriminator(fake_images.detach())
d_loss_fake = loss_function(outputs, fake_labels)
fake_score = outputs
d_loss = d_loss_real + d_loss_fake
d_optimizer.zero_grad()
d_loss.backward()
d_optimizer.step()
# 생성자 훈련
outputs = discriminator(fake_images)
g_loss = loss_function(outputs, real_labels) # 1이 많다는거 -> fake를 real로 속였다.
g_optimizer.zero_grad()
g_loss.backward()
g_optimizer.step()
# 매 에폭 마다 이미지 생성 및 시각화
show_generated_images(epoch, generator, show=True, save=False)
# Change `show` to False and `save` to True to save the figures
print(f'Epoch [{epoch+1}/{epochs}], d_loss: {d_loss.item():.4f}, g_loss: {g_loss.item():.4f}, D(x): {real_score.mean().item():.2f}, D(G(z)): {fake_score.mean().item():.2f}')
# d_loss: 판별자(discriminator)의 손실 값
# g_loss:생성자(generator)의 손실 값
# D(x): 판별자가 실제 이미지를 진짜라고 판단한 평균 확률. 이 값이 1에 가깝다는 것은 판별자가 실제 이미지를 잘 인식하고 있다는 것을 의미
# D(G(z)):판별자가 생성자가 만든 가짜 이미지를 진짜라고 판단한 평균 확률
# 임의의 이미지 생성 및 시각화
noise = torch.randn(16, 100, device=device)
fake_images = generator(noise)
fake_images = fake_images.cpu().detach().numpy()
fig, axes = plt.subplots(4, 4, figsize=(10, 10))
for i, ax in enumerate(axes.flatten()):
ax.imshow(fake_images[i][0], cmap='gray')
ax.axis('off')
plt.tight_layout()
plt.show()