summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorZach Berwaldt <zberwaldt@tutamail.com>2025-01-28 17:53:57 -0500
committerZach Berwaldt <zberwaldt@tutamail.com>2025-01-28 17:53:57 -0500
commit671c17918c877e64aee7c730f32aac9ded9d5915 (patch)
tree5dd489a1da533123245d4f276e2ef11c87764047
parentb067c48e5db91796c96c982011421713b0187ae4 (diff)
reduce args for generate_image by using a config class. update ignore to account for idea folder, re-org code to increase speed.
-rw-r--r--.gitignore2
-rw-r--r--pyproject.toml12
-rw-r--r--run_flux.py117
3 files changed, 72 insertions, 59 deletions
diff --git a/.gitignore b/.gitignore
index 96ef086..73f016a 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,3 +1,5 @@
1*.png 1*.png
2*.log 2*.log
3prompts.txt 3prompts.txt
4.idea
5Flux.egg-info \ No newline at end of file
diff --git a/pyproject.toml b/pyproject.toml
new file mode 100644
index 0000000..2b12a4f
--- /dev/null
+++ b/pyproject.toml
@@ -0,0 +1,12 @@
1[project]
2name = "Flux"
3version = "0.0.1"
4authors = [
5 { name="Zach Berwaldt", email="zberwaldt@gmail.com" }
6]
7description = "A simple CLI to generate images with flux locally."
8readme = "README.md"
9requires-python = ">=3.12"
10
11[project.scripts]
12flux = "run_flux:main" \ No newline at end of file
diff --git a/run_flux.py b/run_flux.py
index b288935..446f57a 100644
--- a/run_flux.py
+++ b/run_flux.py
@@ -1,11 +1,7 @@
1import torch
2import os 1import os
3import base64
4import argparse 2import argparse
5import argcomplete 3from dataclasses import dataclass, asdict
6from diffusers import FluxPipeline, FluxImg2ImgPipeline 4from typing import Tuple, Optional
7from diffusers.utils import load_image
8from typing import Tuple
9from pathlib import Path 5from pathlib import Path
10from PIL import Image 6from PIL import Image
11import uuid 7import uuid
@@ -39,6 +35,9 @@ def record_prompt(prompt, filename="prompts.txt"):
39 logger.info(f"Prompt already exists in the file: \"{prompt}\"") 35 logger.info(f"Prompt already exists in the file: \"{prompt}\"")
40 36
41def load_flux_img_to_img(): 37def load_flux_img_to_img():
38 import torch
39 from diffusers import FluxImg2ImgPipeline
40
42 pipeline = FluxImg2ImgPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16) 41 pipeline = FluxImg2ImgPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16)
43 pipeline.enable_model_cpu_offload() 42 pipeline.enable_model_cpu_offload()
44 pipeline.vae.enable_slicing() 43 pipeline.vae.enable_slicing()
@@ -46,31 +45,32 @@ def load_flux_img_to_img():
46 return pipeline 45 return pipeline
47 46
48def load_flux(): 47def load_flux():
48 import torch
49 from diffusers import FluxPipeline
50
49 pipeline = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16) 51 pipeline = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16)
50 pipeline.enable_model_cpu_offload() 52 pipeline.enable_model_cpu_offload()
51 pipeline.vae.enable_slicing() 53 pipeline.vae.enable_slicing()
52 pipeline.vae.enable_tiling() 54 pipeline.vae.enable_tiling()
53 return pipeline 55 return pipeline
54 56
55def generate_image(pipeline, prompt, strength=None, prompt_2=None, init_image=None, height=1024, width=1024, guideance_scale=0, num_images_per_prompt=1, num_inference_steps=50): 57@dataclass(frozen=True)
56 kwargs = { 58class GenerateImageConfig:
57 "prompt": prompt, 59 prompt: str
58 "prompt_2": prompt_2, 60 prompt_2: Optional[str] = None
59 "guidance_scale": guideance_scale, 61 init_image: Optional[Image] = None
60 "height":height, 62 strength: int = 0.0
61 "width": width, 63 guidance_scale: float = 0.0
62 "max_sequence_length": 256, 64 height: int = 1024
63 "num_inference_steps": num_inference_steps, 65 width: int = 1024
64 "num_images_per_prompt": num_images_per_prompt 66 num_images_per_prompt: int = 1
65 } 67 num_inference_steps: int = 50
66 68
67 if strength: 69 def to_dict(self):
68 kwargs["strength"] = strength 70 return {k: v for k, v in asdict(self).items() if v is not None}
69 71
70 if isinstance(pipeline, FluxImg2ImgPipeline) and init_image is not None: 72def generate_image(pipeline, config: GenerateImageConfig):
71 kwargs["image"] = init_image 73 images = pipeline(**config.to_dict()).images
72
73 images = pipeline(**kwargs).images
74 return images 74 return images
75 75
76def generate_random_string(length=16) -> str: 76def generate_random_string(length=16) -> str:
@@ -79,27 +79,36 @@ def generate_random_string(length=16) -> str:
79def parse_dimensions(dim_str: str) -> Tuple[int, int]: 79def parse_dimensions(dim_str: str) -> Tuple[int, int]:
80 try: 80 try:
81 width, height = map(int, dim_str.split(':')) 81 width, height = map(int, dim_str.split(':'))
82 return (width, height) 82 return width, height
83 except ValueError: 83 except ValueError:
84 raise argparse.ArgumentError('Dimensions must be in format width:height') 84 raise argparse.ArgumentError('Dimensions must be in format width:height')
85 85
86def main(): 86def main():
87 try: 87 logging.basicConfig(filename="flux.log", level=logging.INFO, format='%(asctime)s - %(levelname)s -> %(message)s',
88 logging.basicConfig(filename="flux.log", level=logging.INFO, format='%(asctime)s - %(levelname)s -> %(message)s', datefmt="%m/%d/%Y %I:%M:%S %p") 88 datefmt="%m/%d/%Y %I:%M:%S %p")
89 89
90 logger.info("Parsing arguments") 90 logger.info("Parsing arguments")
91
92 parser = argparse.ArgumentParser(description="Generate some A.I. images", epilog="All done!")
93
94 parser.add_argument("-n", "--number", type=int, default=1, help="the number of images you want to generate")
95 parser.add_argument("-o", "--output", type=str, default="image", help="the name of the output image")
96 parser.add_argument("-p", "--prompt", type=str, required=True, help="the prompt")
97 parser.add_argument("-p2", "--prompt2", type=str, help="A second prompt")
98 parser.add_argument("-gs", "--guideance-scale", type=float, default=0)
99 parser.add_argument("--strength", type=float)
100 parser.add_argument("--size", type=parse_dimensions, default="1024:1024", help="the size of the output images")
101 parser.add_argument("-u", "--use-image", action="store_true", help="use a predefined image")
102 args = parser.parse_args()
91 103
92 args = parser.parse_args() 104 try:
105 import torch
106 from diffusers.utils import load_image
93 107
94 logger.info("Choosing model...") 108 logger.info("Choosing model...")
95 109
96 pipeline = load_flux_img_to_img() if args.use_image else load_flux() 110 pipeline = load_flux_img_to_img() if args.use_image else load_flux()
97 111
98 if isinstance(pipeline, FluxPipeline):
99 logger.info("Using text-to-image model")
100 else:
101 logger.info("Using image-to-image model")
102
103 pipeline.to(torch.float16) 112 pipeline.to(torch.float16)
104 113
105 # target_img = STORAGE_DIR / "1517481062292.jpg" 114 # target_img = STORAGE_DIR / "1517481062292.jpg"
@@ -119,16 +128,20 @@ def main():
119 128
120 logger.info("Generating image(s)...") 129 logger.info("Generating image(s)...")
121 130
131 config = GenerateImageConfig(
132 init_image=init_image if args.use_image else None,
133 prompt=args.prompt,
134 prompt_2=args.prompt2 if args.prompt2 else None,
135 width=width,
136 height=height,
137 strength=args.strength,
138 guidance_scale=args.guideance_scale,
139 num_images_per_prompt=args.number
140 )
141
122 images = generate_image( 142 images = generate_image(
123 pipeline=pipeline, 143 pipeline=pipeline,
124 init_image=init_image, 144 config=config
125 prompt=args.prompt,
126 prompt_2=args.prompt2,
127 width=width,
128 height=height,
129 strength=args.strength,
130 guideance_scale=args.guideance_scale,
131 num_images_per_prompt=args.number
132 ) 145 )
133 146
134 for image in images: 147 for image in images:
@@ -148,20 +161,6 @@ def main():
148 print(f"An error occured: {e}") 161 print(f"An error occured: {e}")
149 exit(1) 162 exit(1)
150 163
151parser = argparse.ArgumentParser(description="Generate some A.I. images", epilog="All done!")
152
153parser.add_argument("-n", "--number", type=int, default=1, help="the number of images you want to generate")
154parser.add_argument("-o", "--output", type=str, default="image", help="the name of the output image")
155parser.add_argument("-p", "--prompt", type=str, required=True, help="the prompt")
156parser.add_argument("-p2", "--prompt2", type=str, help="A second prompt")
157parser.add_argument("-gs", "--guideance-scale", type=float, default=0)
158parser.add_argument("--strength", type=float)
159parser.add_argument("--size", type=parse_dimensions, default="1024:1024", help="the size of the output images")
160parser.add_argument("-u", "--use-image", action="store_true", help="use a predefined image")
161
162# parser.add_argument("-b", "--base-image").completer = image_completer
163
164# argcomplete.autocomplete(parser)
165 164
166if __name__ == "__main__": 165if __name__ == "__main__":
167 main() 166 main()