import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
def get_real_data(data_size):
    return torch.randn(data_size, 50)  # 표준 정규 분포에서 샘플링

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            nn.Linear(50, 16),
            nn.ReLU(True),
            nn.Linear(16, 32),
            nn.ReLU(True),
            nn.Linear(32, 50)
        )

    def forward(self, x):
        return self.main(x)

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            nn.Linear(50, 32),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(32, 16),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(16, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.main(x)
# 디바이스 설정
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 모델 인스턴스화 및 디바이스로 옮기기
generator = Generator().to(device)
discriminator = Discriminator().to(device)

# 옵티마이저 설정
g_optimizer = optim.Adam(generator.parameters(), lr=0.002)
d_optimizer = optim.Adam(discriminator.parameters(), lr=0.002)

# 손실 함수
g_loss_func = nn.BCELoss()
d_loss_func = nn.BCELoss()

# 학습 시작
epochs = 1000
for epoch in range(epochs):
    for _ in range(5):  # 판별자를 더 자주 업데이트
        real_data = get_real_data(100).to(device)  # 실제 데이터 샘플링 및 GPU로 전송
        noise = torch.randn(100, 50).to(device)
        fake_data = generator(noise).detach()  # 생성자의 출력을 계산하고 연산에서 분리

        d_real_decision = discriminator(real_data)  # 판별자가 실제 데이터를 진짜로 판단하는 정도
        d_fake_decision = discriminator(fake_data)  # 판별자가 가짜 데이터를 진짜로 판단하는 정도
        d_real_loss = d_loss_func(d_real_decision, torch.ones_like(d_real_decision))  # 실제 데이터 손실 계산
        d_fake_loss = d_loss_func(d_fake_decision, torch.zeros_like(d_fake_decision))  # 가짜 데이터 손실 계산

        d_loss = d_real_loss + d_fake_loss  # 전체 판별자 손실
        d_optimizer.zero_grad()  # 판별자의 그래디언트 초기화
        d_loss.backward()  # 손실에 대한 역전파 수행
        d_optimizer.step()  # 판별자의 가중치 업데이트

    # 생성자 업데이트
    noise = torch.randn(100, 50).to(device)
    fake_data = generator(noise)  # 새로운 가짜 데이터 생성
    d_fake_decision = discriminator(fake_data)  # 판별자가 새로운 가짜 데이터를 진짜로 판단하는 정도
    g_loss = g_loss_func(d_fake_decision, torch.ones_like(d_fake_decision))  # 생성자 손실 계산

    g_optimizer.zero_grad()  # 생성자의 그래디언트 초기화
    g_loss.backward()  # 손실에 대한 역전파 수행
    g_optimizer.step()  # 생성자의 가중치 업데이트

    if (epoch + 1) % 100 == 0:
        print(f'Epoch {epoch+1}/{epochs}, D Loss: {d_loss.item()}, G Loss: {g_loss.item()}')  # 진행 상황 출력
# 생성된 데이터 시각화
test_samples = generator(torch.randn(1, 50).to(device)).detach().cpu()
print(test_samples)
plt.plot(test_samples.numpy()[0])
plt.title('Generated Sequence Data')
plt.show()

+ Recent posts