Skip to content

Commit

Permalink
improve PPO
Browse files Browse the repository at this point in the history
  • Loading branch information
qlan3 committed Apr 12, 2024
1 parent d6fdbcc commit 39a5910
Show file tree
Hide file tree
Showing 22 changed files with 146 additions and 158 deletions.
10 changes: 5 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,16 @@ Jaxplorer is a **Jax** reinforcement learning (RL) framework for **exploring** n
## TODO

- Add more descriptions about slurm and hyper-parameter comparison.
- Add more descriptions about slurm and experimental result analysis.
- Add more algorithms, such as DQN for Atari games.


## Installation

- Python: 3.11
- [Jax](https://jax.readthedocs.io/en/latest/installation.html): >=0.4.20
- [Gymnasium](https://github.com/Farama-Foundation/Gymnasium): `pip install gymnasium==0.29.1`
- [MuJoCo](https://github.com/google-deepmind/mujoco): `pip install mujoco==2.3.7`
- [Gymnasium(mujoco)](https://gymnasium.farama.org/environments/mujoco/): `pip install gymnasium[mujoco]`
- [MuJoCo](https://github.com/google-deepmind/mujoco): `pip install 'mujoco>=2.3.6,<3.0'`
- [Gymnasium](https://github.com/Farama-Foundation/Gymnasium): `pip install 'gymnasium[box2d,mujoco]>=0.29.1,<1.0'`
- [Gym Games](https://github.com/qlan3/gym-games): >=2.0.0.
- Others: `pip install -r requirements.txt`.

Expand Down Expand Up @@ -124,4 +123,5 @@ If you find this repo useful to your research, please cite this repo:
- [Explorer](https://github.com/qlan3/Explorer)
- [Jax RL](https://github.com/ikostrikov/jaxrl)
- [CleanRL](https://github.com/vwxyzjn/cleanrl)
- [PureJaxRL](https://github.com/luchris429/purejaxrl)
- [PureJaxRL](https://github.com/luchris429/purejaxrl)
- [rl-basics](https://github.com/vcharraut/rl-basics)
4 changes: 2 additions & 2 deletions agents/DDPG.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def critic_loss(params):
critic_state = critic_state.apply_gradients(grads=grads)
# Soft-update target network
critic_state = critic_state.replace(
target_params = optax.incremental_update(critic_state.params, critic_state.target_params, self.cfg['tau'])
target_params = optax.incremental_update(critic_state.params, critic_state.target_params, self.cfg['agent']['tau'])
)
return critic_state

Expand All @@ -89,6 +89,6 @@ def actor_loss(params):
actor_state = actor_state.apply_gradients(grads=grads)
# Soft-update target network
actor_state = actor_state.replace(
target_params = optax.incremental_update(actor_state.params, actor_state.target_params, self.cfg['tau'])
target_params = optax.incremental_update(actor_state.params, actor_state.target_params, self.cfg['agent']['tau'])
)
return actor_state, None
3 changes: 0 additions & 3 deletions agents/MaxminDQN.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,4 @@
import jax
import time
import random
from tqdm import tqdm
import flax.linen as nn
import jax.numpy as jnp
from functools import partial
Expand Down
2 changes: 1 addition & 1 deletion agents/NAF.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,6 @@ def critic_loss(params):
critic_state = critic_state.apply_gradients(grads=grads)
# Soft-update target network
critic_state = critic_state.replace(
target_params = optax.incremental_update(critic_state.params, critic_state.target_params, self.cfg['tau'])
target_params = optax.incremental_update(critic_state.params, critic_state.target_params, self.cfg['agent']['tau'])
)
return critic_state
98 changes: 36 additions & 62 deletions agents/PPO.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,35 +2,11 @@
import time
import math
import distrax
import numpy as np
from tqdm import tqdm
import jax.numpy as jnp
import flax.linen as nn
from functools import partial

import gymnasium as gym
from gymnasium import wrappers
from envs.wrappers import UniversalSeed

def ppo_make_env(env_name, gamma=0.99, deque_size=1, **kwargs):
""" Make env for PPO. """
env = gym.make(env_name, **kwargs)
# Episode statistics wrapper: set it before reward wrappers
env = wrappers.RecordEpisodeStatistics(env, deque_size=deque_size)
# Action wrapper
env = wrappers.ClipAction(wrappers.RescaleAction(env, min_action=-1, max_action=1))
# Obs wrapper
env = wrappers.FlattenObservation(env) # For dm_control
env = wrappers.NormalizeObservation(env)
env = wrappers.TransformObservation(env, lambda obs: np.clip(obs, -10, 10))
# Reward wrapper
env = wrappers.NormalizeReward(env, gamma=gamma)
env = wrappers.TransformReward(env, lambda reward: np.clip(reward, -10, 10))
# Seed wrapper: must be the last wrapper to be effective
env = UniversalSeed(env)
return env


from envs.env import ppo_make_env
from agents.BaseAgent import BaseAgent, TrainState
from components.replay import FiniteReplay
from components.networks import MLPVCritic, MLPGaussianActor, MLPCategoricalActor
Expand Down Expand Up @@ -66,7 +42,6 @@ def linear_schedule(count):
return frac * lr
dummy_obs = self.env['Train'].observation_space.sample()[None,]
self.seed, actor_seed, critic_seed = jax.random.split(self.seed, 3)
# Set actor network
if self.action_type == 'DISCRETE':
actor_net = MLPCategoricalActor
elif self.action_type == 'CONTINUOUS':
Expand Down Expand Up @@ -102,12 +77,13 @@ def run_steps(self, mode='Train'):
# Take a env step
next_obs, reward, terminated, truncated, info = self.env[mode].step(action)
# Save experience
mask = self.discount * (1 - terminated)
done = terminated or truncated
mask = self.discount * (1.0 - done)
self.save_experience(obs, action, reward, mask, v, log_pi)
# Update observation
obs = next_obs
# Record and reset
if terminated or truncated:
if done:
result_dict = {
'Env': self.env_name,
'Agent': self.agent_name,
Expand Down Expand Up @@ -148,20 +124,20 @@ def get_action(self, step, obs, mode='Train'):

@partial(jax.jit, static_argnames=['self'])
def random_action(self, actor_state, critic_state, obs, seed):
seed, action_seed = jax.random.split(seed, 2)
action_mean, action_log_std = actor_state.apply_fn(actor_state.params, obs)
pi = distrax.MultivariateNormalDiag(action_mean, jnp.exp(action_log_std))
v = critic_state.apply_fn(critic_state.params, obs)
seed, action_seed = jax.random.split(seed)
action_mean, action_std = actor_state.apply_fn(actor_state.params, obs)
pi = distrax.Normal(loc=action_mean, scale=action_std)
action = pi.sample(seed=action_seed)
log_pi = pi.log_prob(action)
log_pi = pi.log_prob(action).sum(-1)
v = critic_state.apply_fn(critic_state.params, obs)
return action, v, log_pi, seed

@partial(jax.jit, static_argnames=['self'])
def optimal_action(self, actor_state, critic_state, obs, seed):
action_mean, action_log_std = actor_state.apply_fn(actor_state.params, obs)
pi = distrax.MultivariateNormalDiag(action_mean, jnp.exp(action_log_std))
action_mean, action_std = actor_state.apply_fn(actor_state.params, obs)
pi = distrax.Normal(loc=action_mean, scale=action_std)
log_pi = pi.log_prob(action_mean).sum(-1)
v = critic_state.apply_fn(critic_state.params, obs)
log_pi = pi.log_prob(action_mean)
return action_mean, v, log_pi, seed

@partial(jax.jit, static_argnames=['self'])
Expand Down Expand Up @@ -195,10 +171,11 @@ def _calculate_gae(gae_and_next_value, transition):
init = (0.0, last_v),
xs = trajectory,
length = self.cfg['agent']['collect_steps'],
reverse = True,
reverse = True
)
trajectory['adv'] = adv
trajectory['v_target'] = adv + trajectory['v']
# Normalize advantage
trajectory['adv'] = (adv - adv.mean()) / (adv.std() + 1e-8)
return trajectory

@partial(jax.jit, static_argnames=['self'])
Expand All @@ -215,44 +192,41 @@ def update_epoch(self, carry, _):
lambda x: jnp.reshape(x, (-1, self.cfg['batch_size']) + x.shape[1:]),
shuffled_trajectory,
)
carry = (actor_state, critic_state)
carry, _ = jax.lax.scan(
(actor_state, critic_state), _ = jax.lax.scan(
f = self.update_batch,
init = carry,
init = (actor_state, critic_state),
xs = batches
)
actor_state, critic_state = carry
carry = (actor_state, critic_state, trajectory, seed)
return carry, None

@partial(jax.jit, static_argnames=['self'])
def update_batch(self, carry, batch):
actor_state, critic_state = carry
adv = (batch['adv'] - batch['adv'].mean()) / (batch['adv'].std() + 1e-8)
# Set loss function
def compute_loss(params):
actor_param, critic_param = params
# Compute critic loss
v = critic_state.apply_fn(critic_param, batch['obs'])
v_clipped = batch['v'] + (v - batch['v']).clip(-self.cfg['agent']['clip_ratio'], self.cfg['agent']['clip_ratio'])
critic_loss_unclipped = jnp.square(v - batch['v_target'])
critic_loss_clipped = jnp.square(v_clipped - batch['v_target'])
critic_loss = 0.5 * jnp.maximum(critic_loss_unclipped, critic_loss_clipped).mean()
v = critic_state.apply_fn(params['critic'], batch['obs'])
critic_loss = jnp.square(v - batch['v_target']).mean()
# Compute actor loss
action_mean, action_log_std = actor_state.apply_fn(actor_param, batch['obs'])
pi = distrax.MultivariateNormalDiag(action_mean, jnp.exp(action_log_std))
log_pi = pi.log_prob(batch['action'])
action_mean, action_std = actor_state.apply_fn(params['actor'], batch['obs'])
pi = distrax.Normal(loc=action_mean, scale=action_std)
log_pi = pi.log_prob(batch['action']).sum(-1)
ratio = jnp.exp(log_pi - batch['log_pi'])
obj = ratio * adv
obj_clipped = jnp.clip(ratio, 1.0-self.cfg['agent']['clip_ratio'], 1.0+self.cfg['agent']['clip_ratio']) * adv
obj = ratio * batch['adv']
obj_clipped = jnp.clip(ratio, 1.0-self.cfg['agent']['clip_ratio'], 1.0+self.cfg['agent']['clip_ratio']) * batch['adv']
actor_loss = -jnp.minimum(obj, obj_clipped).mean()
# Compute entropy
entropy = pi.entropy().mean()
total_loss = actor_loss + self.cfg['agent']['vf_coef'] * critic_loss - self.cfg['agent']['ent_coef'] * entropy
# Compute entropy loss
entropy_loss = pi.entropy().sum(-1).mean()
total_loss = actor_loss + self.cfg['agent']['vf_coef'] * critic_loss - self.cfg['agent']['ent_coef'] * entropy_loss
return total_loss

grads = jax.grad(compute_loss)((actor_state.params, critic_state.params))
actor_grads, critic_grads = grads
actor_state = actor_state.apply_gradients(grads=actor_grads)
critic_state = critic_state.apply_gradients(grads=critic_grads)
# Update train_state and critic_state
params = {
'actor': actor_state.params,
'critic': critic_state.params
}
grads = jax.grad(compute_loss)(params)
actor_state = actor_state.apply_gradients(grads=grads['actor'])
critic_state = critic_state.apply_gradients(grads=grads['critic'])
carry = (actor_state, critic_state)
return carry, None
3 changes: 1 addition & 2 deletions agents/SAC.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import jax
import time
import optax
import numpy as np
from tqdm import tqdm
import jax.numpy as jnp
import flax.linen as nn
Expand Down Expand Up @@ -151,7 +150,7 @@ def critic_loss(params):
critic_state = critic_state.apply_gradients(grads=grads)
# Soft-update target network
critic_state = critic_state.replace(
target_params = optax.incremental_update(critic_state.params, critic_state.target_params, self.cfg['tau'])
target_params = optax.incremental_update(critic_state.params, critic_state.target_params, self.cfg['agent']['tau'])
)
return critic_state

Expand Down
4 changes: 2 additions & 2 deletions agents/TD3.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def critic_loss(params):
critic_state = critic_state.apply_gradients(grads=grads)
# Soft-update target network
critic_state = critic_state.replace(
target_params = optax.incremental_update(critic_state.params, critic_state.target_params, self.cfg['tau'])
target_params = optax.incremental_update(critic_state.params, critic_state.target_params, self.cfg['agent']['tau'])
)
return critic_state

Expand All @@ -46,6 +46,6 @@ def actor_loss(params):
actor_state = actor_state.apply_gradients(grads=grads)
# Soft-update target network
actor_state = actor_state.replace(
target_params = optax.incremental_update(actor_state.params, actor_state.target_params, self.cfg['tau'])
target_params = optax.incremental_update(actor_state.params, actor_state.target_params, self.cfg['agent']['tau'])
)
return actor_state, None
9 changes: 8 additions & 1 deletion analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,22 +62,29 @@ def get_csv_result_dict(result, config_idx, mode='Train', ci=95, method='percent
def analyze(exp, runs=1):
cfg['exp'] = exp
cfg['runs'] = runs
'''
sweep_keys_dict = dict(
dqn = ['optim/kwargs/learning_rate'],
ddqn = ['optim/kwargs/learning_rate'],
maxmin = ['optim/kwargs/learning_rate', 'agent/critic_num'],
td = ['optim/kwargs/learning_rate'],
sac = ['optim/kwargs/learning_rate'],
naf = ['optim/kwargs/learning_rate'],
ppo = ['optim/kwargs/learning_rate'],
ddpg = ['optim/kwargs/learning_rate'],
)
algo = exp.rstrip('0123456789').split('_')[-1]
cfg['sweep_keys'] = sweep_keys_dict[algo]
'''
plotter = Plotter(cfg)

plotter.csv_merged_results('Train', get_csv_result_dict, get_process_result_dict)
plotter.plot_results(mode='Train', indexes='all')

# plotter.csv_unmerged_results('Train', get_process_result_dict)
# group_keys = ['optim/kwargs/learning_rate', 'agent/critic_num']
# plotter.get_top1_result(group_keys=group_keys, perf='Return (bmean)', errorbar='Return (ci=95)', mode='Train', nd=2, markdown=False)
# group_keys = ['Env']
# plotter.get_top1_result(group_keys=group_keys, perf='Return (bmean)', errorbar='Return (ci=95)', mode='Train', nd=0, markdown=False)

# Hyper-parameter Comparison
# plotter.csv_unmerged_results('Train', get_process_result_dict)
Expand Down
62 changes: 17 additions & 45 deletions components/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,15 +84,16 @@ class MinAtarQNet(nn.Module):
net_cfg: FrozenDict = FrozenDict({'feature_dim': 128, 'hidden_act': 'ReLU'})
kernel_init: Initializer = default_init
action_size: int = 10
last_w_scale: float = -1.0

def setup(self):
self.Q_net = nn.Sequential([
nn.Conv(16, kernel_size=(3, 3), strides=(1, 1), kernel_init=self.kernel_init),
nn.Conv(16, kernel_size=(3, 3), strides=(1, 1), kernel_init=self.kernel_init()),
activations[self.net_cfg['hidden_act']],
lambda x: x.reshape((x.shape[0], -1)), # flatten
nn.Dense(self.net_cfg['feature_dim'], kernel_init=self.kernel_init),
nn.Dense(self.net_cfg['feature_dim'], kernel_init=self.kernel_init()),
activations[self.net_cfg['hidden_act']],
nn.Dense(self.action_size, kernel_init=self.kernel_init)
nn.Dense(self.action_size, kernel_init=self.kernel_init(self.last_w_scale))
])

def __call__(self, obs):
Expand Down Expand Up @@ -162,57 +163,28 @@ class MLPGaussianActor(nn.Module):
""" MLP actor network with Guassian policy N(mu, std). """
action_size: int = 4
net_cfg: FrozenDict = FrozenDict({'hidden_dims': [32,32], 'hidden_act': 'ReLU'})
log_std_min: float = -20.0
log_std_max: float = 2.0
kernel_init: Initializer = default_init
last_w_scale: float = -1.0

def setup(self):
self.actor_net = MLP(
layer_dims = list(self.net_cfg['hidden_dims'])+[self.action_size],
self.actor_feature = MLP(
layer_dims = list(self.net_cfg['hidden_dims']),
hidden_act = self.net_cfg['hidden_act'],
output_act = self.net_cfg['hidden_act'],
kernel_init = self.kernel_init,
last_w_scale = self.last_w_scale
)
self.action_log_std = self.param('log_std', zeros_init(), (self.action_size,))

def __call__(self, obs):
u_mean = self.actor_net(obs)
u_log_std = jnp.clip(self.action_log_std, self.log_std_min, self.log_std_max)
u_log_std = jnp.broadcast_to(u_log_std, u_mean.shape)
return u_mean, u_log_std


class PPONet(nn.Module):
action_size: int = 4
actor_net_cfg: FrozenDict = FrozenDict({'hidden_dims': [32,32], 'hidden_act': 'ReLU'})
critic_net_cfg: FrozenDict = FrozenDict({'hidden_dims': [32,32], 'hidden_act': 'ReLU'})
log_std_min: float = -20.0
log_std_max: float = 2.0
kernel_init: Initializer = default_init

def setup(self):
self.actor_net = MLP(
layer_dims = list(self.actor_net_cfg['hidden_dims'])+[self.action_size],
hidden_act = self.actor_net_cfg['hidden_act'],
kernel_init = self.kernel_init,
last_w_scale = 0.01
last_w_scale = -1.0
)
self.action_log_std = self.param('log_std', zeros_init(), (self.action_size,))
self.critic_net = MLP(
layer_dims = list(self.critic_net_cfg['hidden_dims']) + [1],
hidden_act = self.critic_net_cfg['hidden_act'],
kernel_init = self.kernel_init,
last_w_scale = 1.0
)
self.actor_mean = nn.Dense(self.action_size, kernel_init=self.kernel_init(self.last_w_scale))
self.actor_std = nn.Sequential([
nn.Dense(self.action_size, kernel_init=self.kernel_init(self.last_w_scale)),
nn.sigmoid
])

def __call__(self, obs):
u_mean = self.actor_net(obs)
u_log_std = jnp.clip(self.action_log_std, self.log_std_min, self.log_std_max)
return u_mean, u_log_std, self.critic_net(obs).squeeze(-1)

def get_v(self, obs):
return self.critic_net(obs).squeeze(-1)
feature = self.actor_feature(obs)
u_mean = self.actor_mean(feature)
u_std = self.actor_std(feature)
return u_mean, u_std


class MLPGaussianTanhActor(nn.Module):
Expand Down
Loading

0 comments on commit 39a5910

Please sign in to comment.