Ideas

  • Make a “squat” LM with large context and fewer params?

  • train adapter to use external weights as storage?

  • *** Khandelwal’s KNN-LM paper - 1911.00172.pdf ***

    • Cache the inputs to the last FFN after layer-norm is used as the “key” and the token(?) is used as the value in the datastore
      • encode first (n-1) words of document as a vector using the output of selected layer
      • value is the n-th word
      • maybe recursive? and repeated for each word?
    • the model queries datastore at inference times and returns “k” possible values
    • uses Euclidean distance
    • the values are converted into a prob-distribution using softmax of negative distances and the output of LM is interpolated with it using a fixed coefficient lambda
    • p(y|x) = lambda * p_knn(y|x) + (1-lambda) * p_lm(y|x)
  • Drozdov’s followup - 2210.15859.pdf

    • How to pick “lambda” automatically using “retrieval quality”
    • Uses semantic similarity of the retrieved context to base LM
    • Uses “buckets” of distance - one coefficient tuned per bucket
    • Use coefficient of the top/closest result
    • coefficient is “trained”?
      • could we just have some kind of exponential dropoff for this based on distance?
  • “Yogatama”

    • goes further and uses both short term and longterm memory (somehow related to Transformer-XL caching)
    • not sure I understand this one fully
    • uses memory from a separately trained model according to another paper (arxiv:2205.12674)
  • Lample et al

    • replace FFN layers with a KNN lookup layer called PKM - “Product Key Memory”
    • based on “Sukhbataar”’s observation that FFN acts like attention if ReLU is replaced with Softmax
    • very interesting - they added 2 to 3 kNN lookup layers instead of FFNs during training
    • seems to require some weird stuff at training time which makes it work
    • does not use a database - storage is part of network still
    • eh
  • “Memory-efficient Transformers” - Gupta et al 2021

    • replace dense attention with kNN lookup
      • actually more like a computational trick for attention rather than using memory
    • doesn’t seem to include actual databases here though
    • eh
  • Memorizing transformers - Wu et al 2022

    • during training, store “keys and values” (from attention layers?) in one of the last layers into a KV-db

    • only store “n” elements at a time dropping the older stuff

    • perform attention on both local context (per usual) and on the memory, subbing in the search result keys and values in place of the usual keys and values and using the same queries as the other attention block

    • a learned “gate”, g, (same as lambda from knn-lm) scales between local attn and memory

      • there is a per-head parameter that is passed through a sigmoid fn to get the value of “g”
    • “no position bias” for retrieved memories

    • normalization of keys and queries help mitigate some of the “staleness” during training

      • staleness caused by the model params changing the queries over time during training
    • knn attention layer has its own weights separate from that used for local attention

      • the input search query is of size “hidden_size” and there are “seq_len” number of queries
    • non-memory model can be fine-tuned to use memory

    • very interesting stuff

“Training Language Models with Memory Augmentation” - use in-batch examples as memory during training? - three types of memory - local memory - from recent past that is obtained using attention - long term memory - from same doc but beyond context size - external memory - stored data from training or other data - “ By packing consecutive segments from the same document in one training batch, our model can access long-term memories beyond the attention context“ - “training memories” - presumably used only during training and is from the same batch - loss function includes a term for memory - computed similar to how Khandelwal’s work does it - maximizes similarity between lastlayeroutput(c) w/ same output for c_j from batch memory - scaled dot product is used as similarity function - “local memory” = using current token sequence as training memory for above loss fn - simply incorporating local memory provides a notable gain on multiple benchmarks at little cost - “long term memory” - use consecutive segments from document in batch - M_train includes tokens from previous segments as well as the preceding tokens in the same segment - “external memory” - packing segments that have large lexical overlap into the same batch using “BM25” scores - Specifically, we start with a single segment and repeatedly add segments with highest BM25 scores into the same batch - To encourage the use of information from other segments, we exclude the local memory from M_train with a probability of p (=90%) during training - training process - standard loss function for first 5% updates - inference - interpolate similar to KNN-LM

-knn-adapter - 2302.10879 - improves on Khandelwal et al - train a small network to predict token-wise interpolation coefficients + context-specific correction coefficients - create data-store from traininset, train on validation set (possibly to minimize cross entropy with ground truth next token prob?) - On the order of 10^4 parameters - much better than fine tuning? - generalizes to other datasets even when using data stores from different data