From d6fdbcc87434d2a035d04a1e7e3dbaf1d4d04ef2 Mon Sep 17 00:00:00 2001 From: qlan3 Date: Sat, 6 Apr 2024 23:07:08 -0600 Subject: [PATCH] minor update --- .gitignore | 2 +- README.md | 4 +- agents/BaseAgent.py | 2 +- agents/DDPG.py | 2 +- agents/DQN.py | 2 +- agents/PPO.py | 34 ++-------------- agents/SAC.py | 2 +- agents/TD3.py | 2 +- analysis.py | 27 ++++++------- requirements.txt | 4 +- utils/plotter.py | 97 ++++++++++++++++++++++++++++++++------------- 11 files changed, 97 insertions(+), 81 deletions(-) diff --git a/.gitignore b/.gitignore index cd74430..685067b 100644 --- a/.gitignore +++ b/.gitignore @@ -9,7 +9,7 @@ logfile *DS_Store* job_idx_* procfile -slacker_msger.py +slacker.py backup run.sh diff --git a/README.md b/README.md index 3213c30..be17ea7 100644 --- a/README.md +++ b/README.md @@ -12,7 +12,6 @@ Jaxplorer is a **Jax** reinforcement learning (RL) framework for **exploring** n ## TODO - Add more descriptions about slurm and hyper-parameter comparison. -- Improve the performance of PPO in MuJoCo tasks. - Add more algorithms, such as DQN for Atari games. @@ -23,7 +22,8 @@ Jaxplorer is a **Jax** reinforcement learning (RL) framework for **exploring** n - [Gymnasium](https://github.com/Farama-Foundation/Gymnasium): `pip install gymnasium==0.29.1` - [MuJoCo](https://github.com/google-deepmind/mujoco): `pip install mujoco==2.3.7` - [Gymnasium(mujoco)](https://gymnasium.farama.org/environments/mujoco/): `pip install gymnasium[mujoco]` -- Others: Please check `requirements.txt`. +- [Gym Games](https://github.com/qlan3/gym-games): >=2.0.0. +- Others: `pip install -r requirements.txt`. ## Implemented algorithms diff --git a/agents/BaseAgent.py b/agents/BaseAgent.py index 0f32e8b..8b13c31 100644 --- a/agents/BaseAgent.py +++ b/agents/BaseAgent.py @@ -55,7 +55,7 @@ def set_optim(self, optim_name, optim_kwargs, schedule=None): grad_clip = optim_kwargs['grad_clip'] max_grad_norm = optim_kwargs['max_grad_norm'] del optim_kwargs['anneal_lr'], optim_kwargs['grad_clip'], optim_kwargs['max_grad_norm'] - assert not (grad_clip > 0 and max_grad_norm > 0), 'Either grad_clip or max_grad_norm should be set.' + assert not (grad_clip > 0 and max_grad_norm > 0), 'Cannot apply both grad_clip and max_grad_norm at the same time.' if anneal_lr and schedule is not None: optim_kwargs['learning_rate'] = schedule if grad_clip > 0: diff --git a/agents/DDPG.py b/agents/DDPG.py index c4f2dad..5ad6d1f 100644 --- a/agents/DDPG.py +++ b/agents/DDPG.py @@ -11,7 +11,7 @@ class DDPG(SAC): """ - Implementation of DDPG (Deep Deterministic Policy Gradient) + Implementation of Deep Deterministic Policy Gradient. """ def createNN(self): # Create train_states and nets of actor, critic, and temperature diff --git a/agents/DQN.py b/agents/DQN.py index 4fb548d..e208f5c 100644 --- a/agents/DQN.py +++ b/agents/DQN.py @@ -12,7 +12,7 @@ class DQN(BaseAgent): """ - Implementation of DQN. + Implementation of Deep Q-Learning. """ def __init__(self, cfg): super().__init__(cfg) diff --git a/agents/PPO.py b/agents/PPO.py index 7e3059e..3912a51 100644 --- a/agents/PPO.py +++ b/agents/PPO.py @@ -1,7 +1,6 @@ import jax import time import math -import optax import distrax import numpy as np from tqdm import tqdm @@ -39,7 +38,7 @@ def ppo_make_env(env_name, gamma=0.99, deque_size=1, **kwargs): class PPO(BaseAgent): """ - Implementation of PPO. + Implementation of Proximal Policy Optimization. """ def __init__(self, cfg): super().__init__(cfg) @@ -58,31 +57,6 @@ def __init__(self, cfg): # Set networks self.createNN() - def set_optim(self, optim_name, optim_kwargs, schedule=None): - optim_kwargs.setdefault('anneal_lr', False) - optim_kwargs.setdefault('grad_clip', -1) - optim_kwargs.setdefault('max_grad_norm', -1) - anneal_lr = optim_kwargs['anneal_lr'] - grad_clip = optim_kwargs['grad_clip'] - max_grad_norm = optim_kwargs['max_grad_norm'] - del optim_kwargs['anneal_lr'], optim_kwargs['grad_clip'], optim_kwargs['max_grad_norm'] - assert not (grad_clip > 0 and max_grad_norm > 0), 'Cannot apply both grad_clip and max_grad_norm at the same time.' - if anneal_lr and schedule is not None: - optim_kwargs['learning_rate'] = schedule - if grad_clip > 0: - optim = optax.chain( - optax.clip(grad_clip), - getattr(optax, optim_name.lower())(**optim_kwargs) - ) - elif max_grad_norm > 0: - optim = optax.chain( - optax.clip_by_global_norm(max_grad_norm), - getattr(optax, optim_name.lower())(**optim_kwargs) - ) - else: - optim = getattr(optax, optim_name.lower())(**optim_kwargs) - return optim - def createNN(self): # Create nets and train_states lr = self.cfg['optim']['kwargs']['learning_rate'] @@ -128,9 +102,7 @@ def run_steps(self, mode='Train'): # Take a env step next_obs, reward, terminated, truncated, info = self.env[mode].step(action) # Save experience - # done = terminated or truncated - # mask = self.discount * (1-done) - mask = self.discount * (1-terminated) + mask = self.discount * (1 - terminated) self.save_experience(obs, action, reward, mask, v, log_pi) # Update observation obs = next_obs @@ -264,7 +236,7 @@ def compute_loss(params): v_clipped = batch['v'] + (v - batch['v']).clip(-self.cfg['agent']['clip_ratio'], self.cfg['agent']['clip_ratio']) critic_loss_unclipped = jnp.square(v - batch['v_target']) critic_loss_clipped = jnp.square(v_clipped - batch['v_target']) - critic_loss = 0.5 * jnp.maximum(critic_loss_unclipped, critic_loss_clipped).mean() + critic_loss = 0.5 * jnp.maximum(critic_loss_unclipped, critic_loss_clipped).mean() # Compute actor loss action_mean, action_log_std = actor_state.apply_fn(actor_param, batch['obs']) pi = distrax.MultivariateNormalDiag(action_mean, jnp.exp(action_log_std)) diff --git a/agents/SAC.py b/agents/SAC.py index f835ea9..c86170a 100644 --- a/agents/SAC.py +++ b/agents/SAC.py @@ -15,7 +15,7 @@ class SAC(BaseAgent): """ - Implementation of SAC (Soft Actor-Critic). + Implementation of Soft Actor-Critic. """ def __init__(self, cfg): cfg['agent'].setdefault('actor_update_steps', 1) diff --git a/agents/TD3.py b/agents/TD3.py index a0f599a..9d0720d 100644 --- a/agents/TD3.py +++ b/agents/TD3.py @@ -8,7 +8,7 @@ class TD3(DDPG): """ - Implementation of TD3 (Twin Delayed Deep Deterministic Policy Gradients) + Implementation of Twin Delayed Deep Deterministic Policy Gradients. """ @partial(jax.jit, static_argnames=['self']) def update_critic(self, actor_state, critic_state, temp_state, batch, seed): diff --git a/analysis.py b/analysis.py index 019722c..96b3e5e 100644 --- a/analysis.py +++ b/analysis.py @@ -12,7 +12,7 @@ def get_process_result_dict(result, config_idx, mode='Train'): 'Env': result['Env'][0], 'Agent': result['Agent'][0], 'Config Index': config_idx, - 'Return (mean)': result['Return'][-10:].mean(skipna=True) if mode=='Train' else result['Return'][-5:].mean(skipna=True) + 'Return (mean)': result['Return'][-1*int(len(result['Return'])*0.1):].mean(skipna=True), # mean of last 10% } return result_dict @@ -66,7 +66,7 @@ def analyze(exp, runs=1): dqn = ['optim/kwargs/learning_rate'], ddqn = ['optim/kwargs/learning_rate'], maxmin = ['optim/kwargs/learning_rate', 'agent/critic_num'], - ppo = ['optim/kwargs/learning_rate', 'agent/actor_net_cfg/hidden_dims'], + ppo = ['optim/kwargs/learning_rate'], ) algo = exp.rstrip('0123456789').split('_')[-1] cfg['sweep_keys'] = sweep_keys_dict[algo] @@ -75,29 +75,28 @@ def analyze(exp, runs=1): plotter.csv_merged_results('Train', get_csv_result_dict, get_process_result_dict) plotter.plot_results(mode='Train', indexes='all') - # plotter.csv_merged_results('Test', get_csv_result_dict, get_process_result_dict) - # plotter.plot_results(mode='Test', indexes='all') - - # plotter.get_top_result(group_keys=group_keys, group_fn='mean_std', top_n=1, mode='Test', markdown=False, dn=3) + # plotter.csv_unmerged_results('Train', get_process_result_dict) + # group_keys = ['optim/kwargs/learning_rate', 'agent/critic_num'] + # plotter.get_top1_result(group_keys=group_keys, perf='Return (bmean)', errorbar='Return (ci=95)', mode='Train', nd=2, markdown=False) # Hyper-parameter Comparison # plotter.csv_unmerged_results('Train', get_process_result_dict) # plotter.csv_unmerged_results('Test', get_process_result_dict) # constraints = [('agent/name', ['NAF'])] - # # constraints = [] + # constraints = [] # for param_name in cfg['sweep_keys']: - # plotter.compare_parameter(param_name=param_name, constraints=constraints, mode='Test', kde=False) - # plotter.compare_parameter(param_name=param_name, constraints=constraints, mode='Test', kde=True) + # plotter.compare_parameter(param_name=param_name, constraints=constraints, mode='Train', kde=False) + # plotter.compare_parameter(param_name=param_name, constraints=constraints, mode='Train', kde=True) if __name__ == "__main__": runs = 10 - mujoco_list = ['mujoco_sac', 'mujoco_ddpg', 'mujoco_td3', 'mujoco_naf'] + mujoco_list = ['mujoco_sac', 'mujoco_ddpg', 'mujoco_td3', 'mujoco_ppo', 'mujoco_naf'] dqn_list = ['classic_dqn', 'lunar_dqn', 'pygame_dqn', 'minatar_dqn'] ddqn_list = ['classic_ddqn', 'lunar_ddqn', 'pygame_ddqn', 'minatar_ddqn'] maxmin_list = ['classic_maxmin', 'lunar_maxmin', 'pygame_maxmin', 'minatar_maxmin'] - for exp in ['mujoco_ppo']: - # unfinished_index(exp, runs=runs) - # memory_info(exp, runs=runs) - # time_info(exp, runs=runs) + for exp in mujoco_list: + unfinished_index(exp, runs=runs) + memory_info(exp, runs=runs) + time_info(exp, runs=runs) analyze(exp, runs=runs) \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index cffd2eb..d69757b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,6 +2,7 @@ jax>=0.4.20 mujoco==2.3.7 gymnasium>=0.29.1 flax>=0.7.5 +distrax>=0.1.5 matplotlib>=3.8.1 numpy>=1.25.2 optax>=0.1.7 @@ -9,4 +10,5 @@ pandas>=2.1 psutil>=5.9.6 seaborn>=0.13.0 pyarrow -gpustat \ No newline at end of file +gpustat +tqdm \ No newline at end of file diff --git a/utils/plotter.py b/utils/plotter.py index 1830868..7be9159 100644 --- a/utils/plotter.py +++ b/utils/plotter.py @@ -320,22 +320,22 @@ def csv_unmerged_results(self, mode, get_process_result_dict): results_file = f'./logs/{self.exp}/0/results_{mode}_unmerged.csv' results.to_csv(results_file, index=False) - def compare_parameter(self, param_name, perf_name=None, image_name=None, constraints=[], mode='Train', stat='count', kde=False): + def compare_parameter(self, param, perf=None, image_name=None, constraints=[], mode='Train', stat='count', kde=False): ''' Plot histograms for hyper-parameter selection. - perf_name: the performance metric from results_{mode}.csv, such as Return (mean). - param_name: the name of considered hyper-parameter, such lr. + perf: the performance metric from results_{mode}.csv, such as Return (mean). + param: the name of considered hyper-parameter, such lr. image_name: the name of the plotted image. constraints: a list of tuple (k, [x,y,...]). We only consider index with config_dict[k] in [x,y,...]. mode: Train or Test. stat: for seaborn plot function kde: if True, plot all kdes (kernel density estimations) in one figure; o.w. plot histograms in different subfigures ''' - param_name_short = param_name.split('/')[-1] + param_name_short = param.split('/')[-1] if image_name is None: image_name = param_name_short - if perf_name is None: - perf_name = f'{self.y_label} (mean)' + if perf is None: + perf = f'{self.y_label} (mean)' config_file = f'./configs/{self.exp}.json' results_file = f'./logs/{self.exp}/0/results_{mode}_unmerged.csv' if kde: @@ -346,28 +346,28 @@ def compare_parameter(self, param_name, perf_name=None, image_name=None, constra assert os.path.exists(config_file), f'{config_file} does not exist.' # Load all results results = pd.read_csv(results_file) - # Select results based on the constraints and param_name + # Select results based on the constraints and param for k, vs in constraints: results = results.loc[lambda df: df[k].isin(vs), :] - results = results.loc[:, [perf_name, param_name]] - results.rename(columns={param_name: param_name_short}, inplace=True) + results = results.loc[:, [perf, param]] + results.rename(columns={param: param_name_short}, inplace=True) # Plot param_values = sorted(list(set(results[param_name_short]))) if len(param_values) == 1 and param_values[0] == '/': return if kde: # Plot all kdes in one figure fig, ax = plt.subplots() - # sns.histplot(data=results, x=perf_name, hue=param_name_short, kde=True, stat=stat, palette='bright', discrete=True) - sns.kdeplot(data=results, x=perf_name, hue=param_name_short, palette='bright') + # sns.histplot(data=results, x=perf, hue=param_name_short, kde=True, stat=stat, palette='bright', discrete=True) + sns.kdeplot(data=results, x=perf, hue=param_name_short, palette='bright') ax.grid(axis='y') else: # Plot histograms in different subfigures fig, axs = plt.subplots(len(param_values), 1, sharex=True, sharey=True, figsize=(7, 3*len(param_values))) if len(param_values) == 1: axs = [axs] for i, param_v in enumerate(param_values): - sns.histplot(data=results[results[param_name_short]==param_v], x=perf_name, hue=param_name_short, kde=False, stat=stat, palette='bright', ax=axs[i], discrete=True) + sns.histplot(data=results[results[param_name_short]==param_v], x=perf, hue=param_name_short, kde=False, stat=stat, palette='bright', ax=axs[i], discrete=True) axs[i].grid(axis='y') - plt.xlabel(perf_name) + plt.xlabel(perf) plt.tight_layout() plt.savefig(image_path) if self.show: @@ -376,34 +376,77 @@ def compare_parameter(self, param_name, perf_name=None, image_name=None, constra plt.cla() # clear axis plt.close() # close window - def get_top_result(self, group_keys, group_fn='mean_std', perf_name=None, ascending=False, top_n=None, mode='Test', dn=3, markdown=True): + def get_topn_result(self, group_keys, group_fn='mean_std', perf=None, ascending=False, topn=None, mode='Test', nd=2, markdown=True): ''' Print averaged top results group_keys: keys to group all results. - perf_name: the performance metric from results_{mode}.csv, such as Return (mean). + group_fn: min_max or mean_std. + perf: the performance metric from results_{mode}.csv, such as Return (mean). ascending: sort results with ascending/decending order. - top_n: select top_n results. When it is None, select all results. + topn: select topn results. When it is None, select all results. mode: Train or Test. - dn: number of decimal digits. - ''' - if perf_name is None: - perf_name = f'{self.y_label} (mean)' + nd: number of decimal digits to display. + ''' + if perf is None: + perf = f'{self.y_label} (bmean)' config_file = f'./configs/{self.exp}.json' results_file = f'./logs/{self.exp}/0/results_{mode}_merged.csv' assert os.path.exists(results_file), f'{results_file} does not exist. Please generate it first with csv_unmerged_results.' assert os.path.exists(config_file), f'{config_file} does not exist.' # Load all results results = pd.read_csv(results_file) - # Select results based on the constraints and param_name + # Select results based on the constraints and param def min_max(group): - l = group.sort_values(by=perf_name, ascending=ascending)[:top_n][perf_name].to_numpy(na_value=0) - return pd.DataFrame({'min-->max': [f'({l.min():.{dn}f}, {l.max():.{dn}f})']}) + l = group.sort_values(by=perf, ascending=ascending)[:topn][perf].to_numpy(na_value=0) + return pd.DataFrame({'min-->max': [f'{l.min():.{nd}f}-->{l.max():.{nd}f}']}) def mean_std(group): - l = group.sort_values(by=perf_name, ascending=ascending)[:top_n][perf_name].to_numpy(na_value=0) - return pd.DataFrame({'mean+/-std': [f'({l.mean():.{dn}f}, {l.std():.{dn}f})']}) + l = group.sort_values(by=perf, ascending=ascending)[:topn][perf].to_numpy(na_value=0) + return pd.DataFrame({'mean+/-std': [f'{l.mean():.{nd}f}+/-{l.std():.{nd}f}']}) + # Get the intersection of two key list + group_keys = list(set(group_keys).intersection(results.columns.tolist())) grouped_results = results.groupby(by=group_keys).apply(eval(group_fn)).reset_index(level=-1, drop=True) - print('Performance measurement:', perf_name) - print(f'Top {top_n} results:') + print('Performance measurement:', perf) + print(f'Top {topn} results:') + print(grouped_results) + print('-'*20) + if markdown: + markdown_table = grouped_results.to_markdown(tablefmt='github') + print('Markdown Table:') + print(markdown_table) + print('-'*20) + + def get_top1_result(self, group_keys, perf=None, errorbar=None, ascending=False, mode='Test', nd=2, markdown=True): + ''' + Print top 1 result + group_keys: keys to group all results. + perf: the performance metric from results_{mode}.csv, such as Return (bmean). + errorbar: the error bar from results_{mode}.csv, such as Return (ci=95). + ascending: sort results with ascending/decending order. + mode: Train or Test. + nd: number of decimal digits. + ''' + topn = 1 + if perf is None: + perf = f'{self.y_label} (bmean)' + if errorbar is None: + perf = f'{self.y_label} (ci=95)' + config_file = f'./configs/{self.exp}.json' + results_file = f'./logs/{self.exp}/0/results_{mode}_merged.csv' + assert os.path.exists(results_file), f'{results_file} does not exist. Please generate it first with csv_unmerged_results.' + assert os.path.exists(config_file), f'{config_file} does not exist.' + # Load all results + results = pd.read_csv(results_file) + # Select results based on the constraints and param + def top1(group): + l = group.sort_values(by=perf, ascending=ascending)[:1] + mean_value = l[perf].to_numpy(na_value=0)[0] + errorbar_value = l[errorbar].to_numpy(na_value=0)[0] + return pd.DataFrame({'bmean+/-ci': [f'{mean_value:.{nd}f}+/-{errorbar_value:.{nd}f}']}) + # Get the intersection of two key list + group_keys = list(set(group_keys).intersection(results.columns.tolist())) + grouped_results = results.groupby(by=group_keys).apply(top1).reset_index(level=-1, drop=True) + print('Performance measurement:', perf) + print(f'Top {topn} results:') print(grouped_results) print('-'*20) if markdown: