How do Textual Inversion tokens destroy prompts?
This blog aims to pose a question on how exactly Textual Inversion tokens destroy prompts/over-dominate the cross-attention in diffusion models. To clarify, I think I got close to the answer but I was only able to form a guess
Background
For evaluating Custom Generation methods like Dreambooth and Textual Inversion, a commonly used evaluation metric is the CLIP score. CLIP score uses CLIP, which is a model made by Open AI to measure image alignment, for how much the generated images are similar to the trained images, and text alignment which measures how much the generated images follow our prompt.
While textual inversion can generate images to a pretty good fidelity, it’s not good at following prompts as shown by the custom diffusion paper below which shows that textual inversion consistently has the lowest text alignment
The reason for this is Textual Inversion is trained to ignore the prompts and just generate the images you train with it. For example in diffusers, we tell our model to generate the same image given all the following prompts
imagenet_templates_small = [
"a photo of a {}",
"a rendering of a {}",
"a cropped photo of the {}",
"the photo of a {}",
"a photo of a clean {}",
"a photo of a dirty {}",
"a dark photo of the {}",
"a photo of my {}",
"a photo of the cool {}",
"a close-up photo of a {}",
"a bright photo of the {}",
"a cropped photo of a {}",
"a photo of the {}",
"a good photo of the {}",
"a photo of one {}",
"a close-up photo of the {}",
"a rendition of the {}",
"a photo of the clean {}",
"a rendition of a {}",
"a photo of a nice {}",
"a good photo of a {}",
"a photo of the nice {}",
"a photo of the small {}",
"a photo of the weird {}",
"a photo of the large {}",
"a photo of a cool {}",
"a photo of a small {}",
]
This motivates our model to generate the subject and ignore the rest of the prompt/overwrite the prompt with the concept. For example, when we ask our model to generate “A <cat-toy> next to a man with a friend” we get
which seems to ignoring the “man with a friend” portion and replaces both with <cat-toy>. This is pretty interesting as each token only takes a 768 dimension vector which is extremely small compared to the rest of the diffusion model. Also, it only affects one word in the clip text encoder.
However, first of all, let us confirm that there is a problem.
The Problem
For this blog, I have only tested with one example but I think I might test with a few more if I get time. I am using daam which can show the contribution of each token to the output like so
For how this is done, it comes from an interesting characteristic in diffusion models where the cross attention maps for a particular token tend to be stuck in the index of the token like so(image taken from prompt-to-prompt paper)
so we can just look at the cross-attention map for the bear and we can see the contribution of the token bear to the output! For more information on this, take a look here.
Now, if we were to look at the contribution of each token for the prompt “A <cat-toy> next to a man with a friend”, when we look at the mean attention maps, for all the normal tokens they have norms of around 10~70. However, the <cat-toy> token attention map consistently had a norm of around 200. Also the <cat-toy> token seemed to have a cleaner attention map compared to other tokens. For example, below is a comparison of the attention maps of <cat-toy> vs the token next
so there is something unique about textual inversion tokens.
To identify what that is, I mainly did 4 tests
- The norm
- The angle
- CLIP Attention
- Checking relation with the SOS token
to see exactly what is causing this cross-attention destruction
The norm
The main theory for how this happens that I found in the LAION discord servers and literature “Encoder-based Domain Tuning for Fast Personalization of Text-to-Image Models”, otherwise known as E4T, is that the norm of textual inversion vector is what causes this collapse. To mitigate the norm from becoming too large, E4T did try to prevent it from growing with a l1 norm. And the claims have some substance. The norms of the tokens follow a distribution like below
where the y-axis is the frequency and the x-axis is the norm. The tiny bump at 0 is for tokens that are never trained. And the mean is around 0.385. Now, for the <cat-toy> token above, the norm is 3.1 which is around 8 times larger than the average tokens. Now then is this what is causing this token to be over-represented?
One counterargument to this is that if you look at the code for the clip text model, which is the text encoder for stable diffusion, we see this line
last_hidden_state = encoder_outputs[0]
last_hidden_state = self.final_layer_norm(last_hidden_state)
Right before we plug into our diffusion model. So essentially, all possible scale information might be lost here. However, it’s very possible that this scale somehow got mingled into the other tokens on the clip side to over-represent the concept. If we assume this, then if we scale down the token’s norm to say around 0.385, the mean token norm, we should see the token get less representation. However, what we get is the below
at the norm 0.385. If we scale to 0.385*2, 0.385*3, and so on until the original scale we get the following images
At least for me, it seems as though the norm does increase quality slightly but lowering it has negligible effect in terms of increasing prompt alignment. I find this to be very fascinating since just a 768-dimension vector can overdominate the prompt this much.
The angle
Now if the norm is not the culprit, is it the angle? For this test, I got the dot product between 1000 tokens with each other(not against the same token). I didn’t do over the entire 49408 tokens for compute reasons. The result of the dot product is below
Just for a refresher, when we do the dot product, we can also get the cosine between two vectors by dividing by their norms like so
The cosine is below
So I think we can say pretty confidently that each input embedding for each token is pretty dissimilar to the others. And since their norms are around 0.385 consistently, we can imagine that each vector is taking up a piece of a sphere in the token space!
Now, let’s take the cosine of each token to the textual inversion token embeddings. What we get is the below
One observation is that the absolute value of the cosine is slightly smaller which means the textual inversion token is slightly more orthogonal/perpendicular to the rest of the tokens than regular tokens which might give us a clue.
CLIP Attention
Now, I mentioned in the beginning that “normal tokens had cross attention norms of around 10~70 at most while the <cat-toy> token consistently had a norm of around 200.”
However, there was one token at that point I haven’t mentioned. That is the start of the sequence token(SOS token). The SOS token is the token that is used to start each prompt. And for some reason, the norm for that turned out to be in the 2000s in the cross-attention maps which does exceed all the token norms I know. So one hypothesis I formed was that perhaps the textual inversion token is taking up characteristics similar to the SOS token to overdominate the prompt with very high attention maps. Secondly, the textual inversion tokens are mainly attacking the text encoder to do this and not much the Stable Diffusion UNet itself.
So let us examine 2 attention maps of the clip text encoder layers. One where we encode the prompt ’A <cat-toy> next to a man with a friend’ and another which is ’A cat toy next to a man with a friend’.
Let’s start with the normal prompt(without textual inversion tokens). For each layer of the clip text transformer, we get attention maps like this
First layer
Second layer
Third layer
To read what these attention maps mean, we can look at the numbers along the axis. If we go to the 1 on the y-axis and 0 on the x-axis. We can see how much attention the token at position 1 is paying to the token at position 0. So as the layers go down deeper and deeper, all the tokens are only paying attention to the SOS token. This is a pretty well-known fact, at least in large language models. I first learned about this in the attention streaming paper(Here is a blog post on it by Tom Aarsen) where the authors used this fact to extend the context length of LLMS!
However, there is one layer that seems to be different than the other layers. This is the first layer below where the y-th token seems to be paying attention to itself, the SOS token, and sometimes the tokens in between. My understanding is the first layer of the clip model is the one that is responsible for getting a hold of all the words in the prompt and encoding them in the start token.
Now then let’s look at the first layer attention map for the textual inversion prompt
We see that at the location of the token, <cat-toy>, at index 2, the textual inversion token seems to only be paying attention to itself. In fact, if we zoom in
We see that it’s paying very little attention to the start token! So in the subsequent layers of the clip, my hypothesis is the textual inversion token is overdominating the rest of the prompt by skipping over the start token at its index. This explains why there is comparably less noise for the textual inversion token. It has full control over generation while the other words are ignored. For hard numbers, the textual inversion token, at index 2, 2, has the highest value in the attention map(0.905) except for the value at index 0, 0 which is 1. We will from here call the value at index 2, 2 the textual inversion attention. So my current guess is there must be some relation between the SOS token and our textual inversion token for this high value for the textual inversion attention.
Checking relation with the SOS token
Let us see what happens if we replace our textual inversion embedding with the SOS token embedding. Interestingly, as the zoomed-in picture of the cross attention below indicates, there doesn't seem to be much attention paid to our token anymore
The image generated is the following
This does seem to indicate that low attention scores in the CLIP text encoder attention map indicate less prompt destruction which we were guessing. Another interesting finding here is that the attention map does not assign more attention to tokens similar to the SOS token.
When we took the cosine of the token with respect to the SOS token, we found the cosine was -0.0821 which is not significant when looking at the rest of the cosines with respect to the SOS token below
but one hypothesis I formed was perhaps the CLIP text encoder pays more attention to tokens that are dissimilar to the SOS token. To confirm this, I tried setting the input embedding of the textual inversion token to the negative of the SOS token. The attention map I got is below
with around 0.88 textual inversion attention! To further confirm this, I did Spherical Linear Interpolation(SLERP) between the scaled-down textual inversion token to 0.385 norm and the SOS token. What I found was that if I am rotating away from the SOS token, the textual inversion attention stays to at least 0.87 or so which is consistently higher than the other tokens. However, if we rotate towards the textual inversion token, the textual inversion attention quickly drops and when it reaches 50%, all signs of any cat toys disappear. At around 0.184 interpolation factor with 0.77 textual inversion attention, I did get the below image
The above is the best image I got from this which shows that some image fidelity is gone. One interesting part of this is when we bring the rotated token back to the original scale, while the general trait of attention score decreasing as it nears the SOS token is still true, the rate is slower. Also, at a certain point, the images turn to very distorted images which indicates there are definitely some parts missing from this puzzle.
My current guess is that when computing the first attention map, there is some operation to subtract the SOS token portion of the token and to assign how much was taken to 0th column and so on. A bit like PCA but not orthogonally because the negative SOS token worked. However, this might be a subject for a future blog.
Conclusion and Future Direction
The above mini research mainly aimed to highlight a pretty interesting phenomenon in textual inversion tokens. One low-hanging fruit is to see if this is reproducible for other tokens but I think one goal that will be interesting is seeing if we can somehow retain the fidelity but still use textual inversion tokens. For example, it might be interesting if we manually set the textual inversion attention to around 0.7 without changing anything about the token. But for now, hope you all enjoyed!
Side Note-Multisubject generation
During multisubject generation, the below is what I get when I combine two separate textual inversion tokens in one prompt one which is a cat and one which is a chair
While individually they generate pretty well like below
My guess for why this is is each token tries to monopolize all the cross attention so they destructively interfere with each other during generation!