공부자료/Deep Learning

적대적 생성 신경망

mogazi 2022. 9. 26. 05:44
  •   적대적 생성 신경망이란? (General Adversarial Network, GAN)

 

GAN을 경찰과 위조지폐범 사이의 게임에 비유하자면, 

 

위조지폐범은 진짜와 같은 화폐를 만들어 경찰을 속이고, 경찰은 진짜 화폐와 가짜 화폐를 판별하여 위조지폐범을 검거한다. 

위조지폐범과 경찰의 경쟁적인 학습이 지속되면 어느 순간 위조지폐범은 진짜와 같은 위조지폐를 만들 수 있게 되고, 

결국 경찰은 위조지폐와 실제 화폐를 구분할 수 없는 상태에 이르게 된다.

 

 

 

 

 

딥러닝 용어로 설명하자면, 경찰은 진짜 지폐와 위조지폐를 구분하는 판별자가 되며, 위조지폐범은 위조지폐를 생성하는 생성자가 된다. 

생성 모델은 최대한 진짜와 비슷한 데이터를 생성하려는 생성자와 진짜와 가짜를 구별하는 판별자가 각각 존재하여 서로 적대적으로 학습한다.

 

 

 

 

적대적 학습에서는 판별자를 먼저 학습시킨 후 생성자를 학습시키는 과정을 반복한다. 

판별자 학습은 크게 두 단계로 진행됩니다. 

먼저 실제 이미지를 입력해서 네트워크(신경망)가 해당 이미지를 진짜로 분류하도록 학습시킵니다. 

그런 다음 생성자가 생성한 모조 이미지를 입력해서 해당 이미지를 가짜로 분류하도록 학습시킵니다. 

이 과정을 거쳐 판별자는 실제 이미지를 진짜로 분류하고, 모조 이미지를 가짜로 분류합니다.

 

 

 

- 적대적 생성 신경망 학습과정

 

 

 

이와 같은 학습 과정을 반복하면 판별자와 생성자가 서로를 적대적인 경쟁자로 인식하여 모두 발전하게 된다. 

결과적으로 생성자는 진짜 이미지에 완벽히 가까울 정도의 유사한 모조 이미지를 만들고, 이에 따라 판별자는 실제 이미지와 모조 이미지를 구분할 수 없게 된다. 

 

 

즉, 생성자는 분류에 성공할 확률을 낮추고 판별자는 분류에 성공할 확률을 높이면서 서로 경쟁적으로 발전시키는 구조이다.

 

 

 

 

 

 

 

 

  •   GAN의 동작 원리

 

 

적대적 생성 신경망(GAN)은 생성자(generator)와 판별자(discriminator) 네트워크 두 개로 구성되어 있다. 

이름에서 알 수 있듯이 두 네트워크는 서로 적대적으로 경쟁하여 학습을 진행한다. 

 

 

생성자 G는 판별자 D를 속이려고 원래 이미지와 최대한 비슷한 이미지를 만들도록 학습한다. 

반대로 판별자 D는 원래 이미지와 생성자 G가 만든 이미지를 잘 구분하도록 학습을 진행한다.

 

 

 

 

 

먼저 판별자 D부터 살펴보자. 

판별자 D의 역할은 주어진 입력 이미지가 진짜 이미지인지 가짜 이미지인지 구별하는 것이다. 

즉, 이미지 x가 입력으로 주어졌을 때 판별자 D의 출력에 해당하는 D(x)가 진짜 이미지일 확률을 반환한다.

 

 

반면 생성자 G의 역할은 판별자 D가 진짜인지 가짜인지 구별할 수 없을 만큼 진짜와 같은 모조 이미지를 노이즈 데이터를 사용하여 만들어 내는 것이다. 

예를 들어 실제 이미지인 알파벳 z가 입력으로 주어졌을 때 판별자는 z를 학습한다. 

또한, 생성자는 임의의 노이즈 데이터를 사용하여 모조 이미지 z'(G(z))를 생성한다. 

이러한 G(z)를 다시 판별자 D의 입력으로 주면 판별자는 G(z)가 실제 이미지일 확률을 반환한다.

 

 

 

실제 데이터를 판단하려고 판별자 D를 학습시킬 때는 생성자 G를 고정시킨 채 실제 이미지(x∼pdata(x))는 높은 확률을 반환하는 방향으로, 

모조 이미지(z∼pz(z))는 낮은 확률을 반환하는 방향으로 가중치를 업데이트한다.

 

 

 

 

GAN의 손실 함수는 다음과 같다.

 

• x~pdata(x): 실제 데이터에 대한 확률 분포에서 샘플링한 데이터

• z~pz(z): 가우시안 분포를 사용하는 임의의 노이즈에서 샘플링한 데이터

• D(x>): 판별자 D(x)가 1에 가까우면 진짜 데이터로, 0에 가까우면 가짜 데이터로 판단, 0이면 가짜를 의미

• D(G(z)): 생성자 G가 생성한 이미지인 G(z)가 1에 가까우면 진짜 데이터로, 0에 가까우면 가짜 데이터로 판단

 

 

 

 

 

 

수식에서 판별자 D는 실제 이미지 x를 입력받을 경우 D(x)를 1로 예측하고, 생성자가 잠재 벡터에서 생성한 모조 이미지 G(z)를 입력받을 경우 D(G(z))를 0으로 예측한다. 

따라서 판별자가 모조 이미지 G(z)를 입력받을 경우 1로 예측하도록 하는 것이 목표다.

 

 

 

다시 손실 함수 전체로 돌아와서  판별자 D와 생성자 G 부분으로 나누어서 살펴보자. 

판별자 D는 다음 식의 최댓값으로 파라미터를 업데이트하는 것을 목표로 한다.

 

 

이때 판별자 입장에서는 D(x)=1, D(G(z))=0이 최상의 결과(진짜 이미지는 1, 가짜 이미지는 0을 출력할 경우)가 될 것이기 때문에 이 식의 최댓값이 될 것이다.

 

 

 

 

또한, 생성자 G는 다음 식의 최솟값으로 파라미터를 업데이트하는 것을 목표로 한다.

 

 

이때 생성자 입장에서는 D(G(z)) = 1 이 최상의 결과 (판별자가 가짜 이미지를 1로 출력한 경우)가 될 것이기 때문에 이 식의 최솟값이 될 것이다.

 

 

 

따라서 logD(x)와 log(1-D(G(z))) 모두 최대가 되어야 한다. 

즉, D(x)는 1이 되어야 실제 이미지를 진짜라고 분류하며, 1-D(G(z))는 1이 되어야 생성자가 만든 모조 이미지를 가짜라고 분류한다.

 

 

 

참고로 GAN을 학습시키려면 판별자와 생성자의 파라미터를 번갈아 가며 업데이트해야 한다. 

또한, 판별자의 파라미터를 업데이트할 때는 생성자의 파라미터를 고정시키고, 생성자의 파라미터를 업데이트할 때는 판별자의 파라미터를 고정해야 한다.