import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
def get_real_data(data_size):
return torch.randn(data_size, 50) # 표준 정규 분포에서 샘플링
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.main = nn.Sequential(
nn.Linear(50, 16),
nn.ReLU(True),
nn.Linear(16, 32),
nn.ReLU(True),
nn.Linear(32, 50)
)
def forward(self, x):
return self.main(x)
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.main = nn.Sequential(
nn.Linear(50, 32),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(32, 16),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(16, 1),
nn.Sigmoid()
)
def forward(self, x):
return self.main(x)
# 디바이스 설정
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 모델 인스턴스화 및 디바이스로 옮기기
generator = Generator().to(device)
discriminator = Discriminator().to(device)
# 옵티마이저 설정
g_optimizer = optim.Adam(generator.parameters(), lr=0.002)
d_optimizer = optim.Adam(discriminator.parameters(), lr=0.002)
# 손실 함수
g_loss_func = nn.BCELoss()
d_loss_func = nn.BCELoss()
# 학습 시작
epochs = 1000
for epoch in range(epochs):
for _ in range(5): # 판별자를 더 자주 업데이트
real_data = get_real_data(100).to(device) # 실제 데이터 샘플링 및 GPU로 전송
noise = torch.randn(100, 50).to(device)
fake_data = generator(noise).detach() # 생성자의 출력을 계산하고 연산에서 분리
d_real_decision = discriminator(real_data) # 판별자가 실제 데이터를 진짜로 판단하는 정도
d_fake_decision = discriminator(fake_data) # 판별자가 가짜 데이터를 진짜로 판단하는 정도
d_real_loss = d_loss_func(d_real_decision, torch.ones_like(d_real_decision)) # 실제 데이터 손실 계산
d_fake_loss = d_loss_func(d_fake_decision, torch.zeros_like(d_fake_decision)) # 가짜 데이터 손실 계산
d_loss = d_real_loss + d_fake_loss # 전체 판별자 손실
d_optimizer.zero_grad() # 판별자의 그래디언트 초기화
d_loss.backward() # 손실에 대한 역전파 수행
d_optimizer.step() # 판별자의 가중치 업데이트
# 생성자 업데이트
noise = torch.randn(100, 50).to(device)
fake_data = generator(noise) # 새로운 가짜 데이터 생성
d_fake_decision = discriminator(fake_data) # 판별자가 새로운 가짜 데이터를 진짜로 판단하는 정도
g_loss = g_loss_func(d_fake_decision, torch.ones_like(d_fake_decision)) # 생성자 손실 계산
g_optimizer.zero_grad() # 생성자의 그래디언트 초기화
g_loss.backward() # 손실에 대한 역전파 수행
g_optimizer.step() # 생성자의 가중치 업데이트
if (epoch + 1) % 100 == 0:
print(f'Epoch {epoch+1}/{epochs}, D Loss: {d_loss.item()}, G Loss: {g_loss.item()}') # 진행 상황 출력
# 생성된 데이터 시각화
test_samples = generator(torch.randn(1, 50).to(device)).detach().cpu()
print(test_samples)
plt.plot(test_samples.numpy()[0])
plt.title('Generated Sequence Data')
plt.show()