Deep Q-Learning (DQN)
Overview
As an extension of the Q-learning, DQN's main technical contribution is the use of replay buffer and target network, both of which would help improve the stability of the algorithm.
Original papers:
Implemented Variants
Variants Implemented | Description |
---|---|
dqn_atari.py , docs |
For playing Atari games. It uses convolutional layers and common atari-based pre-processing techniques. |
dqn.py , docs |
For classic control tasks like CartPole-v1 . |
Below are our single-file implementations of DQN:
dqn_atari.py
The dqn_atari.py has the following features:
- For playing Atari games. It uses convolutional layers and common atari-based pre-processing techniques.
- Works with the Atari's pixel
Box
observation space of shape(210, 160, 3)
- Works with the
Discrete
action space
Usage
poetry install -E atari
python cleanrl/dqn_atari.py --env-id BreakoutNoFrameskip-v4
python cleanrl/dqn_atari.py --env-id PongNoFrameskip-v4
Explanation of the logged metrics
Running python cleanrl/dqn_atari.py
will automatically record various metrics such as actor or value losses in Tensorboard. Below is the documentation for these metrics:
charts/episodic_return
: episodic return of the gamecharts/SPS
: number of steps per secondlosses/td_loss
: the mean squared error (MSE) between the Q values at timestep \(t\) and the Bellman update target estimated using the reward \(r_t\) and the Q values at timestep \(t+1\), thus minimizing the one-step temporal difference. Formally, it can be expressed by the equation below. $$ J(\theta^{Q}) = \mathbb{E}_{(s,a,r,s') \sim \mathcal{D}} \big[ (Q(s, a) - y)^2 \big], $$ with the Bellman update target is \(y = r + \gamma \, Q^{'}(s', a')\) and the replay buffer is \(\mathcal{D}\).losses/q_values
: implemented asqf1(data.observations, data.actions).view(-1)
, it is the average Q values of the sampled data in the replay buffer; useful when gauging if under or over estimation happens.
Implementation details
dqn_atari.py is based on (Mnih et al., 2015)1 but presents a few implementation differences:
dqn_atari.py
use slightly different hyperparameters. Specifically,dqn_atari.py
uses the more popular Adam Optimizer with the--learning-rate=1e-4
as follows:whereas (Mnih et al., 2015)1 (Exntended Data Table 1) uses the RMSProp optimizer withoptim.Adam(q_network.parameters(), lr=1e-4)
--learning-rate=2.5e-4
, gradient momentum0.95
, squared gradient momentum0.95
, and min squared gradient0.01
as follows:optim.RMSprop( q_network.parameters(), lr=2.5e-4, momentum=0.95, # ... PyTorch's RMSprop does not directly support # squared gradient momentum and min squared gradient # so we are not sure what to put here. )
dqn_atari.py
uses--learning-starts=80000
whereas (Mnih et al., 2015)1 (Exntended Data Table 1) uses--learning-starts=50000
.dqn_atari.py
uses--target-network-frequency=1000
whereas (Mnih et al., 2015)1 (Exntended Data Table 1) uses--learning-starts=10000
.dqn_atari.py
uses--total-timesteps=10000000
(i.e., 10M timesteps = 40M frames because of frame-skipping) whereas (Mnih et al., 2015)1 uses--total-timesteps=50000000
(i.e., 50M timesteps = 200M frames) (See "Training details" under "METHODS" on page 6 and the related source code run_gpu#L32, dqn/train_agent.lua#L81-L82, and dqn/train_agent.lua#L165-L169).dqn_atari.py
uses--end-e=0.01
(the final exploration epsilon) whereas (Mnih et al., 2015)1 (Exntended Data Table 1) uses--end-e=0.1
.dqn_atari.py
uses--exploration-fraction=0.1
whereas (Mnih et al., 2015)1 (Exntended Data Table 1) uses--exploration-fraction=0.02
(all corresponds to 250000 steps or 1M frames being the frame that epsilon is annealed to--end-e=0.1
).dqn_atari.py
handles truncation and termination properly like (Mnih et al., 2015)1 by using SB3's replay buffer'shandle_timeout_termination=True
.
dqn_atari.py
use a self-contained evaluation scheme:dqn_atari.py
reports the episodic returns obtained throughout training, whereas (Mnih et al., 2015)1 is trained with--end-e=0.1
but reported episodic returns using a separate evaluation process with--end-e=0.01
(See "Evaluation procedure" under "METHODS" on page 6).dqn_atari.py
rescales the gradient so that the norm of the parameters does not exceed0.5
like done in PPO ( ppo2/model.py#L102-L108).
Experiment results
PR vwxyzjn/cleanrl#124 tracks our effort to conduct experiments, and the reprodudction instructions can be found at vwxyzjn/cleanrl/benchmark/dqn.
Below are the average episodic returns for dqn_atari.py
.
Environment | dqn_atari.py 10M steps |
(Mnih et al., 2015)1 50M steps | (Hessel et al., 2017, Figure 5)3 |
---|---|---|---|
BreakoutNoFrameskip-v4 | 337.64 ± 69.47 | 401.2 ± 26.9 | ~230 at 10M steps, ~300 at 50M steps |
PongNoFrameskip-v4 | 20.293 ± 0.37 | 18.9 ± 1.3 | ~20 10M steps, ~20 at 50M steps |
BeamRiderNoFrameskip-v4 | 6207.41 ± 1019.96 | 6846 ± 1619 | ~6000 10M steps, ~7000 at 50M steps |
Note that we save computational time by reducing timesteps from 50M to 10M, but our dqn_atari.py
scores the same or higher than (Mnih et al., 2015)1 in 10M steps.
Learning curves:
Tracked experiments and game play videos:
dqn.py
The dqn.py has the following features:
- Works with the
Box
observation space of low-level features - Works with the
Discrete
action space - Works with envs like
CartPole-v1
Usage
python cleanrl/dqn.py --env-id CartPole-v1
Explanation of the logged metrics
See related docs for dqn_atari.py
.
Implementation details
The dqn.py shares the same implementation details as dqn_atari.py
except the dqn.py
runs with different hyperparameters and neural network architecture. Specifically,
dqn.py
uses a simpler neural network as follows:self.network = nn.Sequential( nn.Linear(np.array(env.single_observation_space.shape).prod(), 120), nn.ReLU(), nn.Linear(120, 84), nn.ReLU(), nn.Linear(84, env.single_action_space.n), )
dqn.py
runs with different hyperparameters. See vwxyzjn/cleanrl/benchmark/dqn.
Experiment results
PR vwxyzjn/cleanrl#157 tracks our effort to conduct experiments, and the reprodudction instructions can be found at vwxyzjn/cleanrl/benchmark/dqn.
Below are the average episodic returns for dqn.py
.
Environment | dqn.py 10M steps |
(Mnih et al., 2015)1 50M steps | (Hessel et al., 2017, Figure 5)3 |
---|---|---|---|
BreakoutNoFrameskip-v4 | 337.64 ± 69.47 | 401.2 ± 26.9 | ~230 at 10M steps, ~300 at 50M steps |
PongNoFrameskip-v4 | 20.293 ± 0.37 | 18.9 ± 1.3 | ~20 10M steps, ~20 at 50M steps |
BeamRiderNoFrameskip-v4 | 6207.41 ± 1019.96 | 6846 ± 1619 | ~6000 10M steps, ~7000 at 50M steps |
Note that we save computational time by reducing timesteps from 50M to 10M, but our dqn.py
scores the same or higher than (Mnih et al., 2015)1 in 10M steps.
Learning curves:
Tracked experiments and game play videos:
-
Mnih, V., Kavukcuoglu, K., Silver, D. et al. Human-level control through deep reinforcement learning. Nature 518, 529–533 (2015). https://doi.org/10.1038/nature14236 ↩↩↩↩↩↩↩↩↩↩↩↩↩
-
[Proposal] Formal API handling of truncation vs termination. https://github.com/openai/gym/issues/2510 ↩
-
Hessel, M., Modayil, J., Hasselt, H.V., Schaul, T., Ostrovski, G., Dabney, W., Horgan, D., Piot, B., Azar, M.G., & Silver, D. (2018). Rainbow: Combining Improvements in Deep Reinforcement Learning. AAAI. ↩↩