elismasilva commited on
Commit
682ea96
·
1 Parent(s): 9c55707

update app

Browse files
Files changed (7) hide show
  1. .gitignore +1 -0
  2. README.md +152 -1
  3. app.py +312 -187
  4. flux_pipeline_mod.py +52 -118
  5. infer.py +37 -29
  6. requirements.txt +3 -1
  7. requirements_local.txt +3 -1
.gitignore CHANGED
@@ -10,3 +10,4 @@ venv/
10
  .DS_Store
11
  .gradio
12
  download.py
 
 
10
  .DS_Store
11
  .gradio
12
  download.py
13
+ outputs/
README.md CHANGED
@@ -11,4 +11,155 @@ license: apache-2.0
11
  short_description: Flux 1 Panorama
12
  ---
13
 
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  short_description: Flux 1 Panorama
12
  ---
13
 
14
+ # Panorama FLUX 🏞️✨
15
+
16
+ [![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/elismasilva/flux-1-panorama) <!--- Replace with your final space link -->
17
+
18
+ Create stunning, seamless panoramic images by combining multiple distinct scenes with the power of the **FLUX.1-schnell** model. This application uses an advanced "Mixture of Diffusers" tiling pipeline to generate high-resolution compositions from left, center, and right text prompts.
19
+
20
+ ![Example Panorama Image](https://i.imgur.com/example.png) <!--- Optional: Replace with a link to an example image you generated -->
21
+
22
+ ## What is Panorama FLUX?
23
+
24
+ Panorama FLUX is a creative tool that leverages a sophisticated tiling mechanism to generate a single, wide-format image from three separate text prompts. Instead of stretching a single concept, you can describe different but related scenes for the left, center, and right portions of the image. The pipeline then intelligently generates each part and seamlessly blends them together.
25
+
26
+ This is ideal for:
27
+ * **Creating expansive landscapes:** Describe a beach that transitions into an ocean, which then meets a distant jungle.
28
+ * **Composing complex scenes:** Place different characters or objects side-by-side in a shared environment.
29
+ * **Generating ultra-wide art:** Create unique, high-resolution images perfect for wallpapers or digital art.
30
+
31
+ The core technology uses a custom `FluxMoDTilingPipeline` built on the Diffusers library, specifically adapted for the **FLUX.1-schnell** model's "Embedded Guidance" mechanism for fast, high-quality results.
32
+
33
+ ### Key Features
34
+ * **Multi-Prompt Composition:** Control the left, center, and right of your image with unique prompts.
35
+ * **Seamless Stitching:** Uses advanced blending methods (Cosine or Gaussian) to eliminate visible seams between tiles.
36
+ * **High-Resolution Output:** Generates images far wider than what a standard pipeline can handle in a single pass.
37
+ * **Efficient Memory Management:** Integrates `mmgp` for local use on consumer GPUs and supports standard `diffusers` offloading for cloud environments via the `USE_MMGP` environment variable.
38
+ * **Optimized for FLUX.1-schnell:** Tailored to the 4-step inference and `guidance_scale=0.0` architecture of the distilled FLUX model.
39
+
40
+ ---
41
+
42
+ ## Running the App Locally
43
+
44
+ Follow these steps to run the Gradio application on your own machine.
45
+
46
+ ### 1. Prerequisites
47
+ * Python 3.9+
48
+ * Git and Git LFS installed (`git-lfs` is required to clone large model files).
49
+
50
+ ### 2. Clone the Repository
51
+ ```bash
52
+ git clone https://huggingface.co/spaces/elismasilva/flux-1-panorama
53
+ cd flux-1-panorama
54
+ ```
55
+
56
+ ### 3. Set Up a Virtual Environment (Recommended)
57
+ ```bash
58
+ # Windows
59
+ python -m venv venv
60
+ .\venv\Scripts\activate
61
+
62
+ # macOS / Linux
63
+ python3 -m venv venv
64
+ source venv/bin/activate
65
+ ```
66
+
67
+ ### 4. Install Dependencies
68
+ This project includes a specific requirements file for local execution.
69
+ ```bash
70
+ pip install -r requirements_local.txt
71
+ ```
72
+
73
+ ### 5. Configure the Model Path
74
+ By default, the app is configured to load the model from the Hugging Face Hub (`"black-forest-labs/FLUX.1-schnell"`). If you have downloaded the model locally (e.g., to `F:\models\flux_schnell`), you need to update the path in `app.py`.
75
+
76
+ Open `app.py` and modify this line:
77
+ ```python
78
+ # app.py - Line 26 (approximately)
79
+
80
+ pipe = FluxMoDTilingPipeline.from_pretrained(
81
+ "path/to/your/local/model", # <-- CHANGE THIS
82
+ torch_dtype=torch.bfloat16
83
+ ).to("cuda")
84
+ ```
85
+
86
+ ### 6. Run the Gradio App
87
+ ```bash
88
+ python app.py
89
+ ```
90
+ The application will start and provide a local URL (usually `http://127.0.0.1:7860`) that you can open in your web browser.
91
+
92
+ ---
93
+
94
+ ## Using the Command-Line Script (`infer.py`)
95
+
96
+ The `infer.py` script is a great way to test the pipeline directly, without the Gradio interface. This is useful for debugging, checking performance, and ensuring everything works correctly.
97
+
98
+ ### 1. Configure the Script
99
+ Open the `infer.py` file in a text editor. You can modify the parameters inside the `main()` function to match your desired output.
100
+
101
+ ```python
102
+ # infer.py
103
+
104
+ # ... (imports)
105
+
106
+ def main():
107
+ # --- 1. Load Model ---
108
+ MODEL_PATH = "black-forest-labs/FLUX.1-schnell" # Or your local path
109
+
110
+ # ... (model loading code)
111
+
112
+ # --- 2. Set Up Inference Parameters ---
113
+ prompt_grid = [[
114
+ "Your left prompt here.",
115
+ "Your center prompt here.",
116
+ "Your right prompt here."
117
+ ]]
118
+
119
+ target_height = 1024
120
+ target_width = 3072
121
+ # ... and so on for other parameters like steps, seed, etc.
122
+ ```
123
+
124
+ ### 2. Run the Script
125
+ Execute the script from your terminal:
126
+ ```bash
127
+ python infer.py
128
+ ```
129
+ The script will print its progress to the console, including the `tqdm` progress bar, and save the final image as `inference_output_schnell.png` in the project directory.
130
+
131
+ ---
132
+
133
+ ## Environment Variables
134
+
135
+ ### `USE_MMGP`
136
+ This variable controls which memory optimization strategy to use.
137
+
138
+ * **To use `mmgp` (Recommended for local use):**
139
+ Ensure the variable is **not set**, or set it to `true`. This is the default behavior.
140
+ ```bash
141
+ # (No action needed, or run)
142
+ # Linux/macOS: export USE_MMGP=true
143
+ # Windows CMD: set USE_MMGP=true
144
+ python app.py
145
+ ```
146
+
147
+ * **To disable `mmgp` and use standard `diffusers` CPU offloading (For Hugging Face Spaces or troubleshooting):**
148
+ Set the variable to `false`.
149
+ ```bash
150
+ # Linux/macOS
151
+ USE_MMGP=false python app.py
152
+
153
+ # Windows CMD
154
+ set USE_MMGP=false
155
+ python app.py
156
+
157
+ # Windows PowerShell
158
+ $env:USE_MMGP="false"
159
+ python app.py
160
+ ```
161
+
162
+ ## Acknowledgements
163
+ * **Black Forest Labs** for the powerful FLUX models.
164
+ * The original authors of the **Mixture of Diffusers** technique.
165
+ * **Hugging Face** for the `diffusers` library.
app.py CHANGED
@@ -4,64 +4,103 @@ import os
4
  import random
5
  import numpy as np
6
  import torch
 
7
 
8
- # Import the corrected unified pipeline
9
- from flux_pipeline_mod import FluxMoDTilingPipeline
10
 
11
- # 1. Conditional MMGP Setup based on Environment Variable
12
- # Check the 'USE_MMGP' environment variable. Default to 'true' if not set.
13
- USE_MMGP_ENV = os.getenv('USE_MMGP', 'true').lower()
14
- if USE_MMGP_ENV in ('false', '0', 'no', 'none'):
15
- USE_MMGP = False
16
- print("INFO: USE_MMGP environment variable set to false. MMGP will NOT be used.")
17
- else:
18
- USE_MMGP = True
19
- print("INFO: USE_MMGP is true or not set. Attempting to use MMGP.")
 
 
 
 
 
 
20
 
21
- # Conditionally import mmgp
22
  offload = None
23
  if USE_MMGP:
 
24
  try:
25
  from mmgp import offload, profile_type
 
26
  print("Successfully imported MMGP.")
27
  except ImportError:
28
- print("WARNING: USE_MMGP is true, but the 'mmgp' library could not be found. Falling back to standard offload.")
29
- USE_MMGP = False # Update flag as it can't be used
 
 
30
 
31
  MAX_SEED = np.iinfo(np.int32).max
32
 
33
- # 2. Load the Pipeline
34
- print("Loading the FLUX Tiling pipeline. This may take a moment...")
35
- pipe = FluxMoDTilingPipeline.from_pretrained(
36
- #"F:\\models\\flux_dev",
37
- "black-forest-labs/FLUX.1-schnell",
38
- torch_dtype=torch.bfloat16
39
- ).to("cuda")
 
 
40
 
41
- # 3. Apply Memory Optimization based on the flag
42
  if USE_MMGP and offload:
43
  print("Applying LowRAM_LowVRAM offload profile via MMGP...")
44
  offload.profile(pipe, profile_type.LowRAM_LowVRAM)
45
  else:
46
- print("MMGP is disabled. Attempting to use the standard Diffusers CPU offload...")
47
  try:
48
  pipe.enable_model_cpu_offload()
49
  except Exception as e:
50
  print(f"Could not apply standard offload: {e}")
51
 
52
- #pipe.enable_vae_tiling()
53
- #pipe.enable_vae_slicing()
54
-
55
  print("Pipeline loaded and ready.")
56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  def create_hdr_effect(image, hdr_strength):
58
  if hdr_strength == 0:
59
  return image
60
  from PIL import ImageEnhance, Image
61
- if isinstance(image, Image.Image): image = np.array(image)
 
 
62
  from scipy.ndimage import gaussian_filter
 
63
  blurred = gaussian_filter(image, sigma=5)
64
- sharpened = np.clip(image + hdr_strength * (image - blurred), 0, 255).astype(np.uint8)
 
 
65
  pil_img = Image.fromarray(sharpened)
66
  converter = ImageEnhance.Color(pil_img)
67
  return converter.enhance(1 + hdr_strength)
@@ -69,25 +108,38 @@ def create_hdr_effect(image, hdr_strength):
69
 
70
  @spaces.GPU(duration=120)
71
  def predict(
72
- left_prompt, center_prompt, right_prompt, negative_prompt,
73
- left_gs, center_gs, right_gs, overlap_pixels, steps,
74
- generation_seed, tile_weighting_method,
75
- _, __,
76
- target_height, target_width, hdr,
 
 
 
 
 
 
 
 
 
 
 
77
  progress=gr.Progress(track_tqdm=True),
78
  ):
79
  global pipe
80
- generator_device = "cpu"
81
- generator = torch.Generator(generator_device).manual_seed(generation_seed)
82
-
83
  final_height, final_width = int(target_height), int(target_width)
84
 
85
- print("Starting generation with Unified Tiling Pipeline (Composition Mode)...")
 
 
 
 
 
86
  image = pipe(
87
- prompt=[[left_prompt, center_prompt, right_prompt]],
88
  height=final_height,
89
  width=final_width,
90
- negative_prompt=negative_prompt,
91
  tile_overlap=overlap_pixels,
92
  guidance_scale_tiles=[[left_gs, center_gs, right_gs]],
93
  tile_weighting_method=tile_weighting_method,
@@ -98,22 +150,16 @@ def predict(
98
 
99
  return create_hdr_effect(image, hdr)
100
 
 
101
  def do_calc_tile(target_height, target_width, overlap_pixels):
102
  num_cols = 3
103
  num_rows = 1
104
-
105
  tile_width = (target_width + (num_cols - 1) * overlap_pixels) // num_cols
106
  tile_height = (target_height + (num_rows - 1) * overlap_pixels) // num_rows
107
  tile_width -= tile_width % 16
108
  tile_height -= tile_height % 16
109
-
110
  final_width = tile_width * num_cols - (num_cols - 1) * overlap_pixels
111
  final_height = tile_height * num_rows - (num_rows - 1) * overlap_pixels
112
-
113
- print("--- UI Tile Size Preview ---")
114
- print(f"Ideal Tile Height/Width: {tile_height}/{tile_width}")
115
- print(f"Calculated Final Height/Width: {final_height}/{final_width}\n")
116
-
117
  return (
118
  gr.update(value=tile_height),
119
  gr.update(value=tile_width),
@@ -121,122 +167,65 @@ def do_calc_tile(target_height, target_width, overlap_pixels):
121
  gr.update(value=final_width),
122
  )
123
 
 
124
  def clear_result():
125
  return gr.update(value=None)
126
 
 
127
  def run_for_examples(
128
- left_prompt, center_prompt, right_prompt, negative_prompt,
129
- left_gs, center_gs, right_gs, overlap_pixels, steps,
130
- generation_seed, tile_weighting_method, tile_height, tile_width,
131
- target_height, target_width, hdr,
 
 
 
 
 
 
 
 
 
 
 
132
  ):
133
  return predict(
134
- left_prompt, center_prompt, right_prompt, negative_prompt,
135
- left_gs, center_gs, right_gs, overlap_pixels, steps,
136
- generation_seed, tile_weighting_method, tile_height, tile_width,
137
- target_height, target_width, hdr,
 
 
 
 
 
 
 
 
 
 
 
 
138
  )
139
 
 
140
  def randomize_seed_fn(generation_seed: int, randomize_seed: bool) -> int:
141
  if randomize_seed:
142
  generation_seed = random.randint(0, MAX_SEED)
143
  return generation_seed
144
 
 
145
  # UI Layout
146
- css = "..."
147
- title = "..."
148
-
149
- # theme = gr.themes.Default(
150
- # primary_hue='indigo',
151
- # secondary_hue='cyan',
152
- # neutral_hue='gray'
153
- # ).set(
154
- # body_background_fill='*neutral_100',
155
- # body_background_fill_dark='*neutral_900',
156
- # body_text_color='*neutral_900',
157
- # body_text_color_dark='*neutral_100',
158
- # input_background_fill='white',
159
- # input_background_fill_dark='*neutral_800',
160
- # button_primary_background_fill='*primary_500',
161
- # button_primary_background_fill_dark='*primary_700',
162
- # button_primary_text_color='white',
163
- # button_primary_text_color_dark='white',
164
- # button_secondary_background_fill='*secondary_500',
165
- # button_secondary_background_fill_dark='*secondary_700',
166
- # button_secondary_text_color='white',
167
- # button_secondary_text_color_dark='white'
168
- # )
169
  theme = gr.themes.Default(
170
- primary_hue='blue',
171
- secondary_hue='teal',
172
- neutral_hue='neutral'
173
- ).set(
174
- body_background_fill='*neutral_100',
175
- body_background_fill_dark='*neutral_900',
176
- body_text_color='*neutral_700',
177
- body_text_color_dark='*neutral_200',
178
- body_text_weight='400',
179
- link_text_color='*primary_500',
180
- link_text_color_dark='*primary_400',
181
- code_background_fill='*neutral_100',
182
- code_background_fill_dark='*neutral_800',
183
- shadow_drop='0 1px 3px rgba(0,0,0,0.1)',
184
- shadow_inset='inset 0 2px 4px rgba(0,0,0,0.05)',
185
- block_background_fill='*neutral_50',
186
- block_background_fill_dark='*neutral_700',
187
- block_border_color='*neutral_200',
188
- block_border_color_dark='*neutral_600',
189
- block_border_width='1px',
190
- block_border_width_dark='1px',
191
- block_label_background_fill='*primary_50',
192
- block_label_background_fill_dark='*primary_600',
193
- block_label_text_color='*primary_600',
194
- block_label_text_color_dark='*primary_50',
195
- panel_background_fill='white',
196
- panel_background_fill_dark='*neutral_800',
197
- panel_border_color='*neutral_200',
198
- panel_border_color_dark='*neutral_700',
199
- panel_border_width='1px',
200
- panel_border_width_dark='1px',
201
- input_background_fill='white',
202
- input_background_fill_dark='*neutral_800',
203
- input_border_color='*neutral_300',
204
- input_border_color_dark='*neutral_700',
205
- slider_color='*primary_500',
206
- slider_color_dark='*primary_400',
207
- button_primary_background_fill='*primary_600',
208
- button_primary_background_fill_dark='*primary_500',
209
- button_primary_background_fill_hover='*primary_700',
210
- button_primary_background_fill_hover_dark='*primary_400',
211
- button_primary_border_color='transparent',
212
- button_primary_border_color_dark='transparent',
213
- button_primary_text_color='white',
214
- button_primary_text_color_dark='white',
215
- button_secondary_background_fill='*neutral_200',
216
- button_secondary_background_fill_dark='*neutral_600',
217
- button_secondary_background_fill_hover='*neutral_300',
218
- button_secondary_background_fill_hover_dark='*neutral_500',
219
- button_secondary_border_color='transparent',
220
- button_secondary_border_color_dark='transparent',
221
- button_secondary_text_color='*neutral_700',
222
- button_secondary_text_color_dark='*neutral_200'
223
  )
224
- # css = """
225
- # body { font-family: 'Helvetica Neue', Helvetica, Arial, sans-serif; }
226
- # """
227
- css_code = ""
228
- try:
229
- with open("./style.css", "r", encoding="utf-8") as f:
230
- css_code += f.read() + "\n"
231
- except FileNotFoundError:
232
- pass
233
- title = """<h1 align="center">Panorama FLUX - Mixture-of-Diffusers for FLUX ✨</h1>
234
  <div style="text-align: center;">
235
- <span>An advanced tiling pipeline for creative composition and large-scale image generation with the FLUX model.</span>
236
  </div>
237
  """
238
 
239
- with gr.Blocks(css=css_code, theme=theme, title="Panorama FLUX") as app:
240
  gr.Markdown(title)
241
  with gr.Row():
242
  with gr.Column(scale=7):
@@ -245,94 +234,230 @@ with gr.Blocks(css=css_code, theme=theme, title="Panorama FLUX") as app:
245
  with gr.Column(scale=1):
246
  gr.Markdown("### Left Region")
247
  left_prompt = gr.Textbox(lines=4, label="Prompt for left side")
248
- left_gs = gr.Slider(minimum=0, maximum=15, value=7, step=0.5, label="Left CFG scale")
 
 
 
 
 
 
249
  with gr.Column(scale=1):
250
  gr.Markdown("### Center Region")
251
  center_prompt = gr.Textbox(lines=4, label="Prompt for the center")
252
- center_gs = gr.Slider(minimum=0, maximum=15, value=7, step=0.5, label="Center CFG scale")
 
 
 
 
 
 
253
  with gr.Column(scale=1):
254
  gr.Markdown("### Right Region")
255
  right_prompt = gr.Textbox(lines=4, label="Prompt for right side")
256
- right_gs = gr.Slider(minimum=0, maximum=15, value=7, step=0.5, label="Right CFG scale")
 
 
 
 
 
 
 
257
  with gr.Row():
258
- negative_prompt = gr.Textbox(
259
- lines=2,
260
- label="Negative prompt (for the whole image)",
261
- value="nsfw, lowres, bad anatomy, bad hands, duplicate, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, blurry",
 
262
  )
263
- with gr.Row():
264
- result = gr.Image(label="Generated Image", show_label=True, format="png", interactive=False)
265
-
266
  with gr.Sidebar():
267
  gr.Markdown("### Tiling & Generation Parameters")
 
 
 
 
 
 
 
 
 
268
  with gr.Row():
269
- height = gr.Slider(label="Target Height", value=1024, step=16, visible=True, minimum=512, maximum=2048)
270
- width = gr.Slider(label="Target Width", value=3072, step=16, visible=True, minimum=512, maximum=4096)
 
 
 
 
271
  with gr.Row():
272
- overlap = gr.Slider(minimum=0, maximum=512, value=256, step=16, label="Tile Overlap")
273
- tile_weighting_method = gr.Dropdown(label="Blending Method", choices=["Cosine", "Gaussian"], value="Cosine")
 
 
 
 
274
  with gr.Row():
275
- calc_tile = gr.Button("Calculate Final Dimensions")
276
  with gr.Row():
277
- new_target_height = gr.Textbox(label="Actual Image Height", value=1024, interactive=False)
278
- new_target_width = gr.Textbox(label="Actual Image Width", value=3072, interactive=False)
 
 
 
 
279
  with gr.Row():
280
- tile_height = gr.Textbox(label="Ideal Tile Height", value=1024, interactive=False)
281
- tile_width = gr.Textbox(label="Ideal Tile Width", value=1152, interactive=False)
 
 
 
 
282
  with gr.Row():
283
- steps = gr.Slider(minimum=4, maximum=50, value=28, step=1, label="Inference Steps")
 
 
284
  with gr.Row():
285
- generation_seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
 
 
286
  randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
287
  with gr.Row():
288
- hdr = gr.Slider(minimum=0, maximum=1, value=0.1, step=0.05, label="HDR Effect")
289
-
 
 
290
  with gr.Row():
291
  gr.Examples(
292
  examples=[
293
- [
294
- "Iron Man, repulsor rays blasting enemies in destroyed cityscape, cinematic lighting, photorealistic. Focus: Iron Man.",
295
- "Captain America charging forward, vibranium shield deflecting energy blasts in destroyed cityscape, cinematic composition. Focus: Captain America.",
296
- "Thor wielding Stormbreaker in destroyed cityscape, lightning crackling, powerful strike downwards, cinematic photography. Focus: Thor.",
297
- negative_prompt.value, 5, 5, 5, 160, 30, 619517442, "Cosine", 1024, 1152, 1024, 3072, 0,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
298
  ],
299
  [
300
- "A charming house in the countryside, by jakub rozalski, sunset lighting, elegant, highly detailed, sharp focus, artstation, stunning masterpiece",
301
- "A dirt road in the countryside crossing pastures, by jakub rozalski, sunset lighting, elegant, highly detailed, sharp focus, artstation, stunning masterpiece",
302
- "An old and rusty giant robot lying on a dirt road, by jakub rozalski, dark sunset lighting, elegant, highly detailed, sharp focus, artstation, stunning masterpiece",
303
- negative_prompt.value, 7, 7, 7, 256, 28, 358867853, "Gaussian", 1024, 1152, 1024, 2944, 0.1,
 
 
 
 
 
 
 
 
 
 
 
304
  ],
305
  ],
306
  inputs=[
307
- left_prompt, center_prompt, right_prompt, negative_prompt,
308
- left_gs, center_gs, right_gs, overlap, steps,
309
- generation_seed, tile_weighting_method, tile_height, tile_width, height, width, hdr,
 
 
 
 
 
 
 
 
 
 
 
 
310
  ],
311
  fn=run_for_examples,
312
  outputs=result,
313
  cache_examples=False,
314
  )
315
- # Event handling
 
316
  event_calc_tile_size = {
317
  "fn": do_calc_tile,
318
  "inputs": [height, width, overlap],
319
  "outputs": [tile_height, tile_width, new_target_height, new_target_width],
320
  }
321
-
322
  predict_inputs = [
323
- left_prompt, center_prompt, right_prompt, negative_prompt,
324
- left_gs, center_gs, right_gs, overlap, steps,
325
- generation_seed, tile_weighting_method, tile_height, tile_width,
326
- new_target_height, new_target_width, hdr,
 
 
 
 
 
 
 
 
 
 
 
 
327
  ]
 
328
  calc_tile.click(**event_calc_tile_size)
329
  generate_button.click(
330
- fn=clear_result, inputs=None, outputs=result, queue=False,
 
 
 
331
  ).then(**event_calc_tile_size).then(
332
- fn=randomize_seed_fn, inputs=[generation_seed, randomize_seed], outputs=generation_seed, queue=False,
 
 
 
333
  ).then(
334
- fn=predict, inputs=predict_inputs, outputs=result, show_progress='full'
335
  )
336
 
337
  app.queue().launch(share=True)
338
-
 
4
  import random
5
  import numpy as np
6
  import torch
7
+ from transformers import pipeline
8
 
9
+ # Import the pipeline
10
+ from flux_pipeline_mod import FluxMoDTilingPipeline
11
 
12
+ # 1. Load Translation Models ---
13
+ # These models are small and run efficiently on CPU.
14
+ print("Loading translation models...")
15
+ try:
16
+ ko_en_translator = pipeline("translation", model="Helsinki-NLP/opus-mt-ko-en")
17
+ zh_en_translator = pipeline("translation", model="Helsinki-NLP/opus-mt-zh-en")
18
+ print("Translation models loaded successfully.")
19
+ except Exception as e:
20
+ print(f"Could not load translation models: {e}")
21
+ ko_en_translator = None
22
+ zh_en_translator = None
23
+
24
+ # 2. Conditional MMGP Setup ---
25
+ USE_MMGP_ENV = os.getenv("USE_MMGP", "true").lower()
26
+ USE_MMGP = USE_MMGP_ENV not in ("false", "0", "no", "none")
27
 
 
28
  offload = None
29
  if USE_MMGP:
30
+ print("INFO: Attempting to use MMGP.")
31
  try:
32
  from mmgp import offload, profile_type
33
+
34
  print("Successfully imported MMGP.")
35
  except ImportError:
36
+ print("WARNING: MMGP import failed. Falling back to standard offload.")
37
+ USE_MMGP = False
38
+ else:
39
+ print("INFO: MMGP is disabled.")
40
 
41
  MAX_SEED = np.iinfo(np.int32).max
42
 
43
+ # 3. Load the Main Pipeline ---
44
+ print("Loading the FLUX Tiling pipeline...")
45
+ # Use an environment variable for the model path to make it flexible
46
+ MODEL_PATH = os.getenv("MODEL_PATH", "black-forest-labs/FLUX.1-schnell")
47
+ print(f"Loading model from: {MODEL_PATH}")
48
+
49
+ pipe = FluxMoDTilingPipeline.from_pretrained(MODEL_PATH, torch_dtype=torch.bfloat16).to(
50
+ "cuda"
51
+ )
52
 
 
53
  if USE_MMGP and offload:
54
  print("Applying LowRAM_LowVRAM offload profile via MMGP...")
55
  offload.profile(pipe, profile_type.LowRAM_LowVRAM)
56
  else:
57
+ print("Attempting to use the standard Diffusers CPU offload...")
58
  try:
59
  pipe.enable_model_cpu_offload()
60
  except Exception as e:
61
  print(f"Could not apply standard offload: {e}")
62
 
 
 
 
63
  print("Pipeline loaded and ready.")
64
 
65
+
66
+ # Helper Functions
67
+ def translate_prompt(text: str, language: str) -> str:
68
+ """Translates text to English if the selected language is not English."""
69
+ if language == "English" or not text.strip():
70
+ return text
71
+
72
+ translated_text = text
73
+ if language == "Korean" and ko_en_translator:
74
+ if any(
75
+ "\uac00" <= char <= "\ud7a3" for char in text
76
+ ): # Check if Korean characters are present
77
+ print(f"Translating Korean to English: '{text}'")
78
+ translated_text = ko_en_translator(text)[0]["translation_text"]
79
+ print(f" -> Translated: '{translated_text}'")
80
+ elif language == "Chinese" and zh_en_translator:
81
+ if any(
82
+ "\u4e00" <= char <= "\u9fff" for char in text
83
+ ): # Check if Chinese characters are present
84
+ print(f"Translating Chinese to English: '{text}'")
85
+ translated_text = zh_en_translator(text)[0]["translation_text"]
86
+ print(f" -> Translated: '{translated_text}'")
87
+
88
+ return translated_text
89
+
90
+
91
  def create_hdr_effect(image, hdr_strength):
92
  if hdr_strength == 0:
93
  return image
94
  from PIL import ImageEnhance, Image
95
+
96
+ if isinstance(image, Image.Image):
97
+ image = np.array(image)
98
  from scipy.ndimage import gaussian_filter
99
+
100
  blurred = gaussian_filter(image, sigma=5)
101
+ sharpened = np.clip(image + hdr_strength * (image - blurred), 0, 255).astype(
102
+ np.uint8
103
+ )
104
  pil_img = Image.fromarray(sharpened)
105
  converter = ImageEnhance.Color(pil_img)
106
  return converter.enhance(1 + hdr_strength)
 
108
 
109
  @spaces.GPU(duration=120)
110
  def predict(
111
+ left_prompt,
112
+ center_prompt,
113
+ right_prompt,
114
+ left_gs,
115
+ center_gs,
116
+ right_gs,
117
+ overlap_pixels,
118
+ steps,
119
+ generation_seed,
120
+ tile_weighting_method,
121
+ prompt_language,
122
+ _,
123
+ __,
124
+ target_height,
125
+ target_width,
126
+ hdr,
127
  progress=gr.Progress(track_tqdm=True),
128
  ):
129
  global pipe
130
+ generator = torch.Generator("cuda").manual_seed(generation_seed)
 
 
131
  final_height, final_width = int(target_height), int(target_width)
132
 
133
+ # Translate prompts if necessary
134
+ translated_left = translate_prompt(left_prompt, prompt_language)
135
+ translated_center = translate_prompt(center_prompt, prompt_language)
136
+ translated_right = translate_prompt(right_prompt, prompt_language)
137
+
138
+ print("Starting generation with Tiling Pipeline (Composition Mode)...")
139
  image = pipe(
140
+ prompt=[[translated_left, translated_center, translated_right]],
141
  height=final_height,
142
  width=final_width,
 
143
  tile_overlap=overlap_pixels,
144
  guidance_scale_tiles=[[left_gs, center_gs, right_gs]],
145
  tile_weighting_method=tile_weighting_method,
 
150
 
151
  return create_hdr_effect(image, hdr)
152
 
153
+
154
  def do_calc_tile(target_height, target_width, overlap_pixels):
155
  num_cols = 3
156
  num_rows = 1
 
157
  tile_width = (target_width + (num_cols - 1) * overlap_pixels) // num_cols
158
  tile_height = (target_height + (num_rows - 1) * overlap_pixels) // num_rows
159
  tile_width -= tile_width % 16
160
  tile_height -= tile_height % 16
 
161
  final_width = tile_width * num_cols - (num_cols - 1) * overlap_pixels
162
  final_height = tile_height * num_rows - (num_rows - 1) * overlap_pixels
 
 
 
 
 
163
  return (
164
  gr.update(value=tile_height),
165
  gr.update(value=tile_width),
 
167
  gr.update(value=final_width),
168
  )
169
 
170
+
171
  def clear_result():
172
  return gr.update(value=None)
173
 
174
+
175
  def run_for_examples(
176
+ left_prompt,
177
+ center_prompt,
178
+ right_prompt,
179
+ left_gs,
180
+ center_gs,
181
+ right_gs,
182
+ overlap_pixels,
183
+ steps,
184
+ generation_seed,
185
+ tile_weighting_method,
186
+ tile_height,
187
+ tile_width,
188
+ target_height,
189
+ target_width,
190
+ hdr,
191
  ):
192
  return predict(
193
+ left_prompt,
194
+ center_prompt,
195
+ right_prompt,
196
+ left_gs,
197
+ center_gs,
198
+ right_gs,
199
+ overlap_pixels,
200
+ steps,
201
+ generation_seed,
202
+ tile_weighting_method,
203
+ "English",
204
+ tile_height,
205
+ tile_width,
206
+ target_height,
207
+ target_width,
208
+ hdr,
209
  )
210
 
211
+
212
  def randomize_seed_fn(generation_seed: int, randomize_seed: bool) -> int:
213
  if randomize_seed:
214
  generation_seed = random.randint(0, MAX_SEED)
215
  return generation_seed
216
 
217
+
218
  # UI Layout
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
219
  theme = gr.themes.Default(
220
+ primary_hue="blue", secondary_hue="teal", neutral_hue="neutral"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
221
  )
222
+ title = """<h1 align="center">Panorama FLUX 🏞️✨</h1>
 
 
 
 
 
 
 
 
 
223
  <div style="text-align: center;">
224
+ <span>An advanced tiling pipeline for creative composition and large-scale image generation with the FLUX.1-schnell model.</span>
225
  </div>
226
  """
227
 
228
+ with gr.Blocks(theme=theme, title="Panorama FLUX") as app:
229
  gr.Markdown(title)
230
  with gr.Row():
231
  with gr.Column(scale=7):
 
234
  with gr.Column(scale=1):
235
  gr.Markdown("### Left Region")
236
  left_prompt = gr.Textbox(lines=4, label="Prompt for left side")
237
+ left_gs = gr.Slider(
238
+ minimum=0.0,
239
+ maximum=10.0,
240
+ value=0.0,
241
+ step=0.1,
242
+ label="Left Guidance",
243
+ )
244
  with gr.Column(scale=1):
245
  gr.Markdown("### Center Region")
246
  center_prompt = gr.Textbox(lines=4, label="Prompt for the center")
247
+ center_gs = gr.Slider(
248
+ minimum=0.0,
249
+ maximum=10.0,
250
+ value=0.0,
251
+ step=0.1,
252
+ label="Center Guidance",
253
+ )
254
  with gr.Column(scale=1):
255
  gr.Markdown("### Right Region")
256
  right_prompt = gr.Textbox(lines=4, label="Prompt for right side")
257
+ right_gs = gr.Slider(
258
+ minimum=0.0,
259
+ maximum=10.0,
260
+ value=0.0,
261
+ step=0.1,
262
+ label="Right Guidance",
263
+ )
264
+
265
  with gr.Row():
266
+ result = gr.Image(
267
+ label="Generated Image",
268
+ show_label=True,
269
+ format="png",
270
+ interactive=False,
271
  )
272
+
 
 
273
  with gr.Sidebar():
274
  gr.Markdown("### Tiling & Generation Parameters")
275
+
276
+ # New Language Selector
277
+ prompt_language = gr.Radio(
278
+ choices=["English", "Korean", "Chinese"],
279
+ value="English",
280
+ label="Prompt Language",
281
+ info="Select the language you will type your prompts in.",
282
+ )
283
+
284
  with gr.Row():
285
+ height = gr.Slider(
286
+ label="Target Height", value=1024, step=16, minimum=512, maximum=2048
287
+ )
288
+ width = gr.Slider(
289
+ label="Target Width", value=3072, step=16, minimum=512, maximum=4096
290
+ )
291
  with gr.Row():
292
+ overlap = gr.Slider(
293
+ minimum=0, maximum=512, value=256, step=16, label="Tile Overlap"
294
+ )
295
+ tile_weighting_method = gr.Dropdown(
296
+ label="Blending Method", choices=["Cosine", "Gaussian"], value="Cosine"
297
+ )
298
  with gr.Row():
299
+ calc_tile = gr.Button("Calculate Final Dimensions", variant="primary")
300
  with gr.Row():
301
+ new_target_height = gr.Textbox(
302
+ label="Actual Image Height", value=1024, interactive=False
303
+ )
304
+ new_target_width = gr.Textbox(
305
+ label="Actual Image Width", value=3072, interactive=False
306
+ )
307
  with gr.Row():
308
+ tile_height = gr.Textbox(
309
+ label="Ideal Tile Height", value=1024, interactive=False
310
+ )
311
+ tile_width = gr.Textbox(
312
+ label="Ideal Tile Width", value=1152, interactive=False
313
+ )
314
  with gr.Row():
315
+ steps = gr.Slider(
316
+ minimum=1, maximum=10, value=4, step=1, label="Inference Steps"
317
+ )
318
  with gr.Row():
319
+ generation_seed = gr.Slider(
320
+ label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0
321
+ )
322
  randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
323
  with gr.Row():
324
+ hdr = gr.Slider(
325
+ minimum=0.0, maximum=1.0, value=0.1, step=0.00, label="HDR Effect"
326
+ )
327
+
328
  with gr.Row():
329
  gr.Examples(
330
  examples=[
331
+ [
332
+ "A vibrant medieval marketplace...",
333
+ "A majestic stone castle...",
334
+ "A dense, dark forest...",
335
+ 0.0,
336
+ 0.0,
337
+ 0.0,
338
+ 256,
339
+ 4,
340
+ 12345,
341
+ "Cosine",
342
+ 1024,
343
+ 1152,
344
+ 1024,
345
+ 3072,
346
+ 0,
347
+ ],
348
+ [
349
+ "A vibrant mountain slope in full spring bloom, covered in colorful wildflowers and lush green grass, a small stream meandering down, cinematic photo, bright morning light.",
350
+ "The majestic, rocky peak of the same mountain under a clear summer sky, patches of green tundra, eagles soaring high above, strong midday sun. cinematic photo.",
351
+ "The other side of the mountain descending into a valley ablaze with autumn colors, forests of red, orange, and yellow trees, a gentle haze in the air. cinematic photo, golden hour light.",
352
+ 0.0,
353
+ 0.0,
354
+ 0.0,
355
+ 280,
356
+ 4,
357
+ 20240521,
358
+ "Cosine",
359
+ 1024,
360
+ 1152,
361
+ 1024,
362
+ 3072,
363
+ 0,
364
+ ],
365
+ [
366
+ "A futuristic neon-lit city street...",
367
+ "The entrance to a grimy nightclub...",
368
+ "A dark alleyway off the main street...",
369
+ 3.5,
370
+ 3.5,
371
+ 3.5,
372
+ 300,
373
+ 8,
374
+ 98765,
375
+ "Cosine",
376
+ 1024,
377
+ 1280,
378
+ 1024,
379
+ 3240,
380
+ 0,
381
  ],
382
  [
383
+ "Iron Man, repulsor rays...",
384
+ "Captain America charging forward...",
385
+ "Thor wielding Stormbreaker...",
386
+ 0.0,
387
+ 0.0,
388
+ 0.0,
389
+ 160,
390
+ 4,
391
+ 619517442,
392
+ "Cosine",
393
+ 1024,
394
+ 1152,
395
+ 1024,
396
+ 3072,
397
+ 0,
398
  ],
399
  ],
400
  inputs=[
401
+ left_prompt,
402
+ center_prompt,
403
+ right_prompt,
404
+ left_gs,
405
+ center_gs,
406
+ right_gs,
407
+ overlap,
408
+ steps,
409
+ generation_seed,
410
+ tile_weighting_method,
411
+ tile_height,
412
+ tile_width,
413
+ height,
414
+ width,
415
+ hdr,
416
  ],
417
  fn=run_for_examples,
418
  outputs=result,
419
  cache_examples=False,
420
  )
421
+
422
+ # Event Handling
423
  event_calc_tile_size = {
424
  "fn": do_calc_tile,
425
  "inputs": [height, width, overlap],
426
  "outputs": [tile_height, tile_width, new_target_height, new_target_width],
427
  }
428
+
429
  predict_inputs = [
430
+ left_prompt,
431
+ center_prompt,
432
+ right_prompt,
433
+ left_gs,
434
+ center_gs,
435
+ right_gs,
436
+ overlap,
437
+ steps,
438
+ generation_seed,
439
+ tile_weighting_method,
440
+ prompt_language,
441
+ tile_height,
442
+ tile_width,
443
+ new_target_height,
444
+ new_target_width,
445
+ hdr,
446
  ]
447
+
448
  calc_tile.click(**event_calc_tile_size)
449
  generate_button.click(
450
+ fn=clear_result,
451
+ inputs=None,
452
+ outputs=result,
453
+ queue=False,
454
  ).then(**event_calc_tile_size).then(
455
+ fn=randomize_seed_fn,
456
+ inputs=[generation_seed, randomize_seed],
457
+ outputs=generation_seed,
458
+ queue=False,
459
  ).then(
460
+ fn=predict, inputs=predict_inputs, outputs=result, show_progress="full"
461
  )
462
 
463
  app.queue().launch(share=True)
 
flux_pipeline_mod.py CHANGED
@@ -24,40 +24,28 @@ from diffusers.pipelines.flux.pipeline_flux import FluxPipeline, FluxPipelineOut
24
 
25
  logger = logging.get_logger(__name__)
26
 
 
27
  def _adaptive_tile_size(image_size, base_tile_size=512, max_tile_size=1280):
28
- width, height = image_size
29
- aspect_ratio = width / height
30
  if aspect_ratio > 1:
31
- tile_width = min(width, max_tile_size)
32
- tile_height = min(int(tile_width / aspect_ratio), max_tile_size)
33
  else:
34
- tile_height = min(height, max_tile_size)
35
- tile_width = min(int(tile_height * aspect_ratio), max_tile_size)
36
- tile_width = max(tile_width, base_tile_size)
37
- tile_height = max(tile_height, base_tile_size)
38
- return tile_width, tile_height
39
 
40
  def _calculate_tile_positions(image_dim: int, tile_dim: int, overlap: int) -> List[int]:
41
- if image_dim <= tile_dim:
42
- return [0]
43
- positions = []
44
- current_pos = 0
45
- stride = tile_dim - overlap
46
  while True:
47
  positions.append(current_pos)
48
- if current_pos + tile_dim >= image_dim:
49
- break
50
  current_pos += stride
51
- if current_pos > image_dim - tile_dim:
52
- break
53
- last_pos = positions[-1]
54
- if last_pos + tile_dim < image_dim:
55
- positions.append(image_dim - tile_dim)
56
  return sorted(list(set(positions)))
57
 
58
  def _tile2pixel_indices(tile_row_pos, tile_col_pos, tile_width, tile_height, image_width, image_height):
59
- px_row_init = tile_row_pos
60
- px_col_init = tile_col_pos
61
  px_row_end = min(px_row_init + tile_height, image_height)
62
  px_col_end = min(px_col_init + tile_width, image_width)
63
  return px_row_init, px_row_end, px_col_init, px_col_end
@@ -69,48 +57,34 @@ def release_memory(device):
69
  gc.collect()
70
  if torch.cuda.is_available():
71
  with torch.cuda.device(device):
72
- torch.cuda.empty_cache()
73
- torch.cuda.synchronize()
74
 
75
  class FluxMoDTilingPipeline(FluxPipeline):
76
  class TileWeightingMethod(Enum):
77
- COSINE = "Cosine"
78
- GAUSSIAN = "Gaussian"
79
 
80
- def _generate_gaussian_weights(self, tile_width, tile_height, nbatches, device, dtype, sigma=0.05):
81
- latent_width = tile_width // self.vae_scale_factor
82
- latent_height = tile_height // self.vae_scale_factor
83
- x = np.linspace(-1, 1, latent_width)
84
- y = np.linspace(-1, 1, latent_height)
85
  xx, yy = np.meshgrid(x, y)
86
- gaussian_weight = np.exp(-(xx**2 + yy**2) / (2 * sigma**2))
87
- weights_torch = torch.tensor(gaussian_weight, device=device, dtype=dtype)
88
- return torch.tile(weights_torch, (nbatches, self.transformer.config.in_channels // 4, 1, 1))
 
89
 
90
  def _generate_cosine_weights(self, tile_width, tile_height, nbatches, device, dtype):
91
- latent_width = tile_width // self.vae_scale_factor
92
- latent_height = tile_height // self.vae_scale_factor
93
- x = np.arange(latent_width)
94
- y = np.arange(latent_height)
95
- mid_x = (latent_width - 1) / 2
96
- mid_y = (latent_height - 1) / 2
97
- x_probs = np.cos(np.pi * (x - mid_x) / latent_width)
98
- y_probs = np.cos(np.pi * (y - mid_y) / latent_height)
99
- weights_np = np.outer(y_probs, x_probs)
100
- weights_torch = torch.tensor(weights_np, device=device, dtype=dtype)
101
- return torch.tile(weights_torch, (nbatches, self.transformer.config.in_channels // 4, 1, 1))
102
 
103
- def prepare_tiles_weights(
104
- self, y_steps, x_steps, tile_height, tile_width, final_height, final_width,
105
- tile_weighting_method, tile_gaussian_sigma, batch_size, device, dtype
106
- ):
107
  tile_weights = np.empty((len(y_steps), len(x_steps)), dtype=object)
108
  for row, y_start in enumerate(y_steps):
109
  for col, x_start in enumerate(x_steps):
110
  _, px_row_end, _, px_col_end = _tile2pixel_indices(y_start, x_start, tile_width, tile_height, final_width, final_height)
111
- current_tile_h = px_row_end - y_start
112
- current_tile_w = px_col_end - x_start
113
-
114
  if tile_weighting_method == self.TileWeightingMethod.COSINE.value:
115
  tile_weights[row, col] = self._generate_cosine_weights(current_tile_w, current_tile_h, batch_size, device, dtype)
116
  else:
@@ -124,17 +98,17 @@ class FluxMoDTilingPipeline(FluxPipeline):
124
  height: int = 1024,
125
  width: int = 1024,
126
  negative_prompt: Optional[Union[str, List[List[str]]]] = "",
127
- num_inference_steps: int = 28,
128
- guidance_scale: float = 7.0,
129
  generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
130
  max_tile_size: int = 1024,
131
  tile_overlap: int = 256,
132
  tile_weighting_method: str = "Cosine",
133
- tile_gaussian_sigma: float = 0.05,
134
  guidance_scale_tiles: Optional[List[List[float]]] = None,
135
  max_sequence_length: int = 512,
136
  output_type: Optional[str] = "pil",
137
- return_dict: bool = True,
138
  ):
139
  device = self._execution_device
140
  batch_size = 1
@@ -146,42 +120,29 @@ class FluxMoDTilingPipeline(FluxPipeline):
146
  grid_rows, grid_cols = len(prompt), len(prompt[0])
147
  tile_width = (width + (grid_cols - 1) * tile_overlap) // grid_cols
148
  tile_height = (height + (grid_rows - 1) * tile_overlap) // grid_rows
149
- tile_width -= tile_width % PIXEL_MULTIPLE
150
- tile_height -= tile_height % PIXEL_MULTIPLE
151
  final_width = tile_width * grid_cols - (grid_cols - 1) * tile_overlap
152
  final_height = tile_height * grid_rows - (grid_rows - 1) * tile_overlap
153
- stride_x = tile_width - tile_overlap
154
- stride_y = tile_height - tile_overlap
155
- x_steps = [i * stride_x for i in range(grid_cols)]
156
- y_steps = [i * stride_y for i in range(grid_rows)]
157
- logger.info(f"Prompt grid provided. Using fixed {grid_rows}x{grid_cols} grid.")
158
- logger.info(f"Target resolution: {width}x{height}. Actual resolution: {final_width}x{final_height}.")
159
- else:
160
  final_width, final_height = width, height
161
  tile_width, tile_height = _adaptive_tile_size((final_width, final_height), max_tile_size=max_tile_size)
162
- tile_width -= tile_width % PIXEL_MULTIPLE
163
- tile_height -= tile_height % PIXEL_MULTIPLE
164
  y_steps = _calculate_tile_positions(final_height, tile_height, tile_overlap)
165
  x_steps = _calculate_tile_positions(final_width, tile_width, tile_overlap)
166
  grid_rows, grid_cols = len(y_steps), len(x_steps)
167
 
168
  logger.info(f"Processing image in a {grid_rows}x{grid_cols} grid of tiles.")
169
 
170
- if not isinstance(negative_prompt, list) or not all(isinstance(p, list) for p in negative_prompt):
171
- negative_prompt = [[negative_prompt] * grid_cols for _ in range(grid_rows)]
172
-
173
  text_embeddings = []
174
  for r in range(grid_rows):
175
  row_embeddings = []
176
  for c in range(grid_cols):
177
  p = prompt[r][c] if is_prompt_grid else prompt
178
- np_ = negative_prompt[r][c] if is_prompt_grid else negative_prompt[0][0]
179
  prompt_embeds, pooled, text_ids = self.encode_prompt(p, device=device, max_sequence_length=max_sequence_length)
180
- neg_embeds, neg_pooled, neg_ids = self.encode_prompt(np_, device=device, max_sequence_length=max_sequence_length)
181
- row_embeddings.append({
182
- "prompt_embeds": prompt_embeds, "pooled_prompt_embeds": pooled, "txt_ids": text_ids,
183
- "neg_prompt_embeds": neg_embeds, "neg_pooled_prompt_embeds": neg_pooled, "neg_txt_ids": neg_ids,
184
- })
185
  text_embeddings.append(row_embeddings)
186
 
187
  prompt_dtype = text_embeddings[0][0]["prompt_embeds"].dtype
@@ -191,35 +152,21 @@ class FluxMoDTilingPipeline(FluxPipeline):
191
  latents = randn_tensor(latents_shape, generator=generator, device=device, dtype=prompt_dtype)
192
 
193
  image_seq_len = (tile_height // self.vae_scale_factor // 2) * (tile_width // self.vae_scale_factor // 2)
194
- mu = calculate_shift(image_seq_len)
195
- timesteps, _ = retrieve_timesteps(self.scheduler, num_inference_steps, device, mu=mu)
196
 
197
- if self.transformer.config.guidance_embeds:
198
- guidance = torch.tensor([guidance_scale], device=device)
199
- else:
200
- guidance = None
201
-
202
- tile_weights = self.prepare_tiles_weights(
203
- y_steps, x_steps, tile_height, tile_width, final_height, final_width,
204
- tile_weighting_method, tile_gaussian_sigma, batch_size, device, latents.dtype
205
- )
206
 
207
- self.text_encoder.to("cpu");
208
- self.text_encoder_2.to("cpu");
209
  release_memory(device)
210
 
211
  with self.progress_bar(total=num_inference_steps) as progress_bar:
212
  for i, t in enumerate(timesteps):
213
  noise_preds_tiles = np.empty((grid_rows, grid_cols), dtype=object)
214
-
215
  for r, y_start in enumerate(y_steps):
216
  for c, x_start in enumerate(x_steps):
217
  px_r_init, px_r_end, px_c_init, px_c_end = _tile2pixel_indices(y_start, x_start, tile_width, tile_height, final_width, final_height)
218
-
219
- # Store the PIXEL dimensions of the current tile
220
- current_tile_pixel_height = px_r_end - px_r_init
221
- current_tile_pixel_width = px_c_end - px_c_init
222
-
223
  r_init, r_end, c_init, c_end = _tile2latent_indices(px_r_init, px_r_end, px_c_init, px_c_end, self.vae_scale_factor)
224
 
225
  tile_latents = latents[:, :, r_init:r_end, c_init:c_end]
@@ -231,30 +178,19 @@ class FluxMoDTilingPipeline(FluxPipeline):
231
  timestep = t.expand(b).to(packed_latents.dtype)
232
 
233
  current_gs_value = guidance_scale_tiles[r][c] if (is_prompt_grid and guidance_scale_tiles) else guidance_scale
234
- current_guidance = torch.tensor([current_gs_value], device=device) if guidance is not None else None
235
-
236
- noise_pred_uncond_packed = self.transformer(
237
- hidden_states=packed_latents, timestep=timestep / 1000, guidance=current_guidance,
238
- pooled_projections=embeds["neg_pooled_prompt_embeds"],
239
- encoder_hidden_states=embeds["neg_prompt_embeds"],
240
- txt_ids=embeds["neg_txt_ids"], img_ids=latent_image_ids,
241
- )[0]
242
-
243
- noise_pred_text_packed = self.transformer(
244
  hidden_states=packed_latents, timestep=timestep / 1000, guidance=current_guidance,
245
  pooled_projections=embeds["pooled_prompt_embeds"],
246
  encoder_hidden_states=embeds["prompt_embeds"],
247
  txt_ids=embeds["txt_ids"], img_ids=latent_image_ids,
248
  )[0]
249
-
250
- # Pass the correct PIXEL dimensions of the tile to _unpack_latents
251
- noise_pred_uncond = self._unpack_latents(noise_pred_uncond_packed, current_tile_pixel_height, current_tile_pixel_width, self.vae_scale_factor)
252
- noise_pred_text = self._unpack_latents(noise_pred_text_packed, current_tile_pixel_height, current_tile_pixel_width, self.vae_scale_factor)
253
 
254
- noise_pred_tile = noise_pred_uncond + current_gs_value * (noise_pred_text - noise_pred_uncond)
255
  noise_preds_tiles[r, c] = noise_pred_tile
256
 
257
- # Stitch noise predictions
258
  noise_pred = torch.zeros_like(latents)
259
  contributors = torch.zeros_like(latents)
260
  for r, y_start in enumerate(y_steps):
@@ -268,19 +204,17 @@ class FluxMoDTilingPipeline(FluxPipeline):
268
 
269
  latents_dtype = latents.dtype
270
  latents = self.scheduler.step(noise_pred, t, latents)[0]
271
- if latents.dtype != latents_dtype:
272
- latents = latents.to(latents_dtype)
273
-
274
  progress_bar.update()
275
 
276
  # Post-processing
277
- if output_type == "latent":
278
- image = latents
279
  else:
280
  self.vae.to(device)
281
  latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
282
  image = self.vae.decode(latents.to(self.vae.dtype))[0]
283
  image = self.image_processor.postprocess(image, output_type=output_type)
284
 
285
- self.maybe_free_model_hooks()
 
286
  return FluxPipelineOutput(images=image)
 
24
 
25
  logger = logging.get_logger(__name__)
26
 
27
+
28
  def _adaptive_tile_size(image_size, base_tile_size=512, max_tile_size=1280):
29
+ width, height = image_size; aspect_ratio = width / height
 
30
  if aspect_ratio > 1:
31
+ tile_width = min(width, max_tile_size); tile_height = min(int(tile_width / aspect_ratio), max_tile_size)
 
32
  else:
33
+ tile_height = min(height, max_tile_size); tile_width = min(int(tile_height * aspect_ratio), max_tile_size)
34
+ return max(tile_width, base_tile_size), max(tile_height, base_tile_size)
 
 
 
35
 
36
  def _calculate_tile_positions(image_dim: int, tile_dim: int, overlap: int) -> List[int]:
37
+ if image_dim <= tile_dim: return [0]
38
+ positions = []; current_pos = 0; stride = tile_dim - overlap
 
 
 
39
  while True:
40
  positions.append(current_pos)
41
+ if current_pos + tile_dim >= image_dim: break
 
42
  current_pos += stride
43
+ if current_pos > image_dim - tile_dim: break
44
+ if positions[-1] + tile_dim < image_dim: positions.append(image_dim - tile_dim)
 
 
 
45
  return sorted(list(set(positions)))
46
 
47
  def _tile2pixel_indices(tile_row_pos, tile_col_pos, tile_width, tile_height, image_width, image_height):
48
+ px_row_init = tile_row_pos; px_col_init = tile_col_pos
 
49
  px_row_end = min(px_row_init + tile_height, image_height)
50
  px_col_end = min(px_col_init + tile_width, image_width)
51
  return px_row_init, px_row_end, px_col_init, px_col_end
 
57
  gc.collect()
58
  if torch.cuda.is_available():
59
  with torch.cuda.device(device):
60
+ torch.cuda.empty_cache(); torch.cuda.synchronize()
 
61
 
62
  class FluxMoDTilingPipeline(FluxPipeline):
63
  class TileWeightingMethod(Enum):
64
+ COSINE = "Cosine"; GAUSSIAN = "Gaussian"
 
65
 
66
+ def _generate_gaussian_weights(self, tile_width, tile_height, nbatches, device, dtype, sigma=0.4):
67
+ latent_width, latent_height = tile_width // self.vae_scale_factor, tile_height // self.vae_scale_factor
68
+ x, y = np.linspace(-1, 1, latent_width), np.linspace(-1, 1, latent_height)
 
 
69
  xx, yy = np.meshgrid(x, y)
70
+ gaussian_weight_np = np.exp(-(xx**2 + yy**2) / (2 * sigma**2))
71
+ weights_torch_f32 = torch.tensor(gaussian_weight_np, device=device, dtype=torch.float32)
72
+ weights_torch_target_dtype = weights_torch_f32.to(dtype)
73
+ return torch.tile(weights_torch_target_dtype, (nbatches, self.transformer.config.in_channels // 4, 1, 1))
74
 
75
  def _generate_cosine_weights(self, tile_width, tile_height, nbatches, device, dtype):
76
+ latent_width, latent_height = tile_width // self.vae_scale_factor, tile_height // self.vae_scale_factor
77
+ x, y = np.arange(latent_width), np.arange(latent_height)
78
+ mid_x, mid_y = (latent_width - 1) / 2, (latent_height - 1) / 2
79
+ x_probs, y_probs = np.cos(np.pi * (x - mid_x) / latent_width), np.cos(np.pi * (y - mid_y) / latent_height)
80
+ return torch.tile(torch.tensor(np.outer(y_probs, x_probs), device=device, dtype=dtype), (nbatches, self.transformer.config.in_channels // 4, 1, 1))
 
 
 
 
 
 
81
 
82
+ def prepare_tiles_weights(self, y_steps, x_steps, tile_height, tile_width, final_height, final_width, tile_weighting_method, tile_gaussian_sigma, batch_size, device, dtype):
 
 
 
83
  tile_weights = np.empty((len(y_steps), len(x_steps)), dtype=object)
84
  for row, y_start in enumerate(y_steps):
85
  for col, x_start in enumerate(x_steps):
86
  _, px_row_end, _, px_col_end = _tile2pixel_indices(y_start, x_start, tile_width, tile_height, final_width, final_height)
87
+ current_tile_h, current_tile_w = px_row_end - y_start, px_col_end - x_start
 
 
88
  if tile_weighting_method == self.TileWeightingMethod.COSINE.value:
89
  tile_weights[row, col] = self._generate_cosine_weights(current_tile_w, current_tile_h, batch_size, device, dtype)
90
  else:
 
98
  height: int = 1024,
99
  width: int = 1024,
100
  negative_prompt: Optional[Union[str, List[List[str]]]] = "",
101
+ num_inference_steps: int = 4,
102
+ guidance_scale: float = 0.0,
103
  generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
104
  max_tile_size: int = 1024,
105
  tile_overlap: int = 256,
106
  tile_weighting_method: str = "Cosine",
107
+ tile_gaussian_sigma: float = 0.4,
108
  guidance_scale_tiles: Optional[List[List[float]]] = None,
109
  max_sequence_length: int = 512,
110
  output_type: Optional[str] = "pil",
111
+ return_dict: bool = True,
112
  ):
113
  device = self._execution_device
114
  batch_size = 1
 
120
  grid_rows, grid_cols = len(prompt), len(prompt[0])
121
  tile_width = (width + (grid_cols - 1) * tile_overlap) // grid_cols
122
  tile_height = (height + (grid_rows - 1) * tile_overlap) // grid_rows
123
+ tile_width -= tile_width % PIXEL_MULTIPLE; tile_height -= tile_height % PIXEL_MULTIPLE
 
124
  final_width = tile_width * grid_cols - (grid_cols - 1) * tile_overlap
125
  final_height = tile_height * grid_rows - (grid_rows - 1) * tile_overlap
126
+ x_steps = [i * (tile_width - tile_overlap) for i in range(grid_cols)]
127
+ y_steps = [i * (tile_height - tile_overlap) for i in range(grid_rows)]
128
+ logger.info(f"Prompt grid provided. Using fixed {grid_rows}x{grid_cols} grid. Actual resolution: {final_width}x{final_height}.")
129
+ else: # Tiling Mode
 
 
 
130
  final_width, final_height = width, height
131
  tile_width, tile_height = _adaptive_tile_size((final_width, final_height), max_tile_size=max_tile_size)
132
+ tile_width -= tile_width % PIXEL_MULTIPLE; tile_height -= tile_height % PIXEL_MULTIPLE
 
133
  y_steps = _calculate_tile_positions(final_height, tile_height, tile_overlap)
134
  x_steps = _calculate_tile_positions(final_width, tile_width, tile_overlap)
135
  grid_rows, grid_cols = len(y_steps), len(x_steps)
136
 
137
  logger.info(f"Processing image in a {grid_rows}x{grid_cols} grid of tiles.")
138
 
 
 
 
139
  text_embeddings = []
140
  for r in range(grid_rows):
141
  row_embeddings = []
142
  for c in range(grid_cols):
143
  p = prompt[r][c] if is_prompt_grid else prompt
 
144
  prompt_embeds, pooled, text_ids = self.encode_prompt(p, device=device, max_sequence_length=max_sequence_length)
145
+ row_embeddings.append({"prompt_embeds": prompt_embeds, "pooled_prompt_embeds": pooled, "txt_ids": text_ids})
 
 
 
 
146
  text_embeddings.append(row_embeddings)
147
 
148
  prompt_dtype = text_embeddings[0][0]["prompt_embeds"].dtype
 
152
  latents = randn_tensor(latents_shape, generator=generator, device=device, dtype=prompt_dtype)
153
 
154
  image_seq_len = (tile_height // self.vae_scale_factor // 2) * (tile_width // self.vae_scale_factor // 2)
155
+ mu = calculate_shift(image_seq_len); timesteps, _ = retrieve_timesteps(self.scheduler, num_inference_steps, device, mu=mu)
 
156
 
157
+ tile_weights = self.prepare_tiles_weights(y_steps, x_steps, tile_height, tile_width, final_height, final_width, tile_weighting_method, tile_gaussian_sigma, batch_size, device, latents.dtype)
 
 
 
 
 
 
 
 
158
 
159
+ self.text_encoder.to("cpu")
160
+ self.text_encoder_2.to("cpu")
161
  release_memory(device)
162
 
163
  with self.progress_bar(total=num_inference_steps) as progress_bar:
164
  for i, t in enumerate(timesteps):
165
  noise_preds_tiles = np.empty((grid_rows, grid_cols), dtype=object)
 
166
  for r, y_start in enumerate(y_steps):
167
  for c, x_start in enumerate(x_steps):
168
  px_r_init, px_r_end, px_c_init, px_c_end = _tile2pixel_indices(y_start, x_start, tile_width, tile_height, final_width, final_height)
169
+ current_tile_pixel_height = px_r_end - px_r_init; current_tile_pixel_width = px_c_end - px_c_init
 
 
 
 
170
  r_init, r_end, c_init, c_end = _tile2latent_indices(px_r_init, px_r_end, px_c_init, px_c_end, self.vae_scale_factor)
171
 
172
  tile_latents = latents[:, :, r_init:r_end, c_init:c_end]
 
178
  timestep = t.expand(b).to(packed_latents.dtype)
179
 
180
  current_gs_value = guidance_scale_tiles[r][c] if (is_prompt_grid and guidance_scale_tiles) else guidance_scale
181
+ current_guidance = torch.tensor([current_gs_value], device=device) if self.transformer.config.guidance_embeds else None
182
+
183
+ noise_pred_packed = self.transformer(
 
 
 
 
 
 
 
184
  hidden_states=packed_latents, timestep=timestep / 1000, guidance=current_guidance,
185
  pooled_projections=embeds["pooled_prompt_embeds"],
186
  encoder_hidden_states=embeds["prompt_embeds"],
187
  txt_ids=embeds["txt_ids"], img_ids=latent_image_ids,
188
  )[0]
 
 
 
 
189
 
190
+ noise_pred_tile = self._unpack_latents(noise_pred_packed, current_tile_pixel_height, current_tile_pixel_width, self.vae_scale_factor)
191
  noise_preds_tiles[r, c] = noise_pred_tile
192
 
193
+ # Stitching and Scheduler step (no changes)
194
  noise_pred = torch.zeros_like(latents)
195
  contributors = torch.zeros_like(latents)
196
  for r, y_start in enumerate(y_steps):
 
204
 
205
  latents_dtype = latents.dtype
206
  latents = self.scheduler.step(noise_pred, t, latents)[0]
207
+ if latents.dtype != latents_dtype: latents = latents.to(latents_dtype)
 
 
208
  progress_bar.update()
209
 
210
  # Post-processing
211
+ if output_type == "latent": image = latents
 
212
  else:
213
  self.vae.to(device)
214
  latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
215
  image = self.vae.decode(latents.to(self.vae.dtype))[0]
216
  image = self.image_processor.postprocess(image, output_type=output_type)
217
 
218
+ self.maybe_free_model_hooks();
219
+
220
  return FluxPipelineOutput(images=image)
infer.py CHANGED
@@ -1,36 +1,49 @@
1
  # infer.py
2
- # A command-line inference script to test the FluxUnifiedTilingPipeline.
3
  # This script runs the first example from the Gradio app to verify functionality
4
  # and observe the progress bar in the terminal.
5
 
 
6
  import torch
7
- from PIL import Image
8
  import time
9
 
10
- # Make sure flux_unified_tiling_pipeline.py is in the same directory
11
  from flux_pipeline_mod import FluxMoDTilingPipeline
 
 
 
12
 
13
  # Optional: for memory offloading
14
- try:
15
- from mmgp import offload, profile_type
16
- except ImportError:
17
- print("Warning: 'mmgp' library not found. Offload will not be applied.")
18
- offload = None
19
-
 
 
 
20
  def main():
21
  """Main function to run the inference process."""
22
 
23
  # 1. Load Model
24
  print("--- 1. Loading Model ---")
25
  # !!! IMPORTANT: Make sure this path is correct for your system !!!
26
- MODEL_PATH = "F:\\Models\\Flux_dev"
 
27
 
28
  start_load_time = time.time()
29
- pipe = FluxMoDTilingPipeline.from_pretrained(
30
- MODEL_PATH,
31
- torch_dtype=torch.bfloat16
32
- )
33
-
 
 
 
 
 
 
34
  # Apply memory optimization
35
  if offload:
36
  print("Applying LowRAM_LowVRAM offload profile...")
@@ -57,8 +70,7 @@ def main():
57
  "Captain America charging forward, vibranium shield deflecting energy blasts in destroyed cityscape, cinematic composition. Focus: Captain America.",
58
  "Thor wielding Stormbreaker in destroyed cityscape, lightning crackling, powerful strike downwards, cinematic photography. Focus: Thor."
59
  ]]
60
- negative_prompt = "nsfw, lowres, bad anatomy, bad hands, duplicate, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, blurry"
61
-
62
  # Tiling and Dimensions
63
  target_height = 1024
64
  target_width = 3072
@@ -66,17 +78,14 @@ def main():
66
  tile_weighting_method = "Cosine"
67
 
68
  # Generation
69
- num_inference_steps = 30
70
- guidance_scale_tiles = [[5.0, 5.0, 5.0]]
71
  seed = 619517442
72
 
73
  # Create a generator for reproducibility
74
- generator = torch.Generator("cuda" if torch.cuda.is_available() else "cpu").manual_seed(seed)
75
 
76
- print("Parameters set:")
77
- print(f" Resolution: {target_width}x{target_height}")
78
- print(f" Steps: {num_inference_steps}")
79
- print(f" Seed: {seed}")
80
 
81
  # 3. Start Inference
82
  print("\n--- 3. Starting Inference ---")
@@ -86,12 +95,11 @@ def main():
86
  image = pipe(
87
  prompt=prompt_grid,
88
  height=target_height,
89
- width=target_width,
90
- negative_prompt=negative_prompt,
91
  tile_overlap=tile_overlap,
92
- guidance_scale_tiles=guidance_scale_tiles,
93
- tile_weighting_method=tile_weighting_method,
94
  generator=generator,
 
95
  num_inference_steps=num_inference_steps
96
  ).images[0]
97
 
@@ -100,7 +108,7 @@ def main():
100
 
101
  # 4. Save Output
102
  print("\n--- 4. Saving Output ---")
103
- output_filename = "inference_output.png"
104
  image.save(output_filename)
105
  print(f"Image successfully saved as '{output_filename}'")
106
 
 
1
  # infer.py
2
+ # A command-line inference script to test the FluxMoDTilingPipeline.
3
  # This script runs the first example from the Gradio app to verify functionality
4
  # and observe the progress bar in the terminal.
5
 
6
+ import os
7
  import torch
 
8
  import time
9
 
10
+ # Make sure flux_pipeline_mod.py is in the same directory
11
  from flux_pipeline_mod import FluxMoDTilingPipeline
12
+ # Conditional MMGP Setup based on Environment Variable
13
+ USE_MMGP_ENV = os.getenv('USE_MMGP', 'true').lower()
14
+ USE_MMGP = USE_MMGP_ENV not in ('false', '0', 'no', 'none')
15
 
16
  # Optional: for memory offloading
17
+ if USE_MMGP:
18
+ try:
19
+ from mmgp import offload, profile_type
20
+ except ImportError:
21
+ print("Warning: 'mmgp' library not found. Offload will not be applied.")
22
+ offload = None
23
+ else:
24
+ print("INFO: MMGP is disabled.")
25
+
26
  def main():
27
  """Main function to run the inference process."""
28
 
29
  # 1. Load Model
30
  print("--- 1. Loading Model ---")
31
  # !!! IMPORTANT: Make sure this path is correct for your system !!!
32
+ #MODEL_PATH = "F:\\Models\\FLUX.1-schnell"
33
+ MODEL_PATH = "black-forest-labs/FLUX.1-schnell"
34
 
35
  start_load_time = time.time()
36
+ if USE_MMGP:
37
+ pipe = FluxMoDTilingPipeline.from_pretrained(
38
+ MODEL_PATH,
39
+ torch_dtype=torch.bfloat16
40
+ )
41
+ else:
42
+ pipe = FluxMoDTilingPipeline.from_pretrained(
43
+ MODEL_PATH,
44
+ torch_dtype=torch.bfloat16
45
+ ).to("cuda")
46
+
47
  # Apply memory optimization
48
  if offload:
49
  print("Applying LowRAM_LowVRAM offload profile...")
 
70
  "Captain America charging forward, vibranium shield deflecting energy blasts in destroyed cityscape, cinematic composition. Focus: Captain America.",
71
  "Thor wielding Stormbreaker in destroyed cityscape, lightning crackling, powerful strike downwards, cinematic photography. Focus: Thor."
72
  ]]
73
+
 
74
  # Tiling and Dimensions
75
  target_height = 1024
76
  target_width = 3072
 
78
  tile_weighting_method = "Cosine"
79
 
80
  # Generation
81
+ num_inference_steps = 4
82
+ guidance_scale = 0.0
83
  seed = 619517442
84
 
85
  # Create a generator for reproducibility
86
+ generator = torch.Generator("cuda").manual_seed(seed)
87
 
88
+ print(f"Resolution: {target_width}x{target_height}, Steps: {num_inference_steps}, Guidance: {guidance_scale}")
 
 
 
89
 
90
  # 3. Start Inference
91
  print("\n--- 3. Starting Inference ---")
 
95
  image = pipe(
96
  prompt=prompt_grid,
97
  height=target_height,
98
+ width=target_width,
 
99
  tile_overlap=tile_overlap,
100
+ guidance_scale=guidance_scale,
 
101
  generator=generator,
102
+ tile_weighting_method=tile_weighting_method,
103
  num_inference_steps=num_inference_steps
104
  ).images[0]
105
 
 
108
 
109
  # 4. Save Output
110
  print("\n--- 4. Saving Output ---")
111
+ output_filename = "outputs/inference_output.png"
112
  image.save(output_filename)
113
  print(f"Image successfully saved as '{output_filename}'")
114
 
requirements.txt CHANGED
@@ -11,4 +11,6 @@ hf_xet
11
  protobuf
12
  sentencepiece
13
  ligo-segments
14
- scipy
 
 
 
11
  protobuf
12
  sentencepiece
13
  ligo-segments
14
+ scipy
15
+ triton-windows<3.5; sys_platform == 'win32'
16
+ triton==3.4.0; sys_platform != 'win32'
requirements_local.txt CHANGED
@@ -14,4 +14,6 @@ mmgp
14
  protobuf
15
  sentencepiece
16
  ligo-segments
17
- scipy
 
 
 
14
  protobuf
15
  sentencepiece
16
  ligo-segments
17
+ scipy
18
+ triton-windows<3.5; sys_platform == 'win32'
19
+ triton==3.4.0; sys_platform != 'win32'