summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorZach Berwaldt <zberwaldt@tutamail.com>2024-10-30 22:00:56 -0400
committerZach Berwaldt <zberwaldt@tutamail.com>2024-10-30 22:00:56 -0400
commitffbc2e788c113b278b9306aac4cea1d644eaf048 (patch)
tree31794861d69249c923d4f8a10e5b7772abd7372f
parent4445e6f202e9d3a7f7e9bc94894c1eb3bc6c7945 (diff)
add logging, refactor generate image, keep track of prompts used.
-rw-r--r--.gitignore2
-rw-r--r--prompts.txt2
-rw-r--r--run_flux.py101
3 files changed, 89 insertions, 16 deletions
diff --git a/.gitignore b/.gitignore
index 2fa80d6..894a3f5 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,2 +1,2 @@
1*.png 1*.png
2 2*.log
diff --git a/prompts.txt b/prompts.txt
new file mode 100644
index 0000000..3268a51
--- /dev/null
+++ b/prompts.txt
@@ -0,0 +1,2 @@
1A space pilot, wearing her space flight suit, graying hair, determined to face the future.
2Female with very large breasts.
diff --git a/run_flux.py b/run_flux.py
index 4a27816..89d188b 100644
--- a/run_flux.py
+++ b/run_flux.py
@@ -2,17 +2,42 @@ import torch
2import os 2import os
3import base64 3import base64
4import argparse 4import argparse
5import argcomplete
5from diffusers import FluxPipeline, FluxImg2ImgPipeline 6from diffusers import FluxPipeline, FluxImg2ImgPipeline
6from diffusers.utils import load_image 7from diffusers.utils import load_image
7from typing import Tuple 8from typing import Tuple
8from pathlib import Path 9from pathlib import Path
9from PIL import Image 10from PIL import Image
10import uuid 11import uuid
12import logging
11 13
12STORAGE_DIR: Path = Path.home() / "Pictures" 14STORAGE_DIR: Path = Path.home() / "Pictures"
13 15
14STORAGE_DIR.mkdir(parents=True, exist_ok=True) 16STORAGE_DIR.mkdir(parents=True, exist_ok=True)
15 17
18logger = logging.getLogger("run_flux")
19
20def image_completer(prefix, parsed_args, **kwargs):
21 image_dir = STORAGE_DIR / "Flux"
22 return [
23 filename for filename in os.listdir(image_dir)
24 if filename.startswith(prefix) and os.path.isfile(os.path.join(image_dir, filename))
25 ]
26
27def record_prompt(prompt, filename="prompts.txt"):
28 try:
29 with open(filename, "r") as file:
30 existing_prompts = set(line.strip() for line in file)
31 except FileNotFoundError:
32 existing_prompts = set()
33
34 if prompt not in existing_prompts:
35 with open(filename, "a") as file:
36 file.write(prompt + "\n")
37 logger.info(f"Recording new prompt: \"{prompt}\"")
38 else:
39 logger.info(f"Prompt already exists in the file: \"{prompt}\"")
40
16def load_flux_img_to_img(): 41def load_flux_img_to_img():
17 pipeline = FluxImg2ImgPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16) 42 pipeline = FluxImg2ImgPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16)
18 pipeline.enable_model_cpu_offload() 43 pipeline.enable_model_cpu_offload()
@@ -27,17 +52,22 @@ def load_flux():
27 pipeline.vae.enable_tiling() 52 pipeline.vae.enable_tiling()
28 return pipeline 53 return pipeline
29 54
30def generate_image(pipeline, prompt, init_image=None, height=1024, width=1024, guideance_scale=0, num_images_per_prompt=1, num_inference_steps=50): 55def generate_image(pipeline, prompt, prompt_2=None, init_image=None, height=1024, width=1024, guideance_scale=0, num_images_per_prompt=1, num_inference_steps=50):
31 images = pipeline( 56 kwargs = {
32 prompt=prompt, 57 "prompt": prompt,
33 guidance_scale=guideance_scale, 58 "prompt_2": prompt_2,
34 height=height, 59 "guidance_scale": guideance_scale,
35 width=width, 60 "height":height,
36 max_sequence_length=256, 61 "width": width,
37 num_inference_steps=num_inference_steps, 62 "max_sequence_length": 256,
38 num_images_per_prompt=num_images_per_prompt, 63 "num_inference_steps": num_inference_steps,
39 image=init_image 64 "num_images_per_prompt": num_images_per_prompt
40 ).images 65 }
66
67 if isinstance(pipeline, FluxImg2ImgPipeline) and init_image is not None:
68 kwargs["image"] = init_image
69
70 images = pipeline(**kwargs).images
41 return images 71 return images
42 72
43def generate_random_string(length=16) -> str: 73def generate_random_string(length=16) -> str:
@@ -52,12 +82,25 @@ def parse_dimensions(dim_str: str) -> Tuple[int, int]:
52 82
53def main(): 83def main():
54 try: 84 try:
85 logging.basicConfig(filename="flux.log", level=logging.INFO, format='%(asctime)s - %(levelname)s -> %(message)s', datefmt="%m/%d/%Y %I:%M:%S %p")
86
87 logger.info("Parsing arguments")
88
55 args = parser.parse_args() 89 args = parser.parse_args()
56 90
57 pipeline = load_flux_img_to_img() 91 logger.info("Choosing model...")
92
93 pipeline = load_flux_img_to_img() if args.use_image else load_flux()
94
95 if isinstance(pipeline, FluxPipeline):
96 logger.info("Using text-to-image model")
97 else:
98 logger.info("Using image-to-image model")
99
58 pipeline.to(torch.float16) 100 pipeline.to(torch.float16)
59 101
60 target_img = STORAGE_DIR / "1517481062292.jpg" 102 # target_img = STORAGE_DIR / "1517481062292.jpg"
103 target_img = STORAGE_DIR / "Flux" / "a23aae99-c8f1-4ce5-b91f-0b732774dadd.png"
61 104
62 target_img_path = target_img.resolve(strict=True) 105 target_img_path = target_img.resolve(strict=True)
63 106
@@ -67,25 +110,53 @@ def main():
67 110
68 width, height = args.size 111 width, height = args.size
69 112
70 images = generate_image(pipeline, init_image=init_image, prompt=args.prompt, width=width, height=height, guideance_scale=args.guideance_scale, num_images_per_prompt=args.number) 113 record_prompt(args.prompt)
114
115 logger.info(f"Using prompt: \"{args.prompt}\"")
116
117 logger.info("Generating image(s)...")
118
119 images = generate_image(
120 pipeline=pipeline,
121 init_image=init_image,
122 prompt=args.prompt,
123 prompt_2=args.prompt2,
124 width=width,
125 height=height,
126 guideance_scale=args.guideance_scale,
127 num_images_per_prompt=args.number
128 )
129
71 for image in images: 130 for image in images:
72 filename = generate_random_string() 131 filename = generate_random_string()
73 filepath = STORAGE_DIR / "Flux" / f"{filename}.png" 132 filepath = STORAGE_DIR / "Flux" / f"{filename}.png"
133 logger.info(f"Saving {filepath}...")
74 image.save(filepath) 134 image.save(filepath)
135
136 logger.info("Finished")
75 except FileNotFoundError: 137 except FileNotFoundError:
76 print("\n Target image doesn't exist. Exiting...") 138 print("\n Target image doesn't exist. Exiting...")
77 exit(0) 139 exit(0)
78 except KeyboardInterrupt: 140 except KeyboardInterrupt:
79 print('\nExiting early...') 141 print('\nExiting early...')
80 exit(0) 142 exit(0)
143 except Exception as e:
144 print(f"An error occured: {e}")
145 exit(1)
81 146
82parser = argparse.ArgumentParser(description="Generate some A.I. images", epilog="All done!") 147parser = argparse.ArgumentParser(description="Generate some A.I. images", epilog="All done!")
83 148
84parser.add_argument("-n", "--number", type=int, default=1, help="the number of images you want to generate") 149parser.add_argument("-n", "--number", type=int, default=1, help="the number of images you want to generate")
85parser.add_argument("-o", "--output", type=str, default="image", help="the name of the output image") 150parser.add_argument("-o", "--output", type=str, default="image", help="the name of the output image")
86parser.add_argument("-p", "--prompt", type=str, required=True, help="the prompt") 151parser.add_argument("-p", "--prompt", type=str, required=True, help="the prompt")
152parser.add_argument("-p2", "--prompt2", type=str, help="A second prompt")
87parser.add_argument("-gs", "--guideance-scale", type=float, default=0) 153parser.add_argument("-gs", "--guideance-scale", type=float, default=0)
88parser.add_argument("--size", type=parse_dimensions, default="1024:1024") 154parser.add_argument("--size", type=parse_dimensions, default="1024:1024", help="the size of the output images")
155parser.add_argument("-u", "--use-image", action="store_true", help="use a predefined image")
156
157# parser.add_argument("-b", "--base-image").completer = image_completer
158
159# argcomplete.autocomplete(parser)
89 160
90if __name__ == "__main__": 161if __name__ == "__main__":
91 main() 162 main()