diff options
author | Zach Berwaldt <zberwaldt@tutamail.com> | 2024-10-30 22:00:56 -0400 |
---|---|---|
committer | Zach Berwaldt <zberwaldt@tutamail.com> | 2024-10-30 22:00:56 -0400 |
commit | ffbc2e788c113b278b9306aac4cea1d644eaf048 (patch) | |
tree | 31794861d69249c923d4f8a10e5b7772abd7372f | |
parent | 4445e6f202e9d3a7f7e9bc94894c1eb3bc6c7945 (diff) |
add logging, refactor generate image, keep track of prompts used.
-rw-r--r-- | .gitignore | 2 | ||||
-rw-r--r-- | prompts.txt | 2 | ||||
-rw-r--r-- | run_flux.py | 101 |
3 files changed, 89 insertions, 16 deletions
@@ -1,2 +1,2 @@ | |||
1 | *.png | 1 | *.png |
2 | 2 | *.log | |
diff --git a/prompts.txt b/prompts.txt new file mode 100644 index 0000000..3268a51 --- /dev/null +++ b/prompts.txt | |||
@@ -0,0 +1,2 @@ | |||
1 | A space pilot, wearing her space flight suit, graying hair, determined to face the future. | ||
2 | Female with very large breasts. | ||
diff --git a/run_flux.py b/run_flux.py index 4a27816..89d188b 100644 --- a/run_flux.py +++ b/run_flux.py | |||
@@ -2,17 +2,42 @@ import torch | |||
2 | import os | 2 | import os |
3 | import base64 | 3 | import base64 |
4 | import argparse | 4 | import argparse |
5 | import argcomplete | ||
5 | from diffusers import FluxPipeline, FluxImg2ImgPipeline | 6 | from diffusers import FluxPipeline, FluxImg2ImgPipeline |
6 | from diffusers.utils import load_image | 7 | from diffusers.utils import load_image |
7 | from typing import Tuple | 8 | from typing import Tuple |
8 | from pathlib import Path | 9 | from pathlib import Path |
9 | from PIL import Image | 10 | from PIL import Image |
10 | import uuid | 11 | import uuid |
12 | import logging | ||
11 | 13 | ||
12 | STORAGE_DIR: Path = Path.home() / "Pictures" | 14 | STORAGE_DIR: Path = Path.home() / "Pictures" |
13 | 15 | ||
14 | STORAGE_DIR.mkdir(parents=True, exist_ok=True) | 16 | STORAGE_DIR.mkdir(parents=True, exist_ok=True) |
15 | 17 | ||
18 | logger = logging.getLogger("run_flux") | ||
19 | |||
20 | def image_completer(prefix, parsed_args, **kwargs): | ||
21 | image_dir = STORAGE_DIR / "Flux" | ||
22 | return [ | ||
23 | filename for filename in os.listdir(image_dir) | ||
24 | if filename.startswith(prefix) and os.path.isfile(os.path.join(image_dir, filename)) | ||
25 | ] | ||
26 | |||
27 | def record_prompt(prompt, filename="prompts.txt"): | ||
28 | try: | ||
29 | with open(filename, "r") as file: | ||
30 | existing_prompts = set(line.strip() for line in file) | ||
31 | except FileNotFoundError: | ||
32 | existing_prompts = set() | ||
33 | |||
34 | if prompt not in existing_prompts: | ||
35 | with open(filename, "a") as file: | ||
36 | file.write(prompt + "\n") | ||
37 | logger.info(f"Recording new prompt: \"{prompt}\"") | ||
38 | else: | ||
39 | logger.info(f"Prompt already exists in the file: \"{prompt}\"") | ||
40 | |||
16 | def load_flux_img_to_img(): | 41 | def load_flux_img_to_img(): |
17 | pipeline = FluxImg2ImgPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16) | 42 | pipeline = FluxImg2ImgPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16) |
18 | pipeline.enable_model_cpu_offload() | 43 | pipeline.enable_model_cpu_offload() |
@@ -27,17 +52,22 @@ def load_flux(): | |||
27 | pipeline.vae.enable_tiling() | 52 | pipeline.vae.enable_tiling() |
28 | return pipeline | 53 | return pipeline |
29 | 54 | ||
30 | def generate_image(pipeline, prompt, init_image=None, height=1024, width=1024, guideance_scale=0, num_images_per_prompt=1, num_inference_steps=50): | 55 | def generate_image(pipeline, prompt, prompt_2=None, init_image=None, height=1024, width=1024, guideance_scale=0, num_images_per_prompt=1, num_inference_steps=50): |
31 | images = pipeline( | 56 | kwargs = { |
32 | prompt=prompt, | 57 | "prompt": prompt, |
33 | guidance_scale=guideance_scale, | 58 | "prompt_2": prompt_2, |
34 | height=height, | 59 | "guidance_scale": guideance_scale, |
35 | width=width, | 60 | "height":height, |
36 | max_sequence_length=256, | 61 | "width": width, |
37 | num_inference_steps=num_inference_steps, | 62 | "max_sequence_length": 256, |
38 | num_images_per_prompt=num_images_per_prompt, | 63 | "num_inference_steps": num_inference_steps, |
39 | image=init_image | 64 | "num_images_per_prompt": num_images_per_prompt |
40 | ).images | 65 | } |
66 | |||
67 | if isinstance(pipeline, FluxImg2ImgPipeline) and init_image is not None: | ||
68 | kwargs["image"] = init_image | ||
69 | |||
70 | images = pipeline(**kwargs).images | ||
41 | return images | 71 | return images |
42 | 72 | ||
43 | def generate_random_string(length=16) -> str: | 73 | def generate_random_string(length=16) -> str: |
@@ -52,12 +82,25 @@ def parse_dimensions(dim_str: str) -> Tuple[int, int]: | |||
52 | 82 | ||
53 | def main(): | 83 | def main(): |
54 | try: | 84 | try: |
85 | logging.basicConfig(filename="flux.log", level=logging.INFO, format='%(asctime)s - %(levelname)s -> %(message)s', datefmt="%m/%d/%Y %I:%M:%S %p") | ||
86 | |||
87 | logger.info("Parsing arguments") | ||
88 | |||
55 | args = parser.parse_args() | 89 | args = parser.parse_args() |
56 | 90 | ||
57 | pipeline = load_flux_img_to_img() | 91 | logger.info("Choosing model...") |
92 | |||
93 | pipeline = load_flux_img_to_img() if args.use_image else load_flux() | ||
94 | |||
95 | if isinstance(pipeline, FluxPipeline): | ||
96 | logger.info("Using text-to-image model") | ||
97 | else: | ||
98 | logger.info("Using image-to-image model") | ||
99 | |||
58 | pipeline.to(torch.float16) | 100 | pipeline.to(torch.float16) |
59 | 101 | ||
60 | target_img = STORAGE_DIR / "1517481062292.jpg" | 102 | # target_img = STORAGE_DIR / "1517481062292.jpg" |
103 | target_img = STORAGE_DIR / "Flux" / "a23aae99-c8f1-4ce5-b91f-0b732774dadd.png" | ||
61 | 104 | ||
62 | target_img_path = target_img.resolve(strict=True) | 105 | target_img_path = target_img.resolve(strict=True) |
63 | 106 | ||
@@ -67,25 +110,53 @@ def main(): | |||
67 | 110 | ||
68 | width, height = args.size | 111 | width, height = args.size |
69 | 112 | ||
70 | images = generate_image(pipeline, init_image=init_image, prompt=args.prompt, width=width, height=height, guideance_scale=args.guideance_scale, num_images_per_prompt=args.number) | 113 | record_prompt(args.prompt) |
114 | |||
115 | logger.info(f"Using prompt: \"{args.prompt}\"") | ||
116 | |||
117 | logger.info("Generating image(s)...") | ||
118 | |||
119 | images = generate_image( | ||
120 | pipeline=pipeline, | ||
121 | init_image=init_image, | ||
122 | prompt=args.prompt, | ||
123 | prompt_2=args.prompt2, | ||
124 | width=width, | ||
125 | height=height, | ||
126 | guideance_scale=args.guideance_scale, | ||
127 | num_images_per_prompt=args.number | ||
128 | ) | ||
129 | |||
71 | for image in images: | 130 | for image in images: |
72 | filename = generate_random_string() | 131 | filename = generate_random_string() |
73 | filepath = STORAGE_DIR / "Flux" / f"{filename}.png" | 132 | filepath = STORAGE_DIR / "Flux" / f"{filename}.png" |
133 | logger.info(f"Saving {filepath}...") | ||
74 | image.save(filepath) | 134 | image.save(filepath) |
135 | |||
136 | logger.info("Finished") | ||
75 | except FileNotFoundError: | 137 | except FileNotFoundError: |
76 | print("\n Target image doesn't exist. Exiting...") | 138 | print("\n Target image doesn't exist. Exiting...") |
77 | exit(0) | 139 | exit(0) |
78 | except KeyboardInterrupt: | 140 | except KeyboardInterrupt: |
79 | print('\nExiting early...') | 141 | print('\nExiting early...') |
80 | exit(0) | 142 | exit(0) |
143 | except Exception as e: | ||
144 | print(f"An error occured: {e}") | ||
145 | exit(1) | ||
81 | 146 | ||
82 | parser = argparse.ArgumentParser(description="Generate some A.I. images", epilog="All done!") | 147 | parser = argparse.ArgumentParser(description="Generate some A.I. images", epilog="All done!") |
83 | 148 | ||
84 | parser.add_argument("-n", "--number", type=int, default=1, help="the number of images you want to generate") | 149 | parser.add_argument("-n", "--number", type=int, default=1, help="the number of images you want to generate") |
85 | parser.add_argument("-o", "--output", type=str, default="image", help="the name of the output image") | 150 | parser.add_argument("-o", "--output", type=str, default="image", help="the name of the output image") |
86 | parser.add_argument("-p", "--prompt", type=str, required=True, help="the prompt") | 151 | parser.add_argument("-p", "--prompt", type=str, required=True, help="the prompt") |
152 | parser.add_argument("-p2", "--prompt2", type=str, help="A second prompt") | ||
87 | parser.add_argument("-gs", "--guideance-scale", type=float, default=0) | 153 | parser.add_argument("-gs", "--guideance-scale", type=float, default=0) |
88 | parser.add_argument("--size", type=parse_dimensions, default="1024:1024") | 154 | parser.add_argument("--size", type=parse_dimensions, default="1024:1024", help="the size of the output images") |
155 | parser.add_argument("-u", "--use-image", action="store_true", help="use a predefined image") | ||
156 | |||
157 | # parser.add_argument("-b", "--base-image").completer = image_completer | ||
158 | |||
159 | # argcomplete.autocomplete(parser) | ||
89 | 160 | ||
90 | if __name__ == "__main__": | 161 | if __name__ == "__main__": |
91 | main() | 162 | main() |