from diffusers import StableDiffusionImg2ImgPipeline
Introduction & Setup
Inspired by the amazing paper “DiffEdit: Diffusion-based semantic image editing with mask guidance” which Jeremy discussed in last week’s class of Practical Deep Learning for Coders Pt2, this is an attempt (v1) at implementing step one of the paper’s semantic image editing method.
See below an image from the paper, illustrating this novel method of editing images simply using a text query:
Seeing the automatic mask generation was fascinating for me, and it was evident from interactions during the class that this was indeed a novel, interesting way to generate pixel masks. Stable Diffusion, the poster-child for modern AI, and the basis for cool tech like Dream Studio and Midjourney, could do something more pragmatic too! The mask generation strategy is simple - it involves using stable diffusion to create a pixel mask for an input image, simply using two text prompts. Instead of going off on an verbose explanation I’ll allow a screenshot from the DiffEdit paper to do the job.
If you wish to delve deeper into Stable Diffusion I highly recommend going through this amazing repo from fast.ai - diffusion nbs on GitHub.
To quote the paper, for the mask creation - “We add noise to the input image, and denoise it: once conditioned on the query text, and once conditioned on a reference text (or unconditionally). We derive a mask based on the difference in the denoising results”. The above screenshot is part of a detailed schematic - where the input is the image x0 and query text Q = 'Zebra'
and reference text R = 'Horse'
. The StableDiffusionImg2ImgPipeline
pipeline from HuggingFace takes care of adding the Gaussian noise to an input image, and denoising the resulting image by running inference using a pretrained model. We can play around with the inference input parameters like input image, number of inference steps, strength of prompt guidance, etc. to influence the denoised image output.
Note: This post documents multiple experiments, with varying levels of quality of final output. The intent is to capture the thought process and important trial and error and resulting insights that went into trying to generate a good pixel mask using SD.
Load the pretrained StableDiffusionImg2ImgPipeline
onto pipe
:
= StableDiffusionImg2ImgPipeline.from_pretrained("CompVis/stable-diffusion-v1-4",revision="fp16",torch_dtype=torch.float16, safety_checker = None).to("cuda") pipe
Let’s load and take a look at the input image, x0 :
= Image.open("dark-horse-riverV2.jpeg")
init_image = init_image.load()
pixels init_image.show()
Single output - based mask generation
In this case, we’ll keep num_images_per_prompt=1
in the image-image stable diffusion pipeline, and see how good a mask we can generate using that. We’ll play around a bit with num_inference_steps
to create the best possible mask.
pipe
takes care of adding noise to the input image init_image
, and denoising it for num_inference_steps=8
, with respect to the reference text - "a horse"
.
1000)
torch.manual_seed(= "a horse" #Denoising image with respect to the identifying phrase
prompt = pipe(prompt=prompt, num_images_per_prompt=1, init_image=init_image, strength=0.8, num_inference_steps=10).images
image 0] image[
That’s a scary looking horse, one must admit! For some reason I’ve seen from my experiments (not recorded here) that the model often grows extra appendages from the tail pixels of horse images. We can try reducing num_inference_steps
to prevent that, but the downside is that the output image will be noisy, since we haven’t allowed for enough inference steps. For now, we’ll move on and see what happens with the query text denoising.
We use pipe to add noise to the input image init_image, and denoise it for num_inference_steps=10
, with respect to the query text - "a zebra"
.
1000)
torch.manual_seed(= "a zebra" #Denoising image with respect to the identifying phrase
prompt_q = pipe(prompt=prompt_q, num_images_per_prompt=1, init_image=init_image, strength=0.8, num_inference_steps=10).images
image_q 0] image_q[
That was unexpected! Let’s reduce num_inference_steps
to 8
to prevent the inference from going too far!
1000)
torch.manual_seed(= "a zebra" #Denoising image with respect to the identifying phrase
prompt_q = pipe(prompt=prompt_q, num_images_per_prompt=1, init_image=init_image, strength=0.8, num_inference_steps=8).images
image_q 0] image_q[
Better. Notice how the background is changing considerably based on the prompt. Let’s try altering the prompt to make sure the backgrounds are relatively similar in the images denoised with respect to both Query Q and Reference R.
1000)
torch.manual_seed(= "a horse drinking water" #Denoising image with respect to the identifying phrase
prompt = pipe(prompt=prompt, num_images_per_prompt=1, init_image=init_image, strength=0.8, num_inference_steps=10).images
image 0] image[
1000)
torch.manual_seed(= "a zebra drinking water" #Denoising image with respect to the identifying phrase
prompt_q = pipe(prompt=prompt_q, num_images_per_prompt=1, init_image=init_image, strength=0.8, num_inference_steps=10).images
image_q 0] image_q[
Notice how giving more information in the prompt about the action of the subject with the input “a zebra drinking water” has allowed the model to run for 10 inference steps without morphing the zebra too badly!
Neither of those are particularly pretty images on their own; but let’s attempt to find the “normalised difference between the denoised images”, and generate a rough mask. Then we can try to play with num_inference_steps
and create a better mask. Once that parameter’s ideal value is clear we can move to trying to use multiple denoised images.
from torchvision import transforms
= transforms.ToTensor()
convert_to_tensor = transforms.ToPILImage() convert_to_image
Since we’ll be using this operation often, let’s define a function to compute the normalised absolute difference between two images.
def norm_diff_abs(a,b):
= convert_to_tensor(a)
a_tens = convert_to_tensor(b)
b_tens = (abs(a_tens-b_tens))/(a_tens+b_tens)
ftens = convert_to_image(ftens)
fimage return fimage
Note that abs()
is an important part of getting a meaningful normalised difference image. Without it the resulting difference image doesn’t serve our purpose well. Ponder the reason, dear reader ;)
= norm_diff_abs(image_q[0], image[0]) im_diff_normabs
im_diff_normabs
Now that we have the normalised difference between the denoised images (a visual representation of the contrast between reference text and query), let’s write a function to convert this image to grayscale and binarize this. This function, when called, should yield the final mask M
.
def GSandBinarize(im, **kwargs):
= asarray(im)
im_array = cv2.cvtColor(im_array, cv2.COLOR_BGR2GRAY)
im_grayscale if kwargs["thresh_method"] == 'manual':
= cv2.threshold(im_grayscale, kwargs["thresh"], 255, cv2.THRESH_BINARY) #manual thresholding
th, im_binary_array else:
= cv2.threshold(im_grayscale, 0, 255, cv2.THRESH_OTSU) #auto thresholding
th, im_binary_array = Image.fromarray(im_binary_array)
im_binary return im_binary
Now let’s try binarising the difference using the OTSU thresholding method, which will automatically pick a pixel threshold value.
='OTSU') GSandBinarize(im_diff_normabs, thresh_method
The backgound is not too shabby; it’s almost fully blacked out, but the horse body has a lot of black areas. Now let’s try binarising the difference using the binary thresholding method, which allows us to pick a pixel threshold value.
='manual', thresh=30) GSandBinarise(im_diff_normabs, thresh_method
Now, in this case, more of the horse’s body is white, making the mask better at capturing the subject; but lowering the pixel threshold manually has given rise to many white patches in the background as well. Let’s try solving this.
SideNote: For the exact same model, random seed, and parameters, the mask seems to come out slightly better on a A4000 GPU, as compared to a P5000. All outputs shown here are from running inference on a Paperspace A4000 instance.
Now that we have the whole mask generation process ready, let’s play a bit with the num_inference_steps
parameter to see if we can get a better mask!! Let’s start with num_inference_steps=6
.
1000)
torch.manual_seed(= "a horse drinking water" #Denoising image with respect to the identifying phrase (lower number of inf steps)
prompt = pipe(prompt=prompt, num_images_per_prompt=1, init_image=init_image, strength=0.8, num_inference_steps=6).images
image 0] image[
1000)
torch.manual_seed(= "a zebra drinking water" #Denoising image with respect to the identifying phrase (lower number of inf steps)
prompt_q = pipe(prompt=prompt_q, num_images_per_prompt=1, init_image=init_image, strength=0.8, num_inference_steps=6).images
image_q 0] image_q[
= norm_diff_abs(image_q[0], image[0])
im_diff_normabs ='OTSU') GSandBinarise(im_diff_normabs, thresh_method
='manual', thresh=30) GSandBinarise(im_diff_normabs, thresh_method
As we can see, this leads to a much clearer outline of the horse than before, but the image is littered with too many white patches. I suspect that that can be solved by averaging out multiple outputs (see next section), and so num_inference_steps=6
can be considered a good candidate for our final run. Now let’s try num_inference_steps=8
:
1000)
torch.manual_seed(= "a horse drinking water" #Denoising image with respect to the identifying phrase
prompt = pipe(prompt=prompt, num_images_per_prompt=1, init_image=init_image, strength=0.8, num_inference_steps=8).images
image 0] image[
1000)
torch.manual_seed(= "a zebra drinking water" #Denoising image with respect to the identifying phrase
prompt_q = pipe(prompt=prompt_q, num_images_per_prompt=1, init_image=init_image, strength=0.8, num_inference_steps=8).images
image_q 0] image_q[
= norm_diff_abs(image_q[0], image[0])
im_diff_normabs ='OTSU') GSandBinarise(im_diff_normabs, thresh_method
='manual', thresh=30) GSandBinarise(im_diff_normabs, thresh_method
Well why not use a higher num_inference_steps
to allow for more denoising? Fair point; but what often seems to happen is that, weird artifacts are generated at a high value of num_inference_steps
:
1000)
torch.manual_seed(= "a zebra drinking water" #Denoising image with respect to the identifying phrase
prompt_q = pipe(prompt=prompt_q, num_images_per_prompt=1, init_image=init_image, strength=0.8, num_inference_steps=30).images
image_q 0] image_q[
A super clean-looking image, but now the zebra has two heads!
1000)
torch.manual_seed(= "a horse drinking water" #Denoising image with respect to the identifying phrase
prompt = pipe(prompt=prompt, num_images_per_prompt=1, init_image=init_image, strength=0.8, num_inference_steps=30).images
image 0] image[
= norm_diff_abs(image_q[0], image[0])
im_diff_normabs ='OTSU') GSandBinarise(im_diff_normabs, thresh_method
='manual', thresh=30) GSandBinarise(im_diff_normabs, thresh_method
It’s still just as noisy as the mask outputs we got using a much lower value of num_inference_steps
, and now, the outline has become less accurate because of all the extra heads in the denoised images! Let’s stick to a lower value for num_inference_steps
and come up with a different way to get a clearer mask.
Multi output - based mask generation
Since we are denoising a image after adding Gaussian noise to it, we could try to generate multiple denoised images for each prompt, and average them out before taking the normalised difference of query versus reference. The idea is that averaging multiple images denoised based on the same prompt will help cancel out the random noise, and also create a clearer image of the model’s idea of a “zebra” or “horse”.
Note: Judging from the DiffEdit paper’s schematic of the process, this is likely the method adopted by its authors.
pipe.enable_attention_slicing()
# dark-horse-riverv2-rsz
= Image.open("dark-horse-riverV2-rsz.jpg")
init_image = init_image.load()
pixels init_image.show()
The image has been cropped and resized to avoid a CUDA OOM error when running for a high value of num_images_per_prompt
.
First, we try to implement this with 10 images per prompt, denoised for num_inference_steps=6
.
1000)
torch.manual_seed(= "a horse drinking water"
prompt = pipe(prompt=prompt, num_images_per_prompt=10, init_image=init_image, strength=0.8, num_inference_steps=6).images
images_r6 =2, cols=5) image_grid(images_r6, rows
1000)
torch.manual_seed(= "a zebra drinking water"
prompt_q = pipe(prompt=prompt_q, num_images_per_prompt=10, init_image=init_image, strength=0.8, num_inference_steps=6).images
images_q6 =2, cols=5) image_grid(images_q6, rows
Let’s build out the function for passing in the list of images; and getting the averaged out image! How? Stack the list of images generated and create an average of them, along the “number of images” axis.
Note - This function, like all others, was built by experimenting with each line of its components and then putting them together - a nifty trick, courtesy of Jeremy!
def get_averageIm(ImList):
= []
imtensors for Im in ImList:
= convert_to_tensor(Im)
imtensor
imtensors.append(imtensor)
= torch.zeros(imtensors[0].shape)
init_tensor for tensor in imtensors:
+= tensor
init_tensor
= init_tensor
total_tensor = (total_tensor)/10
av_imtensor = convert_to_image(av_imtensor)
av_image
return av_image
= get_averageIm(images_r6)
av_image_r6 av_image_r6
= get_averageIm(images_q6)
av_image_q6 av_image_q6
Like it or not, this is what peak “horse” and “zebra” performance look like! (reference - Chapter 4 of fastbook).
= norm_diff_abs(av_image_r6, av_image_q6)
im_diff_normabs6 ='OTSU') GSandBinarise(im_diff_normabs6, thresh_method
= GSandBinarise(im_diff_normabs6, thresh_method='manual', thresh=20)
mask_6step mask_6step
Repeat the same for other num_inference_steps=8
values to get a better mask. Also remember to play around with manual threshold value, to get ideal mask.
The mask output looked decent for num_inference_steps=8
in the single-output section, so let’s try that here:
1000)
torch.manual_seed(= "a horse drinking water"
prompt = pipe(prompt=prompt, num_images_per_prompt=10, init_image=init_image, strength=0.8, num_inference_steps=8).images
images_r8 =2, cols=5) image_grid(images_r8, rows
1000)
torch.manual_seed(= "a zebra drinking water"
prompt_q = pipe(prompt=prompt_q, num_images_per_prompt=10, init_image=init_image, strength=0.8, num_inference_steps=8).images
images_q8 =2, cols=5) image_grid(images_q8, rows
= get_averageIm(images_r8)
av_image_r8 av_image_r8
= get_averageIm(images_q8)
av_image_q8 av_image_q8
= norm_diff_abs(av_image_r8, av_image_q8)
im_diff_normabs8 = GSandBinarise(im_diff_normabs8, thresh_method='OTSU')
mask_otsu_8step mask_otsu_8step
= GSandBinarise(im_diff_normabs8, thresh_method='manual', thresh=25)
mask_8step mask_8step
That looks better than the previous mask. We’ve considerably reduced the white patches in the background. Let’s try num_inference_steps=7
.
#inf steps 7
1000)
torch.manual_seed(= "a horse drinking water"
prompt = pipe(prompt=prompt, num_images_per_prompt=10, init_image=init_image, strength=0.8, num_inference_steps=7).images
images_r7 =2, cols=5) image_grid(images_r7, rows
1000)
torch.manual_seed(= "a zebra drinking water"
prompt_q = pipe(prompt=prompt_q, num_images_per_prompt=10, init_image=init_image, strength=0.8, num_inference_steps=7).images
images_q7 =2, cols=5) image_grid(images_q7, rows
= get_averageIm(images_r7)
av_image_r7 av_image_r7
= get_averageIm(images_q7)
av_image_q7 av_image_q7
= norm_diff_abs(av_image_r7, av_image_q7)
im_diff_normabs7 = GSandBinarise(im_diff_normabs7, thresh_method='OTSU')
mask_otsu_7step mask_otsu_7step
= GSandBinarise(im_diff_normabs7, thresh_method='manual', thresh=25)
mask_7step mask_7step
This mask (num_inference_steps=7
) has a slightly cleaner outline than the one for num_inference_steps=8
. Let’s try going a bit higher; num_inference_steps=10
1000)
torch.manual_seed(= "a horse drinking water"
prompt = pipe(prompt=prompt, num_images_per_prompt=10, init_image=init_image, strength=0.8, num_inference_steps=10).images
images_r10 =2, cols=5) image_grid(images_r10, rows
1000)
torch.manual_seed(= "a zebra drinking water"
prompt_q = pipe(prompt=prompt_q, num_images_per_prompt=10, init_image=init_image, strength=0.8, num_inference_steps=10).images
images_q10 =2, cols=5) image_grid(images_q10, rows
As one can see, there are a few weird images, and all the images are quite deformed with respect to the original init_image
. Let’s look at the generated mask.
= get_averageIm(images_r10)
av_image_r10
av_image_r10
= get_averageIm(images_q10)
av_image_q10 av_image_q10
= norm_diff_abs(av_image_r10, av_image_q10)
im_diff_normabs10 = GSandBinarise(im_diff_normabs10, thresh_method='OTSU')
mask_otsu_10step mask_otsu_10step
= GSandBinarise(im_diff_normabs10, thresh_method='manual', thresh=25)
mask_10step mask_10step
We start to see a more patchy outline, because the denoised images have deviated more and more from the original outline. Let’s go down one step and try num_inference_steps=9
.
1000)
torch.manual_seed(= "a horse drinking water"
prompt = pipe(prompt=prompt, num_images_per_prompt=10, init_image=init_image, strength=0.8, num_inference_steps=9).images
images_r9 =2, cols=5) image_grid(images_r9, rows
1000)
torch.manual_seed(= "a zebra drinking water"
prompt_q = pipe(prompt=prompt_q, num_images_per_prompt=10, init_image=init_image, strength=0.8, num_inference_steps=9).images
images_q9 =2, cols=5) image_grid(images_q9, rows
The images are still weird and vary considerably in the position of the subject, so we can expect the mask to be “noisy” as well.
= get_averageIm(images_r9)
av_image_r9 av_image_r9
= get_averageIm(images_q9)
av_image_q9 av_image_q9
= norm_diff_abs(av_image_r9, av_image_q9)
im_diff_normabs9 = GSandBinarise(im_diff_normabs9, thresh_method='OTSU')
mask_otsu_9step mask_otsu_9step
= GSandBinarise(im_diff_normabs9, thresh_method='manual', thresh=30)
mask_9step mask_9step
Though better than the previous run, this still has a rather unclear outline. From the above analyses, we can pick num_inference_steps=6, 7, 8
as yielding usable pixel masks.
= [mask_6step, mask_7step, mask_8step]
mask_list =1, cols=3) image_grid(mask_list, rows
If we could see how these masks look on the original image, I suppose they’d be more convincing. I’m currently facing a minor bug in overlaying a red mask onto the original image; and this notebook will be updated with the same once that is resolved.
Though a far cry from the clean mask illustrated in the DiffEdit paper, I suppose this isn’t a bad start. As I experiment with other techniques (like using other noise schedulers) to generate a fully clean mask, I’ll be adding relevant updates on this post.
Hopefully you got something from your time here! If you read this and have any suggestions/comments on how to improve my code (or words!), please reach out to me @Twitter.
Cheers!! 😄