Denoising Diffusion Probabilistic Models (NeurIPS 2020)Permalink

Denoising Diffusion Probabilistic Models. [Paper] [Github]
Jonathan Ho, Ajay Jain, Pieter Abbeel
Jun 19th 2020

Diffusion Model의 (거의) 시작점을 알린 논문이다. 이 논문을 처음 보았을 때 굉장히 어려웠었고, 그래서 이 논문을 가장 먼저 포스팅해보려고 한다. 혹시 관련 개념을 잘 모른다면, [Intro to Diffusion Models] (1) What is Diffusion?를 보고 오시는 것을 추천한다.

또한, Diffusion Model의 수학을 굉장히 잘 설명해둔 Understanding Diffusion Models: A Unified Perspective 역시 큰 도움이 될 것이다. 이 포스팅의 전개 부분은 위 논문을 많이 참고하였다. 이 포스팅에서 Understanding Diffusion Models: A Unified Perspective에 대해 다루지 않은 내용은 추후에 리뷰할 예정이다.

BackgroundPermalink

Diffusion Model은 다음과 같은 형태를 띈다.

x0:=x(0)pdata(x)q(x1|x0)pθ(x0|x1)x1xT1q(xT|xT1)pθ(xT1|xT)xT:=x(1), x(1)zN(0,I)

이를 모델 pθ의 관점에서 표현하면, Reverse Diffusion Process를 묘사하는 다음의 latent variable model로 표현된다.

pθ(x0:T):=p(xT)Tt=1pθ(xt1|xt),pθ(xt1|xt)=N(xt1;μθ(xt,t),Σθ(xt,t))

이때, pθ(x0:T)는 posterior이자, Forward Diffusion Process인

q(x1:T|x0):=Tt=1q(xt|xt1),q(xt|xt1):=N(xt;1βtxt1,βtI)

를 근사하도록 학습한다.

Reverse Diffusion Process로 왜 Forward Diffusion Process를 근사하는가?
만약 pθq를 잘 근사한다면, pθ(xt)가 묘사하는 분포와 q(xt)가 묘사하는 분포가 동일해야한다. 따라서, Forward Diffusion Process를 묘사하는 q의 분포를 pθ를 통해 근사하는 것이다.


이때, q(xT|x0)N(xT;0,I)여야 하기 때문에, q(xt|x0)tT로 진행될 수록 xt1의 영향력을 줄여나가면서, variance를 I에 가깝게 만들어야 한다. 따라서 variance schedule, β1,,βT0<β1,<β2<<βT1<βT<1 로 정의된다. 따라서, 이 과정을 거치면, xt1을 scaling down 하고 gaussian noise를 더하는 형태로 Diffusion Process를 진행하게 된다. 이때, variance scheduling (βt)의 설정은 논문마다 다르며, 이후 논문들에서는 효과적인 variacne scheduling을 찾기 위한 노력을 하기도 한다.

학습은 variational bound on negative log-likelihood를 최적화하여 pθ를 학습한다.

E[logpθ(x0)]Eq[logpθ(x0:T)q(x1:T|x0)]=Eq[logp(xT)t1logpθ(xt1|xt)q(xt|xt1)]:=L

이때, Jensen’s inequality에 의해 성립된다. L을 최적화하는 것은 logpθ(x0)을 최대화하는 것과 같으므로, data x0을 sampling할 확률을 높이는 방향으로 pθ를 학습한다는 뜻이다.

Reparameterization trickPermalink

왜 gaussian noise를 더한다는 것인가? 이는 gaussian distribution에서의 sampling 과정을 살펴보면 명확해진다. (statistical machine learning에서 주로 쓰이며) 딥러닝에서는 VAE 논문에서 사용하여 유명해진 Reparameterization trick을 사용하여 gaussian distribution에서의 sampling을 진행한다. Back-propagation을 통해 gaussian distribution을 묘사하는 μ,σ2을 학습하기 위해서는, μ,σ에 gradient가 흐르도록 zN(μ,σ2)을 sampling 해야한다. Reparameterization trick은 z를 다음과 같은 형태로 샘플링한다.

z=μ+σϵ,ϵN(0,I)

이 형태를 고려하여 q(xt|xt1):=N(xt;1βtxt1,βtI)를 살펴보면

xt=1βtxt1+βtIϵt,ϵtN(0,I)

형태를 띈다. 0<βt<1 이기 때문에, xt에서 xt1은 scaling down되고, βt만큼 scaling된 gaussian noise ϵt1βtxt1에 더해진 형태로 xt가 sampling 되는 것이다.

Toward Efficients Experssions for Diffusion Model TrainingPermalink

만약 xt에 대해서 학습하기 위해서 모든 x0,x1,,xt1의 diffusion step을 거쳐야한다면 굉장히 오랜 시간이 걸릴 것이다. DDPM에서는 임의의 time step t에 대해 xt를 샘플링할 수 있는 closed form을 보였다.

q(xt|x0)=N(xt;ˉαtx0,(1ˉαt)I)

이때, αt:=1βt, ˉαt:=ts=1αs 이다.

Derivation for q(xt|x0)
[Understanding Diffusion Models: A Unified Perspective](https://arxiv.org/abs/2208.11970)에서도 증명을 해줄테지만, 내가 찾아봤던 방식은 Fourier Transform을 사용해서, 정리해본다. ### Background [3.3 Combining Gaussian variables](https://www.astro.ubc.ca/people/jvw/ASTROSTATS/Answers/Chap3/combining%20gaussians.pdf) Let g1(t)=eat2, the Fourier transform of g1(t) is given as follows: G1(w)=eiwteat2dt=πaew2/4a For g1(t)=12πσ1et2/(2σ21) G1(w)=eσ21w2/2 If ˉg1(t)=12πσ1e(tμ1)2/(2σ21), then ˉG1(w)=eiμ1weσ21w2/2=eiμ1wG1(w) For convolving two gaussian distributions, which is adding two gaussian noises, g1(t):=N(μ1,σ1), g2(t):=N(μ2,σ2) g1(t)g2(t)FTG1(w)G2(w)=eiwμ1eσ21w2/2eiwμ2eσ22w2/2=eiw(μ1μ2)e(σ21+σ22)w2/2 If two gaussian distributions share same mean, μ1=μ2=μ, G1(w)G2=e(σ21+σ22)w2/2 which is just another gaussian distribution with a variance σ21+σ22. Then, g1(t)g2(t)=N(0,σ21+σ22)=12π(σ21+σ22)et2/2(σ21+σ22) ### Derivation for q(xt|x0)=N(xt;ˉαt,(1ˉαt)I) - q(xt|xt1):=N(xt;1βtxt1,βtI) - αt:=1βt - ˉαt:=ts=1αs. From xtq(xt|xt1) and xt1q(xt1|xt2) xt1=1βt1xt2+βt1ϵt1,where ϵt1N(0,I)=αt1xt2+1αt1ϵt1xt=1βtxt1+βtϵt,where ϵtN(0,I)=αt\textcolorredxt1+1αtϵt=αt(\textcolorredαt1xt2+1αt1ϵt1)+1αtϵt=αtαt1xt2+αt1αt1ϵt1+1αtϵt Here, adding two gaussian noises is same as convolving two gaussian distributions. Therefore, αt1αt1ϵt1+1αtϵt=ˉϵN(0,αt(1αt1)I)N(0,(1αt1)I)ˉϵtN(0,(αt(1αt1)+(1αt1))I)ˉϵtN(0,(1αtαt1)I)ˉϵt=1αtαt1ϵN(0,(1αtαt1)I),ϵN(0,I) Therfore, eq (8) becomes xt=αtαt1xt2+αt1αt1ϵt1+1αtϵtˉϵt=αtαt1xt2+1αtαt1ϵ,ϵN(0,I) If we repeat this steps, we can derive following closed from: xt=ts=1αsx0+1ts=1αsϵ,ϵN(0,I)=ˉαtx0+1ˉαtϵ which is equal to xtq(xt|x0)=N(xt;ˉαtx0,(1ˉαt)I)


위 식을 이용하여, data x0에서 임의의 time step txt를 직접적으로 샘플링이 가능하기 때문에 효율적으로 학습이 가능하다.

한가지 기억할 점은, 우리의 모델 pθ(xt1|xt)xt를 input으로 받아 xt1을 예측하는 모델이기 때문에, 이를 학습시키기 위한 target xt1을 forward diffusion distribution q()에서 샘플링할 수 있어야 한다. 위 식을 더욱 전개하여 posteior distribution q(xt1|xt,x0)을 얻을 수 있다.

Bayes’ rule에 따라,

q(xt1|xt,x0)posterior=likelihoodq(xt|xt1,x0)priorq(xt1|x0)q(xt|x0)evidence=q(xt|xt1)q(xt1|x0)q(xt|x0)

q(xt|xt1,x0)은 Markov Process를 가정하였기 때문에 q(xt|xt1,x0)=q(xt|xt1)와 동일하다.

normalization term을 배제하면 위 식은 다음 관계를 가지게 된다.

q(xt|xt1,x0)q(xt1|x0)q(xt|x0)exp(12((xtαtxt1)2βt+(xt1ˉαt1x0)21ˉαt1(xtˉαtx0)21ˉαt))=exp(12(x2t2αtxtxt1+αtx2t1βt+x2t12ˉαt1xt1x0+ˉαt1x201ˉαt1x2t2ˉαtxtx0+ˉαtx201ˉαt))=exp(12((αtβt+11ˉαt)x2t1(2αtβtxt+2ˉαt11ˉαt1x0)xt1+C(xt,x0)))=exp(12(αtβt+11ˉαt1)(x2t12(αtβtxt+ˉαt11ˉαt1x0)(αtβt+11ˉαt1)xt1+C(xt,x0))) 이때, C(xt,x0)xt,x0으로 이루어져 있는 constant term이다. 위 식으로부터 우리는 다음 식을 얻을 수 있다.

q(xt1|xt,x0)=N(xt1;˜μt(xt,x0),˜βtI) ˜βt=(αtβt+11ˉαt1)1=(αt(1ˉαt1+βt)βt(1ˉαt1))1=(αtαtˉαt1+βtβt(1ˉαt1))1=(αtˉαt+1αtβt(1ˉαt1))1=(1ˉαtβt(1ˉαt1))1=(1ˉαt11ˉαt)βt˜μt(xt,x0)=ˉαtβtxt+ˉαt11ˉαt1x0αtβt+11ˉαt1=(αtβt+11ˉαt1)1(ˉαtβtxt+ˉαt11ˉαt1x0)=(1ˉαt11ˉαtβt)(ˉαtβtxt+ˉαt11ˉαt1x0)=(ˉαt1βt1ˉαt)x0+(ˉαt(1ˉαt1)1ˉαt)xt

이 식을 사용하여, 위의 최적화 식을 다음과 같이 수정이 가능하다.

Eq[logpθ(x0:T)q(x1:T|x0)]=Eq[logp(xT)Tt=1pθ(xt1|xt)Tt=1q(xt|xt1)]=Eq[logp(xT)pθ(x0|x1)q(x1|x0)logTt=2pθ(xt1|xt)q(xt|xt1)]=Eq[logp(xT)pθ(x0|x1)q(x1|x0)logTt=2pθ(xt1|xt)q(xt|xt1,x0)]=Eq[logp(xT)pθ(x0|x1)q(x1|x0)logTt=2pθ(xt1|xt)q(xt1|xt,x0)q(xt|x0)q(xt1|x0)]=Eq[logp(xT)pθ(x0|x1)q(x1|x0)Tt=2logq(xt1|x0)q(xt|x0)Tt=2logpθ(xt1|xt)q(xt1|xt,x0)]=Eq[logpθ(x0|x1)logp(xT)q(xT|x0)Tt=2logpθ(xt1|xt)q(xt1|xt,x0)]=Eq(x1|x0)[pθ(x0|x1)]Eq(xT|x0)[logp(xT)q(xT|x0)]Tt=2Eq(xt,xt1|x0)[logpθ(xt1|xt)q(xt1|xt,x0)]=Eq(x1|x0)[logpθ(x0|x1)]reconstruction term,L0+DKL(q(xT|x0)p(xT))prior matching term,LT+Tt=2Eq(xt|x0)[DKL(q(xt1|xt,x0)|)pθ(xt1|xt))]denoising matching term,Lt

How to train?Permalink

LtPermalink

두 gaussian distribution의 KL Divergence는 closed form이 존재한다:

DKL(N(x;μx,Σx)N(y;μy,Σy))=12[log|Σy||Σx|d+tr((Σ1yΣx)+(μyμx)TΣ1y(μyμx))]

모델을 학습하면서 pθ(xt1|xt)=N(xt1;μθ(xt,t),Σθ(xt,t))Σθ(xt,t)을 학습할 수도 있지만, DDPM의 저자는 이를 pre-define된 constant term으로 사용한다. 따라서, pθ(xt1|xt)=N(xt1;μθ(xt,t),σ2tI)를 사용한다. 이로부터 아래의 loss 식을 전개할 수 있다. (Understanding Diffusion Models: A Unified Perspective eq (87~92))

DKL(q(xt1|xt,x0)|)pθ(xt1|xt))=Eq[12σ2t˜μt(xt,x0)μθ(xt,t)2]

μθ(xt,t)를 통해 ˜μ(xt,x0)=(ˉαt1βt1ˉαt)x0+(ˉαt(1ˉαt1)1ˉαt)xt를 바로 예측할 수도 있지만, 동일하게 다음의 형태를 예측할 수도 있다.

μθ(xt,t)=(ˉαt1βt1ˉαt)ˆxθ(xt,t)+(ˉαt(1ˉαt1)1ˉαt)xt

이는 pθ(xt,t)˜μ(xt,x0)을 바로 예측하는 것이 아닌 ground truth x0을 예측하도록 하는 것과 동일하다. 이 식을 적용하면, 위 최적화 식은 다음으로 전개된다. (Understanding Diffusion Models: A Unified Perspective eq (95~99))

Eq[12σ2tˉαt1(1αt)2(1ˉαt)2ˆxθ(xt,t)x02]

또한, xtq(xt|x0)으로부터 xt(x0,ϵ)=ˉαtx0+1ˉαtϵ,ϵN(0,I)의 식을 얻을 수 있고, 이로부터 x0=xt(x0,ϵ)1ˉαtϵˉαt이고, 이를 ˜μ(xt,x0)에 대입하면,

˜μ(xt,x0)=1αtxt1αt1ˉαtαtϵ0,ϵ0N(0,I)

이때, ϵ0xtq(xt|x0)를 샘플링할 때, x0에 더해진 noise이다. (Understanding Diffusion Models: A Unified Perspective eq (116~124))

이로부터 μθ(xt,t)를 다음과 같이 표현할 수 있다.

μθ(xt,t)=1αtxt1αt1ˉαtαtˆϵθ(xt,t)used for sampling!

이를 반영하면 최적화 식은 다음과 같다.

Eq[12σ2t(1αt)2(1ˉαt)αtϵ0ˆϵθ(xt,t)2]

이는 xtq(xt|x0)을 샘플링할 때 x0에 더해진 노이즈 ϵ0을 예측하는 것과 같다. (Understanding Diffusion Models: A Unified Perspective eq (126~130))

최종적으로 DDPM의 저자는 noise를 예측하는 형태의 마지막 최적화 식을 기용하였고, 이 식을 이용하여 x0에 더해진 noise를 제거하는 denoising 형태로 샘플링을 진행한다. 저자들은 해당 loss의 simplified 버전인

Eq[ϵ0ˆϵθ(xt,t)2]

을 최적화하는 것이 샘플 퀄리티에 더 도움이 된다는 것을 발견하였고, 이를 loss로 기용하였다. σ2t의 경우, σ2t=βtσ2t=˜βt 모두 실험해보았을 때, 비슷한 효과를 얻었다고 한다.

LTPermalink

βt를 DDPM 저자들은 constant로 두었기 때문에 prior matching term은 constant이고, 학습에서 배제된다.

L0Permalink

마지막 reconstruction term은 실제 이미지를 복원하는 것을 목표로한다. noise가 존재하지 않는 data인 x0을 만드는 것이 목표기 때문에 x0가 {0,…,255}의 픽셀 값이 [-1,1]로 linear scaling 되어있다고 가정하고, 다음 식을 최적화한다.

pθ(x0|x1)=Di=1δ+(xi0)δ(xi0)N(x;μiθ(x1,1),σ21)dxδ+(x)={if x=1x+1255if x<1δ(x)={if x=1x1255if x>1

이때 D는 예측하는 data x0의 dimensionalty, 즉 차원이며, 픽셀 하나하나의 값을 최적화하는 것과 같다. 무슨 뜻일까? 0~255가 -1~1로 scaling되었기 때문에 [1,253255,251255,,251255,253255,1]의 값을 가지고, 한 integer 값이 -1~1에서 차지하는 범위는 2255이다. 해당 범위 안에 μθ값이 예측되면, 그 값으로 반올림될 것이다. 따라서 target xi0에 따라 해당하는 범위 안을 μθ가 예측될 수 있도록 해당 범위의 확률을 maximize하도는 식으로 볼 수 있다. gaussian distribution의 확률 적분값은 closed form으로 예측할 수 있으니 최적화 식은 충분히 효율적이다.

Settings & Sampling & ExperimentsPermalink

DDPM의 학습 세팅은 다음과 같다.

  • T=1000
  • β1=104,βT=102

Training & Sampling algorithm은 다음과 같다.

DDPM algorithm 1&2

Source: DDPM Project page.

이때 xt1pθ(xt1|xt)=N(xt1;μθ(xt,t),σ2tI)을 따라 샘플링된다. μθ(xt1|xt)?

Negative Log-Likelihood (NLL)Permalink

NLL은 모델이 얼마나 data distribution을 잘 표현하는지 측정하는데 주로 사용한다. real data에 대한 model의 확률이기 때문에 logpθ(x)가 높을수록 좋은 것이다. 그렇기에 NLL=logpθ(x)는 낮을 수록 좋다. 다른 종류의 generative model은 diffusion model처럼 progressive하게 샘플링하지 않기 때문에 NLL을 output 하나에 대해서 측정한다. 또한, NLL을 측정하는 방식은 확률 분포의 가정에 따라 다양한데, 다음 포스팅을 참고하면 좋을 것이다.

DDPM의 NLL 계산식은 L0을 distortion, Lt1을 rate로 취급하여 NLL=L0+L1++LT로 계산하였다.

Rate & Distortion?
rate와 distortion은 정보의 압축과 손실을 정량화하는 지표들이다. rate는 정보를 압축할 때 필요한 bit 수, distortion은 정보의 손실로 표현된다. 생성 모델은 데이터의 압축기로도 표현되고, 정확한 likelihood를 표현할 수 있는 효과가 이러한 compressor의 관점에서 중요하다.
ChatGPT에게 물어보니 다음과 같이 rate와 distortion을 설명해주었다.

PixelCNN: Rate와 Distortion

PixelCNN: Rate와 Distortion

PixelCNN에서 RateDistortion은 데이터 압축 효율성과 데이터 품질 간의 트레이드오프를 설명합니다. 이 두 개념은 Negative Log-Likelihood (NLL)와 밀접한 관련이 있습니다.

1. Rate (R): 표현 비용

정의: 데이터를 표현하는 데 필요한 정보량(비트 수)을 나타냅니다.

NLL과의 관계: Rate는 데이터의 평균 NLL과 직접적으로 연관되며, 다음과 같이 계산됩니다:

R = - (1/n) ∑ log P(xᵢ | x₍<ᵢ₎)
    

특징:

  • Rate가 낮을수록 데이터 표현이 효율적입니다.
  • Rate가 높을수록 데이터를 표현하는 데 더 많은 정보량이 필요합니다.

2. Distortion (D): 데이터 손실

정의: 원본 데이터와 모델이 생성한 데이터 간의 차이를 나타냅니다.

NLL과의 관계: Distortion은 모델의 재구성 품질을 나타내며, 일반적으로 다음과 같은 방식으로 측정됩니다:

D(x, ẋ) = (1/n) ∑ (xᵢ - ẋᵢ)²
    

특징:

  • Distortion이 낮을수록 모델이 생성한 데이터가 원본 데이터에 가깝습니다.
  • Distortion이 높을수록 생성된 데이터 품질이 떨어집니다.

3. Rate-Distortion 트레이드오프

정의: Rate와 Distortion은 서로 반비례 관계에 있습니다:

  • Rate가 높아지면 Distortion이 낮아지고(더 높은 품질),
  • Rate가 낮아지면 Distortion이 높아집니다(더 낮은 품질).

이 관계는 다음과 같이 표현됩니다:

minimize E[D(x, ẋ)] subject to R ≤ R_max
    

4. PixelCNN에서의 Rate와 Distortion

PixelCNN에서는 다음과 같이 Rate와 Distortion을 해석할 수 있습니다:

  • Rate: 데이터셋에 대한 평균 NLL로 계산되며, 모델이 학습한 조건부 확률로부터 결정됩니다.
  • Distortion: 원본 이미지와 모델이 생성한 이미지 간의 차이를 나타내며, 품질 지표(MSE, SSIM 등)로 평가됩니다.
생성 모델 분야에서 자주 등장하는 개념이니 알아두면 좋을 것이다.


실제 계산식은 코드1, 코드2를 참고하면 될 것이다.

DDPM NLL

Source: DDPM Project page.

CIFAR10에서 rate: 1.78 bits/dim, distortion 1.97 bits/dim을 달성하였다.

Rate-Distortion BehaviorPermalink

저자들은 Rate-Distorion Behavior를 조사하였다. (방식은 논문 내 algorithm 3&4를 참고) time step을 t:T0으로 진행하면서 각 time step까지의 rate를 측정하였다. (LTLt) 또한, 예측된 xt로부터 x0ˆx0=(xt1ˉαt0ϵθ(xt))를 이용하여 RMSE를 측정하였다. (x0ˆx02/D).

Rate-Distortion Graph

세 번째 rate vs. distortion 그래프에서 볼 수 있듯이, low-rate 영역에서 distortion이 대부분 감소한다. 반면에 high-rate region에서는 distortion이 작은 것을 알 수 있는데, 이로부터 대부분의 bit (rate)가 인지하기 힘든 distortion을 수정하는데 할당된다는 것을 알 수 있다.

사실 이 관찰로부터 현재 SOTA architecture인 LDM이 시작된다. Diffusion model을 latent space에서 모델링함으로써, 인지하기 힘든 distortion에 diffusion step을 할당하는 것을 막고자 하는 것이다. 이는 후에 LDM을 다룰 때 다시 설명하겠다.


처음 생성 모델을 공부할 때, DDPM이 너무 어려워서 반년 가까이 시간을 투자했었다. 그래서 이번 포스팅을 굉장히 오랜 시간 작성하게 되었는데, 이후 논문들은 조금 간략하게 전달할 예정이다.

Leave a comment