Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,44 +1,51 @@
|
|
| 1 |
-
import spaces
|
| 2 |
-
from dataclasses import dataclass
|
| 3 |
-
import json
|
| 4 |
-
import logging
|
| 5 |
import os
|
| 6 |
-
import random
|
| 7 |
-
import re
|
| 8 |
import sys
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
import warnings
|
|
|
|
| 10 |
|
| 11 |
-
from PIL import Image
|
| 12 |
-
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler
|
| 13 |
import gradio as gr
|
| 14 |
import torch
|
| 15 |
-
from
|
| 16 |
|
| 17 |
-
|
|
|
|
|
|
|
| 18 |
|
| 19 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
|
|
|
|
| 21 |
from diffusers import ZImagePipeline
|
| 22 |
from diffusers.models.transformers.transformer_z_image import ZImageTransformer2DModel
|
| 23 |
|
| 24 |
-
from pe import prompt_template
|
| 25 |
-
|
| 26 |
# ==================== Environment Variables ==================================
|
| 27 |
MODEL_PATH = os.environ.get("MODEL_PATH", "Tongyi-MAI/Z-Image-Turbo")
|
| 28 |
ENABLE_COMPILE = os.environ.get("ENABLE_COMPILE", "true").lower() == "true"
|
| 29 |
ENABLE_WARMUP = os.environ.get("ENABLE_WARMUP", "true").lower() == "true"
|
| 30 |
ATTENTION_BACKEND = os.environ.get("ATTENTION_BACKEND", "flash_3")
|
| 31 |
-
UNSAFE_MAX_NEW_TOKEN = int(os.environ.get("UNSAFE_MAX_NEW_TOKEN", "10"))
|
| 32 |
DASHSCOPE_API_KEY = os.environ.get("DASHSCOPE_API_KEY")
|
| 33 |
HF_TOKEN = os.environ.get("HF_TOKEN")
|
| 34 |
-
UNSAFE_PROMPT_CHECK = os.environ.get("UNSAFE_PROMPT_CHECK")
|
| 35 |
# =============================================================================
|
| 36 |
|
| 37 |
-
|
| 38 |
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 39 |
warnings.filterwarnings("ignore")
|
| 40 |
logging.getLogger("transformers").setLevel(logging.ERROR)
|
| 41 |
|
|
|
|
|
|
|
|
|
|
| 42 |
RES_CHOICES = {
|
| 43 |
"1024": [
|
| 44 |
"1024x1024 ( 1:1 )",
|
|
@@ -82,79 +89,100 @@ RES_CHOICES = {
|
|
| 82 |
}
|
| 83 |
|
| 84 |
RESOLUTION_SET = []
|
| 85 |
-
for
|
| 86 |
-
RESOLUTION_SET.extend(
|
| 87 |
|
| 88 |
EXAMPLE_PROMPTS = [
|
| 89 |
["一位男士和他的贵宾犬穿着配套的服装参加狗狗秀,室内灯光,背景中有观众。"],
|
| 90 |
-
[
|
| 91 |
-
"极具氛围感的暗调人像,一位优雅的中国美女在黑暗的房间里。一束强光通过遮光板,在她的脸上投射出一个清晰的闪电形状的光影,正好照亮一只眼睛。高对比度,明暗交界清晰,神秘感,莱卡相机色调。"
|
| 92 |
-
],
|
| 93 |
-
[
|
| 94 |
-
"一张中景手机自拍照片拍摄了一位留着长黑发的年轻东亚女子在灯光明亮的电梯内对着镜子自拍。她穿着一件带有白色花朵图案的黑色露肩短上衣和深色牛仔裤。她的头微微倾斜,嘴唇嘟起做亲吻状,非常可爱俏皮。她右手拿着一部深灰色智能手机,遮住了部分脸,后置摄像头镜头对着镜子"
|
| 95 |
-
],
|
| 96 |
-
[
|
| 97 |
-
"Young Chinese woman in red Hanfu, intricate embroidery. Impeccable makeup, red floral forehead pattern. Elaborate high bun, golden phoenix headdress, red flowers, beads. Holds round folding fan with lady, trees, bird. Neon lightning-bolt lamp (⚡️), bright yellow glow, above extended left palm. Soft-lit outdoor night background, silhouetted tiered pagoda (西安大雁塔), blurred colorful distant lights."
|
| 98 |
-
],
|
| 99 |
-
[
|
| 100 |
-
'''A vertical digital illustration depicting a serene and majestic Chinese landscape, rendered in a style reminiscent of traditional Shanshui painting but with a modern, clean aesthetic. The scene is dominated by towering, steep cliffs in various shades of blue and teal, which frame a central valley. In the distance, layers of mountains fade into a light blue and white mist, creating a strong sense of atmospheric perspective and depth. A calm, turquoise river flows through the center of the composition, with a small, traditional Chinese boat, possibly a sampan, navigating its waters. The boat has a bright yellow canopy and a red hull, and it leaves a gentle wake behind it. It carries several indistinct figures of people. Sparse vegetation, including green trees and some bare-branched trees, clings to the rocky ledges and peaks. The overall lighting is soft and diffused, casting a tranquil glow over the entire scene. Centered in the image is overlaid text. At the top of the text block is a small, red, circular seal-like logo containing stylized characters. Below it, in a smaller, black, sans-serif font, are the words 'Zao-Xiang * East Beauty & West Fashion * Z-Image'. Directly beneath this, in a larger, elegant black serif font, is the word 'SHOW & SHARE CREATIVITY WITH THE WORLD'. Among them, there are "SHOW & SHARE", "CREATIVITY", and "WITH THE WORLD"'''
|
| 101 |
-
],
|
| 102 |
-
[
|
| 103 |
-
"""一张虚构的英语电影《回忆之味》(The Taste of Memory)的电影海报。场景设置在一个质朴的19世纪风格厨房里。画面中央,一位红棕色头发、留着小胡子的中年男子(演员阿瑟·彭哈利根饰)站在一张木桌后,他身穿白色衬衫、黑色马甲和米色围裙,正看着一位女士,手中拿着一大块生红肉,下方是一个木制切菜板。在他的右边,一位梳着高髻的黑发女子(演员埃莉诺·万斯饰)倚靠在桌子上,温柔地对他微笑。她穿着浅色衬衫和一条上白下蓝的长裙。桌上除了放有切碎的葱和卷心菜丝的切菜板外,还有一个白色陶瓷盘、新鲜香草,左侧一个木箱上放着一串深色葡萄。背景是一面粗糙的灰白色抹灰墙,墙上挂着一幅风景画。最右边的一个台面上放着一盏复古油灯。海报上有大量的文字信息。左上角是白色的无衬线字体"ARTISAN FILMS PRESENTS",其下方是"ELEANOR VANCE"和"ACADEMY AWARD® WINNER"。右上角写着"ARTHUR PENHALIGON"和"GOLDEN GLOBE® AWARD WINNER"。顶部中央是圣丹斯电影节的桂冠标志,下方写着"SUNDANCE FILM FESTIVAL GRAND JURY PRIZE 2024"。主标题"THE TASTE OF MEMORY"以白色的大号衬线字体醒目地显示在下半部分。标题下方注明了"A FILM BY Tongyi Interaction Lab"。底部区域用白色小字列出了完整的演职员名单,包括"SCREENPLAY BY ANNA REID"、"CULINARY DIRECTION BY JAMES CARTER"以及Artisan Films、Riverstone Pictures和Heritage Media等众多出品公司标志。整体风格是写实主义,采用温暖柔和的灯光方案,营造出一种亲密的氛围。色调以棕色、米色和柔和的绿色等大地色系为主。两位演员的身体都在腰部被截断。"""
|
| 104 |
-
],
|
| 105 |
-
[
|
| 106 |
-
"""一张方形构图的特写照片,主体是一片巨大的、鲜绿色的植物叶片,并叠加了文字,使其具有海报或杂志封面的外观。主要拍摄对象是一片厚实、有蜡质感的叶子,从左下角到右上角呈对角线弯曲穿过画面。其表面反光性很强,捕捉到一个明亮的直射光源,形成了一道突出的高光,亮面下显露出平行的精细叶脉。背景由其他深绿色的叶子组成,这些叶子轻微失焦,营造出浅景深效果,突出了前景的主叶片。整体风格是写实摄影,明亮的叶片与黑暗的阴影背景之间形成高对比度。图像上有多处渲染文字。左上角是白色的衬线字体文字"PIXEL-PEEPERS GUILD Presents"。右上角同样是白色衬线字体的文字"[Instant Noodle] 泡面调料包"。左侧垂直排列着标题"Render Distance: Max",为白色衬线字体。左下角是五个硕大的白色宋体汉字"显卡在...燃烧"。右下角是较小的白色衬线字体文字"Leica Glow™ Unobtanium X-1",其正上方是用白色宋体字书写的名字"蔡几"。识别出的核心实体包括品牌像素偷窥者协会、其产品线泡面调料包、相机型号买不到™ X-1以及摄影师名字造相。"""
|
| 107 |
-
],
|
| 108 |
]
|
| 109 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 110 |
|
| 111 |
-
def get_resolution(resolution):
|
| 112 |
match = re.search(r"(\d+)\s*[×x]\s*(\d+)", resolution)
|
| 113 |
if match:
|
| 114 |
return int(match.group(1)), int(match.group(2))
|
| 115 |
return 1024, 1024
|
| 116 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 117 |
|
| 118 |
-
def load_models(model_path, enable_compile=False, attention_backend="native"):
|
| 119 |
-
print(f"Loading models from {model_path}
|
|
|
|
| 120 |
|
| 121 |
-
|
|
|
|
|
|
|
| 122 |
|
| 123 |
-
|
|
|
|
| 124 |
vae = AutoencoderKL.from_pretrained(
|
| 125 |
-
|
| 126 |
subfolder="vae",
|
| 127 |
-
torch_dtype=torch.
|
| 128 |
-
|
| 129 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 130 |
)
|
| 131 |
|
|
|
|
|
|
|
| 132 |
text_encoder = AutoModelForCausalLM.from_pretrained(
|
| 133 |
-
|
| 134 |
subfolder="text_encoder",
|
| 135 |
-
torch_dtype=torch.
|
| 136 |
-
|
| 137 |
-
use_auth_token=use_auth_token,
|
| 138 |
).eval()
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
os.path.join(model_path, "vae"), torch_dtype=torch.bfloat16, device_map="cuda"
|
| 144 |
)
|
| 145 |
-
|
| 146 |
text_encoder = AutoModelForCausalLM.from_pretrained(
|
| 147 |
os.path.join(model_path, "text_encoder"),
|
| 148 |
-
torch_dtype=torch.
|
| 149 |
-
device_map="cuda",
|
| 150 |
).eval()
|
| 151 |
-
|
| 152 |
tokenizer = AutoTokenizer.from_pretrained(os.path.join(model_path, "tokenizer"))
|
| 153 |
|
| 154 |
tokenizer.padding_side = "left"
|
| 155 |
|
| 156 |
-
|
| 157 |
-
|
|
|
|
| 158 |
torch._inductor.config.conv_1x1_as_mm = True
|
| 159 |
torch._inductor.config.coordinate_descent_tuning = True
|
| 160 |
torch._inductor.config.epilogue_fusion = False
|
|
@@ -165,42 +193,68 @@ def load_models(model_path, enable_compile=False, attention_backend="native"):
|
|
| 165 |
|
| 166 |
pipe = ZImagePipeline(scheduler=None, vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, transformer=None)
|
| 167 |
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
if not os.path.exists(model_path):
|
| 172 |
transformer = ZImageTransformer2DModel.from_pretrained(
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
transformer = ZImageTransformer2DModel.from_pretrained(os.path.join(model_path, "transformer")).to(
|
| 177 |
-
"cuda", torch.bfloat16
|
| 178 |
)
|
|
|
|
|
|
|
| 179 |
|
|
|
|
| 180 |
pipe.transformer = transformer
|
| 181 |
-
pipe.transformer.set_attention_backend(attention_backend)
|
| 182 |
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
pipe.transformer
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 186 |
|
| 187 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 188 |
|
| 189 |
-
|
| 190 |
-
from transformers import CLIPImageProcessor
|
| 191 |
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 195 |
|
| 196 |
-
pipe.safety_feature_extractor = safety_feature_extractor
|
| 197 |
-
pipe.safety_checker = safety_checker
|
| 198 |
return pipe
|
| 199 |
|
| 200 |
-
|
| 201 |
def generate_image(
|
| 202 |
pipe,
|
| 203 |
-
prompt,
|
| 204 |
resolution="1024x1024",
|
| 205 |
seed=42,
|
| 206 |
guidance_scale=5.0,
|
|
@@ -211,48 +265,46 @@ def generate_image(
|
|
| 211 |
):
|
| 212 |
width, height = get_resolution(resolution)
|
| 213 |
|
| 214 |
-
|
|
|
|
|
|
|
|
|
|
| 215 |
|
| 216 |
-
scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=shift)
|
| 217 |
pipe.scheduler = scheduler
|
| 218 |
|
| 219 |
-
|
| 220 |
prompt=prompt,
|
| 221 |
-
height=height,
|
| 222 |
-
width=width,
|
| 223 |
-
guidance_scale=guidance_scale,
|
| 224 |
-
num_inference_steps=num_inference_steps,
|
| 225 |
generator=generator,
|
| 226 |
-
max_sequence_length=max_sequence_length,
|
| 227 |
-
)
|
| 228 |
-
|
| 229 |
return image
|
| 230 |
|
| 231 |
-
|
| 232 |
def warmup_model(pipe, resolutions):
|
| 233 |
-
print("Starting warmup phase...")
|
| 234 |
-
|
| 235 |
dummy_prompt = "warmup"
|
| 236 |
-
|
| 237 |
for res_str in resolutions:
|
| 238 |
-
print(f"
|
| 239 |
try:
|
| 240 |
-
for i in range(
|
| 241 |
generate_image(
|
| 242 |
pipe,
|
| 243 |
prompt=dummy_prompt,
|
| 244 |
-
resolution=res_str,
|
| 245 |
-
num_inference_steps=
|
| 246 |
guidance_scale=0.0,
|
| 247 |
seed=42 + i,
|
| 248 |
)
|
| 249 |
except Exception as e:
|
| 250 |
-
print(f"Warmup
|
|
|
|
| 251 |
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
# ==================== Prompt Expander ====================
|
| 256 |
@dataclass
|
| 257 |
class PromptOutput:
|
| 258 |
status: bool
|
|
@@ -261,7 +313,6 @@ class PromptOutput:
|
|
| 261 |
system_prompt: str
|
| 262 |
message: str
|
| 263 |
|
| 264 |
-
|
| 265 |
class PromptExpander:
|
| 266 |
def __init__(self, backend="api", **kwargs):
|
| 267 |
self.backend = backend
|
|
@@ -269,7 +320,6 @@ class PromptExpander:
|
|
| 269 |
def decide_system_prompt(self, template_name=None):
|
| 270 |
return prompt_template
|
| 271 |
|
| 272 |
-
|
| 273 |
class APIPromptExpander(PromptExpander):
|
| 274 |
def __init__(self, api_config=None, **kwargs):
|
| 275 |
super().__init__(backend="api", **kwargs)
|
|
@@ -284,15 +334,15 @@ class APIPromptExpander(PromptExpander):
|
|
| 284 |
base_url = self.api_config.get("base_url", "https://dashscope.aliyuncs.com/compatible-mode/v1")
|
| 285 |
|
| 286 |
if not api_key:
|
| 287 |
-
print("Warning: DASHSCOPE_API_KEY not found.")
|
| 288 |
return None
|
| 289 |
|
| 290 |
return OpenAI(api_key=api_key, base_url=base_url)
|
| 291 |
except ImportError:
|
| 292 |
-
print("Please install openai: pip install openai")
|
| 293 |
return None
|
| 294 |
except Exception as e:
|
| 295 |
-
print(f"Failed to initialize API client: {e}")
|
| 296 |
return None
|
| 297 |
|
| 298 |
def __call__(self, prompt, system_prompt=None, seed=-1, **kwargs):
|
|
@@ -300,7 +350,7 @@ class APIPromptExpander(PromptExpander):
|
|
| 300 |
|
| 301 |
def extend(self, prompt, system_prompt=None, seed=-1, **kwargs):
|
| 302 |
if self.client is None:
|
| 303 |
-
return PromptOutput(False, "", seed, system_prompt, "API client not initialized")
|
| 304 |
|
| 305 |
if system_prompt is None:
|
| 306 |
system_prompt = self.decide_system_prompt()
|
|
@@ -317,65 +367,60 @@ class APIPromptExpander(PromptExpander):
|
|
| 317 |
temperature=0.7,
|
| 318 |
top_p=0.8,
|
| 319 |
)
|
|
|
|
| 320 |
|
| 321 |
-
|
|
|
|
| 322 |
json_start = content.find("```json")
|
| 323 |
if json_start != -1:
|
| 324 |
json_end = content.find("```", json_start + 7)
|
| 325 |
-
|
| 326 |
json_str = content[json_start + 7 : json_end].strip()
|
| 327 |
-
|
| 328 |
-
|
| 329 |
-
|
| 330 |
-
|
| 331 |
-
|
| 332 |
-
expanded_prompt = content
|
| 333 |
|
| 334 |
-
return PromptOutput(
|
| 335 |
-
status=True, prompt=expanded_prompt, seed=seed, system_prompt=system_prompt, message=content
|
| 336 |
-
)
|
| 337 |
except Exception as e:
|
| 338 |
return PromptOutput(False, "", seed, system_prompt, str(e))
|
| 339 |
|
| 340 |
-
|
| 341 |
def create_prompt_expander(backend="api", **kwargs):
|
| 342 |
if backend == "api":
|
| 343 |
return APIPromptExpander(**kwargs)
|
| 344 |
raise ValueError("Only 'api' backend is supported.")
|
| 345 |
|
| 346 |
-
|
| 347 |
pipe = None
|
| 348 |
prompt_expander = None
|
| 349 |
|
| 350 |
-
|
| 351 |
def init_app():
|
| 352 |
global pipe, prompt_expander
|
| 353 |
|
| 354 |
try:
|
| 355 |
pipe = load_models(MODEL_PATH, enable_compile=ENABLE_COMPILE, attention_backend=ATTENTION_BACKEND)
|
| 356 |
-
print(
|
| 357 |
|
| 358 |
-
if ENABLE_WARMUP:
|
| 359 |
-
|
| 360 |
for cat in RES_CHOICES.values():
|
| 361 |
-
|
| 362 |
-
warmup_model(pipe,
|
| 363 |
|
| 364 |
except Exception as e:
|
| 365 |
-
print(f"Error loading model: {e}")
|
| 366 |
pipe = None
|
| 367 |
|
| 368 |
try:
|
| 369 |
prompt_expander = create_prompt_expander(backend="api", api_config={"model": "qwen3-max-preview"})
|
| 370 |
-
print("Prompt expander
|
| 371 |
except Exception as e:
|
| 372 |
-
print(f"Error initializing prompt expander: {e}")
|
| 373 |
prompt_expander = None
|
| 374 |
|
| 375 |
-
|
| 376 |
-
def prompt_enhance(prompt, enable_enhance):
|
| 377 |
if not enable_enhance or not prompt_expander:
|
| 378 |
-
return prompt, "Enhancement disabled or
|
| 379 |
|
| 380 |
if not prompt.strip():
|
| 381 |
return "", "Please enter a prompt."
|
|
@@ -384,11 +429,35 @@ def prompt_enhance(prompt, enable_enhance):
|
|
| 384 |
result = prompt_expander(prompt)
|
| 385 |
if result.status:
|
| 386 |
return result.prompt, result.message
|
| 387 |
-
|
| 388 |
-
return prompt, f"Enhancement failed: {result.message}"
|
| 389 |
except Exception as e:
|
| 390 |
return prompt, f"Error: {str(e)}"
|
| 391 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 392 |
|
| 393 |
@spaces.GPU
|
| 394 |
def generate(
|
|
@@ -399,175 +468,42 @@ def generate(
|
|
| 399 |
shift=3.0,
|
| 400 |
random_seed=True,
|
| 401 |
gallery_images=None,
|
| 402 |
-
enhance=False,
|
| 403 |
progress=gr.Progress(track_tqdm=True),
|
| 404 |
):
|
| 405 |
-
"""
|
| 406 |
-
Generate an image using the Z-Image model based on the provided prompt and settings.
|
| 407 |
-
This function is triggered when the user clicks the "Generate" button. It processes
|
| 408 |
-
the input prompt (optionally enhancing it), configures generation parameters, and
|
| 409 |
-
produces an image using the Z-Image diffusion transformer pipeline.
|
| 410 |
-
Args:
|
| 411 |
-
prompt (str): Text prompt describing the desired image content
|
| 412 |
-
resolution (str): Output resolution in format "WIDTHxHEIGHT ( RATIO )" (e.g., "1024x1024 ( 1:1 )")
|
| 413 |
-
seed (int): Seed for reproducible generation
|
| 414 |
-
steps (int): Number of inference steps for the diffusion process
|
| 415 |
-
shift (float): Time shift parameter for the flow matching scheduler
|
| 416 |
-
random_seed (bool): Whether to generate a new random seed, if True will ignore the seed input
|
| 417 |
-
gallery_images (list): List of previously generated images to append to (only needed for the Gradio UI)
|
| 418 |
-
enhance (bool): This was Whether to enhance the prompt (DISABLED! Do not use)
|
| 419 |
-
progress (gr.Progress): Gradio progress tracker for displaying generation progress (only needed for the Gradio UI)
|
| 420 |
-
Returns:
|
| 421 |
-
tuple: (gallery_images, seed_str, seed_int)
|
| 422 |
-
- gallery_images: Updated list of generated images including the new image
|
| 423 |
-
- seed_str: String representation of the seed used for generation
|
| 424 |
-
- seed_int: Integer representation of the seed used for generation
|
| 425 |
-
"""
|
| 426 |
-
|
| 427 |
if random_seed:
|
| 428 |
new_seed = random.randint(1, 1000000)
|
| 429 |
else:
|
| 430 |
-
new_seed = seed if seed != -1 else random.randint(1, 1000000)
|
| 431 |
-
|
| 432 |
-
class UnsafeContentError(Exception):
|
| 433 |
-
pass
|
| 434 |
-
|
| 435 |
-
try:
|
| 436 |
-
if pipe is None:
|
| 437 |
-
raise gr.Error("Model not loaded.")
|
| 438 |
-
|
| 439 |
-
has_unsafe_concept = is_unsafe_prompt(
|
| 440 |
-
pipe.text_encoder,
|
| 441 |
-
pipe.tokenizer,
|
| 442 |
-
system_prompt=UNSAFE_PROMPT_CHECK,
|
| 443 |
-
user_prompt=prompt,
|
| 444 |
-
max_new_token=UNSAFE_MAX_NEW_TOKEN,
|
| 445 |
-
)
|
| 446 |
-
if has_unsafe_concept:
|
| 447 |
-
raise UnsafeContentError("Input unsafe")
|
| 448 |
-
|
| 449 |
-
final_prompt = prompt
|
| 450 |
-
|
| 451 |
-
if enhance:
|
| 452 |
-
final_prompt, _ = prompt_enhance(prompt, True)
|
| 453 |
-
print(f"Enhanced prompt: {final_prompt}")
|
| 454 |
-
|
| 455 |
-
try:
|
| 456 |
-
resolution_str = resolution.split(" ")[0]
|
| 457 |
-
except:
|
| 458 |
-
resolution_str = "1024x1024"
|
| 459 |
-
|
| 460 |
-
image = generate_image(
|
| 461 |
-
pipe=pipe,
|
| 462 |
-
prompt=final_prompt,
|
| 463 |
-
resolution=resolution_str,
|
| 464 |
-
seed=new_seed,
|
| 465 |
-
guidance_scale=0.0,
|
| 466 |
-
num_inference_steps=int(steps + 1),
|
| 467 |
-
shift=shift,
|
| 468 |
-
)
|
| 469 |
-
|
| 470 |
-
safety_checker_input = pipe.safety_feature_extractor([image], return_tensors="pt").pixel_values.cuda()
|
| 471 |
-
_, has_nsfw_concept = pipe.safety_checker(images=[torch.zeros(1)], clip_input=safety_checker_input)
|
| 472 |
-
has_nsfw_concept = has_nsfw_concept[0]
|
| 473 |
-
if has_nsfw_concept:
|
| 474 |
-
raise UnsafeContentError("input unsafe")
|
| 475 |
|
| 476 |
-
|
| 477 |
-
|
| 478 |
|
| 479 |
-
|
| 480 |
-
|
| 481 |
-
|
| 482 |
-
|
|
|
|
| 483 |
|
| 484 |
-
|
| 485 |
-
|
| 486 |
-
|
| 487 |
-
|
| 488 |
-
|
| 489 |
-
|
| 490 |
-
|
| 491 |
-
|
| 492 |
-
|
| 493 |
-
|
| 494 |
-
|
| 495 |
-
|
| 496 |
-
|
| 497 |
-
|
| 498 |
-
|
| 499 |
-
|
| 500 |
-
|
| 501 |
-
)
|
| 502 |
-
|
| 503 |
-
with gr.Row():
|
| 504 |
-
with gr.Column(scale=1):
|
| 505 |
-
prompt_input = gr.Textbox(label="Prompt", lines=3, placeholder="Enter your prompt here...")
|
| 506 |
-
# PE components (Temporarily disabled)
|
| 507 |
-
# with gr.Row():
|
| 508 |
-
# enable_enhance = gr.Checkbox(label="Enhance Prompt (DashScope)", value=False)
|
| 509 |
-
# enhance_btn = gr.Button("Enhance Only")
|
| 510 |
-
|
| 511 |
-
with gr.Row():
|
| 512 |
-
choices = [int(k) for k in RES_CHOICES.keys()]
|
| 513 |
-
res_cat = gr.Dropdown(value=1024, choices=choices, label="Resolution Category")
|
| 514 |
-
|
| 515 |
-
initial_res_choices = RES_CHOICES["1024"]
|
| 516 |
-
resolution = gr.Dropdown(
|
| 517 |
-
value=initial_res_choices[0], choices=RESOLUTION_SET, label="Width x Height (Ratio)"
|
| 518 |
-
)
|
| 519 |
-
|
| 520 |
-
with gr.Row():
|
| 521 |
-
seed = gr.Number(label="Seed", value=42, precision=0)
|
| 522 |
-
random_seed = gr.Checkbox(label="Random Seed", value=True)
|
| 523 |
-
|
| 524 |
-
with gr.Row():
|
| 525 |
-
steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=8, step=1, interactive=False)
|
| 526 |
-
shift = gr.Slider(label="Time Shift", minimum=1.0, maximum=10.0, value=3.0, step=0.1)
|
| 527 |
-
|
| 528 |
-
generate_btn = gr.Button("Generate", variant="primary")
|
| 529 |
-
|
| 530 |
-
# Example prompts
|
| 531 |
-
gr.Markdown("### 📝 Example Prompts")
|
| 532 |
-
gr.Examples(examples=EXAMPLE_PROMPTS, inputs=prompt_input, label=None)
|
| 533 |
-
|
| 534 |
-
with gr.Column(scale=1):
|
| 535 |
-
output_gallery = gr.Gallery(
|
| 536 |
-
label="Generated Images",
|
| 537 |
-
columns=2,
|
| 538 |
-
rows=2,
|
| 539 |
-
height=600,
|
| 540 |
-
object_fit="contain",
|
| 541 |
-
format="png",
|
| 542 |
-
interactive=False,
|
| 543 |
-
)
|
| 544 |
-
used_seed = gr.Textbox(label="Seed Used", interactive=False)
|
| 545 |
-
|
| 546 |
-
def update_res_choices(_res_cat):
|
| 547 |
-
if str(_res_cat) in RES_CHOICES:
|
| 548 |
-
res_choices = RES_CHOICES[str(_res_cat)]
|
| 549 |
-
else:
|
| 550 |
-
res_choices = RES_CHOICES["1024"]
|
| 551 |
-
return gr.update(value=res_choices[0], choices=res_choices)
|
| 552 |
-
|
| 553 |
-
res_cat.change(update_res_choices, inputs=res_cat, outputs=resolution, api_visibility="private")
|
| 554 |
-
|
| 555 |
-
# PE enhancement button (Temporarily disabled)
|
| 556 |
-
# enhance_btn.click(
|
| 557 |
-
# prompt_enhance,
|
| 558 |
-
# inputs=[prompt_input, enable_enhance],
|
| 559 |
-
# outputs=[prompt_input, final_prompt_output]
|
| 560 |
-
# )
|
| 561 |
-
|
| 562 |
-
generate_btn.click(
|
| 563 |
-
generate,
|
| 564 |
-
inputs=[prompt_input, resolution, seed, steps, shift, random_seed, output_gallery],
|
| 565 |
-
outputs=[output_gallery, used_seed, seed],
|
| 566 |
-
api_visibility="public",
|
| 567 |
)
|
| 568 |
|
| 569 |
-
|
| 570 |
-
|
| 571 |
-
"""
|
| 572 |
-
if __name__ == "__main__":
|
| 573 |
-
demo.launch(css=css, mcp_server=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import os
|
|
|
|
|
|
|
| 2 |
import sys
|
| 3 |
+
import re
|
| 4 |
+
import json
|
| 5 |
+
import random
|
| 6 |
+
import logging
|
| 7 |
import warnings
|
| 8 |
+
from dataclasses import dataclass
|
| 9 |
|
|
|
|
|
|
|
| 10 |
import gradio as gr
|
| 11 |
import torch
|
| 12 |
+
from PIL import Image, ImageDraw, ImageFont
|
| 13 |
|
| 14 |
+
import spaces
|
| 15 |
+
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler
|
| 16 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 17 |
|
| 18 |
+
# ------------------------- 可选依赖:Prompt Enhancer 模板 -------------------------
|
| 19 |
+
# 你的原工程里如果有 pe.py,会自动使用;没有也不会报错(enhance 默认关闭)
|
| 20 |
+
try:
|
| 21 |
+
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
| 22 |
+
from pe import prompt_template # type: ignore
|
| 23 |
+
except Exception:
|
| 24 |
+
prompt_template = (
|
| 25 |
+
"You are a helpful prompt engineer. Expand the user prompt into a richer, detailed prompt. "
|
| 26 |
+
"Return JSON with key revised_prompt."
|
| 27 |
+
)
|
| 28 |
|
| 29 |
+
# ------------------------- Z-Image 相关(依赖你环境中 diffusers 的实现) -------------------------
|
| 30 |
from diffusers import ZImagePipeline
|
| 31 |
from diffusers.models.transformers.transformer_z_image import ZImageTransformer2DModel
|
| 32 |
|
|
|
|
|
|
|
| 33 |
# ==================== Environment Variables ==================================
|
| 34 |
MODEL_PATH = os.environ.get("MODEL_PATH", "Tongyi-MAI/Z-Image-Turbo")
|
| 35 |
ENABLE_COMPILE = os.environ.get("ENABLE_COMPILE", "true").lower() == "true"
|
| 36 |
ENABLE_WARMUP = os.environ.get("ENABLE_WARMUP", "true").lower() == "true"
|
| 37 |
ATTENTION_BACKEND = os.environ.get("ATTENTION_BACKEND", "flash_3")
|
|
|
|
| 38 |
DASHSCOPE_API_KEY = os.environ.get("DASHSCOPE_API_KEY")
|
| 39 |
HF_TOKEN = os.environ.get("HF_TOKEN")
|
|
|
|
| 40 |
# =============================================================================
|
| 41 |
|
|
|
|
| 42 |
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 43 |
warnings.filterwarnings("ignore")
|
| 44 |
logging.getLogger("transformers").setLevel(logging.ERROR)
|
| 45 |
|
| 46 |
+
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 47 |
+
DTYPE = torch.bfloat16 if DEVICE == "cuda" else torch.float32
|
| 48 |
+
|
| 49 |
RES_CHOICES = {
|
| 50 |
"1024": [
|
| 51 |
"1024x1024 ( 1:1 )",
|
|
|
|
| 89 |
}
|
| 90 |
|
| 91 |
RESOLUTION_SET = []
|
| 92 |
+
for _k, v in RES_CHOICES.items():
|
| 93 |
+
RESOLUTION_SET.extend(v)
|
| 94 |
|
| 95 |
EXAMPLE_PROMPTS = [
|
| 96 |
["一位男士和他的贵宾犬穿着配套的服装参加狗狗秀,室内灯光,背景中有观众。"],
|
| 97 |
+
["极具氛围感的暗调人像,一位优雅的中国美女在黑暗的房间里。一束强光通过遮光板,在她的脸上投射出一个清晰的闪电形状的光影,正好照亮一只眼睛。高对比度,明暗交界清晰,神秘感,莱卡相机色调。"],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 98 |
]
|
| 99 |
|
| 100 |
+
# ------------------------- HF token 兼容参数 -------------------------
|
| 101 |
+
def _hf_token_kwargs(token: str | None):
|
| 102 |
+
"""
|
| 103 |
+
transformers / diffusers 的 from_pretrained 近年来从 use_auth_token 迁移到 token。
|
| 104 |
+
这里做一个兼容:优先传 token,不支持则回退 use_auth_token。
|
| 105 |
+
"""
|
| 106 |
+
if not token:
|
| 107 |
+
return {}
|
| 108 |
+
return {"token": token, "use_auth_token": token}
|
| 109 |
|
| 110 |
+
def get_resolution(resolution: str):
|
| 111 |
match = re.search(r"(\d+)\s*[×x]\s*(\d+)", resolution)
|
| 112 |
if match:
|
| 113 |
return int(match.group(1)), int(match.group(2))
|
| 114 |
return 1024, 1024
|
| 115 |
|
| 116 |
+
def _make_blocked_image(width=1024, height=1024, text="Blocked by Safety Checker"):
|
| 117 |
+
img = Image.new("RGB", (width, height), (20, 20, 20))
|
| 118 |
+
draw = ImageDraw.Draw(img)
|
| 119 |
+
try:
|
| 120 |
+
font = ImageFont.load_default()
|
| 121 |
+
except Exception:
|
| 122 |
+
font = None
|
| 123 |
+
draw.rectangle([0, 0, width, 90], fill=(160, 0, 0))
|
| 124 |
+
draw.text((20, 30), text, fill=(255, 255, 255), font=font)
|
| 125 |
+
return img
|
| 126 |
+
|
| 127 |
+
def _load_nsfw_placeholder(width=1024, height=1024):
|
| 128 |
+
"""
|
| 129 |
+
命中 NSFW 时优先加载工作目录的 nsfw.png;
|
| 130 |
+
不存在就生成一张占位图,避免文件缺失导致再次报错。
|
| 131 |
+
"""
|
| 132 |
+
if os.path.exists("nsfw.png"):
|
| 133 |
+
try:
|
| 134 |
+
return Image.open("nsfw.png").convert("RGB")
|
| 135 |
+
except Exception:
|
| 136 |
+
pass
|
| 137 |
+
return _make_blocked_image(width, height, "NSFW blocked")
|
| 138 |
|
| 139 |
+
def load_models(model_path: str, enable_compile=False, attention_backend="native"):
|
| 140 |
+
print(f"[Init] Loading models from: {model_path}")
|
| 141 |
+
print(f"[Init] DEVICE={DEVICE}, DTYPE={DTYPE}, ENABLE_COMPILE={enable_compile}, ATTENTION_BACKEND={attention_backend}")
|
| 142 |
|
| 143 |
+
# 远端 repo-id(不存在的本地路径) vs 本地目录
|
| 144 |
+
is_local_dir = os.path.exists(model_path)
|
| 145 |
+
token_kwargs = _hf_token_kwargs(HF_TOKEN) if not is_local_dir else {}
|
| 146 |
|
| 147 |
+
# 1) VAE
|
| 148 |
+
if not is_local_dir:
|
| 149 |
vae = AutoencoderKL.from_pretrained(
|
| 150 |
+
model_path,
|
| 151 |
subfolder="vae",
|
| 152 |
+
torch_dtype=DTYPE if DEVICE == "cuda" else torch.float32,
|
| 153 |
+
**token_kwargs,
|
| 154 |
+
)
|
| 155 |
+
else:
|
| 156 |
+
vae = AutoencoderKL.from_pretrained(
|
| 157 |
+
os.path.join(model_path, "vae"),
|
| 158 |
+
torch_dtype=DTYPE if DEVICE == "cuda" else torch.float32,
|
| 159 |
)
|
| 160 |
|
| 161 |
+
# 2) Text Encoder + Tokenizer
|
| 162 |
+
if not is_local_dir:
|
| 163 |
text_encoder = AutoModelForCausalLM.from_pretrained(
|
| 164 |
+
model_path,
|
| 165 |
subfolder="text_encoder",
|
| 166 |
+
torch_dtype=DTYPE if DEVICE == "cuda" else torch.float32,
|
| 167 |
+
**token_kwargs,
|
|
|
|
| 168 |
).eval()
|
| 169 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 170 |
+
model_path,
|
| 171 |
+
subfolder="tokenizer",
|
| 172 |
+
**token_kwargs,
|
|
|
|
| 173 |
)
|
| 174 |
+
else:
|
| 175 |
text_encoder = AutoModelForCausalLM.from_pretrained(
|
| 176 |
os.path.join(model_path, "text_encoder"),
|
| 177 |
+
torch_dtype=DTYPE if DEVICE == "cuda" else torch.float32,
|
|
|
|
| 178 |
).eval()
|
|
|
|
| 179 |
tokenizer = AutoTokenizer.from_pretrained(os.path.join(model_path, "tokenizer"))
|
| 180 |
|
| 181 |
tokenizer.padding_side = "left"
|
| 182 |
|
| 183 |
+
# compile 优化(仅 CUDA 才建议打开)
|
| 184 |
+
if enable_compile and DEVICE == "cuda":
|
| 185 |
+
print("[Init] Enabling torch.compile optimizations...")
|
| 186 |
torch._inductor.config.conv_1x1_as_mm = True
|
| 187 |
torch._inductor.config.coordinate_descent_tuning = True
|
| 188 |
torch._inductor.config.epilogue_fusion = False
|
|
|
|
| 193 |
|
| 194 |
pipe = ZImagePipeline(scheduler=None, vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, transformer=None)
|
| 195 |
|
| 196 |
+
# 3) Transformer
|
| 197 |
+
if not is_local_dir:
|
|
|
|
|
|
|
| 198 |
transformer = ZImageTransformer2DModel.from_pretrained(
|
| 199 |
+
model_path,
|
| 200 |
+
subfolder="transformer",
|
| 201 |
+
**token_kwargs,
|
|
|
|
|
|
|
| 202 |
)
|
| 203 |
+
else:
|
| 204 |
+
transformer = ZImageTransformer2DModel.from_pretrained(os.path.join(model_path, "transformer"))
|
| 205 |
|
| 206 |
+
transformer = transformer.to(DEVICE, DTYPE)
|
| 207 |
pipe.transformer = transformer
|
|
|
|
| 208 |
|
| 209 |
+
# attention backend 可能在不同环境不支持,做容错
|
| 210 |
+
try:
|
| 211 |
+
pipe.transformer.set_attention_backend(attention_backend)
|
| 212 |
+
except Exception as e:
|
| 213 |
+
print(f"[Init] set_attention_backend('{attention_backend}') failed, fallback to 'native'. Error: {e}")
|
| 214 |
+
try:
|
| 215 |
+
pipe.transformer.set_attention_backend("native")
|
| 216 |
+
except Exception as e2:
|
| 217 |
+
print(f"[Init] fallback set_attention_backend('native') failed: {e2}")
|
| 218 |
|
| 219 |
+
if enable_compile and DEVICE == "cuda":
|
| 220 |
+
try:
|
| 221 |
+
print("[Init] Compiling transformer...")
|
| 222 |
+
pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune-no-cudagraphs", fullgraph=False)
|
| 223 |
+
except Exception as e:
|
| 224 |
+
print(f"[Init] torch.compile failed, continue without compile. Error: {e}")
|
| 225 |
|
| 226 |
+
pipe = pipe.to(DEVICE, DTYPE)
|
|
|
|
| 227 |
|
| 228 |
+
# 4) Safety Checker(用于生成后过滤)
|
| 229 |
+
try:
|
| 230 |
+
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
|
| 231 |
+
try:
|
| 232 |
+
from transformers import CLIPImageProcessor as _CLIPProcessor
|
| 233 |
+
except Exception:
|
| 234 |
+
# 老版本兼容
|
| 235 |
+
from transformers import CLIPFeatureExtractor as _CLIPProcessor # type: ignore
|
| 236 |
+
|
| 237 |
+
safety_model_id = "CompVis/stable-diffusion-safety-checker"
|
| 238 |
+
safety_feature_extractor = _CLIPProcessor.from_pretrained(safety_model_id, **_hf_token_kwargs(HF_TOKEN))
|
| 239 |
+
safety_checker = StableDiffusionSafetyChecker.from_pretrained(
|
| 240 |
+
safety_model_id,
|
| 241 |
+
torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32,
|
| 242 |
+
**_hf_token_kwargs(HF_TOKEN),
|
| 243 |
+
).to(DEVICE)
|
| 244 |
+
|
| 245 |
+
pipe.safety_feature_extractor = safety_feature_extractor
|
| 246 |
+
pipe.safety_checker = safety_checker
|
| 247 |
+
print("[Init] Safety checker loaded.")
|
| 248 |
+
except Exception as e:
|
| 249 |
+
print(f"[Init] Safety checker init failed. NSFW filtering will be skipped. Error: {e}")
|
| 250 |
+
pipe.safety_feature_extractor = None
|
| 251 |
+
pipe.safety_checker = None
|
| 252 |
|
|
|
|
|
|
|
| 253 |
return pipe
|
| 254 |
|
|
|
|
| 255 |
def generate_image(
|
| 256 |
pipe,
|
| 257 |
+
prompt: str,
|
| 258 |
resolution="1024x1024",
|
| 259 |
seed=42,
|
| 260 |
guidance_scale=5.0,
|
|
|
|
| 265 |
):
|
| 266 |
width, height = get_resolution(resolution)
|
| 267 |
|
| 268 |
+
if DEVICE == "cuda":
|
| 269 |
+
generator = torch.Generator(device="cuda").manual_seed(int(seed))
|
| 270 |
+
else:
|
| 271 |
+
generator = torch.Generator().manual_seed(int(seed))
|
| 272 |
|
| 273 |
+
scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=float(shift))
|
| 274 |
pipe.scheduler = scheduler
|
| 275 |
|
| 276 |
+
out = pipe(
|
| 277 |
prompt=prompt,
|
| 278 |
+
height=int(height),
|
| 279 |
+
width=int(width),
|
| 280 |
+
guidance_scale=float(guidance_scale),
|
| 281 |
+
num_inference_steps=int(num_inference_steps),
|
| 282 |
generator=generator,
|
| 283 |
+
max_sequence_length=int(max_sequence_length),
|
| 284 |
+
)
|
| 285 |
+
image = out.images[0]
|
| 286 |
return image
|
| 287 |
|
|
|
|
| 288 |
def warmup_model(pipe, resolutions):
|
| 289 |
+
print("[Warmup] Starting warmup phase...")
|
|
|
|
| 290 |
dummy_prompt = "warmup"
|
|
|
|
| 291 |
for res_str in resolutions:
|
| 292 |
+
print(f"[Warmup] Resolution: {res_str}")
|
| 293 |
try:
|
| 294 |
+
for i in range(2):
|
| 295 |
generate_image(
|
| 296 |
pipe,
|
| 297 |
prompt=dummy_prompt,
|
| 298 |
+
resolution=res_str.split(" ")[0],
|
| 299 |
+
num_inference_steps=6,
|
| 300 |
guidance_scale=0.0,
|
| 301 |
seed=42 + i,
|
| 302 |
)
|
| 303 |
except Exception as e:
|
| 304 |
+
print(f"[Warmup] Failed for {res_str}: {e}")
|
| 305 |
+
print("[Warmup] Completed.")
|
| 306 |
|
| 307 |
+
# ==================== Prompt Expander(保留但默认不启用) ====================
|
|
|
|
|
|
|
|
|
|
| 308 |
@dataclass
|
| 309 |
class PromptOutput:
|
| 310 |
status: bool
|
|
|
|
| 313 |
system_prompt: str
|
| 314 |
message: str
|
| 315 |
|
|
|
|
| 316 |
class PromptExpander:
|
| 317 |
def __init__(self, backend="api", **kwargs):
|
| 318 |
self.backend = backend
|
|
|
|
| 320 |
def decide_system_prompt(self, template_name=None):
|
| 321 |
return prompt_template
|
| 322 |
|
|
|
|
| 323 |
class APIPromptExpander(PromptExpander):
|
| 324 |
def __init__(self, api_config=None, **kwargs):
|
| 325 |
super().__init__(backend="api", **kwargs)
|
|
|
|
| 334 |
base_url = self.api_config.get("base_url", "https://dashscope.aliyuncs.com/compatible-mode/v1")
|
| 335 |
|
| 336 |
if not api_key:
|
| 337 |
+
print("[PE] Warning: DASHSCOPE_API_KEY not found. Prompt enhance unavailable.")
|
| 338 |
return None
|
| 339 |
|
| 340 |
return OpenAI(api_key=api_key, base_url=base_url)
|
| 341 |
except ImportError:
|
| 342 |
+
print("[PE] Please install openai: pip install openai")
|
| 343 |
return None
|
| 344 |
except Exception as e:
|
| 345 |
+
print(f"[PE] Failed to initialize API client: {e}")
|
| 346 |
return None
|
| 347 |
|
| 348 |
def __call__(self, prompt, system_prompt=None, seed=-1, **kwargs):
|
|
|
|
| 350 |
|
| 351 |
def extend(self, prompt, system_prompt=None, seed=-1, **kwargs):
|
| 352 |
if self.client is None:
|
| 353 |
+
return PromptOutput(False, "", seed, system_prompt or "", "API client not initialized")
|
| 354 |
|
| 355 |
if system_prompt is None:
|
| 356 |
system_prompt = self.decide_system_prompt()
|
|
|
|
| 367 |
temperature=0.7,
|
| 368 |
top_p=0.8,
|
| 369 |
)
|
| 370 |
+
content = response.choices[0].message.content or ""
|
| 371 |
|
| 372 |
+
# 尝试从 ```json 块中解析 revised_prompt
|
| 373 |
+
expanded_prompt = content
|
| 374 |
json_start = content.find("```json")
|
| 375 |
if json_start != -1:
|
| 376 |
json_end = content.find("```", json_start + 7)
|
| 377 |
+
if json_end != -1:
|
| 378 |
json_str = content[json_start + 7 : json_end].strip()
|
| 379 |
+
try:
|
| 380 |
+
data = json.loads(json_str)
|
| 381 |
+
expanded_prompt = data.get("revised_prompt", content)
|
| 382 |
+
except Exception:
|
| 383 |
+
expanded_prompt = content
|
|
|
|
| 384 |
|
| 385 |
+
return PromptOutput(True, expanded_prompt, seed, system_prompt, content)
|
|
|
|
|
|
|
| 386 |
except Exception as e:
|
| 387 |
return PromptOutput(False, "", seed, system_prompt, str(e))
|
| 388 |
|
|
|
|
| 389 |
def create_prompt_expander(backend="api", **kwargs):
|
| 390 |
if backend == "api":
|
| 391 |
return APIPromptExpander(**kwargs)
|
| 392 |
raise ValueError("Only 'api' backend is supported.")
|
| 393 |
|
|
|
|
| 394 |
pipe = None
|
| 395 |
prompt_expander = None
|
| 396 |
|
|
|
|
| 397 |
def init_app():
|
| 398 |
global pipe, prompt_expander
|
| 399 |
|
| 400 |
try:
|
| 401 |
pipe = load_models(MODEL_PATH, enable_compile=ENABLE_COMPILE, attention_backend=ATTENTION_BACKEND)
|
| 402 |
+
print("[Init] Model loaded.")
|
| 403 |
|
| 404 |
+
if ENABLE_WARMUP and pipe is not None:
|
| 405 |
+
all_res = []
|
| 406 |
for cat in RES_CHOICES.values():
|
| 407 |
+
all_res.extend(cat)
|
| 408 |
+
warmup_model(pipe, all_res)
|
| 409 |
|
| 410 |
except Exception as e:
|
| 411 |
+
print(f"[Init] Error loading model: {e}")
|
| 412 |
pipe = None
|
| 413 |
|
| 414 |
try:
|
| 415 |
prompt_expander = create_prompt_expander(backend="api", api_config={"model": "qwen3-max-preview"})
|
| 416 |
+
print("[Init] Prompt expander ready (disabled by default).")
|
| 417 |
except Exception as e:
|
| 418 |
+
print(f"[Init] Error initializing prompt expander: {e}")
|
| 419 |
prompt_expander = None
|
| 420 |
|
| 421 |
+
def prompt_enhance(prompt, enable_enhance: bool):
|
|
|
|
| 422 |
if not enable_enhance or not prompt_expander:
|
| 423 |
+
return prompt, "Enhancement disabled or unavailable."
|
| 424 |
|
| 425 |
if not prompt.strip():
|
| 426 |
return "", "Please enter a prompt."
|
|
|
|
| 429 |
result = prompt_expander(prompt)
|
| 430 |
if result.status:
|
| 431 |
return result.prompt, result.message
|
| 432 |
+
return prompt, f"Enhancement failed: {result.message}"
|
|
|
|
| 433 |
except Exception as e:
|
| 434 |
return prompt, f"Error: {str(e)}"
|
| 435 |
|
| 436 |
+
def try_enable_aoti(pipe):
|
| 437 |
+
"""
|
| 438 |
+
AoTI(ZeroGPU 加速)可用则启用;不可用则跳过,不影响主流程。
|
| 439 |
+
"""
|
| 440 |
+
if pipe is None:
|
| 441 |
+
return
|
| 442 |
+
try:
|
| 443 |
+
# 优先按你原代码的结构尝试:pipe.transformer.layers
|
| 444 |
+
if hasattr(pipe, "transformer") and pipe.transformer is not None:
|
| 445 |
+
target = None
|
| 446 |
+
if hasattr(pipe.transformer, "layers"):
|
| 447 |
+
target = pipe.transformer.layers
|
| 448 |
+
if hasattr(target, "_repeated_blocks"):
|
| 449 |
+
target._repeated_blocks = ["ZImageTransformerBlock"]
|
| 450 |
+
else:
|
| 451 |
+
# 兜底:直接对 transformer 设置
|
| 452 |
+
target = pipe.transformer
|
| 453 |
+
if hasattr(target, "_repeated_blocks"):
|
| 454 |
+
target._repeated_blocks = ["ZImageTransformerBlock"]
|
| 455 |
+
|
| 456 |
+
if target is not None:
|
| 457 |
+
spaces.aoti_blocks_load(target, "zerogpu-aoti/Z-Image", variant="fa3")
|
| 458 |
+
print("[Init] AoTI blocks loaded.")
|
| 459 |
+
except Exception as e:
|
| 460 |
+
print(f"[Init] AoTI not enabled (safe to ignore). Error: {e}")
|
| 461 |
|
| 462 |
@spaces.GPU
|
| 463 |
def generate(
|
|
|
|
| 468 |
shift=3.0,
|
| 469 |
random_seed=True,
|
| 470 |
gallery_images=None,
|
| 471 |
+
enhance=False, # 默认不启用
|
| 472 |
progress=gr.Progress(track_tqdm=True),
|
| 473 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 474 |
if random_seed:
|
| 475 |
new_seed = random.randint(1, 1000000)
|
| 476 |
else:
|
| 477 |
+
new_seed = int(seed) if int(seed) != -1 else random.randint(1, 1000000)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 478 |
|
| 479 |
+
if pipe is None:
|
| 480 |
+
raise gr.Error("Model not loaded. Please check logs.")
|
| 481 |
|
| 482 |
+
final_prompt = prompt or ""
|
| 483 |
+
if enhance:
|
| 484 |
+
# 你原注释说 DISABLED,这里仍保留能力但默认关闭
|
| 485 |
+
final_prompt, _msg = prompt_enhance(final_prompt, True)
|
| 486 |
+
print(f"[PE] Enhanced prompt: {final_prompt}")
|
| 487 |
|
| 488 |
+
# 解析 "1024x1024 ( 1:1 )" -> "1024x1024"
|
| 489 |
+
try:
|
| 490 |
+
resolution_str = str(resolution).split(" ")[0]
|
| 491 |
+
except Exception:
|
| 492 |
+
resolution_str = "1024x1024"
|
| 493 |
+
|
| 494 |
+
width, height = get_resolution(resolution_str)
|
| 495 |
+
|
| 496 |
+
# 生成
|
| 497 |
+
image = generate_image(
|
| 498 |
+
pipe=pipe,
|
| 499 |
+
prompt=final_prompt,
|
| 500 |
+
resolution=resolution_str,
|
| 501 |
+
seed=new_seed,
|
| 502 |
+
guidance_scale=0.0,
|
| 503 |
+
num_inference_steps=int(steps) + 1,
|
| 504 |
+
shift=float(shift),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 505 |
)
|
| 506 |
|
| 507 |
+
# 生成后 NSFW 安全检查(已去掉 prompt_check)
|
| 508 |
+
try:
|
| 509 |
+
if getattr(pipe, "safety_feature_extractor", None) is not None and getattr(pipe, "safety_checker", None) is not_
|
|
|
|
|
|