Lior Cohen Kaixin Wang Bingyi Kang Shie Mannor
Abstract
Motivated by the success of Transformers when applied to sequences of discrete symbols, token-based world models (TBWMs) were recently proposed as sample-efficient methods.In TBWMs, the world model consumes agent experience as a language-like sequence of tokens, where each observation constitutes a sub-sequence.However, during imagination, the sequential token-by-token generation of next observations results in a severe bottleneck, leading to long training times, poor GPU utilization, and limited representations.To resolve this bottleneck, we devise a novel Parallel Observation Prediction (POP) mechanism.POP augments a Retentive Network (RetNet) with a novel forward mode tailored to our reinforcement learning setting.We incorporate POP in a novel TBWM agent named REM (Retentive Environment Model), showcasing a 15.4x faster imagination compared to prior TBWMs.REM attains superhuman performance on 12 out of 26 games of the Atari 100K benchmark, while training in less than 12 hours.Our code is available at https://github.com/leor-c/REM.
Machine Learning, ICML, World Model, Reinforcement Learning
1 Introduction
Sample efficiency remains a central challenge in reinforcement learning (RL) due to the substantial data demands of successful RL algorithms(Mnih etal., 2015; Silver etal., 2016; Schrittwieser etal., 2020; Berner etal., 2019; Vinyals etal., 2019).One prominent model-based approach for addressing this challenge is known as world models. In world models, the agent’s learning is solely based on simulated interaction data produced by a learned model of the environment through a process called imagination.World models are gaining increasing popularity due to their effectiveness, particularly in visual environments(Hafner etal., 2023).
![Improving Token-Based World Models with Parallel Observation Prediction (3) Improving Token-Based World Models with Parallel Observation Prediction (3)](https://i0.wp.com/arxiv.org/html/2402.05643v5/x3.png)
In recent years, attention-based sequence models, most notably the Transformer architecture (Vaswani etal., 2017), demonstrated unmatched performance in language modeling tasks (Devlin etal., 2019; Brown etal., 2020; Bubeck etal., 2023; Touvron etal., 2023).The notable success of these models when applied to sequences of discrete tokens sparked motivation to employ these architectures to other data modalities by learning appropriate token-based representations.In computer vision, discrete representations are becoming a mainstream approach for various tasks (vanden Oord etal., 2017; Dosovitskiy etal., 2021; Esser etal., 2021; Li etal., 2023).In RL, token-based world models were recently explored in visual environments (Micheli etal., 2023).The visual perception module in these methods is called a tokenizer, as it maps image observations to sequences of discrete symbols.This way, agent interaction is translated into a language-like sequence of discrete tokens, which are processed individually by the world model.
During imagination, to generate the tokens of the next observation with the auto-regressive model, the prediction is carried sequentially token-by-token.Effectively, this highly-sequential computation results in a severe bottleneck that pronouncedly hinders token-based approaches.Consequently, this bottleneck practically caps the length of the observation token sequences which in turn degrades performance.This limitation renders current token-based methods impractical for complex problems.
In this paper, we present Parallel Observation Prediction (POP), a novel mechanism that resolves the imagination bottleneck of token based world models (TBWMs).With POP, the enire next observation token sequence is generated in parallel during world model imagination.At its core, POP augments a Retentive Network (RetNet) sequence model (Sun etal., 2023) with a novel forward mode devised for retaining world model training efficiency.Additionally, we present REM (Retentive Environment Model), a TBWM agent driven by a POP-augmented RetNet architecture.
Our main contributions are summarized as follows:
- •
We propose Parallel Observation Prediction (POP), a novel mechanism that resolves the inference bottleneck of current token-based world models while retaining performance.
- •
We introduce REM, the first world model approach that incorporates the RetNet architecture. Our experiments provide first evidence of RetNet’s performance in an RL setting.
- •
We evaluate REM on the Atari 100K benchmark, demonstrating the effectiveness of POP. POP leads to a 15.4x speed-up at imagination and trains in under 12 hours, while outperforming prior TBWMs.
2 Method
Notations.We consider the Partially Observable Markov Decision Process (POMDP) settingwith image observations ,discrete actions ,scalar rewards ,episode termination signals ,dynamics ,and discount factor .The objective is to learn a policy such that for every situation the output is optimal w.r.t. the expected discounted sum of rewards from that situation under the policy .
2.1 Overview
REM builds on IRIS (Micheli etal., 2023), and similar to most prior works on world models for pixel input(Hafner etal., 2021; Kaiser etal., 2020; Hafner etal., 2023), REM follows a -- structure(Ha & Schmidhuber, 2018):a isual perception module that compresses observations into compact latent representations, a predictive odel that captures the environment’s dynamics, and a ontroller that learns to act to maximize return.Additionally, a replay buffer is used to store environment interaction data.An overview of REM’s training cycle is presented in Figure 2.A pseudo-code algorithm of REM is presented in Appendix A.2.
- Tokenizer
We instantiate the visual perception component as a tokenizer, mapping input observations into latent tokens.Following (Micheli etal., 2023), the tokenizer is a VQ-VAE discrete auto-encoder(vanden Oord etal., 2017; Esser etal., 2021) comprised of an encoder, a decoder, and an embedding table.The embedding table consists of trainable vectors.The encoder first maps an input image to a sequence of -dimensional latent vectors .Then, each latent vector is mapped to the index of the nearest embedding in , , .Such indices are called tokens.For an input image , its latent token sequence is denoted as .To map a token sequence back to the input space, we first retrieve the embedding for each token and obtain a sequence where .Then, inverse to the encoding process, the decoder is responsible for mapping this sequence to a reconstructed observation .
The tokenizer is trained on frames sampled uniformly from the replay buffer.Its optimization objective, architecture, and other details are deferred to Appendix A.1.1.
- World Model
At the core of a world model is the component that captures the dynamics of the environment and makes predictions based on historical observations.Here, is learned entirely in the latent token space, modeling the following distributions at each step :
Transition: | (1) | |||
Reward: | (2) | |||
Termination: | (3) |
To map observation tokens to embedding vectors, uses the code vectors learned by the tokenizer .Note that is not updated by .In addition, maintains dedicated embedding tables for mapping actions and special tokens (detailed in Section 2.3) to continuous vectors.
- Controller
REM’s actor-critic controller is trained to maximize return entirely in imagination (Kaiser etal., 2020; Hafner etal., 2021; Micheli etal., 2023). comprises of a policy network and a value function estimator , and operates on latent tokens and their embeddings. In each optimization step, and are initialized with a short trajectory segment sampled from the replay buffer.Subsequently, the agent interacts with the world model for steps.At each step , the agent plays an action sampled from its policy .The world model evolves accordingly, generating , , and by sampling from the appropriate distributions (Eqn. (1-3)). The resulting trajectories are then used to train the agent.Following (Micheli etal., 2023), we adopted the actor-critic objectives of DreamerV2 (Hafner etal., 2021).We leave the full details of its architecture and optimization to Appendix A.1.3.
![Improving Token-Based World Models with Parallel Observation Prediction (4) Improving Token-Based World Models with Parallel Observation Prediction (4)](https://i0.wp.com/arxiv.org/html/2402.05643v5/x4.png)
2.2 Retention Preliminaries
Similar to Transformers (Vaswani etal., 2017), a RetNet model (Sun etal., 2023) consists of a stack of layers, where each layer contains a multi-head Attention-like mechanism, called Retention, followed by a fully-connected network.A unique characteristic of the Retention mechanism is that it has a dual form of recurrence and parallelism, called “chunkwise”, for improved efficiency when handling long sequences.This form allows to split such sequences into smaller “chunks”, where a parallel computation takes place within chunks and a sequential recurrent form is used between chunks, as shown in Figure3.The information from previous chunks is summarized by a recurrent state maintained by the Retention mechanism.
Formally, consider a sequence of tokens .In our RL context, this sequence is a token trajectory composed of observation-action sub-sequences we call blocks.As such trajectories are typically long, we split them into chunks of tokens, where is a multiple of so that each chunk only contains complete blocks.Here, the hyperparameter can be tuned according to the size of the models, the hardware, and other factors to maximize efficiency.Let be the -dimensional token embedding vectors.The Retention output of the -th chunk is given by
(4) |
where the bracketed subscript is used to index the -th chunk, , , , and is a matrix with .Here, are learnable weights, is an exponential decay factor, the matrix combines an auto-regressive mask with the temporal decay factor , and the matrices are for relative position embedding (see Appendix A.3).Note the chunk index argument of the Retention operator, which controls positional embedding information through .The chunkwise update rule of the recurrent state is given by
(5) |
where , and is a matrix with .On the right hand side of Equations 4 and 5, the first term corresponds to the computation within the chunk while the second term incorporates the information from previous chunks, encapsulated by the recurrent state.Further details about the RetNet architecture are deferred to Appendix A.3.
![Improving Token-Based World Models with Parallel Observation Prediction (5) Improving Token-Based World Models with Parallel Observation Prediction (5)](https://i0.wp.com/arxiv.org/html/2402.05643v5/x5.png)
2.3 World Model Imagination
As the agent’s training relies entirely on world model imagination, the efficiency of the trajectory generation is critical.During imagination, predicting constitutes the primary non-trivial component and consumes the majority of processing time.In IRIS, the prediction of unfolds sequentially, as the model is limited to predicting only one token ahead at each step.This limitation arises since the identity of the next token, which remains unknown at the current step, is necessary for the prediction of later tokens.Thus, generating observations costs sequential world model calls.This leads to poor GPU utilization and long computation time.
To overcome this bottleneck, POP maintains a set of dedicated prediction tokens together with their corresponding embeddings .To generate in one pass, POP simply computes the RetNet outputs starting from using as its input sequence, as illustrated in Figure4.Note that at imagination, the chunk size is limited to a single block, i.e., to .Here, the notation refers to the state that summarizes the first observation-action blocks.To obtain , we use RetNet’s chunkwise forward to summarize an initial context segment of blocks sampled from the replay buffer.Essentially, for every , POP models the following distribution for next observation prediction:
with
It is worth noting that the tokens are only intended for observation token predictions and are never employed in the update of the recurrent state.
This approach effectively reduces the total number of world model calls during imagination from to , eliminating the dependency on the number of observation tokens .In fact, POP provides an additional generation mode that further reduces the number of sequential calls to .We defer the details on this mode to Appendix A.1.2.Also, by using a recurrent state that summarizes long history sequences, POP improves efficiency further, as the per-token prediction cost reduces.Effectively, POP offers improved scalability at the expense of a higher overall computational cost ( compared to ).Our approach add to existing evidence suggesting that enhanced scalability is often favorable, even at the expense of additional computational costs, with Transformers (Vaswani etal., 2017) serving as a prominent example.
![Improving Token-Based World Models with Parallel Observation Prediction (6) Improving Token-Based World Models with Parallel Observation Prediction (6)](https://i0.wp.com/arxiv.org/html/2402.05643v5/x6.png)
2.4 World Model Training
While applying POP during imagination is fairly straightforward, it requires modification of the training data.Consider an input trajectory segment sampled from the replay buffer.To make meaningful observation predictions at imagination, the model should be trained to predict given , for each time step of every input segment.Hence, for every , the input sequence should contain at block . However, replacing with in the original sequence is inadequate, as the prediction of future observations, rewards, and termination signals depends on .Thus, the standard approach of computing all outputs from the same input sequence is not viable, as in this case these two requirements contradict each other (Figure5).The challenge then lies in devising an efficient method for computing the outputs for all time steps in parallel.
To tackle this challenge, we first note that each trajectory prefix can be summarized into a single recurrent state.For example, for the first input chunk , can be summarized into and can be summarized into .Here, we use the subscript to conveniently refer to the -th block within the -th chunk (this notation is demonstrated in Figure 3), with and .Thus, our plan is to first compute all states in parallel, and then predict all next observations from all tuples.
![Improving Token-Based World Models with Parallel Observation Prediction (7) Improving Token-Based World Models with Parallel Observation Prediction (7)](https://i0.wp.com/arxiv.org/html/2402.05643v5/x7.png)
To compute all recurrent states in parallel, a two-step computation is carried.First, intermediate states are computed in parallel for all with
(6) |
where is a matrix with .Then, each recurrent state is computed sequentially by
(7) |
As the majority of the computational burden lies in the first step, the sequential computation in the second step has minimal impact on the overall speedup.
Once we have all states ready, the output of for all is computed in parallel.Here, we stress that the existing Retention mechanism can only perform batched input computation with recurrent states of the same time step .This is due to the shared positional embedding information applied to every input sequence in the batch.To overcome this, we devise a mechanism which extends RetNet to support the batched computation of the tuples, while applying the appropriate positional encoding information.A pseudo code of our novel POP extension of RetNet is given in Algorithms 1 and 2.The latter presents the core of the mechanism (described above), while the former describes the higher level layer-by-layer computation with a final aggregation for combining the produced outputs.Figure 6 illustrates a simplified example of the POP Forward mechanism (Algorithms 1 and 2) for a single-layer model.For brevity, our pseudo code and illustrations only considers Retention layers, omitting other modules of RetNet (Appendix A.3).
1:Input: chunk size ,token embeddings of chunk ,per-layer recurrent states .
2:Initialize
3:Initialize
4:for to do
5:
6:endfor
7:for to do
8:
9:endfor
10:Return
1:Input: Chunk latents , observation prediction latents , recurrent state , chunk index .
2: (Eqn. 4)
3:Compute in parallel (Eqn. 6)
4:for to do
5: (Eqn. 7)
6:endfor
7:
8: in parallel for (Eqn. 4)
9:Return
To train the world model, trajectory segments of steps from past experience are uniformly sampled from the replay buffer and translated into token sequences.These sequences are processed in chunks of observation-action blocks to produce the modeled distributions, as depicted in Figure 6.Optimization is carried by minimizing the cross-entropy loss of the transitions and termination outputs, and the appropriate loss of the reward outputs, depending on the task.For continuous rewards, the mean-squared error loss is used while for discrete ones cross-entropy is used instead.
Non-Token-Based | Token-Based | |||||||
Game | Random | Human | SimPLe | DreamerV3 | TWM | STORM | IRIS | REM (ours) |
Alien | 227.8 | 7127.7 | 616.9 | 959.4 | 674.6 | 983.6 | 420.0 | 607.2 |
Amidar | 5.8 | 1719.5 | 74.3 | 139.1 | 121.8 | 204.8 | 143.0 | 95.3 |
Assault | 222.4 | 742.0 | 527.2 | 705.6 | 682.6 | 801.0 | 1524.4 | 1764.2 |
Asterix | 210.0 | 8503.3 | 1128.3 | 932.5 | 1116.6 | 1028.0 | 853.6 | 1637.5 |
BankHeist | 14.2 | 753.1 | 34.2 | 648.7 | 466.7 | 641.2 | 53.1 | 19.2 |
BattleZone | 2360.0 | 37187.5 | 4031.2 | 12250.0 | 5068.0 | 13540.0 | 13074.0 | 11826.0 |
Boxing | 0.1 | 12.1 | 7.8 | 78.0 | 77.5 | 79.7 | 70.1 | 87.5 |
Breakout | 1.7 | 30.5 | 16.4 | 31.1 | 20.0 | 15.9 | 83.7 | 90.7 |
ChopperCommand | 811.0 | 7387.8 | 979.4 | 410.0 | 1697.4 | 1888.0 | 1565.0 | 2561.2 |
CrazyClimber | 10780.5 | 35829.4 | 62583.6 | 97190.0 | 71820.4 | 66776.0 | 59324.2 | 76547.6 |
DemonAttack | 152.1 | 1971.0 | 208.1 | 303.3 | 350.2 | 164.6 | 2034.4 | 5738.6 |
Freeway | 0.0 | 29.6 | 16.7 | 0.0 | 24.3 | 0.0 | 31.1 | 32.3 |
Frostbite | 65.2 | 4334.7 | 236.9 | 909.4 | 1475.6 | 1316.0 | 259.1 | 240.5 |
Gopher | 257.6 | 2412.5 | 596.8 | 3730.0 | 1674.8 | 8239.6 | 2236.1 | 5452.4 |
Hero | 1027.0 | 30826.4 | 2656.6 | 11160.5 | 7254.0 | 11044.3 | 7037.4 | 6484.8 |
Jamesbond | 29.0 | 302.8 | 100.5 | 444.6 | 362.4 | 509.0 | 462.7 | 391.2 |
Kangaroo | 52.0 | 3035.0 | 51.2 | 4098.3 | 1240.0 | 4208.0 | 838.2 | 467.6 |
Krull | 1598.0 | 2665.5 | 2204.8 | 7781.5 | 6349.2 | 8412.6 | 6616.4 | 4017.7 |
KungFuMaster | 258.5 | 22736.3 | 14862.5 | 21420.0 | 24554.6 | 26182.0 | 21759.8 | 25172.2 |
MsPacman | 307.3 | 6951.6 | 1480.0 | 1326.9 | 1588.4 | 2673.5 | 999.1 | 962.5 |
Pong | -20.7 | 14.6 | 12.8 | 18.4 | 18.8 | 11.3 | 14.6 | 18.0 |
PrivateEye | 24.9 | 69571.3 | 35.0 | 881.6 | 86.6 | 7781.0 | 100.0 | 99.6 |
Qbert | 163.9 | 13455.0 | 1288.8 | 3405.1 | 3330.8 | 4522.5 | 745.7 | 743.0 |
RoadRunner | 11.5 | 7845.0 | 5640.6 | 15565.0 | 9109.0 | 17564.0 | 9614.6 | 14060.2 |
Seaquest | 68.4 | 42054.7 | 683.3 | 618.0 | 774.4 | 525.2 | 661.3 | 1036.7 |
UpNDown | 533.4 | 11693.2 | 3350.3 | 7567.1 | 15981.7 | 7985.0 | 3546.2 | 3757.6 |
#Superhuman (↑) | 0 | N/A | 1 | 9 | 8 | 9 | 10 | 12 |
Mean (↑) | 0.000 | 1.000 | 0.332 | 1.124 | 0.956 | 1.222 | 1.046 | 1.222 |
Median (↑) | 0.000 | 1.000 | 0.134 | 0.485 | 0.505 | 0.425 | 0.289 | 0.280 |
IQM (↑) | 0.000 | 1.000 | 0.130 | 0.487 | 0.459 | 0.561 | 0.501 | 0.673 |
Optimality Gap (↓) | 1.000 | 0.000 | 0.729 | 0.510 | 0.513 | 0.472 | 0.512 | 0.482 |
![Improving Token-Based World Models with Parallel Observation Prediction (8) Improving Token-Based World Models with Parallel Observation Prediction (8)](https://i0.wp.com/arxiv.org/html/2402.05643v5/x8.png)
3 Experiments
We follow most prior works on world models and evaluate REM on the widely-recognized Atari 100K benchmark (Kaiser etal., 2020) for sample-efficient reinforcement learning.The Atari 100K benchmark considers a subset of 26 Atari games.For each game, the agent is limited to 100K interaction steps, corresponding to 400K game frames due to the standard frame-skip of 4.In total, this amounts to roughly 2 hours of gameplay.To put in perspective, the original Atari benchmark allows agent to collect 200M steps, that is, 500 times more experience.
Experimental Setup
The full details of the architectures and hyper-parameters used in our experiments are presented in Appendix A.1.Notably, our tokenizer uses (i.e., a grid of latent tokens per observation), whereas IRIS uses only .To ensure a meaningful comparison of the run times of REM and IRIS, REM’s configuration was chosen so that the amount of computation carried by each component at each epoch remains (roughly) equivalent to that of the corresponding component in IRIS.For benchmarking agents run times, we used a workstation with an Nvidia RTX 4090 GPU. The rest of our experiments were conducted on Nvidia V100 GPUs.
Baselines
Since the contributions of this paper relate to token-based approaches, and to IRIS in particular, our evaluation focuses on token-based methods. To enrich our results, as well as to facilitate future research, we have included the following additional baselines: SimPLe (Kaiser etal., 2020), DreamerV3 (Hafner etal., 2023), TWM (Robine etal., 2023), and STORM (Zhang etal., 2023).In these approaches, observations are processed as a single sequence element by the world model.Following prior works on world models, lookahead search methods such as MuZero (Schrittwieser etal., 2020) and EfficientZero (Ye etal., 2021) are not included as lookahead search operates on top of a world model.Here, our aim is to improve the world model component itself.
3.1 Results
On Atari, it is standard to use human-normalized scores (HNS) (Mnih etal., 2015), calculated as , rather than raw game scores.Here, the final score of each training run is computed as an average over 100 episodes collected at the end of training.In the work of (Agarwal etal., 2021), the authors found discrepancies between conclusions drawn from point estimate statistics such as mean and median and a more thorough statistical analysis that also considers the uncertainty in the results.Adhering to their established protocol and utilizing their toolkit, we report the mean, median, and interquantile mean (IQM) human-normalized scores, and the optimality gap, with 95% stratified bootstrap confidence intervals in Figure 7.Performance profiles are presented in Figure 8.Average scores of individual games are reported in Table1.
![Improving Token-Based World Models with Parallel Observation Prediction (9) Improving Token-Based World Models with Parallel Observation Prediction (9)](https://i0.wp.com/arxiv.org/html/2402.05643v5/x9.png)
REM attains an IQM human normalized score of 0.673, outperforming all baselines.Additionally, REM improves over IRIS on 3 out of the 4 metrics (i.e., mean, optimality gap, and IQM), while being comparable in terms of its median score.Remarkably, REM achieves superhuman performance on 12 games, more than any other baseline (Table 1).REM also exhibits state-of-the-art scores on several games, including Assault, Boxing, and Chopper Command.These findings support our empirical claim that REM performs similarly or better than previous token-based approaches while running significantly faster.
3.2 Ablation Studies
To analyze the impact of different components of our approach on REM’s performance, we conduct a series of ablation studies.For each component, we compare the final algorithm to a version where the component of interest is disabled.Due to computational resource constraints, the evaluation is performed on a subset of 8 games from the Atari 100K benchmark using 5 random seeds for each game.This subset includes games with large score differences between IRIS and REM, as we are particularly interested in studying the impact of each component in these games.Concretely, this subset includes the games “Assult”, “Asterix”, “Chopper Command”, “Crazy Climber”, “Demon Attack”, “Gopher”, “Krull”, and “Road Runner”.We performed ablation studies on the following aspects: the POP mechanism, the latent space architecture of and its action inputs, the latent resolution of , and the observation token embeddings used by .
The probability of improvement (Agarwal etal., 2021) and IQM human-normalized scores are presented in Figure 9.Figure 10 offers a comparison of the training times, juxtaposing REM with its efficiency-related ablations.
Analyzing POP
To study the impact of POP on REM’s performance, we replaced the POP-augmented RetNet of with a vanilla RetNet.In this version, denoted as ”No POP”, the prediction of next observation tokens is performed sequentially token-by-token, as done in IRIS.
Our results suggest that POP retains the agent’s performance (Figure 9) while significantly reducing the overall computation time (Figure 10).In Appendix A.4, we provide additional results indicating that the world model’s performance are also retained.POP achieves lower total computation time by expediting the actor-critic learning phase, despite the increased computational cost implied by the observation prediction tokens.
Actor-Critic Architecture and Action Inputs
For , we considered an incremental ablation.First, we replaced the architecture of REM’s controller with that of IRIS (denoted “”).In contrast to REM, this version processes fully reconstructed pixel frames and does not incorporate action inputs.Formally, models .In the second ablation, REM was modified so that only the action inputs of were disabled.This ablation corresponds to .
Our findings indicate that both the latent codes based architecture and the added action inputs contribute to the final performance of REM (Figure 9).Additionally, the latent codes based architecture of leads to reduced computational overhead and shorter actor-critic learning times (Figure 10).
![Improving Token-Based World Models with Parallel Observation Prediction (10) Improving Token-Based World Models with Parallel Observation Prediction (10)](https://i0.wp.com/arxiv.org/html/2402.05643v5/x10.png)
Tokenizer Resolution
Here, we compare REM to a version with a reduced latent resolution of , similar to that of IRIS.The results in Figure 9 provides clear evidence that the latent resolution of the tokenizer has a significant impact on the agent’s performance.Our results demonstrates that POP enables REM to utilize higher latent resolutions while incurring shorter computation times than prior token-based approaches.
World Model Embeddings
In REM, translates observation tokens to embedding vectors using the embedding table learned by .These embeddings encode the visual information as learned by .In contrast, IRIS maintains a separate embedding table learned by the world model for that purpose.Here, the results in Figure 9 provide empirical evidence indicating that leveraging this encoded visual information leads to improved performance.In Appendix A.4, we provide additional evidence suggesting that the world model’s next-observation predictions are also improved.
![Improving Token-Based World Models with Parallel Observation Prediction (11) Improving Token-Based World Models with Parallel Observation Prediction (11)](https://i0.wp.com/arxiv.org/html/2402.05643v5/x11.png)
4 Related Work
Model-based reinforcement learning (RL), with its roots in the tabular setting (Sutton, 1991), has been a focus of extensive research in recent decades.The deep RL agent of (Ha & Schmidhuber, 2018) leveraged an LSTM (Hochreiter & Schmidhuber, 1997) sequence model with a VAE (Kingma & Welling, 2014) to model the dynamics in visual environments, demonstrating that successful policies can be learned entirely from simulated data.This approach, commonly known as world models, was later applied to Atari games (Kaiser etal., 2020) with the PPO (Schulman etal., 2017) RL algorithm.Later, a series of works (Hafner etal., 2020, 2021, 2023) proposed the Dreamer algorithms, which are based on a recurrent state space model (RSSM) (Hafner etal., 2019) to model dynamics.The latest DreamerV3 was evaluated on a variety of challenging environments, providing further evidence of the promising potential of world models.In contrast to token-based approaches, where each token serves as a standalone sequence element, Dreamer encodes each frame as a vector of categorical variables, which are processed at once by the RSSM.
Following the success of the Transformer architecture (Vaswani etal., 2017) in language modeling (Brown etal., 2020), and motivated by their favorable scaling properties compared to RNNs, Transformer were recently explored in RL (Parisotto etal., 2020; Chen etal., 2021; Reed etal., 2022; Shridhar etal., 2023).World model approaches also adopted the Transformer architecture.(Micheli etal., 2023) blazed the trail for token-based world models with IRIS, representing agent trajectories as language-like sequences.By treating each observation as a sequence, its Transformer-based world model gains an explicit sub-observation attention resolution.Despite IRIS’s high performance, its imagination bottleneck results in a substantial disadvantage.
In addition to IRIS, non-token-based world models driven by Transformers were proposed.TWM (Robine etal., 2023) utilizes the Transformer-XL architecture (Dai etal., 2020) and a non-uniform data sampling.STORM (Zhang etal., 2023) proposes an efficient Transformer based world model agent which sets state-of-the-art result for the Atari 100K benchmark.STORM has a significantly smaller 2-layer Transformer compared to the 10-layer models of TWM and IRIS, demonstrating drastically reduced training times and improved agent performance.
5 Conclusions
In this work, we presented a novel parallel observation prediction (POP) mechanism augmenting Retention networks with a dedicated forward mode to improve the efficiency of token-based world models (TBWMs).POP effectively solves the imagination bottleneck of TBWMs and enables them to deal with longer observation sequences.Additionally, we introduced REM, a TBWM agent equipped with POP.REM is the first world model agent driven by the RetNet architecture.Empirically, we demonstrated the superiority of REM over prior TBWMs on the Atari 100K benchmark, rendering REM competitive with the state-of-the-art, both in terms of agent performance and overall run time.
Our work opens up many promising avenues for future research by making TBWMs practical and cost-efficient.One such direction could be to explore a modification of REM where the recurrent state of the world model summarizes the entire history of the agent.Similarly, a history-preserving RetNet architecture should be considered for the controller as well.Another promising avenue would be to leverage the independent optimization of the tokenizer to enable REM to use pretrained visual perception models in environments where visual data is abundant, for example, the real world.Such perceptual models could be trained at scale, and allow REM to store only compressed observations in its replay buffer, further improving its efficiency.Lastly, token-based methods for video generation tasks can benefit from using the POP mechanism for generating entire frames in parallel conditioned on the past context.We believe that this is an exciting avenue to explore with a potentially high impact.
Acknowledgements
This project has received funding from the European Union’s Horizon Europe Programme under grant agreement No. 101070568.
Impact Statement
This paper presents work whose goal is to advance the field of Machine Learning. There are many potential societal consequences of our work, none of which we feel must be specifically highlighted here.
References
- Agarwal etal. (2021)Agarwal, R., Schwarzer, M., Castro, P.S., Courville, A.C., and Bellemare, M.Deep reinforcement learning at the edge of the statistical precipice.In Ranzato, M., Beygelzimer, A., Dauphin, Y., Liang, P., and Vaughan, J.W. (eds.), Advances in Neural Information Processing Systems, volume34, pp. 29304–29320. Curran Associates, Inc., 2021.URL https://proceedings.neurips.cc/paper_files/paper/2021/file/f514cec81cb148559cf475e7426eed5e-Paper.pdf.
- Ba etal. (2016)Ba, J.L., Kiros, J.R., and Hinton, G.E.Layer normalization.arXiv preprint arXiv:1607.06450, 2016.
- Berner etal. (2019)Berner, C., Brockman, G., Chan, B., Cheung, V., Debiak, P., Dennison, C., Farhi, D., Fischer, Q., Hashme, S., Hesse, C., Józefowicz, R., Gray, S., Olsson, C., Pachocki, J., Petrov, M., deOliveiraPinto, H.P., Raiman, J., Salimans, T., Schlatter, J., Schneider, J., Sidor, S., Sutskever, I., Tang, J., Wolski, F., and Zhang, S.Dota 2 with large scale deep reinforcement learning.CoRR, abs/1912.06680, 2019.URL http://arxiv.org/abs/1912.06680.
- Brown etal. (2020)Brown, T., Mann, B., Ryder, N., Subbiah, M., Kaplan, J.D., Dhariwal, P., Neelakantan, A., Shyam, P., Sastry, G., Askell, A., Agarwal, S., Herbert-Voss, A., Krueger, G., Henighan, T., Child, R., Ramesh, A., Ziegler, D., Wu, J., Winter, C., Hesse, C., Chen, M., Sigler, E., Litwin, M., Gray, S., Chess, B., Clark, J., Berner, C., McCandlish, S., Radford, A., Sutskever, I., and Amodei, D.Language models are few-shot learners.In Larochelle, H., Ranzato, M., Hadsell, R., Balcan, M., and Lin, H. (eds.), Advances in Neural Information Processing Systems, volume33, pp. 1877–1901. Curran Associates, Inc., 2020.URL https://proceedings.neurips.cc/paper_files/paper/2020/file/1457c0d6bfcb4967418bfb8ac142f64a-Paper.pdf.
- Bubeck etal. (2023)Bubeck, S., Chandrasekaran, V., Eldan, R., Gehrke, J., Horvitz, E., Kamar, E., Lee, P., Lee, Y.T., Li, Y., Lundberg, S.M., Nori, H., Palangi, H., Ribeiro, M.T., and Zhang, Y.Sparks of artificial general intelligence: Early experiments with GPT-4.CoRR, abs/2303.12712, 2023.doi: 10.48550/ARXIV.2303.12712.URL https://doi.org/10.48550/arXiv.2303.12712.
- Chen etal. (2021)Chen, L., Lu, K., Rajeswaran, A., Lee, K., Grover, A., Laskin, M., Abbeel, P., Srinivas, A., and Mordatch, I.Decision Transformer: Reinforcement Learning via Sequence Modeling.In Advances in Neural Information Processing Systems, volume18, 2021.
- Dai etal. (2020)Dai, Z., Yang, Z., Yang, Y., Carbonell, J., Le, Q.V., and Salakhutdinov, R.Transformer-XL: Attentive language models beyond a fixed-length context.In ACL 2019 - 57th Annual Meeting of the Association for Computational Linguistics, Proceedings of the Conference, 2020.doi: 10.18653/v1/p19-1285.
- Devlin etal. (2019)Devlin, J., Chang, M.W., Lee, K., and Toutanova, K.BERT: Pre-training of deep bidirectional transformers for language understanding.In NAACL HLT 2019 - 2019 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies - Proceedings of the Conference, volume1, 2019.
- Dosovitskiy etal. (2021)Dosovitskiy, A., Beyer, L., Kolesnikov, A., Weissenborn, D., Zhai, X., Unterthiner, T., Dehghani, M., Minderer, M., Heigold, G., Gelly, S., Uszkoreit, J., and Houlsby, N.An image is worth 16x16 words: Transformers for image recognition at scale.In International Conference on Learning Representations, 2021.URL https://openreview.net/forum?id=YicbFdNTTy.
- Esser etal. (2021)Esser, P., Rombach, R., and Ommer, B.Taming transformers for high-resolution image synthesis.In Proceedings of the IEEE/CVF conference on computer vision and pattern recognition, pp. 12873–12883, 2021.
- Ha & Schmidhuber (2018)Ha, D. and Schmidhuber, J.Recurrent world models facilitate policy evolution.In Advances in Neural Information Processing Systems 31, pp. 2451–2463. Curran Associates, Inc., 2018.URL https://papers.nips.cc/paper/7512-recurrent-world-models-facilitate-policy-evolution.https://worldmodels.github.io.
- Hafner etal. (2019)Hafner, D., Lillicrap, T., Fischer, I., Villegas, R., Ha, D., Lee, H., and Davidson, J.Learning latent dynamics for planning from pixels.In 36th International Conference on Machine Learning, ICML 2019, volume 2019-June, 2019.
- Hafner etal. (2020)Hafner, D., Lillicrap, T., Ba, J., and Norouzi, M.Dream to control: Learning behaviors by latent imagination.In International Conference on Learning Representations, 2020.URL https://openreview.net/forum?id=S1lOTC4tDS.
- Hafner etal. (2021)Hafner, D., Lillicrap, T.P., Norouzi, M., and Ba, J.Mastering atari with discrete world models.In 9th International Conference on Learning Representations, ICLR 2021, Virtual Event, Austria, May 3-7, 2021. OpenReview.net, 2021.URL https://openreview.net/forum?id=0oabwyZbOu.
- Hafner etal. (2023)Hafner, D., Pasukonis, J., Ba, J., and Lillicrap, T.Mastering diverse domains through world models.arXiv preprint arXiv:2301.04104, 2023.
- Hendrycks & Gimpel (2017)Hendrycks, D. and Gimpel, K.Bridging nonlinearities and stochastic regularizers with gaussian error linear units, 2017.URL https://openreview.net/forum?id=Bk0MRI5lg.
- Hochreiter & Schmidhuber (1997)Hochreiter, S. and Schmidhuber, J.Long short-term memory.Neural computation, 9(8):1735–1780, 1997.
- Kaiser etal. (2020)Kaiser, Ł., Babaeizadeh, M., Miłos, P., Osiński, B., Campbell, R.H., Czechowski, K., Erhan, D., Finn, C., Kozakowski, P., Levine, S., Mohiuddin, A., Sepassi, R., Tucker, G., and Michalewski, H.Model based reinforcement learning for atari.In International Conference on Learning Representations, 2020.URL https://openreview.net/forum?id=S1xCPJHtDB.
- Kingma & Welling (2014)Kingma, D.P. and Welling, M.Auto-encoding variational bayes.In 2nd International Conference on Learning Representations, ICLR 2014 - Conference Track Proceedings, 2014.
- Li etal. (2023)Li, T., Chang, H., Mishra, S., Zhang, H., Katabi, D., and Krishnan, D.Mage: Masked generative encoder to unify representation learning and image synthesis.In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 2142–2152, 2023.
- Micheli etal. (2023)Micheli, V., Alonso, E., and Fleuret, F.Transformers are sample-efficient world models.In The Eleventh International Conference on Learning Representations, ICLR 2023, Kigali, Rwanda, May 1-5, 2023. OpenReview.net, 2023.URL https://openreview.net/pdf?id=vhFu1Acb0xb.
- Mnih etal. (2015)Mnih, V., Kavukcuoglu, K., Silver, D., Rusu, A.A., Veness, J., Bellemare, M.G., Graves, A., Riedmiller, M., Fidjeland, A.K., Ostrovski, G., etal.Human-level control through deep reinforcement learning.nature, 518(7540):529–533, 2015.
- Parisotto etal. (2020)Parisotto, E., Song, F., Rae, J., Pascanu, R., Gulcehre, C., Jayakumar, S., Jaderberg, M., Kaufman, R.L., Clark, A., Noury, S., Botvinick, M., Heess, N., and Hadsell, R.Stabilizing transformers for reinforcement learning.In III, H.D. and Singh, A. (eds.), Proceedings of the 37th International Conference on Machine Learning, volume 119 of Proceedings of Machine Learning Research, pp. 7487–7498. PMLR, 13–18 Jul 2020.URL https://proceedings.mlr.press/v119/parisotto20a.html.
- Ramachandran etal. (2018)Ramachandran, P., Zoph, B., and Le, Q.V.Searching for activation functions, 2018.URL https://openreview.net/forum?id=SkBYYyZRZ.
- Reed etal. (2022)Reed, S., Zolna, K., Parisotto, E., Colmenarejo, S.G., Novikov, A., Barth-maron, G., Giménez, M., Sulsky, Y., Kay, J., Springenberg, J.T., Eccles, T., Bruce, J., Razavi, A., Edwards, A., Heess, N., Chen, Y., Hadsell, R., Vinyals, O., Bordbar, M., and deFreitas, N.A generalist agent.Transactions on Machine Learning Research, 2022.ISSN 2835-8856.URL https://openreview.net/forum?id=1ikK0kHjvj.Featured Certification, Outstanding Certification.
- Robine etal. (2023)Robine, J., Höftmann, M., Uelwer, T., and Harmeling, S.Transformer-based world models are happy with 100k interactions.In The Eleventh International Conference on Learning Representations, 2023.URL https://openreview.net/forum?id=TdBaDGCpjly.
- Schrittwieser etal. (2020)Schrittwieser, J., Antonoglou, I., Hubert, T., Simonyan, K., Sifre, L., Schmitt, S., Guez, A., Lockhart, E., Hassabis, D., Graepel, T., Lillicrap, T., and Silver, D.Mastering Atari, Go, chess and shogi by planning with a learned model.Nature, 588(7839):604–609, dec 2020.ISSN 14764687.doi: 10.1038/s41586-020-03051-4.URL https://www.nature.com/articles/s41586-020-03051-4.
- Schulman etal. (2017)Schulman, J., Wolski, F., Dhariwal, P., Radford, A., and Klimov, O.Proximal policy optimization algorithms.ArXiv, abs/1707.06347, 2017.URL https://api.semanticscholar.org/CorpusID:28695052.
- Shridhar etal. (2023)Shridhar, M., Manuelli, L., and Fox, D.Perceiver-actor: A multi-task transformer for robotic manipulation.In Liu, K., Kulic, D., and Ichnowski, J. (eds.), Proceedings of The 6th Conference on Robot Learning, volume 205 of Proceedings of Machine Learning Research, pp. 785–799. PMLR, 14–18 Dec 2023.URL https://proceedings.mlr.press/v205/shridhar23a.html.
- Silver etal. (2016)Silver, D., Huang, A., Maddison, C.J., Guez, A., Sifre, L., vanden Driessche, G., Schrittwieser, J., Antonoglou, I., Panneershelvam, V., Lanctot, M., Dieleman, S., Grewe, D., Nham, J., Kalchbrenner, N., Sutskever, I., Lillicrap, T., Leach, M., Kavukcuoglu, K., Graepel, T., and Hassabis, D.Mastering the game of go with deep neural networks and tree search.Nature, 529:484–503, 2016.URL http://www.nature.com/nature/journal/v529/n7587/full/nature16961.html.
- Sun etal. (2023)Sun, Y., Dong, L., Huang, S., Ma, S., Xia, Y., Xue, J., Wang, J., and Wei, F.Retentive network: A successor to transformer for large language models.arXiv preprint arXiv:2307.08621, 2023.
- Sutton (1991)Sutton, R.S.Dyna, an integrated architecture for learning, planning, and reacting.ACM SIGART Bulletin, 2(4), 1991.ISSN 0163-5719.doi: 10.1145/122344.122377.
- Sutton & Barto (2018)Sutton, R.S. and Barto, A.G.Reinforcement Learning: An Introduction.A Bradford Book, Cambridge, MA, USA, 2018.ISBN 0262039249.
- Touvron etal. (2023)Touvron, H., Martin, L., Stone, K.R., Albert, P., Almahairi, A., Babaei, Y., Bashlykov, N., Batra, S., Bhargava, P., Bhosale, S., Bikel, D.M., Blecher, L., Ferrer, C.C., Chen, M., Cucurull, G., Esiobu, D., Fernandes, J., Fu, J., Fu, W., Fuller, B., Gao, C., Goswami, V., Goyal, N., Hartshorn, A.S., Hosseini, S., Hou, R., Inan, H., Kardas, M., Kerkez, V., Khabsa, M., Kloumann, I.M., Korenev, A.V., Koura, P.S., Lachaux, M.-A., Lavril, T., Lee, J., Liskovich, D., Lu, Y., Mao, Y., Martinet, X., Mihaylov, T., Mishra, P., Molybog, I., Nie, Y., Poulton, A., Reizenstein, J., Rungta, R., Saladi, K., Schelten, A., Silva, R., Smith, E.M., Subramanian, R., Tan, X., Tang, B., Taylor, R., Williams, A., Kuan, J.X., Xu, P., Yan, Z., Zarov, I., Zhang, Y., Fan, A., Kambadur, M., Narang, S., Rodriguez, A., Stojnic, R., Edunov, S., and Scialom, T.Llama 2: Open foundation and fine-tuned chat models.ArXiv, abs/2307.09288, 2023.URL https://api.semanticscholar.org/CorpusID:259950998.
- vanden Oord etal. (2017)vanden Oord, A., Vinyals, O., and kavukcuoglu, k.Neural discrete representation learning.In Guyon, I., Luxburg, U.V., Bengio, S., Wallach, H., Fergus, R., Vishwanathan, S., and Garnett, R. (eds.), Advances in Neural Information Processing Systems, volume30. Curran Associates, Inc., 2017.URL https://proceedings.neurips.cc/paper_files/paper/2017/file/7a98af17e63a0ac09ce2e96d03992fbc-Paper.pdf.
- Vaswani etal. (2017)Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A.N., Kaiser, L.u., and Polosukhin, I.Attention is all you need.In Guyon, I., Luxburg, U.V., Bengio, S., Wallach, H., Fergus, R., Vishwanathan, S., and Garnett, R. (eds.), Advances in Neural Information Processing Systems, volume30. Curran Associates, Inc., 2017.URL https://proceedings.neurips.cc/paper_files/paper/2017/file/3f5ee243547dee91fbd053c1c4a845aa-Paper.pdf.
- Vinyals etal. (2019)Vinyals, O., Babuschkin, I., Czarnecki, W.M., Mathieu, M., Dudzik, A., Chung, J., Choi, D.H., Powell, R., Ewalds, T., Georgiev, P., Oh, J., Horgan, D., Kroiss, M., Danihelka, I., Huang, A., Sifre, L., Cai, T., Agapiou, J.P., Jaderberg, M., Vezhnevets, A.S., Leblond, R., Pohlen, T., Dalibard, V., Budden, D., Sulsky, Y., Molloy, J., Paine, T.L., Gulcehre, C., Wang, Z., Pfaff, T., Wu, Y., Ring, R., Yogatama, D., Wünsch, D., McKinney, K., Smith, O., Schaul, T., Lillicrap, T., Kavukcuoglu, K., Hassabis, D., Apps, C., and Silver, D.Grandmaster level in StarCraft II using multi-agent reinforcement learning.Nature, 575(7782), 2019.ISSN 14764687.doi: 10.1038/s41586-019-1724-z.
- Ye etal. (2021)Ye, W., Liu, S., Kurutach, T., Abbeel, P., and Gao, Y.Mastering Atari Games with Limited Data.In Advances in Neural Information Processing Systems, volume30, 2021.
- Zhang etal. (2023)Zhang, W., Wang, G., Sun, J., Yuan, Y., and Huang, G.STORM: Efficient stochastic transformer based world models for reinforcement learning.In Thirty-seventh Conference on Neural Information Processing Systems, 2023.URL https://openreview.net/forum?id=WxnrX42rnS.
Appendix A Appendix
A.1 Models and Hyperparameters
Tables 2 and 3 detail hyperparameters of the optimization and environment, as well as hyperparameters shared by multiple components.
Description | Symbol | Value |
Horizon | H | 10 |
Tokens per observation | K | 64 |
Tokenizer vocabulary size | N | 512 |
Epochs | - | 600 |
Experience collection epochs | - | 500 |
Environment steps per epoch | - | 200 |
Collection epsilon-greedy | - | 0.01 |
Eval sampling temperature | - | 0.5 |
Optimizer | - | AdamW |
AdamW | - | 0.9 |
AdamW | - | 0.999 |
Frame resolution | - | |
Frame Skip | - | 4 |
Max no-ops (train, test) | - | (30, 1) |
Max episode steps (train, test) | - | (20K, 108K) |
Terminate on live loss (train, test) | - | (No, Yes) |
Description | Symbol | Tokenizer | World Model | Actor-Critic |
Learning rate | - | 0.0001 | 0.0002 | 0.0001 |
Batch size | - | 128 | 64 | 128 |
Gradient Clipping Threshold | - | 10 | 100 | 3 |
Start after epochs | - | 5 | 25 | 50 |
Training Steps per epoch | - | 200 | 200 | 100 |
AdamW Weight Decay | - | 0.01 | 0.05 | 0.01 |
A.1.1 Tokenizer ()
Tokenizer Architecture
Our tokenizer is based on the implementation of VQ-GAN (Esser etal., 2021).However, we simplified the architectures of the encoder and decoder networks.A description of the architectures of the encoder and decoder networks can be found in table 4.
Module | Output Shape |
Encoder | |
Input | |
Conv(3, 1, 1) | |
EncoderBlock1 | |
EncoderBlock2 | |
EncoderBlock3 | |
GN | |
SiLU | |
Conv(3, 1, 1) | |
EncoderBlock | |
Input | |
GN | |
SiLU | |
Conv(3, 2, Asym.) | |
Conv(3, 1, 1) | |
Decoder | |
Input | |
Conv(3, 1, 1) | |
DecoderBlock1 | |
DecoderBlock2 | |
DecoderBlock3 | |
GN | |
SiLU | |
Conv(3, 1, 1) | |
DecoderBlock | |
Input | |
GN | |
SiLU | |
Interpolate | |
Conv(3, 1, 1) | |
Conv(3, 1, 1) |
Tokenizer Learning
Following IRIS (Micheli etal., 2023), our tokenizer is a VQ-VAE (vanden Oord etal., 2017) based on the implementation of (Esser etal., 2021) (without the discriminator).The training objective is given by
(8) |
where and are the encoder and decoder models, respectively, and is the stop-gradient operator.The first term on the right hand side of Equation 8 above is the reconstruction loss, the second and third terms correspond to the commitment loss, and the last term is the perceptual loss.
A.1.2 Retentive World Model ()
The hyperparameters of are presented in Table 5.
Implementation Details
We use the “Yet-Another-RetNet” RetNet implementation111https://github.com/fkodom/yet-another-retnet, as its code is simple and convenient while its performance remain competitive with the official implementation in terms of run time and efficiency.
Originally, the IRIS algorithm provides the world model with a single observation to make forward predictions.Our implementation considers a context of two frames for making forward predictions.
Description | Symbol | Value |
Number of layers | - | 5 |
Number of Retention heads | - | 4 |
Embedding dimension | d | 256 |
Dropout | - | 0.1 |
RetNet feed-forward dimension | - | 1024 |
RetNet LayerNorm epsilon | - | 1e-6 |
Blocks per chunk | c | 3 |
World model context length | - | 2 |
POP Observation-Generation Modes
The simpler observation generation mode presented in Section 2.3 requires two sequential world model calls to generate the next observation. The first call consumes the previous observation-action block to compute the recurrent state at the current time step, while the second call uses to generate the next observation tokens.Note that the next observation tokens sampled in the second call have not been processed by the world model at this point.To incorporate these tokens into the recurrent state, an additional world model call is required.Here, the cost induced by the first world model call is , while the second call costs .
Alternatively, it is also possible to combine these two calls into one by concatenating the previous observation-action block and .However, to avoid having incorporated in the recurrent state computed by this call, a modified forward call should be used.Concretely, the resulting recurrent state should only summarize the previous observation-action block, neglecting the suffix .In practice, we use the backbone of the POP chunkwise forward mode (Alg. 1) for this computation.This alternative mode induces only sequential world model calls, while each call processes .Hence, this alternative reduces the number of sequential calls while maintaining the same total cost.
In practice, the optimal mode to use depends on the configuration, model sizes, and hardware.In our configuration, we opted for a larger batch size for the imagination phase.Hence, we found the first (simpler) mode to be slightly more efficient in this case.However, the second mode could be more efficient in other settings.
Algorithm | ObservationPrediction Cost | World ModelTraining Cost | ImaginationSequential Calls | ImaginationCost per Call |
POP (default mode) | ||||
POP (alternative mode) | ||||
No POP |
A.1.3 Controller ()
Actor-Critic Learning
Our learning algorithm follows IRIS and Dreamer (Micheli etal., 2023; Hafner etal., 2020, 2021), which uses -returns defined recursively as
where is the value network learned by the critic, and is a trajectory obtained through world model imagination.
To optimize , the following loss is minimized:
The policy optimization follows a simple REINFORCE (Sutton & Barto, 2018) objective, with used as a baseline for variance reduction.The objective is given by
where The values of the hyperparameters used in our experiments are detailed in Table 7
Description | Symbol | Value |
Discount factor | 0.995 | |
-return | 0.95 | |
Entropy loss weight | 0.001 |
Module | Output Shape |
Input | |
Conv(3, 1, 1) | |
SiLU | |
Conv(3, 1, 1) | |
SiLU | |
Flatten | 4096 |
Linear | 512 |
SiLU | 512 |
Agent Architecture
The architecture of the agent module comprises of a shared backbone and two linear maps for the actor and critic heads, respectively.The shared backbone first maps the input to a latent representation which takes the form of a 512-dimensional vector.For action token inputs, a learned embedding table is used to map the token to its latent representation.For observation inputs, the tokens are first mapped to their corresponding code vectors learned by the tokenizer and reshaped according to their original spatial order.Then, the resulting tensor is processed by a convolutional neural network followed by a fully connected network.The architecture details of these networks are presented in Table 8.Lastly, a long-short term memory (LSTM) (Hochreiter & Schmidhuber, 1997) network of dimension maps the processed input vector to an history-dependant latent vector, which serves as the output of the shared backbone.
Action Dependant Actor Critic
In the IRIS algorithm, the actor and critic networks share an LSTM (Hochreiter & Schmidhuber, 1997) backbone and model .Notice that the output of the policy models the distribution of actions at step .Importantly, the model has no information about the sampled actions.In REM, the input of contains the sampled actions, i.e., our algorithm models .
A.2 REM Algorithm
Here, we present a pseudo-code of REM.The high-level loop is presented in Algorithm 3, while the pseudo-codes of the training of each component are presented in algorithms 4-7.
Input:
repeat
collect_experience() (Alg. 4)
train_V() (Alg. 5)
train_M() (Alg. 6)
train_C() (Alg. 7)
untilstopping criterion is met
Input:
for to do
ifthen
endif
endfor
Compute loss (Eqn. 8)
Update
for to do
endfor
for to do
(Alg. 1)
endfor
Compute Losses and update
for to do
endfor
Initialize context
for to do
endfor
Update (detailed in Section A.1.3)
A.3 Retentive Networks
In this section, we give detailed information regarding the RetNet architecture for the completeness of this paper.For convenience reasons, we defer to the notations of (Sun etal., 2023), rather than the notation presented in
(Sun etal., 2023) is a recent alternative to Transformers (Vaswani etal., 2017). It is highly parallelizable, has lower cost inference than Transformers, and is empirically claimed to perform competitively on language modelling tasks.The RetNet model is a stack of identical layers.Here, we denote the output of the -th layer by .Given an embedded input sequence of -dimensional vectors, each RetNet layer can be described as
(9) | |||
(10) |
where is layer-norm (Ba etal., 2016), is a feed-forward network (FFN), and is a multi-scale retention (MSR) module with multiple Retention heads.The output of the RetNet model is given by .
As presented in the main text, the chunkwise equations are
where , , , and is a matrix with .Here, , , and ,
where is an exponential decay factor, and the matrices are for relative position encoding, and combines an auto-regressive mask with the temporal decay factor .
In each RetNet layer, heads are used, where is the dimension of each head.Head Retention head uses different parameters .Additionally, Retention head uses a different value of .Among different RetNet layers, the values of are fixed.Each layer is defined as follows:
where are learnable parameters.
A.4 Additional Results
In addition to comparing the run times of REM and IRIS, we also conducted a comparison to an improved version of IRIS that uses REM’s configurations.These results are presented in Figure 11.These results clearly show the effectiveness of our novel POP mechanism.
![Improving Token-Based World Models with Parallel Observation Prediction (12) Improving Token-Based World Models with Parallel Observation Prediction (12)](https://i0.wp.com/arxiv.org/html/2402.05643v5/x12.png)
The probability of improvement results from our Atari 100K benchmark experiment are presented in Figure 12.Importantly, REM outperforms previous token-based methods, namely, IRIS, while competitive with all baselines except STORM on this metric.We highlight that our main contributions address the computational bottleneck of token-based methods, and thus we focus on comparing REM to these approaches.
![Improving Token-Based World Models with Parallel Observation Prediction (13) Improving Token-Based World Models with Parallel Observation Prediction (13)](https://i0.wp.com/arxiv.org/html/2402.05643v5/x13.png)
The complete set of ablation results are presented in Figure 13 and Table 9.The performance profiles for the ablations are presented in Figure 14.
![Improving Token-Based World Models with Parallel Observation Prediction (14) Improving Token-Based World Models with Parallel Observation Prediction (14)](https://i0.wp.com/arxiv.org/html/2402.05643v5/x14.png)
![Improving Token-Based World Models with Parallel Observation Prediction (15) Improving Token-Based World Models with Parallel Observation Prediction (15)](https://i0.wp.com/arxiv.org/html/2402.05643v5/x15.png)
Game | Random | Human | REM | IRIS | No POP | Separate emb. | tokenizer | w/oaction inputs | |
Assault | 222.4 | 742.0 | 1764.2 | 1524.4 | 1472.2 | 1269.2 | 1288.9 | 1221.5 | 1498.5 |
Asterix | 210.0 | 8503.3 | 1637.5 | 853.6 | 1603.4 | 1185.9 | 909.6 | 1376.4 | 1656.3 |
ChopperCommand | 811.0 | 7387.8 | 2561.2 | 1565.0 | 1848.0 | 1928.3 | 1958.9 | 2517.6 | 2302.7 |
CrazyClimber | 10780.5 | 35829.4 | 76547.6 | 59324.2 | 62964.8 | 74791.3 | 57814.7 | 30952.7 | 42441.2 |
DemonAttack | 152.1 | 1971.0 | 5738.6 | 2034.4 | 12316.0 | 4389.9 | 3863.3 | 5159.0 | 5827.0 |
Gopher | 257.6 | 2412.5 | 5452.4 | 2236.1 | 5338.4 | 3764.2 | 2174.9 | 2891.2 | 4365.3 |
Krull | 1598.0 | 2665.5 | 4017.7 | 6616.4 | 5138.6 | 5779.9 | 4612.8 | 3866.2 | 3659.6 |
RoadRunner | 11.5 | 7845.0 | 14060.2 | 9614.6 | 13161.6 | 11723.5 | 6161.7 | 11692.9 | 11692.9 |
#Superhuman (↑) | 0 | N/A | 6 | 5 | 6 | 6 | 4 | 5 | 6 |
Mean (↑) | 0.000 | 1.000 | 1.947 | 1.564 | 2.357 | 1.778 | 1.341 | 1.340 | 1.571 |
Median (↑) | 0.000 | 1.000 | 2.339 | 1.130 | 2.221 | 1.821 | 1.384 | 1.357 | 1.699 |
IQM (↑) | 0.000 | 1.000 | 2.201 | 1.191 | 2.068 | 1.794 | 1.026 | 1.234 | 1.535 |
Optimality Gap (↓) | 1.000 | 0.000 | 0.198 | 0.298 | 0.209 | 0.214 | 0.289 | 0.261 | 0.208 |
Ablations World Model Observation Prediction Losses
To investigate the contribution of each ablation to the quality of world model observation predictions, we measured the corresponding loss values during training and during test episodes with a frequency of 50 epochs.The results are presented in Figure 15, including results for each of the 8 games used in our ablation studies.
A.5 Setup in Freeway
For a fair comparison, we followed the actor-critic configurations of IRIS (Micheli etal., 2023) for Freeway.Specifically, the sampling temperature of the agent is modified from 1 to 0.01, a heuristic that guides the agent towards non-zero reward trajectories.We highlight that different methods use other mechanisms such as epsilon-greedy schedules and “argmax” action selection policies to overcome this exploration challenge (Micheli etal., 2023).