diff options
author | Zach Berwaldt <zberwaldt@tutamail.com> | 2025-01-28 17:53:57 -0500 |
---|---|---|
committer | Zach Berwaldt <zberwaldt@tutamail.com> | 2025-01-28 17:53:57 -0500 |
commit | 671c17918c877e64aee7c730f32aac9ded9d5915 (patch) | |
tree | 5dd489a1da533123245d4f276e2ef11c87764047 | |
parent | b067c48e5db91796c96c982011421713b0187ae4 (diff) |
reduce args for generate_image by using a config class. update ignore to account for idea folder, re-org code to increase speed.
-rw-r--r-- | .gitignore | 2 | ||||
-rw-r--r-- | pyproject.toml | 12 | ||||
-rw-r--r-- | run_flux.py | 117 |
3 files changed, 72 insertions, 59 deletions
@@ -1,3 +1,5 @@ | |||
1 | *.png | 1 | *.png |
2 | *.log | 2 | *.log |
3 | prompts.txt | 3 | prompts.txt |
4 | .idea | ||
5 | Flux.egg-info \ No newline at end of file | ||
diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..2b12a4f --- /dev/null +++ b/pyproject.toml | |||
@@ -0,0 +1,12 @@ | |||
1 | [project] | ||
2 | name = "Flux" | ||
3 | version = "0.0.1" | ||
4 | authors = [ | ||
5 | { name="Zach Berwaldt", email="zberwaldt@gmail.com" } | ||
6 | ] | ||
7 | description = "A simple CLI to generate images with flux locally." | ||
8 | readme = "README.md" | ||
9 | requires-python = ">=3.12" | ||
10 | |||
11 | [project.scripts] | ||
12 | flux = "run_flux:main" \ No newline at end of file | ||
diff --git a/run_flux.py b/run_flux.py index b288935..446f57a 100644 --- a/run_flux.py +++ b/run_flux.py | |||
@@ -1,11 +1,7 @@ | |||
1 | import torch | ||
2 | import os | 1 | import os |
3 | import base64 | ||
4 | import argparse | 2 | import argparse |
5 | import argcomplete | 3 | from dataclasses import dataclass, asdict |
6 | from diffusers import FluxPipeline, FluxImg2ImgPipeline | 4 | from typing import Tuple, Optional |
7 | from diffusers.utils import load_image | ||
8 | from typing import Tuple | ||
9 | from pathlib import Path | 5 | from pathlib import Path |
10 | from PIL import Image | 6 | from PIL import Image |
11 | import uuid | 7 | import uuid |
@@ -39,6 +35,9 @@ def record_prompt(prompt, filename="prompts.txt"): | |||
39 | logger.info(f"Prompt already exists in the file: \"{prompt}\"") | 35 | logger.info(f"Prompt already exists in the file: \"{prompt}\"") |
40 | 36 | ||
41 | def load_flux_img_to_img(): | 37 | def load_flux_img_to_img(): |
38 | import torch | ||
39 | from diffusers import FluxImg2ImgPipeline | ||
40 | |||
42 | pipeline = FluxImg2ImgPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16) | 41 | pipeline = FluxImg2ImgPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16) |
43 | pipeline.enable_model_cpu_offload() | 42 | pipeline.enable_model_cpu_offload() |
44 | pipeline.vae.enable_slicing() | 43 | pipeline.vae.enable_slicing() |
@@ -46,31 +45,32 @@ def load_flux_img_to_img(): | |||
46 | return pipeline | 45 | return pipeline |
47 | 46 | ||
48 | def load_flux(): | 47 | def load_flux(): |
48 | import torch | ||
49 | from diffusers import FluxPipeline | ||
50 | |||
49 | pipeline = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16) | 51 | pipeline = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16) |
50 | pipeline.enable_model_cpu_offload() | 52 | pipeline.enable_model_cpu_offload() |
51 | pipeline.vae.enable_slicing() | 53 | pipeline.vae.enable_slicing() |
52 | pipeline.vae.enable_tiling() | 54 | pipeline.vae.enable_tiling() |
53 | return pipeline | 55 | return pipeline |
54 | 56 | ||
55 | def generate_image(pipeline, prompt, strength=None, prompt_2=None, init_image=None, height=1024, width=1024, guideance_scale=0, num_images_per_prompt=1, num_inference_steps=50): | 57 | @dataclass(frozen=True) |
56 | kwargs = { | 58 | class GenerateImageConfig: |
57 | "prompt": prompt, | 59 | prompt: str |
58 | "prompt_2": prompt_2, | 60 | prompt_2: Optional[str] = None |
59 | "guidance_scale": guideance_scale, | 61 | init_image: Optional[Image] = None |
60 | "height":height, | 62 | strength: int = 0.0 |
61 | "width": width, | 63 | guidance_scale: float = 0.0 |
62 | "max_sequence_length": 256, | 64 | height: int = 1024 |
63 | "num_inference_steps": num_inference_steps, | 65 | width: int = 1024 |
64 | "num_images_per_prompt": num_images_per_prompt | 66 | num_images_per_prompt: int = 1 |
65 | } | 67 | num_inference_steps: int = 50 |
66 | 68 | ||
67 | if strength: | 69 | def to_dict(self): |
68 | kwargs["strength"] = strength | 70 | return {k: v for k, v in asdict(self).items() if v is not None} |
69 | 71 | ||
70 | if isinstance(pipeline, FluxImg2ImgPipeline) and init_image is not None: | 72 | def generate_image(pipeline, config: GenerateImageConfig): |
71 | kwargs["image"] = init_image | 73 | images = pipeline(**config.to_dict()).images |
72 | |||
73 | images = pipeline(**kwargs).images | ||
74 | return images | 74 | return images |
75 | 75 | ||
76 | def generate_random_string(length=16) -> str: | 76 | def generate_random_string(length=16) -> str: |
@@ -79,27 +79,36 @@ def generate_random_string(length=16) -> str: | |||
79 | def parse_dimensions(dim_str: str) -> Tuple[int, int]: | 79 | def parse_dimensions(dim_str: str) -> Tuple[int, int]: |
80 | try: | 80 | try: |
81 | width, height = map(int, dim_str.split(':')) | 81 | width, height = map(int, dim_str.split(':')) |
82 | return (width, height) | 82 | return width, height |
83 | except ValueError: | 83 | except ValueError: |
84 | raise argparse.ArgumentError('Dimensions must be in format width:height') | 84 | raise argparse.ArgumentError('Dimensions must be in format width:height') |
85 | 85 | ||
86 | def main(): | 86 | def main(): |
87 | try: | 87 | logging.basicConfig(filename="flux.log", level=logging.INFO, format='%(asctime)s - %(levelname)s -> %(message)s', |
88 | logging.basicConfig(filename="flux.log", level=logging.INFO, format='%(asctime)s - %(levelname)s -> %(message)s', datefmt="%m/%d/%Y %I:%M:%S %p") | 88 | datefmt="%m/%d/%Y %I:%M:%S %p") |
89 | 89 | ||
90 | logger.info("Parsing arguments") | 90 | logger.info("Parsing arguments") |
91 | |||
92 | parser = argparse.ArgumentParser(description="Generate some A.I. images", epilog="All done!") | ||
93 | |||
94 | parser.add_argument("-n", "--number", type=int, default=1, help="the number of images you want to generate") | ||
95 | parser.add_argument("-o", "--output", type=str, default="image", help="the name of the output image") | ||
96 | parser.add_argument("-p", "--prompt", type=str, required=True, help="the prompt") | ||
97 | parser.add_argument("-p2", "--prompt2", type=str, help="A second prompt") | ||
98 | parser.add_argument("-gs", "--guideance-scale", type=float, default=0) | ||
99 | parser.add_argument("--strength", type=float) | ||
100 | parser.add_argument("--size", type=parse_dimensions, default="1024:1024", help="the size of the output images") | ||
101 | parser.add_argument("-u", "--use-image", action="store_true", help="use a predefined image") | ||
102 | args = parser.parse_args() | ||
91 | 103 | ||
92 | args = parser.parse_args() | 104 | try: |
105 | import torch | ||
106 | from diffusers.utils import load_image | ||
93 | 107 | ||
94 | logger.info("Choosing model...") | 108 | logger.info("Choosing model...") |
95 | 109 | ||
96 | pipeline = load_flux_img_to_img() if args.use_image else load_flux() | 110 | pipeline = load_flux_img_to_img() if args.use_image else load_flux() |
97 | 111 | ||
98 | if isinstance(pipeline, FluxPipeline): | ||
99 | logger.info("Using text-to-image model") | ||
100 | else: | ||
101 | logger.info("Using image-to-image model") | ||
102 | |||
103 | pipeline.to(torch.float16) | 112 | pipeline.to(torch.float16) |
104 | 113 | ||
105 | # target_img = STORAGE_DIR / "1517481062292.jpg" | 114 | # target_img = STORAGE_DIR / "1517481062292.jpg" |
@@ -119,16 +128,20 @@ def main(): | |||
119 | 128 | ||
120 | logger.info("Generating image(s)...") | 129 | logger.info("Generating image(s)...") |
121 | 130 | ||
131 | config = GenerateImageConfig( | ||
132 | init_image=init_image if args.use_image else None, | ||
133 | prompt=args.prompt, | ||
134 | prompt_2=args.prompt2 if args.prompt2 else None, | ||
135 | width=width, | ||
136 | height=height, | ||
137 | strength=args.strength, | ||
138 | guidance_scale=args.guideance_scale, | ||
139 | num_images_per_prompt=args.number | ||
140 | ) | ||
141 | |||
122 | images = generate_image( | 142 | images = generate_image( |
123 | pipeline=pipeline, | 143 | pipeline=pipeline, |
124 | init_image=init_image, | 144 | config=config |
125 | prompt=args.prompt, | ||
126 | prompt_2=args.prompt2, | ||
127 | width=width, | ||
128 | height=height, | ||
129 | strength=args.strength, | ||
130 | guideance_scale=args.guideance_scale, | ||
131 | num_images_per_prompt=args.number | ||
132 | ) | 145 | ) |
133 | 146 | ||
134 | for image in images: | 147 | for image in images: |
@@ -148,20 +161,6 @@ def main(): | |||
148 | print(f"An error occured: {e}") | 161 | print(f"An error occured: {e}") |
149 | exit(1) | 162 | exit(1) |
150 | 163 | ||
151 | parser = argparse.ArgumentParser(description="Generate some A.I. images", epilog="All done!") | ||
152 | |||
153 | parser.add_argument("-n", "--number", type=int, default=1, help="the number of images you want to generate") | ||
154 | parser.add_argument("-o", "--output", type=str, default="image", help="the name of the output image") | ||
155 | parser.add_argument("-p", "--prompt", type=str, required=True, help="the prompt") | ||
156 | parser.add_argument("-p2", "--prompt2", type=str, help="A second prompt") | ||
157 | parser.add_argument("-gs", "--guideance-scale", type=float, default=0) | ||
158 | parser.add_argument("--strength", type=float) | ||
159 | parser.add_argument("--size", type=parse_dimensions, default="1024:1024", help="the size of the output images") | ||
160 | parser.add_argument("-u", "--use-image", action="store_true", help="use a predefined image") | ||
161 | |||
162 | # parser.add_argument("-b", "--base-image").completer = image_completer | ||
163 | |||
164 | # argcomplete.autocomplete(parser) | ||
165 | 164 | ||
166 | if __name__ == "__main__": | 165 | if __name__ == "__main__": |
167 | main() | 166 | main() |