diff options
| -rw-r--r-- | .gitignore | 2 | ||||
| -rw-r--r-- | run_flux.py | 64 |
2 files changed, 66 insertions, 0 deletions
diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..2fa80d6 --- /dev/null +++ b/.gitignore | |||
| @@ -0,0 +1,2 @@ | |||
| 1 | *.png | ||
| 2 | |||
diff --git a/run_flux.py b/run_flux.py new file mode 100644 index 0000000..c095fce --- /dev/null +++ b/run_flux.py | |||
| @@ -0,0 +1,64 @@ | |||
| 1 | import torch | ||
| 2 | import os | ||
| 3 | import base64 | ||
| 4 | import argparse | ||
| 5 | from diffusers import FluxPipeline | ||
| 6 | from typing import Tuple | ||
| 7 | import uuid | ||
| 8 | |||
| 9 | def load_flux(): | ||
| 10 | pipeline = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16) | ||
| 11 | pipeline.enable_model_cpu_offload() | ||
| 12 | pipeline.vae.enable_slicing() | ||
| 13 | pipeline.vae.enable_tiling() | ||
| 14 | return pipeline | ||
| 15 | |||
| 16 | def generate_image(pipeline, prompt, height=1024, width=1024, guideance_scale=0, num_images_per_prompt=1, num_inference_steps=50): | ||
| 17 | images = pipeline( | ||
| 18 | prompt=prompt, | ||
| 19 | guidance_scale=guideance_scale, | ||
| 20 | height=height, | ||
| 21 | width=width, | ||
| 22 | max_sequence_length=256, | ||
| 23 | num_inference_steps=num_inference_steps, | ||
| 24 | num_images_per_prompt=num_images_per_prompt | ||
| 25 | ).images | ||
| 26 | return images | ||
| 27 | |||
| 28 | def generate_random_string(length=16) -> str: | ||
| 29 | return str(uuid.uuid4()) | ||
| 30 | |||
| 31 | def parse_dimensions(dim_str: str) -> Tuple[int, int]: | ||
| 32 | try: | ||
| 33 | width, height = map(int, dim_str.split(':')) | ||
| 34 | return (width, height) | ||
| 35 | except ValueError: | ||
| 36 | raise argparse.ArgumentError('Dimensions must be in format width:height') | ||
| 37 | |||
| 38 | def main(): | ||
| 39 | try: | ||
| 40 | args = parser.parse_args() | ||
| 41 | |||
| 42 | pipeline = load_flux() | ||
| 43 | pipeline.to(torch.float16) | ||
| 44 | |||
| 45 | width, height = args.size | ||
| 46 | |||
| 47 | images = generate_image(pipeline, prompt=args.prompt, width=width, height=height, guideance_scale=args.guideance_scale, num_images_per_prompt=args.number) | ||
| 48 | for image in images: | ||
| 49 | filename = generate_random_string() | ||
| 50 | image.save(f"{filename}.png") | ||
| 51 | except KeyboardInterrupt: | ||
| 52 | print('\nExiting early...') | ||
| 53 | exit(0) | ||
| 54 | |||
| 55 | parser = argparse.ArgumentParser(description="Generate some A.I. images", epilog="All done!") | ||
| 56 | |||
| 57 | parser.add_argument("-n", "--number", type=int, default=1, help="the number of images you want to generate") | ||
| 58 | parser.add_argument("-o", "--output", type=str, default="image", help="the name of the output image") | ||
| 59 | parser.add_argument("-p", "--prompt", type=str, required=True, help="the prompt") | ||
| 60 | parser.add_argument("-gs", "--guideance-scale", type=float, default=0) | ||
| 61 | parser.add_argument("--size", type=parse_dimensions, default="1024:1024") | ||
| 62 | |||
| 63 | if __name__ == "__main__": | ||
| 64 | main() | ||
