More Stable Diffusion! This time attempting to add inpainting / masking based on my previous code, to merge both txt2img.py
and img2img.py
capabilities, disregarding the out-of-box inpainting.py
code, which does not have parameters for positive or negative prompts. Keyword being attempting...
This is part 4 in what is turning out to be a series - skip to post 6 for the latest code:
- (16 Sep 22): AI-generated images with Stable Diffusion on an M1 mac a.k.a.
txt2img
- (17 Sep 22): Stable Diffusion image-to-image mode a.k.a.
img2img
- (21 Sep 22): My simplified Stable Diffusion Python script a.k.a
txtimg2img
- (26 Sep 22): this post: Stable Diffusion script with inpainting mask a.k.a.
mask+txtimg2img
- (28 Sep 22): Adding CLIPSeg automatic masking to Stable Diffusion a.k.a.
txt2mask+txtimg2img
- (9 Oct 22): myByways Simple-SD v1.0 Python script
(29 Jul 23): Ignore all the above as being outdated - jump to Stable Diffusion SDXL 1.0 with ComfyUI for Stable Diffusion SDXL 1.0 instead.
Background
According to Wikipedia, Inpainting is the process of filling in missing parts of an image. In my case, I use a mask to specify which parts of an image should be replaced by Stable Diffusion. Both the image and the mask must be PNGs of exactly 512x512 pixels, and the mask is grayscale - white portions will be over-written and black portions will be retained (although Stable Diffusion will still change the content slightly).
Before going any further, I don’t know what I’m doing. I merely copied snippets or concepts from Stable Diffusion Web UI (sd-webui) on GitHub. Specifically:
- The modification to the Diffusion model DDIMSampler,
ldm/models/diffusion/ddim.py
, with the two inputs needed, - How to initialize the latent and obliterate the masked portion of the image as
z_enc
, otherwise, the masked area retains its primarily white color.
Don’t use my code, please. I cannot emphasize this enough: I don’t know what I am doing, all this is for fun and part of my own learning journey. For the inexperienced, please the Stable Diffusion Web UI or other tools instead. The only reason I don’t is I have a M1 mac, not Windows or Linux machine with a GPU.
My less simple Python code
Pre-requisites:
- A M1/M2 Apple mac with Homebrew installed. Homebrew will be used to install Python 3.10 and a few other dependencies.
- Whatever I did to install Stable Diffusion in my first post, AI-generated images with Stable Diffusion on an M1 mac using bfirsh’s stable-diffusion/apple-silicon-mps-support branch
- And whatever I coded in my third post, My simplified Stable Diffusion Python script.
My changes involve:
- Making the modifications to
ddim.py
, - Modifying my previous script, and modifying the global variables (in UPPERCASE)
- The script should have execute permission
chmod +x simple-sd.py
, - And the Python virtual environment should be setup, so
source venv/bin/activate
, - The finally testing
./simple-sd.py
.... and testing, and testing, and testing...
Modified ddim.py
Being too lazy to do a proper job, all I did is copy the main change to ldm/models/diffusion/ddim.py
by adding a new function decode2()
at the end - please note the indentation must be the same level as the preceding original function decode()
which is left in-place as fallback:
- the
decode2()
function now takesz_mask
andx0
inputs, - and the addition of the block starting from the line
if z_mask is not None and i < total_steps - 2:
within the interation loop.
@torch.no_grad()
def decode2(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None,
use_original_steps=False, z_mask = None, x0=None):
timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps
timesteps = timesteps[:t_start]
time_range = np.flip(timesteps)
total_steps = timesteps.shape[0]
print(f"Running DDIM Sampling with {total_steps} timesteps")
iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
x_dec = x_latent
for i, step in enumerate(iterator):
index = total_steps - i - 1
ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long)
if z_mask is not None and i < total_steps - 2:
img_orig = self.model.q_sample(x0, ts)
mask_inv = 1. - z_mask
x_dec = (img_orig * mask_inv) + (z_mask * x_dec)
x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning)
return x_dec
Simplified Script
I only made the minimum changes to my previous simplified Stable Diffusion script to achieve what I wanted. There are no special features and few parameters, since I don’t understand the model anyway! Therefore, the modifications:
- just one new
MASK
variable, which is the filename of mask PNG (or None for previous behaviour without a mask) - obviously, to be used in conjunction with theIMAGE
variable, - deleted the function
load_image()
and replaced it withload_image_mask()
which is pretty similar, but also converts the mask to only bother with LuminanceL
(i.e. grayscale), - modified
setup_sampler()
to incorporatex0
,z_enc
andmask
latents, - and changed
generate_samples()
to callgenerate2()
instead of the default function.
#! python
# copyright (c) 2022 C.Y. Wong, myByways.com simplified Stable Diffusion v0.2
import os, sys, time
import torch
import numpy as np
from omegaconf import OmegaConf
from PIL import Image
from einops import rearrange
from pytorch_lightning import seed_everything
from contextlib import nullcontext
from ldm.util import instantiate_from_config
from ldm.models.diffusion.plms import PLMSSampler
from ldm.models.diffusion.ddim import DDIMSampler
from transformers import logging
PROMPTS = [ # --prompt, one or more in an array
'the death star',
]
NEGATIVES = [ # negative prompt, one or more, default None (or an empty array)
]
HEIGHT = 512 # --H, default 512, beyond causes M1 to crawl
WIDTH = 512 # --W, default 512, beyond causes M1 to crawl
FACTOR = 8 # --f downsampling factor, default 8
FIXED = 0 # --fixed_code, 1 for repeatable results, default 0
SEED = 42 # --seed, default 42
NOISE = 0 # --ddim_eta, 0 deterministic, no noise - 1.0 random noise, ignored for PLMS (must be 0)
PLMS = 0 # --plms, default 1 on M1 for txt2img but ignored for img2img (must be DDIM)
ITERATIONS = 1 # --n_iter, default 1
SCALE = 7.5 # --scale, 5 further -> 15 closer to prompt, default 7.5
STEPS = 50 # --ddim_steps, practically little improvement >50 but takes longer, default 50
IMAGE = 'src.png' # --init-img, img2img initial latent seed, default None
STRENGTH = 0.75 # --strength 0 more -> 1 less like image, default 0.75
MASK = 'mask.png' # in-paint to masked area, white=overwrite, black=retain, default=None
FOLDER = 'outputs' # --outdir for images and history file below
HISTORY = 'history.txt'
CONFIG = 'configs/stable-diffusion/v1-inference.yaml'
CHECKPOINT = 'models/ldm/stable-diffusion-v1/model.ckpt'
def seed_pre():
if not FIXED:
seed_everything(SEED)
def seed_post(device):
if FIXED:
seed_everything(SEED)
return torch.randn([1, 4, HEIGHT // FACTOR, WIDTH // FACTOR], device='cpu').to(torch.device(device.type))
return None
def load_model(config, ckpt=CHECKPOINT):
pl_sd = torch.load(ckpt, map_location='cpu')
sd = pl_sd['state_dict']
model = instantiate_from_config(config.model)
model.load_state_dict(sd, strict=False)
return model
def set_device(model):
if torch.backends.mps.is_available():
device = torch.device('mps')
precision = nullcontext
elif torch.cuda.is_available():
device = torch.device('cuda')
precision = torch.autocast
else:
device = torch.device('cpu')
precision = torch.autocast
model.to(device.type)
model.eval()
return device, precision
def load_image_mask(device, image_file=IMAGE, mask_file=MASK):
image = Image.open(image_file).convert('RGB')
w, h = image.size
w, h = map(lambda x: x - x % 32, (w, h))
image = image.resize((w, h), resample=Image.Resampling.LANCZOS)
image = np.array(image).astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
image = 2. * image - 1.
image = image.to(device.type)
mask = None
if MASK:
mask = Image.open(mask_file).convert('L')
mask = mask.resize((w // FACTOR, h // FACTOR), resample=Image.Resampling.LANCZOS)
mask = np.array(mask).astype(np.float32) / 255.0
mask = np.tile(mask, (4, 1, 1))
mask = mask[None].transpose(0, 1, 2, 3)
mask = torch.from_numpy(mask).to(device.type)
return image, mask
def setup_sampler(model):
global NOISE
if IMAGE:
sampler = DDIMSampler(model)
sampler.make_schedule(ddim_num_steps=STEPS, ddim_eta=NOISE, verbose=False)
image, mask = load_image_mask(model.device, IMAGE, MASK)
sampler.t_enc = int(STRENGTH * STEPS)
sampler.x0 = model.get_first_stage_encoding(model.encode_first_stage(image))
sampler.z_enc = sampler.stochastic_encode(sampler.x0, torch.tensor([sampler.t_enc]).to(model.device.type))
sampler.z_mask = mask
if MASK:
random = torch.randn(mask.shape, device=model.device)
sampler.z_enc = (mask * random) + ((1-mask) * sampler.z_enc)
elif PLMS:
sampler = PLMSSampler(model)
NOISE = 0
else:
sampler = DDIMSampler(model)
return sampler
def get_prompts():
global NEGATIVES
if NEGATIVES is None:
NEGATIVES = [''] * len(PROMPTS)
else:
NEGATIVES.extend([''] * (len(PROMPTS)-len(NEGATIVES)))
return zip(PROMPTS, NEGATIVES)
def generate_samples(model, sampler, prompt, negative, start):
uncond = model.get_learned_conditioning(negative) if SCALE != 1.0 else None
cond = model.get_learned_conditioning(prompt)
if IMAGE:
samples = sampler.decode2(sampler.z_enc, cond, sampler.t_enc,
unconditional_guidance_scale=SCALE, unconditional_conditioning=uncond,
z_mask=sampler.z_mask, x0=sampler.x0)
else:
shape = [4, HEIGHT // FACTOR, WIDTH // FACTOR]
samples, _ = sampler.sample(S=STEPS, conditioning=cond, batch_size=1,
shape=shape, verbose=False, unconditional_guidance_scale=SCALE,
unconditional_conditioning=uncond, eta=NOISE, x_T=start)
x_samples = model.decode_first_stage(samples)
x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
return x_samples
def save_image(image):
name = f'{time.strftime("%Y%m%d_%H%M%S")}.png'
image = 255. * rearrange(image.cpu().numpy(), 'c h w -> h w c')
img = Image.fromarray(image.astype(np.uint8))
img.save(os.path.join(FOLDER, name))
return name
def save_history(name, prompt, negative):
with open(os.path.join(FOLDER, HISTORY), 'a') as history:
history.write(f'{name} -> {"PLMS" if PLMS else "DDIM"}, Seed={SEED}{" fixed" if FIXED else ""}, Scale={SCALE}, Steps={STEPS}, Noise={NOISE}')
if IMAGE:
history.write(f', Image={IMAGE}, Strength={STRENGTH}')
if MASK:
history.write(f', Mask={MASK}')
if len(negative):
history.write(f'\n + {prompt}\n - {negative}\n')
else:
history.write(f'\n + {prompt}\n')
def main():
print('*** Loading Stable Diffusion - myByways.com simple-sd version 0.2')
tic1 = time.time()
logging.set_verbosity_error()
os.makedirs(FOLDER, exist_ok=True)
seed_pre()
config = OmegaConf.load(CONFIG)
model = load_model(config)
device, precision_scope = set_device(model)
sampler = setup_sampler(model)
start_code = seed_post(device)
toc1 = time.time()
print(f'*** Model setup time: {(toc1 - tic1):.2f}s')
counter = 0
with torch.no_grad():
with precision_scope(device.type):
with model.ema_scope():
for iteration in range(ITERATIONS):
for prompt, negative in get_prompts():
print(f'*** Iteration {iteration + 1}: {prompt}')
tic2 = time.time()
images = generate_samples(model, sampler, prompt, negative, start_code)
for image in images:
name = save_image(image)
save_history(name, prompt, negative)
print(f'*** Saved image: {name}')
counter += len(images)
toc2 = time.time()
print(f'*** Synthesis time: {(toc2 - tic2):.2f}s')
print(f'*** Total time: {(toc2 - tic1):.2f}s')
print(f'*** Saved {counter} image(s) to {FOLDER} folder.')
if __name__ == "__main__":
try:
main()
except KeyboardInterrupt:
print('*** User abort, goodbye.')
except FileNotFoundError as e:
print(f'*** {e}')
Sample Output
First, I took my original image in the post, Stable Diffusion image-to-image mode, and replaced the planet with a prompt the death star
and a mask that looked like this:
Here’s the output from my most recent run - to confirm the masked portion is correctly replaced:
Next, I tested using the original example image of the girl smoking in front a a chain link fence data/inpainting_examples/6458524847_2f4c361183_k.png
and flipped the mask data/inpainting_examples/6458524847_2f4c361183_k_mask.png
horizontally:
I got the images below with the prompt a photo of a handsome young man in jeans and a blazer, standing next to a woman with long brown hair, with a fence in the background
. Don’t blame SD for the missing smokes (it’s because I couldn’t be bothered to mask properly) and the gentleman’s long hair (it’s poor prompt engineering). Anyway, after many runs, this was the best result... creepy, but accurate!