DDPM(Denoising Diffusion Probabilistic Models)
笔记来源:
1.Denoising Diffusion Probabilistic Models
2.大白话AI | 图像生成模型DDPM | 扩散模型 | 生成模型 | 概率扩散去噪生成模型
3.pytorch-stable-diffusion
扩散模型正向过程(Forward Diffusion Process)
给某张图片加噪的具体操作
由前一个 x t − 1 x_{t-1} xt−1 推导后一个 x t x_t xt
经过一番推导(详见下文),我们直接由第一个 x 0 x_0 x0 推导第 t t t 个结果 x t x_t xt
DDPM的主要作用:
(1) Add noise to clear image x 0 x_0 x0
(2) calculate μ t ~ \tilde{\mu_t} μt~ (mean) and β t ~ \tilde{\beta_t} βt~ (variance) for distribution q ( x t − 1 ∣ x t , x 0 ) = N ( x t − 1 ; μ t ~ , β t ~ I ) q(x_{t-1}|x_t,x_0) = N(x_{t-1};\tilde{\mu_t},\tilde{\beta_t}I) q(xt−1∣xt,x0)=N(xt−1;μt~,βt~I)
(3) update μ t ~ \tilde{\mu_t} μt~ (mean)
(1) Add noise to clear image using function def add_noise()
上图加噪公式的推导过程见下图
实现 add_noise(clear image: : x 0 x_0 x0, timesteps: t)
class DDPMSampler:def __init__(...):...def set_inference_timesteps(...): # Set the number of inference timesteps for the DDPM model....def _get_previous_timestep(...): # Calculate the previous timestep for the given timestep...def _get_variance(...): # Calculate the variance for the given timestep...def set_strength(...): # Set how much noise to add to the input image....def step(...): # Perform one step of the diffusion (forward) process....def add_noise( # Add noise to the original samples according to the diffusion (forward) process.self,original_samples: torch.FloatTensor,timesteps: torch.IntTensor,) -> torch.FloatTensor:"""Add noise to the original samples according to the diffusion process.Args:- original_samples (torch.FloatTensor): The original samples (images) to which noise will be added.- timesteps (torch.IntTensor): The timesteps at which the noise will be added.Returns:- torch.FloatTensor: The noisy samples."""# Retrieve the cumulative product of alphas on the same device and with the same dtype as the original samplesalphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)# Move timesteps to the same device as the original samplestimesteps = timesteps.to(original_samples.device)# Compute the square root of the cumulative product of alphas for the given timesteps# sqert{hat_alpha_t}sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5# Flatten sqrt_alpha_prod to ensure it's a 1D tensorsqrt_alpha_prod = sqrt_alpha_prod.flatten()# Reshape sqrt_alpha_prod to match the dimensions of original_sampleswhile len(sqrt_alpha_prod.shape) < len(original_samples.shape):sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)# Compute the square root of (1 - cumulative product of alphas) for the given timesteps# sqrt{1-hat_alpha_t}sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5# Flatten sqrt_one_minus_alpha_prod to ensure it's a 1D tensorsqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()# Reshape sqrt_one_minus_alpha_prod to match the dimensions of original_samples# checks if the number of dimensions of sqrt_alpha_prod is less than the number of dimensions of original_sampleswhile len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)# Sample from q(x_t | x_0) as in equation (4) of https://arxiv.org/pdf/2006.11239.pdf# Because N(mu, sigma) = X can be obtained by X = mu + sigma * N(0, 1)# here mu = sqrt_alpha_prod * original_samples and sigma = sqrt_one_minus_alpha_prod# Sample noise from a normal distribution with the same shape as the original samplesnoise = torch.randn(original_samples.shape, generator=self.generator, device=original_samples.device, dtype=original_samples.dtype)# sqrt_alpha_prod * original_samples (This represents the mean component in the noisy sample calculation.)# This term scales the original samples by the square root of the cumulative product of alphas for the given timesteps.# sqrt_one_minus_alpha_prod * noise (This represents the variance component in the noisy sample calculation.)# This term scales the random noise by the square root of (1 - cumulative product of alphas) for the given timesteps.# sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise# adds the scaled noise to the scaled original samples. This operation forms the noisy samples,# where the influence of the original samples and the noise varies according to the timesteps.# x_t = sqrt{hat_alpha_t} * x_0 + sqrt{1-hat_alpha_t} * epsilonnoisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noisereturn noisy_samples
(2) calculate μ t ~ \tilde{\mu_t} μt~ (mean) and β t ~ \tilde{\beta_t} βt~ (variance) for distribution q ( x t − 1 ∣ x t , x 0 ) = N ( x t − 1 ; μ t ~ , β t ~ I ) q(x_{t-1}|x_t,x_0) = N(x_{t-1};\tilde{\mu_t},\tilde{\beta_t}I) q(xt−1∣xt,x0)=N(xt−1;μt~,βt~I) Note: N(output; mean, variance) \text{Note: N(output; mean, variance)} Note: N(output; mean, variance)
求上述概率分布的均值和方差的推导过程见下图
实现 _get_variance() 计算方差,实现 step() 计算均值并更新均值
class DDPMSampler:def __init__(...):...def set_inference_timesteps(...): # Set the number of inference timesteps for the DDPM model....def _get_previous_timestep(...): # Calculate the previous timestep for the given timestep...def _get_variance(...): # Calculate the variance for the given timestep...def set_strength(...): # Set how much noise to add to the input image....def step(self, timestep: int, latents: torch.Tensor, model_output: torch.Tensor):"""Perform one step of the diffusion (forward) process.Args:- timestep (int): The current timestep during diffusion.- latents (torch.Tensor): The latent representation of the input.- model_output (torch.Tensor): The output from the diffusion model."""t = timestep# Get the previous timestep using the _get_previous_timestep methodprev_t = self._get_previous_timestep(t)# 1. compute alphas, betas# hat_alpha_talpha_prod_t = self.alphas_cumprod[t]# hat_alpha_{t-1}alpha_prod_t_prev = self.alphas_cumprod[prev_t] if prev_t >= 0 else self.one# hat_beta_t = 1 - hat_alpha_tbeta_prod_t = 1 - alpha_prod_t# hat_beta_{t-1} = 1 - hat_alpha_{t-1}beta_prod_t_prev = 1 - alpha_prod_t_prev# alpha_prod_t / alpha_prod_t_prev = (alpha_t*alpha_{t-1}*...*alpha_1) / (alpha_{t-1}*...*alpha_1) = alpha_tcurrent_alpha_t = alpha_prod_t / alpha_prod_t_prev# beta_t = 1- alpha_tcurrent_beta_t = 1 - current_alpha_t# 2. compute predicted original sample from predicted noise also called# "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf# x_t = sqrt{1 - hat_alpha_t}* epsilon + sqrt{hat_alpha_t} * x_0# x_0 = (x_t - sqrt{1 - hat_alpha_t} * epsilon(x_t)) / sqrt{hat_alpha_t}# x_0 = (x_t - sqrt{hat_beta_t} * epsilon(x_t)) / sqrt{hat_alpha_t}pred_original_sample = (latents - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)# 4. Compute coefficients for pred_original_sample x_0 and current sample x_t# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf# x_{t-1} ~ p_{theta}(x_{t-1} | x_t) a distribution with regard to x_{t-1} during reverse process# = N (1/sqrt{alpha_t} * x_t - (beta_t)/(sqrt{alpha_t}sqrt{1-hat_alpha_t} * epsilon(x_t,t))# , (beta_t * 1-hat_alpha_{t-1})/(1-hat_alpha_{t}) )# x_{t-1} ~ q(x_{t-1} | x_t,x_0) a distribution with regard to x_{t-1} during forward process# = N (frac{sqrt{hat_alpha_{t-1}}beta_t}{1-hat_alpha_t}x_0+frac{sqrt{alpha_t}(1-hat_alpha_{t-1})}{1-hat_alphat_t}*x_t# , (beta_t * 1-hat_alpha_{t-1})/(1-hat_alpha_{t}))# frac{sqrt{hat_alpha_{t-1}}beta_t}{1-hat_alpha_t}pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * current_beta_t) / beta_prod_t# frac{sqrt{alpha_t}(1-hat_alpha_{t-1})}{1-hat_alphat_t}current_sample_coeff = current_alpha_t ** (0.5) * beta_prod_t_prev / beta_prod_t# 5. Compute predicted previous sample µ_t# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf# pred_mu_t = coeff_1 * x_0 + coeff_2 * x_tpred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * latents# 6. Update pred_mu_t according to pred_beta_t...def add_noise(...):...
为何我们要计算概率分布 q ( x t − 1 ∣ x t , x 0 ) q(x_{t-1}|x_t,x_0) q(xt−1∣xt,x0)?
Stable Diffusion 的 Loss Funtion 推导中会出现一个KL散度项,此项衡量两个分布的相似性,以此来不断引导反向过程生成最终的图片,具体解释见后续博客
(3) update μ t ~ \tilde{\mu_t} μt~ (mean)
μ ~ t = μ ~ t + β t ~ 2 × ϵ ( Note: ϵ ∼ N ( 0 , 1 ) ) μ ~ t = μ ~ t + β t ~ × ϵ \tilde{\mu}_t = \tilde{\mu}_t + \sqrt{\tilde{\beta_t}^2}×\epsilon\ \left(\text{Note: }\epsilon \sim N(0,1)\right)\\ \tilde{\mu}_t = \tilde{\mu}_t + \tilde{\beta_t}×\epsilon μ~t=μ~t+βt~2×ϵ (Note: ϵ∼N(0,1))μ~t=μ~t+βt~×ϵ
class DDPMSampler:def __init__(...):...def set_inference_timesteps(...): # Set the number of inference timesteps for the DDPM model....def _get_previous_timestep(...): # Calculate the previous timestep for the given timestep...def _get_variance(...): # Calculate the variance for the given timestep...def set_strength(...): # Set how much noise to add to the input image....def step(self, timestep: int, latents: torch.Tensor, model_output: torch.Tensor):"""Perform one step of the diffusion (forward) process.Args:- timestep (int): The current timestep during diffusion.- latents (torch.Tensor): The latent representation of the input.- model_output (torch.Tensor): The output from the diffusion model.""".........# 6. Update pred_mu_t according to pred_beta_tvariance = 0if t > 0:# Get the device of model_outputdevice = model_output.device# Generate random noise with the same shape as model_outputnoise = torch.randn(model_output.shape, generator=self.generator, device=device, dtype=model_output.dtype)# Compute the variance for the current timestep as per formula (7) from https://arxiv.org/pdf/2006.11239.pdf# sqrt{sigma_t}*epsilonvariance = (self._get_variance(t) ** 0.5) * noise# Add the variance (multiplied by noise) to the predicted previous sample# sample from N(mu, sigma) = X can be obtained by X = mu + sigma * N(0, 1)# the variable "variance" is already multiplied by the noise N(0, 1)# For t > 0, compute predicted variance βt (see formula (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf)# and sample from it to get previous sample# pred_mu_t = pred_mu_t + sqrt{pred_beta_t^2} * epsilon (Note:epsilon ~N(0,1))pred_prev_sample = pred_prev_sample + variancereturn pred_prev_sampledef add_noise(...):...
All of codes about DDPM (ddpm,.py)
import torch
import numpy as np
'''
# Forward Process
# Add noise to clear image and calculate pred_mu_t and pred_beta_t for distribution and update pred_mu_t
# (1) Add noise to clear image using function def add_noise()
# x_t = sqrt{hat_alpha_t} * x_0 + sqrt{1-hat_alpha_t} * epsilon (Note:epsilon~N(0,1))
# see formula (4) from https://arxiv.org/pdf/2006.11239.pdf
# (2) calculate pred_mu_t and pred_beta_t for distribution
# q(x_{t-1}|x_t,x_0) = N(pred_mu_t,pred_beta_t*I)
# def step()
# predicted_mu_t = coeff_1 * x_0 + coeff_2 * x_t
# def _get_variance()
# predicted_variance beta_t=(1-hat_alpha_{t-1})/(1-hat_alpha_t)*beta_t
# (3) update pred_mu_t
# def step()
# update pred_mu_t = pred_mu_t + sqrt{pred_beta_t^2} * noise (Note:noise ~ N(0,1))
# see formula (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf
'''
class DDPMSampler:def __init__(self, generator: torch.Generator, num_training_steps=1000, beta_start: float = 0.00085, beta_end: float = 0.0120):# Params "beta_start" and "beta_end" taken from:# https://github.com/CompVis/stable-diffusion/blob/21f890f9da3cfbeaba8e2ac3c425ee9e998d5229/configs/stable-diffusion/v1-inference.yaml#L5C8-L5C8# For the naming conventions, refer to the DDPM paper (https://arxiv.org/pdf/2006.11239.pdf)"""Initialize the DDPM (Denoising Diffusion Probabilistic Model) parameters.Args:- generator (torch.Generator): A PyTorch random number generator.- num_training_steps (int, optional): Number of training steps. Default is 1000.- beta_start (float, optional): The starting value of beta. Default is 0.00085.- beta_end (float, optional): The ending value of beta. Default is 0.0120."""self.betas = torch.linspace(beta_start ** 0.5, beta_end ** 0.5, num_training_steps, dtype=torch.float32) ** 2# alppha = 1 - betaself.alphas = 1.0 - self.betas# hat_alpha = alpha_t * alpha_ {t-1} * ... * alpha_2 * alpha_1self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)# Define a tensor representing the value 1.0self.one = torch.tensor(1.0)# Store the generator for random number generationself.generator = generator# Number of training timestepsself.num_train_timesteps = num_training_steps# Create a tensor of timesteps in reverse orderself.timesteps = torch.from_numpy(np.arange(0, num_training_steps)[::-1].copy())def set_inference_timesteps(self, num_inference_steps=50):"""Set the number of inference timesteps for the DDPM model.Args:- num_inference_steps (int, optional): Number of steps to use during inference. Default is 50."""# Store the number of inference stepsself.num_inference_steps = num_inference_steps# Calculate the ratio between training timesteps and inference timestepsstep_ratio = self.num_train_timesteps // self.num_inference_steps# Generate an array of timesteps for inference:# - np.arange(0, num_inference_steps): Create an array from 0 to num_inference_steps-1# - Multiply by step_ratio to space out the timesteps# - round() to ensure the timesteps are integers# - [::-1] to reverse the order, as inference typically proceeds backward through the timesteps# - copy() to ensure the array is contiguous in memory# - astype(np.int64) to ensure the timesteps are of type int64, which is compatible with PyTorchtimesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64)# Convert the numpy array of timesteps to a PyTorch tensorself.timesteps = torch.from_numpy(timesteps)def _get_previous_timestep(self, timestep: int) -> int:"""Calculate the previous timestep for the given timestep during inference.Args:- timestep (int): The current timestep during inference.Returns:- int: The previous timestep during inference."""# Calculate the previous timestep by subtracting the step ratio from the current timestep.# The step ratio is the integer division of the total number of training timesteps by the number of inference timesteps.# timstep t-1 = timestep t - ratioprev_t = timestep - self.num_train_timesteps // self.num_inference_stepsreturn prev_tdef _get_variance(self, timestep: int) -> torch.Tensor:"""Calculate the variance for the given timestep during inference.Args:- timestep (int): The current timestep during inference.Returns:- torch.Tensor: The variance for the given timestep."""# Get the previous timestep using the _get_previous_timestep methodprev_t = self._get_previous_timestep(timestep)# Retrieve the cumulative product of alphas at the current and previous timesteps# hat_alpha_talpha_prod_t = self.alphas_cumprod[timestep]# hat_alpha_{t-1}alpha_prod_t_prev = self.alphas_cumprod[prev_t] if prev_t >= 0 else self.one# alpha_prod_t / alpha_prod_t_prev = (alpha_t*alpha_{t-1}*...*alpha_1) / (alpha_{t-1}*...*alpha_1) = alpha_t# beta_t = 1- alpha_tcurrent_beta_t = 1 - alpha_prod_t / alpha_prod_t_prev# For t > 0, compute predicted variance βt (see formula (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf)# and sample from it to get previous sample# x_{t-1} ~ P(x_{t-1} | x_t,x_0)# = N (mu, sigma)# = N (1/sqrt{alpha_t} * x_t - (beta_t)/(sqrt{alpha_t}sqrt{1-hat_alpha_t} * epsilon)# , (beta_t * 1-hat_alpha_{t-1})/(1-hat_alpha_{t}) )# x_{t-1} ~ N(pred_prev_sample, variance) == add variance to pred_samplevariance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * current_beta_t# Clamp the variance to ensure it's not zero, as we will take its log latervariance = torch.clamp(variance, min=1e-20)return variancedef set_strength(self, strength=1):"""Set how much noise to add to the input image.Args:- strength (float, optional): A value between 0 and 1 indicating the amount of noise to add.- A strength value close to 1 means the output will be further from the input image (more noise).- A strength value close to 0 means the output will be closer to the input image (less noise)."""# Calculate the number of inference steps to skip based on the strength# Higher strength means fewer steps skipped (more noise added)# start_step is the number of noise levels to skipstart_step = self.num_inference_steps - int(self.num_inference_steps * strength)# Update the timesteps to start from the calculated step# This effectively sets the starting point for the noise addition processself.timesteps = self.timesteps[start_step:]# Store the starting step for referenceself.start_step = start_stepdef step(self, timestep: int, latents: torch.Tensor, model_output: torch.Tensor):"""Perform one step of the diffusion(forward) process.Args:- timestep (int): The current timestep during diffusion.- latents (torch.Tensor): The latent representation of the input.- model_output (torch.Tensor): The output from the diffusion model."""t = timestep# Get the previous timestep using the _get_previous_timestep methodprev_t = self._get_previous_timestep(t)# 1. compute alphas, betas# hat_alpha_talpha_prod_t = self.alphas_cumprod[t]# hat_alpha_{t-1}alpha_prod_t_prev = self.alphas_cumprod[prev_t] if prev_t >= 0 else self.one# hat_beta_t = 1 - hat_alpha_tbeta_prod_t = 1 - alpha_prod_t# hat_beta_{t-1} = 1 - hat_alpha_{t-1}beta_prod_t_prev = 1 - alpha_prod_t_prev# alpha_prod_t / alpha_prod_t_prev = (alpha_t*alpha_{t-1}*...*alpha_1) / (alpha_{t-1}*...*alpha_1) = alpha_tcurrent_alpha_t = alpha_prod_t / alpha_prod_t_prev# beta_t = 1- alpha_tcurrent_beta_t = 1 - current_alpha_t# 2. compute predicted original sample from predicted noise also called# "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf# x_t = sqrt{1 - hat_alpha_t}* epsilon + sqrt{hat_alpha_t} * x_0# x_0 = (x_t - sqrt{1 - hat_alpha_t} * epsilon(x_t)) / sqrt{hat_alpha_t}# x_0 = (x_t - sqrt{hat_beta_t} * epsilon(x_t)) / sqrt{hat_alpha_t}pred_original_sample = (latents - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)# 4. Compute coefficients for pred_original_sample x_0 and current sample x_t# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf# x_{t-1} ~ p_{theta}(x_{t-1} | x_t) a distribution with regard to x_{t-1} during reverse process# = N (1/sqrt{alpha_t} * x_t - (beta_t)/(sqrt{alpha_t}sqrt{1-hat_alpha_t} * epsilon(x_t,t))# , (beta_t * 1-hat_alpha_{t-1})/(1-hat_alpha_{t}) )# x_{t-1} ~ q(x_{t-1} | x_t,x_0) a distribution with regard to x_{t-1} during forward process# = N (frac{sqrt{hat_alpha_{t-1}}beta_t}{1-hat_alpha_t}x_0+frac{sqrt{alpha_t}(1-hat_alpha_{t-1})}{1-hat_alphat_t}*x_t# , (beta_t * 1-hat_alpha_{t-1})/(1-hat_alpha_{t}))# frac{sqrt{hat_alpha_{t-1}}beta_t}{1-hat_alpha_t}pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * current_beta_t) / beta_prod_t# frac{sqrt{alpha_t}(1-hat_alpha_{t-1})}{1-hat_alphat_t}current_sample_coeff = current_alpha_t ** (0.5) * beta_prod_t_prev / beta_prod_t# 5. Compute predicted previous sample µ_t# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf# pred_mu_t = coeff_1 * x_0 + coeff_2 * x_tpred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * latents# 6. Update pred_mu_t according to pred_beta_tvariance = 0if t > 0:# Get the device of model_outputdevice = model_output.device# Generate random noise with the same shape as model_outputnoise = torch.randn(model_output.shape, generator=self.generator, device=device, dtype=model_output.dtype)# Compute the variance for the current timestep as per formula (7) from https://arxiv.org/pdf/2006.11239.pdf# sqrt{sigma_t}*epsilonvariance = (self._get_variance(t) ** 0.5) * noise# Add the variance (multiplied by noise) to the predicted previous sample# sample from N(mu, sigma) = X can be obtained by X = mu + sigma * N(0, 1)# the variable "variance" is already multiplied by the noise N(0, 1)# For t > 0, compute predicted variance βt (see formula (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf)# and sample from it to get previous sample# pred_mu_t = pred_mu_t + sqrt{pred_beta_t^2} * epsilon (Note:epsilon ~N(0,1))pred_prev_sample = pred_prev_sample + variancereturn pred_prev_sampledef add_noise(self,original_samples: torch.FloatTensor,timesteps: torch.IntTensor,) -> torch.FloatTensor:"""Add noise to the original samples according to the diffusion process.Args:- original_samples (torch.FloatTensor): The original samples (images) to which noise will be added.- timesteps (torch.IntTensor): The timesteps at which the noise will be added.Returns:- torch.FloatTensor: The noisy samples."""# Retrieve the cumulative product of alphas on the same device and with the same dtype as the original samplesalphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)# Move timesteps to the same device as the original samplestimesteps = timesteps.to(original_samples.device)# Compute the square root of the cumulative product of alphas for the given timesteps# sqert{hat_alpha_t}sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5# Flatten sqrt_alpha_prod to ensure it's a 1D tensorsqrt_alpha_prod = sqrt_alpha_prod.flatten()# Reshape sqrt_alpha_prod to match the dimensions of original_sampleswhile len(sqrt_alpha_prod.shape) < len(original_samples.shape):sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)# Compute the square root of (1 - cumulative product of alphas) for the given timesteps# sqrt{1-hat_alpha_t}sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5# Flatten sqrt_one_minus_alpha_prod to ensure it's a 1D tensorsqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()# Reshape sqrt_one_minus_alpha_prod to match the dimensions of original_samples# checks if the number of dimensions of sqrt_alpha_prod is less than the number of dimensions of original_sampleswhile len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)# Sample from q(x_t | x_0) as in equation (4) of https://arxiv.org/pdf/2006.11239.pdf# Because N(mu, sigma) = X can be obtained by X = mu + sigma * N(0, 1)# here mu = sqrt_alpha_prod * original_samples and sigma = sqrt_one_minus_alpha_prod# Sample noise from a normal distribution with the same shape as the original samplesnoise = torch.randn(original_samples.shape, generator=self.generator, device=original_samples.device, dtype=original_samples.dtype)# sqrt_alpha_prod * original_samples (This represents the mean component in the noisy sample calculation.)# This term scales the original samples by the square root of the cumulative product of alphas for the given timesteps.# sqrt_one_minus_alpha_prod * noise (This represents the variance component in the noisy sample calculation.)# This term scales the random noise by the square root of (1 - cumulative product of alphas) for the given timesteps.# sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise# adds the scaled noise to the scaled original samples. This operation forms the noisy samples,# where the influence of the original samples and the noise varies according to the timesteps.# x_t = sqrt{hat_alpha_t} * x_0 + sqrt{1-hat_alpha_t} * epsilonnoisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noisereturn noisy_samples