cpuai commited on
Commit
6dfb0b7
·
verified ·
1 Parent(s): 2e613c7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +237 -301
app.py CHANGED
@@ -1,44 +1,51 @@
1
- import spaces
2
- from dataclasses import dataclass
3
- import json
4
- import logging
5
  import os
6
- import random
7
- import re
8
  import sys
 
 
 
 
9
  import warnings
 
10
 
11
- from PIL import Image
12
- from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler
13
  import gradio as gr
14
  import torch
15
- from transformers import AutoModelForCausalLM, AutoTokenizer
16
 
17
- from prompt_check import is_unsafe_prompt
 
 
18
 
19
- sys.path.append(os.path.dirname(os.path.abspath(__file__)))
 
 
 
 
 
 
 
 
 
20
 
 
21
  from diffusers import ZImagePipeline
22
  from diffusers.models.transformers.transformer_z_image import ZImageTransformer2DModel
23
 
24
- from pe import prompt_template
25
-
26
  # ==================== Environment Variables ==================================
27
  MODEL_PATH = os.environ.get("MODEL_PATH", "Tongyi-MAI/Z-Image-Turbo")
28
  ENABLE_COMPILE = os.environ.get("ENABLE_COMPILE", "true").lower() == "true"
29
  ENABLE_WARMUP = os.environ.get("ENABLE_WARMUP", "true").lower() == "true"
30
  ATTENTION_BACKEND = os.environ.get("ATTENTION_BACKEND", "flash_3")
31
- UNSAFE_MAX_NEW_TOKEN = int(os.environ.get("UNSAFE_MAX_NEW_TOKEN", "10"))
32
  DASHSCOPE_API_KEY = os.environ.get("DASHSCOPE_API_KEY")
33
  HF_TOKEN = os.environ.get("HF_TOKEN")
34
- UNSAFE_PROMPT_CHECK = os.environ.get("UNSAFE_PROMPT_CHECK")
35
  # =============================================================================
36
 
37
-
38
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
39
  warnings.filterwarnings("ignore")
40
  logging.getLogger("transformers").setLevel(logging.ERROR)
41
 
 
 
 
42
  RES_CHOICES = {
43
  "1024": [
44
  "1024x1024 ( 1:1 )",
@@ -82,79 +89,100 @@ RES_CHOICES = {
82
  }
83
 
84
  RESOLUTION_SET = []
85
- for resolutions in RES_CHOICES.values():
86
- RESOLUTION_SET.extend(resolutions)
87
 
88
  EXAMPLE_PROMPTS = [
89
  ["一位男士和他的贵宾犬穿着配套的服装参加狗狗秀,室内灯光,背景中有观众。"],
90
- [
91
- "极具氛围感的暗调人像,一位优雅的中国美女在黑暗的房间里。一束强光通过遮光板,在她的脸上投射出一个清晰的闪电形状的光影,正好照亮一只眼睛。高对比度,明暗交界清晰,神秘感,莱卡相机色调。"
92
- ],
93
- [
94
- "一张中景手机自拍照片拍摄了一位留着长黑发的年轻东亚女子在灯光明亮的电梯内对着镜子自拍。她穿着一件带有白色花朵图案的黑色露肩短上衣和深色牛仔裤。她的头微微倾斜,嘴唇嘟起做亲吻状,非常可爱俏皮。她右手拿着一部深灰色智能手机,遮住了部分脸,后置摄像头镜头对着镜子"
95
- ],
96
- [
97
- "Young Chinese woman in red Hanfu, intricate embroidery. Impeccable makeup, red floral forehead pattern. Elaborate high bun, golden phoenix headdress, red flowers, beads. Holds round folding fan with lady, trees, bird. Neon lightning-bolt lamp (⚡️), bright yellow glow, above extended left palm. Soft-lit outdoor night background, silhouetted tiered pagoda (西安大雁塔), blurred colorful distant lights."
98
- ],
99
- [
100
- '''A vertical digital illustration depicting a serene and majestic Chinese landscape, rendered in a style reminiscent of traditional Shanshui painting but with a modern, clean aesthetic. The scene is dominated by towering, steep cliffs in various shades of blue and teal, which frame a central valley. In the distance, layers of mountains fade into a light blue and white mist, creating a strong sense of atmospheric perspective and depth. A calm, turquoise river flows through the center of the composition, with a small, traditional Chinese boat, possibly a sampan, navigating its waters. The boat has a bright yellow canopy and a red hull, and it leaves a gentle wake behind it. It carries several indistinct figures of people. Sparse vegetation, including green trees and some bare-branched trees, clings to the rocky ledges and peaks. The overall lighting is soft and diffused, casting a tranquil glow over the entire scene. Centered in the image is overlaid text. At the top of the text block is a small, red, circular seal-like logo containing stylized characters. Below it, in a smaller, black, sans-serif font, are the words 'Zao-Xiang * East Beauty & West Fashion * Z-Image'. Directly beneath this, in a larger, elegant black serif font, is the word 'SHOW & SHARE CREATIVITY WITH THE WORLD'. Among them, there are "SHOW & SHARE", "CREATIVITY", and "WITH THE WORLD"'''
101
- ],
102
- [
103
- """一张虚构的英语电影《回忆之味》(The Taste of Memory)的电影海报。场景设置在一个质朴的19世纪风格厨房里。画面中央,一位红棕色头发、留着小胡子的中年男子(演员阿瑟·彭哈利根饰)站在一张木桌后,他身穿白色衬衫、黑色马甲和米色围裙,正看着一位女士,手中拿着一大块生红肉,下方是一个木制切菜板。在他的右边,一位梳着高髻的黑发女子(演员埃莉诺·万斯饰)倚靠在桌子上,温柔地对他微笑。她穿着浅色衬衫和一条上白下蓝的长裙。桌上除了放有切碎的葱和卷心菜丝的切菜板外,还有一个白色陶瓷盘、新鲜香草,左侧一个木箱上放着一串深色葡萄。背景是一面粗糙的灰白色抹灰墙,墙上挂着一幅风景画。最右边的一个台面上放着一盏复古油灯。海报上有大量的文字信息。左上角是白色的无衬线字体"ARTISAN FILMS PRESENTS",其下方是"ELEANOR VANCE"和"ACADEMY AWARD® WINNER"。右上角写着"ARTHUR PENHALIGON"和"GOLDEN GLOBE® AWARD WINNER"。顶部中央是圣丹斯电影节的桂冠标志,下方写着"SUNDANCE FILM FESTIVAL GRAND JURY PRIZE 2024"。主标题"THE TASTE OF MEMORY"以白色的大号衬线字体醒目地显示在下半部分。标题下方注明了"A FILM BY Tongyi Interaction Lab"。底部区域用白色小字列出了完整的演职员名单,包括"SCREENPLAY BY ANNA REID"、"CULINARY DIRECTION BY JAMES CARTER"以及Artisan Films、Riverstone Pictures和Heritage Media等众多出品公司标志。整体风格是写实主义,采用温暖柔和的灯光方案,营造出一种亲密的氛围。色调以棕色、米色和柔和的绿色等大地色系为主。两位演员的身体都在腰部被截断。"""
104
- ],
105
- [
106
- """一张方形构图的特写照片,主体是一片巨大的、鲜绿色的植物叶片,并叠加了文字,使其具有海报或杂志封面的外观。主要拍摄对象是一片厚实、有蜡质感的叶子,从左下角到右上角呈对角线弯曲穿过画面。其表面反光性很强,捕捉到一个明亮的直射光源,形成了一道突出的高光,亮面下显露出平行的精细叶脉。背景由其他深绿色的叶子组成,这些叶子轻微失焦,营造出浅景深效果,突出了前景的主叶片。整体风格是写实摄影,明亮的叶片与黑暗的阴影背景之间形成高对比度。图像上有多处渲染文字。左上角是白色的衬线字体文字"PIXEL-PEEPERS GUILD Presents"。右上角同样是白色衬线字体的文字"[Instant Noodle] 泡面调料包"。左侧垂直排列着标题"Render Distance: Max",为白色衬线字体。左下角是五个硕大的白色宋体汉字"显卡在...燃烧"。右下角是较小的白色衬线字体文字"Leica Glow™ Unobtanium X-1",其正上方是用白色宋体字书写的名字"蔡几"。识别出的核心实体包括品牌像素偷窥者协会、其产品线泡面调料包、相机型号买不到™ X-1以及摄影师名字造相。"""
107
- ],
108
  ]
109
 
 
 
 
 
 
 
 
 
 
110
 
111
- def get_resolution(resolution):
112
  match = re.search(r"(\d+)\s*[×x]\s*(\d+)", resolution)
113
  if match:
114
  return int(match.group(1)), int(match.group(2))
115
  return 1024, 1024
116
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
 
118
- def load_models(model_path, enable_compile=False, attention_backend="native"):
119
- print(f"Loading models from {model_path}...")
 
120
 
121
- use_auth_token = HF_TOKEN if HF_TOKEN else True
 
 
122
 
123
- if not os.path.exists(model_path):
 
124
  vae = AutoencoderKL.from_pretrained(
125
- f"{model_path}",
126
  subfolder="vae",
127
- torch_dtype=torch.bfloat16,
128
- device_map="cuda",
129
- use_auth_token=use_auth_token,
 
 
 
 
130
  )
131
 
 
 
132
  text_encoder = AutoModelForCausalLM.from_pretrained(
133
- f"{model_path}",
134
  subfolder="text_encoder",
135
- torch_dtype=torch.bfloat16,
136
- device_map="cuda",
137
- use_auth_token=use_auth_token,
138
  ).eval()
139
-
140
- tokenizer = AutoTokenizer.from_pretrained(f"{model_path}", subfolder="tokenizer", use_auth_token=use_auth_token)
141
- else:
142
- vae = AutoencoderKL.from_pretrained(
143
- os.path.join(model_path, "vae"), torch_dtype=torch.bfloat16, device_map="cuda"
144
  )
145
-
146
  text_encoder = AutoModelForCausalLM.from_pretrained(
147
  os.path.join(model_path, "text_encoder"),
148
- torch_dtype=torch.bfloat16,
149
- device_map="cuda",
150
  ).eval()
151
-
152
  tokenizer = AutoTokenizer.from_pretrained(os.path.join(model_path, "tokenizer"))
153
 
154
  tokenizer.padding_side = "left"
155
 
156
- if enable_compile:
157
- print("Enabling torch.compile optimizations...")
 
158
  torch._inductor.config.conv_1x1_as_mm = True
159
  torch._inductor.config.coordinate_descent_tuning = True
160
  torch._inductor.config.epilogue_fusion = False
@@ -165,42 +193,68 @@ def load_models(model_path, enable_compile=False, attention_backend="native"):
165
 
166
  pipe = ZImagePipeline(scheduler=None, vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, transformer=None)
167
 
168
- if enable_compile:
169
- pipe.vae.disable_tiling()
170
-
171
- if not os.path.exists(model_path):
172
  transformer = ZImageTransformer2DModel.from_pretrained(
173
- f"{model_path}", subfolder="transformer", use_auth_token=use_auth_token
174
- ).to("cuda", torch.bfloat16)
175
- else:
176
- transformer = ZImageTransformer2DModel.from_pretrained(os.path.join(model_path, "transformer")).to(
177
- "cuda", torch.bfloat16
178
  )
 
 
179
 
 
180
  pipe.transformer = transformer
181
- pipe.transformer.set_attention_backend(attention_backend)
182
 
183
- if enable_compile:
184
- print("Compiling transformer...")
185
- pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune-no-cudagraphs", fullgraph=False)
 
 
 
 
 
 
186
 
187
- pipe.to("cuda", torch.bfloat16)
 
 
 
 
 
188
 
189
- from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
190
- from transformers import CLIPImageProcessor
191
 
192
- safety_model_id = "CompVis/stable-diffusion-safety-checker"
193
- safety_feature_extractor = CLIPImageProcessor.from_pretrained(safety_model_id)
194
- safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id, torch_dtype=torch.float16).to("cuda")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
195
 
196
- pipe.safety_feature_extractor = safety_feature_extractor
197
- pipe.safety_checker = safety_checker
198
  return pipe
199
 
200
-
201
  def generate_image(
202
  pipe,
203
- prompt,
204
  resolution="1024x1024",
205
  seed=42,
206
  guidance_scale=5.0,
@@ -211,48 +265,46 @@ def generate_image(
211
  ):
212
  width, height = get_resolution(resolution)
213
 
214
- generator = torch.Generator("cuda").manual_seed(seed)
 
 
 
215
 
216
- scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=shift)
217
  pipe.scheduler = scheduler
218
 
219
- image = pipe(
220
  prompt=prompt,
221
- height=height,
222
- width=width,
223
- guidance_scale=guidance_scale,
224
- num_inference_steps=num_inference_steps,
225
  generator=generator,
226
- max_sequence_length=max_sequence_length,
227
- ).images[0]
228
-
229
  return image
230
 
231
-
232
  def warmup_model(pipe, resolutions):
233
- print("Starting warmup phase...")
234
-
235
  dummy_prompt = "warmup"
236
-
237
  for res_str in resolutions:
238
- print(f"Warming up for resolution: {res_str}")
239
  try:
240
- for i in range(3):
241
  generate_image(
242
  pipe,
243
  prompt=dummy_prompt,
244
- resolution=res_str,
245
- num_inference_steps=9,
246
  guidance_scale=0.0,
247
  seed=42 + i,
248
  )
249
  except Exception as e:
250
- print(f"Warmup failed for {res_str}: {e}")
 
251
 
252
- print("Warmup completed.")
253
-
254
-
255
- # ==================== Prompt Expander ====================
256
  @dataclass
257
  class PromptOutput:
258
  status: bool
@@ -261,7 +313,6 @@ class PromptOutput:
261
  system_prompt: str
262
  message: str
263
 
264
-
265
  class PromptExpander:
266
  def __init__(self, backend="api", **kwargs):
267
  self.backend = backend
@@ -269,7 +320,6 @@ class PromptExpander:
269
  def decide_system_prompt(self, template_name=None):
270
  return prompt_template
271
 
272
-
273
  class APIPromptExpander(PromptExpander):
274
  def __init__(self, api_config=None, **kwargs):
275
  super().__init__(backend="api", **kwargs)
@@ -284,15 +334,15 @@ class APIPromptExpander(PromptExpander):
284
  base_url = self.api_config.get("base_url", "https://dashscope.aliyuncs.com/compatible-mode/v1")
285
 
286
  if not api_key:
287
- print("Warning: DASHSCOPE_API_KEY not found.")
288
  return None
289
 
290
  return OpenAI(api_key=api_key, base_url=base_url)
291
  except ImportError:
292
- print("Please install openai: pip install openai")
293
  return None
294
  except Exception as e:
295
- print(f"Failed to initialize API client: {e}")
296
  return None
297
 
298
  def __call__(self, prompt, system_prompt=None, seed=-1, **kwargs):
@@ -300,7 +350,7 @@ class APIPromptExpander(PromptExpander):
300
 
301
  def extend(self, prompt, system_prompt=None, seed=-1, **kwargs):
302
  if self.client is None:
303
- return PromptOutput(False, "", seed, system_prompt, "API client not initialized")
304
 
305
  if system_prompt is None:
306
  system_prompt = self.decide_system_prompt()
@@ -317,65 +367,60 @@ class APIPromptExpander(PromptExpander):
317
  temperature=0.7,
318
  top_p=0.8,
319
  )
 
320
 
321
- content = response.choices[0].message.content
 
322
  json_start = content.find("```json")
323
  if json_start != -1:
324
  json_end = content.find("```", json_start + 7)
325
- try:
326
  json_str = content[json_start + 7 : json_end].strip()
327
- data = json.loads(json_str)
328
- expanded_prompt = data.get("revised_prompt", content)
329
- except:
330
- expanded_prompt = content
331
- else:
332
- expanded_prompt = content
333
 
334
- return PromptOutput(
335
- status=True, prompt=expanded_prompt, seed=seed, system_prompt=system_prompt, message=content
336
- )
337
  except Exception as e:
338
  return PromptOutput(False, "", seed, system_prompt, str(e))
339
 
340
-
341
  def create_prompt_expander(backend="api", **kwargs):
342
  if backend == "api":
343
  return APIPromptExpander(**kwargs)
344
  raise ValueError("Only 'api' backend is supported.")
345
 
346
-
347
  pipe = None
348
  prompt_expander = None
349
 
350
-
351
  def init_app():
352
  global pipe, prompt_expander
353
 
354
  try:
355
  pipe = load_models(MODEL_PATH, enable_compile=ENABLE_COMPILE, attention_backend=ATTENTION_BACKEND)
356
- print(f"Model loaded. Compile: {ENABLE_COMPILE}, Backend: {ATTENTION_BACKEND}")
357
 
358
- if ENABLE_WARMUP:
359
- all_resolutions = []
360
  for cat in RES_CHOICES.values():
361
- all_resolutions.extend(cat)
362
- warmup_model(pipe, all_resolutions)
363
 
364
  except Exception as e:
365
- print(f"Error loading model: {e}")
366
  pipe = None
367
 
368
  try:
369
  prompt_expander = create_prompt_expander(backend="api", api_config={"model": "qwen3-max-preview"})
370
- print("Prompt expander initialized.")
371
  except Exception as e:
372
- print(f"Error initializing prompt expander: {e}")
373
  prompt_expander = None
374
 
375
-
376
- def prompt_enhance(prompt, enable_enhance):
377
  if not enable_enhance or not prompt_expander:
378
- return prompt, "Enhancement disabled or not available."
379
 
380
  if not prompt.strip():
381
  return "", "Please enter a prompt."
@@ -384,11 +429,35 @@ def prompt_enhance(prompt, enable_enhance):
384
  result = prompt_expander(prompt)
385
  if result.status:
386
  return result.prompt, result.message
387
- else:
388
- return prompt, f"Enhancement failed: {result.message}"
389
  except Exception as e:
390
  return prompt, f"Error: {str(e)}"
391
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
392
 
393
  @spaces.GPU
394
  def generate(
@@ -399,175 +468,42 @@ def generate(
399
  shift=3.0,
400
  random_seed=True,
401
  gallery_images=None,
402
- enhance=False,
403
  progress=gr.Progress(track_tqdm=True),
404
  ):
405
- """
406
- Generate an image using the Z-Image model based on the provided prompt and settings.
407
- This function is triggered when the user clicks the "Generate" button. It processes
408
- the input prompt (optionally enhancing it), configures generation parameters, and
409
- produces an image using the Z-Image diffusion transformer pipeline.
410
- Args:
411
- prompt (str): Text prompt describing the desired image content
412
- resolution (str): Output resolution in format "WIDTHxHEIGHT ( RATIO )" (e.g., "1024x1024 ( 1:1 )")
413
- seed (int): Seed for reproducible generation
414
- steps (int): Number of inference steps for the diffusion process
415
- shift (float): Time shift parameter for the flow matching scheduler
416
- random_seed (bool): Whether to generate a new random seed, if True will ignore the seed input
417
- gallery_images (list): List of previously generated images to append to (only needed for the Gradio UI)
418
- enhance (bool): This was Whether to enhance the prompt (DISABLED! Do not use)
419
- progress (gr.Progress): Gradio progress tracker for displaying generation progress (only needed for the Gradio UI)
420
- Returns:
421
- tuple: (gallery_images, seed_str, seed_int)
422
- - gallery_images: Updated list of generated images including the new image
423
- - seed_str: String representation of the seed used for generation
424
- - seed_int: Integer representation of the seed used for generation
425
- """
426
-
427
  if random_seed:
428
  new_seed = random.randint(1, 1000000)
429
  else:
430
- new_seed = seed if seed != -1 else random.randint(1, 1000000)
431
-
432
- class UnsafeContentError(Exception):
433
- pass
434
-
435
- try:
436
- if pipe is None:
437
- raise gr.Error("Model not loaded.")
438
-
439
- has_unsafe_concept = is_unsafe_prompt(
440
- pipe.text_encoder,
441
- pipe.tokenizer,
442
- system_prompt=UNSAFE_PROMPT_CHECK,
443
- user_prompt=prompt,
444
- max_new_token=UNSAFE_MAX_NEW_TOKEN,
445
- )
446
- if has_unsafe_concept:
447
- raise UnsafeContentError("Input unsafe")
448
-
449
- final_prompt = prompt
450
-
451
- if enhance:
452
- final_prompt, _ = prompt_enhance(prompt, True)
453
- print(f"Enhanced prompt: {final_prompt}")
454
-
455
- try:
456
- resolution_str = resolution.split(" ")[0]
457
- except:
458
- resolution_str = "1024x1024"
459
-
460
- image = generate_image(
461
- pipe=pipe,
462
- prompt=final_prompt,
463
- resolution=resolution_str,
464
- seed=new_seed,
465
- guidance_scale=0.0,
466
- num_inference_steps=int(steps + 1),
467
- shift=shift,
468
- )
469
-
470
- safety_checker_input = pipe.safety_feature_extractor([image], return_tensors="pt").pixel_values.cuda()
471
- _, has_nsfw_concept = pipe.safety_checker(images=[torch.zeros(1)], clip_input=safety_checker_input)
472
- has_nsfw_concept = has_nsfw_concept[0]
473
- if has_nsfw_concept:
474
- raise UnsafeContentError("input unsafe")
475
 
476
- except UnsafeContentError:
477
- image = Image.open("nsfw.png")
478
 
479
- if gallery_images is None:
480
- gallery_images = []
481
- # gallery_images.append(image)
482
- gallery_images = [image] + gallery_images # latest output to be at the top of the list
 
483
 
484
- return gallery_images, str(new_seed), int(new_seed)
485
-
486
-
487
- init_app()
488
-
489
- # ==================== AoTI (Ahead of Time Inductor compilation) ====================
490
-
491
- pipe.transformer.layers._repeated_blocks = ["ZImageTransformerBlock"]
492
- spaces.aoti_blocks_load(pipe.transformer.layers, "zerogpu-aoti/Z-Image", variant="fa3")
493
-
494
- with gr.Blocks(title="Z-Image Demo") as demo:
495
- gr.Markdown(
496
- """<div align="center">
497
- # Z-Image Generation Demo
498
- [![GitHub](https://img.shields.io/badge/GitHub-Z--Image-181717?logo=github&logoColor=white)](https://github.com/Tongyi-MAI/Z-Image)
499
- *An Efficient Image Generation Foundation Model with Single-Stream Diffusion Transformer*
500
- </div>"""
501
- )
502
-
503
- with gr.Row():
504
- with gr.Column(scale=1):
505
- prompt_input = gr.Textbox(label="Prompt", lines=3, placeholder="Enter your prompt here...")
506
- # PE components (Temporarily disabled)
507
- # with gr.Row():
508
- # enable_enhance = gr.Checkbox(label="Enhance Prompt (DashScope)", value=False)
509
- # enhance_btn = gr.Button("Enhance Only")
510
-
511
- with gr.Row():
512
- choices = [int(k) for k in RES_CHOICES.keys()]
513
- res_cat = gr.Dropdown(value=1024, choices=choices, label="Resolution Category")
514
-
515
- initial_res_choices = RES_CHOICES["1024"]
516
- resolution = gr.Dropdown(
517
- value=initial_res_choices[0], choices=RESOLUTION_SET, label="Width x Height (Ratio)"
518
- )
519
-
520
- with gr.Row():
521
- seed = gr.Number(label="Seed", value=42, precision=0)
522
- random_seed = gr.Checkbox(label="Random Seed", value=True)
523
-
524
- with gr.Row():
525
- steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=8, step=1, interactive=False)
526
- shift = gr.Slider(label="Time Shift", minimum=1.0, maximum=10.0, value=3.0, step=0.1)
527
-
528
- generate_btn = gr.Button("Generate", variant="primary")
529
-
530
- # Example prompts
531
- gr.Markdown("### 📝 Example Prompts")
532
- gr.Examples(examples=EXAMPLE_PROMPTS, inputs=prompt_input, label=None)
533
-
534
- with gr.Column(scale=1):
535
- output_gallery = gr.Gallery(
536
- label="Generated Images",
537
- columns=2,
538
- rows=2,
539
- height=600,
540
- object_fit="contain",
541
- format="png",
542
- interactive=False,
543
- )
544
- used_seed = gr.Textbox(label="Seed Used", interactive=False)
545
-
546
- def update_res_choices(_res_cat):
547
- if str(_res_cat) in RES_CHOICES:
548
- res_choices = RES_CHOICES[str(_res_cat)]
549
- else:
550
- res_choices = RES_CHOICES["1024"]
551
- return gr.update(value=res_choices[0], choices=res_choices)
552
-
553
- res_cat.change(update_res_choices, inputs=res_cat, outputs=resolution, api_visibility="private")
554
-
555
- # PE enhancement button (Temporarily disabled)
556
- # enhance_btn.click(
557
- # prompt_enhance,
558
- # inputs=[prompt_input, enable_enhance],
559
- # outputs=[prompt_input, final_prompt_output]
560
- # )
561
-
562
- generate_btn.click(
563
- generate,
564
- inputs=[prompt_input, resolution, seed, steps, shift, random_seed, output_gallery],
565
- outputs=[output_gallery, used_seed, seed],
566
- api_visibility="public",
567
  )
568
 
569
- css = """
570
- .fillable{max-width: 1230px !important}
571
- """
572
- if __name__ == "__main__":
573
- demo.launch(css=css, mcp_server=True)
 
 
 
 
 
1
  import os
 
 
2
  import sys
3
+ import re
4
+ import json
5
+ import random
6
+ import logging
7
  import warnings
8
+ from dataclasses import dataclass
9
 
 
 
10
  import gradio as gr
11
  import torch
12
+ from PIL import Image, ImageDraw, ImageFont
13
 
14
+ import spaces
15
+ from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler
16
+ from transformers import AutoModelForCausalLM, AutoTokenizer
17
 
18
+ # ------------------------- 可选依赖:Prompt Enhancer 模板 -------------------------
19
+ # 你的原工程里如果有 pe.py,会自动使用;没有也不会报错(enhance 默认关闭)
20
+ try:
21
+ sys.path.append(os.path.dirname(os.path.abspath(__file__)))
22
+ from pe import prompt_template # type: ignore
23
+ except Exception:
24
+ prompt_template = (
25
+ "You are a helpful prompt engineer. Expand the user prompt into a richer, detailed prompt. "
26
+ "Return JSON with key revised_prompt."
27
+ )
28
 
29
+ # ------------------------- Z-Image 相关(依赖你环境中 diffusers 的实现) -------------------------
30
  from diffusers import ZImagePipeline
31
  from diffusers.models.transformers.transformer_z_image import ZImageTransformer2DModel
32
 
 
 
33
  # ==================== Environment Variables ==================================
34
  MODEL_PATH = os.environ.get("MODEL_PATH", "Tongyi-MAI/Z-Image-Turbo")
35
  ENABLE_COMPILE = os.environ.get("ENABLE_COMPILE", "true").lower() == "true"
36
  ENABLE_WARMUP = os.environ.get("ENABLE_WARMUP", "true").lower() == "true"
37
  ATTENTION_BACKEND = os.environ.get("ATTENTION_BACKEND", "flash_3")
 
38
  DASHSCOPE_API_KEY = os.environ.get("DASHSCOPE_API_KEY")
39
  HF_TOKEN = os.environ.get("HF_TOKEN")
 
40
  # =============================================================================
41
 
 
42
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
43
  warnings.filterwarnings("ignore")
44
  logging.getLogger("transformers").setLevel(logging.ERROR)
45
 
46
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
47
+ DTYPE = torch.bfloat16 if DEVICE == "cuda" else torch.float32
48
+
49
  RES_CHOICES = {
50
  "1024": [
51
  "1024x1024 ( 1:1 )",
 
89
  }
90
 
91
  RESOLUTION_SET = []
92
+ for _k, v in RES_CHOICES.items():
93
+ RESOLUTION_SET.extend(v)
94
 
95
  EXAMPLE_PROMPTS = [
96
  ["一位男士和他的贵宾犬穿着配套的服装参加狗狗秀,室内灯光,背景中有观众。"],
97
+ ["极具氛围感的暗调人像,一位优雅的中国美女在黑暗的房间里。一束强光通过遮光板,在她的脸上投射出一个清晰的闪电形状的光影,正好照亮一只眼睛。高对比度,明暗交界清晰,神秘感,莱卡相机色调。"],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
  ]
99
 
100
+ # ------------------------- HF token 兼容参数 -------------------------
101
+ def _hf_token_kwargs(token: str | None):
102
+ """
103
+ transformers / diffusers 的 from_pretrained 近年来从 use_auth_token 迁移到 token。
104
+ 这里做一个兼容:优先传 token,不支持则回退 use_auth_token。
105
+ """
106
+ if not token:
107
+ return {}
108
+ return {"token": token, "use_auth_token": token}
109
 
110
+ def get_resolution(resolution: str):
111
  match = re.search(r"(\d+)\s*[×x]\s*(\d+)", resolution)
112
  if match:
113
  return int(match.group(1)), int(match.group(2))
114
  return 1024, 1024
115
 
116
+ def _make_blocked_image(width=1024, height=1024, text="Blocked by Safety Checker"):
117
+ img = Image.new("RGB", (width, height), (20, 20, 20))
118
+ draw = ImageDraw.Draw(img)
119
+ try:
120
+ font = ImageFont.load_default()
121
+ except Exception:
122
+ font = None
123
+ draw.rectangle([0, 0, width, 90], fill=(160, 0, 0))
124
+ draw.text((20, 30), text, fill=(255, 255, 255), font=font)
125
+ return img
126
+
127
+ def _load_nsfw_placeholder(width=1024, height=1024):
128
+ """
129
+ 命中 NSFW 时优先加载工作目录的 nsfw.png;
130
+ 不存在就生成一张占位图,避免文件缺失导致再次报错。
131
+ """
132
+ if os.path.exists("nsfw.png"):
133
+ try:
134
+ return Image.open("nsfw.png").convert("RGB")
135
+ except Exception:
136
+ pass
137
+ return _make_blocked_image(width, height, "NSFW blocked")
138
 
139
+ def load_models(model_path: str, enable_compile=False, attention_backend="native"):
140
+ print(f"[Init] Loading models from: {model_path}")
141
+ print(f"[Init] DEVICE={DEVICE}, DTYPE={DTYPE}, ENABLE_COMPILE={enable_compile}, ATTENTION_BACKEND={attention_backend}")
142
 
143
+ # 远端 repo-id(不存在的本地路径) vs 本地目录
144
+ is_local_dir = os.path.exists(model_path)
145
+ token_kwargs = _hf_token_kwargs(HF_TOKEN) if not is_local_dir else {}
146
 
147
+ # 1) VAE
148
+ if not is_local_dir:
149
  vae = AutoencoderKL.from_pretrained(
150
+ model_path,
151
  subfolder="vae",
152
+ torch_dtype=DTYPE if DEVICE == "cuda" else torch.float32,
153
+ **token_kwargs,
154
+ )
155
+ else:
156
+ vae = AutoencoderKL.from_pretrained(
157
+ os.path.join(model_path, "vae"),
158
+ torch_dtype=DTYPE if DEVICE == "cuda" else torch.float32,
159
  )
160
 
161
+ # 2) Text Encoder + Tokenizer
162
+ if not is_local_dir:
163
  text_encoder = AutoModelForCausalLM.from_pretrained(
164
+ model_path,
165
  subfolder="text_encoder",
166
+ torch_dtype=DTYPE if DEVICE == "cuda" else torch.float32,
167
+ **token_kwargs,
 
168
  ).eval()
169
+ tokenizer = AutoTokenizer.from_pretrained(
170
+ model_path,
171
+ subfolder="tokenizer",
172
+ **token_kwargs,
 
173
  )
174
+ else:
175
  text_encoder = AutoModelForCausalLM.from_pretrained(
176
  os.path.join(model_path, "text_encoder"),
177
+ torch_dtype=DTYPE if DEVICE == "cuda" else torch.float32,
 
178
  ).eval()
 
179
  tokenizer = AutoTokenizer.from_pretrained(os.path.join(model_path, "tokenizer"))
180
 
181
  tokenizer.padding_side = "left"
182
 
183
+ # compile 优化(仅 CUDA 才建议打开)
184
+ if enable_compile and DEVICE == "cuda":
185
+ print("[Init] Enabling torch.compile optimizations...")
186
  torch._inductor.config.conv_1x1_as_mm = True
187
  torch._inductor.config.coordinate_descent_tuning = True
188
  torch._inductor.config.epilogue_fusion = False
 
193
 
194
  pipe = ZImagePipeline(scheduler=None, vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, transformer=None)
195
 
196
+ # 3) Transformer
197
+ if not is_local_dir:
 
 
198
  transformer = ZImageTransformer2DModel.from_pretrained(
199
+ model_path,
200
+ subfolder="transformer",
201
+ **token_kwargs,
 
 
202
  )
203
+ else:
204
+ transformer = ZImageTransformer2DModel.from_pretrained(os.path.join(model_path, "transformer"))
205
 
206
+ transformer = transformer.to(DEVICE, DTYPE)
207
  pipe.transformer = transformer
 
208
 
209
+ # attention backend 可能在不同环境不支持,做容错
210
+ try:
211
+ pipe.transformer.set_attention_backend(attention_backend)
212
+ except Exception as e:
213
+ print(f"[Init] set_attention_backend('{attention_backend}') failed, fallback to 'native'. Error: {e}")
214
+ try:
215
+ pipe.transformer.set_attention_backend("native")
216
+ except Exception as e2:
217
+ print(f"[Init] fallback set_attention_backend('native') failed: {e2}")
218
 
219
+ if enable_compile and DEVICE == "cuda":
220
+ try:
221
+ print("[Init] Compiling transformer...")
222
+ pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune-no-cudagraphs", fullgraph=False)
223
+ except Exception as e:
224
+ print(f"[Init] torch.compile failed, continue without compile. Error: {e}")
225
 
226
+ pipe = pipe.to(DEVICE, DTYPE)
 
227
 
228
+ # 4) Safety Checker(用于生成后过滤)
229
+ try:
230
+ from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
231
+ try:
232
+ from transformers import CLIPImageProcessor as _CLIPProcessor
233
+ except Exception:
234
+ # 老版本兼容
235
+ from transformers import CLIPFeatureExtractor as _CLIPProcessor # type: ignore
236
+
237
+ safety_model_id = "CompVis/stable-diffusion-safety-checker"
238
+ safety_feature_extractor = _CLIPProcessor.from_pretrained(safety_model_id, **_hf_token_kwargs(HF_TOKEN))
239
+ safety_checker = StableDiffusionSafetyChecker.from_pretrained(
240
+ safety_model_id,
241
+ torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32,
242
+ **_hf_token_kwargs(HF_TOKEN),
243
+ ).to(DEVICE)
244
+
245
+ pipe.safety_feature_extractor = safety_feature_extractor
246
+ pipe.safety_checker = safety_checker
247
+ print("[Init] Safety checker loaded.")
248
+ except Exception as e:
249
+ print(f"[Init] Safety checker init failed. NSFW filtering will be skipped. Error: {e}")
250
+ pipe.safety_feature_extractor = None
251
+ pipe.safety_checker = None
252
 
 
 
253
  return pipe
254
 
 
255
  def generate_image(
256
  pipe,
257
+ prompt: str,
258
  resolution="1024x1024",
259
  seed=42,
260
  guidance_scale=5.0,
 
265
  ):
266
  width, height = get_resolution(resolution)
267
 
268
+ if DEVICE == "cuda":
269
+ generator = torch.Generator(device="cuda").manual_seed(int(seed))
270
+ else:
271
+ generator = torch.Generator().manual_seed(int(seed))
272
 
273
+ scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=float(shift))
274
  pipe.scheduler = scheduler
275
 
276
+ out = pipe(
277
  prompt=prompt,
278
+ height=int(height),
279
+ width=int(width),
280
+ guidance_scale=float(guidance_scale),
281
+ num_inference_steps=int(num_inference_steps),
282
  generator=generator,
283
+ max_sequence_length=int(max_sequence_length),
284
+ )
285
+ image = out.images[0]
286
  return image
287
 
 
288
  def warmup_model(pipe, resolutions):
289
+ print("[Warmup] Starting warmup phase...")
 
290
  dummy_prompt = "warmup"
 
291
  for res_str in resolutions:
292
+ print(f"[Warmup] Resolution: {res_str}")
293
  try:
294
+ for i in range(2):
295
  generate_image(
296
  pipe,
297
  prompt=dummy_prompt,
298
+ resolution=res_str.split(" ")[0],
299
+ num_inference_steps=6,
300
  guidance_scale=0.0,
301
  seed=42 + i,
302
  )
303
  except Exception as e:
304
+ print(f"[Warmup] Failed for {res_str}: {e}")
305
+ print("[Warmup] Completed.")
306
 
307
+ # ==================== Prompt Expander(保留但默认不启用) ====================
 
 
 
308
  @dataclass
309
  class PromptOutput:
310
  status: bool
 
313
  system_prompt: str
314
  message: str
315
 
 
316
  class PromptExpander:
317
  def __init__(self, backend="api", **kwargs):
318
  self.backend = backend
 
320
  def decide_system_prompt(self, template_name=None):
321
  return prompt_template
322
 
 
323
  class APIPromptExpander(PromptExpander):
324
  def __init__(self, api_config=None, **kwargs):
325
  super().__init__(backend="api", **kwargs)
 
334
  base_url = self.api_config.get("base_url", "https://dashscope.aliyuncs.com/compatible-mode/v1")
335
 
336
  if not api_key:
337
+ print("[PE] Warning: DASHSCOPE_API_KEY not found. Prompt enhance unavailable.")
338
  return None
339
 
340
  return OpenAI(api_key=api_key, base_url=base_url)
341
  except ImportError:
342
+ print("[PE] Please install openai: pip install openai")
343
  return None
344
  except Exception as e:
345
+ print(f"[PE] Failed to initialize API client: {e}")
346
  return None
347
 
348
  def __call__(self, prompt, system_prompt=None, seed=-1, **kwargs):
 
350
 
351
  def extend(self, prompt, system_prompt=None, seed=-1, **kwargs):
352
  if self.client is None:
353
+ return PromptOutput(False, "", seed, system_prompt or "", "API client not initialized")
354
 
355
  if system_prompt is None:
356
  system_prompt = self.decide_system_prompt()
 
367
  temperature=0.7,
368
  top_p=0.8,
369
  )
370
+ content = response.choices[0].message.content or ""
371
 
372
+ # 尝试从 ```json 块中解析 revised_prompt
373
+ expanded_prompt = content
374
  json_start = content.find("```json")
375
  if json_start != -1:
376
  json_end = content.find("```", json_start + 7)
377
+ if json_end != -1:
378
  json_str = content[json_start + 7 : json_end].strip()
379
+ try:
380
+ data = json.loads(json_str)
381
+ expanded_prompt = data.get("revised_prompt", content)
382
+ except Exception:
383
+ expanded_prompt = content
 
384
 
385
+ return PromptOutput(True, expanded_prompt, seed, system_prompt, content)
 
 
386
  except Exception as e:
387
  return PromptOutput(False, "", seed, system_prompt, str(e))
388
 
 
389
  def create_prompt_expander(backend="api", **kwargs):
390
  if backend == "api":
391
  return APIPromptExpander(**kwargs)
392
  raise ValueError("Only 'api' backend is supported.")
393
 
 
394
  pipe = None
395
  prompt_expander = None
396
 
 
397
  def init_app():
398
  global pipe, prompt_expander
399
 
400
  try:
401
  pipe = load_models(MODEL_PATH, enable_compile=ENABLE_COMPILE, attention_backend=ATTENTION_BACKEND)
402
+ print("[Init] Model loaded.")
403
 
404
+ if ENABLE_WARMUP and pipe is not None:
405
+ all_res = []
406
  for cat in RES_CHOICES.values():
407
+ all_res.extend(cat)
408
+ warmup_model(pipe, all_res)
409
 
410
  except Exception as e:
411
+ print(f"[Init] Error loading model: {e}")
412
  pipe = None
413
 
414
  try:
415
  prompt_expander = create_prompt_expander(backend="api", api_config={"model": "qwen3-max-preview"})
416
+ print("[Init] Prompt expander ready (disabled by default).")
417
  except Exception as e:
418
+ print(f"[Init] Error initializing prompt expander: {e}")
419
  prompt_expander = None
420
 
421
+ def prompt_enhance(prompt, enable_enhance: bool):
 
422
  if not enable_enhance or not prompt_expander:
423
+ return prompt, "Enhancement disabled or unavailable."
424
 
425
  if not prompt.strip():
426
  return "", "Please enter a prompt."
 
429
  result = prompt_expander(prompt)
430
  if result.status:
431
  return result.prompt, result.message
432
+ return prompt, f"Enhancement failed: {result.message}"
 
433
  except Exception as e:
434
  return prompt, f"Error: {str(e)}"
435
 
436
+ def try_enable_aoti(pipe):
437
+ """
438
+ AoTI(ZeroGPU 加速)可用则启用;不可用则跳过,不影响主流程。
439
+ """
440
+ if pipe is None:
441
+ return
442
+ try:
443
+ # 优先按你原代码的结构尝试:pipe.transformer.layers
444
+ if hasattr(pipe, "transformer") and pipe.transformer is not None:
445
+ target = None
446
+ if hasattr(pipe.transformer, "layers"):
447
+ target = pipe.transformer.layers
448
+ if hasattr(target, "_repeated_blocks"):
449
+ target._repeated_blocks = ["ZImageTransformerBlock"]
450
+ else:
451
+ # 兜底:直接对 transformer 设置
452
+ target = pipe.transformer
453
+ if hasattr(target, "_repeated_blocks"):
454
+ target._repeated_blocks = ["ZImageTransformerBlock"]
455
+
456
+ if target is not None:
457
+ spaces.aoti_blocks_load(target, "zerogpu-aoti/Z-Image", variant="fa3")
458
+ print("[Init] AoTI blocks loaded.")
459
+ except Exception as e:
460
+ print(f"[Init] AoTI not enabled (safe to ignore). Error: {e}")
461
 
462
  @spaces.GPU
463
  def generate(
 
468
  shift=3.0,
469
  random_seed=True,
470
  gallery_images=None,
471
+ enhance=False, # 默认不启用
472
  progress=gr.Progress(track_tqdm=True),
473
  ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
474
  if random_seed:
475
  new_seed = random.randint(1, 1000000)
476
  else:
477
+ new_seed = int(seed) if int(seed) != -1 else random.randint(1, 1000000)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
478
 
479
+ if pipe is None:
480
+ raise gr.Error("Model not loaded. Please check logs.")
481
 
482
+ final_prompt = prompt or ""
483
+ if enhance:
484
+ # 你原注释说 DISABLED,这里仍保留能力但默认关闭
485
+ final_prompt, _msg = prompt_enhance(final_prompt, True)
486
+ print(f"[PE] Enhanced prompt: {final_prompt}")
487
 
488
+ # 解析 "1024x1024 ( 1:1 )" -> "1024x1024"
489
+ try:
490
+ resolution_str = str(resolution).split(" ")[0]
491
+ except Exception:
492
+ resolution_str = "1024x1024"
493
+
494
+ width, height = get_resolution(resolution_str)
495
+
496
+ # 生成
497
+ image = generate_image(
498
+ pipe=pipe,
499
+ prompt=final_prompt,
500
+ resolution=resolution_str,
501
+ seed=new_seed,
502
+ guidance_scale=0.0,
503
+ num_inference_steps=int(steps) + 1,
504
+ shift=float(shift),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
505
  )
506
 
507
+ # 生成后 NSFW 安全检查(已去掉 prompt_check)
508
+ try:
509
+ if getattr(pipe, "safety_feature_extractor", None) is not None and getattr(pipe, "safety_checker", None) is not_