From 4445e6f202e9d3a7f7e9bc94894c1eb3bc6c7945 Mon Sep 17 00:00:00 2001 From: Zach Berwaldt Date: Wed, 30 Oct 2024 01:06:23 -0400 Subject: add img2img pipeline --- run_flux.py | 35 ++++++++++++++++++++++++++++------- 1 file changed, 28 insertions(+), 7 deletions(-) diff --git a/run_flux.py b/run_flux.py index 8c4c27e..4a27816 100644 --- a/run_flux.py +++ b/run_flux.py @@ -2,15 +2,24 @@ import torch import os import base64 import argparse -from diffusers import FluxPipeline +from diffusers import FluxPipeline, FluxImg2ImgPipeline +from diffusers.utils import load_image from typing import Tuple from pathlib import Path +from PIL import Image import uuid -STORAGE_DIR: Path = Path.home() / "Pictures" / "Flux" +STORAGE_DIR: Path = Path.home() / "Pictures" STORAGE_DIR.mkdir(parents=True, exist_ok=True) +def load_flux_img_to_img(): + pipeline = FluxImg2ImgPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16) + pipeline.enable_model_cpu_offload() + pipeline.vae.enable_slicing() + pipeline.vae.enable_tiling() + return pipeline + def load_flux(): pipeline = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16) pipeline.enable_model_cpu_offload() @@ -18,7 +27,7 @@ def load_flux(): pipeline.vae.enable_tiling() return pipeline -def generate_image(pipeline, prompt, height=1024, width=1024, guideance_scale=0, num_images_per_prompt=1, num_inference_steps=50): +def generate_image(pipeline, prompt, init_image=None, height=1024, width=1024, guideance_scale=0, num_images_per_prompt=1, num_inference_steps=50): images = pipeline( prompt=prompt, guidance_scale=guideance_scale, @@ -26,7 +35,8 @@ def generate_image(pipeline, prompt, height=1024, width=1024, guideance_scale=0, width=width, max_sequence_length=256, num_inference_steps=num_inference_steps, - num_images_per_prompt=num_images_per_prompt + num_images_per_prompt=num_images_per_prompt, + image=init_image ).images return images @@ -44,16 +54,27 @@ def main(): try: args = parser.parse_args() - pipeline = load_flux() + pipeline = load_flux_img_to_img() pipeline.to(torch.float16) + target_img = STORAGE_DIR / "1517481062292.jpg" + + target_img_path = target_img.resolve(strict=True) + + image = Image.open(target_img_path) + + init_image = load_image(image).resize((256, 256)) + width, height = args.size - images = generate_image(pipeline, prompt=args.prompt, width=width, height=height, guideance_scale=args.guideance_scale, num_images_per_prompt=args.number) + images = generate_image(pipeline, init_image=init_image, prompt=args.prompt, width=width, height=height, guideance_scale=args.guideance_scale, num_images_per_prompt=args.number) for image in images: filename = generate_random_string() - filepath = STORAGE_DIR / f"{filename}.png" + filepath = STORAGE_DIR / "Flux" / f"{filename}.png" image.save(filepath) + except FileNotFoundError: + print("\n Target image doesn't exist. Exiting...") + exit(0) except KeyboardInterrupt: print('\nExiting early...') exit(0) -- cgit v1.1