Skip to content

Commit bfaa135

Browse files
author
Lui
committed
make LQR ILQR work with attitude interface
1 parent 94f6a78 commit bfaa135

File tree

1 file changed

+37
-44
lines changed

1 file changed

+37
-44
lines changed

tutorials/LQR_ILQR.ipynb

+37-44
Original file line numberDiff line numberDiff line change
@@ -176,8 +176,8 @@
176176
"\n",
177177
"print(\"A shape:\", A.shape) # Should be (12, 12)\n",
178178
"print(\"B shape:\", B.shape) # Should be (12, 4)\n",
179-
"print(\"A :\\n\", A)\n",
180-
"print(\"B :\\n\", B)"
179+
"# print(\"A :\\n\", A)\n",
180+
"# print(\"B :\\n\", B)"
181181
]
182182
},
183183
{
@@ -271,23 +271,18 @@
271271
"metadata": {},
272272
"outputs": [],
273273
"source": [
274-
"Ad, Bd = discretize_linear_system(A, B, dt) #, exact=True)\n",
275-
"# print(\"A :\\n\", Ad)\n",
276-
"# print(\"B :\\n\", Bd)\n",
277-
"P = scipy.linalg.solve_discrete_are(Ad, Bd, Q_lqr, R_lqr)\n",
274+
"Ad, Bd = discretize_linear_system(A, B, dt)\n",
278275
"\n",
276+
"P = scipy.linalg.solve_discrete_are(Ad, Bd, Q_lqr, R_lqr)\n",
279277
"btp = np.dot(Bd.T, P)\n",
280-
"\n",
281278
"gain_lqr = np.dot(np.linalg.inv(R_lqr + np.dot(btp, Bd)),\n",
282279
" np.dot(btp, Ad))\n",
283280
"\n",
284-
"# We can also comment out the above two lines of code \n",
285-
"# and use the following line instead to compute for the continuous-time case\n",
281+
"### We can also comment out the above lines of code \n",
282+
"### and use the following lines instead to compute for the continuous-time case\n",
286283
"# P = scipy.linalg.solve_continuous_are(A, B, Q_lqr, R_lqr)\n",
287284
"# gain_lqr = np.dot(np.linalg.inv(R_lqr), np.dot(B.T, P))\n",
288285
"\n",
289-
"# print(\"A (discretized):\\n\", Ad)\n",
290-
"# print(\"B (discretized):\\n\", Bd)\n",
291286
"print(\"gain:\\n\", gain_lqr)\n",
292287
"print(\"shape of gain:\", gain_lqr.shape)\n"
293288
]
@@ -309,7 +304,7 @@
309304
"source": [
310305
"SEED = 42\n",
311306
"\n",
312-
"obs, info = envs.reset()#seed=SEED)\n",
307+
"obs, info = envs.reset(seed=SEED)\n",
313308
"state = obs_to_state(obs)\n",
314309
"# print(obs)\n",
315310
"# Step through the environment\n",
@@ -324,7 +319,7 @@
324319
" goal = np.array([0, 0, 5, 0, 1, 0, 0, 0, 0, 0, 0, 0]) # set goal state\n",
325320
"\n",
326321
" control_input = -gain_lqr @ (state - goal) + u_op\n",
327-
" # control_input = np.clip(control_input, MIN_THRUST, MAX_THRUST) \n",
322+
"\n",
328323
" control_input = np.clip(control_input, envs.action_space.low, envs.action_space.high)\n",
329324
" action = control_input.reshape(1,4).astype(np.float32)\n",
330325
" # print(action)\n",
@@ -353,7 +348,7 @@
353348
"cell_type": "markdown",
354349
"metadata": {},
355350
"source": [
356-
"### 2.5.6 Plots"
351+
"### 2.1.6 Plots"
357352
]
358353
},
359354
{
@@ -522,6 +517,13 @@
522517
"### 2.2.3 Recursion"
523518
]
524519
},
520+
{
521+
"cell_type": "markdown",
522+
"metadata": {},
523+
"source": [
524+
"It will take some time. :)"
525+
]
526+
},
525527
{
526528
"cell_type": "code",
527529
"execution_count": null,
@@ -540,7 +542,7 @@
540542
" input_pre = input_stack\n",
541543
"\n",
542544
" # Forward / \"rollout\" of the current policy\n",
543-
" obs, info = envs.reset() #seed=SEED)\n",
545+
" obs, info = envs.reset(seed=SEED)\n",
544546
" state = obs_to_state(obs) # (12,)\n",
545547
"\n",
546548
" for step in range(2500):\n",
@@ -551,7 +553,7 @@
551553
" # Clip the control input to the specified range\n",
552554
" control_input = np.clip(control_input, envs.action_space.low, envs.action_space.high) \n",
553555
" \n",
554-
" # Convert to np.ndarray\n",
556+
" # Reshape and Convert to np.ndarray\n",
555557
" action = control_input.reshape(1,4).astype(np.float32) # (1, 4)\n",
556558
"\n",
557559
" # Save rollout data.\n",
@@ -570,7 +572,7 @@
570572
" \n",
571573
" envs.close()\n",
572574
" \n",
573-
" # TODO: Compute cost to see if it could converge\n",
575+
" # TODO: Compute cost to see if it diverse or converge\n",
574576
" # cost_curr = 0\n",
575577
" # for i in range(state_stack.shape[0]):\n",
576578
"\n",
@@ -665,15 +667,6 @@
665667
" iter += 1"
666668
]
667669
},
668-
{
669-
"cell_type": "code",
670-
"execution_count": null,
671-
"metadata": {},
672-
"outputs": [],
673-
"source": [
674-
"print(Sm)"
675-
]
676-
},
677670
{
678671
"cell_type": "code",
679672
"execution_count": null,
@@ -696,11 +689,14 @@
696689
" # Compute control action (force) using the iLQR gain\n",
697690
" control_input = input_ff[:, i] + gains_fb[i].dot(state) # gains_fb[:, i].dot(state) + input_ff[i]\n",
698691
"\n",
699-
" # Clip the control iptput to the specified range\n",
700-
" control_input = np.clip(control_input, 0.028161688, 0.14834145)\n",
692+
" # Clip the control input to the specified range\n",
693+
" control_input = np.clip(control_input, envs.action_space.low, envs.action_space.high) \n",
701694
" \n",
702695
" # Convert to np.ndarray\n",
703-
" action = np.array([control_input], dtype=np.float32) \n",
696+
" action = control_input.reshape(1,4).astype(np.float32) # (1, 4)\n",
697+
" \n",
698+
" # # Convert to np.ndarray\n",
699+
" # action = np.array([control_input], dtype=np.float32) \n",
704700
"\n",
705701
" # Take a step in the environment with the computed action\n",
706702
" obs, reward, terminated, truncated, _ = envs.step(action)\n",
@@ -716,7 +712,11 @@
716712
" if terminated or truncated:\n",
717713
" print(\"Episode ended at step:\", i)\n",
718714
" break\n",
719-
" envs.render()\n",
715+
"\n",
716+
" if (i * fps) % envs.sim.freq < fps:\n",
717+
" envs.render()\n",
718+
" time.sleep(1 / fps)\n",
719+
" # envs.render()\n",
720720
"# Close the environment\n",
721721
"envs.sim.close()\n",
722722
"envs.close()"
@@ -736,9 +736,9 @@
736736
"plt.plot(time_log, x_log_ilqr, label=\"x(iLQR)\", color=\"blue\")\n",
737737
"plt.plot(time_log, y_log_ilqr, label=\"y(iLQR)\", color=\"green\")\n",
738738
"plt.plot(time_log, z_log_ilqr, label=\"z(iLQR)\", color=\"red\")\n",
739-
"# plt.plot(time_log, x_log, label=\"x(LQR)\", color=\"blue\", linestyle=\"--\")\n",
740-
"# plt.plot(time_log, y_log, label=\"y(LQR)\", color=\"green\", linestyle=\"--\")\n",
741-
"# plt.plot(time_log, z_log, label=\"z(LQR)\", color=\"red\", linestyle=\"--\")\n",
739+
"plt.plot(time_log, x_log, label=\"x(LQR)\", color=\"blue\", linestyle=\"--\")\n",
740+
"plt.plot(time_log, y_log, label=\"y(LQR)\", color=\"green\", linestyle=\"--\")\n",
741+
"plt.plot(time_log, z_log, label=\"z(LQR)\", color=\"red\", linestyle=\"--\")\n",
742742
"plt.xlabel(\"Time (s)\")\n",
743743
"plt.ylabel(\"position\")\n",
744744
"plt.title(\"position vs Time\")\n",
@@ -763,10 +763,10 @@
763763
"plt.plot(time_log, thrust_values_ilqr[:, 2], label=\"Motor 3(iLQR)\", color=\"green\")\n",
764764
"plt.plot(time_log, thrust_values_ilqr[:, 3], label=\"Motor 4(iLQR)\", color=\"red\")\n",
765765
"\n",
766-
"# plt.plot(time_log, thrust_values[:, 0], label=\"Motor 1(LQR)\", color=\"blue\", linestyle=\"--\" )\n",
767-
"# plt.plot(time_log, thrust_values[:, 1], label=\"Motor 2(LQR)\", color=\"orange\", linestyle=\"--\")\n",
768-
"# plt.plot(time_log, thrust_values[:, 2], label=\"Motor 3(LQR)\", color=\"green\", linestyle=\"--\")\n",
769-
"# plt.plot(time_log, thrust_values[:, 3], label=\"Motor 4(LQR)\", color=\"red\", linestyle=\"--\")\n",
766+
"plt.plot(time_log, thrust_values[:, 0], label=\"Motor 1(LQR)\", color=\"blue\", linestyle=\"--\" )\n",
767+
"plt.plot(time_log, thrust_values[:, 1], label=\"Motor 2(LQR)\", color=\"orange\", linestyle=\"--\")\n",
768+
"plt.plot(time_log, thrust_values[:, 2], label=\"Motor 3(LQR)\", color=\"green\", linestyle=\"--\")\n",
769+
"plt.plot(time_log, thrust_values[:, 3], label=\"Motor 4(LQR)\", color=\"red\", linestyle=\"--\")\n",
770770
"\n",
771771
"plt.xlabel('Time Steps')\n",
772772
"plt.ylabel('Thrust (N)')\n",
@@ -775,13 +775,6 @@
775775
"plt.grid()\n",
776776
"plt.show()"
777777
]
778-
},
779-
{
780-
"cell_type": "code",
781-
"execution_count": null,
782-
"metadata": {},
783-
"outputs": [],
784-
"source": []
785778
}
786779
],
787780
"metadata": {

0 commit comments

Comments
 (0)