summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorZach Berwaldt <zberwaldt@tutamail.com>2024-10-30 01:06:23 -0400
committerZach Berwaldt <zberwaldt@tutamail.com>2024-10-30 01:06:23 -0400
commit4445e6f202e9d3a7f7e9bc94894c1eb3bc6c7945 (patch)
tree35e958e423f55abd7feb149a4b6ef77bddd95c70
parent23b68c871d95353258d180639fa30b636a982423 (diff)
add img2img pipeline
-rw-r--r--run_flux.py35
1 files changed, 28 insertions, 7 deletions
diff --git a/run_flux.py b/run_flux.py
index 8c4c27e..4a27816 100644
--- a/run_flux.py
+++ b/run_flux.py
@@ -2,15 +2,24 @@ import torch
2import os 2import os
3import base64 3import base64
4import argparse 4import argparse
5from diffusers import FluxPipeline 5from diffusers import FluxPipeline, FluxImg2ImgPipeline
6from diffusers.utils import load_image
6from typing import Tuple 7from typing import Tuple
7from pathlib import Path 8from pathlib import Path
9from PIL import Image
8import uuid 10import uuid
9 11
10STORAGE_DIR: Path = Path.home() / "Pictures" / "Flux" 12STORAGE_DIR: Path = Path.home() / "Pictures"
11 13
12STORAGE_DIR.mkdir(parents=True, exist_ok=True) 14STORAGE_DIR.mkdir(parents=True, exist_ok=True)
13 15
16def load_flux_img_to_img():
17 pipeline = FluxImg2ImgPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16)
18 pipeline.enable_model_cpu_offload()
19 pipeline.vae.enable_slicing()
20 pipeline.vae.enable_tiling()
21 return pipeline
22
14def load_flux(): 23def load_flux():
15 pipeline = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16) 24 pipeline = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16)
16 pipeline.enable_model_cpu_offload() 25 pipeline.enable_model_cpu_offload()
@@ -18,7 +27,7 @@ def load_flux():
18 pipeline.vae.enable_tiling() 27 pipeline.vae.enable_tiling()
19 return pipeline 28 return pipeline
20 29
21def generate_image(pipeline, prompt, height=1024, width=1024, guideance_scale=0, num_images_per_prompt=1, num_inference_steps=50): 30def generate_image(pipeline, prompt, init_image=None, height=1024, width=1024, guideance_scale=0, num_images_per_prompt=1, num_inference_steps=50):
22 images = pipeline( 31 images = pipeline(
23 prompt=prompt, 32 prompt=prompt,
24 guidance_scale=guideance_scale, 33 guidance_scale=guideance_scale,
@@ -26,7 +35,8 @@ def generate_image(pipeline, prompt, height=1024, width=1024, guideance_scale=0,
26 width=width, 35 width=width,
27 max_sequence_length=256, 36 max_sequence_length=256,
28 num_inference_steps=num_inference_steps, 37 num_inference_steps=num_inference_steps,
29 num_images_per_prompt=num_images_per_prompt 38 num_images_per_prompt=num_images_per_prompt,
39 image=init_image
30 ).images 40 ).images
31 return images 41 return images
32 42
@@ -44,16 +54,27 @@ def main():
44 try: 54 try:
45 args = parser.parse_args() 55 args = parser.parse_args()
46 56
47 pipeline = load_flux() 57 pipeline = load_flux_img_to_img()
48 pipeline.to(torch.float16) 58 pipeline.to(torch.float16)
49 59
60 target_img = STORAGE_DIR / "1517481062292.jpg"
61
62 target_img_path = target_img.resolve(strict=True)
63
64 image = Image.open(target_img_path)
65
66 init_image = load_image(image).resize((256, 256))
67
50 width, height = args.size 68 width, height = args.size
51 69
52 images = generate_image(pipeline, prompt=args.prompt, width=width, height=height, guideance_scale=args.guideance_scale, num_images_per_prompt=args.number) 70 images = generate_image(pipeline, init_image=init_image, prompt=args.prompt, width=width, height=height, guideance_scale=args.guideance_scale, num_images_per_prompt=args.number)
53 for image in images: 71 for image in images:
54 filename = generate_random_string() 72 filename = generate_random_string()
55 filepath = STORAGE_DIR / f"{filename}.png" 73 filepath = STORAGE_DIR / "Flux" / f"{filename}.png"
56 image.save(filepath) 74 image.save(filepath)
75 except FileNotFoundError:
76 print("\n Target image doesn't exist. Exiting...")
77 exit(0)
57 except KeyboardInterrupt: 78 except KeyboardInterrupt:
58 print('\nExiting early...') 79 print('\nExiting early...')
59 exit(0) 80 exit(0)