Time Travel

Breaking the Temporal Bottleneck in Linear Transformers

To enable the development of scalable linear transformers capable of supporting infinite-context modeling, it is crucial to address two key challenges: achieving efficient parallelism across the time dimension and managing extremely large recurrent states. Here we examine an existing solution for accelerating the DeltaNet algorithm, which employs the delta rule to dynamically update memory states. While their formulation demonstrates hardware-efficient training through chunkwise parallelism, it remains constrained by the need to materialize key and query vectors during some operations. This eliminates the potential benefits of compressed key/query vector representations such as Symmetric Power Transformers or sparse vectors. We propose an updated formulation that resolves this limitation. This adjustment ensures that key and query vectors may remain unmaterialized during computation, except in cases where direct interaction with the recurrent state is required. Through theoretical analysis and empirical validation, we demonstrate that our approach preserves mathematical correctness while preserving the properties that make both chunkwise DeltaNet and compressed vectors compute and memory efficient.

Memory-Base Attention

The attention mechanism has emerged as a fundamental component for accurate sequence modeling due to its ability to capture dependencies across sequences of arbitrary lengths. Attention is computationally efficient during training, as it heavily relies on matrix multiplications that are well-suited to modern GPU architectures. However, the quadratic complexity of standard attention with respect to sequence length makes it prohibitively expensive for very-long-context tasks. Despite recent advancements aimed at scaling attention through hardware-aware optimizations, this still presents a significant limitation for scalable, infinite-context modeling.

Linear attention transformers address these issues by replacing the softmax attention mechanism with a dot-product computation over (possibly transformed) key and query vectors. This reformulation enables a constant-memory inference process and allows linear attention to be represented as a linear RNN with matrix-valued hidden states. While early variants of linear attention lagged behind softmax attention in tasks like language modeling, recent improvements, such as gated linear transformers and time-varying state space models like Mamba, have demonstrated competitive performance. Notable among recent improvements is chunkwise recurrent computation which allows for the computation of multiple timesteps simultaneously while only materializing the recurrent state at the end of each chunk.

DeltaNet, introduced by Schlag et al. (2021), builds on the linear attention framework by incorporating a delta rule-like update. This mechanism dynamically retrieves and updates memory states, enhancing associative recall over long contexts. Despite its effectiveness in small-scale settings, the original DeltaNet formulation relied on a sequential algorithm that could not parallelize across sequence length, resulting in inefficient training and challenges in scaling to larger models and datasets. Subsequent work by Yang et al. (2024) addressed this limitation by developing a chunkwise parallel algorithm for DeltaNet, leveraging a memory-efficient reparameterization using generalized Householder transformations. This approach enables scalable training and inference of DeltaNet by avoiding the need to materialize recurrent states during chunk computation.

Very Large Vectors and States

There has been empirical evidence that demonstrates that the performance of linear transformers, DeltaNet variants especially, improves dramatically with the size of the key and query vectors, and subsequently with the size of the recurrent state. The number of separate values that can be stored in a DeltaNet state is equal to the orthogonality of the key space, and when the key space is not constrained this is equal to the dimensionality of the key space. This, among other benefits, has led to the development of compressed vector representations during linear transformer computation, the most notable being Symmetric Power Transformers introducted by Buckman et al. (2024).

Linear transformers generally utilize a feature map ϕ()\phi() that projects key and query vectors before they are used within the attention mechanism. Symmetric Power Transformers treat this feature map as a modified outer product expansion. This enables them to compare keys and queries with each other using what is called the kernel trick; where keys and queries can be compared with a simple dot product followed by raising the result by a power n{n} rather than computing the outer product expansion. Combine this with only materializing the key and query vectors within a gpu kernel when interacting with the state and the full vector never needs to sit in memory, freeing up both space and bandwidth. Buckman et al. applied their approach to traditional linear transformers which has a simpler chunkwise algorithm and integration process than DeltaNet.

The Linear Transformer

Transformers with linear attention reformulate the attention mechanism by replacing the softmax operation with a dot product over feature-transformed key and query vectors. This reformulation avoids the quadratic complexity of traditional transformers, enabling constant-memory inference by leveraging matrix-valued hidden states.

A typical linear transformer computes the output as:

ot=Stϕ(qt),St=St1+vtϕ(kt)T o_{t} = S_{t}\phi(q_{t}),\quad S_{t} = S_{t – 1} + v_{t}\phi(k_{t})^{T}

where StS_{t} is the state matrix updated recurrently, and qtq_{t}, ktk_{t}, and vtv_{t} are the query, key, and value vectors at time step t. phi()phi() represents a feature map, generally used to project qtq_{t} and ktk_{t} onto larger state spaces that preserve orthogonality. While this recurrent formulation is computationally efficient with O(Ld2)O\left( Ld^{2} \right) complexity, it is inherently sequential, limiting parallelism.

To address this limitation, chunkwise parallelism has been introduced. In the chunkwise form, the sequence is divided into chunks of size C, enabling partial parallelization. The chunkwise update for state and output is expressed as:

S[t+1]=S[t]+V[t]TK[t],O[t]=Q[t]S[t]T+(Q[t]K[t]TM)V[t] S\lbrack t + 1\rbrack = S\lbrack t\rbrack + V\lbrack t\rbrack^{T}K\lbrack t\rbrack,\quad O\lbrack t\rbrack = Q\lbrack t\rbrack S\lbrack t\rbrack^{T} + \left( Q\lbrack t\rbrack K\lbrack t\rbrack^{T} \odot M \right)V\lbrack t\rbrack

where S[t], Q[t], K[t], and V[t] are the state, query, key, and value matrices for a given chunk, and M is the causal mask. This formulation strikes a balance between computational efficiency and sequence-level parallelism, significantly improving hardware utilization.

DeltaNet

DeltaNet improves upon the linear transformer by introducing a delta update rule, inspired by mechanisms for associative recall. The update rule dynamically adjusts the memory state StS_{t} based on the interaction between the current key and the existing memory, enabling selective overwriting of past associations. We cover a similar mechanism in a previous article: “Treating a Matrix as a Database”. The delta update can be expressed in two forms:

  1. Recurrent Offset Form:

St=St1+ΔSt,ΔSt=βt(vtSt1ϕ(kt))ϕ(kt)T S_{t} = S_{t – 1} + \Delta S_{t},\quad\Delta S_{t} = \beta_{t}\left( v_{t} – S_{t – 1}\phi(k_{t}) \right)\phi(k_{t})^{T}

Where βt(0,1)\beta_{t} \in (0,1) is a learnable gating parameter that controls the strength of the update. In the original formulation, βt\beta_{t} is a vector the size of vtv_{t} which controls the gating channels. If βt\beta_{t} is a scalar we can rework the equation into a recurrent form.

  1. Recurrent Form:

St=St1(Iβtϕ(kt)ϕ(kt)T)+βtvtϕ(kt)T S_{t} = S_{t – 1}\left( I – \beta_{t}\phi(k_{t})\phi(k_{t})^{T} \right) + \beta_{t}v_{t}\phi(k_{t})^{T}

This form reparameterizes the update as a generalized Householder transformation, which was utilized by Yang et al (2024) to derive a chunkwise formulation that is parallelizable across the time dimension despite dependent state updates.

Chunkwise DeltaNet

Yang et al. (2024) addressed DeltaNet’s sequential bottleneck by introducing a chunkwise parallel form. Building on the reparameterization of the delta rule, they derived a brilliant formulation that enables chunkwise updates while maintaining the algorithm’s expressivity.

The chunkwise update for DeltaNet is expressed as:

S[t+1]=S[t]+(U[t]W[t]S[t]T)TK[t] S\lbrack t + 1\rbrack = S\lbrack t\rbrack + \left( U\lbrack t\rbrack – W\lbrack t\rbrack S\lbrack t\rbrack^{T} \right)^{T}K\lbrack t\rbrack

where:

W[t]=T[t]K[t],U[t]=T[t]V[t] W\lbrack t\rbrack = T\lbrack t\rbrack K\lbrack t\rbrack,\quad U\lbrack t\rbrack = T\lbrack t\rbrack V\lbrack t\rbrack

And

T[t]=(Itril(Diag(β[t])K[t]K[t]T,1))1Diag(β[t]) T\lbrack t\rbrack = \left( I – \text{tril}\left( \text{Diag}\left( \beta_{\lbrack t\rbrack} \right)K\lbrack t\rbrack K\lbrack t\rbrack^{T}, – 1 \right) \right)^{- 1}\text{Diag}\left( \beta_{\lbrack t\rbrack} \right)

This formulation ensures that:

  1. The state and output updates for each chunk are computed without explicitly materializing the hidden state at all time steps.

  2. The computation of a chunk can be effectively parallelized, avoiding explicit recurrence.

  3. The computation stays linear over the larger sequence, despite the quadratic cost of operation within the chunk.

By extending the chunkwise parallelism introduced for linear transformers, Yang et al. (2024) made it possible to train DeltaNet efficiently on modern hardware. However, their approach requires materializing key and query vectors, introducing a scalability bottleneck when the dimensionality of these vectors grows.

Symmetric Power Transformers

Buckman et al. (2024) propose a framework for using symmetric power expansions as a feature map ϕ()\phi() within linear transformers. Their approach replaces standard dot products between key and query vectors with a polynomial kernel ϕ(k)ϕ(q)=(kq)n\phi(k) \cdot \phi(q) = (k \cdot q)^{n}, where nn is an even degree to ensure all comparisons are positive. This kernel implicitly represents a dot product in a high-dimensional feature space, corresponding to all degree-n outer products of the input vector’s channels.

The feature map ϕ(k)\phi(k) can be materialized by computing the products of all unique permutations of the channels in kk, scaled by the number of duplicates in the repeated outer product. While this expansion grows exponentially with n, Buckman et al. showed that the kernel trick allows efficient comparisons without materializing ϕ(k)\phi(k). Instead, compressed representations of ϕ(k)\phi(k) suffice for comparisons, and the full feature map can be materialized on-demand for matrix operations involving state updates. This technique, and others with similar properties, enables efficient manipulation of large state matrices while retaining computational feasibility in high-dimensional feature spaces.

A Demonstration of Chunkwise DeltaNet

Instead of presenting a full derivation of chunkwise DeltaNet, we validate its correctness through a comparison with the recurrent DeltaNet formulation. This serves as both an empirical validation and a useful unit test framework for future modifications. First we will demonstrate the equivalence of the original formulation and the recurrent form with the householder matrix.

Lets start by creating a function that executes DeltaNet’s core update mechanism. We will assume that S has shape [key_size, value_size], q has shape [key_size] and is l2_normalized, k has shape [key_size] and is l2 normalized, v has shape [value_size], and β\beta is a scalar with domain [0, 1] represented by b. The original formulation of a single timestep has this form:

St+1=St+βt(vtStkt)kt S_{t + 1} = S_{t} + \beta_{t}\left( v_{t} – S_{t}k_{t} \right){k_{t}}^{\top}

ot=St+1qt o_{t} = S_{t + 1}q_{t}

and can be implemented in numpy with:

def delta_net_step_ref(S, q, k, v, b):
  # Read what is currently stored in S at k
  vo = k @ S
  # Construct kernel and update the state
  S = S + np.outer(k, b * (v - vo))
  # Read what is stored in S at q
  o = q @ S
  return o, S

The recurrent form can be found by:

St+1=St+βt(vt)ktβt(Stkt)kt S_{t + 1} = S_{t} + \beta_{t}(v_{t}){k_{t}}^{\top} – \beta_{t}(S_{t}k_{t}){k_{t}}^{\top}

St+1=StStβtktkt+βtvtkt S_{t + 1} = S_{t} – S_{t}\beta_{t}k_{t}{k_{t}}^{\top} + \beta_{t}v_{t}{k_{t}}^{\top}

St+1=St(Iβtktkt)+βtvtkt S_{t + 1} = S_{t}(I – \beta_{t} k_{t}{k_{t}}^{\top}) + \beta_{t} v_{t}{k_{t}}^{\top}

ot=St+1qt o_{t} = S_{t + 1}q_{t}

This formulation effectively removes the contribution of the current key vector ktk_{t}​ from the memory state StS_{t}​, replacing it with a new memory vector constructed from the corresponding value vector vtv_{t}​. Here is a brief explanation: ktktk_{t}{k_{t}}^{\top} is a rank 1 outer product that extracts the component of any vector aligned with ktk_{t}. St(Iβktkt)S_{t}(I – {\beta k}_{t}{k_{t}}^{\top}) subtracts a scaled version of this vector from the original state StS_{t} effectively removing the contribution of ktk_{t} from StS_{t}. βtvtkt\beta_{t} v_{t}{k_{t}}^{\top} represents a scaled outer product representing a new memory component aligned with ktk_{t}, added to the resulting state. This form might be implemented as:

def delta_net_householder_step(S, q, k, v, b):
  # Compute Householder-like transformation matrix
  A = (
    np.eye(k.shape[-1], dtype=k.dtype) 
    b * np.outer(k, k)
  )
  # Update the state S using A and add the new kernel
  S = A @ S + np.outer(k, b * v)
  # Read what is stored in S at q
  o = q @ S
  return o, S

A quick test demonstrates that these are equivalent:

S = np.random.rand(3, 3)
q = np.random.rand(3)
q = q / np.linalg.norm(q, 2)
k = np.random.rand(3)
k = k / np.linalg.norm(k, 2)
v = np.random.rand(3)
b = 0.9

o_r, S_r = delta_net_step_ref(S, q, k, v, b)
print("Output for reference: ", o_r)
o_h, S_h = delta_net_step_householder(S, q, k, v, b)
print("Output for householder: ", o_h)
print("Difference in S: ", np.linalg.norm(S_r - S_h))

Output:
Output for original: [0.73660445 0.11329739 0.65290737]
Output for householder: [0.73660445 0.11329739 0.65290737]
Difference in S: 1.942890293094024e-16

Now we will verify the chunkwise formulation. We can implement a reference operation by simply iterating through a sequence of inputs:

def delta_net_ref(S, Q, K, V, B):
  O = []
  for i in range(K.shape[0]):
    o, S = delta_net_step_ref(S, Q[i], K[i], V[i], B[i])
    O.append(o)
  O = np.array(O)
  return O, S

For reference here are the components of chunkwise DeltaNet:

Tt=(Itril(Diag(βt)KtKt,1))1Diag(βt) T_{t} = \left( I – \text{tril}\left( \text{Diag}\left( \beta_{t} \right)K_{t}{K_{t}}^{\top}, – 1 \right) \right)^{- 1}\text{Diag}\left( \beta_{t} \right)

Wt=TtKt W_{t} = T_{t}K_{t}

Ut=TtVt U_{t} = T_{t}V_{t}

Ot=QtSt+(QtKt M)(UtWtSt) O_{t} = Q_{t}{S_{t}}^{\top} + (Q_{t}{K^{\top}}_{t}\ \odot M)(U_{t} – W_{t}S_{t}^{\top})

St+1=St+(UtWtSt)Kt S_{t + 1} = S_{t} + {(U_{t} – W_{t}S_{t}^{\top})}^{\top}K_{t}

This leads us to the following implementation:

def delta_net_chunk(S, Q, K, V, B):
  # Precompute before blockwise recurrent operation
  T = -np.tril(np.diag(B) @ K @ K.T, -1)
  # Compute the inverse of the lower triangular matrix
  # using recursive approach
  for i in range(1, K.shape[-2]):
    T[..., i, :i] = (
      T[..., i, :i] +
        (
          T[..., i, :, None].copy() *
          T[..., :, :i].copy()
        ).sum(-2)
    )
  T = T + np.eye(K.shape[-2], dtype=K.dtype)
  T = T @ np.diag(B)
  W = T @ K
  U = T @ V
  # Causal mask for comparing keys and queries
  M = np.tril(
    np.ones((Q.shape[-2], K.shape[-2]), dtype=K.dtype),
    0
  )
  A = Q @ K.T * M

  # Start recurrent loop
  O = Q @ S.T + A @ (U - W @ S.T)
  S = S + (U - W @ S.T).T @ K
  # End recurrent loop
  return O, S

To compute the matrix inverse TtT_{t}, we use a recurrent implementation based on the demonstration by Yang et al. (2024). While this naive implementation demonstrates correctness, we use specialized GPU kernels to maximize parallelism and efficiency in research and production contexts. We can now demonstrate that this formulation is correct with a quick demonstration of its equivalence to iterative delta net:

CHUNK_SIZE = 3
KEY_SIZE = 3
VALUE_SIZE = 3

S = np.random.rand(KEY_SIZE, VALUE_SIZE)
q = np.random.rand(CHUNK_SIZE, KEY_SIZE)
q = q / np.linalg.norm(q, 2, axis=-1, keepdims=True)
k = np.random.rand(CHUNK_SIZE, KEY_SIZE)
k = k / np.linalg.norm(k, 2, axis=-1, keepdims=True)
v = np.random.rand(CHUNK_SIZE, VALUE_SIZE)
b = np.random.random(CHUNK_SIZE)

O_r, S_r = delta_net_ref(S, q, k, v, b)
O_c, S_c = delta_net_chunk(S.T, q, k, v, b)

print("Output for reference:\n", O_r)
print("Output for chunk:\n", O_c)
print("State for reference:\n", S_r)
print("State for chunk:\n", S_c.T)
print("Difference in S:\n", np.linalg.norm(S_r - S_c.T))

Output:
Output for reference:
  [[0.61825371 0.87760879 0.84799235]
   [0.34656147 0.85150314 0.83266842]
   [0.29352307 0.35111954 0.728127  ]]
Output for chunk:
  [[0.61825371 0.87760879 0.84799235]
   [0.34656147 0.85150314 0.83266842]
   [0.29352307 0.35111954 0.728127  ]]
State for reference:
  [[ 0.16640016 0.32539042 0.65276613]
   [ 0.41033506 0.12561812 0.34827096]
   [-0.32403035 0.21163771 0.28064045]]
State for chunk:
  [[ 0.16640016 0.32539042 0.65276613]
   [ 0.41033506 0.12561812 0.34827096]
   [-0.32403035 0.21163771 0.28064045]]
Difference in S:
  3.152427400121712e-16

Again, the results of these functions are nearly perfectly aligned.

Flow of Operations

What are each of the operation that are executed in this algorithm? Here we will examine each of them and identify whether they are compatible with compressed key formats.

KK KK^{\top}

A comparison of the key matrices, can be pre-computed chunkwise before the chunkwise recurrence because it does not depend on the state. Shapes in the matmul are [chunk_size, key_size] x [key_size, chunk_size] = [chunk_size, chunk_size]. In a compressed key format this can be computed by a kernel that compares each of the keys to each other key within this chunk. In a symmetric power transformer architecture this would involve performing a dot product over all of the vectors and raising to n.

Tt=(Itril(Diag(βt)KtKt,1))1Diag(βt) T_{t} = \left( I – \text{tril}\left( \text{Diag}\left( \beta_{t} \right)K_{t}{K_{t}}^{\top}, – 1 \right) \right)^{- 1}\text{Diag}\left( \beta_{t} \right)

TtT_{t} can be pre-computed chunkwise as well because it also does not depend on the state. Yang et al (2024) provide a kernel for computing this for a materialized KtK_{t} matrix, one could also be constructed for other classes of compressed key representations, or one that consumes KtKtK_{t}K_{t}^{\top} to create flexibility with different comparison logic. TtT_{t} will have a shape: [chunk_size, chunk_size]

Wt=TtKt W_{t} = T_{t}K_{t}

Ut=TtVt U_{t} = T_{t}V_{t}

WtW_{t} and UtU_{t} can also be pre-computed before the recurrence, as again they don’t depend on the state. WtW_{t} creates a wrinkle for compressed key implementations, as it aggregates key matrix KtK_{t} along a sequence dimension. This requires the expanded key dimension to be materialized without the possibility of re-compressing the representation.

WtSt W_{t}S_{t}^{\top}

This operation reads the transformed keys WtW_{t} from the state. If using compressed keys, this will not be performed using compressed key logic, but instead with dense matrix multiplications.

(UtWtSt) (U_{t} – W_{t}S_{t}^{\top})

This simply subtracts what is read from the current state from the transformed values. Both of these tensors should be dense [chunk_size, value_size] representations.

QtSt Q_{t}{S_{t}}^{\top}

This is a read operation. Computed densely this is a simple matrix multiplication between matrix QtQ_{t} and the state. In compressed key implementations this would already need to be implemented as a part of supporting traditional linear transformers. Results in a [chunk_size, value_size] tensor.

(QtKt M)(UtWtSt) (Q_{t}{K^{\top}}_{t}\ \odot M)(U_{t} – W_{t}S_{t}^{\top})

Noting that UtWtStU_{t} – W_{t}S_{t}^{\top} has already been computed, this can be computed with a linear implementation of the flash attention kernel. A special kernel will need to be written for each compressed key format or a two stage process might be used instead, using the same kernel for KtKtK_{t}{K^{\top}}_{t} and QtKtQ_{t}{K^{\top}}_{t}. As it does not depend on the state, (QtKt M)(Q_{t}{K^{\top}}_{t}\ \odot M) might be pre-computed before the chunkwise recurrence. This should already be implemented for compressed key systems that support linear transformers.

Ot=QtSt+(QtKt M)(UtWtSt) O_{t} = Q_{t}{S_{t}}^{\top} + (Q_{t}{K^{\top}}_{t}\ \odot M)(U_{t} – W_{t}S_{t}^{\top})

This becomes a simple sum of the different output contributions.

St+1=St+(UtWtSt)Kt S_{t + 1} = S_{t} + {(U_{t} – W_{t}S_{t}^{\top})}^{\top}K_{t}

This becomes a C=AB+CC = AB + C in-place update where the state update ABAB is constructed by (UtWtSt)Kt{(U_{t} – W_{t}S_{t}^{\top})}^{\top}K_{t} in a simple matrix multiplication. This is another kernel that must be implemented for any compressed key format.

Limitations of the This Formulation

While the chunkwise DeltaNet formulation offers a significant improvement in parallelizing the recurrence across multiple timesteps, it inherently assumes that both the key and query vectors are fully materialized. This assumption becomes a bottleneck when considering memory and compute efficient compressed key representations.

Compressed key representations are designed to minimize the memory and computational overhead associated with storing and accessing the key and query vectors. Instead of materializing the full key or query matrix, these representations maintain a compact encoding that allows for efficient comparisons between keys and queries without expanding them into full-dimensional vectors. This compression is particularly effective when performing inner products, as it enables comparisons without ever creating the full key or query matrices in memory. Consequently, the key and query vectors only need to be expanded when interacting directly with the state matrix during read or write operations.

The original chunkwise DeltaNet formulation breaks this efficiency assumption due to how it computes the intermediate matrix WtW_{t}, defined as:

Wt=TtKt W_{t} = T_{t}K_{t}

The computation of WtW_{t} requires fully materializing the key matrix KtK_{t}, as the transformation matrix TtT_{t} operates on the chunk dimension requiring aggregation across the chunk dimension within key space. This results in WtW_{t} having shape [chunk_size,key_size]\left\lbrack \text{chunk}\text{\_}\text{size},\text{key}\text{\_}\text{size} \right\rbrack, even when the original key vectors could have been stored in a compressed representation of size [chunk_size,compressed_key_size]\left\lbrack \text{chunk}\text{\_}\text{size},\text{compressed}\text{\_}\text{key}\text{\_}\text{size} \right\rbrack. As a result, both the memory and compute overhead scale with the uncompressed key size, negating the benefits offered by compressed representations.
WtW_{t} and UtU_{t} appear in the formulation for several reasons. One is that they appear in the derivation as intermediary steps. Another is that they do not depend on the current state, meaning they can be pre-computed before state recurrence, decreasing the number of operations that need to be computed within the chunkwise recurrence.

In the following section, we will propose a reformulation that addresses this limitation by restructuring the update equations to avoid direct materialization of the key vectors wherever possible. This new approach retains the parallelization benefits of the chunkwise update while ensuring that the compressed representation of the key is only expanded when absolutely necessary, thus reducing both memory usage and computational cost for compressed key configurations.

Reformulation of the Chunkwise DeltaNet Update

The current formulation of the chunkwise DeltaNet introduces the intermediate matrices WtW_{t} and UtU_{t}, defined as:

Wt=TtKtandUt=TtVt W_{t} = T_{t}K_{t}\quad\text{and}\quad U_{t} = T_{t}V_{t}

These matrices emerge naturally from the derivation process and can be conveniently pre-computed but are not inherently required for the final state or output computation. Their introduction leads to the materialization of the full key matrix KtK_{t}, creating unnecessary computational and memory overhead in scenarios where compressed representations could be used. To address this, we propose a reformulation that eliminates these intermediate matrices by substituting their definitions directly into the state and output equations.

Reformulated State Update

Starting from the original update equation:

St+1=St+(UtWtSt)Kt S_{t + 1} = S_{t} + \left( U_{t} – W_{t}S_{t}^{\top} \right)^{\top}K_{t}

Substitute the definitions of Wt=TtKtW_{t} = T_{t}K_{t} and Ut=TtVtU_{t} = T_{t}V_{t}:

St+1=St+(TtVtTtKtSt)Kt S_{t + 1} = S_{t} + \left( T_{t}V_{t} – T_{t}K_{t}S_{t}^{\top} \right)^{\top}K_{t}

Factor out the common matrix TtT_{t}:

St+1=St+(Tt(VtKtSt))Kt S_{t + 1} = S_{t} + {{(T}_{t}(V_{t} – K_{t}S_{t}^{\top}))}^{\top}K_{t}

Thus, the reformulated state update avoids explicitly computing WtW_{t} and reduces key materialization to the core matrix multiplication involving KtK_{t} and StS_{t}^{\top}.

Reformulated Output Calculation

Similarly, the output computation from the original equation is:

Ot=QtSt+(QtKtM)(UtWtSt) O_{t} = Q_{t}S_{t}^{\top} + \left( Q_{t}K_{t}^{\top} \odot M \right)\left( U_{t} – W_{t}S_{t}^{\top} \right)

Substitute the definitions of WtW_{t} and UtU_{t}:

Ot=QtSt+(QtKtM)(TtVtTtKtSt) O_{t} = Q_{t}S_{t}^{\top} + \left( Q_{t}K_{t}^{\top} \odot M \right)\left( T_{t}V_{t} – T_{t}K_{t}S_{t}^{\top} \right)

Factor out TtT_{t}:

Ot=QtSt+(QtKtM)Tt(VtKtSt) O_{t} = Q_{t}S_{t}^{\top} + \left( Q_{t}K_{t}^{\top} \odot M \right)T_{t}\left( V_{t} – K_{t}S_{t}^{\top} \right)

Demonstration

We provide a quick demonstration to show that this retains numerical accuracy:

def delta_net_rework(S, Q, K, V, B):
  # Precompute before blockwise recurrent operation
  T = -np.tril(np.diag(B) @ K @ K.T, -1)
  # Compute the inverse of the lower triangular matrix
  # using recursive approach
  for i in range(1, K.shape[-2]):
    T[..., i, :i] = (
      T[..., i, :i] +
        (
          T[..., i, :, None].copy() *
          T[..., :, :i].copy()
        ).sum(-2)
    )
  T = T + np.eye(K.shape[-2], dtype=Q.dtype)
  T = T @ np.diag(B)
  # Causal mask for comparing keys and queries
  M = np.tril(np.ones((Q.shape[-2], K.shape[-2]), dtype=K.dtype), 0)
  A = Q @ K.T * M

  # Start recurrent loop
  U = T @ (V - K @ S.T)
  O = Q @ S.T + A @ U
  S = S + U.T @ K
  # End recurrent loop

  return O, S

CHUNK_SIZE = 3
KEY_SIZE = 3
VALUE_SIZE = 3

S = np.random.rand(VALUE_SIZE, KEY_SIZE)
q = np.random.rand(CHUNK_SIZE, KEY_SIZE)
q = q / np.linalg.norm(q, 2, axis=-1, keepdims=True)
k = np.random.rand(CHUNK_SIZE, KEY_SIZE)
k = k / np.linalg.norm(k, 2, axis=-1, keepdims=True)
v = np.random.rand(CHUNK_SIZE, VALUE_SIZE)
b = np.random.random(CHUNK_SIZE)

O_r, S_r = delta_net_ref(S, q, k, v, b)
O_c, S_c = delta_net_rework(S, q, k, v, b)

print("Output for reference:\n ", O_r)
print("Output for chunk:\n ", O_c)
print("State for reference:\n ", S_r)
print("State for chunk:\n ", S_c)
print("Difference in S: ", np.linalg.norm(S_r - S_c))

Output:
Output for reference:
  [[1.02962676 0.96376011 0.36513722]
   [0.39607067 0.37464636 0.61936491]
   [0.89844703 0.65667506 0.56667131]]
Output for chunk:
  [[1.02962676 0.96376011 0.36513722]
   [0.39607067 0.37464636 0.61936491]
   [0.89844703 0.65667506 0.56667131]]
State for reference:
  [[0.33786606 0.93419013 0.49567793]
   [0.36359558 0.47879147 0.36560537]
   [0.43815287 0.06677817 0.46018492]]
State for chunk:
  [[0.33786606 0.93419013 0.49567793]
   [0.36359558 0.47879147 0.36560537]
   [0.43815287 0.06677817 0.46018492]]
Difference in S: 1.1188630228279524e-16

Here we can see that the refactor did not impact the mathematical correctness of the approach, and we have empirically validated that they are still computing the same thing.

Expanding Memory Size and Capabilities

The refactoring presented above highlights an important step in making DeltaNet more robust and expressive. By removing the need to explicitly materialize the key vectors KtK_{t} at various stages in the computation, we open the door to compressed key representations. These compressed representations can drastically reduce both storage and compute requirements, giving the model the ability to handle very high-dimensional (and thus more expressive) keys. Because the memory capacity of DeltaNet—how many unique associations it can store—depends heavily on the dimensionality of its key space, leveraging compressed keys is an essential strategy for pushing toward infinite-context sequence modeling.

In essence, using compressed key representations allows us to scale up the “size” of the memory space without the prohibitive overhead associated with storing enormous matrices in main memory or GPU memory. Instead, the model only materializes these vectors when interacting with the state.

Gated Formulations and State Decay

Another formulation proposed by Yang et al. (2024) is what they describe as Gated DeltaNet, which introduces a state decay mechanism to the recurrent state update. From the perspective of truly infinite context, a naive application of state decay goes against the idea of preserving all historical information indefinitely. So why are we interested in this formulation? That is a topic we’re excited to dive into in a future post, where we will explore connections to computational neuroscience.

Looking Ahead

These reparameterization and refactoring strategies pave the way for DeltaNet to manage extremely large state dimensions and maintain performance across long sequences. By pairing chunkwise parallelization with compressed key strategies, we are one step closer to building highly efficient, infinite-context transformers. Stay tuned for our next posts, where we plan to:

  1. Cover the gated formulation, examining its motivations and potential benefits, as well as how it can be adapted in a similar manner for compressed key formats.

  2. Reveal how these memory dynamics relate to theories of cortical function in neuroscience.

  3. Perform a study of vector expansion, projection, and key compression mechanisms and their effect on orthogonality with respect to the source vectors.

  4. Dive deep into the memory capacity of a DeltaNet style architecture.

As large-language-model-style AI continues to evolve, the ability to train and deploy transformers with effectively unbounded memory capacity stands to significantly expand the scope of tasks our models can perform. The refactorization showcased here is just one puzzle piece in that effort—but it’s an essential one if we want to keep pushing the boundaries of what infinite-context transformers can do.