Spaces:
Runtime error
Runtime error
| import spaces | |
| import gradio as gr | |
| import torch | |
| from diffusers import UNet2DConditionModel, DDIMInverseScheduler, DDIMScheduler | |
| from utils.pipeline_stable_diffusion_xl import StableDiffusionXLPipeline | |
| import torch | |
| from PIL import Image | |
| import argparse | |
| weak_model = StableDiffusionXLPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0") | |
| strong_model = StableDiffusionXLPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0") | |
| def get_generator(random_seed): | |
| torch.manual_seed(int(random_seed)) | |
| torch.cuda.manual_seed(int(random_seed)) | |
| generator = torch.manual_seed(random_seed) | |
| return generator | |
| model_dict = { | |
| "SDXL": None, | |
| "Human Preference": './ckpt/xlMoreArtFullV1.pREw.safetensors', | |
| 'Batman': './ckpt/batman89000003.BlKn.safetensors', | |
| 'Disney': './ckpt/princessXlV2.WSt4.safetensors', | |
| 'Parchment': './ckpt/ParchartXL.safetensors' | |
| } | |
| # 生成图像的函数 | |
| def generate_image(prompt, seed, T, high_cfg, low_cfg, high_lora, low_lora, weak_choice, strong_choice): | |
| # 设置随机种子 | |
| size = 1024 | |
| guidance_scale = 5.5 | |
| lora_sclae = 0.8 | |
| # device = 'cpu' | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| if high_lora == 0: | |
| high_lora = 0.001 | |
| if low_lora == 0: | |
| low_lora = -0.001 #avoid bug | |
| # 选择模型 | |
| model_id = "stabilityai/stable-diffusion-xl-base-1.0" | |
| dtype = torch.float16 | |
| pipe = StableDiffusionXLPipeline.from_pretrained(model_id, torch_dtype=dtype, | |
| variant='fp16', | |
| safety_checker=None, requires_safety_checker=False).to(device) | |
| pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) | |
| pipe.inv_scheduler = DDIMInverseScheduler.from_pretrained('stabilityai/stable-diffusion-xl-base-1.0', | |
| subfolder='scheduler') | |
| # load dpo lora as strong model | |
| lora_name = strong_choice | |
| if model_dict[strong_choice] is not None: | |
| pipe.load_lora_weights(model_dict[strong_choice], adapter_name=lora_name) | |
| # weak model | |
| generator = get_generator(seed) | |
| pipe.disable_lora() | |
| image_sdxl = pipe(prompt=prompt, height=size, width=size, guidance_scale=guidance_scale, | |
| num_inference_steps=T, generator=generator).images[0] | |
| # strong model | |
| generator = get_generator(seed) | |
| if model_dict[lora_name] is not None: | |
| pipe.enable_lora() | |
| pipe.set_adapters(lora_name, adapter_weights=lora_sclae) | |
| image_dpo_lora = pipe(prompt=prompt, height=size, width=size, guidance_scale=guidance_scale, | |
| num_inference_steps=T, generator=generator).images[0] | |
| # W2SD | |
| generator = get_generator(seed) | |
| pipe.disable_lora() | |
| image_w2sd = \ | |
| pipe.w2sd_lora(prompt=prompt, height=size, width=size, guidance_scale=guidance_scale, | |
| denoise_lora_scale=lora_sclae, | |
| num_inference_steps=T, generator=generator, | |
| lora_gap_list=[high_lora, low_lora], | |
| cfg_gap_list=[high_cfg, low_cfg], lora_name=lora_name).images[0] | |
| return image_sdxl, image_dpo_lora, image_w2sd | |
| with gr.Blocks() as app: | |
| gr.Markdown("# Weak-to-Strong Diffusion with Reflection") | |
| gr.Markdown(""" | |
| **Note:** | |
| 1. The weak model should not be too weak. It is recommended to set the weak LoRA scale to around (-0.5, 0.5), as otherwise, performance degradation may occur (refer to Figure 9 in the paper). | |
| 2. Due to computational limits, it’s best to avoid setting Timesteps too high (standard is 50). A value of 10-15 is recommended, as higher values can slow down the process significantly. | |
| """) | |
| with gr.Row(): | |
| weak_image = gr.Image(label="Generated Image by Weak Model", type="pil") | |
| strong_image = gr.Image(label="Generated Image by Strong Model", type="pil") | |
| w2sd_image = gr.Image(label="Generated Image via W2SD", type="pil") | |
| with gr.Row(): | |
| prompt_input = gr.Textbox(label="Prompt", placeholder="A young girl holding a rose.", lines=2) | |
| with gr.Row(): | |
| seed_slider = gr.Slider(minimum=0, maximum=9999, step=1, value=42, label="Seed") | |
| T_slider = gr.Slider(minimum=5, maximum=50, step=1, value=10, label="Timesteps") | |
| with gr.Row(): | |
| high_lora_slider = gr.Slider(minimum=-2.0, maximum=2.0, step=0.1, value=0.8, label="Select Strong LoRA Scale") | |
| low_lora_slider = gr.Slider(minimum=-2.0, maximum=2.0, step=0.1, value=-0.5, label="Select Weak LoRA Scale") | |
| high_cfg_slider = gr.Slider(minimum=-3, maximum=3, step=0.1, value=2.0, label="Select Strong Guidance Scale") | |
| low_cfg_slider = gr.Slider(minimum=-3, maximum=3, step=0.1, value=1.0, label="Select Weak Guidance Scale") | |
| with gr.Row(): | |
| weak_model_dropdown = gr.Dropdown(choices=["SDXL"], label="Select Weak Model", | |
| value="SDXL") | |
| strong_model_dropdown = gr.Dropdown(choices=model_dict.keys(), | |
| label="Select Strong Model", value="Human Preference") | |
| generate_button = gr.Button("Generate Image") | |
| generate_button.click(generate_image, | |
| inputs=[prompt_input, seed_slider, T_slider, high_cfg_slider, low_cfg_slider, high_lora_slider, low_lora_slider, weak_model_dropdown, | |
| strong_model_dropdown], | |
| outputs=[weak_image, strong_image, w2sd_image]) | |
| # Enable the queue feature | |
| app.queue() | |
| app.launch() | |
| # app.launch(server_name='0.0.0.0', share=True, server_port=7788) |