import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
from torchvision.utils import save_image

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)

def to_img(x):
    x = 0.5 * (x + 1)
    x = x.clamp(0, 1)
    x = x.view(x.size(0), 1, 28, 28)
    return x

class Autoencoder(nn.Module):
    def __init__(self):
        super(Autoencoder, self).__init__()
        # 인코더 부분
        self.encoder = nn.Sequential(
            nn.Linear(28*28, 128),
            nn.ReLU(True),
            nn.Linear(128, 64),
            nn.ReLU(True),
            nn.Linear(64, 12),
            nn.ReLU(True),
            nn.Linear(12, 3)  # 잠재 공간의 차원
        )
        # 디코더 부분
        self.decoder = nn.Sequential(
            nn.Linear(3, 12),
            nn.ReLU(True),
            nn.Linear(12, 64),
            nn.ReLU(True),
            nn.Linear(64, 128),
            nn.ReLU(True),
            nn.Linear(128, 28*28),
            nn.Tanh()
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = Autoencoder().to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

num_epochs = 1
for epoch in range(num_epochs):
    for data in train_loader:
        img, _ = data
        img = img.view(img.size(0), -1)
        img = img.to(device)

        # Forward pass
        output = model(img)
        loss = criterion(output, img)

        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    # Log
    print('Epoch [{}/{}], Loss: {:.4f}'.format(epoch+1, num_epochs, loss.item()))

    # 시각화
    if (epoch+1) % 10 == 0:
        pic = to_img(output.cpu().data)
        save_image(pic, './data/output_{}.png'.format(epoch+1))
# 테스트 이미지 로드 및 변환
test_dataset = datasets.MNIST(root='./data', train=False, transform=transform)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=10, shuffle=True)
test_images, _ = next(iter(test_loader))
test_images = test_images.view(test_images.size(0), -1)
test_images = test_images.to(device)

# 이미지 복원
reconstructed = model(test_images)
reconstructed = to_img(reconstructed.cpu().data)

# 원본 및 복원 이미지 시각화
fig, axs = plt.subplots(2, 10, figsize=(10, 2))
for i in range(10):
    axs[0, i].imshow(test_images[i].view(28, 28).cpu(), cmap='gray')
    axs[0, i].axis('off')
    axs[1, i].imshow(reconstructed[i].view(28, 28), cmap='gray')
    axs[1, i].axis('off')
plt.show()

+ Recent posts