Skip to content

Commit

Permalink
modify weighting example for the state weighted formulation
Browse files Browse the repository at this point in the history
  • Loading branch information
EthanJamesLew committed Apr 9, 2024
1 parent 81dda39 commit 8303d95
Showing 1 changed file with 20 additions and 7 deletions.
27 changes: 20 additions & 7 deletions notebooks/weighted-cost-func.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
"fhn = fhn.FitzHughNagumo()\n",
"training_data = fhn.solve_ivps(\n",
" initial_states=np.random.uniform(low=-2.0, high=2.0, size=(10, 2)),\n",
" tspan=[0.0, 10.0],\n",
" tspan=[0.0, 6.0],\n",
" sampling_period=0.1\n",
")"
]
Expand All @@ -66,15 +66,27 @@
" \n",
"\n",
" # weight good trajectory by its 1 norm\n",
" w = np.sum(traj.abs().states, axis=1)\n",
" #w = np.sum(traj.abs().states, axis=1)\n",
" w = np.ones(traj.states.shape)\n",
" weights.append(w)\n",
"\n",
" # weight garbage trajectory to zero\n",
" w = np.zeros(len(traj.states))\n",
" #w = np.zeros(len(traj.states))\n",
" w = np.zeros(traj.states.shape)\n",
" weights.append(w)\n",
"\n",
"# you can also use a dict to name the trajectories if using TrajectoriesData (numpy arrays are named by their index number)\n",
"weights = {idx: w for idx, w in enumerate(weights)}"
"#weights = {idx: w for idx, w in enumerate(weights)}"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "280b4bb3-4f7d-4a94-a983-663c6255bc83",
"metadata": {},
"outputs": [],
"source": [
"weights[1].shape"
]
},
{
Expand All @@ -93,7 +105,7 @@
" learning_weights=weights, # weight the eDMD algorithm objectives\n",
" scoring_weights=weights, # pass weights as required for cost_func=\"weighted\"\n",
" opt=\"grid\", # grid search to find best hyperparameters\n",
" n_obs=200, # maximum number of observables to try\n",
" n_obs=40, # maximum number of observables to try\n",
" max_opt_iter=200, # maximum number of optimization iterations\n",
" grid_param_slices=5, # for grid search, number of slices for each parameter\n",
" n_splits=5, # k-folds validation for tuning, helps stabilize the scoring\n",
Expand All @@ -117,7 +129,7 @@
" learning_weights=None, # don't use eDMD weighting\n",
" scoring_weights=weights, # pass weights as required for cost_func=\"weighted\"\n",
" opt=\"grid\", # grid search to find best hyperparameters\n",
" n_obs=200, # maximum number of observables to try\n",
" n_obs=40, # maximum number of observables to try\n",
" max_opt_iter=200, # maximum number of optimization iterations\n",
" grid_param_slices=5, # for grid search, number of slices for each parameter\n",
" n_splits=5, # k-folds validation for tuning, helps stabilize the scoring\n",
Expand Down Expand Up @@ -178,9 +190,10 @@
"plt.figure(figsize=(10, 6))\n",
"\n",
"# plot the results\n",
"plt.plot(*true_trajectory.states.T, linewidth=2, label='Ground Truth')\n",
"plt.plot(*trajectory.states.T, label='Weighted Trajectory Prediction')\n",
"plt.plot(*trajectory_uw.states.T, label='Trajectory Prediction')\n",
"plt.plot(*true_trajectory.states.T, label='Ground Truth')\n",
"\n",
"\n",
"plt.xlabel(\"$x_1$\")\n",
"plt.ylabel(\"$x_2$\")\n",
Expand Down

0 comments on commit 8303d95

Please sign in to comment.