From 671c17918c877e64aee7c730f32aac9ded9d5915 Mon Sep 17 00:00:00 2001 From: Zach Berwaldt Date: Tue, 28 Jan 2025 17:53:57 -0500 Subject: reduce args for generate_image by using a config class. update ignore to account for idea folder, re-org code to increase speed. --- .gitignore | 2 + pyproject.toml | 12 ++++++ run_flux.py | 117 ++++++++++++++++++++++++++++----------------------------- 3 files changed, 72 insertions(+), 59 deletions(-) create mode 100644 pyproject.toml diff --git a/.gitignore b/.gitignore index 96ef086..73f016a 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,5 @@ *.png *.log prompts.txt +.idea +Flux.egg-info \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..2b12a4f --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,12 @@ +[project] +name = "Flux" +version = "0.0.1" +authors = [ + { name="Zach Berwaldt", email="zberwaldt@gmail.com" } +] +description = "A simple CLI to generate images with flux locally." +readme = "README.md" +requires-python = ">=3.12" + +[project.scripts] +flux = "run_flux:main" \ No newline at end of file diff --git a/run_flux.py b/run_flux.py index b288935..446f57a 100644 --- a/run_flux.py +++ b/run_flux.py @@ -1,11 +1,7 @@ -import torch import os -import base64 import argparse -import argcomplete -from diffusers import FluxPipeline, FluxImg2ImgPipeline -from diffusers.utils import load_image -from typing import Tuple +from dataclasses import dataclass, asdict +from typing import Tuple, Optional from pathlib import Path from PIL import Image import uuid @@ -39,6 +35,9 @@ def record_prompt(prompt, filename="prompts.txt"): logger.info(f"Prompt already exists in the file: \"{prompt}\"") def load_flux_img_to_img(): + import torch + from diffusers import FluxImg2ImgPipeline + pipeline = FluxImg2ImgPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16) pipeline.enable_model_cpu_offload() pipeline.vae.enable_slicing() @@ -46,31 +45,32 @@ def load_flux_img_to_img(): return pipeline def load_flux(): + import torch + from diffusers import FluxPipeline + pipeline = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16) pipeline.enable_model_cpu_offload() pipeline.vae.enable_slicing() pipeline.vae.enable_tiling() return pipeline -def generate_image(pipeline, prompt, strength=None, prompt_2=None, init_image=None, height=1024, width=1024, guideance_scale=0, num_images_per_prompt=1, num_inference_steps=50): - kwargs = { - "prompt": prompt, - "prompt_2": prompt_2, - "guidance_scale": guideance_scale, - "height":height, - "width": width, - "max_sequence_length": 256, - "num_inference_steps": num_inference_steps, - "num_images_per_prompt": num_images_per_prompt - } - - if strength: - kwargs["strength"] = strength - - if isinstance(pipeline, FluxImg2ImgPipeline) and init_image is not None: - kwargs["image"] = init_image - - images = pipeline(**kwargs).images +@dataclass(frozen=True) +class GenerateImageConfig: + prompt: str + prompt_2: Optional[str] = None + init_image: Optional[Image] = None + strength: int = 0.0 + guidance_scale: float = 0.0 + height: int = 1024 + width: int = 1024 + num_images_per_prompt: int = 1 + num_inference_steps: int = 50 + + def to_dict(self): + return {k: v for k, v in asdict(self).items() if v is not None} + +def generate_image(pipeline, config: GenerateImageConfig): + images = pipeline(**config.to_dict()).images return images def generate_random_string(length=16) -> str: @@ -79,27 +79,36 @@ def generate_random_string(length=16) -> str: def parse_dimensions(dim_str: str) -> Tuple[int, int]: try: width, height = map(int, dim_str.split(':')) - return (width, height) + return width, height except ValueError: raise argparse.ArgumentError('Dimensions must be in format width:height') def main(): - try: - logging.basicConfig(filename="flux.log", level=logging.INFO, format='%(asctime)s - %(levelname)s -> %(message)s', datefmt="%m/%d/%Y %I:%M:%S %p") - - logger.info("Parsing arguments") + logging.basicConfig(filename="flux.log", level=logging.INFO, format='%(asctime)s - %(levelname)s -> %(message)s', + datefmt="%m/%d/%Y %I:%M:%S %p") + + logger.info("Parsing arguments") + + parser = argparse.ArgumentParser(description="Generate some A.I. images", epilog="All done!") + + parser.add_argument("-n", "--number", type=int, default=1, help="the number of images you want to generate") + parser.add_argument("-o", "--output", type=str, default="image", help="the name of the output image") + parser.add_argument("-p", "--prompt", type=str, required=True, help="the prompt") + parser.add_argument("-p2", "--prompt2", type=str, help="A second prompt") + parser.add_argument("-gs", "--guideance-scale", type=float, default=0) + parser.add_argument("--strength", type=float) + parser.add_argument("--size", type=parse_dimensions, default="1024:1024", help="the size of the output images") + parser.add_argument("-u", "--use-image", action="store_true", help="use a predefined image") + args = parser.parse_args() - args = parser.parse_args() + try: + import torch + from diffusers.utils import load_image logger.info("Choosing model...") pipeline = load_flux_img_to_img() if args.use_image else load_flux() - if isinstance(pipeline, FluxPipeline): - logger.info("Using text-to-image model") - else: - logger.info("Using image-to-image model") - pipeline.to(torch.float16) # target_img = STORAGE_DIR / "1517481062292.jpg" @@ -119,16 +128,20 @@ def main(): logger.info("Generating image(s)...") + config = GenerateImageConfig( + init_image=init_image if args.use_image else None, + prompt=args.prompt, + prompt_2=args.prompt2 if args.prompt2 else None, + width=width, + height=height, + strength=args.strength, + guidance_scale=args.guideance_scale, + num_images_per_prompt=args.number + ) + images = generate_image( - pipeline=pipeline, - init_image=init_image, - prompt=args.prompt, - prompt_2=args.prompt2, - width=width, - height=height, - strength=args.strength, - guideance_scale=args.guideance_scale, - num_images_per_prompt=args.number + pipeline=pipeline, + config=config ) for image in images: @@ -148,20 +161,6 @@ def main(): print(f"An error occured: {e}") exit(1) -parser = argparse.ArgumentParser(description="Generate some A.I. images", epilog="All done!") - -parser.add_argument("-n", "--number", type=int, default=1, help="the number of images you want to generate") -parser.add_argument("-o", "--output", type=str, default="image", help="the name of the output image") -parser.add_argument("-p", "--prompt", type=str, required=True, help="the prompt") -parser.add_argument("-p2", "--prompt2", type=str, help="A second prompt") -parser.add_argument("-gs", "--guideance-scale", type=float, default=0) -parser.add_argument("--strength", type=float) -parser.add_argument("--size", type=parse_dimensions, default="1024:1024", help="the size of the output images") -parser.add_argument("-u", "--use-image", action="store_true", help="use a predefined image") - -# parser.add_argument("-b", "--base-image").completer = image_completer - -# argcomplete.autocomplete(parser) if __name__ == "__main__": main() -- cgit v1.1