Improving Token-Based World Models with Parallel Observation Prediction (2024)

Lior Cohen  Kaixin Wang  Bingyi Kang  Shie Mannor

Abstract

Motivated by the success of Transformers when applied to sequences of discrete symbols, token-based world models (TBWMs) were recently proposed as sample-efficient methods.In TBWMs, the world model consumes agent experience as a language-like sequence of tokens, where each observation constitutes a sub-sequence.However, during imagination, the sequential token-by-token generation of next observations results in a severe bottleneck, leading to long training times, poor GPU utilization, and limited representations.To resolve this bottleneck, we devise a novel Parallel Observation Prediction (POP) mechanism.POP augments a Retentive Network (RetNet) with a novel forward mode tailored to our reinforcement learning setting.We incorporate POP in a novel TBWM agent named REM (Retentive Environment Model), showcasing a 15.4x faster imagination compared to prior TBWMs.REM attains superhuman performance on 12 out of 26 games of the Atari 100K benchmark, while training in less than 12 hours.Our code is available at https://github.com/leor-c/REM.

Machine Learning, ICML, World Model, Reinforcement Learning

1 Introduction

Sample efficiency remains a central challenge in reinforcement learning (RL) due to the substantial data demands of successful RL algorithms(Mnih etal., 2015; Silver etal., 2016; Schrittwieser etal., 2020; Berner etal., 2019; Vinyals etal., 2019).One prominent model-based approach for addressing this challenge is known as world models. In world models, the agent’s learning is solely based on simulated interaction data produced by a learned model of the environment through a process called imagination.World models are gaining increasing popularity due to their effectiveness, particularly in visual environments(Hafner etal., 2023).

Improving Token-Based World Models with Parallel Observation Prediction (1)

Improving Token-Based World Models with Parallel Observation Prediction (2)

Improving Token-Based World Models with Parallel Observation Prediction (3)

In recent years, attention-based sequence models, most notably the Transformer architecture (Vaswani etal., 2017), demonstrated unmatched performance in language modeling tasks (Devlin etal., 2019; Brown etal., 2020; Bubeck etal., 2023; Touvron etal., 2023).The notable success of these models when applied to sequences of discrete tokens sparked motivation to employ these architectures to other data modalities by learning appropriate token-based representations.In computer vision, discrete representations are becoming a mainstream approach for various tasks (vanden Oord etal., 2017; Dosovitskiy etal., 2021; Esser etal., 2021; Li etal., 2023).In RL, token-based world models were recently explored in visual environments (Micheli etal., 2023).The visual perception module in these methods is called a tokenizer, as it maps image observations to sequences of discrete symbols.This way, agent interaction is translated into a language-like sequence of discrete tokens, which are processed individually by the world model.

During imagination, to generate the tokens of the next observation with the auto-regressive model, the prediction is carried sequentially token-by-token.Effectively, this highly-sequential computation results in a severe bottleneck that pronouncedly hinders token-based approaches.Consequently, this bottleneck practically caps the length of the observation token sequences which in turn degrades performance.This limitation renders current token-based methods impractical for complex problems.

In this paper, we present Parallel Observation Prediction (POP), a novel mechanism that resolves the imagination bottleneck of token based world models (TBWMs).With POP, the enire next observation token sequence is generated in parallel during world model imagination.At its core, POP augments a Retentive Network (RetNet) sequence model (Sun etal., 2023) with a novel forward mode devised for retaining world model training efficiency.Additionally, we present REM (Retentive Environment Model), a TBWM agent driven by a POP-augmented RetNet architecture.

Our main contributions are summarized as follows:

  • We propose Parallel Observation Prediction (POP), a novel mechanism that resolves the inference bottleneck of current token-based world models while retaining performance.

  • We introduce REM, the first world model approach that incorporates the RetNet architecture. Our experiments provide first evidence of RetNet’s performance in an RL setting.

  • We evaluate REM on the Atari 100K benchmark, demonstrating the effectiveness of POP. POP leads to a 15.4x speed-up at imagination and trains in under 12 hours, while outperforming prior TBWMs.

2 Method

Notations.We consider the Partially Observable Markov Decision Process (POMDP) settingwith image observations 𝐨tΩh×w×3subscript𝐨𝑡Ωsuperscript𝑤3\mathbf{o}_{t}\in\Omega\subseteq\mathbb{R}^{h\times w\times 3}bold_o start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ∈ roman_Ω ⊆ blackboard_R start_POSTSUPERSCRIPT italic_h × italic_w × 3 end_POSTSUPERSCRIPT,discrete actions at𝒜subscript𝑎𝑡𝒜a_{t}\in\mathcal{A}italic_a start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ∈ caligraphic_A,scalar rewards rtsubscript𝑟𝑡r_{t}\in\mathbb{R}italic_r start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ∈ blackboard_R,episode termination signals dt{0,1}subscript𝑑𝑡01d_{t}\in\{0,1\}italic_d start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ∈ { 0 , 1 },dynamics 𝐨t+1,rt,dtp(𝐨t+1,rt,dt|𝐨t,at)similar-tosubscript𝐨𝑡1subscript𝑟𝑡subscript𝑑𝑡𝑝subscript𝐨𝑡1subscript𝑟𝑡conditionalsubscript𝑑𝑡subscript𝐨absent𝑡subscript𝑎absent𝑡\mathbf{o}_{t+1},r_{t},d_{t}\sim p(\mathbf{o}_{t+1},r_{t},d_{t}|\mathbf{o}_{%\leq t},a_{\leq t})bold_o start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT , italic_r start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_d start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ∼ italic_p ( bold_o start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT , italic_r start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_d start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT | bold_o start_POSTSUBSCRIPT ≤ italic_t end_POSTSUBSCRIPT , italic_a start_POSTSUBSCRIPT ≤ italic_t end_POSTSUBSCRIPT ),and discount factor γ𝛾\gammaitalic_γ.The objective is to learn a policy π𝜋\piitalic_π such that for every situation the output π(at|𝐨t,a<t)𝜋conditionalsubscript𝑎𝑡subscript𝐨absent𝑡subscript𝑎absent𝑡\pi(a_{t}|\mathbf{o}_{\leq t},a_{<t})italic_π ( italic_a start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT | bold_o start_POSTSUBSCRIPT ≤ italic_t end_POSTSUBSCRIPT , italic_a start_POSTSUBSCRIPT < italic_t end_POSTSUBSCRIPT ) is optimal w.r.t. the expected discounted sum of rewards from that situation 𝔼[τ=0γτRt+τ]𝔼superscriptsubscript𝜏0superscript𝛾𝜏subscript𝑅𝑡𝜏\operatorname*{\mathbb{E}}[\sum_{\tau=0}^{\infty}\gamma^{\tau}R_{t+\tau}]blackboard_E [ ∑ start_POSTSUBSCRIPT italic_τ = 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∞ end_POSTSUPERSCRIPT italic_γ start_POSTSUPERSCRIPT italic_τ end_POSTSUPERSCRIPT italic_R start_POSTSUBSCRIPT italic_t + italic_τ end_POSTSUBSCRIPT ] under the policy π𝜋\piitalic_π.

2.1 Overview

REM builds on IRIS (Micheli etal., 2023), and similar to most prior works on world models for pixel input(Hafner etal., 2021; Kaiser etal., 2020; Hafner etal., 2023), REM follows a 𝒱𝒱\mathcal{V}caligraphic_V-\mathcal{M}caligraphic_M-𝒞𝒞\mathcal{C}caligraphic_C structure(Ha & Schmidhuber, 2018):a 𝒱𝒱\mathcal{V}caligraphic_Visual perception module that compresses observations into compact latent representations, a predictive \mathcal{M}caligraphic_Model that captures the environment’s dynamics, and a 𝒞𝒞\mathcal{C}caligraphic_Controller that learns to act to maximize return.Additionally, a replay buffer is used to store environment interaction data.An overview of REM’s training cycle is presented in Figure 2.A pseudo-code algorithm of REM is presented in Appendix A.2.

𝒱𝒱\mathcal{V}caligraphic_V - Tokenizer

We instantiate the visual perception component as a tokenizer, mapping input observations into latent tokens.Following (Micheli etal., 2023), the tokenizer is a VQ-VAE discrete auto-encoder(vanden Oord etal., 2017; Esser etal., 2021) comprised of an encoder, a decoder, and an embedding table.The embedding table ={𝐞i}i=1NN×dsuperscriptsubscriptsubscript𝐞𝑖𝑖1𝑁superscript𝑁𝑑\mathcal{E}=\{\mathbf{e}_{i}\}_{i=1}^{N}\in\mathbb{R}^{N\times d}caligraphic_E = { bold_e start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_N × italic_d end_POSTSUPERSCRIPT consists of N𝑁Nitalic_N trainable vectors.The encoder first maps an input image 𝐨tsubscript𝐨𝑡\mathbf{o}_{t}bold_o start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT to a sequence of d𝑑ditalic_d-dimensional latent vectors (𝐡t1,𝐡t2,,𝐡tK)subscriptsuperscript𝐡1𝑡subscriptsuperscript𝐡2𝑡subscriptsuperscript𝐡𝐾𝑡(\mathbf{h}^{1}_{t},\mathbf{h}^{2}_{t},\cdots,\mathbf{h}^{K}_{t})( bold_h start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , bold_h start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , ⋯ , bold_h start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ).Then, each latent vector 𝐡tkdsubscriptsuperscript𝐡𝑘𝑡superscript𝑑\mathbf{h}^{k}_{t}\in\mathbb{R}^{d}bold_h start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT is mapped to the index of the nearest embedding in \mathcal{E}caligraphic_E, , ztk=argmini𝐡tk𝐞isuperscriptsubscript𝑧𝑡𝑘subscriptargmin𝑖normsuperscriptsubscript𝐡𝑡𝑘subscript𝐞𝑖z_{t}^{k}=\operatorname*{arg\,min}_{i}\|\mathbf{h}_{t}^{k}-\mathbf{e}_{i}\|italic_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT = start_OPERATOR roman_arg roman_min end_OPERATOR start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∥ bold_h start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT - bold_e start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∥.Such indices are called tokens.For an input image 𝐨tsubscript𝐨𝑡\mathbf{o}_{t}bold_o start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT, its latent token sequence is denoted as 𝐳t=(zt1,zt2,,ztK)subscript𝐳𝑡superscriptsubscript𝑧𝑡1superscriptsubscript𝑧𝑡2superscriptsubscript𝑧𝑡𝐾\mathbf{z}_{t}=(z_{t}^{1},z_{t}^{2},\cdots,z_{t}^{K})bold_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = ( italic_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT , italic_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT , ⋯ , italic_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT ).To map a token sequence back to the input space, we first retrieve the embedding for each token and obtain a sequence (𝐡^t1,𝐡^t2,,𝐡^tK)superscriptsubscript^𝐡𝑡1superscriptsubscript^𝐡𝑡2superscriptsubscript^𝐡𝑡𝐾(\hat{\mathbf{h}}_{t}^{1},\hat{\mathbf{h}}_{t}^{2},\cdots,\hat{\mathbf{h}}_{t}%^{K})( over^ start_ARG bold_h end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT , over^ start_ARG bold_h end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT , ⋯ , over^ start_ARG bold_h end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT ) where 𝐡^tk=𝐞ztksuperscriptsubscript^𝐡𝑡𝑘subscript𝐞superscriptsubscript𝑧𝑡𝑘\hat{\mathbf{h}}_{t}^{k}=\mathbf{e}_{z_{t}^{k}}over^ start_ARG bold_h end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT = bold_e start_POSTSUBSCRIPT italic_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT end_POSTSUBSCRIPT.Then, inverse to the encoding process, the decoder is responsible for mapping this sequence to a reconstructed observation 𝐨^tsubscript^𝐨𝑡\hat{\mathbf{o}}_{t}over^ start_ARG bold_o end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT.

The tokenizer is trained on frames sampled uniformly from the replay buffer.Its optimization objective, architecture, and other details are deferred to Appendix A.1.1.

\mathcal{M}caligraphic_M - World Model

At the core of a world model is the component that captures the dynamics of the environment and makes predictions based on historical observations.Here, \mathcal{M}caligraphic_M is learned entirely in the latent token space, modeling the following distributions at each step t𝑡titalic_t:

Transition:p(𝐳^t+1|𝐳1,a1,,𝐳t,at),𝑝conditionalsubscript^𝐳𝑡1subscript𝐳1subscript𝑎1subscript𝐳𝑡subscript𝑎𝑡\displaystyle p(\hat{\mathbf{z}}_{t+1}|\mathbf{z}_{1},a_{1},\ldots,\mathbf{z}_%{t},a_{t}),italic_p ( over^ start_ARG bold_z end_ARG start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT | bold_z start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_a start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , bold_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_a start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) ,(1)
Reward:p(r^t|𝐳1,a1,,𝐳t,at),𝑝conditionalsubscript^𝑟𝑡subscript𝐳1subscript𝑎1subscript𝐳𝑡subscript𝑎𝑡\displaystyle p(\hat{r}_{t}|\mathbf{z}_{1},a_{1},\ldots,\mathbf{z}_{t},a_{t}),italic_p ( over^ start_ARG italic_r end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT | bold_z start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_a start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , bold_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_a start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) ,(2)
Termination:p(d^t|𝐳1,a1,,𝐳t,at).𝑝conditionalsubscript^𝑑𝑡subscript𝐳1subscript𝑎1subscript𝐳𝑡subscript𝑎𝑡\displaystyle p(\hat{d}_{t}|\mathbf{z}_{1},a_{1},\ldots,\mathbf{z}_{t},a_{t}).italic_p ( over^ start_ARG italic_d end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT | bold_z start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_a start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , bold_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_a start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) .(3)

To map observation tokens to embedding vectors, \mathcal{M}caligraphic_M uses the code vectors \mathcal{E}caligraphic_E learned by the tokenizer 𝒱𝒱\mathcal{V}caligraphic_V.Note that \mathcal{E}caligraphic_E is not updated by \mathcal{M}caligraphic_M.In addition, \mathcal{M}caligraphic_M maintains dedicated embedding tables for mapping actions and special tokens (detailed in Section 2.3) to continuous vectors.

𝒞𝒞\mathcal{C}caligraphic_C - Controller

REM’s actor-critic controller 𝒞𝒞\mathcal{C}caligraphic_C is trained to maximize return entirely in imagination (Kaiser etal., 2020; Hafner etal., 2021; Micheli etal., 2023).𝒞𝒞\mathcal{C}caligraphic_C comprises of a policy network π𝜋\piitalic_π and a value function estimator Vπsuperscript𝑉𝜋V^{\pi}italic_V start_POSTSUPERSCRIPT italic_π end_POSTSUPERSCRIPT, and operates on latent tokens and their embeddings. In each optimization step, \mathcal{M}caligraphic_M and 𝒞𝒞\mathcal{C}caligraphic_C are initialized with a short trajectory segment sampled from the replay buffer.Subsequently, the agent interacts with the world model for H𝐻Hitalic_H steps.At each step t𝑡titalic_t, the agent plays an action sampled from its policy π(at|𝐳1,a1,,𝐳t1,at1,𝐳t)𝜋conditionalsubscript𝑎𝑡subscript𝐳1subscript𝑎1subscript𝐳𝑡1subscript𝑎𝑡1subscript𝐳𝑡\pi(a_{t}|\mathbf{z}_{1},a_{1},\ldots,\mathbf{z}_{t-1},a_{t-1},\mathbf{z}_{t})italic_π ( italic_a start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT | bold_z start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_a start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , bold_z start_POSTSUBSCRIPT italic_t - 1 end_POSTSUBSCRIPT , italic_a start_POSTSUBSCRIPT italic_t - 1 end_POSTSUBSCRIPT , bold_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ).The world model evolves accordingly, generating r^tsubscript^𝑟𝑡\hat{r}_{t}over^ start_ARG italic_r end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT, d^tsubscript^𝑑𝑡\hat{d}_{t}over^ start_ARG italic_d end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT, and 𝐳^t+1subscript^𝐳𝑡1\hat{\mathbf{z}}_{t+1}over^ start_ARG bold_z end_ARG start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT by sampling from the appropriate distributions (Eqn. (1-3)). The resulting trajectories are then used to train the agent.Following (Micheli etal., 2023), we adopted the actor-critic objectives of DreamerV2 (Hafner etal., 2021).We leave the full details of its architecture and optimization to Appendix A.1.3.

Improving Token-Based World Models with Parallel Observation Prediction (4)

2.2 Retention Preliminaries

Similar to Transformers (Vaswani etal., 2017), a RetNet model (Sun etal., 2023) consists of a stack of layers, where each layer contains a multi-head Attention-like mechanism, called Retention, followed by a fully-connected network.A unique characteristic of the Retention mechanism is that it has a dual form of recurrence and parallelism, called “chunkwise”, for improved efficiency when handling long sequences.This form allows to split such sequences into smaller “chunks”, where a parallel computation takes place within chunks and a sequential recurrent form is used between chunks, as shown in Figure3.The information from previous chunks is summarized by a recurrent state 𝐒d×d𝐒superscript𝑑𝑑\mathbf{S}\in\mathbb{R}^{d\times d}bold_S ∈ blackboard_R start_POSTSUPERSCRIPT italic_d × italic_d end_POSTSUPERSCRIPT maintained by the Retention mechanism.

Formally, consider a sequence of tokens (x1,x2,,xm)subscript𝑥1subscript𝑥2subscript𝑥𝑚(x_{1},x_{2},\cdots,x_{m})( italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_x start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , ⋯ , italic_x start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT ).In our RL context, this sequence is a token trajectory composed of observation-action sub-sequences (zt1,,ztK,at)superscriptsubscript𝑧𝑡1superscriptsubscript𝑧𝑡𝐾subscript𝑎𝑡(z_{t}^{1},\cdots,z_{t}^{K},a_{t})( italic_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT , ⋯ , italic_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT , italic_a start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) we call blocks.As such trajectories are typically long, we split them into chunks of B𝐵Bitalic_B tokens, where B=c(K+1)𝐵𝑐𝐾1B=c(K+1)italic_B = italic_c ( italic_K + 1 ) is a multiple of K+1𝐾1K+1italic_K + 1 so that each chunk only contains complete blocks.Here, the hyperparameter c𝑐citalic_c can be tuned according to the size of the models, the hardware, and other factors to maximize efficiency.Let 𝐗=(𝐱1,𝐱2,,𝐱m)m×d𝐗subscript𝐱1subscript𝐱2subscript𝐱𝑚superscript𝑚𝑑\mathbf{X}=(\mathbf{x}_{1},\mathbf{x}_{2},\cdots,\mathbf{x}_{m})\in\mathbb{R}^%{m\times d}bold_X = ( bold_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , bold_x start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , ⋯ , bold_x start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT ) ∈ blackboard_R start_POSTSUPERSCRIPT italic_m × italic_d end_POSTSUPERSCRIPT be the d𝑑ditalic_d-dimensional token embedding vectors.The Retention output 𝐘[i]=Retention(𝐗[i],𝐒[i1],i)subscript𝐘delimited-[]𝑖Retentionsubscript𝐗delimited-[]𝑖subscript𝐒delimited-[]𝑖1𝑖\mathbf{Y}_{[i]}=\mathrm{Retention}(\mathbf{X}_{[i]},\mathbf{S}_{[i-1]},i)bold_Y start_POSTSUBSCRIPT [ italic_i ] end_POSTSUBSCRIPT = roman_Retention ( bold_X start_POSTSUBSCRIPT [ italic_i ] end_POSTSUBSCRIPT , bold_S start_POSTSUBSCRIPT [ italic_i - 1 ] end_POSTSUBSCRIPT , italic_i ) of the i𝑖iitalic_i-th chunk is given by

𝐘[i]=(𝐐[i]𝐊[i]𝐃)𝐕[i]+(𝐐[i]𝐒[i1])𝝃,subscript𝐘delimited-[]𝑖direct-productsubscript𝐐delimited-[]𝑖subscriptsuperscript𝐊topdelimited-[]𝑖𝐃subscript𝐕delimited-[]𝑖direct-productsubscript𝐐delimited-[]𝑖subscript𝐒delimited-[]𝑖1𝝃\mathbf{Y}_{[i]}=\left(\mathbf{Q}_{[i]}\mathbf{K}^{\top}_{[i]}\odot\mathbf{D}%\right)\mathbf{V}_{[i]}+(\mathbf{Q}_{[i]}\mathbf{S}_{[i-1]})\odot\boldsymbol{%\xi},bold_Y start_POSTSUBSCRIPT [ italic_i ] end_POSTSUBSCRIPT = ( bold_Q start_POSTSUBSCRIPT [ italic_i ] end_POSTSUBSCRIPT bold_K start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT [ italic_i ] end_POSTSUBSCRIPT ⊙ bold_D ) bold_V start_POSTSUBSCRIPT [ italic_i ] end_POSTSUBSCRIPT + ( bold_Q start_POSTSUBSCRIPT [ italic_i ] end_POSTSUBSCRIPT bold_S start_POSTSUBSCRIPT [ italic_i - 1 ] end_POSTSUBSCRIPT ) ⊙ bold_italic_ξ ,(4)

where the bracketed subscript [i]delimited-[]𝑖[i][ italic_i ] is used to index the i𝑖iitalic_i-th chunk, 𝐐=(𝐗𝐖Q)Θ𝐐direct-productsubscript𝐗𝐖𝑄Θ\mathbf{Q}=\left(\mathbf{X}\mathbf{W}_{Q}\right)\odot\Thetabold_Q = ( bold_XW start_POSTSUBSCRIPT italic_Q end_POSTSUBSCRIPT ) ⊙ roman_Θ, 𝐊=(𝐗𝐖K)Θ¯𝐊direct-productsubscript𝐗𝐖𝐾¯Θ\mathbf{K}=\left(\mathbf{X}\mathbf{W}_{K}\right)\odot\bar{\Theta}bold_K = ( bold_XW start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT ) ⊙ over¯ start_ARG roman_Θ end_ARG, 𝐕=𝐗𝐖V𝐕subscript𝐗𝐖𝑉\mathbf{V}=\mathbf{X}\mathbf{W}_{V}bold_V = bold_XW start_POSTSUBSCRIPT italic_V end_POSTSUBSCRIPT, and 𝝃B×d𝝃superscript𝐵𝑑\boldsymbol{\xi}\in\mathbb{R}^{B\times d}bold_italic_ξ ∈ blackboard_R start_POSTSUPERSCRIPT italic_B × italic_d end_POSTSUPERSCRIPT is a matrix with 𝝃ij=ηi+1subscript𝝃𝑖𝑗superscript𝜂𝑖1\boldsymbol{\xi}_{ij}=\eta^{i+1}bold_italic_ξ start_POSTSUBSCRIPT italic_i italic_j end_POSTSUBSCRIPT = italic_η start_POSTSUPERSCRIPT italic_i + 1 end_POSTSUPERSCRIPT.Here,𝐖Q,𝐖K,𝐖Vd×dsubscript𝐖𝑄subscript𝐖𝐾subscript𝐖𝑉superscript𝑑𝑑\mathbf{W}_{Q},\mathbf{W}_{K},\mathbf{W}_{V}\in\mathbb{R}^{d\times d}bold_W start_POSTSUBSCRIPT italic_Q end_POSTSUBSCRIPT , bold_W start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT , bold_W start_POSTSUBSCRIPT italic_V end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_d × italic_d end_POSTSUPERSCRIPT are learnable weights, η𝜂\etaitalic_η is an exponential decay factor, the matrix 𝐃B×B𝐃superscript𝐵𝐵\mathbf{D}\in\mathbb{R}^{B\times B}bold_D ∈ blackboard_R start_POSTSUPERSCRIPT italic_B × italic_B end_POSTSUPERSCRIPT combines an auto-regressive mask with the temporal decay factor η𝜂\etaitalic_η, and the matrices Θ,Θ¯m×dΘ¯Θsuperscript𝑚𝑑\Theta,\bar{\Theta}\in\mathbb{C}^{m\times d}roman_Θ , over¯ start_ARG roman_Θ end_ARG ∈ blackboard_C start_POSTSUPERSCRIPT italic_m × italic_d end_POSTSUPERSCRIPT are for relative position embedding (see Appendix A.3).Note the chunk index i𝑖iitalic_i argument of the Retention operator, which controls positional embedding information through ΘΘ\Thetaroman_Θ.The chunkwise update rule of the recurrent state is given by

𝐒[i]=(𝐊[i]𝜻)𝐕[i]+ηB𝐒[i1]subscript𝐒delimited-[]𝑖superscriptdirect-productsubscript𝐊delimited-[]𝑖𝜻topsubscript𝐕delimited-[]𝑖superscript𝜂𝐵subscript𝐒delimited-[]𝑖1\mathbf{S}_{[i]}=(\mathbf{K}_{[i]}\odot\boldsymbol{\zeta})^{\top}\mathbf{V}_{[%i]}+\eta^{B}\mathbf{S}_{[i-1]}bold_S start_POSTSUBSCRIPT [ italic_i ] end_POSTSUBSCRIPT = ( bold_K start_POSTSUBSCRIPT [ italic_i ] end_POSTSUBSCRIPT ⊙ bold_italic_ζ ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_V start_POSTSUBSCRIPT [ italic_i ] end_POSTSUBSCRIPT + italic_η start_POSTSUPERSCRIPT italic_B end_POSTSUPERSCRIPT bold_S start_POSTSUBSCRIPT [ italic_i - 1 ] end_POSTSUBSCRIPT(5)

where 𝐒[0]=𝐒0=0subscript𝐒delimited-[]0subscript𝐒00\mathbf{S}_{[0]}=\mathbf{S}_{0}=0bold_S start_POSTSUBSCRIPT [ 0 ] end_POSTSUBSCRIPT = bold_S start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT = 0, and 𝜻B×d𝜻superscript𝐵𝑑\boldsymbol{\zeta}\in\mathbb{R}^{B\times d}bold_italic_ζ ∈ blackboard_R start_POSTSUPERSCRIPT italic_B × italic_d end_POSTSUPERSCRIPT is a matrix with 𝜻ij=ηBi1subscript𝜻𝑖𝑗superscript𝜂𝐵𝑖1\boldsymbol{\zeta}_{ij}=\eta^{B-i-1}bold_italic_ζ start_POSTSUBSCRIPT italic_i italic_j end_POSTSUBSCRIPT = italic_η start_POSTSUPERSCRIPT italic_B - italic_i - 1 end_POSTSUPERSCRIPT.On the right hand side of Equations 4 and 5, the first term corresponds to the computation within the chunk while the second term incorporates the information from previous chunks, encapsulated by the recurrent state.Further details about the RetNet architecture are deferred to Appendix A.3.

Improving Token-Based World Models with Parallel Observation Prediction (5)

2.3 World Model Imagination

As the agent’s training relies entirely on world model imagination, the efficiency of the trajectory generation is critical.During imagination, predicting 𝐳^t+1subscript^𝐳𝑡1\hat{\mathbf{z}}_{t+1}over^ start_ARG bold_z end_ARG start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT constitutes the primary non-trivial component and consumes the majority of processing time.In IRIS, the prediction of 𝐳^t+1subscript^𝐳𝑡1\hat{\mathbf{z}}_{t+1}over^ start_ARG bold_z end_ARG start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT unfolds sequentially, as the model is limited to predicting only one token ahead at each step.This limitation arises since the identity of the next token, which remains unknown at the current step, is necessary for the prediction of later tokens.Thus, generating H𝐻Hitalic_H observations costs KH𝐾𝐻KHitalic_K italic_H sequential world model calls.This leads to poor GPU utilization and long computation time.

To overcome this bottleneck, POP maintains a set of K𝐾Kitalic_K dedicated prediction tokens 𝐮=(u1,,uK)𝐮subscript𝑢1subscript𝑢𝐾\mathbf{u}=(u_{1},\ldots,u_{K})bold_u = ( italic_u start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_u start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT ) together with their corresponding embeddings 𝐮K×dsubscript𝐮superscript𝐾𝑑\mathcal{E}_{\mathbf{u}}\in\mathbb{R}^{K\times d}caligraphic_E start_POSTSUBSCRIPT bold_u end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_K × italic_d end_POSTSUPERSCRIPT.To generate 𝐳^t+1subscript^𝐳𝑡1\hat{\mathbf{z}}_{t+1}over^ start_ARG bold_z end_ARG start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT in one pass, POP simply computes the RetNet outputs starting from 𝐒tsubscript𝐒𝑡\mathbf{S}_{t}bold_S start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT using 𝐮𝐮\mathbf{u}bold_u as its input sequence, as illustrated in Figure4.Note that at imagination, the chunk size is limited to a single block, i.e., to K+1𝐾1K+1italic_K + 1.Here, the notation 𝐒tsubscript𝐒𝑡\mathbf{S}_{t}bold_S start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT refers to the state that summarizes the first t𝑡titalic_t observation-action blocks.To obtain 𝐒tsubscript𝐒𝑡\mathbf{S}_{t}bold_S start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT, we use RetNet’s chunkwise forward to summarize an initial context segment of blocks sampled from the replay buffer.Essentially, for every t𝑡titalic_t, POP models the following distribution for next observation prediction:

p(𝐳^t+1|𝐳1,a1,,𝐳t,at,𝐮)𝑝conditionalsubscript^𝐳𝑡1subscript𝐳1subscript𝑎1subscript𝐳𝑡subscript𝑎𝑡𝐮p(\hat{\mathbf{z}}_{t+1}|\mathbf{z}_{1},a_{1},\ldots,\mathbf{z}_{t},a_{t},%\mathbf{u})italic_p ( over^ start_ARG bold_z end_ARG start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT | bold_z start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_a start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , bold_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_a start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , bold_u )

with

p(𝐳^t+1k|𝐳1,a1,,𝐳t,at,𝐮k).𝑝conditionalsuperscriptsubscript^𝐳𝑡1𝑘subscript𝐳1subscript𝑎1subscript𝐳𝑡subscript𝑎𝑡subscript𝐮absent𝑘p(\hat{\mathbf{z}}_{t+1}^{k}|\mathbf{z}_{1},a_{1},\ldots,\mathbf{z}_{t},a_{t},%\mathbf{u}_{\leq k}).italic_p ( over^ start_ARG bold_z end_ARG start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT | bold_z start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_a start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , bold_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_a start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , bold_u start_POSTSUBSCRIPT ≤ italic_k end_POSTSUBSCRIPT ) .

It is worth noting that the tokens 𝐮𝐮\mathbf{u}bold_u are only intended for observation token predictions and are never employed in the update of the recurrent state.

This approach effectively reduces the total number of world model calls during imagination from KH𝐾𝐻KHitalic_K italic_H to 2H2𝐻2H2 italic_H, eliminating the dependency on the number of observation tokens K𝐾Kitalic_K.In fact, POP provides an additional generation mode that further reduces the number of sequential calls to H𝐻Hitalic_H.We defer the details on this mode to Appendix A.1.2.Also, by using a recurrent state that summarizes long history sequences, POP improves efficiency further, as the per-token prediction cost reduces.Effectively, POP offers improved scalability at the expense of a higher overall computational cost ((2K+1)H2𝐾1𝐻(2K+1)H( 2 italic_K + 1 ) italic_H compared to (K+1)H𝐾1𝐻(K+1)H( italic_K + 1 ) italic_H).Our approach add to existing evidence suggesting that enhanced scalability is often favorable, even at the expense of additional computational costs, with Transformers (Vaswani etal., 2017) serving as a prominent example.

Improving Token-Based World Models with Parallel Observation Prediction (6)

2.4 World Model Training

While applying POP during imagination is fairly straightforward, it requires modification of the training data.Consider an input trajectory segment (𝐳1,a1,,𝐳T,aT)subscript𝐳1subscript𝑎1subscript𝐳𝑇subscript𝑎𝑇(\mathbf{z}_{1},a_{1},\ldots,\mathbf{z}_{T},a_{T})( bold_z start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_a start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , bold_z start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT , italic_a start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT ) sampled from the replay buffer.To make meaningful observation predictions at imagination, the model should be trained to predict 𝐳tsubscript𝐳𝑡\mathbf{z}_{t}bold_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT given (𝐳1,a1,,𝐳t1,at1,𝐮)subscript𝐳1subscript𝑎1subscript𝐳𝑡1subscript𝑎𝑡1𝐮(\mathbf{z}_{1},a_{1},\ldots,\mathbf{z}_{t-1},a_{t-1},\mathbf{u})( bold_z start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_a start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , bold_z start_POSTSUBSCRIPT italic_t - 1 end_POSTSUBSCRIPT , italic_a start_POSTSUBSCRIPT italic_t - 1 end_POSTSUBSCRIPT , bold_u ), for each time step t𝑡titalic_t of every input segment.Hence, for every t𝑡titalic_t, the input sequence should contain 𝐮𝐮\mathbf{u}bold_u at block t𝑡titalic_t. However, replacing 𝐳tsubscript𝐳𝑡\mathbf{z}_{t}bold_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT with 𝐮𝐮\mathbf{u}bold_u in the original sequence is inadequate, as the prediction of future observations, rewards, and termination signals depends on 𝐳tsubscript𝐳𝑡\mathbf{z}_{t}bold_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT.Thus, the standard approach of computing all outputs from the same input sequence is not viable, as in this case these two requirements contradict each other (Figure5).The challenge then lies in devising an efficient method for computing the outputs for all time steps in parallel.

To tackle this challenge, we first note that each trajectory prefix can be summarized into a single recurrent state.For example, for the first input chunk (𝐳1,a1,,𝐳c,ac)subscript𝐳1subscript𝑎1subscript𝐳𝑐subscript𝑎𝑐(\mathbf{z}_{1},a_{1},\ldots,\mathbf{z}_{c},a_{c})( bold_z start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_a start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , bold_z start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT , italic_a start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT ),(𝐳1,a1)subscript𝐳1subscript𝑎1(\mathbf{z}_{1},a_{1})( bold_z start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_a start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) can be summarized into 𝐒[1,1]subscript𝐒11\mathbf{S}_{[1,1]}bold_S start_POSTSUBSCRIPT [ 1 , 1 ] end_POSTSUBSCRIPT and (𝐳1,a1,𝐳2,a2)subscript𝐳1subscript𝑎1subscript𝐳2subscript𝑎2(\mathbf{z}_{1},a_{1},\mathbf{z}_{2},a_{2})( bold_z start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_a start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , bold_z start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , italic_a start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) can be summarized into 𝐒[1,2]subscript𝐒12\mathbf{S}_{[1,2]}bold_S start_POSTSUBSCRIPT [ 1 , 2 ] end_POSTSUBSCRIPT.Here, we use the subscript [i,j]𝑖𝑗[i,j][ italic_i , italic_j ] to conveniently refer to the j𝑗jitalic_j-th block within the i𝑖iitalic_i-th chunk (this notation is demonstrated in Figure 3), with 𝐒[i,0]=𝐒[i1]subscript𝐒𝑖0subscript𝐒delimited-[]𝑖1\mathbf{S}_{[i,0]}=\mathbf{S}_{[i-1]}bold_S start_POSTSUBSCRIPT [ italic_i , 0 ] end_POSTSUBSCRIPT = bold_S start_POSTSUBSCRIPT [ italic_i - 1 ] end_POSTSUBSCRIPT and 𝐒[i,c]=𝐒[i]subscript𝐒𝑖𝑐subscript𝐒delimited-[]𝑖\mathbf{S}_{[i,c]}=\mathbf{S}_{[i]}bold_S start_POSTSUBSCRIPT [ italic_i , italic_c ] end_POSTSUBSCRIPT = bold_S start_POSTSUBSCRIPT [ italic_i ] end_POSTSUBSCRIPT.Thus, our plan is to first compute all states 𝐒[i,1],,𝐒[i,c]subscript𝐒𝑖1subscript𝐒𝑖𝑐\mathbf{S}_{[i,1]},\ldots,\mathbf{S}_{[i,c]}bold_S start_POSTSUBSCRIPT [ italic_i , 1 ] end_POSTSUBSCRIPT , … , bold_S start_POSTSUBSCRIPT [ italic_i , italic_c ] end_POSTSUBSCRIPT in parallel, and then predict all next observations from all (𝐒[i,j],𝐮)subscript𝐒𝑖𝑗𝐮(\mathbf{S}_{[i,j]},\mathbf{u})( bold_S start_POSTSUBSCRIPT [ italic_i , italic_j ] end_POSTSUBSCRIPT , bold_u ) tuples.

Improving Token-Based World Models with Parallel Observation Prediction (7)

To compute all recurrent states 𝐒[i,j]subscript𝐒𝑖𝑗\mathbf{S}_{[i,j]}bold_S start_POSTSUBSCRIPT [ italic_i , italic_j ] end_POSTSUBSCRIPT in parallel, a two-step computation is carried.First, intermediate states 𝐒~[i,j]subscript~𝐒𝑖𝑗\tilde{\mathbf{S}}_{[i,j]}over~ start_ARG bold_S end_ARG start_POSTSUBSCRIPT [ italic_i , italic_j ] end_POSTSUBSCRIPT are computed in parallel for all j𝑗jitalic_j with

𝐒~[i,j]=(𝐊[i,j]𝜻)𝐕[i,j],subscript~𝐒𝑖𝑗superscriptdirect-productsubscript𝐊𝑖𝑗𝜻topsubscript𝐕𝑖𝑗\tilde{\mathbf{S}}_{[i,j]}=\left(\mathbf{K}_{[i,j]}\odot\boldsymbol{\zeta}%\right)^{\top}\mathbf{V}_{[i,j]},over~ start_ARG bold_S end_ARG start_POSTSUBSCRIPT [ italic_i , italic_j ] end_POSTSUBSCRIPT = ( bold_K start_POSTSUBSCRIPT [ italic_i , italic_j ] end_POSTSUBSCRIPT ⊙ bold_italic_ζ ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_V start_POSTSUBSCRIPT [ italic_i , italic_j ] end_POSTSUBSCRIPT ,(6)

where 𝜻(K+1)×d𝜻superscript𝐾1𝑑\boldsymbol{\zeta}\in\mathbb{R}^{(K+1)\times d}bold_italic_ζ ∈ blackboard_R start_POSTSUPERSCRIPT ( italic_K + 1 ) × italic_d end_POSTSUPERSCRIPT is a matrix with 𝜻ij=ηKisubscript𝜻𝑖𝑗superscript𝜂𝐾𝑖\boldsymbol{\zeta}_{ij}=\eta^{K-i}bold_italic_ζ start_POSTSUBSCRIPT italic_i italic_j end_POSTSUBSCRIPT = italic_η start_POSTSUPERSCRIPT italic_K - italic_i end_POSTSUPERSCRIPT.Then, each recurrent state is computed sequentially by

𝐒[i,j]=𝐒~[i,j]+ηK+1𝐒[i,j1].subscript𝐒𝑖𝑗subscript~𝐒𝑖𝑗superscript𝜂𝐾1subscript𝐒𝑖𝑗1\mathbf{S}_{[i,j]}=\tilde{\mathbf{S}}_{[i,j]}+\eta^{K+1}\mathbf{S}_{[i,j-1]}.bold_S start_POSTSUBSCRIPT [ italic_i , italic_j ] end_POSTSUBSCRIPT = over~ start_ARG bold_S end_ARG start_POSTSUBSCRIPT [ italic_i , italic_j ] end_POSTSUBSCRIPT + italic_η start_POSTSUPERSCRIPT italic_K + 1 end_POSTSUPERSCRIPT bold_S start_POSTSUBSCRIPT [ italic_i , italic_j - 1 ] end_POSTSUBSCRIPT .(7)

As the majority of the computational burden lies in the first step, the sequential computation in the second step has minimal impact on the overall speedup.

Once we have all states ready, the output of (𝐒[i,j],𝐮)subscript𝐒𝑖𝑗𝐮(\mathbf{S}_{[i,j]},\mathbf{u})( bold_S start_POSTSUBSCRIPT [ italic_i , italic_j ] end_POSTSUBSCRIPT , bold_u ) for all 1jc1𝑗𝑐1\leq j\leq c1 ≤ italic_j ≤ italic_c is computed in parallel.Here, we stress that the existing Retention mechanism can only perform batched input computation with recurrent states 𝐒tsubscript𝐒𝑡\mathbf{S}_{t}bold_S start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT of the same time step t𝑡titalic_t.This is due to the shared positional embedding information applied to every input sequence in the batch.To overcome this, we devise a mechanism which extends RetNet to support the batched computation of the (𝐒[i,j],𝐮)subscript𝐒𝑖𝑗𝐮(\mathbf{S}_{[i,j]},\mathbf{u})( bold_S start_POSTSUBSCRIPT [ italic_i , italic_j ] end_POSTSUBSCRIPT , bold_u ) tuples, while applying the appropriate positional encoding information.A pseudo code of our novel POP extension of RetNet is given in Algorithms 1 and 2.The latter presents the core of the mechanism (described above), while the former describes the higher level layer-by-layer computation with a final aggregation for combining the produced outputs.Figure 6 illustrates a simplified example of the POP Forward mechanism (Algorithms 1 and 2) for a single-layer model.For brevity, our pseudo code and illustrations only considers Retention layers, omitting other modules of RetNet (Appendix A.3).

1:Input: chunk size 1cH1𝑐𝐻1\leq c\leq H1 ≤ italic_c ≤ italic_H,token embeddings 𝐗[i]subscript𝐗delimited-[]𝑖\mathbf{X}_{[i]}bold_X start_POSTSUBSCRIPT [ italic_i ] end_POSTSUBSCRIPT of chunk i𝑖iitalic_i,per-layer recurrent states {𝐒[i1]l}l=1Lsuperscriptsubscriptsuperscriptsubscript𝐒delimited-[]𝑖1𝑙𝑙1𝐿\{\mathbf{S}_{[i-1]}^{l}\}_{l=1}^{L}{ bold_S start_POSTSUBSCRIPT [ italic_i - 1 ] end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_l end_POSTSUPERSCRIPT } start_POSTSUBSCRIPT italic_l = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT.

2:Initialize 𝐀[i]0𝐗subscriptsuperscript𝐀0delimited-[]𝑖𝐗\mathbf{A}^{0}_{[i]}\leftarrow\mathbf{X}bold_A start_POSTSUPERSCRIPT 0 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT [ italic_i ] end_POSTSUBSCRIPT ← bold_X

3:Initialize 𝐁[i,1]0,,𝐁[i,c]0𝐮,,𝐮formulae-sequencesubscriptsuperscript𝐁0𝑖1subscriptsuperscript𝐁0𝑖𝑐subscript𝐮subscript𝐮\mathbf{B}^{0}_{[i,1]},\ldots,\mathbf{B}^{0}_{[i,c]}\leftarrow\mathcal{E}_{%\mathbf{u}},\ldots,\mathcal{E}_{\mathbf{u}}bold_B start_POSTSUPERSCRIPT 0 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT [ italic_i , 1 ] end_POSTSUBSCRIPT , … , bold_B start_POSTSUPERSCRIPT 0 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT [ italic_i , italic_c ] end_POSTSUBSCRIPT ← caligraphic_E start_POSTSUBSCRIPT bold_u end_POSTSUBSCRIPT , … , caligraphic_E start_POSTSUBSCRIPT bold_u end_POSTSUBSCRIPT

4:forl=1𝑙1l=1italic_l = 1 to L𝐿Litalic_Ldo

5:𝐀[i]l,𝐁[i]l,𝐒[i]lPOPLayer(𝐀[i]l1,𝐁[i]l1,𝐒[i1]l,i)subscriptsuperscript𝐀𝑙delimited-[]𝑖subscriptsuperscript𝐁𝑙delimited-[]𝑖superscriptsubscript𝐒delimited-[]𝑖𝑙POPLayersubscriptsuperscript𝐀𝑙1delimited-[]𝑖subscriptsuperscript𝐁𝑙1delimited-[]𝑖superscriptsubscript𝐒delimited-[]𝑖1𝑙𝑖\mathbf{A}^{l}_{[i]},\mathbf{B}^{l}_{[i]},\mathbf{S}_{[i]}^{l}\leftarrow\text{%POPLayer}(\mathbf{A}^{l-1}_{[i]},\mathbf{B}^{l-1}_{[i]},\mathbf{S}_{[i-1]}^{l}%,i)bold_A start_POSTSUPERSCRIPT italic_l end_POSTSUPERSCRIPT start_POSTSUBSCRIPT [ italic_i ] end_POSTSUBSCRIPT , bold_B start_POSTSUPERSCRIPT italic_l end_POSTSUPERSCRIPT start_POSTSUBSCRIPT [ italic_i ] end_POSTSUBSCRIPT , bold_S start_POSTSUBSCRIPT [ italic_i ] end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_l end_POSTSUPERSCRIPT ← POPLayer ( bold_A start_POSTSUPERSCRIPT italic_l - 1 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT [ italic_i ] end_POSTSUBSCRIPT , bold_B start_POSTSUPERSCRIPT italic_l - 1 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT [ italic_i ] end_POSTSUBSCRIPT , bold_S start_POSTSUBSCRIPT [ italic_i - 1 ] end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_l end_POSTSUPERSCRIPT , italic_i )

6:endfor

7:forj=1𝑗1j=1italic_j = 1 to c𝑐citalic_cdo

8:𝐘[i,j]Concat(𝐁[i,j]L,𝐀[i,j,K+1]L)subscript𝐘𝑖𝑗Concatsubscriptsuperscript𝐁𝐿𝑖𝑗subscriptsuperscript𝐀𝐿𝑖𝑗𝐾1\mathbf{Y}_{[i,j]}\leftarrow\text{Concat}(\mathbf{B}^{L}_{[i,j]},\mathbf{A}^{L%}_{[i,j,K+1]})bold_Y start_POSTSUBSCRIPT [ italic_i , italic_j ] end_POSTSUBSCRIPT ← Concat ( bold_B start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT start_POSTSUBSCRIPT [ italic_i , italic_j ] end_POSTSUBSCRIPT , bold_A start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT start_POSTSUBSCRIPT [ italic_i , italic_j , italic_K + 1 ] end_POSTSUBSCRIPT )

9:endfor

10:Return 𝐘,{𝐒[i]l}l=1L𝐘superscriptsubscriptsuperscriptsubscript𝐒delimited-[]𝑖𝑙𝑙1𝐿\mathbf{Y},\{\mathbf{S}_{[i]}^{l}\}_{l=1}^{L}bold_Y , { bold_S start_POSTSUBSCRIPT [ italic_i ] end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_l end_POSTSUPERSCRIPT } start_POSTSUBSCRIPT italic_l = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT


1:Input: Chunk latents 𝐀[i]subscript𝐀delimited-[]𝑖\mathbf{A}_{[i]}bold_A start_POSTSUBSCRIPT [ italic_i ] end_POSTSUBSCRIPT, observation prediction latents 𝐁[i]subscript𝐁delimited-[]𝑖\mathbf{B}_{[i]}bold_B start_POSTSUBSCRIPT [ italic_i ] end_POSTSUBSCRIPT, recurrent state 𝐒[i1]subscript𝐒delimited-[]𝑖1\mathbf{S}_{[i-1]}bold_S start_POSTSUBSCRIPT [ italic_i - 1 ] end_POSTSUBSCRIPT, chunk index i𝑖iitalic_i.

2:𝐀[i]Retention(𝐀[i],𝐒[i1],i)subscript𝐀delimited-[]𝑖Retentionsubscript𝐀delimited-[]𝑖subscript𝐒delimited-[]𝑖1𝑖\mathbf{A}_{[i]}\leftarrow\mathrm{Retention}(\mathbf{A}_{[i]},\mathbf{S}_{[i-1%]},i)bold_A start_POSTSUBSCRIPT [ italic_i ] end_POSTSUBSCRIPT ← roman_Retention ( bold_A start_POSTSUBSCRIPT [ italic_i ] end_POSTSUBSCRIPT , bold_S start_POSTSUBSCRIPT [ italic_i - 1 ] end_POSTSUBSCRIPT , italic_i ) (Eqn. 4)

3:Compute 𝐒~[i,1],,𝐒~[i,c]subscript~𝐒𝑖1subscript~𝐒𝑖𝑐\tilde{\mathbf{S}}_{[i,1]},\ldots,\tilde{\mathbf{S}}_{[i,c]}over~ start_ARG bold_S end_ARG start_POSTSUBSCRIPT [ italic_i , 1 ] end_POSTSUBSCRIPT , … , over~ start_ARG bold_S end_ARG start_POSTSUBSCRIPT [ italic_i , italic_c ] end_POSTSUBSCRIPT in parallel (Eqn. 6)

4:forj=1𝑗1j=1italic_j = 1 to c𝑐citalic_c do

5:𝐒[i,j]𝐒~[i,j]+ηK+1𝐒[i,j1]subscript𝐒𝑖𝑗subscript~𝐒𝑖𝑗superscript𝜂𝐾1subscript𝐒𝑖𝑗1\mathbf{S}_{[i,j]}\leftarrow\tilde{\mathbf{S}}_{[i,j]}+\eta^{K+1}\mathbf{S}_{[%i,j-1]}bold_S start_POSTSUBSCRIPT [ italic_i , italic_j ] end_POSTSUBSCRIPT ← over~ start_ARG bold_S end_ARG start_POSTSUBSCRIPT [ italic_i , italic_j ] end_POSTSUBSCRIPT + italic_η start_POSTSUPERSCRIPT italic_K + 1 end_POSTSUPERSCRIPT bold_S start_POSTSUBSCRIPT [ italic_i , italic_j - 1 ] end_POSTSUBSCRIPT (Eqn. 7)

6:endfor

7:𝐒[i]𝐒[i,c]subscript𝐒delimited-[]𝑖subscript𝐒𝑖𝑐\mathbf{S}_{[i]}\leftarrow\mathbf{S}_{[i,c]}bold_S start_POSTSUBSCRIPT [ italic_i ] end_POSTSUBSCRIPT ← bold_S start_POSTSUBSCRIPT [ italic_i , italic_c ] end_POSTSUBSCRIPT

8:𝐁[i,j]Retention(𝐁[i,j],𝐒[i,j1],[i,j])subscript𝐁𝑖𝑗Retentionsubscript𝐁𝑖𝑗subscript𝐒𝑖𝑗1𝑖𝑗\mathbf{B}_{[i,j]}\leftarrow\mathrm{Retention}(\mathbf{B}_{[i,j]},\mathbf{S}_{%[i,j-1]},[i,j])bold_B start_POSTSUBSCRIPT [ italic_i , italic_j ] end_POSTSUBSCRIPT ← roman_Retention ( bold_B start_POSTSUBSCRIPT [ italic_i , italic_j ] end_POSTSUBSCRIPT , bold_S start_POSTSUBSCRIPT [ italic_i , italic_j - 1 ] end_POSTSUBSCRIPT , [ italic_i , italic_j ] ) in parallel for j=1,,c𝑗1𝑐j=1,\ldots,citalic_j = 1 , … , italic_c (Eqn. 4)

9:Return 𝐀[i],𝐁[i],𝐒[i]subscript𝐀delimited-[]𝑖subscript𝐁delimited-[]𝑖subscript𝐒delimited-[]𝑖\mathbf{A}_{[i]},\mathbf{B}_{[i]},\mathbf{S}_{[i]}bold_A start_POSTSUBSCRIPT [ italic_i ] end_POSTSUBSCRIPT , bold_B start_POSTSUBSCRIPT [ italic_i ] end_POSTSUBSCRIPT , bold_S start_POSTSUBSCRIPT [ italic_i ] end_POSTSUBSCRIPT

To train the world model, trajectory segments of H𝐻Hitalic_H steps from past experience are uniformly sampled from the replay buffer and translated into token sequences.These sequences are processed in chunks of c𝑐citalic_c observation-action blocks to produce the modeled distributions, as depicted in Figure 6.Optimization is carried by minimizing the cross-entropy loss of the transitions and termination outputs, and the appropriate loss of the reward outputs, depending on the task.For continuous rewards, the mean-squared error loss is used while for discrete ones cross-entropy is used instead.


Non-Token-BasedToken-Based
GameRandomHumanSimPLeDreamerV3TWMSTORMIRISREM (ours)
Alien227.87127.7616.9959.4674.6983.6420.0607.2
Amidar5.81719.574.3139.1121.8204.8143.095.3
Assault222.4742.0527.2705.6682.6801.01524.41764.2
Asterix210.08503.31128.3932.51116.61028.0853.61637.5
BankHeist14.2753.134.2648.7466.7641.253.119.2
BattleZone2360.037187.54031.212250.05068.013540.013074.011826.0
Boxing0.112.17.878.077.579.770.187.5
Breakout1.730.516.431.120.015.983.790.7
ChopperCommand811.07387.8979.4410.01697.41888.01565.02561.2
CrazyClimber10780.535829.462583.697190.071820.466776.059324.276547.6
DemonAttack152.11971.0208.1303.3350.2164.62034.45738.6
Freeway0.029.616.70.024.30.031.132.3
Frostbite65.24334.7236.9909.41475.61316.0259.1240.5
Gopher257.62412.5596.83730.01674.88239.62236.15452.4
Hero1027.030826.42656.611160.57254.011044.37037.46484.8
Jamesbond29.0302.8100.5444.6362.4509.0462.7391.2
Kangaroo52.03035.051.24098.31240.04208.0838.2467.6
Krull1598.02665.52204.87781.56349.28412.66616.44017.7
KungFuMaster258.522736.314862.521420.024554.626182.021759.825172.2
MsPacman307.36951.61480.01326.91588.42673.5999.1962.5
Pong-20.714.612.818.418.811.314.618.0
PrivateEye24.969571.335.0881.686.67781.0100.099.6
Qbert163.913455.01288.83405.13330.84522.5745.7743.0
RoadRunner11.57845.05640.615565.09109.017564.09614.614060.2
Seaquest68.442054.7683.3618.0774.4525.2661.31036.7
UpNDown533.411693.23350.37567.115981.77985.03546.23757.6
#Superhuman (↑)0N/A19891012
Mean (↑)0.0001.0000.3321.1240.9561.2221.0461.222
Median (↑)0.0001.0000.1340.4850.5050.4250.2890.280
IQM (↑)0.0001.0000.1300.4870.4590.5610.5010.673
Optimality Gap (↓)1.0000.0000.7290.5100.5130.4720.5120.482
Improving Token-Based World Models with Parallel Observation Prediction (8)

3 Experiments

We follow most prior works on world models and evaluate REM on the widely-recognized Atari 100K benchmark (Kaiser etal., 2020) for sample-efficient reinforcement learning.The Atari 100K benchmark considers a subset of 26 Atari games.For each game, the agent is limited to 100K interaction steps, corresponding to 400K game frames due to the standard frame-skip of 4.In total, this amounts to roughly 2 hours of gameplay.To put in perspective, the original Atari benchmark allows agent to collect 200M steps, that is, 500 times more experience.

Experimental Setup

The full details of the architectures and hyper-parameters used in our experiments are presented in Appendix A.1.Notably, our tokenizer uses K=64𝐾64K=64italic_K = 64 (i.e., a grid of 8×8888\times 88 × 8 latent tokens per observation), whereas IRIS uses only K=4×4=16𝐾4416K=4\times 4=16italic_K = 4 × 4 = 16.To ensure a meaningful comparison of the run times of REM and IRIS, REM’s configuration was chosen so that the amount of computation carried by each component at each epoch remains (roughly) equivalent to that of the corresponding component in IRIS.For benchmarking agents run times, we used a workstation with an Nvidia RTX 4090 GPU. The rest of our experiments were conducted on Nvidia V100 GPUs.

Baselines

Since the contributions of this paper relate to token-based approaches, and to IRIS in particular, our evaluation focuses on token-based methods. To enrich our results, as well as to facilitate future research, we have included the following additional baselines: SimPLe (Kaiser etal., 2020), DreamerV3 (Hafner etal., 2023), TWM (Robine etal., 2023), and STORM (Zhang etal., 2023).In these approaches, observations are processed as a single sequence element by the world model.Following prior works on world models, lookahead search methods such as MuZero (Schrittwieser etal., 2020) and EfficientZero (Ye etal., 2021) are not included as lookahead search operates on top of a world model.Here, our aim is to improve the world model component itself.

3.1 Results

On Atari, it is standard to use human-normalized scores (HNS) (Mnih etal., 2015), calculated as agent_scorerandom_scorehuman_scorerandom_scoreagent_scorerandom_scorehuman_scorerandom_score\frac{\text{agent\_score}\ -\ \text{random\_score}}{\text{human\_score}\ -\ %\text{random\_score}}divide start_ARG agent_score - random_score end_ARG start_ARG human_score - random_score end_ARG, rather than raw game scores.Here, the final score of each training run is computed as an average over 100 episodes collected at the end of training.In the work of (Agarwal etal., 2021), the authors found discrepancies between conclusions drawn from point estimate statistics such as mean and median and a more thorough statistical analysis that also considers the uncertainty in the results.Adhering to their established protocol and utilizing their toolkit, we report the mean, median, and interquantile mean (IQM) human-normalized scores, and the optimality gap, with 95% stratified bootstrap confidence intervals in Figure 7.Performance profiles are presented in Figure 8.Average scores of individual games are reported in Table1.

Improving Token-Based World Models with Parallel Observation Prediction (9)

REM attains an IQM human normalized score of 0.673, outperforming all baselines.Additionally, REM improves over IRIS on 3 out of the 4 metrics (i.e., mean, optimality gap, and IQM), while being comparable in terms of its median score.Remarkably, REM achieves superhuman performance on 12 games, more than any other baseline (Table 1).REM also exhibits state-of-the-art scores on several games, including Assault, Boxing, and Chopper Command.These findings support our empirical claim that REM performs similarly or better than previous token-based approaches while running significantly faster.

3.2 Ablation Studies

To analyze the impact of different components of our approach on REM’s performance, we conduct a series of ablation studies.For each component, we compare the final algorithm to a version where the component of interest is disabled.Due to computational resource constraints, the evaluation is performed on a subset of 8 games from the Atari 100K benchmark using 5 random seeds for each game.This subset includes games with large score differences between IRIS and REM, as we are particularly interested in studying the impact of each component in these games.Concretely, this subset includes the games “Assult”, “Asterix”, “Chopper Command”, “Crazy Climber”, “Demon Attack”, “Gopher”, “Krull”, and “Road Runner”.We performed ablation studies on the following aspects: the POP mechanism, the latent space architecture of 𝒞𝒞\mathcal{C}caligraphic_C and its action inputs, the latent resolution of 𝒱𝒱\mathcal{V}caligraphic_V, and the observation token embeddings used by \mathcal{M}caligraphic_M.

The probability of improvement (Agarwal etal., 2021) and IQM human-normalized scores are presented in Figure 9.Figure 10 offers a comparison of the training times, juxtaposing REM with its efficiency-related ablations.

Analyzing POP

To study the impact of POP on REM’s performance, we replaced the POP-augmented RetNet of \mathcal{M}caligraphic_M with a vanilla RetNet.In this version, denoted as ”No POP”, the prediction of next observation tokens is performed sequentially token-by-token, as done in IRIS.

Our results suggest that POP retains the agent’s performance (Figure 9) while significantly reducing the overall computation time (Figure 10).In Appendix A.4, we provide additional results indicating that the world model’s performance are also retained.POP achieves lower total computation time by expediting the actor-critic learning phase, despite the increased computational cost implied by the observation prediction tokens.

Actor-Critic Architecture and Action Inputs

For 𝒞𝒞\mathcal{C}caligraphic_C, we considered an incremental ablation.First, we replaced the architecture of REM’s controller 𝒞𝒞\mathcal{C}caligraphic_C with that of IRIS (denoted “𝒞IRISsubscript𝒞IRIS\mathcal{C}_{\text{IRIS}}caligraphic_C start_POSTSUBSCRIPT IRIS end_POSTSUBSCRIPT”).In contrast to REM, this version processes fully reconstructed pixel frames and does not incorporate action inputs.Formally, 𝒞IRISsubscript𝒞IRIS\mathcal{C}_{\text{IRIS}}caligraphic_C start_POSTSUBSCRIPT IRIS end_POSTSUBSCRIPT models π(at|𝐨^t),Vπ(𝐨^t)𝜋conditionalsubscript𝑎𝑡subscript^𝐨absent𝑡superscript𝑉𝜋subscript^𝐨absent𝑡\pi(a_{t}|\hat{\mathbf{o}}_{\leq t}),V^{\pi}(\hat{\mathbf{o}}_{\leq t})italic_π ( italic_a start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT | over^ start_ARG bold_o end_ARG start_POSTSUBSCRIPT ≤ italic_t end_POSTSUBSCRIPT ) , italic_V start_POSTSUPERSCRIPT italic_π end_POSTSUPERSCRIPT ( over^ start_ARG bold_o end_ARG start_POSTSUBSCRIPT ≤ italic_t end_POSTSUBSCRIPT ).In the second ablation, REM was modified so that only the action inputs of 𝒞𝒞\mathcal{C}caligraphic_C were disabled.This ablation corresponds to π(at|𝐳^t),Vπ(𝐳^t)𝜋conditionalsubscript𝑎𝑡subscript^𝐳absent𝑡superscript𝑉𝜋subscript^𝐳absent𝑡\pi(a_{t}|\hat{\mathbf{z}}_{\leq t}),V^{\pi}(\hat{\mathbf{z}}_{\leq t})italic_π ( italic_a start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT | over^ start_ARG bold_z end_ARG start_POSTSUBSCRIPT ≤ italic_t end_POSTSUBSCRIPT ) , italic_V start_POSTSUPERSCRIPT italic_π end_POSTSUPERSCRIPT ( over^ start_ARG bold_z end_ARG start_POSTSUBSCRIPT ≤ italic_t end_POSTSUBSCRIPT ).

Our findings indicate that both the latent codes based architecture and the added action inputs contribute to the final performance of REM (Figure 9).Additionally, the latent codes based architecture of 𝒞𝒞\mathcal{C}caligraphic_C leads to reduced computational overhead and shorter actor-critic learning times (Figure 10).

Improving Token-Based World Models with Parallel Observation Prediction (10)
Tokenizer Resolution

Here, we compare REM to a version with a reduced latent resolution of 4×4444\times 44 × 4, similar to that of IRIS.The results in Figure 9 provides clear evidence that the latent resolution of the tokenizer has a significant impact on the agent’s performance.Our results demonstrates that POP enables REM to utilize higher latent resolutions while incurring shorter computation times than prior token-based approaches.

World Model Embeddings

In REM, \mathcal{M}caligraphic_M translates observation tokens to embedding vectors using the embedding table \mathcal{E}caligraphic_E learned by 𝒱𝒱\mathcal{V}caligraphic_V.These embeddings encode the visual information as learned by 𝒱𝒱\mathcal{V}caligraphic_V.In contrast, IRIS maintains a separate embedding table learned by the world model for that purpose.Here, the results in Figure 9 provide empirical evidence indicating that leveraging this encoded visual information leads to improved performance.In Appendix A.4, we provide additional evidence suggesting that the world model’s next-observation predictions are also improved.

Improving Token-Based World Models with Parallel Observation Prediction (11)

4 Related Work

Model-based reinforcement learning (RL), with its roots in the tabular setting (Sutton, 1991), has been a focus of extensive research in recent decades.The deep RL agent of (Ha & Schmidhuber, 2018) leveraged an LSTM (Hochreiter & Schmidhuber, 1997) sequence model with a VAE (Kingma & Welling, 2014) to model the dynamics in visual environments, demonstrating that successful policies can be learned entirely from simulated data.This approach, commonly known as world models, was later applied to Atari games (Kaiser etal., 2020) with the PPO (Schulman etal., 2017) RL algorithm.Later, a series of works (Hafner etal., 2020, 2021, 2023) proposed the Dreamer algorithms, which are based on a recurrent state space model (RSSM) (Hafner etal., 2019) to model dynamics.The latest DreamerV3 was evaluated on a variety of challenging environments, providing further evidence of the promising potential of world models.In contrast to token-based approaches, where each token serves as a standalone sequence element, Dreamer encodes each frame as a vector of categorical variables, which are processed at once by the RSSM.

Following the success of the Transformer architecture (Vaswani etal., 2017) in language modeling (Brown etal., 2020), and motivated by their favorable scaling properties compared to RNNs, Transformer were recently explored in RL (Parisotto etal., 2020; Chen etal., 2021; Reed etal., 2022; Shridhar etal., 2023).World model approaches also adopted the Transformer architecture.(Micheli etal., 2023) blazed the trail for token-based world models with IRIS, representing agent trajectories as language-like sequences.By treating each observation as a sequence, its Transformer-based world model gains an explicit sub-observation attention resolution.Despite IRIS’s high performance, its imagination bottleneck results in a substantial disadvantage.

In addition to IRIS, non-token-based world models driven by Transformers were proposed.TWM (Robine etal., 2023) utilizes the Transformer-XL architecture (Dai etal., 2020) and a non-uniform data sampling.STORM (Zhang etal., 2023) proposes an efficient Transformer based world model agent which sets state-of-the-art result for the Atari 100K benchmark.STORM has a significantly smaller 2-layer Transformer compared to the 10-layer models of TWM and IRIS, demonstrating drastically reduced training times and improved agent performance.

5 Conclusions

In this work, we presented a novel parallel observation prediction (POP) mechanism augmenting Retention networks with a dedicated forward mode to improve the efficiency of token-based world models (TBWMs).POP effectively solves the imagination bottleneck of TBWMs and enables them to deal with longer observation sequences.Additionally, we introduced REM, a TBWM agent equipped with POP.REM is the first world model agent driven by the RetNet architecture.Empirically, we demonstrated the superiority of REM over prior TBWMs on the Atari 100K benchmark, rendering REM competitive with the state-of-the-art, both in terms of agent performance and overall run time.

Our work opens up many promising avenues for future research by making TBWMs practical and cost-efficient.One such direction could be to explore a modification of REM where the recurrent state of the world model summarizes the entire history of the agent.Similarly, a history-preserving RetNet architecture should be considered for the controller as well.Another promising avenue would be to leverage the independent optimization of the tokenizer to enable REM to use pretrained visual perception models in environments where visual data is abundant, for example, the real world.Such perceptual models could be trained at scale, and allow REM to store only compressed observations in its replay buffer, further improving its efficiency.Lastly, token-based methods for video generation tasks can benefit from using the POP mechanism for generating entire frames in parallel conditioned on the past context.We believe that this is an exciting avenue to explore with a potentially high impact.

Acknowledgements

This project has received funding from the European Union’s Horizon Europe Programme under grant agreement No. 101070568.

Impact Statement

This paper presents work whose goal is to advance the field of Machine Learning. There are many potential societal consequences of our work, none of which we feel must be specifically highlighted here.

References

  • Agarwal etal. (2021)Agarwal, R., Schwarzer, M., Castro, P.S., Courville, A.C., and Bellemare, M.Deep reinforcement learning at the edge of the statistical precipice.In Ranzato, M., Beygelzimer, A., Dauphin, Y., Liang, P., and Vaughan, J.W. (eds.), Advances in Neural Information Processing Systems, volume34, pp. 29304–29320. Curran Associates, Inc., 2021.URL https://proceedings.neurips.cc/paper_files/paper/2021/file/f514cec81cb148559cf475e7426eed5e-Paper.pdf.
  • Ba etal. (2016)Ba, J.L., Kiros, J.R., and Hinton, G.E.Layer normalization.arXiv preprint arXiv:1607.06450, 2016.
  • Berner etal. (2019)Berner, C., Brockman, G., Chan, B., Cheung, V., Debiak, P., Dennison, C., Farhi, D., Fischer, Q., Hashme, S., Hesse, C., Józefowicz, R., Gray, S., Olsson, C., Pachocki, J., Petrov, M., deOliveiraPinto, H.P., Raiman, J., Salimans, T., Schlatter, J., Schneider, J., Sidor, S., Sutskever, I., Tang, J., Wolski, F., and Zhang, S.Dota 2 with large scale deep reinforcement learning.CoRR, abs/1912.06680, 2019.URL http://arxiv.org/abs/1912.06680.
  • Brown etal. (2020)Brown, T., Mann, B., Ryder, N., Subbiah, M., Kaplan, J.D., Dhariwal, P., Neelakantan, A., Shyam, P., Sastry, G., Askell, A., Agarwal, S., Herbert-Voss, A., Krueger, G., Henighan, T., Child, R., Ramesh, A., Ziegler, D., Wu, J., Winter, C., Hesse, C., Chen, M., Sigler, E., Litwin, M., Gray, S., Chess, B., Clark, J., Berner, C., McCandlish, S., Radford, A., Sutskever, I., and Amodei, D.Language models are few-shot learners.In Larochelle, H., Ranzato, M., Hadsell, R., Balcan, M., and Lin, H. (eds.), Advances in Neural Information Processing Systems, volume33, pp. 1877–1901. Curran Associates, Inc., 2020.URL https://proceedings.neurips.cc/paper_files/paper/2020/file/1457c0d6bfcb4967418bfb8ac142f64a-Paper.pdf.
  • Bubeck etal. (2023)Bubeck, S., Chandrasekaran, V., Eldan, R., Gehrke, J., Horvitz, E., Kamar, E., Lee, P., Lee, Y.T., Li, Y., Lundberg, S.M., Nori, H., Palangi, H., Ribeiro, M.T., and Zhang, Y.Sparks of artificial general intelligence: Early experiments with GPT-4.CoRR, abs/2303.12712, 2023.doi: 10.48550/ARXIV.2303.12712.URL https://doi.org/10.48550/arXiv.2303.12712.
  • Chen etal. (2021)Chen, L., Lu, K., Rajeswaran, A., Lee, K., Grover, A., Laskin, M., Abbeel, P., Srinivas, A., and Mordatch, I.Decision Transformer: Reinforcement Learning via Sequence Modeling.In Advances in Neural Information Processing Systems, volume18, 2021.
  • Dai etal. (2020)Dai, Z., Yang, Z., Yang, Y., Carbonell, J., Le, Q.V., and Salakhutdinov, R.Transformer-XL: Attentive language models beyond a fixed-length context.In ACL 2019 - 57th Annual Meeting of the Association for Computational Linguistics, Proceedings of the Conference, 2020.doi: 10.18653/v1/p19-1285.
  • Devlin etal. (2019)Devlin, J., Chang, M.W., Lee, K., and Toutanova, K.BERT: Pre-training of deep bidirectional transformers for language understanding.In NAACL HLT 2019 - 2019 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies - Proceedings of the Conference, volume1, 2019.
  • Dosovitskiy etal. (2021)Dosovitskiy, A., Beyer, L., Kolesnikov, A., Weissenborn, D., Zhai, X., Unterthiner, T., Dehghani, M., Minderer, M., Heigold, G., Gelly, S., Uszkoreit, J., and Houlsby, N.An image is worth 16x16 words: Transformers for image recognition at scale.In International Conference on Learning Representations, 2021.URL https://openreview.net/forum?id=YicbFdNTTy.
  • Esser etal. (2021)Esser, P., Rombach, R., and Ommer, B.Taming transformers for high-resolution image synthesis.In Proceedings of the IEEE/CVF conference on computer vision and pattern recognition, pp. 12873–12883, 2021.
  • Ha & Schmidhuber (2018)Ha, D. and Schmidhuber, J.Recurrent world models facilitate policy evolution.In Advances in Neural Information Processing Systems 31, pp. 2451–2463. Curran Associates, Inc., 2018.URL https://papers.nips.cc/paper/7512-recurrent-world-models-facilitate-policy-evolution.https://worldmodels.github.io.
  • Hafner etal. (2019)Hafner, D., Lillicrap, T., Fischer, I., Villegas, R., Ha, D., Lee, H., and Davidson, J.Learning latent dynamics for planning from pixels.In 36th International Conference on Machine Learning, ICML 2019, volume 2019-June, 2019.
  • Hafner etal. (2020)Hafner, D., Lillicrap, T., Ba, J., and Norouzi, M.Dream to control: Learning behaviors by latent imagination.In International Conference on Learning Representations, 2020.URL https://openreview.net/forum?id=S1lOTC4tDS.
  • Hafner etal. (2021)Hafner, D., Lillicrap, T.P., Norouzi, M., and Ba, J.Mastering atari with discrete world models.In 9th International Conference on Learning Representations, ICLR 2021, Virtual Event, Austria, May 3-7, 2021. OpenReview.net, 2021.URL https://openreview.net/forum?id=0oabwyZbOu.
  • Hafner etal. (2023)Hafner, D., Pasukonis, J., Ba, J., and Lillicrap, T.Mastering diverse domains through world models.arXiv preprint arXiv:2301.04104, 2023.
  • Hendrycks & Gimpel (2017)Hendrycks, D. and Gimpel, K.Bridging nonlinearities and stochastic regularizers with gaussian error linear units, 2017.URL https://openreview.net/forum?id=Bk0MRI5lg.
  • Hochreiter & Schmidhuber (1997)Hochreiter, S. and Schmidhuber, J.Long short-term memory.Neural computation, 9(8):1735–1780, 1997.
  • Kaiser etal. (2020)Kaiser, Ł., Babaeizadeh, M., Miłos, P., Osiński, B., Campbell, R.H., Czechowski, K., Erhan, D., Finn, C., Kozakowski, P., Levine, S., Mohiuddin, A., Sepassi, R., Tucker, G., and Michalewski, H.Model based reinforcement learning for atari.In International Conference on Learning Representations, 2020.URL https://openreview.net/forum?id=S1xCPJHtDB.
  • Kingma & Welling (2014)Kingma, D.P. and Welling, M.Auto-encoding variational bayes.In 2nd International Conference on Learning Representations, ICLR 2014 - Conference Track Proceedings, 2014.
  • Li etal. (2023)Li, T., Chang, H., Mishra, S., Zhang, H., Katabi, D., and Krishnan, D.Mage: Masked generative encoder to unify representation learning and image synthesis.In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 2142–2152, 2023.
  • Micheli etal. (2023)Micheli, V., Alonso, E., and Fleuret, F.Transformers are sample-efficient world models.In The Eleventh International Conference on Learning Representations, ICLR 2023, Kigali, Rwanda, May 1-5, 2023. OpenReview.net, 2023.URL https://openreview.net/pdf?id=vhFu1Acb0xb.
  • Mnih etal. (2015)Mnih, V., Kavukcuoglu, K., Silver, D., Rusu, A.A., Veness, J., Bellemare, M.G., Graves, A., Riedmiller, M., Fidjeland, A.K., Ostrovski, G., etal.Human-level control through deep reinforcement learning.nature, 518(7540):529–533, 2015.
  • Parisotto etal. (2020)Parisotto, E., Song, F., Rae, J., Pascanu, R., Gulcehre, C., Jayakumar, S., Jaderberg, M., Kaufman, R.L., Clark, A., Noury, S., Botvinick, M., Heess, N., and Hadsell, R.Stabilizing transformers for reinforcement learning.In III, H.D. and Singh, A. (eds.), Proceedings of the 37th International Conference on Machine Learning, volume 119 of Proceedings of Machine Learning Research, pp. 7487–7498. PMLR, 13–18 Jul 2020.URL https://proceedings.mlr.press/v119/parisotto20a.html.
  • Ramachandran etal. (2018)Ramachandran, P., Zoph, B., and Le, Q.V.Searching for activation functions, 2018.URL https://openreview.net/forum?id=SkBYYyZRZ.
  • Reed etal. (2022)Reed, S., Zolna, K., Parisotto, E., Colmenarejo, S.G., Novikov, A., Barth-maron, G., Giménez, M., Sulsky, Y., Kay, J., Springenberg, J.T., Eccles, T., Bruce, J., Razavi, A., Edwards, A., Heess, N., Chen, Y., Hadsell, R., Vinyals, O., Bordbar, M., and deFreitas, N.A generalist agent.Transactions on Machine Learning Research, 2022.ISSN 2835-8856.URL https://openreview.net/forum?id=1ikK0kHjvj.Featured Certification, Outstanding Certification.
  • Robine etal. (2023)Robine, J., Höftmann, M., Uelwer, T., and Harmeling, S.Transformer-based world models are happy with 100k interactions.In The Eleventh International Conference on Learning Representations, 2023.URL https://openreview.net/forum?id=TdBaDGCpjly.
  • Schrittwieser etal. (2020)Schrittwieser, J., Antonoglou, I., Hubert, T., Simonyan, K., Sifre, L., Schmitt, S., Guez, A., Lockhart, E., Hassabis, D., Graepel, T., Lillicrap, T., and Silver, D.Mastering Atari, Go, chess and shogi by planning with a learned model.Nature, 588(7839):604–609, dec 2020.ISSN 14764687.doi: 10.1038/s41586-020-03051-4.URL https://www.nature.com/articles/s41586-020-03051-4.
  • Schulman etal. (2017)Schulman, J., Wolski, F., Dhariwal, P., Radford, A., and Klimov, O.Proximal policy optimization algorithms.ArXiv, abs/1707.06347, 2017.URL https://api.semanticscholar.org/CorpusID:28695052.
  • Shridhar etal. (2023)Shridhar, M., Manuelli, L., and Fox, D.Perceiver-actor: A multi-task transformer for robotic manipulation.In Liu, K., Kulic, D., and Ichnowski, J. (eds.), Proceedings of The 6th Conference on Robot Learning, volume 205 of Proceedings of Machine Learning Research, pp. 785–799. PMLR, 14–18 Dec 2023.URL https://proceedings.mlr.press/v205/shridhar23a.html.
  • Silver etal. (2016)Silver, D., Huang, A., Maddison, C.J., Guez, A., Sifre, L., vanden Driessche, G., Schrittwieser, J., Antonoglou, I., Panneershelvam, V., Lanctot, M., Dieleman, S., Grewe, D., Nham, J., Kalchbrenner, N., Sutskever, I., Lillicrap, T., Leach, M., Kavukcuoglu, K., Graepel, T., and Hassabis, D.Mastering the game of go with deep neural networks and tree search.Nature, 529:484–503, 2016.URL http://www.nature.com/nature/journal/v529/n7587/full/nature16961.html.
  • Sun etal. (2023)Sun, Y., Dong, L., Huang, S., Ma, S., Xia, Y., Xue, J., Wang, J., and Wei, F.Retentive network: A successor to transformer for large language models.arXiv preprint arXiv:2307.08621, 2023.
  • Sutton (1991)Sutton, R.S.Dyna, an integrated architecture for learning, planning, and reacting.ACM SIGART Bulletin, 2(4), 1991.ISSN 0163-5719.doi: 10.1145/122344.122377.
  • Sutton & Barto (2018)Sutton, R.S. and Barto, A.G.Reinforcement Learning: An Introduction.A Bradford Book, Cambridge, MA, USA, 2018.ISBN 0262039249.
  • Touvron etal. (2023)Touvron, H., Martin, L., Stone, K.R., Albert, P., Almahairi, A., Babaei, Y., Bashlykov, N., Batra, S., Bhargava, P., Bhosale, S., Bikel, D.M., Blecher, L., Ferrer, C.C., Chen, M., Cucurull, G., Esiobu, D., Fernandes, J., Fu, J., Fu, W., Fuller, B., Gao, C., Goswami, V., Goyal, N., Hartshorn, A.S., Hosseini, S., Hou, R., Inan, H., Kardas, M., Kerkez, V., Khabsa, M., Kloumann, I.M., Korenev, A.V., Koura, P.S., Lachaux, M.-A., Lavril, T., Lee, J., Liskovich, D., Lu, Y., Mao, Y., Martinet, X., Mihaylov, T., Mishra, P., Molybog, I., Nie, Y., Poulton, A., Reizenstein, J., Rungta, R., Saladi, K., Schelten, A., Silva, R., Smith, E.M., Subramanian, R., Tan, X., Tang, B., Taylor, R., Williams, A., Kuan, J.X., Xu, P., Yan, Z., Zarov, I., Zhang, Y., Fan, A., Kambadur, M., Narang, S., Rodriguez, A., Stojnic, R., Edunov, S., and Scialom, T.Llama 2: Open foundation and fine-tuned chat models.ArXiv, abs/2307.09288, 2023.URL https://api.semanticscholar.org/CorpusID:259950998.
  • vanden Oord etal. (2017)vanden Oord, A., Vinyals, O., and kavukcuoglu, k.Neural discrete representation learning.In Guyon, I., Luxburg, U.V., Bengio, S., Wallach, H., Fergus, R., Vishwanathan, S., and Garnett, R. (eds.), Advances in Neural Information Processing Systems, volume30. Curran Associates, Inc., 2017.URL https://proceedings.neurips.cc/paper_files/paper/2017/file/7a98af17e63a0ac09ce2e96d03992fbc-Paper.pdf.
  • Vaswani etal. (2017)Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A.N., Kaiser, L.u., and Polosukhin, I.Attention is all you need.In Guyon, I., Luxburg, U.V., Bengio, S., Wallach, H., Fergus, R., Vishwanathan, S., and Garnett, R. (eds.), Advances in Neural Information Processing Systems, volume30. Curran Associates, Inc., 2017.URL https://proceedings.neurips.cc/paper_files/paper/2017/file/3f5ee243547dee91fbd053c1c4a845aa-Paper.pdf.
  • Vinyals etal. (2019)Vinyals, O., Babuschkin, I., Czarnecki, W.M., Mathieu, M., Dudzik, A., Chung, J., Choi, D.H., Powell, R., Ewalds, T., Georgiev, P., Oh, J., Horgan, D., Kroiss, M., Danihelka, I., Huang, A., Sifre, L., Cai, T., Agapiou, J.P., Jaderberg, M., Vezhnevets, A.S., Leblond, R., Pohlen, T., Dalibard, V., Budden, D., Sulsky, Y., Molloy, J., Paine, T.L., Gulcehre, C., Wang, Z., Pfaff, T., Wu, Y., Ring, R., Yogatama, D., Wünsch, D., McKinney, K., Smith, O., Schaul, T., Lillicrap, T., Kavukcuoglu, K., Hassabis, D., Apps, C., and Silver, D.Grandmaster level in StarCraft II using multi-agent reinforcement learning.Nature, 575(7782), 2019.ISSN 14764687.doi: 10.1038/s41586-019-1724-z.
  • Ye etal. (2021)Ye, W., Liu, S., Kurutach, T., Abbeel, P., and Gao, Y.Mastering Atari Games with Limited Data.In Advances in Neural Information Processing Systems, volume30, 2021.
  • Zhang etal. (2023)Zhang, W., Wang, G., Sun, J., Yuan, Y., and Huang, G.STORM: Efficient stochastic transformer based world models for reinforcement learning.In Thirty-seventh Conference on Neural Information Processing Systems, 2023.URL https://openreview.net/forum?id=WxnrX42rnS.

Appendix A Appendix

A.1 Models and Hyperparameters

Tables 2 and 3 detail hyperparameters of the optimization and environment, as well as hyperparameters shared by multiple components.


DescriptionSymbolValue
HorizonH10
Tokens per observationK64
Tokenizer vocabulary sizeN512
Epochs-600
Experience collection epochs-500
Environment steps per epoch-200
Collection epsilon-greedy-0.01
Eval sampling temperature-0.5
Optimizer-AdamW
AdamW β1subscript𝛽1\beta_{1}italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT-0.9
AdamW β2subscript𝛽2\beta_{2}italic_β start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT-0.999
Frame resolution-64×64646464\times 6464 × 64
Frame Skip-4
Max no-ops (train, test)-(30, 1)
Max episode steps (train, test)-(20K, 108K)
Terminate on live loss (train, test)-(No, Yes)

DescriptionSymbolTokenizerWorld ModelActor-Critic
Learning rate-0.00010.00020.0001
Batch size-12864128
Gradient Clipping Threshold-101003
Start after epochs-52550
Training Steps per epoch-200200100
AdamW Weight Decay-0.010.050.01

A.1.1 Tokenizer (𝒱𝒱\mathcal{V}caligraphic_V)

Tokenizer Architecture

Our tokenizer is based on the implementation of VQ-GAN (Esser etal., 2021).However, we simplified the architectures of the encoder and decoder networks.A description of the architectures of the encoder and decoder networks can be found in table 4.


ModuleOutput Shape
Encoder
Input3×64×64364643\times 64\times 643 × 64 × 64
Conv(3, 1, 1)32×64×6432646432\times 64\times 6432 × 64 × 64
EncoderBlock164×32×3264323264\times 32\times 3264 × 32 × 32
EncoderBlock2128×16×161281616128\times 16\times 16128 × 16 × 16
EncoderBlock3256×8×825688256\times 8\times 8256 × 8 × 8
GN256×8×825688256\times 8\times 8256 × 8 × 8
SiLU256×8×825688256\times 8\times 8256 × 8 × 8
Conv(3, 1, 1)256×8×825688256\times 8\times 8256 × 8 × 8
EncoderBlock
Inputc×h×w𝑐𝑤c\times h\times witalic_c × italic_h × italic_w
GNc×h×w𝑐𝑤c\times h\times witalic_c × italic_h × italic_w
SiLUc×h×w𝑐𝑤c\times h\times witalic_c × italic_h × italic_w
Conv(3, 2, Asym.)2c×h2×w22𝑐2𝑤22c\times\frac{h}{2}\times\frac{w}{2}2 italic_c × divide start_ARG italic_h end_ARG start_ARG 2 end_ARG × divide start_ARG italic_w end_ARG start_ARG 2 end_ARG
Conv(3, 1, 1)2c×h2×w22𝑐2𝑤22c\times\frac{h}{2}\times\frac{w}{2}2 italic_c × divide start_ARG italic_h end_ARG start_ARG 2 end_ARG × divide start_ARG italic_w end_ARG start_ARG 2 end_ARG
Decoder
Input256×8×825688256\times 8\times 8256 × 8 × 8
Conv(3, 1, 1)256×8×825688256\times 8\times 8256 × 8 × 8
DecoderBlock1128×16×161281616128\times 16\times 16128 × 16 × 16
DecoderBlock264×32×3264323264\times 32\times 3264 × 32 × 32
DecoderBlock332×64×6432646432\times 64\times 6432 × 64 × 64
GN32×64×6432646432\times 64\times 6432 × 64 × 64
SiLU32×64×6432646432\times 64\times 6432 × 64 × 64
Conv(3, 1, 1)3×64×64364643\times 64\times 643 × 64 × 64
DecoderBlock
Inputc×h×w𝑐𝑤c\times h\times witalic_c × italic_h × italic_w
GNc×h×w𝑐𝑤c\times h\times witalic_c × italic_h × italic_w
SiLUc×h×w𝑐𝑤c\times h\times witalic_c × italic_h × italic_w
Interpolatec×2h×2w𝑐22𝑤c\times 2h\times 2witalic_c × 2 italic_h × 2 italic_w
Conv(3, 1, 1)c2×2h×2w𝑐222𝑤\frac{c}{2}\times 2h\times 2wdivide start_ARG italic_c end_ARG start_ARG 2 end_ARG × 2 italic_h × 2 italic_w
Conv(3, 1, 1)c2×2h×2w𝑐222𝑤\frac{c}{2}\times 2h\times 2wdivide start_ARG italic_c end_ARG start_ARG 2 end_ARG × 2 italic_h × 2 italic_w
Tokenizer Learning

Following IRIS (Micheli etal., 2023), our tokenizer is a VQ-VAE (vanden Oord etal., 2017) based on the implementation of (Esser etal., 2021) (without the discriminator).The training objective is given by

(E,D,)=xD(z)1+sg(E(x))(z)22+sg((z))E(x)22+perceptual(x,D(z))𝐸𝐷subscriptnorm𝑥𝐷𝑧1superscriptsubscriptnormsg𝐸𝑥𝑧22superscriptsubscriptnormsg𝑧𝐸𝑥22subscriptperceptual𝑥𝐷𝑧\mathcal{L}(E,D,\mathcal{E})=\|x-D(z)\|_{1}+\|\text{sg}(E(x))-\mathcal{E}(z)\|%_{2}^{2}+\|\text{sg}(\mathcal{E}(z))-E(x)\|_{2}^{2}+\mathcal{L}_{\text{%perceptual}}(x,D(z))caligraphic_L ( italic_E , italic_D , caligraphic_E ) = ∥ italic_x - italic_D ( italic_z ) ∥ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT + ∥ sg ( italic_E ( italic_x ) ) - caligraphic_E ( italic_z ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + ∥ sg ( caligraphic_E ( italic_z ) ) - italic_E ( italic_x ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + caligraphic_L start_POSTSUBSCRIPT perceptual end_POSTSUBSCRIPT ( italic_x , italic_D ( italic_z ) )(8)

where E𝐸Eitalic_E and D𝐷Ditalic_D are the encoder and decoder models, respectively, and sg()sg\text{sg}(\cdot)sg ( ⋅ ) is the stop-gradient operator.The first term on the right hand side of Equation 8 above is the reconstruction loss, the second and third terms correspond to the commitment loss, and the last term is the perceptual loss.

A.1.2 Retentive World Model (\mathcal{M}caligraphic_M)

The hyperparameters of \mathcal{M}caligraphic_M are presented in Table 5.

Implementation Details

We use the “Yet-Another-RetNet” RetNet implementation111https://github.com/fkodom/yet-another-retnet, as its code is simple and convenient while its performance remain competitive with the official implementation in terms of run time and efficiency.

Originally, the IRIS algorithm provides the world model with a single observation to make forward predictions.Our implementation considers a context of two frames for making forward predictions.


DescriptionSymbolValue
Number of layers-5
Number of Retention heads-4
Embedding dimensiond256
Dropout-0.1
RetNet feed-forward dimension-1024
RetNet LayerNorm epsilon-1e-6
Blocks per chunkc3
World model context length-2
POP Observation-Generation Modes

The simpler observation generation mode presented in Section 2.3 requires two sequential world model calls to generate the next observation. The first call consumes the previous observation-action block to compute the recurrent state at the current time step, while the second call uses 𝐮𝐮\mathbf{u}bold_u to generate the next observation tokens.Note that the next observation tokens sampled in the second call have not been processed by the world model at this point.To incorporate these tokens into the recurrent state, an additional world model call is required.Here, the cost induced by the first world model call is K+1𝐾1K+1italic_K + 1, while the second call costs K𝐾Kitalic_K.

Alternatively, it is also possible to combine these two calls into one by concatenating the previous observation-action block and 𝐮𝐮\mathbf{u}bold_u.However, to avoid having 𝐮𝐮\mathbf{u}bold_u incorporated in the recurrent state computed by this call, a modified forward call should be used.Concretely, the resulting recurrent state should only summarize the previous observation-action block, neglecting the suffix 𝐮𝐮\mathbf{u}bold_u.In practice, we use the backbone of the POP chunkwise forward mode (Alg. 1) for this computation.This alternative mode induces only H𝐻Hitalic_H sequential world model calls, while each call processes 2K+12𝐾12K+12 italic_K + 1.Hence, this alternative reduces the number of sequential calls while maintaining the same total cost.

In practice, the optimal mode to use depends on the configuration, model sizes, and hardware.In our configuration, we opted for a larger batch size for the imagination phase.Hence, we found the first (simpler) mode to be slightly more efficient in this case.However, the second mode could be more efficient in other settings.


Algorithm ObservationPrediction Cost World ModelTraining Cost ImaginationSequential Calls ImaginationCost per Call
POP (default mode)2K2𝐾2K2 italic_K2KH2𝐾𝐻2KH2 italic_K italic_H2H2𝐻2H2 italic_HK𝐾Kitalic_K
POP (alternative mode)2K2𝐾2K2 italic_K2KH2𝐾𝐻2KH2 italic_K italic_HH𝐻Hitalic_H2K2𝐾2K2 italic_K
No POPK𝐾Kitalic_KKH𝐾𝐻KHitalic_K italic_HKH𝐾𝐻KHitalic_K italic_H1111

A.1.3 Controller (𝒞𝒞\mathcal{C}caligraphic_C)

Actor-Critic Learning

Our learning algorithm follows IRIS and Dreamer (Micheli etal., 2023; Hafner etal., 2020, 2021), which uses λ𝜆\lambdaitalic_λ-returns defined recursively as

Gt={r^t+γ(1d^t)((1λ)Vπ(𝐳^t+1)+λGt+1)t<HVπ(𝐳^H)t=Hsubscript𝐺𝑡casessubscript^𝑟𝑡𝛾1subscript^𝑑𝑡1𝜆superscript𝑉𝜋subscript^𝐳𝑡1𝜆subscript𝐺𝑡1𝑡𝐻superscript𝑉𝜋subscript^𝐳𝐻𝑡𝐻G_{t}=\begin{cases}\hat{r}_{t}+\gamma(1-\hat{d}_{t})\left((1-\lambda)V^{\pi}(%\hat{\mathbf{z}}_{t+1})+\lambda G_{t+1}\right)&t<H\\V^{\pi}(\hat{\mathbf{z}}_{H})&t=H\end{cases}italic_G start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = { start_ROW start_CELL over^ start_ARG italic_r end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT + italic_γ ( 1 - over^ start_ARG italic_d end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) ( ( 1 - italic_λ ) italic_V start_POSTSUPERSCRIPT italic_π end_POSTSUPERSCRIPT ( over^ start_ARG bold_z end_ARG start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT ) + italic_λ italic_G start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT ) end_CELL start_CELL italic_t < italic_H end_CELL end_ROW start_ROW start_CELL italic_V start_POSTSUPERSCRIPT italic_π end_POSTSUPERSCRIPT ( over^ start_ARG bold_z end_ARG start_POSTSUBSCRIPT italic_H end_POSTSUBSCRIPT ) end_CELL start_CELL italic_t = italic_H end_CELL end_ROW

where Vπsuperscript𝑉𝜋V^{\pi}italic_V start_POSTSUPERSCRIPT italic_π end_POSTSUPERSCRIPT is the value network learned by the critic, and (𝐳^0,a0,r^0,d^0,,𝐳^H1,aH1,r^H1,d^H1,𝐳^H)subscript^𝐳0subscript𝑎0subscript^𝑟0subscript^𝑑0subscript^𝐳𝐻1subscript𝑎𝐻1subscript^𝑟𝐻1subscript^𝑑𝐻1subscript^𝐳𝐻(\hat{\mathbf{z}}_{0},a_{0},\hat{r}_{0},\hat{d}_{0},\ldots,\hat{\mathbf{z}}_{H%-1},a_{H-1},\hat{r}_{H-1},\hat{d}_{H-1},\hat{\mathbf{z}}_{H})( over^ start_ARG bold_z end_ARG start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_a start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , over^ start_ARG italic_r end_ARG start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , over^ start_ARG italic_d end_ARG start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , … , over^ start_ARG bold_z end_ARG start_POSTSUBSCRIPT italic_H - 1 end_POSTSUBSCRIPT , italic_a start_POSTSUBSCRIPT italic_H - 1 end_POSTSUBSCRIPT , over^ start_ARG italic_r end_ARG start_POSTSUBSCRIPT italic_H - 1 end_POSTSUBSCRIPT , over^ start_ARG italic_d end_ARG start_POSTSUBSCRIPT italic_H - 1 end_POSTSUBSCRIPT , over^ start_ARG bold_z end_ARG start_POSTSUBSCRIPT italic_H end_POSTSUBSCRIPT ) is a trajectory obtained through world model imagination.

To optimize Vπsuperscript𝑉𝜋V^{\pi}italic_V start_POSTSUPERSCRIPT italic_π end_POSTSUPERSCRIPT, the following loss is minimized:

Vπ=𝔼π[t=0H1Vπ(𝐳^t)sg(Gt)2]subscriptsuperscript𝑉𝜋subscript𝔼𝜋delimited-[]superscriptsubscript𝑡0𝐻1superscript𝑉𝜋subscript^𝐳𝑡sgsuperscriptsubscript𝐺𝑡2\mathcal{L}_{V^{\pi}}=\mathbb{E}_{\pi}[\sum_{t=0}^{H-1}V^{\pi}(\hat{\mathbf{z}%}_{t})-\text{sg}(G_{t})^{2}]caligraphic_L start_POSTSUBSCRIPT italic_V start_POSTSUPERSCRIPT italic_π end_POSTSUPERSCRIPT end_POSTSUBSCRIPT = blackboard_E start_POSTSUBSCRIPT italic_π end_POSTSUBSCRIPT [ ∑ start_POSTSUBSCRIPT italic_t = 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_H - 1 end_POSTSUPERSCRIPT italic_V start_POSTSUPERSCRIPT italic_π end_POSTSUPERSCRIPT ( over^ start_ARG bold_z end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) - sg ( italic_G start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ]

The policy optimization follows a simple REINFORCE (Sutton & Barto, 2018) objective, with Vπsuperscript𝑉𝜋V^{\pi}italic_V start_POSTSUPERSCRIPT italic_π end_POSTSUPERSCRIPT used as a baseline for variance reduction.The objective is given by

π=𝔼π[t=0H1log(π(at|𝐳^1,a1,,𝐳^t1,at1,𝐳^t))sg(GtVπ(𝐳^t))+α(π(at|𝐳^1,a1,,𝐳^t1,at1,𝐳^t))]subscript𝜋subscript𝔼𝜋delimited-[]superscriptsubscript𝑡0𝐻1𝜋conditionalsubscript𝑎𝑡subscript^𝐳1subscript𝑎1subscript^𝐳𝑡1subscript𝑎𝑡1subscript^𝐳𝑡sgsubscript𝐺𝑡superscript𝑉𝜋subscript^𝐳𝑡𝛼𝜋conditionalsubscript𝑎𝑡subscript^𝐳1subscript𝑎1subscript^𝐳𝑡1subscript𝑎𝑡1subscript^𝐳𝑡\mathcal{L}_{\pi}=-\mathbb{E}_{\pi}[\sum_{t=0}^{H-1}\log(\pi(a_{t}|\hat{%\mathbf{z}}_{1},a_{1},\ldots,\hat{\mathbf{z}}_{t-1},a_{t-1},\hat{\mathbf{z}}_{%t}))\text{sg}(G_{t}-V^{\pi}(\hat{\mathbf{z}}_{t}))+\alpha\mathcal{H}(\pi(a_{t}%|\hat{\mathbf{z}}_{1},a_{1},\ldots,\hat{\mathbf{z}}_{t-1},a_{t-1},\hat{\mathbf%{z}}_{t}))]caligraphic_L start_POSTSUBSCRIPT italic_π end_POSTSUBSCRIPT = - blackboard_E start_POSTSUBSCRIPT italic_π end_POSTSUBSCRIPT [ ∑ start_POSTSUBSCRIPT italic_t = 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_H - 1 end_POSTSUPERSCRIPT roman_log ( italic_π ( italic_a start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT | over^ start_ARG bold_z end_ARG start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_a start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , over^ start_ARG bold_z end_ARG start_POSTSUBSCRIPT italic_t - 1 end_POSTSUBSCRIPT , italic_a start_POSTSUBSCRIPT italic_t - 1 end_POSTSUBSCRIPT , over^ start_ARG bold_z end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) ) sg ( italic_G start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT - italic_V start_POSTSUPERSCRIPT italic_π end_POSTSUPERSCRIPT ( over^ start_ARG bold_z end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) ) + italic_α caligraphic_H ( italic_π ( italic_a start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT | over^ start_ARG bold_z end_ARG start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_a start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , over^ start_ARG bold_z end_ARG start_POSTSUBSCRIPT italic_t - 1 end_POSTSUBSCRIPT , italic_a start_POSTSUBSCRIPT italic_t - 1 end_POSTSUBSCRIPT , over^ start_ARG bold_z end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) ) ]

where α𝛼\alphaitalic_αThe values of the hyperparameters used in our experiments are detailed in Table 7


DescriptionSymbolValue
Discount factorγ𝛾\gammaitalic_γ0.995
λ𝜆\lambdaitalic_λ-returnλ𝜆\lambdaitalic_λ0.95
Entropy loss weightα𝛼\alphaitalic_α0.001

ModuleOutput Shape
Input256×8×825688256\times 8\times 8256 × 8 × 8
Conv(3, 1, 1)128×8×812888128\times 8\times 8128 × 8 × 8
SiLU128×8×812888128\times 8\times 8128 × 8 × 8
Conv(3, 1, 1)64×8×8648864\times 8\times 864 × 8 × 8
SiLU64×8×8648864\times 8\times 864 × 8 × 8
Flatten4096
Linear512
SiLU512
Agent Architecture

The architecture of the agent module comprises of a shared backbone and two linear maps for the actor and critic heads, respectively.The shared backbone first maps the input to a latent representation which takes the form of a 512-dimensional vector.For action token inputs, a learned embedding table is used to map the token to its latent representation.For observation inputs, the K𝐾Kitalic_K tokens are first mapped to their corresponding code vectors learned by the tokenizer and reshaped according to their original spatial order.Then, the resulting tensor is processed by a convolutional neural network followed by a fully connected network.The architecture details of these networks are presented in Table 8.Lastly, a long-short term memory (LSTM) (Hochreiter & Schmidhuber, 1997) network of dimension 512512512512 maps the processed input vector to an history-dependant latent vector, which serves as the output of the shared backbone.

Action Dependant Actor Critic

In the IRIS algorithm, the actor and critic networks share an LSTM (Hochreiter & Schmidhuber, 1997) backbone and model π(at|𝐨t),Vπ(𝐨t)𝜋conditionalsubscript𝑎𝑡subscript𝐨absent𝑡superscript𝑉𝜋subscript𝐨absent𝑡\pi(a_{t}|\mathbf{o}_{\leq t}),V^{\pi}(\mathbf{o}_{\leq t})italic_π ( italic_a start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT | bold_o start_POSTSUBSCRIPT ≤ italic_t end_POSTSUBSCRIPT ) , italic_V start_POSTSUPERSCRIPT italic_π end_POSTSUPERSCRIPT ( bold_o start_POSTSUBSCRIPT ≤ italic_t end_POSTSUBSCRIPT ).Notice that the output of the policy models the distribution of actions at step t𝑡titalic_t.Importantly, the model has no information about the sampled actions.In REM, the input of 𝒞𝒞\mathcal{C}caligraphic_C contains the sampled actions, i.e., our algorithm models π(at|𝐳^1,a1,,𝐳^t1,at1,𝐳^t),Vπ(𝐳^1,a1,,𝐳^t1,at1,𝐳^t)𝜋conditionalsubscript𝑎𝑡subscript^𝐳1subscript𝑎1subscript^𝐳𝑡1subscript𝑎𝑡1subscript^𝐳𝑡superscript𝑉𝜋subscript^𝐳1subscript𝑎1subscript^𝐳𝑡1subscript𝑎𝑡1subscript^𝐳𝑡\pi(a_{t}|\hat{\mathbf{z}}_{1},a_{1},\ldots,\hat{\mathbf{z}}_{t-1},a_{t-1},%\hat{\mathbf{z}}_{t}),V^{\pi}(\hat{\mathbf{z}}_{1},a_{1},\ldots,\hat{\mathbf{z%}}_{t-1},a_{t-1},\hat{\mathbf{z}}_{t})italic_π ( italic_a start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT | over^ start_ARG bold_z end_ARG start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_a start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , over^ start_ARG bold_z end_ARG start_POSTSUBSCRIPT italic_t - 1 end_POSTSUBSCRIPT , italic_a start_POSTSUBSCRIPT italic_t - 1 end_POSTSUBSCRIPT , over^ start_ARG bold_z end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) , italic_V start_POSTSUPERSCRIPT italic_π end_POSTSUPERSCRIPT ( over^ start_ARG bold_z end_ARG start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_a start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , over^ start_ARG bold_z end_ARG start_POSTSUBSCRIPT italic_t - 1 end_POSTSUBSCRIPT , italic_a start_POSTSUBSCRIPT italic_t - 1 end_POSTSUBSCRIPT , over^ start_ARG bold_z end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ).

A.2 REM Algorithm

Here, we present a pseudo-code of REM.The high-level loop is presented in Algorithm 3, while the pseudo-codes of the training of each component are presented in algorithms 4-7.

Input:

repeat

collect_experience() (Alg. 4)

train_V() (Alg. 5)

train_M() (Alg. 6)

train_C() (Alg. 7)

untilstopping criterion is met

Input:

𝐨1env.reset()subscript𝐨1env.reset()\mathbf{o}_{1}\leftarrow\texttt{env.reset()}bold_o start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ← env.reset()

fort=1𝑡1t=1italic_t = 1 to T𝑇Titalic_Tdo

𝐳t𝒱Enc(𝐨t)subscript𝐳𝑡subscript𝒱Encsubscript𝐨𝑡\mathbf{z}_{t}\leftarrow\mathcal{V}_{\text{Enc}}(\mathbf{o}_{t})bold_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ← caligraphic_V start_POSTSUBSCRIPT Enc end_POSTSUBSCRIPT ( bold_o start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT )

atπ(at|𝐳1,a1,,𝐳t1,at1,𝐳t)similar-tosubscript𝑎𝑡𝜋conditionalsubscript𝑎𝑡subscript𝐳1subscript𝑎1subscript𝐳𝑡1subscript𝑎𝑡1subscript𝐳𝑡a_{t}\sim\pi(a_{t}|\mathbf{z}_{1},a_{1},\ldots,\mathbf{z}_{t-1},a_{t-1},%\mathbf{z}_{t})italic_a start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ∼ italic_π ( italic_a start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT | bold_z start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_a start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , bold_z start_POSTSUBSCRIPT italic_t - 1 end_POSTSUBSCRIPT , italic_a start_POSTSUBSCRIPT italic_t - 1 end_POSTSUBSCRIPT , bold_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT )

𝐨t+1,rt,dtenv.step(at)subscript𝐨𝑡1subscript𝑟𝑡subscript𝑑𝑡env.stepsubscript𝑎𝑡\mathbf{o}_{t+1},r_{t},d_{t}\leftarrow\texttt{env.step}(a_{t})bold_o start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT , italic_r start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_d start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ← env.step ( italic_a start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT )

ifdt=1subscript𝑑𝑡1d_{t}=1italic_d start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = 1then

𝐨t+1env.reset()subscript𝐨𝑡1env.reset()\mathbf{o}_{t+1}\leftarrow\texttt{env.reset()}bold_o start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT ← env.reset()

endif

endfor

replay_buffer.store({𝐨t,at,rt,dt}t=1T)replay_buffer.storesuperscriptsubscriptsubscript𝐨𝑡subscript𝑎𝑡subscript𝑟𝑡subscript𝑑𝑡𝑡1𝑇\texttt{replay\_buffer.store}(\{\mathbf{o}_{t},a_{t},r_{t},d_{t}\}_{t=1}^{T})replay_buffer.store ( { bold_o start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_a start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_r start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_d start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_t = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT )

𝐨replay_buffer.sample_obs()𝐨replay_buffer.sample_obs()\mathbf{o}\leftarrow\texttt{replay\_buffer.sample\_obs()}bold_o ← replay_buffer.sample_obs()

𝐳𝒱Enc(𝐨)𝐳subscript𝒱Enc𝐨\mathbf{z}\leftarrow\mathcal{V}_{\text{Enc}}(\mathbf{o})bold_z ← caligraphic_V start_POSTSUBSCRIPT Enc end_POSTSUBSCRIPT ( bold_o )

𝐨^𝒱Dec(𝐳)^𝐨subscript𝒱Dec𝐳\hat{\mathbf{o}}\leftarrow\mathcal{V}_{\text{Dec}}(\mathbf{z})over^ start_ARG bold_o end_ARG ← caligraphic_V start_POSTSUBSCRIPT Dec end_POSTSUBSCRIPT ( bold_z )

Compute loss (Eqn. 8)

Update 𝒱𝒱\mathcal{V}caligraphic_V


{𝐨t,at,rt,dt}t=1Hreplay_buffer.sample()superscriptsubscriptsubscript𝐨𝑡subscript𝑎𝑡subscript𝑟𝑡subscript𝑑𝑡𝑡1𝐻replay_buffer.sample()\{\mathbf{o}_{t},a_{t},r_{t},d_{t}\}_{t=1}^{H}\leftarrow\texttt{replay\_buffer%.sample()}{ bold_o start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_a start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_r start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_d start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_t = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_H end_POSTSUPERSCRIPT ← replay_buffer.sample()

fort=1𝑡1t=1italic_t = 1 to H𝐻Hitalic_Hdo

𝐳t,𝐡t𝒱Enc(𝐨t)subscript𝐳𝑡subscript𝐡𝑡subscript𝒱Encsubscript𝐨𝑡\mathbf{z}_{t},\mathbf{h}_{t}\leftarrow\mathcal{V}_{\text{Enc}}(\mathbf{o}_{t})bold_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , bold_h start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ← caligraphic_V start_POSTSUBSCRIPT Enc end_POSTSUBSCRIPT ( bold_o start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT )

𝐚t.embed_action(at)subscript𝐚𝑡.embed_actionsubscript𝑎𝑡\mathbf{a}_{t}\leftarrow\mathcal{M}\texttt{.embed\_action}(a_{t})bold_a start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ← caligraphic_M .embed_action ( italic_a start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT )

endfor

𝐗(𝐡1,𝐚1,,𝐡H,𝐚H)𝐗subscript𝐡1subscript𝐚1subscript𝐡𝐻subscript𝐚𝐻\mathbf{X}\leftarrow(\mathbf{h}_{1},\mathbf{a}_{1},\ldots,\mathbf{h}_{H},%\mathbf{a}_{H})bold_X ← ( bold_h start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , bold_a start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , bold_h start_POSTSUBSCRIPT italic_H end_POSTSUBSCRIPT , bold_a start_POSTSUBSCRIPT italic_H end_POSTSUBSCRIPT )

𝐒01,,𝐒0L0,,0formulae-sequencesuperscriptsubscript𝐒01superscriptsubscript𝐒0𝐿00\mathbf{S}_{0}^{1},\ldots,\mathbf{S}_{0}^{L}\leftarrow 0,\ldots,0bold_S start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT , … , bold_S start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT ← 0 , … , 0

fori=1𝑖1i=1italic_i = 1 to Hc𝐻𝑐\lceil\frac{H}{c}\rceil⌈ divide start_ARG italic_H end_ARG start_ARG italic_c end_ARG ⌉do

𝐘[i],{𝐒il}l=1LPOP_forward(𝐗[i],{𝐒[i1]l}l=1L)subscript𝐘delimited-[]𝑖superscriptsubscriptsuperscriptsubscript𝐒𝑖𝑙𝑙1𝐿POP_forwardsubscript𝐗delimited-[]𝑖superscriptsubscriptsuperscriptsubscript𝐒delimited-[]𝑖1𝑙𝑙1𝐿\mathbf{Y}_{[i]},\{\mathbf{S}_{i}^{l}\}_{l=1}^{L}\leftarrow\texttt{POP\_%forward}(\mathbf{X}_{[i]},\{\mathbf{S}_{[i-1]}^{l}\}_{l=1}^{L})bold_Y start_POSTSUBSCRIPT [ italic_i ] end_POSTSUBSCRIPT , { bold_S start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_l end_POSTSUPERSCRIPT } start_POSTSUBSCRIPT italic_l = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT ← POP_forward ( bold_X start_POSTSUBSCRIPT [ italic_i ] end_POSTSUBSCRIPT , { bold_S start_POSTSUBSCRIPT [ italic_i - 1 ] end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_l end_POSTSUPERSCRIPT } start_POSTSUBSCRIPT italic_l = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT ) (Alg. 1)

endfor

(𝐳^1,,𝐳^H).obs_pred_head(𝐘[:, :-1])subscript^𝐳1subscript^𝐳𝐻.obs_pred_head𝐘[:, :-1](\hat{\mathbf{z}}_{1},\ldots,\hat{\mathbf{z}}_{H})\leftarrow\mathcal{M}\texttt%{.obs\_pred\_head}(\mathbf{Y}\texttt{[:, :-1]})( over^ start_ARG bold_z end_ARG start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , over^ start_ARG bold_z end_ARG start_POSTSUBSCRIPT italic_H end_POSTSUBSCRIPT ) ← caligraphic_M .obs_pred_head ( bold_Y [:, :-1] )

(r^1,d^1,,r^H,d^H).reward_done_head(𝐘[:, -1])subscript^𝑟1subscript^𝑑1subscript^𝑟𝐻subscript^𝑑𝐻.reward_done_head𝐘[:, -1](\hat{r}_{1},\hat{d}_{1},\ldots,\hat{r}_{H},\hat{d}_{H})\leftarrow\mathcal{M}%\texttt{.reward\_done\_head}(\mathbf{Y}\texttt{[:, -1]})( over^ start_ARG italic_r end_ARG start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , over^ start_ARG italic_d end_ARG start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , over^ start_ARG italic_r end_ARG start_POSTSUBSCRIPT italic_H end_POSTSUBSCRIPT , over^ start_ARG italic_d end_ARG start_POSTSUBSCRIPT italic_H end_POSTSUBSCRIPT ) ← caligraphic_M .reward_done_head ( bold_Y [:, -1] )

Compute Losses and update \mathcal{M}caligraphic_M

{𝐨t,at,rt,dt}t=1Hreplay_buffer.sample()superscriptsubscriptsubscript𝐨𝑡subscript𝑎𝑡subscript𝑟𝑡subscript𝑑𝑡𝑡1𝐻replay_buffer.sample()\{\mathbf{o}_{t},a_{t},r_{t},d_{t}\}_{t=1}^{H}\leftarrow\texttt{replay\_buffer%.sample()}{ bold_o start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_a start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_r start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_d start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_t = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_H end_POSTSUPERSCRIPT ← replay_buffer.sample()

fort=1𝑡1t=1italic_t = 1 to H𝐻Hitalic_Hdo

𝐳t,𝐡t𝒱Enc(𝐨t)subscript𝐳𝑡subscript𝐡𝑡subscript𝒱Encsubscript𝐨𝑡\mathbf{z}_{t},\mathbf{h}_{t}\leftarrow\mathcal{V}_{\text{Enc}}(\mathbf{o}_{t})bold_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , bold_h start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ← caligraphic_V start_POSTSUBSCRIPT Enc end_POSTSUBSCRIPT ( bold_o start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT )

endfor

𝐒0𝐒0\mathbf{S}\leftarrow 0bold_S ← 0

c1𝑐1c\leftarrow 1italic_c ← 1

Initialize context τ(𝐳1,a1,,𝐳H)𝜏subscript𝐳1subscript𝑎1subscript𝐳𝐻\tau\leftarrow(\mathbf{z}_{1},a_{1},\dots,\mathbf{z}_{H})italic_τ ← ( bold_z start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_a start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , bold_z start_POSTSUBSCRIPT italic_H end_POSTSUBSCRIPT )

fort=H+1superscript𝑡𝐻1t^{\prime}=H+1italic_t start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT = italic_H + 1 to 2H2𝐻2H2 italic_Hdo

atπ(at|𝐳1,a1,,𝐳t)similar-tosubscript𝑎superscript𝑡𝜋conditionalsubscript𝑎superscript𝑡subscript𝐳1subscript𝑎1subscript𝐳superscript𝑡a_{t^{\prime}}\sim\pi(a_{t^{\prime}}|\mathbf{z}_{1},a_{1},\ldots,\mathbf{z}_{t%^{\prime}})italic_a start_POSTSUBSCRIPT italic_t start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ∼ italic_π ( italic_a start_POSTSUBSCRIPT italic_t start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT | bold_z start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_a start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , bold_z start_POSTSUBSCRIPT italic_t start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT )

VtV(𝐳1,a1,,𝐳t)subscript𝑉superscript𝑡𝑉subscript𝐳1subscript𝑎1subscript𝐳superscript𝑡V_{t^{\prime}}\leftarrow V(\mathbf{z}_{1},a_{1},\ldots,\mathbf{z}_{t^{\prime}})italic_V start_POSTSUBSCRIPT italic_t start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ← italic_V ( bold_z start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_a start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , bold_z start_POSTSUBSCRIPT italic_t start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT )

𝐘,𝐒.retnet_chunkwise_forward((τ,at),𝐒,t)𝐘𝐒.retnet_chunkwise_forward𝜏subscript𝑎superscript𝑡𝐒superscript𝑡\mathbf{Y},\mathbf{S}\leftarrow\mathcal{M}\texttt{.retnet\_chunkwise\_forward}%((\tau,a_{t^{\prime}}),\mathbf{S},t^{\prime})bold_Y , bold_S ← caligraphic_M .retnet_chunkwise_forward ( ( italic_τ , italic_a start_POSTSUBSCRIPT italic_t start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ) , bold_S , italic_t start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT )

rt,dt.reward_done_head(𝐘[-1])similar-tosubscript𝑟superscript𝑡subscript𝑑superscript𝑡.reward_done_head𝐘[-1]r_{t^{\prime}},d_{t^{\prime}}\sim\mathcal{M}\texttt{.reward\_done\_head}(%\mathbf{Y}\texttt{[-1]})italic_r start_POSTSUBSCRIPT italic_t start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT , italic_d start_POSTSUBSCRIPT italic_t start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ∼ caligraphic_M .reward_done_head ( bold_Y [-1] )

𝐘,_.retnet_chunkwise_forward(𝐮,𝐒,t+1)𝐘_.retnet_chunkwise_forward𝐮𝐒superscript𝑡1\mathbf{Y},\text{\_}\leftarrow\mathcal{M}\texttt{.retnet\_chunkwise\_forward}(%\mathbf{u},\mathbf{S},t^{\prime}+1)bold_Y , _ ← caligraphic_M .retnet_chunkwise_forward ( bold_u , bold_S , italic_t start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT + 1 )

𝐳t+1.obs_pred_head(𝐘[:-1])similar-tosubscript𝐳superscript𝑡1.obs_pred_head𝐘[:-1]\mathbf{z}_{t^{\prime}+1}\sim\mathcal{M}\texttt{.obs\_pred\_head}(\mathbf{Y}%\texttt{[:-1]})bold_z start_POSTSUBSCRIPT italic_t start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT + 1 end_POSTSUBSCRIPT ∼ caligraphic_M .obs_pred_head ( bold_Y [:-1] )

τ𝐳t+1𝜏subscript𝐳superscript𝑡1\tau\leftarrow\mathbf{z}_{t^{\prime}+1}italic_τ ← bold_z start_POSTSUBSCRIPT italic_t start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT + 1 end_POSTSUBSCRIPT

endfor

Update π,V𝜋𝑉\pi,Vitalic_π , italic_V (detailed in Section A.1.3)

A.3 Retentive Networks

In this section, we give detailed information regarding the RetNet architecture for the completeness of this paper.For convenience reasons, we defer to the notations of (Sun etal., 2023), rather than the notation presented in

(Sun etal., 2023) is a recent alternative to Transformers (Vaswani etal., 2017). It is highly parallelizable, has lower cost inference than Transformers, and is empirically claimed to perform competitively on language modelling tasks.The RetNet model is a stack of L𝐿Litalic_L identical layers.Here, we denote the output of the l𝑙litalic_l-th layer by 𝐘lsuperscript𝐘𝑙\mathbf{Y}^{l}bold_Y start_POSTSUPERSCRIPT italic_l end_POSTSUPERSCRIPT.Given an embedded input sequence 𝐗=𝐘0m×d𝐗superscript𝐘0superscript𝑚𝑑\mathbf{X}=\mathbf{Y}^{0}\in\mathbb{R}^{m\times d}bold_X = bold_Y start_POSTSUPERSCRIPT 0 end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_m × italic_d end_POSTSUPERSCRIPT of m𝑚mitalic_m d𝑑ditalic_d-dimensional vectors, each RetNet layer can be described as

𝐗l=MSR(LN(𝐘l))+𝐘lsuperscript𝐗𝑙MSRLNsuperscript𝐘𝑙superscript𝐘𝑙\displaystyle\mathbf{X}^{l}=\mathrm{MSR}(\mathrm{LN}(\mathbf{Y}^{l}))+\mathbf{%Y}^{l}bold_X start_POSTSUPERSCRIPT italic_l end_POSTSUPERSCRIPT = roman_MSR ( roman_LN ( bold_Y start_POSTSUPERSCRIPT italic_l end_POSTSUPERSCRIPT ) ) + bold_Y start_POSTSUPERSCRIPT italic_l end_POSTSUPERSCRIPT(9)
𝐘l+1=FFN(LN(𝐗l))+𝐗lsuperscript𝐘𝑙1FFNLNsuperscript𝐗𝑙superscript𝐗𝑙\displaystyle\mathbf{Y}^{l+1}=\mathrm{FFN}(\mathrm{LN}(\mathbf{X}^{l}))+%\mathbf{X}^{l}bold_Y start_POSTSUPERSCRIPT italic_l + 1 end_POSTSUPERSCRIPT = roman_FFN ( roman_LN ( bold_X start_POSTSUPERSCRIPT italic_l end_POSTSUPERSCRIPT ) ) + bold_X start_POSTSUPERSCRIPT italic_l end_POSTSUPERSCRIPT(10)

where LN()LN\mathrm{LN}(\cdot)roman_LN ( ⋅ ) is layer-norm (Ba etal., 2016), FFN(𝐗)=gelu(𝐗W1)W2FFN𝐗gelu𝐗subscriptW1subscriptW2\mathrm{FFN}(\mathbf{X})=\mathrm{gelu}(\mathbf{X}\textbf{W}_{1})\textbf{W}_{2}roman_FFN ( bold_X ) = roman_gelu ( bold_X W start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) W start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT is a feed-forward network (FFN), and MSR()MSR\mathrm{MSR}(\cdot)roman_MSR ( ⋅ ) is a multi-scale retention (MSR) module with multiple Retention heads.The output of the RetNet model is given by RetNet(𝐘0)=𝐘LRetNetsuperscript𝐘0superscript𝐘𝐿\mathrm{RetNet}(\mathbf{Y}^{0})=\mathbf{Y}^{L}roman_RetNet ( bold_Y start_POSTSUPERSCRIPT 0 end_POSTSUPERSCRIPT ) = bold_Y start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT.

As presented in the main text, the chunkwise equations are

𝐒[i]=(𝐊[i]𝜻)𝐕[i]+ηB𝐒[i1]subscript𝐒delimited-[]𝑖superscriptdirect-productsubscript𝐊delimited-[]𝑖𝜻topsubscript𝐕delimited-[]𝑖superscript𝜂𝐵subscript𝐒delimited-[]𝑖1\displaystyle\mathbf{S}_{[i]}=(\mathbf{K}_{[i]}\odot\boldsymbol{\zeta})^{\top}%\mathbf{V}_{[i]}+\eta^{B}\mathbf{S}_{[i-1]}bold_S start_POSTSUBSCRIPT [ italic_i ] end_POSTSUBSCRIPT = ( bold_K start_POSTSUBSCRIPT [ italic_i ] end_POSTSUBSCRIPT ⊙ bold_italic_ζ ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_V start_POSTSUBSCRIPT [ italic_i ] end_POSTSUBSCRIPT + italic_η start_POSTSUPERSCRIPT italic_B end_POSTSUPERSCRIPT bold_S start_POSTSUBSCRIPT [ italic_i - 1 ] end_POSTSUBSCRIPT
𝐘[i]=(𝐐[i]𝐊[i]𝐃)𝐕[i]+(𝐐[i]𝐒[i1])𝝃subscript𝐘delimited-[]𝑖direct-productsubscript𝐐delimited-[]𝑖subscriptsuperscript𝐊topdelimited-[]𝑖𝐃subscript𝐕delimited-[]𝑖direct-productsubscript𝐐delimited-[]𝑖subscript𝐒delimited-[]𝑖1𝝃\displaystyle\mathbf{Y}_{[i]}=\left(\mathbf{Q}_{[i]}\mathbf{K}^{\top}_{[i]}%\odot\mathbf{D}\right)\mathbf{V}_{[i]}+(\mathbf{Q}_{[i]}\mathbf{S}_{[i-1]})%\odot\boldsymbol{\xi}bold_Y start_POSTSUBSCRIPT [ italic_i ] end_POSTSUBSCRIPT = ( bold_Q start_POSTSUBSCRIPT [ italic_i ] end_POSTSUBSCRIPT bold_K start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT [ italic_i ] end_POSTSUBSCRIPT ⊙ bold_D ) bold_V start_POSTSUBSCRIPT [ italic_i ] end_POSTSUBSCRIPT + ( bold_Q start_POSTSUBSCRIPT [ italic_i ] end_POSTSUBSCRIPT bold_S start_POSTSUBSCRIPT [ italic_i - 1 ] end_POSTSUBSCRIPT ) ⊙ bold_italic_ξ

where 𝐐=(𝐗𝐖Q)Θ𝐐direct-productsubscript𝐗𝐖𝑄Θ\mathbf{Q}=\left(\mathbf{X}\mathbf{W}_{Q}\right)\odot\Thetabold_Q = ( bold_XW start_POSTSUBSCRIPT italic_Q end_POSTSUBSCRIPT ) ⊙ roman_Θ, 𝐊=(𝐗𝐖K)Θ¯𝐊direct-productsubscript𝐗𝐖𝐾¯Θ\mathbf{K}=\left(\mathbf{X}\mathbf{W}_{K}\right)\odot\bar{\Theta}bold_K = ( bold_XW start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT ) ⊙ over¯ start_ARG roman_Θ end_ARG, 𝐕=𝐗𝐖V𝐕subscript𝐗𝐖𝑉\mathbf{V}=\mathbf{X}\mathbf{W}_{V}bold_V = bold_XW start_POSTSUBSCRIPT italic_V end_POSTSUBSCRIPT, and 𝝃B×d𝝃superscript𝐵𝑑\boldsymbol{\xi}\in\mathbb{R}^{B\times d}bold_italic_ξ ∈ blackboard_R start_POSTSUPERSCRIPT italic_B × italic_d end_POSTSUPERSCRIPT is a matrix with 𝝃ij=ηi+1subscript𝝃𝑖𝑗superscript𝜂𝑖1\boldsymbol{\xi}_{ij}=\eta^{i+1}bold_italic_ξ start_POSTSUBSCRIPT italic_i italic_j end_POSTSUBSCRIPT = italic_η start_POSTSUPERSCRIPT italic_i + 1 end_POSTSUPERSCRIPT.Here, Θn=einθsubscriptΘ𝑛superscript𝑒𝑖𝑛𝜃\Theta_{n}=e^{{i\mkern 1.0mu}n\theta}roman_Θ start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT = italic_e start_POSTSUPERSCRIPT italic_i italic_n italic_θ end_POSTSUPERSCRIPT, 𝐃n,m={ηnmnm0n<msubscript𝐃𝑛𝑚casessuperscript𝜂𝑛𝑚𝑛𝑚0𝑛𝑚\mathbf{D}_{n,m}=\begin{cases}\eta^{n-m}&n\geq m\\0&n<m\end{cases}bold_D start_POSTSUBSCRIPT italic_n , italic_m end_POSTSUBSCRIPT = { start_ROW start_CELL italic_η start_POSTSUPERSCRIPT italic_n - italic_m end_POSTSUPERSCRIPT end_CELL start_CELL italic_n ≥ italic_m end_CELL end_ROW start_ROW start_CELL 0 end_CELL start_CELL italic_n < italic_m end_CELL end_ROW, and θ,ηd𝜃𝜂superscript𝑑\theta,\eta\in\mathbb{R}^{d}italic_θ , italic_η ∈ blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT,

where η𝜂\etaitalic_η is an exponential decay factor, and the matrices Θ,Θ¯m×dΘ¯Θsuperscript𝑚𝑑\Theta,\bar{\Theta}\in\mathbb{C}^{m\times d}roman_Θ , over¯ start_ARG roman_Θ end_ARG ∈ blackboard_C start_POSTSUPERSCRIPT italic_m × italic_d end_POSTSUPERSCRIPT are for relative position encoding, and 𝐃B×B𝐃superscript𝐵𝐵\mathbf{D}\in\mathbb{R}^{B\times B}bold_D ∈ blackboard_R start_POSTSUPERSCRIPT italic_B × italic_B end_POSTSUPERSCRIPT combines an auto-regressive mask with the temporal decay factor η𝜂\etaitalic_η.

In each RetNet layer, h=ddhead𝑑subscript𝑑headh=\frac{d}{d_{\text{head}}}italic_h = divide start_ARG italic_d end_ARG start_ARG italic_d start_POSTSUBSCRIPT head end_POSTSUBSCRIPT end_ARG heads are used, where dheadsubscript𝑑headd_{\text{head}}italic_d start_POSTSUBSCRIPT head end_POSTSUBSCRIPT is the dimension of each head.Head Retention head uses different parameters WK,WQ,WVsubscript𝑊𝐾subscript𝑊𝑄subscript𝑊𝑉W_{K},W_{Q},W_{V}italic_W start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT , italic_W start_POSTSUBSCRIPT italic_Q end_POSTSUBSCRIPT , italic_W start_POSTSUBSCRIPT italic_V end_POSTSUBSCRIPT.Additionally, Retention head uses a different value of η𝜂\etaitalic_η.Among different RetNet layers, the values of η𝜂\etaitalic_η are fixed.Each layer is defined as follows:

η=125arange(0,h)h𝜂1superscript25arange0superscript\displaystyle\eta=1-2^{-5-\text{arange}(0,h)}\in\mathbb{R}^{h}italic_η = 1 - 2 start_POSTSUPERSCRIPT - 5 - arange ( 0 , italic_h ) end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_h end_POSTSUPERSCRIPT
headi=Retention(X,ηi)subscripthead𝑖Retention𝑋subscript𝜂𝑖\displaystyle\text{head}_{i}=\text{Retention}(X,\eta_{i})head start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = Retention ( italic_X , italic_η start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT )
Y=GroupNormh(Concat(head1,,headh))𝑌subscriptGroupNormConcatsubscripthead1subscripthead\displaystyle Y=\text{GroupNorm}_{h}(\text{Concat}(\text{head}_{1},\ldots,%\text{head}_{h}))italic_Y = GroupNorm start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT ( Concat ( head start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , head start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT ) )
MSR(X)=(swish(XWG)Y)WOMSR𝑋direct-productswish𝑋subscript𝑊𝐺𝑌subscript𝑊𝑂\displaystyle\text{MSR}(X)=(\text{swish}(XW_{G})\odot Y)W_{O}MSR ( italic_X ) = ( swish ( italic_X italic_W start_POSTSUBSCRIPT italic_G end_POSTSUBSCRIPT ) ⊙ italic_Y ) italic_W start_POSTSUBSCRIPT italic_O end_POSTSUBSCRIPT

where WG,WOd×dsubscript𝑊𝐺subscript𝑊𝑂superscript𝑑𝑑W_{G},W_{O}\in\mathbb{R}^{d\times d}italic_W start_POSTSUBSCRIPT italic_G end_POSTSUBSCRIPT , italic_W start_POSTSUBSCRIPT italic_O end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_d × italic_d end_POSTSUPERSCRIPT are learnable parameters.

A.4 Additional Results

In addition to comparing the run times of REM and IRIS, we also conducted a comparison to an improved version of IRIS that uses REM’s configurations.These results are presented in Figure 11.These results clearly show the effectiveness of our novel POP mechanism.

Improving Token-Based World Models with Parallel Observation Prediction (12)

The probability of improvement results from our Atari 100K benchmark experiment are presented in Figure 12.Importantly, REM outperforms previous token-based methods, namely, IRIS, while competitive with all baselines except STORM on this metric.We highlight that our main contributions address the computational bottleneck of token-based methods, and thus we focus on comparing REM to these approaches.

Improving Token-Based World Models with Parallel Observation Prediction (13)

The complete set of ablation results are presented in Figure 13 and Table 9.The performance profiles for the ablations are presented in Figure 14.

Improving Token-Based World Models with Parallel Observation Prediction (14)
Improving Token-Based World Models with Parallel Observation Prediction (15)

GameRandomHumanREMIRISNo POP Separate\mathcal{M}caligraphic_M emb. 4×4444\times 44 × 4tokenizer 𝒞IRISsubscript𝒞IRIS\mathcal{C}_{\text{IRIS}}caligraphic_C start_POSTSUBSCRIPT IRIS end_POSTSUBSCRIPT 𝒞𝒞\mathcal{C}caligraphic_C w/oaction inputs
Assault222.4742.01764.21524.41472.21269.21288.91221.51498.5
Asterix210.08503.31637.5853.61603.41185.9909.61376.41656.3
ChopperCommand811.07387.82561.21565.01848.01928.31958.92517.62302.7
CrazyClimber10780.535829.476547.659324.262964.874791.357814.730952.742441.2
DemonAttack152.11971.05738.62034.412316.04389.93863.35159.05827.0
Gopher257.62412.55452.42236.15338.43764.22174.92891.24365.3
Krull1598.02665.54017.76616.45138.65779.94612.83866.23659.6
RoadRunner11.57845.014060.29614.613161.611723.56161.711692.911692.9
#Superhuman (↑)0N/A6566456
Mean (↑)0.0001.0001.9471.5642.3571.7781.3411.3401.571
Median (↑)0.0001.0002.3391.1302.2211.8211.3841.3571.699
IQM (↑)0.0001.0002.2011.1912.0681.7941.0261.2341.535
Optimality Gap (↓)1.0000.0000.1980.2980.2090.2140.2890.2610.208
Ablations World Model Observation Prediction Losses

To investigate the contribution of each ablation to the quality of world model observation predictions, we measured the corresponding loss values during training and during test episodes with a frequency of 50 epochs.The results are presented in Figure 15, including results for each of the 8 games used in our ablation studies.

Improving Token-Based World Models with Parallel Observation Prediction (16)

Improving Token-Based World Models with Parallel Observation Prediction (17)

Improving Token-Based World Models with Parallel Observation Prediction (18)

Improving Token-Based World Models with Parallel Observation Prediction (19)

Improving Token-Based World Models with Parallel Observation Prediction (20)

Improving Token-Based World Models with Parallel Observation Prediction (21)

Improving Token-Based World Models with Parallel Observation Prediction (22)

Improving Token-Based World Models with Parallel Observation Prediction (23)

A.5 Setup in Freeway

For a fair comparison, we followed the actor-critic configurations of IRIS (Micheli etal., 2023) for Freeway.Specifically, the sampling temperature of the agent is modified from 1 to 0.01, a heuristic that guides the agent towards non-zero reward trajectories.We highlight that different methods use other mechanisms such as epsilon-greedy schedules and “argmax” action selection policies to overcome this exploration challenge (Micheli etal., 2023).

Improving Token-Based World Models with Parallel Observation Prediction (2024)

References

Top Articles
Latest Posts
Article information

Author: Tyson Zemlak

Last Updated:

Views: 5971

Rating: 4.2 / 5 (43 voted)

Reviews: 82% of readers found this page helpful

Author information

Name: Tyson Zemlak

Birthday: 1992-03-17

Address: Apt. 662 96191 Quigley Dam, Kubview, MA 42013

Phone: +441678032891

Job: Community-Services Orchestrator

Hobby: Coffee roasting, Calligraphy, Metalworking, Fashion, Vehicle restoration, Shopping, Photography

Introduction: My name is Tyson Zemlak, I am a excited, light, sparkling, super, open, fair, magnificent person who loves writing and wants to share my knowledge and understanding with you.