Skip to content

Commit

Permalink
minor update
Browse files Browse the repository at this point in the history
  • Loading branch information
qlan3 committed Apr 7, 2024
1 parent ab08db0 commit d6fdbcc
Show file tree
Hide file tree
Showing 11 changed files with 97 additions and 81 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ logfile
*DS_Store*
job_idx_*
procfile
slacker_msger.py
slacker.py
backup
run.sh

Expand Down
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ Jaxplorer is a **Jax** reinforcement learning (RL) framework for **exploring** n
## TODO

- Add more descriptions about slurm and hyper-parameter comparison.
- Improve the performance of PPO in MuJoCo tasks.
- Add more algorithms, such as DQN for Atari games.


Expand All @@ -23,7 +22,8 @@ Jaxplorer is a **Jax** reinforcement learning (RL) framework for **exploring** n
- [Gymnasium](https://github.com/Farama-Foundation/Gymnasium): `pip install gymnasium==0.29.1`
- [MuJoCo](https://github.com/google-deepmind/mujoco): `pip install mujoco==2.3.7`
- [Gymnasium(mujoco)](https://gymnasium.farama.org/environments/mujoco/): `pip install gymnasium[mujoco]`
- Others: Please check `requirements.txt`.
- [Gym Games](https://github.com/qlan3/gym-games): >=2.0.0.
- Others: `pip install -r requirements.txt`.


## Implemented algorithms
Expand Down
2 changes: 1 addition & 1 deletion agents/BaseAgent.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def set_optim(self, optim_name, optim_kwargs, schedule=None):
grad_clip = optim_kwargs['grad_clip']
max_grad_norm = optim_kwargs['max_grad_norm']
del optim_kwargs['anneal_lr'], optim_kwargs['grad_clip'], optim_kwargs['max_grad_norm']
assert not (grad_clip > 0 and max_grad_norm > 0), 'Either grad_clip or max_grad_norm should be set.'
assert not (grad_clip > 0 and max_grad_norm > 0), 'Cannot apply both grad_clip and max_grad_norm at the same time.'
if anneal_lr and schedule is not None:
optim_kwargs['learning_rate'] = schedule
if grad_clip > 0:
Expand Down
2 changes: 1 addition & 1 deletion agents/DDPG.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

class DDPG(SAC):
"""
Implementation of DDPG (Deep Deterministic Policy Gradient)
Implementation of Deep Deterministic Policy Gradient.
"""
def createNN(self):
# Create train_states and nets of actor, critic, and temperature
Expand Down
2 changes: 1 addition & 1 deletion agents/DQN.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

class DQN(BaseAgent):
"""
Implementation of DQN.
Implementation of Deep Q-Learning.
"""
def __init__(self, cfg):
super().__init__(cfg)
Expand Down
34 changes: 3 additions & 31 deletions agents/PPO.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import jax
import time
import math
import optax
import distrax
import numpy as np
from tqdm import tqdm
Expand Down Expand Up @@ -39,7 +38,7 @@ def ppo_make_env(env_name, gamma=0.99, deque_size=1, **kwargs):

class PPO(BaseAgent):
"""
Implementation of PPO.
Implementation of Proximal Policy Optimization.
"""
def __init__(self, cfg):
super().__init__(cfg)
Expand All @@ -58,31 +57,6 @@ def __init__(self, cfg):
# Set networks
self.createNN()

def set_optim(self, optim_name, optim_kwargs, schedule=None):
optim_kwargs.setdefault('anneal_lr', False)
optim_kwargs.setdefault('grad_clip', -1)
optim_kwargs.setdefault('max_grad_norm', -1)
anneal_lr = optim_kwargs['anneal_lr']
grad_clip = optim_kwargs['grad_clip']
max_grad_norm = optim_kwargs['max_grad_norm']
del optim_kwargs['anneal_lr'], optim_kwargs['grad_clip'], optim_kwargs['max_grad_norm']
assert not (grad_clip > 0 and max_grad_norm > 0), 'Cannot apply both grad_clip and max_grad_norm at the same time.'
if anneal_lr and schedule is not None:
optim_kwargs['learning_rate'] = schedule
if grad_clip > 0:
optim = optax.chain(
optax.clip(grad_clip),
getattr(optax, optim_name.lower())(**optim_kwargs)
)
elif max_grad_norm > 0:
optim = optax.chain(
optax.clip_by_global_norm(max_grad_norm),
getattr(optax, optim_name.lower())(**optim_kwargs)
)
else:
optim = getattr(optax, optim_name.lower())(**optim_kwargs)
return optim

def createNN(self):
# Create nets and train_states
lr = self.cfg['optim']['kwargs']['learning_rate']
Expand Down Expand Up @@ -128,9 +102,7 @@ def run_steps(self, mode='Train'):
# Take a env step
next_obs, reward, terminated, truncated, info = self.env[mode].step(action)
# Save experience
# done = terminated or truncated
# mask = self.discount * (1-done)
mask = self.discount * (1-terminated)
mask = self.discount * (1 - terminated)
self.save_experience(obs, action, reward, mask, v, log_pi)
# Update observation
obs = next_obs
Expand Down Expand Up @@ -264,7 +236,7 @@ def compute_loss(params):
v_clipped = batch['v'] + (v - batch['v']).clip(-self.cfg['agent']['clip_ratio'], self.cfg['agent']['clip_ratio'])
critic_loss_unclipped = jnp.square(v - batch['v_target'])
critic_loss_clipped = jnp.square(v_clipped - batch['v_target'])
critic_loss = 0.5 * jnp.maximum(critic_loss_unclipped, critic_loss_clipped).mean()
critic_loss = 0.5 * jnp.maximum(critic_loss_unclipped, critic_loss_clipped).mean()
# Compute actor loss
action_mean, action_log_std = actor_state.apply_fn(actor_param, batch['obs'])
pi = distrax.MultivariateNormalDiag(action_mean, jnp.exp(action_log_std))
Expand Down
2 changes: 1 addition & 1 deletion agents/SAC.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

class SAC(BaseAgent):
"""
Implementation of SAC (Soft Actor-Critic).
Implementation of Soft Actor-Critic.
"""
def __init__(self, cfg):
cfg['agent'].setdefault('actor_update_steps', 1)
Expand Down
2 changes: 1 addition & 1 deletion agents/TD3.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

class TD3(DDPG):
"""
Implementation of TD3 (Twin Delayed Deep Deterministic Policy Gradients)
Implementation of Twin Delayed Deep Deterministic Policy Gradients.
"""
@partial(jax.jit, static_argnames=['self'])
def update_critic(self, actor_state, critic_state, temp_state, batch, seed):
Expand Down
27 changes: 13 additions & 14 deletions analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def get_process_result_dict(result, config_idx, mode='Train'):
'Env': result['Env'][0],
'Agent': result['Agent'][0],
'Config Index': config_idx,
'Return (mean)': result['Return'][-10:].mean(skipna=True) if mode=='Train' else result['Return'][-5:].mean(skipna=True)
'Return (mean)': result['Return'][-1*int(len(result['Return'])*0.1):].mean(skipna=True), # mean of last 10%
}
return result_dict

Expand Down Expand Up @@ -66,7 +66,7 @@ def analyze(exp, runs=1):
dqn = ['optim/kwargs/learning_rate'],
ddqn = ['optim/kwargs/learning_rate'],
maxmin = ['optim/kwargs/learning_rate', 'agent/critic_num'],
ppo = ['optim/kwargs/learning_rate', 'agent/actor_net_cfg/hidden_dims'],
ppo = ['optim/kwargs/learning_rate'],
)
algo = exp.rstrip('0123456789').split('_')[-1]
cfg['sweep_keys'] = sweep_keys_dict[algo]
Expand All @@ -75,29 +75,28 @@ def analyze(exp, runs=1):
plotter.csv_merged_results('Train', get_csv_result_dict, get_process_result_dict)
plotter.plot_results(mode='Train', indexes='all')

# plotter.csv_merged_results('Test', get_csv_result_dict, get_process_result_dict)
# plotter.plot_results(mode='Test', indexes='all')

# plotter.get_top_result(group_keys=group_keys, group_fn='mean_std', top_n=1, mode='Test', markdown=False, dn=3)
# plotter.csv_unmerged_results('Train', get_process_result_dict)
# group_keys = ['optim/kwargs/learning_rate', 'agent/critic_num']
# plotter.get_top1_result(group_keys=group_keys, perf='Return (bmean)', errorbar='Return (ci=95)', mode='Train', nd=2, markdown=False)

# Hyper-parameter Comparison
# plotter.csv_unmerged_results('Train', get_process_result_dict)
# plotter.csv_unmerged_results('Test', get_process_result_dict)
# constraints = [('agent/name', ['NAF'])]
# # constraints = []
# constraints = []
# for param_name in cfg['sweep_keys']:
# plotter.compare_parameter(param_name=param_name, constraints=constraints, mode='Test', kde=False)
# plotter.compare_parameter(param_name=param_name, constraints=constraints, mode='Test', kde=True)
# plotter.compare_parameter(param_name=param_name, constraints=constraints, mode='Train', kde=False)
# plotter.compare_parameter(param_name=param_name, constraints=constraints, mode='Train', kde=True)


if __name__ == "__main__":
runs = 10
mujoco_list = ['mujoco_sac', 'mujoco_ddpg', 'mujoco_td3', 'mujoco_naf']
mujoco_list = ['mujoco_sac', 'mujoco_ddpg', 'mujoco_td3', 'mujoco_ppo', 'mujoco_naf']
dqn_list = ['classic_dqn', 'lunar_dqn', 'pygame_dqn', 'minatar_dqn']
ddqn_list = ['classic_ddqn', 'lunar_ddqn', 'pygame_ddqn', 'minatar_ddqn']
maxmin_list = ['classic_maxmin', 'lunar_maxmin', 'pygame_maxmin', 'minatar_maxmin']
for exp in ['mujoco_ppo']:
# unfinished_index(exp, runs=runs)
# memory_info(exp, runs=runs)
# time_info(exp, runs=runs)
for exp in mujoco_list:
unfinished_index(exp, runs=runs)
memory_info(exp, runs=runs)
time_info(exp, runs=runs)
analyze(exp, runs=runs)
4 changes: 3 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@ jax>=0.4.20
mujoco==2.3.7
gymnasium>=0.29.1
flax>=0.7.5
distrax>=0.1.5
matplotlib>=3.8.1
numpy>=1.25.2
optax>=0.1.7
pandas>=2.1
psutil>=5.9.6
seaborn>=0.13.0
pyarrow
gpustat
gpustat
tqdm
97 changes: 70 additions & 27 deletions utils/plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,22 +320,22 @@ def csv_unmerged_results(self, mode, get_process_result_dict):
results_file = f'./logs/{self.exp}/0/results_{mode}_unmerged.csv'
results.to_csv(results_file, index=False)

def compare_parameter(self, param_name, perf_name=None, image_name=None, constraints=[], mode='Train', stat='count', kde=False):
def compare_parameter(self, param, perf=None, image_name=None, constraints=[], mode='Train', stat='count', kde=False):
'''
Plot histograms for hyper-parameter selection.
perf_name: the performance metric from results_{mode}.csv, such as Return (mean).
param_name: the name of considered hyper-parameter, such lr.
perf: the performance metric from results_{mode}.csv, such as Return (mean).
param: the name of considered hyper-parameter, such lr.
image_name: the name of the plotted image.
constraints: a list of tuple (k, [x,y,...]). We only consider index with config_dict[k] in [x,y,...].
mode: Train or Test.
stat: for seaborn plot function
kde: if True, plot all kdes (kernel density estimations) in one figure; o.w. plot histograms in different subfigures
'''
param_name_short = param_name.split('/')[-1]
param_name_short = param.split('/')[-1]
if image_name is None:
image_name = param_name_short
if perf_name is None:
perf_name = f'{self.y_label} (mean)'
if perf is None:
perf = f'{self.y_label} (mean)'
config_file = f'./configs/{self.exp}.json'
results_file = f'./logs/{self.exp}/0/results_{mode}_unmerged.csv'
if kde:
Expand All @@ -346,28 +346,28 @@ def compare_parameter(self, param_name, perf_name=None, image_name=None, constra
assert os.path.exists(config_file), f'{config_file} does not exist.'
# Load all results
results = pd.read_csv(results_file)
# Select results based on the constraints and param_name
# Select results based on the constraints and param
for k, vs in constraints:
results = results.loc[lambda df: df[k].isin(vs), :]
results = results.loc[:, [perf_name, param_name]]
results.rename(columns={param_name: param_name_short}, inplace=True)
results = results.loc[:, [perf, param]]
results.rename(columns={param: param_name_short}, inplace=True)
# Plot
param_values = sorted(list(set(results[param_name_short])))
if len(param_values) == 1 and param_values[0] == '/':
return
if kde: # Plot all kdes in one figure
fig, ax = plt.subplots()
# sns.histplot(data=results, x=perf_name, hue=param_name_short, kde=True, stat=stat, palette='bright', discrete=True)
sns.kdeplot(data=results, x=perf_name, hue=param_name_short, palette='bright')
# sns.histplot(data=results, x=perf, hue=param_name_short, kde=True, stat=stat, palette='bright', discrete=True)
sns.kdeplot(data=results, x=perf, hue=param_name_short, palette='bright')
ax.grid(axis='y')
else: # Plot histograms in different subfigures
fig, axs = plt.subplots(len(param_values), 1, sharex=True, sharey=True, figsize=(7, 3*len(param_values)))
if len(param_values) == 1:
axs = [axs]
for i, param_v in enumerate(param_values):
sns.histplot(data=results[results[param_name_short]==param_v], x=perf_name, hue=param_name_short, kde=False, stat=stat, palette='bright', ax=axs[i], discrete=True)
sns.histplot(data=results[results[param_name_short]==param_v], x=perf, hue=param_name_short, kde=False, stat=stat, palette='bright', ax=axs[i], discrete=True)
axs[i].grid(axis='y')
plt.xlabel(perf_name)
plt.xlabel(perf)
plt.tight_layout()
plt.savefig(image_path)
if self.show:
Expand All @@ -376,34 +376,77 @@ def compare_parameter(self, param_name, perf_name=None, image_name=None, constra
plt.cla() # clear axis
plt.close() # close window

def get_top_result(self, group_keys, group_fn='mean_std', perf_name=None, ascending=False, top_n=None, mode='Test', dn=3, markdown=True):
def get_topn_result(self, group_keys, group_fn='mean_std', perf=None, ascending=False, topn=None, mode='Test', nd=2, markdown=True):
'''
Print averaged top results
group_keys: keys to group all results.
perf_name: the performance metric from results_{mode}.csv, such as Return (mean).
group_fn: min_max or mean_std.
perf: the performance metric from results_{mode}.csv, such as Return (mean).
ascending: sort results with ascending/decending order.
top_n: select top_n results. When it is None, select all results.
topn: select topn results. When it is None, select all results.
mode: Train or Test.
dn: number of decimal digits.
'''
if perf_name is None:
perf_name = f'{self.y_label} (mean)'
nd: number of decimal digits to display.
'''
if perf is None:
perf = f'{self.y_label} (bmean)'
config_file = f'./configs/{self.exp}.json'
results_file = f'./logs/{self.exp}/0/results_{mode}_merged.csv'
assert os.path.exists(results_file), f'{results_file} does not exist. Please generate it first with csv_unmerged_results.'
assert os.path.exists(config_file), f'{config_file} does not exist.'
# Load all results
results = pd.read_csv(results_file)
# Select results based on the constraints and param_name
# Select results based on the constraints and param
def min_max(group):
l = group.sort_values(by=perf_name, ascending=ascending)[:top_n][perf_name].to_numpy(na_value=0)
return pd.DataFrame({'min-->max': [f'({l.min():.{dn}f}, {l.max():.{dn}f})']})
l = group.sort_values(by=perf, ascending=ascending)[:topn][perf].to_numpy(na_value=0)
return pd.DataFrame({'min-->max': [f'{l.min():.{nd}f}-->{l.max():.{nd}f}']})
def mean_std(group):
l = group.sort_values(by=perf_name, ascending=ascending)[:top_n][perf_name].to_numpy(na_value=0)
return pd.DataFrame({'mean+/-std': [f'({l.mean():.{dn}f}, {l.std():.{dn}f})']})
l = group.sort_values(by=perf, ascending=ascending)[:topn][perf].to_numpy(na_value=0)
return pd.DataFrame({'mean+/-std': [f'{l.mean():.{nd}f}+/-{l.std():.{nd}f}']})
# Get the intersection of two key list
group_keys = list(set(group_keys).intersection(results.columns.tolist()))
grouped_results = results.groupby(by=group_keys).apply(eval(group_fn)).reset_index(level=-1, drop=True)
print('Performance measurement:', perf_name)
print(f'Top {top_n} results:')
print('Performance measurement:', perf)
print(f'Top {topn} results:')
print(grouped_results)
print('-'*20)
if markdown:
markdown_table = grouped_results.to_markdown(tablefmt='github')
print('Markdown Table:')
print(markdown_table)
print('-'*20)

def get_top1_result(self, group_keys, perf=None, errorbar=None, ascending=False, mode='Test', nd=2, markdown=True):
'''
Print top 1 result
group_keys: keys to group all results.
perf: the performance metric from results_{mode}.csv, such as Return (bmean).
errorbar: the error bar from results_{mode}.csv, such as Return (ci=95).
ascending: sort results with ascending/decending order.
mode: Train or Test.
nd: number of decimal digits.
'''
topn = 1
if perf is None:
perf = f'{self.y_label} (bmean)'
if errorbar is None:
perf = f'{self.y_label} (ci=95)'
config_file = f'./configs/{self.exp}.json'
results_file = f'./logs/{self.exp}/0/results_{mode}_merged.csv'
assert os.path.exists(results_file), f'{results_file} does not exist. Please generate it first with csv_unmerged_results.'
assert os.path.exists(config_file), f'{config_file} does not exist.'
# Load all results
results = pd.read_csv(results_file)
# Select results based on the constraints and param
def top1(group):
l = group.sort_values(by=perf, ascending=ascending)[:1]
mean_value = l[perf].to_numpy(na_value=0)[0]
errorbar_value = l[errorbar].to_numpy(na_value=0)[0]
return pd.DataFrame({'bmean+/-ci': [f'{mean_value:.{nd}f}+/-{errorbar_value:.{nd}f}']})
# Get the intersection of two key list
group_keys = list(set(group_keys).intersection(results.columns.tolist()))
grouped_results = results.groupby(by=group_keys).apply(top1).reset_index(level=-1, drop=True)
print('Performance measurement:', perf)
print(f'Top {topn} results:')
print(grouped_results)
print('-'*20)
if markdown:
Expand Down

0 comments on commit d6fdbcc

Please sign in to comment.