Skip to content

Commit

Permalink
add SW-eDMD implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
EthanJamesLew committed Apr 12, 2024
1 parent 8303d95 commit c51b7db
Show file tree
Hide file tree
Showing 2 changed files with 115 additions and 19 deletions.
40 changes: 28 additions & 12 deletions autokoopman/estimator/koopman.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,27 +54,43 @@ def wdmdc(X, Xp, U, r, W):

def swdmdc(X, Xp, U, r, Js, W):
"""State Weighted Dynamic Mode Decomposition with Control (wDMDC)"""
from sklearn.preprocessing import normalize
import cvxpy as cp
assert len(W.shape) == 2, "weights must be 2D for snapshot x state"

if U is not None:
Y = np.hstack((X, U)).T
Y = np.hstack((X, U))
else:
Y = X.T
Yp = Xp.T
Y = X
Yp = Xp
state_size = Yp.shape[1]

# compute observables weights from state weights
Wy = np.vstack([(np.abs(J) @ np.atleast_2d(w).T).T for J, w in zip(Js, W)])
n_snap, n_obs = X.shape
_, n_states = Js[0].shape

# apply weights element-wise
Y, Yp = Wy.T * Y, Wy.T * Yp
state_size = Yp.shape[0]
# so the objective isn't numerically unstable
sf = (1.0 / n_snap)

# compute Atilde
U, Sigma, V = np.linalg.svd(Y, False)
U, Sigma, V = U[:, :r], np.diag(Sigma)[:r, :r], V.conj().T[:, :r]
# koopman operator
K = cp.Variable((n_obs, n_obs))

# SW-eDMD objective
objective = cp.Minimize(sum([
cp.sum_squares(sf * cp.multiply((np.abs(J) @ w)[:n_obs], (xpi[:n_obs] - K @ xi[:n_obs]))) for J, w, xpi, xi in zip(Js, W, Xp, X)
]))

# unconstrained problem
constraints = None

# SW-eDMD problem
prob = cp.Problem(objective, constraints)

# solve for the SW-eDMD Koopman operator
result = prob.solve()
# TODO: check the result

# get the transformation
Atilde = Yp @ V @ np.linalg.inv(Sigma) @ U.conj().T
Atilde = K.value
return Atilde[:, :state_size], Atilde[:, state_size:]


Expand Down
94 changes: 87 additions & 7 deletions notebooks/weighted-cost-func.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,11 @@
" # garbage trajectory\n",
" trajectories.append(np.random.rand(*traj.states.shape))\n",
" \n",
"\n",
" # weight good trajectory by its 1 norm\n",
" #w = np.sum(traj.abs().states, axis=1)\n",
" #w = 1/(traj.abs().states+1.0)\n",
" w = np.ones(traj.states.shape)\n",
" w[:, 0] = 0.2\n",
" weights.append(w)\n",
"\n",
" # weight garbage trajectory to zero\n",
Expand All @@ -86,7 +87,7 @@
"metadata": {},
"outputs": [],
"source": [
"weights[1].shape"
"weights[0].shape"
]
},
{
Expand All @@ -105,7 +106,7 @@
" learning_weights=weights, # weight the eDMD algorithm objectives\n",
" scoring_weights=weights, # pass weights as required for cost_func=\"weighted\"\n",
" opt=\"grid\", # grid search to find best hyperparameters\n",
" n_obs=40, # maximum number of observables to try\n",
" n_obs=30, # maximum number of observables to try\n",
" max_opt_iter=200, # maximum number of optimization iterations\n",
" grid_param_slices=5, # for grid search, number of slices for each parameter\n",
" n_splits=5, # k-folds validation for tuning, helps stabilize the scoring\n",
Expand All @@ -129,7 +130,7 @@
" learning_weights=None, # don't use eDMD weighting\n",
" scoring_weights=weights, # pass weights as required for cost_func=\"weighted\"\n",
" opt=\"grid\", # grid search to find best hyperparameters\n",
" n_obs=40, # maximum number of observables to try\n",
" n_obs=30, # maximum number of observables to try\n",
" max_opt_iter=200, # maximum number of optimization iterations\n",
" grid_param_slices=5, # for grid search, number of slices for each parameter\n",
" n_splits=5, # k-folds validation for tuning, helps stabilize the scoring\n",
Expand Down Expand Up @@ -160,7 +161,7 @@
"model_uw = experiment_results_unweighted['tuned_model']\n",
"\n",
"# simulate using the learned model\n",
"iv = [0.5, 0.1]\n",
"iv = [0.1, 0.5]\n",
"trajectory = model.solve_ivp(\n",
" initial_state=iv,\n",
" tspan=(0.0, 10.0),\n",
Expand Down Expand Up @@ -209,14 +210,93 @@
"id": "b1458259-6c92-46e5-91a3-f56e53633b35",
"metadata": {},
"outputs": [],
"source": []
"source": [
"plt.plot(true_trajectory.states[:, 0], linewidth=2, label='Ground Truth')\n",
"plt.plot(trajectory.states[:, 0], label='Weighted Trajectory Prediction')\n",
"plt.plot(trajectory_uw.states[:, 0], label='Trajectory Prediction')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9626e981-a8b3-40e4-b70c-d95e6dbee7ef",
"metadata": {},
"outputs": [],
"source": [
"plt.plot(true_trajectory.states[:, 1], linewidth=2, label='Ground Truth')\n",
"plt.plot(trajectory.states[:, 1], label='Weighted Trajectory Prediction')\n",
"plt.plot(trajectory_uw.states[:, 1], label='Trajectory Prediction')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "bbe8ef2b-3169-476f-ad8c-003f92469247",
"metadata": {},
"outputs": [],
"source": [
"from casadi import *\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b656fb35-a5e0-4405-9564-332d427bfe47",
"metadata": {},
"outputs": [],
"source": [
"X = np.vstack([t.states[:-1] for t in training_data])\n",
"Xp = np.vstack([t.states[1:] for t in training_data])\n",
"W = [w[:-1] for w in weights]\n",
"g = experiment_results['tuned_model'].obs_func\n",
"gd = experiment_results['tuned_model'].obs.obs_grad\n",
"G = np.vstack([np.atleast_2d(g(x)).T for x in X])\n",
"Gp = np.vstack([np.atleast_2d(g(x)).T for x in Xp])\n",
"\n",
"G, Gd = G.T, Gp.T\n",
"Js = [gd(xi) for xi in X]\n",
"Wy = np.vstack(\n",
" [(np.abs(J) @ w.T).T \n",
" for J, w in zip(Js, W)\n",
" ])"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "861fef48-3b9c-4686-ab7d-11183a80eff0",
"metadata": {},
"outputs": [],
"source": [
"(Js[0] @ W[0].T).shape"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ed116b64-098f-481b-bc5c-958b6a12c096",
"metadata": {},
"outputs": [],
"source": [
"Wy.T * G"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "51ac108b-7e59-4248-adf0-4c0042dac586",
"metadata": {},
"outputs": [],
"source": [
"J.shape"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b82faf32-98db-45cf-a188-0479e424272d",
"metadata": {},
"outputs": [],
"source": []
}
],
Expand All @@ -236,7 +316,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.0"
"version": "3.9.7"
}
},
"nbformat": 4,
Expand Down

0 comments on commit c51b7db

Please sign in to comment.