| | import torch |
| | from dataloader import get_dataloaders |
| | from config import Config |
| | from noise_scheduler import FrequencyAwareNoise |
| | import matplotlib.pyplot as plt |
| |
|
| | def debug_data(): |
| | config = Config() |
| | train_loader, _ = get_dataloaders(config) |
| | x0, _ = next(iter(train_loader)) |
| | |
| | |
| | plt.figure(figsize=(10, 5)) |
| | plt.subplot(1, 2, 1) |
| | plt.imshow(x0[0].permute(1, 2, 0).numpy() * 0.5 + 0.5) |
| | plt.title("Original") |
| | |
| | |
| | noise_scheduler = FrequencyAwareNoise(config) |
| | xt = noise_scheduler.apply_noise(x0, torch.tensor([500] * len(x0))) |
| | plt.subplot(1, 2, 2) |
| | plt.imshow(xt[0].permute(1, 2, 0).numpy() * 0.5 + 0.5) |
| | plt.title("Noisy (t=500)") |
| | plt.show() |
| |
|
| | if __name__ == "__main__": |
| | debug_data() |