I have more ideas for Stable Diffusion. My nights and weekends are consumed! This time: For inpainting, why create a mask image manually, when A.I. can automatically build a mask from a text prompt? Someone much smarter has already published a paper (arXiv:2112.10003 [cs.CV]), with source code, to do just this!
Background
I first got the idea from a Reddit post that mentioned ThereforeGames’ txt2mask add-on for AUTOMATIC1111’s Stable Diffusion Web UI. A quick look revealed that the heavy lifting is done by Timo Lüddecke’s CLIPSeg for “Image Segmentation Using Text and Image Prompts”.
To summarize: this post is adds txt2mask capabilities, building on my previous code which added inpainting with a mask+txtimg, which is based on my simplified Python script that merged Stable Diffusion’s txt2img and img2img capabilities! All-in-one... txt2mask+txtimg2img?
This is part 5 in what is turning out to be a series - skip to post 6 for the latest code:
- (16 Sep 22): AI-generated images with Stable Diffusion on an M1 mac a.k.a.
txt2img- (17 Sep 22): Stable Diffusion image-to-image mode a.k.a.
img2img- (21 Sep 22): My simplified Stable Diffusion Python script a.k.a
txtimg2img- (26 Sep 22): Stable Diffusion script with inpainting mask a.k.a.
mask+txtimg2img- (28 Sep 22): this post: Adding CLIPSeg automatic masking to Stable Diffusion a.k.a.
txt2mask+txtimg2img- (9 Oct 22): myByways Simple-SD v1.0 Python script
(29 Jul 23): Ignore all the above as being outdated - jump to Stable Diffusion SDXL 1.0 with ComfyUI for Stable Diffusion SDXL 1.0 instead.
As usual I don’t understand the models, I just use them... usual warnings and disclaimers ally: do not run my code blindly, etc.
My slightly more complex Python code
CLIPSeg was surprisingly easy to incorporate into my script. I didn’t feel like re-structuring the script, so instead I just added functions the load CLIPSeg and create a file mask. This file is then read by the existing code without further changes. A few challenges around merging the masks from multiple prompts, and flattening the mask with a threshold. And addressing this issue: I noticed the resulting images leave a vague outline or halo of the original objects, because the mask I generate with CLIPSeg is maybe a little off at the edges. Being lazy, I just imported cv2 and implemented Erosion and Dilation of images using OpenCV.
This time, it’s mostly my algorithms, though a little hack-y... but I am proud and pretty pleased with what I have achieved so far (though I do intend to improve it further). So:
load_clipseg()simply loads the model, I am using CPU because MPS does not make much difference- if
MASKvariable if it is a List (array), thencreate_mask()will:- load the mask prompts from the
MASKarray (otherwise, it’s a string PNG mask filename) - run the prompts through the CLIPSeg model, and then sum the outputs values (i.e. merging more than one prompt)
- normalize and compress the output values (range 0-1): I am undecided between a single arbitrary
THRESHOLD(black and white), or twoTHRESHOLDS(grayscale) - both options are still present in my code (methinks the former will do) - save the file...
- load the mask prompts from the
- and if
DILATEis set, thendilate_mask()will:- load the image using OpenCV
- perform the dilate,
- save the image, overwriting the previous file...
To setup, see my previous posts where I installed Stable Diffusion on a M1 mac, and created my own simplified Python script. Then:
- in the stable-diffusion folder, download CLIPSeg (via
git cloneorpip install) - then, download and unzip the weights (refer to CLIPSEG documentation for latest download)
- finally, move
clipseg_weights/rd64-uni-refined.pthtostable-diffusion/model# in the stable-diffusion folder pip install git+https://github.com/timojl/clipseg.git curl https://owncloud.gwdg.de/index.php/s/ioHbRzFx6th32hn/download --output weights.zip unzip w.zip -d model - edit the global variables before running...
Here’s the code:
#! python
# myByways simplified Stable Diffusion v0.3 - add clipseg
import os, sys, time
import torch
import numpy as np
from omegaconf import OmegaConf
from PIL import Image
from einops import rearrange
from pytorch_lightning import seed_everything
from contextlib import nullcontext
from ldm.util import instantiate_from_config
from ldm.models.diffusion.plms import PLMSSampler
from ldm.models.diffusion.ddim import DDIMSampler
from transformers import logging
from torchvision import transforms
from torchvision.utils import save_image as tv_save_image
from clipseg.clipseg import CLIPDensePredT
import cv2
PROMPTS = [ # --prompt, one or more in an array, one per output image
'a photo of an old man in a tweed jacket and boots leaning on a fence, dramatic lighting',
]
NEGATIVES = [ # negative prompt, one or more, default None (or an empty array)
]
HEIGHT = 512 # --H, default 512, beyond causes M1 to crawl
WIDTH = 512 # --W, default 512, beyond causes M1 to crawl
FACTOR = 8 # --f downsampling factor, default 8
FIXED = 0 # --fixed_code, 1 for repeatable results, default 0
SEED = 42 # --seed, default 42
NOISE = 0 # --ddim_eta, 0 deterministic, no noise - 1.0 random noise, ignored for PLMS (must be 0)
PLMS = 0 # --plms, default 1 on M1 for txt2img but ignored for img2img (must be DDIM)
ITERATIONS = 1 # --n_iter, default 1
SCALE = 7.5 # --scale, 5 further -> 15 closer to prompt, default 7.5
STEPS = 50 # --ddim_steps, practically little improvement >50 but takes longer, default 50
IMAGE = '1.png' # --init-img, img2img initial latent seed, default None
STRENGTH = 0.75 # --strength 0 more -> 1 less like image, default 0.75
MASK = [ # inpaint mask PNG where white=overwrite, black=retain, or an array of prompts for CLIPSeg, default=None
'a woman'
]
FOLDER = 'outputs' # --outdir for images and history file below
HISTORY = 'history.txt'
CONFIG = 'configs/stable-diffusion/v1-inference.yaml'
CHECKPOINT = 'models/ldm/stable-diffusion-v1/model.ckpt'
THRESHOLDS = 0.1 # CLIPSeg threshold to convert to black and white, or an array of 2 numbers
DILATE = 10 # number of pixels to dilate the mask by
CLIPSEG = 'models/clipseg_weights/rd64-uni-refined.pth'
def seed_pre():
if not FIXED:
seed_everything(SEED)
def seed_post(device):
if FIXED:
seed_everything(SEED)
return torch.randn([1, 4, HEIGHT // FACTOR, WIDTH // FACTOR], device='cpu').to(torch.device(device.type))
return None
def load_model(config, ckpt_file=CHECKPOINT):
pl_sd = torch.load(ckpt_file, map_location='cpu')
sd = pl_sd['state_dict']
model = instantiate_from_config(config.model)
model.load_state_dict(sd, strict=False)
return model
def set_device(model):
if torch.backends.mps.is_available():
device = torch.device('mps')
precision = nullcontext
elif torch.cuda.is_available():
device = torch.device('cuda')
precision = torch.autocast
else:
device = torch.device('cpu')
precision = torch.autocast
model.to(device.type)
model.eval()
return device, precision
def load_image_mask(device, image_file=IMAGE, mask_file=MASK):
image = Image.open(image_file).convert('RGB')
w, h = image.size
w, h = map(lambda x: x - x % 32, (w, h))
image = image.resize((w, h), resample=Image.Resampling.LANCZOS)
image = np.array(image).astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
image = 2. * image - 1.
image = image.to(device.type)
mask = None
if MASK:
mask = Image.open(mask_file).convert('L')
mask = mask.resize((w // FACTOR, h // FACTOR), resample=Image.Resampling.LANCZOS)
mask = np.array(mask).astype(np.float32) / 255.0
mask = np.tile(mask, (4, 1, 1))
mask = mask[None].transpose(0, 1, 2, 3)
mask = torch.from_numpy(mask).to(device.type)
return image, mask
def load_clipseg(weights_file=CLIPSEG):
clipseg = CLIPDensePredT(version='ViT-B/16', reduce_dim=64, complex_trans_conv=True)
clipseg.eval()
clipseg.load_state_dict(torch.load(weights_file, map_location=torch.device('cpu')), strict=False)
return clipseg
def create_mask(clipseg, image_file=IMAGE, prompt=MASK):
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
transforms.Resize((WIDTH, HEIGHT)) ])
image = transform(Image.open(image_file)).unsqueeze(0)
size = len(MASK)
with torch.no_grad():
preds = clipseg(image.repeat(size,1,1,1), MASK)[0]
image = torch.sigmoid(preds[0][0])
for i in range(1, size):
image += torch.sigmoid(preds[i][0])
image = (image - image.min()) / image.max()
if isinstance(THRESHOLDS, list):
if len(THRESHOLDS) >= 2:
image = torch.where(image >= THRESHOLDS[1], 1., image)
image = torch.where(image <= THRESHOLDS[0], 0., image)
else:
image = torch.where(image >= THRESHOLDS[0], 1., 0.)
else:
image = torch.where(image >= THRESHOLDS, 1., 0.)
mask_file = os.path.join(FOLDER, f'mask_{time.strftime("%Y%m%d_%H%M%S")}.png')
tv_save_image(image, mask_file)
return mask_file
def dilate_mask(mask_file):
image = cv2.imread(mask_file, 0)
kernel = np.ones((DILATE, DILATE), np.uint8)
dilated = cv2.dilate(image, kernel, iterations=1)
cv2.imwrite(mask_file, dilated)
return mask_file
def setup_sampler(model):
global NOISE
if IMAGE:
sampler = DDIMSampler(model)
sampler.make_schedule(ddim_num_steps=STEPS, ddim_eta=NOISE, verbose=False)
image, mask = load_image_mask(model.device, IMAGE, MASK)
sampler.t_enc = int(STRENGTH * STEPS)
sampler.x0 = model.get_first_stage_encoding(model.encode_first_stage(image))
sampler.z_enc = sampler.stochastic_encode(sampler.x0, torch.tensor([sampler.t_enc]).to(model.device.type))
sampler.z_mask = mask
if MASK:
random = torch.randn(mask.shape, device=model.device)
sampler.z_enc = (mask * random) + ((1-mask) * sampler.z_enc)
elif PLMS:
sampler = PLMSSampler(model)
NOISE = 0
else:
sampler = DDIMSampler(model)
return sampler
def get_prompts():
global NEGATIVES
if NEGATIVES is None:
NEGATIVES = [''] * len(PROMPTS)
else:
NEGATIVES.extend([''] * (len(PROMPTS)-len(NEGATIVES)))
return zip(PROMPTS, NEGATIVES)
def generate_samples(model, sampler, prompt, negative, start):
uncond = model.get_learned_conditioning(negative) if SCALE != 1.0 else None
cond = model.get_learned_conditioning(prompt)
if IMAGE:
samples = sampler.decode2(sampler.z_enc, cond, sampler.t_enc,
unconditional_guidance_scale=SCALE, unconditional_conditioning=uncond,
z_mask=sampler.z_mask, x0=sampler.x0)
else:
shape = [4, HEIGHT // FACTOR, WIDTH // FACTOR]
samples, _ = sampler.sample(S=STEPS, conditioning=cond, batch_size=1,
shape=shape, verbose=False, unconditional_guidance_scale=SCALE,
unconditional_conditioning=uncond, eta=NOISE, x_T=start)
x_samples = model.decode_first_stage(samples)
x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
return x_samples
def save_image(image):
name = f'{time.strftime("%Y%m%d_%H%M%S")}.png'
image = 255. * rearrange(image.cpu().numpy(), 'c h w -> h w c')
img = Image.fromarray(image.astype(np.uint8))
img.save(os.path.join(FOLDER, name))
return name
def save_history(name, prompt, negative):
with open(os.path.join(FOLDER, HISTORY), 'a') as history:
history.write(f'{name} -> {"PLMS" if PLMS else "DDIM"}, Seed={SEED}{" fixed" if FIXED else ""}, Scale={SCALE}, Steps={STEPS}, Noise={NOISE}')
if IMAGE:
history.write(f', Image={IMAGE}, Strength={STRENGTH}')
if MASK:
history.write(f', Mask={MASK}')
if len(negative):
history.write(f'\n + {prompt}\n - {negative}\n')
else:
history.write(f'\n + {prompt}\n')
def main():
global MASK
print('*** Simple Stable Diffusion - myByways.com simple-sd version 0.3')
logging.set_verbosity_error()
os.makedirs(FOLDER, exist_ok=True)
seed_pre()
if isinstance(MASK, list):
print(f'*** Loading CLIPSeg and creating mask')
tic1 = time.time()
clipseg = load_clipseg()
MASK = create_mask(clipseg)
if DILATE:
MASK = dilate_mask(MASK)
toc1 = time.time()
print(f'*** CLIPSeg masking time: {(toc1 - tic1):.2f}s')
if not PROMPTS:
sys.exit(0)
print(f'*** Loading Stable Diffusion')
tic1 = time.time()
config = OmegaConf.load(CONFIG)
model = load_model(config)
device, precision_scope = set_device(model)
sampler = setup_sampler(model)
start_code = seed_post(device)
toc1 = time.time()
print(f'*** Stable Diffusion setup time: {(toc1 - tic1):.2f}s')
counter = 0
with torch.no_grad():
with precision_scope(device.type):
with model.ema_scope():
for iteration in range(ITERATIONS):
for prompt, negative in get_prompts():
print(f'*** Iteration {iteration + 1}: {prompt}')
tic2 = time.time()
images = generate_samples(model, sampler, prompt, negative, start_code)
for image in images:
name = save_image(image)
save_history(name, prompt, negative)
print(f'*** Saved image: {name}')
counter += len(images)
toc2 = time.time()
print(f'*** Synthesis time: {(toc2 - tic2):.2f}s')
print(f'*** Total time: {(toc2 - tic1):.2f}s')
print(f'*** Saved {counter} image(s) to {FOLDER} folder.')
if __name__ == "__main__":
try:
main()
except KeyboardInterrupt:
print('*** User abort, goodbye.')
except FileNotFoundError as e:
print(f'*** {e}')
Yesterday, CLIPSeg released a new -refined model! The output is less chunky and the mask has much smoother edges. Make sure you have the latest code!
To update CLIPSeg, I did a pip install --force-reinstall which was a big mistake. I ruined my setup, because the latest PyTorch nightly v1.13.0.dev20220927 is broken. My render time went from 1-2 minutes (MPS) to 10-15 minutes, despite no code change! I had to slowly re-trace my old version until I found the version that worked for me:
pip3 uninstall torch torchvision
pip3 install --pre torch==1.13.0.dev20220915 torchvision==0.14.0.dev20220915 --extra-index-url https://download.pytorch.org/whl/nightly/cpu
Sample Output
First, trying using my previous img2img trial, masking an astronaut and a planet and replacing it with a scary blue alien and the red planet jupiter in cartoon style. The first image is the original, the second is without DILATION and tweaking THRESHOLDS, and the third image is using a good mask.
Next, using the sample image provided by Stable Diffusion: the first image is the original, the second is the mask generated with CLIPSeg for the prompt a woman, and the third image replaces the mask with a photo of an old man in a tweed jacket and boots leaning on a fence, dramatic lighting:
Here’s another example, using just prompts to both create the image, and then to mask and replace parts of the image. In the second image, I masked wings and hair, and in the third image I additionally masked gold dress.
Incredible!