diff options
| -rw-r--r-- | .gitignore | 2 | ||||
| -rw-r--r-- | pyproject.toml | 12 | ||||
| -rw-r--r-- | run_flux.py | 117 |
3 files changed, 72 insertions, 59 deletions
| @@ -1,3 +1,5 @@ | |||
| 1 | *.png | 1 | *.png |
| 2 | *.log | 2 | *.log |
| 3 | prompts.txt | 3 | prompts.txt |
| 4 | .idea | ||
| 5 | Flux.egg-info \ No newline at end of file | ||
diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..2b12a4f --- /dev/null +++ b/pyproject.toml | |||
| @@ -0,0 +1,12 @@ | |||
| 1 | [project] | ||
| 2 | name = "Flux" | ||
| 3 | version = "0.0.1" | ||
| 4 | authors = [ | ||
| 5 | { name="Zach Berwaldt", email="zberwaldt@gmail.com" } | ||
| 6 | ] | ||
| 7 | description = "A simple CLI to generate images with flux locally." | ||
| 8 | readme = "README.md" | ||
| 9 | requires-python = ">=3.12" | ||
| 10 | |||
| 11 | [project.scripts] | ||
| 12 | flux = "run_flux:main" \ No newline at end of file | ||
diff --git a/run_flux.py b/run_flux.py index b288935..446f57a 100644 --- a/run_flux.py +++ b/run_flux.py | |||
| @@ -1,11 +1,7 @@ | |||
| 1 | import torch | ||
| 2 | import os | 1 | import os |
| 3 | import base64 | ||
| 4 | import argparse | 2 | import argparse |
| 5 | import argcomplete | 3 | from dataclasses import dataclass, asdict |
| 6 | from diffusers import FluxPipeline, FluxImg2ImgPipeline | 4 | from typing import Tuple, Optional |
| 7 | from diffusers.utils import load_image | ||
| 8 | from typing import Tuple | ||
| 9 | from pathlib import Path | 5 | from pathlib import Path |
| 10 | from PIL import Image | 6 | from PIL import Image |
| 11 | import uuid | 7 | import uuid |
| @@ -39,6 +35,9 @@ def record_prompt(prompt, filename="prompts.txt"): | |||
| 39 | logger.info(f"Prompt already exists in the file: \"{prompt}\"") | 35 | logger.info(f"Prompt already exists in the file: \"{prompt}\"") |
| 40 | 36 | ||
| 41 | def load_flux_img_to_img(): | 37 | def load_flux_img_to_img(): |
| 38 | import torch | ||
| 39 | from diffusers import FluxImg2ImgPipeline | ||
| 40 | |||
| 42 | pipeline = FluxImg2ImgPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16) | 41 | pipeline = FluxImg2ImgPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16) |
| 43 | pipeline.enable_model_cpu_offload() | 42 | pipeline.enable_model_cpu_offload() |
| 44 | pipeline.vae.enable_slicing() | 43 | pipeline.vae.enable_slicing() |
| @@ -46,31 +45,32 @@ def load_flux_img_to_img(): | |||
| 46 | return pipeline | 45 | return pipeline |
| 47 | 46 | ||
| 48 | def load_flux(): | 47 | def load_flux(): |
| 48 | import torch | ||
| 49 | from diffusers import FluxPipeline | ||
| 50 | |||
| 49 | pipeline = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16) | 51 | pipeline = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16) |
| 50 | pipeline.enable_model_cpu_offload() | 52 | pipeline.enable_model_cpu_offload() |
| 51 | pipeline.vae.enable_slicing() | 53 | pipeline.vae.enable_slicing() |
| 52 | pipeline.vae.enable_tiling() | 54 | pipeline.vae.enable_tiling() |
| 53 | return pipeline | 55 | return pipeline |
| 54 | 56 | ||
| 55 | def generate_image(pipeline, prompt, strength=None, prompt_2=None, init_image=None, height=1024, width=1024, guideance_scale=0, num_images_per_prompt=1, num_inference_steps=50): | 57 | @dataclass(frozen=True) |
| 56 | kwargs = { | 58 | class GenerateImageConfig: |
| 57 | "prompt": prompt, | 59 | prompt: str |
| 58 | "prompt_2": prompt_2, | 60 | prompt_2: Optional[str] = None |
| 59 | "guidance_scale": guideance_scale, | 61 | init_image: Optional[Image] = None |
| 60 | "height":height, | 62 | strength: int = 0.0 |
| 61 | "width": width, | 63 | guidance_scale: float = 0.0 |
| 62 | "max_sequence_length": 256, | 64 | height: int = 1024 |
| 63 | "num_inference_steps": num_inference_steps, | 65 | width: int = 1024 |
| 64 | "num_images_per_prompt": num_images_per_prompt | 66 | num_images_per_prompt: int = 1 |
| 65 | } | 67 | num_inference_steps: int = 50 |
| 66 | 68 | ||
| 67 | if strength: | 69 | def to_dict(self): |
| 68 | kwargs["strength"] = strength | 70 | return {k: v for k, v in asdict(self).items() if v is not None} |
| 69 | 71 | ||
| 70 | if isinstance(pipeline, FluxImg2ImgPipeline) and init_image is not None: | 72 | def generate_image(pipeline, config: GenerateImageConfig): |
| 71 | kwargs["image"] = init_image | 73 | images = pipeline(**config.to_dict()).images |
| 72 | |||
| 73 | images = pipeline(**kwargs).images | ||
| 74 | return images | 74 | return images |
| 75 | 75 | ||
| 76 | def generate_random_string(length=16) -> str: | 76 | def generate_random_string(length=16) -> str: |
| @@ -79,27 +79,36 @@ def generate_random_string(length=16) -> str: | |||
| 79 | def parse_dimensions(dim_str: str) -> Tuple[int, int]: | 79 | def parse_dimensions(dim_str: str) -> Tuple[int, int]: |
| 80 | try: | 80 | try: |
| 81 | width, height = map(int, dim_str.split(':')) | 81 | width, height = map(int, dim_str.split(':')) |
| 82 | return (width, height) | 82 | return width, height |
| 83 | except ValueError: | 83 | except ValueError: |
| 84 | raise argparse.ArgumentError('Dimensions must be in format width:height') | 84 | raise argparse.ArgumentError('Dimensions must be in format width:height') |
| 85 | 85 | ||
| 86 | def main(): | 86 | def main(): |
| 87 | try: | 87 | logging.basicConfig(filename="flux.log", level=logging.INFO, format='%(asctime)s - %(levelname)s -> %(message)s', |
| 88 | logging.basicConfig(filename="flux.log", level=logging.INFO, format='%(asctime)s - %(levelname)s -> %(message)s', datefmt="%m/%d/%Y %I:%M:%S %p") | 88 | datefmt="%m/%d/%Y %I:%M:%S %p") |
| 89 | 89 | ||
| 90 | logger.info("Parsing arguments") | 90 | logger.info("Parsing arguments") |
| 91 | |||
| 92 | parser = argparse.ArgumentParser(description="Generate some A.I. images", epilog="All done!") | ||
| 93 | |||
| 94 | parser.add_argument("-n", "--number", type=int, default=1, help="the number of images you want to generate") | ||
| 95 | parser.add_argument("-o", "--output", type=str, default="image", help="the name of the output image") | ||
| 96 | parser.add_argument("-p", "--prompt", type=str, required=True, help="the prompt") | ||
| 97 | parser.add_argument("-p2", "--prompt2", type=str, help="A second prompt") | ||
| 98 | parser.add_argument("-gs", "--guideance-scale", type=float, default=0) | ||
| 99 | parser.add_argument("--strength", type=float) | ||
| 100 | parser.add_argument("--size", type=parse_dimensions, default="1024:1024", help="the size of the output images") | ||
| 101 | parser.add_argument("-u", "--use-image", action="store_true", help="use a predefined image") | ||
| 102 | args = parser.parse_args() | ||
| 91 | 103 | ||
| 92 | args = parser.parse_args() | 104 | try: |
| 105 | import torch | ||
| 106 | from diffusers.utils import load_image | ||
| 93 | 107 | ||
| 94 | logger.info("Choosing model...") | 108 | logger.info("Choosing model...") |
| 95 | 109 | ||
| 96 | pipeline = load_flux_img_to_img() if args.use_image else load_flux() | 110 | pipeline = load_flux_img_to_img() if args.use_image else load_flux() |
| 97 | 111 | ||
| 98 | if isinstance(pipeline, FluxPipeline): | ||
| 99 | logger.info("Using text-to-image model") | ||
| 100 | else: | ||
| 101 | logger.info("Using image-to-image model") | ||
| 102 | |||
| 103 | pipeline.to(torch.float16) | 112 | pipeline.to(torch.float16) |
| 104 | 113 | ||
| 105 | # target_img = STORAGE_DIR / "1517481062292.jpg" | 114 | # target_img = STORAGE_DIR / "1517481062292.jpg" |
| @@ -119,16 +128,20 @@ def main(): | |||
| 119 | 128 | ||
| 120 | logger.info("Generating image(s)...") | 129 | logger.info("Generating image(s)...") |
| 121 | 130 | ||
| 131 | config = GenerateImageConfig( | ||
| 132 | init_image=init_image if args.use_image else None, | ||
| 133 | prompt=args.prompt, | ||
| 134 | prompt_2=args.prompt2 if args.prompt2 else None, | ||
| 135 | width=width, | ||
| 136 | height=height, | ||
| 137 | strength=args.strength, | ||
| 138 | guidance_scale=args.guideance_scale, | ||
| 139 | num_images_per_prompt=args.number | ||
| 140 | ) | ||
| 141 | |||
| 122 | images = generate_image( | 142 | images = generate_image( |
| 123 | pipeline=pipeline, | 143 | pipeline=pipeline, |
| 124 | init_image=init_image, | 144 | config=config |
| 125 | prompt=args.prompt, | ||
| 126 | prompt_2=args.prompt2, | ||
| 127 | width=width, | ||
| 128 | height=height, | ||
| 129 | strength=args.strength, | ||
| 130 | guideance_scale=args.guideance_scale, | ||
| 131 | num_images_per_prompt=args.number | ||
| 132 | ) | 145 | ) |
| 133 | 146 | ||
| 134 | for image in images: | 147 | for image in images: |
| @@ -148,20 +161,6 @@ def main(): | |||
| 148 | print(f"An error occured: {e}") | 161 | print(f"An error occured: {e}") |
| 149 | exit(1) | 162 | exit(1) |
| 150 | 163 | ||
| 151 | parser = argparse.ArgumentParser(description="Generate some A.I. images", epilog="All done!") | ||
| 152 | |||
| 153 | parser.add_argument("-n", "--number", type=int, default=1, help="the number of images you want to generate") | ||
| 154 | parser.add_argument("-o", "--output", type=str, default="image", help="the name of the output image") | ||
| 155 | parser.add_argument("-p", "--prompt", type=str, required=True, help="the prompt") | ||
| 156 | parser.add_argument("-p2", "--prompt2", type=str, help="A second prompt") | ||
| 157 | parser.add_argument("-gs", "--guideance-scale", type=float, default=0) | ||
| 158 | parser.add_argument("--strength", type=float) | ||
| 159 | parser.add_argument("--size", type=parse_dimensions, default="1024:1024", help="the size of the output images") | ||
| 160 | parser.add_argument("-u", "--use-image", action="store_true", help="use a predefined image") | ||
| 161 | |||
| 162 | # parser.add_argument("-b", "--base-image").completer = image_completer | ||
| 163 | |||
| 164 | # argcomplete.autocomplete(parser) | ||
| 165 | 164 | ||
| 166 | if __name__ == "__main__": | 165 | if __name__ == "__main__": |
| 167 | main() | 166 | main() |
