I refactored my previous Stable Diffusion code, to clean up, OO it a little, and add new features like tiling, upscaling, PNG metadata. As I mentioned before, I don’t understand AI/ML... but I do understand programming! So here is my new, more elegant Simple-SD v1.0 Python script.

For a gentle explanation of Stable Diffusion for dummies, I thought this Illustrated Stable Diffusion was very good.

This is part 6 in what is turning out to be a series:

  1. (16 Sep 22): AI-generated images with Stable Diffusion on an M1 mac a.k.a. txt2img
  2. (17 Sep 22): Stable Diffusion image-to-image mode a.k.a. img2img
  3. (21 Sep 22): My simplified Stable Diffusion Python script a.k.a txtimg2img
  4. (26 Sep 22): Stable Diffusion script with inpainting mask a.k.a. mask+txtimg2img
  5. (28 Sep 22): Adding CLIPSeg automatic masking to Stable Diffusion a.k.a. txt2mask+txtimg2img
  6. (9 Oct 22): this post, myByways Simple-SD v1.0 Python script

Features

Stable Diffusion v1-4:

  • Re-implement the original Stable Diffusion v1-4 algorithms, merging txt2img.py, img2img.py and inpaint.py into one,
  • Use Apple Metal Performance Shader (MPS) for M1/M2 chips, similar to bfirsh’s fork except for inpaint which still requires CPU,
  • Add proper inpainting based on a text prompt to fill in a masked area, similar to Stable Diffusion Web UI by overriding DDIMSampler.decode() in ldm/models/diffusion/ddim.py,
  • Add seamless tile generation, based on InvokeAI merge #339 to set padding_mode = circular,
  • Enable offline mode by setting TRANSFORMERS_OFFLINE=1 environment variable in code (note: adding HF_DATASETS_OFFLINE and HF_HUB_OFFLINE don’t seem to do anything),
  • Reset seed generation each run for reproducible results - the seed is reset for every prompt rather than incrementing by 1.

CLIPSeg September 2022 model:

  • Add text-to-mask to automatically create a mask, using CLIPSeg’s new fine-grained predictions**
  • Sum multiple positive and negative prompts (this code I’m in two minds about),
  • dilate the mask to completely cover edges,
  • Add some sort of range threshold to “feather” grayscale.

Real-ESRGAN v0.3.0

PNG metadata:

  • Store most configuration parameters used by Stable Diffusion in the output PNG file in a tEXt chunk (note: Author and Software are always set),
  • Able to load use parameters from the PNG metadata, if you want to re-run Stable Diffusion with the same parameters,
  • Use ExifTool by Phil Harvey (or write a function in Python like mine) to retrieve PNG metadata.

Some developer’s notes:

  • Everything is in a single Python file that expects the user to override the global variables in the CONFIG class directly in code, rather than having to remember the CLI!
  • I have an M1 Pro with 16 GPU cores + 16 neural engine cores and have hardcoded values like f32 precision scope and setting batch size to 1 (anything larger is too slow) - I do not know if an M1, M1 Max, M1 Ultra, or M2 behaves in the same way,
  • Stable Diffusion inpaining and GFPGAN face enhancement are CPU only, and cannot run on MPD, because PyTorch support for aten::_slow_conv2d_forward is not yet implemented, see issue #77764,
  • Finally, I only use 512x512 PNG files. I do not have any error handling. Good luck.

Usage Scenarios

Refer to the OVERRIDES comment in the code for the usage scenarios:

n Usage Required Variables
1 Text-to-image generation one or more PROMPTS in an array (list), and optionally, one or more NEGATIVES - prompts correspond to negatives on a one-to-one basis (but not all prompts need to have negatives)
2 Seamless tile generation PROMPTS and TILING = True
3 Image-to-image generation PROMPTS and IMAGE which is a PNG image
4 Inpaining based on surrounding IMAGE and MASK which is a PNG file, where white = areas to be replaced and back = areas to be retained, while PROMPTS = None
5 Text-to-mask followed by inpainting IMAGE and MASK, which this time is an array of one or more text prompts which are applied in unison, and optionally, UNMASK text array to subtract from the mask
6 Text-to-mask followed by image-to-image with a maske PROMPTS, IMAGE and MASK
7 Upscaling only IMAGE and UPSCALE factor of 1 or higher, optionally FACE_ENHANCE = True
8 Text-to-image with upscaling with face enhancement PROMPTS and UPSCALE
9 Text-to-mask, image-to-image, mask inpainting, upscaling PROMPTS, IMAGE, MASK, UPSCALE, and optional tuning parameters like SD.STRENGTH, CS.DILATE, etc.

Examples

Here are some examples outputs I generated from the sample images provided in the Stable Diffusion repo. Uncomment the OVERRIDE lines in the code. In the case of Rick (#7) I just took a crop of his face.

myByways Simple-SD Python script sample images

Installation

The code below assumes you want Stable Diffusion and dependencies installed in a folder called stable-diffusion.

# Install dependencies using Homebrew
brew update
brew install python Cmake protobuf rust

# Clone Stable Diffusion and model
git clone https://github.com/CompVis/stable-diffusion/
cd stable-diffusion
mkdir -p models/ldm/stable-diffusion-v1/
curl -L https://www.googleapis.com/storage/v1/b/aai-blog-files/o/sd-v1-4.ckpt?alt=media -o models/ldm/stable-diffusion-v1/sd-v1-4.ckpt

# Install Python pre-requisites
python3 -m pip install virtualenv
python3 -m virtualenv venv
source venv/bin/activate
pip install -r requirements.txt
pip install basicsr
pip install gfpgan
pip install pytorch

# Clone CLIPSeg and weights
git clone https://github.com/timojl/clipseg
cd clipseg
curl -L https://owncloud.gwdg.de/index.php/s/ioHbRzFx6th32hn/download -o weights.zip
unzip weights.zip
mv clipseg-weights weights
rm weights.zip
python setup.py develop
cd ..

# Clone Real-ESRGAN and weights
git clone https://github.com/xinntao/Real-ESRGAN
mv Real-ESRGAN realesrgan
cd realesrgan
curl -L https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth -o weights/RealESRGAN_x4plus.pth
curl -L https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth -o weights/RealESRGAN_x4plus_anime_6B.pth
curl -L https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth -o weights/GFPGANv1.4.pth
pyton setup.py develop
cd ..

# Get correct version of PyTorch 
pip install --pre torchvision==0.14.0.dev20221001 --extra-index-url https://download.pytorch.org/whl/nightly/cpu  
pip install --pre torch==1.13.0.dev20220915 --extra-index-url https://download.pytorch.org/whl/nightly/cpu

For me, Stable Diffusion decode on MPS is ~2 minutes. Some PyTorch nightly versions have MPS bugs, causing decoding to take ~10 times longer.

My Simple-SD v1.0 Script

Place this script in the root of the stable-diffusion folder (not the scripts folder). Do chmod +x simple-sd.py to make it executable before running it, ./simple-sd.py.

The first time the script is run, various files and models will be downloaded. What is downloaded depends on the usage scenario. However, once you are sure files for all scenarios have been downloaded, feel free to uncomment the line with TRANSFORMERS_OFFLINE=1.

#! python -O
# copyright (c) 2022 C.Y. Wong, myByways.com simplified Stable Diffusion v1.0
import os, time, pickle, warnings, logging, logging.handlers
import pickletools
import sys
os.environ['TRANSFORMERS_OFFLINE'] = '0' #transformers/utils/hub.py
import torch
import numpy as np
from pytorch_lightning import seed_everything
from omegaconf import OmegaConf
from PIL import Image

#######################################
# Global Variables
#######################################
import json

class CONFIG():
    """Global configuration variables - edit or override"""

    # Stable Diffusion `txt2img`:
    PROMPTS = []            # one or more prompts in an array, one per output image
    NEGATIVES = None        # negative prompt, one or more, default None / []
    TILING = False          # seamless tiling, default=False

    # Stable Diffusion `img2img` with mask and `inpaint`:
    IMAGE = None            # img2img initial seed image in PNG format, default None
    MASK = None             # either: inpainting mask in PNG format, white to overwrite, black to retain, default None / []
    MASK = []               # or: one or more prompts for CLIPSeg to generate the mask
    UNMASK = []             # and optionally: negative mask to deduct from MASK, default None / []

    # Real-ESRGAN upscaling and face enhancement:
    UPSCALE = 0             # default 0 to skip upscaling, or set to 1 for no size change but with enhancement
    FACE_ENHANCE = False    # when upscaling, enable face enhancement (runs on CPU only), default False 

    # Also important:
    SEED = None             # Seed, default None for random (same seed for predictable txt2img results)
    WIDTH = 512             # image height default 512 (larger causes M1 to crawl)
    HEIGHT = 512            # image width default 512 (larger causes M1 to crawl)
    OUTPUT_DIR = 'outputs'          # output directory for images, generated masks, and history file
    HISTORY_FILE = 'history.txt'    # history log file name
    FILENAME = '%y%m%d_%H%M%S.png'

    class SD:               # Core Stable Diffusion:
        ITERATIONS = 1      # iterations (repeat runs using the same set of prompts), default 1
        FACTOR = 8          # down sampling factor, default 8 (can be larger, but not smaller)
        STEPS = 50          # number of steps, seldom visible improvement beyond ~50, default 50
        NOISE = 0           # additional noise, 0 deterministic - 1.0 random noise, ignored for PLMS (i.e. --ddim_eta)
        SCALE = 7.5         # guidance scale, 5 further -> 15 closer to prompt, default 7.5
        SAMPLER = 'PLMS'    # sampler, default 'PLMS' for txt2img which will set NOISE=0, and anything else for DDIM
        STRENGTH = 0.75     # img2img strength, 0 more like original -> 1 less like, default 0.75
        CONFIG = 'configs/stable-diffusion/v1-inference.yaml'
        MODEL = 'models/ldm/stable-diffusion-v1/sd-v1-4.ckpt'

    class IP(SD):           # Inpainting without mask:
        CONFIG = 'models/ldm/inpainting_big/config.yaml'
        MODEL = 'models/ldm/inpainting_big/last.ckpt'

    class CS:               # CLIPSeg `txt2mask` :
        BLACK = 0.2         # threshold <= to which to convert to black, to feather the edges 
        WHITE = 0.5         # threshold >= to which to convert to white, to feather the edges
        DILATE = 20         # mask dilate amount in pixels, default 20
        MODEL = 'clipseg/weights/rd64-uni-refined.pth'
        FILENAME = '%y%m%d_%H%M%Sm.png'

    class RE:               # Real-ESRGAN up-scaling, change RealESRGAN_x4plus to RealESRGAN_x4plus_anime_6B for anime
        MODEL = 'realesrgan/weights/RealESRGAN_x4plus.pth' # or 'realesrgan/weights/RealESRGAN_x4plus_anime_6B.pth' only
        FACE_MODEL = 'realesrgan/weights/GFPGANv1.4.pth'
        FILENAME = '%y%m%d_%H%M%Su.png'

    class METADATA:
        SAVE_LEVEL = 1      # save config to PNG metadata, default 0=no metadata, 1=prompt + seed, 2=prompt + seed + SD parameters
        LOAD_FILE = None    # filename of PNG to retrieve config from, default None

# OVERRIDES: EXAMPLES / TEST CASES

    # 1. Stable Diffusion txt2img
    PROMPTS = ['the quick brown fox jumps over the lazy dog']

    # 2. Stable Diffusion txt2img with tiling
    #PROMPTS = ['multicolored parquet flooring']
    #TILING = True

    # 3. Stable Diffusion img2img
    #PROMPTS = ['a dark dreary cyberpunk city with many neon signs']
    #IMAGE = 'data/inpainting_examples/bertrand-gabioud-CpuFzIsHYJ0.png'

    # 4. Stable Diffusion inpainting
    #IMAGE = 'data/inpainting_examples/bench2.png'
    #MASK = 'data/inpainting_examples/bench2_mask.png'

    # 5. Stable Diffusion inpainting with CLIPSeg text2mask
    #IMAGE = 'data/inpainting_examples/overture-creations-5sI6fQgYIuo.png'
    #MASK = ['dog']

    # 6. Stable Diffuion txt+img2img with CLIPSeg txt2mask
    #IMAGE = 'data/inpainting_examples/overture-creations-5sI6fQgYIuo.png'
    #MASK = ['dog']
    #PROMPTS = ['a small tiger cub sitting on a bench']

    # 7. Real-ESRGAN upscaling and GFPGAN face enhancement only
    #IMAGE = 'assets/rick.jpeg'
    #UPSCALE = 2
    #FACE_ENHANCE = True

    # 8. Stable Diffusion with upscaling 
    #PROMPTS = ['a black_knight and a red_knight clashing swords in a lush forrest, realistic, 4k, hd']
    #SEED = 588513860
    #UPSCALE = 2

    # 9. Everything
    #PROMPTS = ['a medevial castle with a large tree in front']
    #IMAGE = 'data/inpainting_examples/bertrand-gabioud-CpuFzIsHYJ0.png'
    #SD.STRENGTH = 0.8
    #MASK = ['buildings']
    #UNMASK = ['tree', 'trunk']
    #CS.DILATE = 50
    #UPSCALE = 2
    #FACE_ENHANCE = False

#######################################
# StableDiffusion txt2img
#######################################
from einops import rearrange
from contextlib import nullcontext
from ldm.util import instantiate_from_config
from ldm.models.diffusion.plms import PLMSSampler
from ldm.models.diffusion.ddim import DDIMSampler
from transformers import logging as transformers_logging

class StableDiffusion_Text():
    """Stable Diffusion text to image (M1 MPS)"""

    def __init__(self, device):
        """Initialize Stable Diffusion with the given config and checkpoint files"""
        logging.debug(f'StableDiffusion_Text Init Config={CONFIG.SD.CONFIG}')
        tic = time.perf_counter()
        config = OmegaConf.load(CONFIG.SD.CONFIG)
        model = instantiate_from_config(config.model)
        pl_sd = torch.load(CONFIG.SD.MODEL, map_location='cpu')
        model.load_state_dict(pl_sd['state_dict'], strict=False)
        model.to(device.type)
        model.eval()
        self.model = model
        self.device = device
        logging.debug(f'StableDiffusion_Text Init Time={(time.perf_counter()-tic):.2f}s')

    def setup_sampler(self):
        """Select the sampler, PLMS or DDIM, and some setup if tiling enabled"""
        if CONFIG.SD.SAMPLER == 'PLMS':
            self.sampler = PLMSSampler(self.model)
            CONFIG.SD.NOISE = 0
        else:
            self.sampler = DDIMSampler(self.model)
            CONFIG.SD.SAMPLER = 'DDIM'
        tiling = 'circular' if CONFIG.TILING else 'zeros'
        for m in self.model.modules():
            if isinstance(m, (torch.nn.Conv2d, torch.nn.ConvTranspose2d)):
                if tiling != m.padding_mode: m.padding_mode = tiling
        CONFIG.SEED = seed_everything(CONFIG.SEED)

    def _generate_sample(self, conditioning, unconditional):
        h = CONFIG.HEIGHT // CONFIG.SD.FACTOR
        w = CONFIG.WIDTH // CONFIG.SD.FACTOR
        shape = [4, h, w]
        start = torch.randn([1, 4, h, w], device='cpu').to(torch.device(self.device.type))
        sample, _ = self.sampler.sample(S=CONFIG.SD.STEPS, conditioning=conditioning, batch_size=1, shape=shape, verbose=False, unconditional_guidance_scale=CONFIG.SD.SCALE, unconditional_conditioning=unconditional, eta=CONFIG.SD.NOISE, x_T=start)
        return sample

    def __save_image(self, image):
        image = 255. * rearrange(image.cpu().numpy(), 'c h w -> h w c')
        img = Image.fromarray(image.astype(np.uint8))
        image_file = os.path.join(CONFIG.OUTPUT_DIR, time.strftime(CONFIG.FILENAME))
        img.save(image_file)
        image_size = os.path.getsize(image_file)
        logging.info(f'  ==> Output Image={image_file}, Size={image_size}')
        if image_size < 1000: logging.warning(f'  ❌ Output image file size is too small ❌')
        return image_file

    def create_image(self, positive:str, negative:str)-> str:
        """Create a single image from the prompts"""
        logging.info(f'  Positive=\'{positive}\'')
        if negative: logging.info(f'  Negative=\'{negative}\'')
        with torch.no_grad():
            precision_scope = nullcontext if self.device.type == 'mps' else torch.autocast
            with precision_scope(self.device.type):
                with self.model.ema_scope():
                    uncond = self.model.get_learned_conditioning(negative) if CONFIG.SD.SCALE != 1.0 else None
                    cond = self.model.get_learned_conditioning(positive)
                    sample = self._generate_sample(cond, uncond)
                    sample = self.model.decode_first_stage(sample)
                    images = torch.clamp((sample + 1.0) / 2.0, min=0.0, max=1.0)
        image_file = self.__save_image(images[0])
        METADATA.save(image_file, positive, negative)
        return image_file

    def _create_images(self, positives, negatives, iterations, reset_seed):
        image_files = []
        if negatives:
            negatives.extend([''] * (len(positives)-len(negatives)))
        else:
            negatives = [''] * len(positives)
        for positive, negative in zip(positives, negatives):
            for iteration in range(iterations):
                if reset_seed: CONFIG.SEED = seed_everything(CONFIG.SEED)
                else: CONFIG.SEED += 1
                image_files.append(self.create_image(positive, negative))
        return image_files

    def create_images(self, positives:[str], negatives:[str]=None, iterations=CONFIG.SD.ITERATIONS, reset_seed=True)-> [str]:
        """Create images, repeating a number of times using the same seed - raison detre""" 
        logging.info(f'StableDiffusion_Text Sampler={CONFIG.SD.SAMPLER if CONFIG.SD.SAMPLER else "DDIM"}, Size={CONFIG.WIDTH}x{CONFIG.HEIGHT}, Factor={CONFIG.SD.FACTOR}')
        logging.info(f'  Seed={CONFIG.SEED}, Steps={CONFIG.SD.STEPS}, Noise={CONFIG.SD.NOISE}, Scale={CONFIG.SD.SCALE}{", Tiling" if CONFIG.TILING else ""}')
        tic = time.perf_counter()
        image_files = self._create_images(positives, negatives, iterations, reset_seed)
        logging.debug(f'StableDiffusion_Text Run Time={(time.perf_counter()-tic):.2f}s')
        return image_files

#######################################
# StableDiffusion img2img / text+mask2img / img+mask2img
#######################################
from tqdm import tqdm

class DDIMSampler_Mask(DDIMSampler):
    """Override original DDIMSampler to add z_mask and x0 parameters"""

    @torch.no_grad()
    def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None, use_original_steps=False, z_mask = None, x0=None):
        timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps
        timesteps = timesteps[:t_start]
        time_range = np.flip(timesteps)
        total_steps = timesteps.shape[0]
        print(f"Running DDIM Sampling with {total_steps} timesteps")
        iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
        x_dec = x_latent
        for i, step in enumerate(iterator):
            index = total_steps - i - 1
            ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long)
            if z_mask is not None and i < total_steps - 2:
                img_orig = self.model.q_sample(x0, ts)
                mask_inv = 1. - z_mask
                x_dec = (img_orig * mask_inv) + (z_mask * x_dec)
            x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps, 
                unconditional_guidance_scale=unconditional_guidance_scale, 
                unconditional_conditioning=unconditional_conditioning)
        return x_dec

class StableDiffusion_Image(StableDiffusion_Text):
    """Stable Diffusion text+image/mask to image (M1 MPS)"""

    def load_image(self, image_file, mask_file=None):
        """Load the image and mask in PNG format, 512x512"""
        self.image_file = image_file
        self.mask_file = mask_file
        image = Image.open(image_file).convert('RGB')
        w, h = image.size
        w, h = map(lambda x: x - x % 32, (w, h))
        w, h = max(w, CONFIG.WIDTH), max(h, CONFIG.HEIGHT)
        image = image.resize((w, h), resample=Image.Resampling.LANCZOS)
        image = np.array(image).astype(np.float32) / 255.0
        image = image[None].transpose(0, 3, 1, 2)
        image = torch.from_numpy(image)
        image = 2. * image - 1.
        image = image.to(self.device.type)
        self.image = image
        if mask_file:
            mask = Image.open(mask_file).convert('L')
            w, h = mask.size
            w, h = map(lambda x: x - x % 32, (w, h))
            w, h = max(w, CONFIG.WIDTH), max(h, CONFIG.HEIGHT)
            mask = mask.resize((w // CONFIG.SD.FACTOR, h // CONFIG.SD.FACTOR), resample=Image.Resampling.LANCZOS)
            mask = np.array(mask).astype(np.float32) / 255.0
            mask = np.tile(mask, (4, 1, 1))
            mask = mask[None].transpose(0, 1, 2, 3)
            mask = torch.from_numpy(mask).to(self.device.type)
            self.mask = mask
        else:
            self.mask = None

    def setup_sampler(self):
        """Select the sampler, DDIM only"""
        CONFIG.SD.SAMPLER = 'DDIM'
        self.sampler = DDIMSampler_Mask(self.model)
        CONFIG.SEED = seed_everything(CONFIG.SEED)

    def _generate_sample(self, conditioning, unconditional):
        self.sampler.make_schedule(ddim_num_steps=CONFIG.SD.STEPS, ddim_eta=CONFIG.SD.NOISE, verbose=False)
        t_enc = int(CONFIG.SD.STRENGTH * CONFIG.SD.STEPS)
        x0 = self.model.encode_first_stage(self.image)
        x0 = self.model.get_first_stage_encoding(x0) 
        z_enc = self.sampler.stochastic_encode(x0, torch.tensor([t_enc]).to(self.device.type))
        if self.mask_file:
            random = torch.randn(self.mask.shape, device=self.device)
            z_enc = (self.mask * random) + ((1-self.mask) * z_enc)
        sample = self.sampler.decode(z_enc, conditioning, t_enc, unconditional_guidance_scale=CONFIG.SD.SCALE, unconditional_conditioning=unconditional, z_mask=self.mask, x0=x0)
        return sample

    def create_images(self, positives:[str], negatives:[str]=None, iterations=1, reset_seed=True)-> [str]:
        """Create images, repeating a number of times using the same seed - raison detre""" 
        logging.info(f'StableDiffusion_Image Sampler={CONFIG.SD.SAMPLER if CONFIG.SD.SAMPLER else "DDIM"}, Size={CONFIG.WIDTH}x{CONFIG.HEIGHT}, Factor={CONFIG.SD.FACTOR}')
        logging.info(f'  Seed={CONFIG.SEED}, Strength={CONFIG.SD.STRENGTH}, Steps={CONFIG.SD.STEPS}, Noise={CONFIG.SD.NOISE}, Scale={CONFIG.SD.SCALE}{", Tiling" if CONFIG.TILING else ""}')
        logging.info(f'  Image={self.image_file}{f", Mask={self.mask_file}" if self.mask_file else ""}')
        tic = time.perf_counter()
        image_files = super()._create_images(positives, negatives, iterations, reset_seed)
        logging.debug(f'StableDiffusion_Image Run Time={(time.perf_counter()-tic):.2f}s')
        return image_files

#######################################
# StableDiffusion inpaint
#######################################

class StableDiffusion_Erase():
    """Stable Diffusion inpainting (CPU)"""

    def __init__(self, device):
        """Initialize Stable Diffusion Erase with the given config and checkpoint files"""
        logging.debug(f'StableDiffusion Init Config={CONFIG.IP.CONFIG}')
        tic = time.perf_counter()
        config = OmegaConf.load(CONFIG.IP.CONFIG)
        model = instantiate_from_config(config.model)
        pl_sd = torch.load(CONFIG.IP.MODEL, map_location='cpu')
        model.load_state_dict(pl_sd['state_dict'], strict=False)
        model.to(device.type)
        self.device = device
        self.model = model
        logging.debug(f'StableDiffusion Run Time={(time.perf_counter()-tic):.2f}s')

    def setup_sampler(self):
        """Select the sampler, DDIM"""
        CONFIG.IP.SAMPLER = 'DDIM'
        self.sampler = DDIMSampler(self.model)
        CONFIG.SEED = seed_everything(CONFIG.SEED)

    def load_image(self, image_file:str, mask_file:str):
        """Load the image and mask PNG files, max 512x512"""
        self.image_file = image_file
        self.mask_file = mask_file
        image = np.array(Image.open(image_file).convert('RGB'))
        image = image.astype(np.float32)/255.0
        image = image[None].transpose(0,3,1,2)
        image = torch.from_numpy(image)
        mask = np.array(Image.open(mask_file).convert('L'))
        mask = mask.astype(np.float32)/255.0
        mask = mask[None, None]
        mask[mask < 0.5] = 0
        mask[mask >= 0.5] = 1
        mask = torch.from_numpy(mask)
        masked_image = (1-mask) * image
        image = image.to(self.device.type)
        image = 2. * image - 1.
        self.image = image
        mask = mask.to(self.device.type)
        mask = 2. * mask - 1.
        self.mask = mask
        masked_image = masked_image.to(self.device.type)
        masked_image = 2. * masked_image - 1.
        self.masked_image = masked_image

    def __save_image(self, image):
        img = Image.fromarray(image.astype(np.uint8))
        image_file = os.path.join(CONFIG.OUTPUT_DIR, time.strftime(CONFIG.FILENAME))
        img.save(image_file)
        logging.info(f'  ==> Output Image={image_file}')
        return image_file

    def create_image(self)-> str:
        """Create image by erasing masked area - raison detre"""
        logging.info(f'StableDiffusion_Erase Sampler={CONFIG.IP.SAMPLER if CONFIG.IP.SAMPLER else "DDIM"}, Steps={CONFIG.IP.STEPS}')
        if self.device.type == 'cpu': logging.warning(f'  StableDiffusion_Erase running on CPU')
        logging.info(f'  Image={self.image_file}{f", Mask={self.mask_file}" if self.mask_file else ""}')
        tic = time.perf_counter()
        with torch.no_grad():
            with self.model.ema_scope():
                cond = self.model.cond_stage_model.encode(self.masked_image)
                uncond = torch.nn.functional.interpolate(self.mask, size=cond.shape[-2:])
                cond = torch.cat((cond, uncond), dim=1)
                shape = (cond.shape[1]-1,)+cond.shape[2:]
                samples, _ = self.sampler.sample(S=CONFIG.IP.STEPS, conditioning=cond, batch_size=cond.shape[0], shape=shape, verbose=False)
                samples = self.model.decode_first_stage(samples)
                image = torch.clamp((self.image+1.0)/2.0, min=0.0, max=1.0)
                mask = torch.clamp((self.mask+1.0)/2.0, min=0.0, max=1.0)
                predicted_image = torch.clamp((samples+1.0)/2.0, min=0.0, max=1.0)
                inpainted = (1-mask)*image+mask*predicted_image
                inpainted = inpainted.cpu().numpy().transpose(0,2,3,1)[0]*255
        image_file = self.__save_image(inpainted)
        logging.debug(f'StableDiffusion_Erase Run Time={(time.perf_counter()-tic):.2f}s')
        return image_file

#######################################
# CLIPSeg
#######################################
import cv2
from torchvision import transforms
from torchvision.utils import save_image as tv_save_image
from clipseg.models.clipseg import CLIPDensePredT

class CLIPSeg_Mask():
    """CLIPSeg for text to mask (CPU)"""

    def __init__(self, device):
        """Initialize CLIPSeg with the given weights file"""
        logging.debug(f'CLIPSeg_Mask Init Weights={CONFIG.CS.MODEL}')
        tic = time.perf_counter()
        model = CLIPDensePredT(version='ViT-B/16', reduce_dim=64, complex_trans_conv=True)
        pl_sd = torch.load(CONFIG.CS.MODEL, map_location='cpu')
        model.load_state_dict(pl_sd, strict=False)
        model.eval()
        self.device = device
        self.model = model
        logging.debug(f'CLIPSeg_Mask Init Time={(time.perf_counter()-tic):.2f}s')

    def __load_image(self, image_file):
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            transforms.Resize((CONFIG.WIDTH, CONFIG.HEIGHT)) ], )
        image = Image.open(image_file)
        return transform(image).unsqueeze(0)

    def __generate_samples(self, image, prompts):
        length = len(prompts)
        with torch.no_grad():
            preds = self.model(image.repeat(length,1,1,1), prompts)[0]
            mask = sum(torch.sigmoid(preds[i][0]) for i in range(length))
            mask = (mask - mask.min()) / mask.max()
            mask = torch.where(mask >= CONFIG.CS.WHITE, 1., mask)
            mask = torch.where(mask <= CONFIG.CS.BLACK, 0., mask)
        return mask

    def __dilate_mask(self, mask_file, dilate):
        mask = cv2.imread(mask_file, 0)
        kernel = np.ones((dilate, dilate), np.uint8)
        mask = cv2.dilate(mask, kernel, iterations=1)
        cv2.imwrite(mask_file, mask)
        return mask_file

    def create_mask(self, image_file:str, positives:[str], negatives:[str]=None)-> str:
        """Create mask from the prompts, in PNG format - raison detre"""
        logging.info(f'CLIPSeg_Mask Image={image_file}, Size={CONFIG.WIDTH}x{CONFIG.HEIGHT}, Dilate={CONFIG.CS.DILATE}, Black={CONFIG.CS.BLACK}, White={CONFIG.CS.WHITE}')
        logging.info(f'  Positive={positives}')
        tic = time.perf_counter()
        image = self.__load_image(image_file)
        mask = self.__generate_samples(image, positives)
        if negatives:
            logging.info(f'  Negative={negatives}')
            mask -= self.__generate_samples(image, negatives)
        mask_file = os.path.join(CONFIG.OUTPUT_DIR, time.strftime(CONFIG.CS.FILENAME))
        tv_save_image(mask, mask_file)
        if CONFIG.CS.DILATE: self.__dilate_mask(mask_file, CONFIG.CS.DILATE)
        logging.info(f'  ==> Output Mask={mask_file}')
        logging.debug(f'CLIPSeg_Mask Run Time={(time.perf_counter()-tic):.2f}s')
        return mask_file

#######################################
# Real-ESRGAN and GFPGAN upscaling
#######################################
from realesrgan import RealESRGANer
from basicsr.archs.rrdbnet_arch import RRDBNet
from gfpgan import GFPGANer

class RealESRGAN_Upscaler():
    """Real-ESRGAN Upscaler (M1 MPS)"""

    def __init__(self, device):
        """Initialize Real-ESRGAN with the given weights"""
        logging.debug(f'RealESRGAN_Upscaler Init Weights={CONFIG.RE.MODEL}')
        tic = time.perf_counter()
        self.device = device
        if 'RealESRGAN_x4plus_anime_6B' in CONFIG.RE.MODEL:
            self.model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
        else:
            self.model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
        self.model.to(device.type)
        self.upsampler = RealESRGANer(scale=4, model_path=CONFIG.RE.MODEL, model=self.model, device=device.type)
        logging.debug(f'RealESRGAN Init Time={(time.perf_counter()-tic):.2f}s')       

    def load_image(self, image_file:str):
        """Load image, applying MPS fix if required"""
        self.image_file = image_file
        image = cv2.imread(image_file)
        if self.device.type == 'mps':
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            image = image.flatten().reshape((image.shape[2], image.shape[0], image.shape[1])).transpose((1,2,0))
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        self.image = image

    def _save_image(self, image):
        image_file = os.path.join(CONFIG.OUTPUT_DIR, time.strftime(CONFIG.RE.FILENAME))
        cv2.imwrite(image_file, image)
        if CONFIG.PROMPTS and CONFIG.METADATA.SAVE_LEVEL:
            prompts = METADATA.load_str(self.image_file, 'prompt', 'negative', 'Software')
            if 'myByways' in prompts[2]: METADATA.save(image_file, prompts[0], prompts[1])
        logging.info(f'  ==> Output Image={image_file}')
        return image_file

    def upscale_image(self, scale=CONFIG.UPSCALE)->str:
        """Upscale the image - raison detre """
        logging.info(f'RealESRGAN_Upscaler Image={self.image_file}, Scale={scale}')
        tic = time.perf_counter()
        image, _ = self.upsampler.enhance(self.image, outscale=scale)
        image_file = self._save_image(image)
        logging.debug(f'RealESRGAN_Upscaler Run Time={(time.perf_counter()-tic):.2f}s')
        return image_file

class RealESRGAN_FaceUpscaler(RealESRGAN_Upscaler):
    """Real-ESRGAN with GFPGAN Face Enhancement (CPU)"""

    def upscale_image(self, scale=CONFIG.UPSCALE):
        """Upscale the image and enhance faces - raison detre) """
        logging.info(f'RealESRGAN_FaceUpscaler Image={self.image_file}, Scale={scale}, Face Enahnce=True')
        if self.device.type == 'cpu': logging.warning(f'  RealESRGAN_FaceUpscaler running on CPU')
        tic = time.perf_counter()
        face_enhancer = GFPGANer(model_path=CONFIG.RE.FACE_MODEL, upscale=scale, bg_upsampler=self.upsampler, device=self.device)
        _, _, image = face_enhancer.enhance(self.image, has_aligned=False, only_center_face=False, paste_back=True)
        image_file = self._save_image(image)
        logging.debug(f'RealESRGAN_FaceUpscaler Run Time={(time.perf_counter()-tic):.2f}s')
        return image_file

#######################################
# PNG metadata 
#######################################
from PIL.PngImagePlugin import PngInfo

class METADATA(CONFIG):
    """Stores Stable Diffusion parameters in PNG iTXt chunk (UTF8)"""

    def __save_list(values):
        if not isinstance(values, list): return values
        masks = [mask.replace('|', ' ') for mask in CONFIG.MASK]
        return '|'+'|'.join(masks)+'|'

    def __load_list(values):
        if values[0] != '|': return values
        values = values.split('|')
        return [value for value in values if value]

    def save(image_file, prompt, negative=None, level=CONFIG.METADATA.SAVE_LEVEL):
        metadata = PngInfo()
        metadata.add_itxt('Software', 'Simple-SD v1.0 copyright (c) 2022 C.Y. Wong, myByways.com')
        metadata.add_itxt('Author', 'Stable Diffusion v1')    
        if level >= 1:
            metadata.add_itxt('prompt', prompt)
            if negative: metadata.add_itxt('negative', negative)
            metadata.add_itxt('seed', str(CONFIG.SEED))
            if CONFIG.TILING: metadata.add_itxt('tiling', str(1))
        if level >= 2: 
            metadata.add_itxt('steps', str(CONFIG.SD.STEPS))
            metadata.add_itxt('noise', str(CONFIG.SD.NOISE))
            metadata.add_itxt('scale',str(CONFIG.SD.SCALE))
            if CONFIG.SD.SAMPLER: metadata.add_itxt('sampler', CONFIG.SD.SAMPLER)
            if CONFIG.UPSCALE: metadata.add_itxt('upscale', str(CONFIG.UPSCALE))
        if level >= 3:
            if CONFIG.IMAGE: 
                metadata.add_itxt('image', CONFIG.IMAGE)
                metadata.add_itxt('strength', str(CONFIG.SD.STRENGTH))
            if CONFIG.MASK: metadata.add_itxt('mask', METADATA.__save_list(CONFIG.MASK))
            if CONFIG.UNMASK: metadata.add_itxt('unmask', METADATA.__save_list(CONFIG.UNMASK))
        if image_file.lower().endswith('.png'):
            image = Image.open(image_file)
            image.save(image_file, pnginfo=metadata)
        return metadata

    def load_str(image_file, *keys):
        values = []
        image = Image.open(image_file)
        for key in keys:
            values.append(image.text.get(key, None))
        return values

    def load(image_file=CONFIG.METADATA.LOAD_FILE):
        logging.debug(f'METADATA Loading from PNG File={image_file}')
        if not image_file.lower().endswith('.png'): return None
        image = Image.open(image_file)
        CONFIG.TILING = False
        CONFIG.SD.SAMPLER = ''
        for k, v in image.text.items():
            if k == 'prompt': CONFIG.PROMPTS = [v]
            elif k == 'negative': CONFIG.NEGATIVES = [v]
            elif k == 'seed': CONFIG.SEED = int(v) if v != 'None' else None
            elif k == 'tiling': CONFIG.TILING = True
            elif k == 'steps': CONFIG.SD.STEPS = int(v)
            elif k == 'noise': CONFIG.SD.NOISE = float(v)
            elif k == 'scale': CONFIG.SD.SCALE = float(v)
            elif k == 'sampler': CONFIG.SD.SAMPLER = v
            elif k == 'upscale': CONFIG.UPSCALE = float(v)
            elif k == 'image': CONFIG.IMAGE = v
            elif k == 'strength': CONFIG.SD.STRENGTH = float(v)
            elif k == 'mask': CONFIG.MASK = METADATA.__load_list(v)
            elif k == 'unmask': CONFIG.UNMASK = METADATA.__load_list(v)
        return image.text

    def __str__(self):
        return f'Prompts = {CONFIG.PROMPTS}\nNegatives = {CONFIG.NEGATIVES}\nTiling = {CONFIG.TILING}\n' \
            f'Image = {CONFIG.IMAGE}\nMask = {CONFIG.MASK}\nUnmask = {CONFIG.UNMASK}\n' \
            f'Upscale = {CONFIG.UPSCALE}\nFace_Enhance = {CONFIG.FACE_ENHANCE}\nSeed = {CONFIG.SEED}\n' \
            f'Width = {CONFIG.WIDTH}\nHeight = {CONFIG.HEIGHT}\nFilename = {CONFIG.FILENAME}\n' \
            f'SD.Factor = {CONFIG.SD.FACTOR}\nSD.Steps = {CONFIG.SD.STEPS}\nSD.Noise = {CONFIG.SD.NOISE}\n' \
            f'SD.Scale = {CONFIG.SD.SCALE}\nSD.Strength = {CONFIG.SD.STRENGTH}\nSD.Config = {CONFIG.SD.CONFIG}\n' \
            f'SD.Iterations = {CONFIG.SD.ITERATIONS}\nIP.Config = {CONFIG.IP.CONFIG}\nIP.Model = {CONFIG.IP.MODEL}\n' \
            f'CS.Black = {CONFIG.CS.BLACK}\nCS.White = {CONFIG.CS.WHITE}\nCS.Dilage = {CONFIG.CS.DILATE}\n' \
            f'CS.Model = {CONFIG.CS.MODEL}\nCS.Filename = {CONFIG.CS.FILENAME}\nRE.Model = {CONFIG.RE.MODEL}\n' \
            f'RE.Face_Model = {CONFIG.RE.FACE_MODEL}\nRE.Filename = {CONFIG.RE.FILENAME}'

def setup_logging():
    class ConsoleLog_Formatter(logging.Formatter):
        formats = {
            logging.DEBUG:    logging.Formatter('🔵 \x1b[34m%(message)s\x1b[0m'),
            logging.INFO:     logging.Formatter('🟢 \x1b[32m%(message)s\x1b[0m'),
            logging.WARNING:  logging.Formatter('🟡 \x1b[33;1m%(message)s\x1b[0m'),
            logging.ERROR:    logging.Formatter('❌ \x1b[31;1m%(message)s\x1b[0m'),
            logging.CRITICAL: logging.Formatter('❌ \x1b[31;1m%(message)s\x1b[0m')
        }
        def format(self, log_record):
            return self.formats[log_record.levelno].format(log_record)

    class FileLog_InfoFilter(object):
        def filter(self, log_record):
            return log_record.levelno <= logging.INFO

    logger = logging.getLogger()
    logger.setLevel(logging.DEBUG)
    if len(logger.root.handlers) == 0:
        logger.addHandler(logging.StreamHandler())
    logger.handlers[0].setLevel(logging.DEBUG)
    logger.handlers[0].setFormatter(ConsoleLog_Formatter())
    file_log = logging.FileHandler(os.path.join(CONFIG.OUTPUT_DIR, CONFIG.HISTORY_FILE))
    file_log.setLevel(logging.INFO)
    file_log.addFilter(FileLog_InfoFilter())
    file_log.setFormatter(logging.Formatter('%(message)s'))
    logger.addHandler(file_log)
    warnings.filterwarnings('ignore')
    logging.getLogger('PIL.PngImagePlugin').setLevel(logging.ERROR)
    logging.getLogger('pytorch_lightning').setLevel(logging.ERROR)
    transformers_logging.set_verbosity_error()

#######################################
# Main
#######################################

def main():
    print('\x1b[44m🙈🙉🙊 \x1b[1;34m\x1b[4mmyByways Simple Stable Diffusion v1.0 (1 Oct 2022)\x1b[0m\x1b[44m 🙈🙉🙊\x1b[0m')
    setup_logging()
    os.makedirs(CONFIG.OUTPUT_DIR, exist_ok=True)

    if CONFIG.METADATA.LOAD_FILE:
        METADATA.load(CONFIG.METADATA.LOAD_FILE)

    if torch.backends.mps.is_available():
        device = torch.device('mps')
    elif torch.cuda.is_available():
        device = torch.device('cuda')
    else:
        device = torch.device('cpu')

    mask_file = CONFIG.MASK
    if CONFIG.MASK and isinstance(CONFIG.MASK, list):
        cs = CLIPSeg_Mask(device)
        mask_file = cs.create_mask(CONFIG.IMAGE, CONFIG.MASK, CONFIG.UNMASK)

    image_files = None
    if CONFIG.PROMPTS:
        if CONFIG.IMAGE:
            sd = StableDiffusion_Image(device)
            sd.load_image(CONFIG.IMAGE, mask_file)
        else:
            sd = StableDiffusion_Text(device)
        sd.setup_sampler()
        image_files = sd.create_images(CONFIG.PROMPTS, CONFIG.NEGATIVES)
    elif CONFIG.IMAGE and mask_file:
        sd = StableDiffusion_Erase(torch.device('cpu'))
        sd.setup_sampler()
        sd.load_image(CONFIG.IMAGE, mask_file)
        image_files = [sd.create_image()]
    else:
        image_files = [CONFIG.IMAGE]

    upscaled_files = []
    if CONFIG.UPSCALE and image_files:
        if CONFIG.FACE_ENHANCE:
            re = RealESRGAN_FaceUpscaler(torch.device('cpu'))
        else:
            re = RealESRGAN_Upscaler(device) 
        for image_file in image_files:
            re.load_image(image_file)
            upscaled_files.append(re.upscale_image(CONFIG.UPSCALE))

    if mask_file != CONFIG.MASK: logging.debug(f'😎 Mask     = [{mask_file}]')
    if image_files:              logging.debug(f'😎 Images   = {image_files}')
    if upscaled_files:           logging.debug(f'😎 Upscaled = {upscaled_files}')

if __name__ == '__main__':
    try:
        main()
    except KeyboardInterrupt:
        logging.error('User abort, goodbye')
    except FileNotFoundError as e:
        logging.error(f'File not found {e}')

Don’t use my code if you can’t understand it. See the disclaimers! I provide no support or documentation other that the code itself, so... good luck!