diff options
author | Zach Berwaldt <zberwaldt@tutamail.com> | 2024-10-29 21:51:54 -0400 |
---|---|---|
committer | Zach Berwaldt <zberwaldt@tutamail.com> | 2024-10-29 21:51:54 -0400 |
commit | 47084f8024b2be503f5abe0c0464512a226b7356 (patch) | |
tree | 0b13f0f759f9bf6d1d4f26f8507700bdf61808bb |
first commit
-rw-r--r-- | .gitignore | 2 | ||||
-rw-r--r-- | run_flux.py | 64 |
2 files changed, 66 insertions, 0 deletions
diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..2fa80d6 --- /dev/null +++ b/.gitignore | |||
@@ -0,0 +1,2 @@ | |||
1 | *.png | ||
2 | |||
diff --git a/run_flux.py b/run_flux.py new file mode 100644 index 0000000..c095fce --- /dev/null +++ b/run_flux.py | |||
@@ -0,0 +1,64 @@ | |||
1 | import torch | ||
2 | import os | ||
3 | import base64 | ||
4 | import argparse | ||
5 | from diffusers import FluxPipeline | ||
6 | from typing import Tuple | ||
7 | import uuid | ||
8 | |||
9 | def load_flux(): | ||
10 | pipeline = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16) | ||
11 | pipeline.enable_model_cpu_offload() | ||
12 | pipeline.vae.enable_slicing() | ||
13 | pipeline.vae.enable_tiling() | ||
14 | return pipeline | ||
15 | |||
16 | def generate_image(pipeline, prompt, height=1024, width=1024, guideance_scale=0, num_images_per_prompt=1, num_inference_steps=50): | ||
17 | images = pipeline( | ||
18 | prompt=prompt, | ||
19 | guidance_scale=guideance_scale, | ||
20 | height=height, | ||
21 | width=width, | ||
22 | max_sequence_length=256, | ||
23 | num_inference_steps=num_inference_steps, | ||
24 | num_images_per_prompt=num_images_per_prompt | ||
25 | ).images | ||
26 | return images | ||
27 | |||
28 | def generate_random_string(length=16) -> str: | ||
29 | return str(uuid.uuid4()) | ||
30 | |||
31 | def parse_dimensions(dim_str: str) -> Tuple[int, int]: | ||
32 | try: | ||
33 | width, height = map(int, dim_str.split(':')) | ||
34 | return (width, height) | ||
35 | except ValueError: | ||
36 | raise argparse.ArgumentError('Dimensions must be in format width:height') | ||
37 | |||
38 | def main(): | ||
39 | try: | ||
40 | args = parser.parse_args() | ||
41 | |||
42 | pipeline = load_flux() | ||
43 | pipeline.to(torch.float16) | ||
44 | |||
45 | width, height = args.size | ||
46 | |||
47 | images = generate_image(pipeline, prompt=args.prompt, width=width, height=height, guideance_scale=args.guideance_scale, num_images_per_prompt=args.number) | ||
48 | for image in images: | ||
49 | filename = generate_random_string() | ||
50 | image.save(f"{filename}.png") | ||
51 | except KeyboardInterrupt: | ||
52 | print('\nExiting early...') | ||
53 | exit(0) | ||
54 | |||
55 | parser = argparse.ArgumentParser(description="Generate some A.I. images", epilog="All done!") | ||
56 | |||
57 | parser.add_argument("-n", "--number", type=int, default=1, help="the number of images you want to generate") | ||
58 | parser.add_argument("-o", "--output", type=str, default="image", help="the name of the output image") | ||
59 | parser.add_argument("-p", "--prompt", type=str, required=True, help="the prompt") | ||
60 | parser.add_argument("-gs", "--guideance-scale", type=float, default=0) | ||
61 | parser.add_argument("--size", type=parse_dimensions, default="1024:1024") | ||
62 | |||
63 | if __name__ == "__main__": | ||
64 | main() | ||