| | import torch |
| | import torchvision |
| | from torchvision.utils import save_image |
| | import os |
| | from config import Config |
| |
|
| | def simple_sample(model, noise_scheduler, device, epoch=None, writer=None, n_samples=4): |
| | """Standard DDPM sampling - this should actually work""" |
| | config = Config() |
| | model.eval() |
| | |
| | with torch.no_grad(): |
| | |
| | x = torch.randn(n_samples, 3, config.image_size, config.image_size, device=device) |
| | |
| | print(f"Starting reverse diffusion for {n_samples} samples...") |
| | |
| | |
| | alphas = noise_scheduler.alphas.to(device) |
| | alpha_bars = noise_scheduler.alpha_bars.to(device) |
| | betas = noise_scheduler.betas.to(device) |
| | |
| | |
| | for step, t in enumerate(reversed(range(config.T))): |
| | if step % 100 == 0: |
| | print(f"Step {step}/{config.T}, t={t}") |
| | |
| | t_tensor = torch.full((n_samples,), t, device=device, dtype=torch.long) |
| | |
| | |
| | pred_noise = model(x, t_tensor) |
| | |
| | |
| | alpha_t = alphas[t] |
| | alpha_bar_t = alpha_bars[t] |
| | beta_t = betas[t] |
| | |
| | |
| | if t > 0: |
| | alpha_bar_prev = alpha_bars[t-1] |
| | |
| | |
| | pred_x0 = (x - torch.sqrt(1 - alpha_bar_t) * pred_noise) / torch.sqrt(alpha_bar_t) |
| | |
| | |
| | mean = (torch.sqrt(alpha_bar_prev) * beta_t / (1 - alpha_bar_t)) * pred_x0 + \ |
| | (torch.sqrt(alpha_t) * (1 - alpha_bar_prev) / (1 - alpha_bar_t)) * x |
| | |
| | |
| | noise = torch.randn_like(x) |
| | variance = (1 - alpha_bar_prev) / (1 - alpha_bar_t) * beta_t |
| | x = mean + torch.sqrt(variance) * noise |
| | else: |
| | |
| | x = (x - torch.sqrt(1 - alpha_bar_t) * pred_noise) / torch.sqrt(alpha_bar_t) |
| | |
| | |
| | x = torch.clamp(x, -1, 1) |
| | |
| | |
| | if epoch is not None and epoch % 10 == 0: |
| | print(f"Sample stats at epoch {epoch}: range [{x.min().item():.3f}, {x.max().item():.3f}], mean {x.mean().item():.3f}") |
| | |
| | grid = torchvision.utils.make_grid(x, nrow=2, normalize=True) |
| | |
| | if writer: |
| | writer.add_image('Samples', grid, epoch) |
| | |
| | if epoch is not None: |
| | os.makedirs("samples", exist_ok=True) |
| | save_image(grid, f"samples/epoch_{epoch}.png") |
| | |
| | return x, grid |
| |
|
| | |
| | def sample(model, noise_scheduler, device, epoch=None, writer=None, n_samples=4): |
| | return simple_sample(model, noise_scheduler, device, epoch, writer, n_samples) |
| |
|