From 39a5910832dc455ccf6c505504ccdcb17adfdb5c Mon Sep 17 00:00:00 2001 From: qlan3 Date: Thu, 11 Apr 2024 18:25:02 -0600 Subject: [PATCH] improve PPO --- README.md | 10 ++-- agents/DDPG.py | 4 +- agents/MaxminDQN.py | 3 -- agents/NAF.py | 2 +- agents/PPO.py | 98 ++++++++++++++----------------------- agents/SAC.py | 3 +- agents/TD3.py | 4 +- analysis.py | 9 +++- components/networks.py | 62 +++++++---------------- configs/minatar_ddqn.json | 2 +- configs/minatar_maxmin.json | 2 +- configs/mujoco_ddpg.json | 6 +-- configs/mujoco_naf.json | 6 +-- configs/mujoco_ppo.json | 3 +- configs/mujoco_sac.json | 6 +-- configs/mujoco_td3.json | 6 +-- envs/env.py | 22 ++++++++- requirements.txt | 7 +-- results/performance.md | 24 ++++++++- sbatch_m.sh | 6 +-- sbatch_s.sh | 8 ++- submit.py | 11 ++--- 22 files changed, 146 insertions(+), 158 deletions(-) diff --git a/README.md b/README.md index be17ea7..99ebbe2 100644 --- a/README.md +++ b/README.md @@ -11,7 +11,7 @@ Jaxplorer is a **Jax** reinforcement learning (RL) framework for **exploring** n ## TODO -- Add more descriptions about slurm and hyper-parameter comparison. +- Add more descriptions about slurm and experimental result analysis. - Add more algorithms, such as DQN for Atari games. @@ -19,9 +19,8 @@ Jaxplorer is a **Jax** reinforcement learning (RL) framework for **exploring** n - Python: 3.11 - [Jax](https://jax.readthedocs.io/en/latest/installation.html): >=0.4.20 -- [Gymnasium](https://github.com/Farama-Foundation/Gymnasium): `pip install gymnasium==0.29.1` -- [MuJoCo](https://github.com/google-deepmind/mujoco): `pip install mujoco==2.3.7` -- [Gymnasium(mujoco)](https://gymnasium.farama.org/environments/mujoco/): `pip install gymnasium[mujoco]` +- [MuJoCo](https://github.com/google-deepmind/mujoco): `pip install 'mujoco>=2.3.6,<3.0'` +- [Gymnasium](https://github.com/Farama-Foundation/Gymnasium): `pip install 'gymnasium[box2d,mujoco]>=0.29.1,<1.0'` - [Gym Games](https://github.com/qlan3/gym-games): >=2.0.0. - Others: `pip install -r requirements.txt`. @@ -124,4 +123,5 @@ If you find this repo useful to your research, please cite this repo: - [Explorer](https://github.com/qlan3/Explorer) - [Jax RL](https://github.com/ikostrikov/jaxrl) - [CleanRL](https://github.com/vwxyzjn/cleanrl) -- [PureJaxRL](https://github.com/luchris429/purejaxrl) \ No newline at end of file +- [PureJaxRL](https://github.com/luchris429/purejaxrl) +- [rl-basics](https://github.com/vcharraut/rl-basics) \ No newline at end of file diff --git a/agents/DDPG.py b/agents/DDPG.py index 5ad6d1f..5796ed2 100644 --- a/agents/DDPG.py +++ b/agents/DDPG.py @@ -72,7 +72,7 @@ def critic_loss(params): critic_state = critic_state.apply_gradients(grads=grads) # Soft-update target network critic_state = critic_state.replace( - target_params = optax.incremental_update(critic_state.params, critic_state.target_params, self.cfg['tau']) + target_params = optax.incremental_update(critic_state.params, critic_state.target_params, self.cfg['agent']['tau']) ) return critic_state @@ -89,6 +89,6 @@ def actor_loss(params): actor_state = actor_state.apply_gradients(grads=grads) # Soft-update target network actor_state = actor_state.replace( - target_params = optax.incremental_update(actor_state.params, actor_state.target_params, self.cfg['tau']) + target_params = optax.incremental_update(actor_state.params, actor_state.target_params, self.cfg['agent']['tau']) ) return actor_state, None \ No newline at end of file diff --git a/agents/MaxminDQN.py b/agents/MaxminDQN.py index 295539d..0f9fee8 100644 --- a/agents/MaxminDQN.py +++ b/agents/MaxminDQN.py @@ -1,7 +1,4 @@ import jax -import time -import random -from tqdm import tqdm import flax.linen as nn import jax.numpy as jnp from functools import partial diff --git a/agents/NAF.py b/agents/NAF.py index 4c4d015..16cf1e5 100644 --- a/agents/NAF.py +++ b/agents/NAF.py @@ -69,6 +69,6 @@ def critic_loss(params): critic_state = critic_state.apply_gradients(grads=grads) # Soft-update target network critic_state = critic_state.replace( - target_params = optax.incremental_update(critic_state.params, critic_state.target_params, self.cfg['tau']) + target_params = optax.incremental_update(critic_state.params, critic_state.target_params, self.cfg['agent']['tau']) ) return critic_state \ No newline at end of file diff --git a/agents/PPO.py b/agents/PPO.py index 3912a51..683657a 100644 --- a/agents/PPO.py +++ b/agents/PPO.py @@ -2,35 +2,11 @@ import time import math import distrax -import numpy as np from tqdm import tqdm import jax.numpy as jnp -import flax.linen as nn from functools import partial -import gymnasium as gym -from gymnasium import wrappers -from envs.wrappers import UniversalSeed - -def ppo_make_env(env_name, gamma=0.99, deque_size=1, **kwargs): - """ Make env for PPO. """ - env = gym.make(env_name, **kwargs) - # Episode statistics wrapper: set it before reward wrappers - env = wrappers.RecordEpisodeStatistics(env, deque_size=deque_size) - # Action wrapper - env = wrappers.ClipAction(wrappers.RescaleAction(env, min_action=-1, max_action=1)) - # Obs wrapper - env = wrappers.FlattenObservation(env) # For dm_control - env = wrappers.NormalizeObservation(env) - env = wrappers.TransformObservation(env, lambda obs: np.clip(obs, -10, 10)) - # Reward wrapper - env = wrappers.NormalizeReward(env, gamma=gamma) - env = wrappers.TransformReward(env, lambda reward: np.clip(reward, -10, 10)) - # Seed wrapper: must be the last wrapper to be effective - env = UniversalSeed(env) - return env - - +from envs.env import ppo_make_env from agents.BaseAgent import BaseAgent, TrainState from components.replay import FiniteReplay from components.networks import MLPVCritic, MLPGaussianActor, MLPCategoricalActor @@ -66,7 +42,6 @@ def linear_schedule(count): return frac * lr dummy_obs = self.env['Train'].observation_space.sample()[None,] self.seed, actor_seed, critic_seed = jax.random.split(self.seed, 3) - # Set actor network if self.action_type == 'DISCRETE': actor_net = MLPCategoricalActor elif self.action_type == 'CONTINUOUS': @@ -102,12 +77,13 @@ def run_steps(self, mode='Train'): # Take a env step next_obs, reward, terminated, truncated, info = self.env[mode].step(action) # Save experience - mask = self.discount * (1 - terminated) + done = terminated or truncated + mask = self.discount * (1.0 - done) self.save_experience(obs, action, reward, mask, v, log_pi) # Update observation obs = next_obs # Record and reset - if terminated or truncated: + if done: result_dict = { 'Env': self.env_name, 'Agent': self.agent_name, @@ -148,20 +124,20 @@ def get_action(self, step, obs, mode='Train'): @partial(jax.jit, static_argnames=['self']) def random_action(self, actor_state, critic_state, obs, seed): - seed, action_seed = jax.random.split(seed, 2) - action_mean, action_log_std = actor_state.apply_fn(actor_state.params, obs) - pi = distrax.MultivariateNormalDiag(action_mean, jnp.exp(action_log_std)) - v = critic_state.apply_fn(critic_state.params, obs) + seed, action_seed = jax.random.split(seed) + action_mean, action_std = actor_state.apply_fn(actor_state.params, obs) + pi = distrax.Normal(loc=action_mean, scale=action_std) action = pi.sample(seed=action_seed) - log_pi = pi.log_prob(action) + log_pi = pi.log_prob(action).sum(-1) + v = critic_state.apply_fn(critic_state.params, obs) return action, v, log_pi, seed @partial(jax.jit, static_argnames=['self']) def optimal_action(self, actor_state, critic_state, obs, seed): - action_mean, action_log_std = actor_state.apply_fn(actor_state.params, obs) - pi = distrax.MultivariateNormalDiag(action_mean, jnp.exp(action_log_std)) + action_mean, action_std = actor_state.apply_fn(actor_state.params, obs) + pi = distrax.Normal(loc=action_mean, scale=action_std) + log_pi = pi.log_prob(action_mean).sum(-1) v = critic_state.apply_fn(critic_state.params, obs) - log_pi = pi.log_prob(action_mean) return action_mean, v, log_pi, seed @partial(jax.jit, static_argnames=['self']) @@ -195,10 +171,11 @@ def _calculate_gae(gae_and_next_value, transition): init = (0.0, last_v), xs = trajectory, length = self.cfg['agent']['collect_steps'], - reverse = True, + reverse = True ) - trajectory['adv'] = adv trajectory['v_target'] = adv + trajectory['v'] + # Normalize advantage + trajectory['adv'] = (adv - adv.mean()) / (adv.std() + 1e-8) return trajectory @partial(jax.jit, static_argnames=['self']) @@ -215,44 +192,41 @@ def update_epoch(self, carry, _): lambda x: jnp.reshape(x, (-1, self.cfg['batch_size']) + x.shape[1:]), shuffled_trajectory, ) - carry = (actor_state, critic_state) - carry, _ = jax.lax.scan( + (actor_state, critic_state), _ = jax.lax.scan( f = self.update_batch, - init = carry, + init = (actor_state, critic_state), xs = batches ) - actor_state, critic_state = carry carry = (actor_state, critic_state, trajectory, seed) return carry, None @partial(jax.jit, static_argnames=['self']) def update_batch(self, carry, batch): actor_state, critic_state = carry - adv = (batch['adv'] - batch['adv'].mean()) / (batch['adv'].std() + 1e-8) + # Set loss function def compute_loss(params): - actor_param, critic_param = params # Compute critic loss - v = critic_state.apply_fn(critic_param, batch['obs']) - v_clipped = batch['v'] + (v - batch['v']).clip(-self.cfg['agent']['clip_ratio'], self.cfg['agent']['clip_ratio']) - critic_loss_unclipped = jnp.square(v - batch['v_target']) - critic_loss_clipped = jnp.square(v_clipped - batch['v_target']) - critic_loss = 0.5 * jnp.maximum(critic_loss_unclipped, critic_loss_clipped).mean() + v = critic_state.apply_fn(params['critic'], batch['obs']) + critic_loss = jnp.square(v - batch['v_target']).mean() # Compute actor loss - action_mean, action_log_std = actor_state.apply_fn(actor_param, batch['obs']) - pi = distrax.MultivariateNormalDiag(action_mean, jnp.exp(action_log_std)) - log_pi = pi.log_prob(batch['action']) + action_mean, action_std = actor_state.apply_fn(params['actor'], batch['obs']) + pi = distrax.Normal(loc=action_mean, scale=action_std) + log_pi = pi.log_prob(batch['action']).sum(-1) ratio = jnp.exp(log_pi - batch['log_pi']) - obj = ratio * adv - obj_clipped = jnp.clip(ratio, 1.0-self.cfg['agent']['clip_ratio'], 1.0+self.cfg['agent']['clip_ratio']) * adv + obj = ratio * batch['adv'] + obj_clipped = jnp.clip(ratio, 1.0-self.cfg['agent']['clip_ratio'], 1.0+self.cfg['agent']['clip_ratio']) * batch['adv'] actor_loss = -jnp.minimum(obj, obj_clipped).mean() - # Compute entropy - entropy = pi.entropy().mean() - total_loss = actor_loss + self.cfg['agent']['vf_coef'] * critic_loss - self.cfg['agent']['ent_coef'] * entropy + # Compute entropy loss + entropy_loss = pi.entropy().sum(-1).mean() + total_loss = actor_loss + self.cfg['agent']['vf_coef'] * critic_loss - self.cfg['agent']['ent_coef'] * entropy_loss return total_loss - - grads = jax.grad(compute_loss)((actor_state.params, critic_state.params)) - actor_grads, critic_grads = grads - actor_state = actor_state.apply_gradients(grads=actor_grads) - critic_state = critic_state.apply_gradients(grads=critic_grads) + # Update train_state and critic_state + params = { + 'actor': actor_state.params, + 'critic': critic_state.params + } + grads = jax.grad(compute_loss)(params) + actor_state = actor_state.apply_gradients(grads=grads['actor']) + critic_state = critic_state.apply_gradients(grads=grads['critic']) carry = (actor_state, critic_state) return carry, None \ No newline at end of file diff --git a/agents/SAC.py b/agents/SAC.py index c86170a..a413dd3 100644 --- a/agents/SAC.py +++ b/agents/SAC.py @@ -1,7 +1,6 @@ import jax import time import optax -import numpy as np from tqdm import tqdm import jax.numpy as jnp import flax.linen as nn @@ -151,7 +150,7 @@ def critic_loss(params): critic_state = critic_state.apply_gradients(grads=grads) # Soft-update target network critic_state = critic_state.replace( - target_params = optax.incremental_update(critic_state.params, critic_state.target_params, self.cfg['tau']) + target_params = optax.incremental_update(critic_state.params, critic_state.target_params, self.cfg['agent']['tau']) ) return critic_state diff --git a/agents/TD3.py b/agents/TD3.py index 9d0720d..92fcd2d 100644 --- a/agents/TD3.py +++ b/agents/TD3.py @@ -29,7 +29,7 @@ def critic_loss(params): critic_state = critic_state.apply_gradients(grads=grads) # Soft-update target network critic_state = critic_state.replace( - target_params = optax.incremental_update(critic_state.params, critic_state.target_params, self.cfg['tau']) + target_params = optax.incremental_update(critic_state.params, critic_state.target_params, self.cfg['agent']['tau']) ) return critic_state @@ -46,6 +46,6 @@ def actor_loss(params): actor_state = actor_state.apply_gradients(grads=grads) # Soft-update target network actor_state = actor_state.replace( - target_params = optax.incremental_update(actor_state.params, actor_state.target_params, self.cfg['tau']) + target_params = optax.incremental_update(actor_state.params, actor_state.target_params, self.cfg['agent']['tau']) ) return actor_state, None \ No newline at end of file diff --git a/analysis.py b/analysis.py index 96b3e5e..5257ba0 100644 --- a/analysis.py +++ b/analysis.py @@ -62,14 +62,20 @@ def get_csv_result_dict(result, config_idx, mode='Train', ci=95, method='percent def analyze(exp, runs=1): cfg['exp'] = exp cfg['runs'] = runs + ''' sweep_keys_dict = dict( dqn = ['optim/kwargs/learning_rate'], ddqn = ['optim/kwargs/learning_rate'], maxmin = ['optim/kwargs/learning_rate', 'agent/critic_num'], + td = ['optim/kwargs/learning_rate'], + sac = ['optim/kwargs/learning_rate'], + naf = ['optim/kwargs/learning_rate'], ppo = ['optim/kwargs/learning_rate'], + ddpg = ['optim/kwargs/learning_rate'], ) algo = exp.rstrip('0123456789').split('_')[-1] cfg['sweep_keys'] = sweep_keys_dict[algo] + ''' plotter = Plotter(cfg) plotter.csv_merged_results('Train', get_csv_result_dict, get_process_result_dict) @@ -77,7 +83,8 @@ def analyze(exp, runs=1): # plotter.csv_unmerged_results('Train', get_process_result_dict) # group_keys = ['optim/kwargs/learning_rate', 'agent/critic_num'] - # plotter.get_top1_result(group_keys=group_keys, perf='Return (bmean)', errorbar='Return (ci=95)', mode='Train', nd=2, markdown=False) + # group_keys = ['Env'] + # plotter.get_top1_result(group_keys=group_keys, perf='Return (bmean)', errorbar='Return (ci=95)', mode='Train', nd=0, markdown=False) # Hyper-parameter Comparison # plotter.csv_unmerged_results('Train', get_process_result_dict) diff --git a/components/networks.py b/components/networks.py index 3f7c60c..32a9004 100644 --- a/components/networks.py +++ b/components/networks.py @@ -84,15 +84,16 @@ class MinAtarQNet(nn.Module): net_cfg: FrozenDict = FrozenDict({'feature_dim': 128, 'hidden_act': 'ReLU'}) kernel_init: Initializer = default_init action_size: int = 10 + last_w_scale: float = -1.0 def setup(self): self.Q_net = nn.Sequential([ - nn.Conv(16, kernel_size=(3, 3), strides=(1, 1), kernel_init=self.kernel_init), + nn.Conv(16, kernel_size=(3, 3), strides=(1, 1), kernel_init=self.kernel_init()), activations[self.net_cfg['hidden_act']], lambda x: x.reshape((x.shape[0], -1)), # flatten - nn.Dense(self.net_cfg['feature_dim'], kernel_init=self.kernel_init), + nn.Dense(self.net_cfg['feature_dim'], kernel_init=self.kernel_init()), activations[self.net_cfg['hidden_act']], - nn.Dense(self.action_size, kernel_init=self.kernel_init) + nn.Dense(self.action_size, kernel_init=self.kernel_init(self.last_w_scale)) ]) def __call__(self, obs): @@ -162,57 +163,28 @@ class MLPGaussianActor(nn.Module): """ MLP actor network with Guassian policy N(mu, std). """ action_size: int = 4 net_cfg: FrozenDict = FrozenDict({'hidden_dims': [32,32], 'hidden_act': 'ReLU'}) - log_std_min: float = -20.0 - log_std_max: float = 2.0 kernel_init: Initializer = default_init last_w_scale: float = -1.0 def setup(self): - self.actor_net = MLP( - layer_dims = list(self.net_cfg['hidden_dims'])+[self.action_size], + self.actor_feature = MLP( + layer_dims = list(self.net_cfg['hidden_dims']), hidden_act = self.net_cfg['hidden_act'], + output_act = self.net_cfg['hidden_act'], kernel_init = self.kernel_init, - last_w_scale = self.last_w_scale - ) - self.action_log_std = self.param('log_std', zeros_init(), (self.action_size,)) - - def __call__(self, obs): - u_mean = self.actor_net(obs) - u_log_std = jnp.clip(self.action_log_std, self.log_std_min, self.log_std_max) - u_log_std = jnp.broadcast_to(u_log_std, u_mean.shape) - return u_mean, u_log_std - - -class PPONet(nn.Module): - action_size: int = 4 - actor_net_cfg: FrozenDict = FrozenDict({'hidden_dims': [32,32], 'hidden_act': 'ReLU'}) - critic_net_cfg: FrozenDict = FrozenDict({'hidden_dims': [32,32], 'hidden_act': 'ReLU'}) - log_std_min: float = -20.0 - log_std_max: float = 2.0 - kernel_init: Initializer = default_init - - def setup(self): - self.actor_net = MLP( - layer_dims = list(self.actor_net_cfg['hidden_dims'])+[self.action_size], - hidden_act = self.actor_net_cfg['hidden_act'], - kernel_init = self.kernel_init, - last_w_scale = 0.01 + last_w_scale = -1.0 ) - self.action_log_std = self.param('log_std', zeros_init(), (self.action_size,)) - self.critic_net = MLP( - layer_dims = list(self.critic_net_cfg['hidden_dims']) + [1], - hidden_act = self.critic_net_cfg['hidden_act'], - kernel_init = self.kernel_init, - last_w_scale = 1.0 - ) + self.actor_mean = nn.Dense(self.action_size, kernel_init=self.kernel_init(self.last_w_scale)) + self.actor_std = nn.Sequential([ + nn.Dense(self.action_size, kernel_init=self.kernel_init(self.last_w_scale)), + nn.sigmoid + ]) def __call__(self, obs): - u_mean = self.actor_net(obs) - u_log_std = jnp.clip(self.action_log_std, self.log_std_min, self.log_std_max) - return u_mean, u_log_std, self.critic_net(obs).squeeze(-1) - - def get_v(self, obs): - return self.critic_net(obs).squeeze(-1) + feature = self.actor_feature(obs) + u_mean = self.actor_mean(feature) + u_std = self.actor_std(feature) + return u_mean, u_std class MLPGaussianTanhActor(nn.Module): diff --git a/configs/minatar_ddqn.json b/configs/minatar_ddqn.json index 4bd813b..a087f05 100644 --- a/configs/minatar_ddqn.json +++ b/configs/minatar_ddqn.json @@ -33,6 +33,6 @@ "test_episodes": [5], "discount": [0.99], "seed": [1], - "device": ["cuda"], + "device": ["cpu"], "generate_random_seed": [true] } \ No newline at end of file diff --git a/configs/minatar_maxmin.json b/configs/minatar_maxmin.json index fd32cf3..2ff8a41 100644 --- a/configs/minatar_maxmin.json +++ b/configs/minatar_maxmin.json @@ -34,6 +34,6 @@ "test_episodes": [5], "discount": [0.99], "seed": [1], - "device": ["cuda"], + "device": ["cpu"], "generate_random_seed": [true] } \ No newline at end of file diff --git a/configs/mujoco_ddpg.json b/configs/mujoco_ddpg.json index 96f62fc..14055c3 100644 --- a/configs/mujoco_ddpg.json +++ b/configs/mujoco_ddpg.json @@ -16,7 +16,8 @@ "critic_net_cfg": [{ "hidden_dims": [[256,256]], "hidden_act": ["ReLU"] - }] + }], + "tau": [0.005] }], "optim": [{ "name": ["Adam"], @@ -27,10 +28,9 @@ "display_interval": [1e4], "show_progress": [false], "ckpt_interval": [1e5], - "test_interval": [1e4], + "test_interval": [-1], "test_episodes": [5], "discount": [0.99], - "tau": [0.005], "seed": [1], "device": ["cuda"], "generate_random_seed": [true] diff --git a/configs/mujoco_naf.json b/configs/mujoco_naf.json index 5a38305..647c21d 100644 --- a/configs/mujoco_naf.json +++ b/configs/mujoco_naf.json @@ -20,7 +20,8 @@ "L_net_cfg": [{ "hidden_dims": [[256,256]], "hidden_act": ["Tanh"] - }] + }], + "tau": [0.005] }], "optim": [{ "name": ["Adam"], @@ -31,10 +32,9 @@ "display_interval": [1e4], "show_progress": [false], "ckpt_interval": [1e5], - "test_interval": [1e4], + "test_interval": [-1], "test_episodes": [5], "discount": [0.99], - "tau": [0.005], "seed": [1], "device": ["cuda"], "generate_random_seed": [true] diff --git a/configs/mujoco_ppo.json b/configs/mujoco_ppo.json index 6d08c76..335ccb1 100644 --- a/configs/mujoco_ppo.json +++ b/configs/mujoco_ppo.json @@ -20,8 +20,7 @@ }], "optim": [{ "name": ["Adam"], - "anneal_lr": [true], - "kwargs": [{"learning_rate": [3e-4], "anneal_lr": [true], "eps": [1e-5], "max_grad_norm": [0.5]}] + "kwargs": [{"learning_rate": [3e-4], "max_grad_norm": [0.5], "anneal_lr": [true]}] }], "batch_size": [64], "display_interval": [1], diff --git a/configs/mujoco_sac.json b/configs/mujoco_sac.json index 96236d5..adce2c7 100644 --- a/configs/mujoco_sac.json +++ b/configs/mujoco_sac.json @@ -15,7 +15,8 @@ "critic_net_cfg": [{ "hidden_dims": [[256,256]], "hidden_act": ["ReLU"] - }] + }], + "tau": [0.005] }], "optim": [{ "name": ["Adam"], @@ -25,10 +26,9 @@ "display_interval": [1e4], "show_progress": [false], "ckpt_interval": [1e5], - "test_interval": [1e4], + "test_interval": [-1], "test_episodes": [5], "discount": [0.99], - "tau": [0.005], "seed": [1], "device": ["cuda"], "generate_random_seed": [true] diff --git a/configs/mujoco_td3.json b/configs/mujoco_td3.json index 469206f..6bce5cb 100644 --- a/configs/mujoco_td3.json +++ b/configs/mujoco_td3.json @@ -18,7 +18,8 @@ "critic_net_cfg": [{ "hidden_dims": [[256,256]], "hidden_act": ["ReLU"] - }] + }], + "tau": [0.005] }], "optim": [{ "name": ["Adam"], @@ -29,10 +30,9 @@ "display_interval": [1e4], "show_progress": [false], "ckpt_interval": [1e5], - "test_interval": [1e4], + "test_interval": [-1], "test_episodes": [5], "discount": [0.99], - "tau": [0.005], "seed": [1], "device": ["cuda"], "generate_random_seed": [true] diff --git a/envs/env.py b/envs/env.py index 83b184d..6765390 100644 --- a/envs/env.py +++ b/envs/env.py @@ -1,6 +1,7 @@ +import numpy as np import gymnasium as gym from gymnasium import spaces -from gymnasium.wrappers import ClipAction, RescaleAction, RecordEpisodeStatistics +from gymnasium.wrappers import ClipAction, RescaleAction, RecordEpisodeStatistics, FlattenObservation, NormalizeObservation, NormalizeReward, TransformObservation, TransformReward from envs.wrappers import UniversalSeed import gym_pygame @@ -23,6 +24,25 @@ def make_env(env_name, deque_size=1, **kwargs): return env +def ppo_make_env(env_name, gamma=0.99, deque_size=1, **kwargs): + """ Make env for PPO. """ + env = gym.make(env_name, **kwargs) + # Episode statistics wrapper: set it before reward wrappers + env = RecordEpisodeStatistics(env, deque_size=deque_size) + # Action wrapper + env = ClipAction(RescaleAction(env, min_action=-1, max_action=1)) + # Obs wrapper + env = FlattenObservation(env) # For dm_control + env = NormalizeObservation(env) + env = TransformObservation(env, lambda obs: np.clip(obs, -10, 10)) + # Reward wrapper + env = NormalizeReward(env, gamma=gamma) + env = TransformReward(env, lambda reward: np.clip(reward, -10, 10)) + # Seed wrapper: must be the last wrapper to be effective + env = UniversalSeed(env) + return env + + def make_vec_env(env_name, num_envs=1, asynchronous=False, deque_size=1, max_episode_steps=None, **kwargs): env = gym.make(env_name, **kwargs) wrappers = [] diff --git a/requirements.txt b/requirements.txt index d69757b..dd25b36 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ jax>=0.4.20 -mujoco==2.3.7 -gymnasium>=0.29.1 +mujoco>=2.3.6,<3.0 +gymnasium>=0.29.1,<1.0 flax>=0.7.5 distrax>=0.1.5 matplotlib>=3.8.1 @@ -11,4 +11,5 @@ psutil>=5.9.6 seaborn>=0.13.0 pyarrow gpustat -tqdm \ No newline at end of file +tqdm +tabulate \ No newline at end of file diff --git a/results/performance.md b/results/performance.md index 2a637f4..340f575 100644 --- a/results/performance.md +++ b/results/performance.md @@ -1,8 +1,10 @@ This document records the best hyper-parameter setup for each (task, algorithm) pair, in terms of training performance. -All performance results are averaged over 5 runs with standard deviation reported as well. ## Gym Classic Control +- Performance metric: average returns of last 10 (for training) or 5 (for testing) episodes. +- Reported performance: average performance over 5 runs with 1 standard deviation. + ### MountainCar-v0 | Algo\Param | train perf | test perf | cfg file | cfg index | lr | critic_num | @@ -42,6 +44,9 @@ All performance results are averaged over 5 runs with standard deviation reporte ## Gym Box2D +- Performance metric: average returns of last 10 (for training) or 5 (for testing) episodes. +- Reported performance: average performance over 5 runs with 1 standard deviation. + ### LunarLander-v2 | Algo\Param | train perf | test perf | cfg file | cfg index | lr | critic_num | @@ -57,6 +62,9 @@ All performance results are averaged over 5 runs with standard deviation reporte ## PyGame Learning Environment +- Performance metric: average returns of last 10 (for training) or 5 (for testing) episodes. +- Reported performance: average performance over 5 runs with 1 standard deviation. + ### Catcher-PLE-v0 | Algo\Param | train perf | test perf | cfg file | cfg index | lr | critic_num | @@ -81,6 +89,20 @@ All performance results are averaged over 5 runs with standard deviation reporte | MaxminDQN | 56 $\pm$ 3 | 56 $\pm$ 1 | pygame_maxmin | 32 | 0.0001 | 8 | +## MuJoCo + +- Performance metric: average training returns of last 10% episodes. +- Reported performance: bootstrapped average performance over 5 runs with 95% confidence interval. + +| Task\Algo | PPO | SAC | DDPG | TD3 | +| -------------- | -------------- | --------------- | --------------- | -------------- | +| Ant-v4 | 1920 $\pm$ 393 | 4989 $\pm$ 332 | 1411 $\pm$ 576 | 2780 $\pm$ 244 | +| HalfCheetah-v4 | 3868 $\pm$ 585 | 10469 $\pm$ 269 | 9441 $\pm$ 553 | 8587 $\pm$ 453 | +| Hopper-v4 | 2334 $\pm$ 147 | 2459 $\pm$ 352 | 1750 $\pm$ 283 | 2437 $\pm$ 730 | +| Humanoid-v4 | 670 $\pm$ 23 | 5141 $\pm$ 125 | 3190 $\pm$ 1013 | 4948 $\pm$ 299 | +| Swimmer-v4 | 68 $\pm$ 11 | 62 $\pm$ 7 | 99 $\pm$ 23 | 83 $\pm$ 20 | +| Walker2d-v4 | 2857 $\pm$ 379 | 4285 $\pm$ 538 | 2424 $\pm$ 595 | 3959 $\pm$ 444 | + ## MinAtar diff --git a/sbatch_m.sh b/sbatch_m.sh index 2ba26ea..79395ec 100644 --- a/sbatch_m.sh +++ b/sbatch_m.sh @@ -1,5 +1,5 @@ #!/bin/bash -# Ask SLURM to send the USR1 signal 300 seconds before end of the time limit +# Ask SLURM to send the USR1 signal 60 seconds before end of the time limit #SBATCH --signal=B:USR1@60 #SBATCH --output=output/%x/%a.txt #SBATCH --mail-type=ALL @@ -26,8 +26,8 @@ cleanup() trap 'cleanup' USR1 EXIT # --------------------------------------------------------------------- # export OMP_NUM_THREADS=1 -module load StdEnv/2023 gcc/12.3 cudacore/.12.2.2 cudnn/8.9 cuda/12.2 mujoco/3.0.1 python/3.11 scipy-stack arrow -source ~/envs/invert/bin/activate +module load StdEnv/2020 gcc/9.3.0 cudacore/.11.4.2 cudnn/8.2.0 cuda/11.4 mujoco/2.3.6 python/3.11 scipy-stack/2022a arrow +source ~/envs/jaxplorer/bin/activate parallel --ungroup --jobs procfile python main.py --config_file ./configs/${SLURM_JOB_NAME}.json --config_idx {1} --slurm_dir $SLURM_TMPDIR :::: job_idx_${SLURM_JOB_NAME}_${SLURM_ARRAY_TASK_ID}.txt # parallel --eta --ungroup --jobs procfile python main.py --config_file ./configs/${SLURM_JOB_NAME}.json --config_idx {1} --slurm_dir $SLURM_TMPDIR :::: job_idx_${SLURM_JOB_NAME}_${SLURM_ARRAY_TASK_ID}.txt diff --git a/sbatch_s.sh b/sbatch_s.sh index c2a3e51..fe1b14d 100644 --- a/sbatch_s.sh +++ b/sbatch_s.sh @@ -1,9 +1,8 @@ #!/bin/bash -# Ask SLURM to send the USR1 signal 300 seconds before end of the time limit +# Ask SLURM to send the USR1 signal 60 seconds before end of the time limit #SBATCH --signal=B:USR1@60 #SBATCH --output=output/%x/%a.txt #SBATCH --mail-type=ALL -#SBATCH --exclude=bc11202,bc11203,bc11357,bc11322,bc11234 # --------------------------------------------------------------------- echo "Current working directory: `pwd`" @@ -26,9 +25,8 @@ cleanup() # Call `cleanup` once we receive USR1 or EXIT signal trap 'cleanup' USR1 EXIT # --------------------------------------------------------------------- -# export OMP_NUM_THREADS=1 -module load StdEnv/2023 gcc/12.3 cudacore/.12.2.2 cudnn/8.9 cuda/12.2 mujoco/3.0.1 python/3.11 scipy-stack arrow -source ~/envs/invert/bin/activate +module load StdEnv/2020 gcc/9.3.0 cudacore/.11.4.2 cudnn/8.2.0 cuda/11.4 mujoco/2.3.6 python/3.11 scipy-stack/2022a arrow +source ~/envs/jaxplorer/bin/activate python main.py --config_file ./configs/${SLURM_JOB_NAME}.json --config_idx $SLURM_ARRAY_TASK_ID --slurm_dir $SLURM_TMPDIR # python main.py --config_file ./configs/${SLURM_JOB_NAME}.json --config_idx $SLURM_ARRAY_TASK_ID diff --git a/submit.py b/submit.py index 484970f..7a8eab4 100644 --- a/submit.py +++ b/submit.py @@ -23,8 +23,8 @@ def main(argv): # 'account': 'def-ashique', 'account': 'rrg-ashique', # Job name - # 'job-name': 'naf1', # 1/2GPU, 8G, 75min - 'job-name': 'sac1', # 1/2GPU, 8G, 75min + 'job-name': 'mujoco_sac', # 1/2GPU, 1CPU, 8G, 75min + 'job-name': 'mujoco_sac', # 1GPU, 1CPU, 8G, 50min # Job time 'time': '0-01:15:00', # Email notification @@ -39,8 +39,7 @@ def main(argv): 'cluster_capacity': 996, # Job indexes list # 'job-list': np.array([1,2]) - # 'job-list': np.array(range(3, 40+1)) - 'job-list': np.array(range(1, 5+1)) + 'job-list': np.array(range(1, 60+1)) } make_dir(f"output/{sbatch_cfg['job-name']}") @@ -50,7 +49,7 @@ def main(argv): # Max number of parallel jobs in one task max_parallel_jobs = 2 mem_per_job = 8 # in GB - cpu_per_job = 4 # Larger cpus_per_job increases speed + cpu_per_job = 1 # Larger cpus_per_job increases speed mem_per_cpu = int(ceil(max_parallel_jobs*mem_per_job/cpu_per_job)) # Write to procfile for Parallel with open('procfile', 'w') as f: @@ -66,7 +65,7 @@ def main(argv): submitter.multiple_submit() elif args.job_type == 'S': mem_per_job = 8 # in GB - cpu_per_job = 4 # Larger cpus_per_job increases speed + cpu_per_job = 1 # Larger cpus_per_job increases speed mem_per_cpu = int(ceil(mem_per_job/cpu_per_job)) sbatch_cfg['gres'] = 'gpu:1' # GPU type sbatch_cfg['cpus-per-task'] = cpu_per_job