diff options
author | Zach Berwaldt <zberwaldt@tutamail.com> | 2024-10-30 01:06:23 -0400 |
---|---|---|
committer | Zach Berwaldt <zberwaldt@tutamail.com> | 2024-10-30 01:06:23 -0400 |
commit | 4445e6f202e9d3a7f7e9bc94894c1eb3bc6c7945 (patch) | |
tree | 35e958e423f55abd7feb149a4b6ef77bddd95c70 | |
parent | 23b68c871d95353258d180639fa30b636a982423 (diff) |
add img2img pipeline
-rw-r--r-- | run_flux.py | 35 |
1 files changed, 28 insertions, 7 deletions
diff --git a/run_flux.py b/run_flux.py index 8c4c27e..4a27816 100644 --- a/run_flux.py +++ b/run_flux.py | |||
@@ -2,15 +2,24 @@ import torch | |||
2 | import os | 2 | import os |
3 | import base64 | 3 | import base64 |
4 | import argparse | 4 | import argparse |
5 | from diffusers import FluxPipeline | 5 | from diffusers import FluxPipeline, FluxImg2ImgPipeline |
6 | from diffusers.utils import load_image | ||
6 | from typing import Tuple | 7 | from typing import Tuple |
7 | from pathlib import Path | 8 | from pathlib import Path |
9 | from PIL import Image | ||
8 | import uuid | 10 | import uuid |
9 | 11 | ||
10 | STORAGE_DIR: Path = Path.home() / "Pictures" / "Flux" | 12 | STORAGE_DIR: Path = Path.home() / "Pictures" |
11 | 13 | ||
12 | STORAGE_DIR.mkdir(parents=True, exist_ok=True) | 14 | STORAGE_DIR.mkdir(parents=True, exist_ok=True) |
13 | 15 | ||
16 | def load_flux_img_to_img(): | ||
17 | pipeline = FluxImg2ImgPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16) | ||
18 | pipeline.enable_model_cpu_offload() | ||
19 | pipeline.vae.enable_slicing() | ||
20 | pipeline.vae.enable_tiling() | ||
21 | return pipeline | ||
22 | |||
14 | def load_flux(): | 23 | def load_flux(): |
15 | pipeline = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16) | 24 | pipeline = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16) |
16 | pipeline.enable_model_cpu_offload() | 25 | pipeline.enable_model_cpu_offload() |
@@ -18,7 +27,7 @@ def load_flux(): | |||
18 | pipeline.vae.enable_tiling() | 27 | pipeline.vae.enable_tiling() |
19 | return pipeline | 28 | return pipeline |
20 | 29 | ||
21 | def generate_image(pipeline, prompt, height=1024, width=1024, guideance_scale=0, num_images_per_prompt=1, num_inference_steps=50): | 30 | def generate_image(pipeline, prompt, init_image=None, height=1024, width=1024, guideance_scale=0, num_images_per_prompt=1, num_inference_steps=50): |
22 | images = pipeline( | 31 | images = pipeline( |
23 | prompt=prompt, | 32 | prompt=prompt, |
24 | guidance_scale=guideance_scale, | 33 | guidance_scale=guideance_scale, |
@@ -26,7 +35,8 @@ def generate_image(pipeline, prompt, height=1024, width=1024, guideance_scale=0, | |||
26 | width=width, | 35 | width=width, |
27 | max_sequence_length=256, | 36 | max_sequence_length=256, |
28 | num_inference_steps=num_inference_steps, | 37 | num_inference_steps=num_inference_steps, |
29 | num_images_per_prompt=num_images_per_prompt | 38 | num_images_per_prompt=num_images_per_prompt, |
39 | image=init_image | ||
30 | ).images | 40 | ).images |
31 | return images | 41 | return images |
32 | 42 | ||
@@ -44,16 +54,27 @@ def main(): | |||
44 | try: | 54 | try: |
45 | args = parser.parse_args() | 55 | args = parser.parse_args() |
46 | 56 | ||
47 | pipeline = load_flux() | 57 | pipeline = load_flux_img_to_img() |
48 | pipeline.to(torch.float16) | 58 | pipeline.to(torch.float16) |
49 | 59 | ||
60 | target_img = STORAGE_DIR / "1517481062292.jpg" | ||
61 | |||
62 | target_img_path = target_img.resolve(strict=True) | ||
63 | |||
64 | image = Image.open(target_img_path) | ||
65 | |||
66 | init_image = load_image(image).resize((256, 256)) | ||
67 | |||
50 | width, height = args.size | 68 | width, height = args.size |
51 | 69 | ||
52 | images = generate_image(pipeline, prompt=args.prompt, width=width, height=height, guideance_scale=args.guideance_scale, num_images_per_prompt=args.number) | 70 | images = generate_image(pipeline, init_image=init_image, prompt=args.prompt, width=width, height=height, guideance_scale=args.guideance_scale, num_images_per_prompt=args.number) |
53 | for image in images: | 71 | for image in images: |
54 | filename = generate_random_string() | 72 | filename = generate_random_string() |
55 | filepath = STORAGE_DIR / f"{filename}.png" | 73 | filepath = STORAGE_DIR / "Flux" / f"{filename}.png" |
56 | image.save(filepath) | 74 | image.save(filepath) |
75 | except FileNotFoundError: | ||
76 | print("\n Target image doesn't exist. Exiting...") | ||
77 | exit(0) | ||
57 | except KeyboardInterrupt: | 78 | except KeyboardInterrupt: |
58 | print('\nExiting early...') | 79 | print('\nExiting early...') |
59 | exit(0) | 80 | exit(0) |