Generalization through memorization: Nearest neighbor language models

This paper1 describes a method for augmenting an existing language model with external memory to improve its performance without requiring any extra training. The datastore is initialized for a given dataset and used during inference time. The authors demonstrate performance improvements (measured in perplexity over a given dataset) even when operating on data-stores for datasets the model was not trained on.

How it works

The method essentially involves “memorizing” the training set and using it to directly augment the model at inference time. This can also be used for memorizing data other than the training set and give similar improvements.

Indexing or “Memorizing”

The data-store is a “key-value” database, similar to faiss or pinecone with vectors (of floating point numbers) that form keys and some arbitrary data as value. To “memorize” a data set, the existing LLM is evaluated on the data split into some chunk size, and the outputs of the network right before the final “softmax” layer is used as the “key” for the database, while the subsequent token in the dataset (which the model is supposed to predict) is stored as the value.

Assume that $x_0$, $x_1$ … $x_{n-1}$, are the different tokens in a chunk of text from the dataset. Let f be a function that converts this token sequence into the “key” vector. In the paper, they examined different layer outputs for this and showed that the output of the final layer, right before the soft-max activation, is a good candidate for this. So for a given token-sequence, the key and value are given by

$$ k = f([x_0, x_1, ... , x_{n-1}])\\ v = x_n $$

For a chunk of n-tokens, we may have up to n-1 different data-points stored in the database, i.e. $f([x_0]) \rightarrow x_1$, $f([x_0, x_1]) \rightarrow x_2$ and so on.

Inference

At inference time, the input token-sequence is run through the LM to get the probability distribution for the next token. The activations of the final layer prior to the soft-max is then used to perform a k-nearest-neighbors search in the vector datastore created in the last step (the authors used Euclidean distance and k=1000 for this search). The vectors are then converted into a probability distribution of its own using the distances as follows:

$$ P_{kNN}(y|x) = \sum_{(k_i, v_i \in N)} \mathbb{1}_{y=v_i} \exp\left(\ -d(k_i, f(x_i)) \right) $$

where $\mathbb{1}_{y=v_i}$ is the one-hot encoded vector for token $v_i$, $d(k_i, f(x_i))$ is the distance for the search result from the search-key.

$P_{LM}(y|x)$ are the logits from the original model. The final probability distribution is then computed by linearly interpolating between the two distributions:

$$ P(y|x) = \lambda P_{kNN}(y|x) + (1 - \lambda) P_{LM}(y|x) $$

where $\lambda$ is a fixed coefficient. Reference 2 talks about selecting $\lambda$ based on “semantic similarity” (the cosine distance?) between the closest key from the search results and the search query. They trained a model to predict what the coefficient-profile should be for a given dataset (i.e how to map semantic similarity to the interpolation coefficient $\lambda$).

References

[1] Khandelwal, U., Levy, O., Jurafsky, D., Zettlemoyer, L., & Lewis, M. (2019). Generalization through memorization: Nearest neighbor language models. arXiv preprint arXiv:1911.00172.

[2] Drozdov, A., Wang, S., Rahimi, R., McCallum, A., Zamani, H., & Iyyer, M. (2022). You can’t pick your neighbors, or can you? When and how to rely on retrieval in the kNN-LM. arXiv preprint arXiv:2210.15859.

Backlinks

  • AI/ML
  • AI Research Papers