import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
from torchvision.utils import save_image
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)
def to_img(x):
x = 0.5 * (x + 1)
x = x.clamp(0, 1)
x = x.view(x.size(0), 1, 28, 28)
return x
class Autoencoder(nn.Module):
def __init__(self):
super(Autoencoder, self).__init__()
# 인코더 부분
self.encoder = nn.Sequential(
nn.Linear(28*28, 128),
nn.ReLU(True),
nn.Linear(128, 64),
nn.ReLU(True),
nn.Linear(64, 12),
nn.ReLU(True),
nn.Linear(12, 3) # 잠재 공간의 차원
)
# 디코더 부분
self.decoder = nn.Sequential(
nn.Linear(3, 12),
nn.ReLU(True),
nn.Linear(12, 64),
nn.ReLU(True),
nn.Linear(64, 128),
nn.ReLU(True),
nn.Linear(128, 28*28),
nn.Tanh()
)
def forward(self, x):
x = self.encoder(x)
x = self.decoder(x)
return x
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = Autoencoder().to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
num_epochs = 1
for epoch in range(num_epochs):
for data in train_loader:
img, _ = data
img = img.view(img.size(0), -1)
img = img.to(device)
# Forward pass
output = model(img)
loss = criterion(output, img)
# Backward and optimize
optimizer.zero_grad()
loss.backward()
optimizer.step()
# Log
print('Epoch [{}/{}], Loss: {:.4f}'.format(epoch+1, num_epochs, loss.item()))
# 시각화
if (epoch+1) % 10 == 0:
pic = to_img(output.cpu().data)
save_image(pic, './data/output_{}.png'.format(epoch+1))
# 테스트 이미지 로드 및 변환
test_dataset = datasets.MNIST(root='./data', train=False, transform=transform)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=10, shuffle=True)
test_images, _ = next(iter(test_loader))
test_images = test_images.view(test_images.size(0), -1)
test_images = test_images.to(device)
# 이미지 복원
reconstructed = model(test_images)
reconstructed = to_img(reconstructed.cpu().data)
# 원본 및 복원 이미지 시각화
fig, axs = plt.subplots(2, 10, figsize=(10, 2))
for i in range(10):
axs[0, i].imshow(test_images[i].view(28, 28).cpu(), cmap='gray')
axs[0, i].axis('off')
axs[1, i].imshow(reconstructed[i].view(28, 28), cmap='gray')
axs[1, i].axis('off')
plt.show()