More Stable Diffusion! This time attempting to add inpainting / masking based on my previous code, to merge both txt2img.py and img2img.py capabilities, disregarding the out-of-box inpainting.py code, which does not have parameters for positive or negative prompts. Keyword being attempting...

This is part 4 in what is turning out to be a series - skip to post 6 for the latest code:

  1. (16 Sep 22): AI-generated images with Stable Diffusion on an M1 mac a.k.a. txt2img
  2. (17 Sep 22): Stable Diffusion image-to-image mode a.k.a. img2img
  3. (21 Sep 22): My simplified Stable Diffusion Python script a.k.a txtimg2img
  4. (26 Sep 22): this post: Stable Diffusion script with inpainting mask a.k.a. mask+txtimg2img
  5. (28 Sep 22): Adding CLIPSeg automatic masking to Stable Diffusion a.k.a. txt2mask+txtimg2img
  6. (9 Oct 22): myByways Simple-SD v1.0 Python script

    (29 Jul 23): Ignore all the above as being outdated - jump to Stable Diffusion SDXL 1.0 with ComfyUI for Stable Diffusion SDXL 1.0 instead.

Background

According to Wikipedia, Inpainting is the process of filling in missing parts of an image. In my case, I use a mask to specify which parts of an image should be replaced by Stable Diffusion. Both the image and the mask must be PNGs of exactly 512x512 pixels, and the mask is grayscale - white portions will be over-written and black portions will be retained (although Stable Diffusion will still change the content slightly).

Before going any further, I don’t know what I’m doing. I merely copied snippets or concepts from Stable Diffusion Web UI (sd-webui) on GitHub. Specifically:

  1. The modification to the Diffusion model DDIMSampler, ldm/models/diffusion/ddim.py, with the two inputs needed,
  2. How to initialize the latent and obliterate the masked portion of the image as z_enc, otherwise, the masked area retains its primarily white color.

Don’t use my code, please. I cannot emphasize this enough: I don’t know what I am doing, all this is for fun and part of my own learning journey. For the inexperienced, please the Stable Diffusion Web UI or other tools instead. The only reason I don’t is I have a M1 mac, not Windows or Linux machine with a GPU.

My less simple Python code

Pre-requisites:

My changes involve:

  1. Making the modifications to ddim.py,
  2. Modifying my previous script, and modifying the global variables (in UPPERCASE)
  3. The script should have execute permission chmod +x simple-sd.py,
  4. And the Python virtual environment should be setup, so source venv/bin/activate,
  5. The finally testing ./simple-sd.py.... and testing, and testing, and testing...

Modified ddim.py

Being too lazy to do a proper job, all I did is copy the main change to ldm/models/diffusion/ddim.py by adding a new function decode2() at the end - please note the indentation must be the same level as the preceding original function decode() which is left in-place as fallback:

  • the decode2() function now takes z_mask and x0 inputs,
  • and the addition of the block starting from the line if z_mask is not None and i < total_steps - 2: within the interation loop.
    @torch.no_grad()
    def decode2(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None,
               use_original_steps=False, z_mask = None, x0=None):

        timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps
        timesteps = timesteps[:t_start]

        time_range = np.flip(timesteps)
        total_steps = timesteps.shape[0]
        print(f"Running DDIM Sampling with {total_steps} timesteps")

        iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
        x_dec = x_latent
        for i, step in enumerate(iterator):
            index = total_steps - i - 1
            ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long)

            if z_mask is not None and i < total_steps - 2:
                img_orig = self.model.q_sample(x0, ts)
                mask_inv = 1. - z_mask
                x_dec = (img_orig * mask_inv) + (z_mask * x_dec)

            x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps,
                                          unconditional_guidance_scale=unconditional_guidance_scale,
                                          unconditional_conditioning=unconditional_conditioning)
        return x_dec

Simplified Script

I only made the minimum changes to my previous simplified Stable Diffusion script to achieve what I wanted. There are no special features and few parameters, since I don’t understand the model anyway! Therefore, the modifications:

  • just one new MASK variable, which is the filename of mask PNG (or None for previous behaviour without a mask) - obviously, to be used in conjunction with the IMAGE variable,
  • deleted the function load_image() and replaced it with load_image_mask() which is pretty similar, but also converts the mask to only bother with Luminance L (i.e. grayscale),
  • modified setup_sampler() to incorporate x0, z_enc and mask latents,
  • and changed generate_samples() to call generate2() instead of the default function.
#! python
# copyright (c) 2022 C.Y. Wong, myByways.com simplified Stable Diffusion v0.2

import os, sys, time
import torch
import numpy as np
from omegaconf import OmegaConf
from PIL import Image
from einops import rearrange
from pytorch_lightning import seed_everything
from contextlib import nullcontext
from ldm.util import instantiate_from_config
from ldm.models.diffusion.plms import PLMSSampler
from ldm.models.diffusion.ddim import DDIMSampler
from transformers import logging

PROMPTS = [         # --prompt, one or more in an array
    'the death star',
]
NEGATIVES = [       # negative prompt, one or more, default None (or an empty array)
]

HEIGHT = 512        # --H, default 512, beyond causes M1 to crawl
WIDTH = 512         # --W, default 512, beyond causes M1 to crawl
FACTOR = 8          # --f downsampling factor, default 8

FIXED = 0           # --fixed_code, 1 for repeatable results, default 0
SEED = 42           # --seed, default 42
NOISE = 0           # --ddim_eta, 0 deterministic, no noise - 1.0 random noise, ignored for PLMS (must be 0)
PLMS = 0            # --plms, default 1 on M1 for txt2img but ignored for img2img (must be DDIM)
ITERATIONS = 1      # --n_iter, default 1
SCALE = 7.5         # --scale, 5 further -> 15 closer to prompt, default 7.5
STEPS = 50          # --ddim_steps, practically little improvement >50 but takes longer, default 50

IMAGE = 'src.png'   # --init-img, img2img initial latent seed, default None
STRENGTH = 0.75     # --strength 0 more -> 1 less like image, default 0.75

MASK = 'mask.png'   # in-paint to masked area, white=overwrite, black=retain, default=None

FOLDER = 'outputs'  # --outdir for images and history file below
HISTORY = 'history.txt'
CONFIG = 'configs/stable-diffusion/v1-inference.yaml'
CHECKPOINT = 'models/ldm/stable-diffusion-v1/model.ckpt'

def seed_pre():
    if not FIXED:
        seed_everything(SEED)

def seed_post(device):
    if FIXED:
        seed_everything(SEED)
        return torch.randn([1, 4, HEIGHT // FACTOR, WIDTH // FACTOR], device='cpu').to(torch.device(device.type))
    return None

def load_model(config, ckpt=CHECKPOINT):
    pl_sd = torch.load(ckpt, map_location='cpu')
    sd = pl_sd['state_dict']
    model = instantiate_from_config(config.model)
    model.load_state_dict(sd, strict=False)
    return model

def set_device(model):
    if torch.backends.mps.is_available():
        device = torch.device('mps')
        precision = nullcontext
    elif torch.cuda.is_available():
        device = torch.device('cuda')
        precision = torch.autocast
    else:
        device = torch.device('cpu')
        precision = torch.autocast
    model.to(device.type)
    model.eval()
    return device, precision

def load_image_mask(device, image_file=IMAGE, mask_file=MASK):
    image = Image.open(image_file).convert('RGB')
    w, h = image.size
    w, h = map(lambda x: x - x % 32, (w, h))
    image = image.resize((w, h), resample=Image.Resampling.LANCZOS)
    image = np.array(image).astype(np.float32) / 255.0
    image = image[None].transpose(0, 3, 1, 2)
    image = torch.from_numpy(image)
    image = 2. * image - 1.
    image = image.to(device.type)
    mask = None
    if MASK:
        mask = Image.open(mask_file).convert('L')
        mask = mask.resize((w // FACTOR, h // FACTOR), resample=Image.Resampling.LANCZOS)
        mask = np.array(mask).astype(np.float32) / 255.0
        mask = np.tile(mask, (4, 1, 1))
        mask = mask[None].transpose(0, 1, 2, 3)
        mask = torch.from_numpy(mask).to(device.type)
    return image, mask

def setup_sampler(model):
    global NOISE
    if IMAGE:
        sampler = DDIMSampler(model)
        sampler.make_schedule(ddim_num_steps=STEPS, ddim_eta=NOISE, verbose=False)
        image, mask = load_image_mask(model.device, IMAGE, MASK)
        sampler.t_enc = int(STRENGTH * STEPS)
        sampler.x0 = model.get_first_stage_encoding(model.encode_first_stage(image))
        sampler.z_enc = sampler.stochastic_encode(sampler.x0, torch.tensor([sampler.t_enc]).to(model.device.type))
        sampler.z_mask = mask
        if MASK:
            random = torch.randn(mask.shape, device=model.device)
            sampler.z_enc = (mask * random) + ((1-mask) * sampler.z_enc)
    elif PLMS:
        sampler = PLMSSampler(model)
        NOISE = 0
    else:
        sampler = DDIMSampler(model)
    return sampler

def get_prompts():
    global NEGATIVES
    if NEGATIVES is None:
        NEGATIVES = [''] * len(PROMPTS)
    else:
        NEGATIVES.extend([''] * (len(PROMPTS)-len(NEGATIVES)))
    return zip(PROMPTS, NEGATIVES)

def generate_samples(model, sampler, prompt, negative, start):
    uncond = model.get_learned_conditioning(negative) if SCALE != 1.0 else None
    cond = model.get_learned_conditioning(prompt)
    if IMAGE:
        samples = sampler.decode2(sampler.z_enc, cond, sampler.t_enc, 
            unconditional_guidance_scale=SCALE, unconditional_conditioning=uncond, 
            z_mask=sampler.z_mask, x0=sampler.x0)
    else:
        shape = [4, HEIGHT // FACTOR, WIDTH // FACTOR]
        samples, _ = sampler.sample(S=STEPS, conditioning=cond, batch_size=1,
            shape=shape, verbose=False, unconditional_guidance_scale=SCALE, 
            unconditional_conditioning=uncond, eta=NOISE, x_T=start)
    x_samples = model.decode_first_stage(samples)
    x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
    return x_samples

def save_image(image):
    name = f'{time.strftime("%Y%m%d_%H%M%S")}.png'
    image = 255. * rearrange(image.cpu().numpy(), 'c h w -> h w c')
    img = Image.fromarray(image.astype(np.uint8))
    img.save(os.path.join(FOLDER, name))
    return name

def save_history(name, prompt, negative):
    with open(os.path.join(FOLDER, HISTORY), 'a') as history:
        history.write(f'{name} -> {"PLMS" if PLMS else "DDIM"}, Seed={SEED}{" fixed" if FIXED else ""}, Scale={SCALE}, Steps={STEPS}, Noise={NOISE}')
        if IMAGE:
            history.write(f', Image={IMAGE}, Strength={STRENGTH}')
        if MASK:
            history.write(f', Mask={MASK}')
        if len(negative):
            history.write(f'\n + {prompt}\n - {negative}\n')
        else:
            history.write(f'\n + {prompt}\n')

def main():
    print('*** Loading Stable Diffusion - myByways.com simple-sd version 0.2')
    tic1 = time.time()
    logging.set_verbosity_error()
    os.makedirs(FOLDER, exist_ok=True)

    seed_pre()
    config = OmegaConf.load(CONFIG)
    model = load_model(config)
    device, precision_scope = set_device(model)
    sampler = setup_sampler(model)
    start_code = seed_post(device)

    toc1 = time.time()
    print(f'*** Model setup time: {(toc1 - tic1):.2f}s')

    counter = 0
    with torch.no_grad():
        with precision_scope(device.type):
            with model.ema_scope():

                for iteration in range(ITERATIONS):
                    for prompt, negative in get_prompts():
                        print(f'*** Iteration {iteration + 1}: {prompt}')
                        tic2 = time.time()
                        images = generate_samples(model, sampler, prompt, negative, start_code)
                        for image in images:
                            name = save_image(image)
                            save_history(name, prompt, negative)
                            print(f'*** Saved image: {name}')

                        counter += len(images)
                        toc2 = time.time()
                        print(f'*** Synthesis time: {(toc2 - tic2):.2f}s')

    print(f'*** Total time: {(toc2 - tic1):.2f}s')
    print(f'*** Saved {counter} image(s) to {FOLDER} folder.')

if __name__ == "__main__":
    try:
        main()
    except KeyboardInterrupt:
        print('*** User abort, goodbye.')
    except FileNotFoundError as e:
        print(f'*** {e}')

Sample Output

First, I took my original image in the post, Stable Diffusion image-to-image mode, and replaced the planet with a prompt the death star and a mask that looked like this:

Here’s the output from my most recent run - to confirm the masked portion is correctly replaced:

Stable Diffusion in-painting with mask - example 1

Next, I tested using the original example image of the girl smoking in front a a chain link fence data/inpainting_examples/6458524847_2f4c361183_k.png and flipped the mask data/inpainting_examples/6458524847_2f4c361183_k_mask.png horizontally:

I got the images below with the prompt a photo of a handsome young man in jeans and a blazer, standing next to a woman with long brown hair, with a fence in the background. Don’t blame SD for the missing smokes (it’s because I couldn’t be bothered to mask properly) and the gentleman’s long hair (it’s poor prompt engineering). Anyway, after many runs, this was the best result... creepy, but accurate!

Stable Diffusion in-painting with mask - example 2