Treating a Matrix as a Database

Today’s neural networks mimic memory through optimization. They compress vast datasets into billions of parameters, encoding patterns indirectly through iterative adjustments during training. This process is powerful but also time-consuming, requiring substantial computational resources to fine-tune these parameters for each new task. But what if there were a way to store memories explicitly, bypassing this lengthy optimization process?

This article explores a technique that enables neural networks to embed and retrieve information directly in their structure, treating a matrix as a memory system akin to a key-value store. Unlike traditional training, this method integrates memory operations—writing and reading information—into the network’s architecture in real time. By leveraging this capability, a neural network could adapt more dynamically to new information, balancing the precision of computer memory with the flexibility of human learning.

How does AI memorize information today?

Generative AI systems are typically pretrained on large datasets with the goal of learning to replicate them. For language models, this involves an autoregressive training approach, where the model predicts the next token in a sequence based on previous tokens. The predicted outputs are compared to the actual next token in the dataset, generating an error signal that guides the optimization of the model’s parameters. This process effectively seeks a configuration of parameters capable of reproducing the source dataset as accurately as possible.

Achieving this level of precision inherently requires the model to encode, or “memorize,” the dataset’s information. While the memorization is not a verbatim storage of the data, the optimization process searches for a representation that captures the relationships, patterns, and structures necessary to regenerate the data. Thus, the act of pretraining on this objective can be understood as a sophisticated form of memorization, where the model internalizes the dataset in a way that allows it to recreate it during inference.

A Matrix as a Key-Value Store

Let’s derive another way that a neural network might learn. Imagine a neural network capable of generating the learned dataset but also storing and retrieving information directly, as though it were writing to a database. One way we might do this is to treat a weight matrix within the neural network as a key-value store. This approach opens the door to another way of managing memory in neural networks.

To achieve this, two core components are required:
A Read Mechanism: The system must retrieve stored information by mapping a query (or key) to its corresponding value.
A Write Mechanism: The system must update the stored information, encoding new key-value pairs.

This system must also be differentiable, so that we can still use backpropagation and gradient descent to train the resulting network.

The Read Mechanism

Lets come up with a read mechanism for this memory system uses a simple and familiar operation: the vector-matrix multiplication. This design aligns with how neural networks typically process information, where activations are projected through weight matrices to generate meaningful outputs. Here, a query vector acts as a coordinate that determines which portion of the weight matrix is read during the operation. Let the query vector be represented as qRd{q} \in R^{d}, where dd is the dimensionality of the query. The memory is stored in a weight matrix, WRn×d{W} \in R^{n \times d}, with nn rows representing distinct stored entries. The read operation is defined by the equation:

y=Wq {y} = {W} \cdot {q}

In this formulation:
q{q}: The query vector, specifying the coordinate in query space.
W{W}: The weight matrix, serving as the memory storage.
yRn{y} \in R^{n}: The resulting vector, representing the retrieved information.

The operation Wq{W} \cdot {q} computes a linear combination of the rows of W{W}, weighted by the components of q{q}. Each element of q{q} determines how much influence its corresponding row of W{W} has on the output. This approach is straightforward to implement in neural networks, where projection operations like these are standard for encoding relationships between activations.

A Simplified Case: One-Hot Queries

To illustrate this process more concretely, consider the case where the query vector q{q} is a one-hot vector—a vector with a single element set to 1 and all others set to 0. For instance, let q=[0,0,1,0,,0]{q} = \lbrack 0,0,1,0,\ldots,0\rbrack^{\top}. Substituting this into the read equation yields:

y=Wq=w3 {y} = {W} \cdot {q} = {w}_{{3}}

Here, w3{w}_{{3}} represents the third row of W{W}, which is directly retrieved. This case shows how the mechanism explicitly retrieves stored values when queries are localized, as each row of W{W} corresponds directly to an indexed value in memory.

This simplified scenario mirrors traditional memory lookup, where an index points to a specific value. Neural networks, however, extend this by allowing q{q} to take on continuous values, enabling the retrieval of weighted combinations of stored entries rather than individual rows.

The Write Mechanism

Lets derive a write mechanism to go along with this read mechanism. We can work backwards from the read mechanism. Similar to the write mechanism in computer memory, we want to change the weights of the matrix so that when we read from the desired key we get the desired value:

v=Wk {v} = {W} \cdot {k}

where:
kRd{k} \in R^{d}: the key vector, specifying the query in memory space.
WRn×d{W} \in R^{n \times d}: the weight matrix, representing the memory.
vRn{v} \in R^{n}: the value vector, representing the output retrieved from the memory.

We can formulate this update process as computing an offset ΔW\Delta{W} to W{W} such that the resulting matrix W+ΔW{W} + \Delta{W} exactly projects key vector k{k} onto value vector v{v}. After incorporating the update, the read mechanism becomes:

v=(W+ΔW)k {v} = \left( {W} + \Delta{W} \right) \cdot {k}

To ensure that the updated memory produces the desired output v{v}, we can solve for ΔW\Delta{W}. Expanding this equation gives:

v=Wk+ΔWk {v} = {W} \cdot {k} + \Delta{W} \cdot {k}

Here:
Wk{W} \cdot {k}: the current output of the read mechanism.
ΔWk\Delta{W} \cdot {k}: the contribution of the update to the output.

Rearranging the equation we can obtain:

ΔWk=vWk \Delta{W} \cdot {k} = {v} – {W} \cdot {k}

This shows that the adjustment ΔW\Delta{W} must account for the difference between the desired value v{v} and the current output Wk{W} \cdot {k}. In order to remove from the left side of the equation, the Moore Penrose Inverse can be applied:

ΔWkkkk=(vWk)kkk \Delta{W} \cdot {k} \otimes \frac{{k}}{{k} \cdot {k}} = \left( {v} – {W} \cdot {k} \right) \otimes \frac{{k}}{{k} \cdot {k}}

Which simplifies to our final equation:

ΔW=(vWk)kkk \Delta{W} = \frac{\left( {v} – {W} \cdot {k} \right) \otimes {k}}{{k} \cdot {k}}

Where:
vWk{v} – {W} \cdot {k}: the residual vector, representing the adjustment needed in the output space.
kk{k} \cdot {k}: the normalization term, ensuring that the update scales appropriately based on the magnitude of k{k}.

Here is a visual representation of this update logic:

Simplifying for Normalized k vectors

This write mechanism involves a normalization term, kk{k} \cdot {k}, to account for the magnitude of the key vector k{k}. This ensures that the update ΔW\Delta{W} scales appropriately with respect to the key’s length. However, this equation simplifies significantly if k{k} is a normalized vector, meaning its length equals 1. In such cases:

kk=length(k)2=12=1 {k} \cdot {k} = {length({k})}^{2} = 1^{2} = 1

Substituting this into the equation removes the normalization factor, resulting in:

ΔW=(vWk)k \Delta{W} = \left( {v} – {W} \cdot {k} \right) \otimes {k}

This simplification eliminates the need to compute the squared sum of the key vector, which reduces the computational overhead of the operation.

A Simplified Case: One-Hot Keys

Again, we can illustrate this mechanism with a one-hot key vector. Lets let k=[0,0,1,0,,0]{k} = \lbrack 0,0,1,0,\ldots,0\rbrack^{\top}. Since this vector has a length of 1, we can use the simplified update equation:

W+ΔW=W+(vWk)k {W} + \mathrm{\Delta}{W} = {W} + \left( {v} – {W} \cdot {k} \right) \otimes {k}

Since k{k} indexes the third row, Wk{W} \cdot {k} is equal to the value stored in the third row of W{W}. Substituting this into the equation gives us:

W+ΔW=W+(vw3)k {W} + \mathrm{\Delta}{W} = {W} + \left( {v} – {w}_{{3}} \right) \otimes {k}

And since the only row being modified is the one indexed by k{k} (the third row) and this is controlled by the outer product with k{k}, we can simply represent this with an update equation for this row:

w3+Δw3=w3+(vw3)k3=w3+vw3 {w}_{{3}} + \mathrm{\Delta}{w}_{{3}} = {w}_{{3}} + ({v} – {w}_{{3}}){k}_{{3}} = {w}_{{3}} + {v} – {w}_{{3}}

Which simplifies to:

w3+Δw3=v {w}_{{3}} + \mathrm{\Delta}{w}_{{3}} = {v}

This demonstrates that this does in fact do what we set out to do: insert value v{v} into matrix W{W} at key k{k}.

A Demonstration

Now that we have established the read and write mechanisms, let’s see them in action. Using Numpy, we can implement the read and write operations and observe how the matrix updates with different key and value pairs.

Implementing the Read and Write Operations

The read operation retrieves information from the memory matrix W{W} based on a query vector q{q}. This is implemented as a simple matrix multiplication, The write operation updates the memory matrix W{W} such that querying with a key vector k{k} yields the desired value vector v{v}. The update is computed as:

import numpy as np

def read(W, q):
    return np.dot(W, q)

def write(W, k, v):
    k_norm_sq = np.dot(k, k)
    delta_W = np.outer(v - np.dot(W, k), k) / k_norm_sq
    W += delta_W
    return W

Let’s Run It

We start by initializing a small memory matrix WW, a key vector kk, and a value vector vv:

W = np.random.rand(3, 3)  # Small memory matrix
k = np.random.rand(3)  # Normalized key vector
v = np.random.rand(3)  # Value vector
print("W =", W)
print("k =", k)
print("v =", v)

Output:
W = [[0.23400824 0.16200084 0.61989965]
 [0.70328459 0.44872138 0.13665879]
 [0.77664905 0.76927199 0.68632115]]
k = [0.23557364 0.78298785 0.11506011]
v = [0.46181898 0.08128806 0.67273326]

Lets see what the current value stored in W{W} at k is:

read_before = read(W, k)
print("Read before write:", read_before)

Output:
Read before write: [0.25329658 0.53274268 0.86425685]

Next, we perform a write operation to update W{W} and show that querying with k{k} produces v{v}:

W = write(W, k, v)
read_after = read(W, k)
print("Read after write:", read_after)

Output:
Read after write: [0.46181898 0.08128806 0.67273326]

The output shows that the write operation successfully updates W{W}, making vv the value returned when reading with k{k}.

Large-Scale Demonstration

This mechanism is not limited to small matrices. Using a large W{W}, k{k}, and v{v}, we can verify that the operation scales effectively:

W = np.random.rand(1000, 1000)
k = np.random.rand(1000)
v = np.random.rand(1000)
print("v:", v[:5])

read_large_before = read(W, k)

print("Read before write:", read_large_before[:5])

W = write(W, k, v)
read_large_after = read(W, k)

print("Read after write:", read_large_after[:5])

Output:
v: [0.44185732 0.73252687 0.17086871 0.30669385 0.7451792 ]
Read before write: [245.86409235 257.01742979 251.3365419  245.5073116  256.29716251]
Read after write: [0.44185732 0.73252687 0.17086871 0.30669385 0.7451792 ]

What About the Simplified Write Method?

The write mechanism can be simplified if the key vector k{k} is already normalized. To demonstrate this, we implement a simplified write operation. Additionally, lets implement a helper function to normalize vectors:

def write_simple(W, k, v):
    delta_W = np.outer(v - np.dot(W, k), k)
    W += delta_W
    return W

def normalize(vec):
    return vec / np.linalg.norm(vec)

First, lets sample our W{W} matrix, k{k} vector, and v{v} vector, as well as compute a normalized version of k{k}:

W = np.random.rand(3, 3)
k = np.random.rand(3)
v = np.random.rand(3)
print("W =", W)
print("k =", k)
print("v =", v)

k_normalized = normalize(k)
print("k_normalized =", k_normalized)

Output:
W = [[0.78186133 0.99731076 0.41638517]
 [0.99191986 0.47667637 0.79124389]
 [0.93051694 0.05489218 0.53498828]]
k = [0.0386618  0.94535449 0.52822756]
v = [0.26334417 0.74346171 0.98091659]
k_normalized = [0.03567865 0.87241079 0.48746944]

Now we can demonstrate that this normalized version of k{k} can correctly compute the update to W{W} using this simplified function:

W_w_normalize = write_simple(W.copy(), k_normalized, v)
print("Simplified write works:")
print("v =", v)
print("y =", read(W_w_normalize, k_normalized))

Output:
Simplified write works:
v = [0.26334417 0.74346171 0.98091659]
y = [0.26334417 0.74346171 0.98091659]

However, using a non-normalized k{k} with the simplified method introduces errors, as the update will not scale correctly:

W_wo_normalize = write_simple(W.copy(), k, v)
print("Simplified write doesn't work:")
print("v =", v)
print("y =", read(W_wo_normalize, k))

Output:
Simplified write doesn't work:
v = [0.26334417 0.74346171 0.98091659]
y = [0.10138726 0.71498244 1.08726618]

What Happens with Multiple Updates?

If we write to the same memory matrix multiple times, the updates can interfere with each other. Consider the following example, where two different key-value pairs are written to W{W}:

W = np.random.rand(3, 3)
k1 = normalize(np.random.rand(3))
v1 = np.random.rand(3)
k2 = normalize(np.random.rand(3))
v2 = np.random.rand(3)
print("W =", W)
print("k1 =", k1)
print("v1 =", v1)
print("k2 =", k2)
print("v2 =", v2)

W1 = write(W.copy(), k1, v1)
W2 = write(W1, k2, v2)

print("Read at k1:", read(W2, k1), "\n  expected:", v1)
print("Read at k2:", read(W2, k2), "\n  expected:", v2)

Output:
W = [[0.31029006 0.15289519 0.89391077]
 [0.84189235 0.66320922 0.05878183]
 [0.41339753 0.38605187 0.50916015]]
k1 = [0.66955548 0.74075881 0.0545147 ]
v1 = [0.590489   0.42438511 0.37899409]
k2 = [0.34733479 0.42039853 0.83822647]
v2 = [0.36811081 0.24278476 0.9231165 ]
Read at k1: [0.18750555 0.42203165 0.56484184] 
  expected: [0.590489   0.42438511 0.37899409]
Read at k2: [0.36811081 0.24278476 0.9231165 ] 
  expected: [0.36811081 0.24278476 0.9231165 ]

To minimize interference, we can ensure that the key vectors are orthogonal. By generating a random vector orthogonal to k{k}, we see that writing with the second key does not affect the value retrieved with the first:

def sample_orthogonal(vec):
    random_vec = np.random.randn(vec.shape[0])
    projection = np.dot(random_vec, vec) / np.dot(vec, vec) * vec
    orthogonal_vec = random_vec - projection
    return normalize(orthogonal_vec)

k2 = sample_orthogonal(k1)
print("k2 =", k2)

W1 = write(W.copy(), k1, v1)
W2 = write(W1, k2, v2)

print("Read at k1:", read(W2, k1), "\n  expected:", v1)
print("Read at k2:", read(W2, k2), "\n  expected:", v2)

Output:
k2 = [ 0.5194568  -0.41453595 -0.7472112 ]
Read at k1: [0.590489   0.42438511 0.37899409] 
  expected: [0.590489   0.42438511 0.37899409]
Read at k2: [0.36811081 0.24278476 0.9231165 ] 
  expected: [0.36811081 0.24278476 0.9231165 ]

Observations

The update mechanism we’ve discussed enables the direct embedding of information within a neural network’s structure by treating a layer as a key-value store. Unlike traditional discrete key-value systems, this approach operates in a continuous key space, introducing unique properties and limitations. Notably, it lacks inherent safeguards against data alteration from future updates involving overlapping keys; subsequent writes can modify previously stored values even if the original key isn’t reused. This characteristic can be advantageous or detrimental, depending on the application: overlapping writes might integrate new observations into a world model or cause unintended interference. A notable property is that updates using orthogonal key vectors do not affect information stored at other orthogonal keys, allowing some information to remain unmodified. In an n-dimensional space, the maximum number of mutually orthogonal vectors is n, corresponding to the space’s dimensionality. Therefore, the matrix’s capacity to store distinct key-value pairs without interference is directly linked to the key vectors’ dimensionality; the number of orthogonal keys, and thus unique entries, cannot exceed the key space’s dimension. This relationship highlights the significance of the key vector size with respect to the matrix’s capacity.

How to make the most of this mechanism?

This update mechanism is particularly relevant in the context of linear transformers, which are designed to approximate traditional softmax-based self-attention mechanisms while maintaining linear complexity. In order to approximate the softmax weighting mechanism these models often use key and query projections that map them onto much large state spaces and often have properties designed to ensure all comparisons between keys are queries are positive (such as utilizing all-positive activations or projections that make this guarantee).

One of the most compelling ways to enhance this system is by crafting architectures with large key dimensions. The dimensionality of the key vectors directly determines the capacity of the memory system, as it governs the maximum number of orthogonal keys that can be stored without interference. By increasing the key size, a network can store more distinct entries, enabling it to scale its memory capacity to match increasingly complex tasks. This property makes the update mechanism a powerful tool for building high-capacity memory systems in neural networks.

Linear transformers often use positive-only activations, such as DPFP, to ensure that the dot product between key and query vectors is always non-negative. This is designed to approximate the softmax weighting scheme as the key dimensionality approaches infinity. However, when applied to the update mechanism described here, these activations introduce a wrinkle. The use of positive-only activations imposes a significant limitation on the representation of orthogonal keys. Orthogonality, by definition, requires the dot product between two vectors to be zero. For positive-only vectors, this is achievable only if their non-zero dimensions do not overlap—an arrangement that is inherently sparse. Sparse vectors, with many zero elements, can satisfy the orthogonality condition in this context, but the activation count limits the number of orthogonal keys that can be represented within a given dimensionality. If the vectors are too sparse it limits the representation space, if the vectors are too dense it limits the separability of key vectors. This sparsity constraint can restrict the effective memory capacity of the system.

Despite these limitations, sparsity brings computational advantages that can be exploited in high-dimensional systems. Sparse representations allow for efficient matrix operations, reducing the computational burden of working with large matrices and high-dimensional vectors. These efficiencies are particularly valuable in memory-heavy applications, where both storage and computation must scale effectively to accommodate the model’s demands. By designing systems that intentionally leverage sparsity, we can balance the trade-offs between capacity and efficiency. We will discuss it in later posts but there is a threshold where sparse computation becomes worth it on a GPU despite the inability to fully utilize tensor cores.

Normalization is one final consideration to make when designing systems that use this mechanism. Normalizing the key and query vector before the mechanism removes any sort of scale when retrieving information from a vector or writing to the memory buffer. It also removes some paradoxical behavior of them mechanism: writes using large key vectors result in small values being stored in the memory buffer as they are scaled to match its value given a large query vector. Combine this with the relative simplicity of the resulting update mechanism and it would be difficult to justify not integrating this into a network’s design.

Is This Used Today?

This update mechanism is central to some of the best-performing linear transformers under active research, including DeltaNet. At its core, DeltaNet employs a simplified version of the update mechanism we’ve discussed. Its rule is expressed as:

Wt+1=Wt+ΔWt=Wt+βt(vtWtkt)kt W_{t + 1} = W_{t} + \Delta W_{t} = W_{t} + \beta_{{t}}\left( v_{{t}} – W_{t} \cdot k_{{t}} \right){k_{{t}}}^{\top}

This formulation, first introduced in the 2021 paper Linear Transformers Are Secretly Fast Weight Programmers by Schlag et al., utilizes the same read mechanism but simplifies the update by skipping explicit normalization of the key vector. DeltaNet also introduces a gating mechanism, represented by β[0,1]\beta \in \lbrack 0,1\rbrack, which determines the degree of interpolation between the current state and the target value. This gate allows the system to adjust how far the update mechanism moves the stored value towards vv, providing finer-grained control over the write operation.

While DeltaNet does not perform L2 normalization before the update mechanism, it closely resembles the process outlined in this article. Interestingly, in 2024, a paper by Yang et al., Parallelizing Linear Transformers with the Delta Rule over Sequence Length, extended this mechanism to optimize chunk-wise execution across the sequence length dimension. This reformulation parallelizes computation across multiple timesteps at the cost of O(n2)O\left( n^{2} \right) complexity within each chunk. In their work, the update mechanism was rewritten in a recurrent form that reveals a Householder matrix:

Wt+1=Wt(Iβtktkt)+βtvtkt W_{t + 1} = W_{t}\left( I – \beta_{t}k_{t}k_{t}^{\top} \right) + \beta_{t}v_{t}k_{t}^{\top}

They utilize this reformulation to derive another representation that allows for efficient parallelization while maintaining the fundamental characteristics of the update mechanism. Notably, Yang et al. incorporated L2 normalization into their ablation study, comparing its performance to the L1 normalization used in the original DeltaNet formulation. They found that L2 normalization improved the model’s performance, aligning with the findings presented earlier in this article and further validating the benefits of this normalization approach.

We will delve deeper into the chunk-wise formulation and its implications for performance in a later post. For now, it is worth emphasizing that these advancements show the elegance of this paradigm of memory storage, as well as the fact that this update mechanism satisfies many of our criteria for infinite context:
– No information is decayed unless explicitly overwritten.
– The state matrix does not saturate as more updates are added.
– There is little risk of overflow as entries are replaced not summed.
– Uses linear time compute and constant memory (like any other RNN).
– Has the potential for large capacity (We will discuss this later, it is bound by the key/query dimension).
– Can be parallelized across the time dimension (We will discuss this in the chunkwise article).

Leave a Reply

Your email address will not be published. Required fields are marked *