import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
import matplotlib.pyplot as plt

# 디바이스 설정
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 데이터 로더 설정
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)

# 생성자 클래스
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(100, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 1024),
            nn.LeakyReLU(0.2),
            nn.Linear(1024, 784),
            nn.Tanh()
        )

    def forward(self, x):
        return self.model(x).view(x.size(0), 1, 28, 28)

# 판별자 클래스
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(784, 1024),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(1024, 512),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = x.view(x.size(0), -1)
        return self.model(x)

# 모델 및 디바이스 설정
generator = Generator().to(device)
discriminator = Discriminator().to(device)

# 손실 함수 및 옵티마이저
loss_function = nn.BCELoss()
g_optimizer = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
d_optimizer = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

# 시각화 함수
def show_generated_images(epoch, generator, show=False, save=False, path='output.png'):
    with torch.no_grad():
        noise = torch.randn(16, 100, device=device)
        fake_images = generator(noise)
        fake_images = fake_images.cpu().view(-1, 28, 28)  # reshape for plotting
        fig, axs = plt.subplots(4, 4, figsize=(4, 4), sharey=True, sharex=True)
        for i, ax in enumerate(axs.flatten()):
            ax.imshow(fake_images[i], cmap='gray')
            ax.axis('off')
        plt.tight_layout()
        if save:
            plt.savefig(path.format(epoch=epoch))
        if show:
            plt.show()
        plt.close()

# 학습 시작
epochs = 1
for epoch in range(epochs):
    for i, (images, _) in enumerate(train_loader):
        images = images.to(device)
        batch_size = images.size(0)

        real_labels = torch.ones(batch_size, 1, device=device)
        fake_labels = torch.zeros(batch_size, 1, device=device)

        # 판별자 훈련
        outputs = discriminator(images)
        d_loss_real = loss_function(outputs, real_labels)
        real_score = outputs

        noise = torch.randn(batch_size, 100, device=device)
        fake_images = generator(noise)
        outputs = discriminator(fake_images.detach())
        d_loss_fake = loss_function(outputs, fake_labels)
        fake_score = outputs

        d_loss = d_loss_real + d_loss_fake
        d_optimizer.zero_grad()
        d_loss.backward()
        d_optimizer.step()

        # 생성자 훈련
        outputs = discriminator(fake_images)
        g_loss = loss_function(outputs, real_labels) #  1이 많다는거 -> fake를 real로 속였다.

        g_optimizer.zero_grad()
        g_loss.backward()
        g_optimizer.step()

    # 매 에폭 마다 이미지 생성 및 시각화
    show_generated_images(epoch, generator, show=True, save=False)
    # Change `show` to False and `save` to True to save the figures

    print(f'Epoch [{epoch+1}/{epochs}], d_loss: {d_loss.item():.4f}, g_loss: {g_loss.item():.4f}, D(x): {real_score.mean().item():.2f}, D(G(z)): {fake_score.mean().item():.2f}')

# d_loss: 판별자(discriminator)의 손실 값
# g_loss:생성자(generator)의 손실 값
# D(x): 판별자가 실제 이미지를 진짜라고 판단한 평균 확률. 이 값이 1에 가깝다는 것은 판별자가 실제 이미지를 잘 인식하고 있다는 것을 의미
# D(G(z)):판별자가 생성자가 만든 가짜 이미지를 진짜라고 판단한 평균 확률
# 임의의 이미지 생성 및 시각화
noise = torch.randn(16, 100, device=device)
fake_images = generator(noise)
fake_images = fake_images.cpu().detach().numpy()

fig, axes = plt.subplots(4, 4, figsize=(10, 10))
for i, ax in enumerate(axes.flatten()):
    ax.imshow(fake_images[i][0], cmap='gray')
    ax.axis('off')
plt.tight_layout()
plt.show()

+ Recent posts