summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorZach Berwaldt <zberwaldt@tutamail.com>2025-01-28 18:05:48 -0500
committerZach Berwaldt <zberwaldt@tutamail.com>2025-01-28 18:05:48 -0500
commit812b0530578a2b338f9fed72c5df01d9b659da1c (patch)
treed1646268bddfbc5b46219e17bcec56d663fb15e7
parent671c17918c877e64aee7c730f32aac9ded9d5915 (diff)
remove image to image support.
-rw-r--r--run_flux.py23
1 files changed, 1 insertions, 22 deletions
diff --git a/run_flux.py b/run_flux.py
index 446f57a..9d6c2b5 100644
--- a/run_flux.py
+++ b/run_flux.py
@@ -34,16 +34,6 @@ def record_prompt(prompt, filename="prompts.txt"):
34 else: 34 else:
35 logger.info(f"Prompt already exists in the file: \"{prompt}\"") 35 logger.info(f"Prompt already exists in the file: \"{prompt}\"")
36 36
37def load_flux_img_to_img():
38 import torch
39 from diffusers import FluxImg2ImgPipeline
40
41 pipeline = FluxImg2ImgPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16)
42 pipeline.enable_model_cpu_offload()
43 pipeline.vae.enable_slicing()
44 pipeline.vae.enable_tiling()
45 return pipeline
46
47def load_flux(): 37def load_flux():
48 import torch 38 import torch
49 from diffusers import FluxPipeline 39 from diffusers import FluxPipeline
@@ -98,7 +88,6 @@ def main():
98 parser.add_argument("-gs", "--guideance-scale", type=float, default=0) 88 parser.add_argument("-gs", "--guideance-scale", type=float, default=0)
99 parser.add_argument("--strength", type=float) 89 parser.add_argument("--strength", type=float)
100 parser.add_argument("--size", type=parse_dimensions, default="1024:1024", help="the size of the output images") 90 parser.add_argument("--size", type=parse_dimensions, default="1024:1024", help="the size of the output images")
101 parser.add_argument("-u", "--use-image", action="store_true", help="use a predefined image")
102 args = parser.parse_args() 91 args = parser.parse_args()
103 92
104 try: 93 try:
@@ -107,19 +96,10 @@ def main():
107 96
108 logger.info("Choosing model...") 97 logger.info("Choosing model...")
109 98
110 pipeline = load_flux_img_to_img() if args.use_image else load_flux() 99 pipeline = load_flux()
111 100
112 pipeline.to(torch.float16) 101 pipeline.to(torch.float16)
113 102
114 # target_img = STORAGE_DIR / "1517481062292.jpg"
115 target_img = STORAGE_DIR / "Flux" / "a23aae99-c8f1-4ce5-b91f-0b732774dadd.png"
116
117 target_img_path = target_img.resolve(strict=True)
118
119 image = Image.open(target_img_path)
120
121 init_image = load_image(image).resize((1024, 1024))
122
123 width, height = args.size 103 width, height = args.size
124 104
125 record_prompt(args.prompt) 105 record_prompt(args.prompt)
@@ -129,7 +109,6 @@ def main():
129 logger.info("Generating image(s)...") 109 logger.info("Generating image(s)...")
130 110
131 config = GenerateImageConfig( 111 config = GenerateImageConfig(
132 init_image=init_image if args.use_image else None,
133 prompt=args.prompt, 112 prompt=args.prompt,
134 prompt_2=args.prompt2 if args.prompt2 else None, 113 prompt_2=args.prompt2 if args.prompt2 else None,
135 width=width, 114 width=width,