Skip to content

Commit

Permalink
update cartpole example to use getZeroState()
Browse files Browse the repository at this point in the history
Also update cartpole dynamics to use the state/control enums.
  • Loading branch information
bogidude committed Sep 25, 2024
1 parent 03e6bcc commit a4928e1
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 32 deletions.
1 change: 1 addition & 0 deletions cmake/MPPIGenericToolsConfig.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ set(CUDA_PROPAGATE_HOST_FLAGS OFF)
# Autodetect Cuda Architecture on system and add to executables
# More info for autodetection:
# https://stackoverflow.com/questions/35485087/determining-which-gencode-compute-arch-values-i-need-for-nvcc-within-cmak
set(MPPI_ARCH_FLAGS "-arch=compute_52")
if (NOT DEFINED MPPI_ARCH_FLAGS)
CUDA_SELECT_NVCC_ARCH_FLAGS(MPPI_ARCH_FLAGS ${CUDA_ARCH_LIST})

Expand Down
6 changes: 3 additions & 3 deletions examples/cartpole_example.cu
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,13 @@ int main(int argc, char** argv)
controller_params.cost_rollout_dim_ = dim3(64, 4, 1);
CartpoleController->setParams(controller_params);

CartpoleDynamics::state_array current_state = CartpoleDynamics::state_array::Zero();
CartpoleDynamics::state_array next_state = CartpoleDynamics::state_array::Zero();
CartpoleDynamics::state_array current_state = model->getZeroState();
CartpoleDynamics::state_array next_state = model->getZeroState();
CartpoleDynamics::output_array output = CartpoleDynamics::output_array::Zero();

int time_horizon = 5000;

CartpoleDynamics::state_array xdot = CartpoleDynamics::state_array::Zero();
CartpoleDynamics::state_array xdot = model->getZeroState();

auto time_start = std::chrono::system_clock::now();
for (int i = 0; i < time_horizon; ++i)
Expand Down
56 changes: 27 additions & 29 deletions include/mppi/dynamics/cartpole/cartpole_dynamics.cu
Original file line number Diff line number Diff line change
Expand Up @@ -49,23 +49,21 @@ void CartpoleDynamics::computeDynamics(const Eigen::Ref<const state_array>& stat
const Eigen::Ref<const control_array>& control,
Eigen::Ref<state_array> state_der)
{
const float theta = state(2);
const float sin_theta = sinf(theta);
const float cos_theta = cosf(theta);
float theta_dot = state(3);
float force = control(0);
float m_c = this->params_.cart_mass;
float m_p = this->params_.pole_mass;
float l_p = this->params_.pole_length;
const float sin_theta = sinf(state(S_INDEX(THETA)));
const float cos_theta = cosf(state(S_INDEX(THETA)));
float theta_dot = state[S_INDEX(THETA_DOT)];
float force = control[C_INDEX(FORCE)];
const float m_c = this->params_.cart_mass;
const float m_p = this->params_.pole_mass;
const float l_p = this->params_.pole_length;

// TODO WAT?
state_der(0) = state(S_INDEX(VEL_X));
state_der(1) =
1.0f / (m_c + m_p * SQ(sin_theta)) * (force + m_p * sin_theta * (l_p * SQ(theta_dot) + gravity_ * cos_theta));
state_der(2) = theta_dot;
state_der(3) =
1.0f / (l_p * (m_c + m_p * SQ(sin_theta))) *
(-force * cos_theta - m_p * l_p * SQ(theta_dot) * cos_theta * sin_theta - (m_c + m_p) * gravity_ * sin_theta);
state_der(S_INDEX(POS_X)) = state(S_INDEX(VEL_X));
state_der(S_INDEX(VEL_X)) =
(force + m_p * sin_theta * (l_p * SQ(theta_dot) + gravity_ * cos_theta)) / (m_c + m_p * SQ(sin_theta));
state_der(S_INDEX(THETA)) = theta_dot;
state_der(S_INDEX(THETA_DOT)) =
(-force * cos_theta - m_p * l_p * SQ(theta_dot) * cos_theta * sin_theta - (m_c + m_p) * gravity_ * sin_theta) /
(l_p * (m_c + m_p * SQ(sin_theta)));
}

void CartpoleDynamics::printState(const Eigen::Ref<const state_array>& state)
Expand All @@ -88,22 +86,22 @@ void CartpoleDynamics::printParams()

__device__ void CartpoleDynamics::computeDynamics(float* state, float* control, float* state_der, float* theta_s)
{
float theta = angle_utils::normalizeAngle(state[2]);
float theta = angle_utils::normalizeAngle(state[S_INDEX(THETA)]);
const float sin_theta = __sinf(theta);
const float cos_theta = __cosf(theta);
float theta_dot = state[3];
float force = control[0];
float m_c = this->params_.cart_mass;
float m_p = this->params_.pole_mass;
float l_p = this->params_.pole_length;
float theta_dot = state[S_INDEX(THETA_DOT)];
float force = control[C_INDEX(FORCE)];
const float m_c = this->params_.cart_mass;
const float m_p = this->params_.pole_mass;
const float l_p = this->params_.pole_length;

state_der[0] = state[1];
state_der[1] =
1.0f / (m_c + m_p * SQ(sin_theta)) * (force + m_p * sin_theta * (l_p * SQ(theta_dot) + gravity_ * cos_theta));
state_der[2] = theta_dot;
state_der[3] =
1.0f / (l_p * (m_c + m_p * SQ(sin_theta))) *
(-force * cos_theta - m_p * l_p * SQ(theta_dot) * cos_theta * sin_theta - (m_c + m_p) * gravity_ * sin_theta);
state_der[S_INDEX(POS_X)] = state[S_INDEX(VEL_X)];
state_der[S_INDEX(VEL_X)] =
(force + m_p * sin_theta * (l_p * SQ(theta_dot) + gravity_ * cos_theta)) / (m_c + m_p * SQ(sin_theta));
state_der[S_INDEX(THETA)] = theta_dot;
state_der[S_INDEX(THETA_DOT)] =
(-force * cos_theta - m_p * l_p * SQ(theta_dot) * cos_theta * sin_theta - (m_c + m_p) * gravity_ * sin_theta) /
(l_p * (m_c + m_p * SQ(sin_theta)));
}

Dynamics<CartpoleDynamics, CartpoleDynamicsParams>::state_array
Expand Down

0 comments on commit a4928e1

Please sign in to comment.