diff options
| author | Zach Berwaldt <zberwaldt@tutamail.com> | 2024-10-30 01:06:23 -0400 |
|---|---|---|
| committer | Zach Berwaldt <zberwaldt@tutamail.com> | 2024-10-30 01:06:23 -0400 |
| commit | 4445e6f202e9d3a7f7e9bc94894c1eb3bc6c7945 (patch) | |
| tree | 35e958e423f55abd7feb149a4b6ef77bddd95c70 | |
| parent | 23b68c871d95353258d180639fa30b636a982423 (diff) | |
add img2img pipeline
| -rw-r--r-- | run_flux.py | 35 |
1 files changed, 28 insertions, 7 deletions
diff --git a/run_flux.py b/run_flux.py index 8c4c27e..4a27816 100644 --- a/run_flux.py +++ b/run_flux.py | |||
| @@ -2,15 +2,24 @@ import torch | |||
| 2 | import os | 2 | import os |
| 3 | import base64 | 3 | import base64 |
| 4 | import argparse | 4 | import argparse |
| 5 | from diffusers import FluxPipeline | 5 | from diffusers import FluxPipeline, FluxImg2ImgPipeline |
| 6 | from diffusers.utils import load_image | ||
| 6 | from typing import Tuple | 7 | from typing import Tuple |
| 7 | from pathlib import Path | 8 | from pathlib import Path |
| 9 | from PIL import Image | ||
| 8 | import uuid | 10 | import uuid |
| 9 | 11 | ||
| 10 | STORAGE_DIR: Path = Path.home() / "Pictures" / "Flux" | 12 | STORAGE_DIR: Path = Path.home() / "Pictures" |
| 11 | 13 | ||
| 12 | STORAGE_DIR.mkdir(parents=True, exist_ok=True) | 14 | STORAGE_DIR.mkdir(parents=True, exist_ok=True) |
| 13 | 15 | ||
| 16 | def load_flux_img_to_img(): | ||
| 17 | pipeline = FluxImg2ImgPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16) | ||
| 18 | pipeline.enable_model_cpu_offload() | ||
| 19 | pipeline.vae.enable_slicing() | ||
| 20 | pipeline.vae.enable_tiling() | ||
| 21 | return pipeline | ||
| 22 | |||
| 14 | def load_flux(): | 23 | def load_flux(): |
| 15 | pipeline = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16) | 24 | pipeline = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16) |
| 16 | pipeline.enable_model_cpu_offload() | 25 | pipeline.enable_model_cpu_offload() |
| @@ -18,7 +27,7 @@ def load_flux(): | |||
| 18 | pipeline.vae.enable_tiling() | 27 | pipeline.vae.enable_tiling() |
| 19 | return pipeline | 28 | return pipeline |
| 20 | 29 | ||
| 21 | def generate_image(pipeline, prompt, height=1024, width=1024, guideance_scale=0, num_images_per_prompt=1, num_inference_steps=50): | 30 | def generate_image(pipeline, prompt, init_image=None, height=1024, width=1024, guideance_scale=0, num_images_per_prompt=1, num_inference_steps=50): |
| 22 | images = pipeline( | 31 | images = pipeline( |
| 23 | prompt=prompt, | 32 | prompt=prompt, |
| 24 | guidance_scale=guideance_scale, | 33 | guidance_scale=guideance_scale, |
| @@ -26,7 +35,8 @@ def generate_image(pipeline, prompt, height=1024, width=1024, guideance_scale=0, | |||
| 26 | width=width, | 35 | width=width, |
| 27 | max_sequence_length=256, | 36 | max_sequence_length=256, |
| 28 | num_inference_steps=num_inference_steps, | 37 | num_inference_steps=num_inference_steps, |
| 29 | num_images_per_prompt=num_images_per_prompt | 38 | num_images_per_prompt=num_images_per_prompt, |
| 39 | image=init_image | ||
| 30 | ).images | 40 | ).images |
| 31 | return images | 41 | return images |
| 32 | 42 | ||
| @@ -44,16 +54,27 @@ def main(): | |||
| 44 | try: | 54 | try: |
| 45 | args = parser.parse_args() | 55 | args = parser.parse_args() |
| 46 | 56 | ||
| 47 | pipeline = load_flux() | 57 | pipeline = load_flux_img_to_img() |
| 48 | pipeline.to(torch.float16) | 58 | pipeline.to(torch.float16) |
| 49 | 59 | ||
| 60 | target_img = STORAGE_DIR / "1517481062292.jpg" | ||
| 61 | |||
| 62 | target_img_path = target_img.resolve(strict=True) | ||
| 63 | |||
| 64 | image = Image.open(target_img_path) | ||
| 65 | |||
| 66 | init_image = load_image(image).resize((256, 256)) | ||
| 67 | |||
| 50 | width, height = args.size | 68 | width, height = args.size |
| 51 | 69 | ||
| 52 | images = generate_image(pipeline, prompt=args.prompt, width=width, height=height, guideance_scale=args.guideance_scale, num_images_per_prompt=args.number) | 70 | images = generate_image(pipeline, init_image=init_image, prompt=args.prompt, width=width, height=height, guideance_scale=args.guideance_scale, num_images_per_prompt=args.number) |
| 53 | for image in images: | 71 | for image in images: |
| 54 | filename = generate_random_string() | 72 | filename = generate_random_string() |
| 55 | filepath = STORAGE_DIR / f"{filename}.png" | 73 | filepath = STORAGE_DIR / "Flux" / f"{filename}.png" |
| 56 | image.save(filepath) | 74 | image.save(filepath) |
| 75 | except FileNotFoundError: | ||
| 76 | print("\n Target image doesn't exist. Exiting...") | ||
| 77 | exit(0) | ||
| 57 | except KeyboardInterrupt: | 78 | except KeyboardInterrupt: |
| 58 | print('\nExiting early...') | 79 | print('\nExiting early...') |
| 59 | exit(0) | 80 | exit(0) |
