|
176 | 176 | "\n",
|
177 | 177 | "print(\"A shape:\", A.shape) # Should be (12, 12)\n",
|
178 | 178 | "print(\"B shape:\", B.shape) # Should be (12, 4)\n",
|
179 |
| - "print(\"A :\\n\", A)\n", |
180 |
| - "print(\"B :\\n\", B)" |
| 179 | + "# print(\"A :\\n\", A)\n", |
| 180 | + "# print(\"B :\\n\", B)" |
181 | 181 | ]
|
182 | 182 | },
|
183 | 183 | {
|
|
271 | 271 | "metadata": {},
|
272 | 272 | "outputs": [],
|
273 | 273 | "source": [
|
274 |
| - "Ad, Bd = discretize_linear_system(A, B, dt) #, exact=True)\n", |
275 |
| - "# print(\"A :\\n\", Ad)\n", |
276 |
| - "# print(\"B :\\n\", Bd)\n", |
277 |
| - "P = scipy.linalg.solve_discrete_are(Ad, Bd, Q_lqr, R_lqr)\n", |
| 274 | + "Ad, Bd = discretize_linear_system(A, B, dt)\n", |
278 | 275 | "\n",
|
| 276 | + "P = scipy.linalg.solve_discrete_are(Ad, Bd, Q_lqr, R_lqr)\n", |
279 | 277 | "btp = np.dot(Bd.T, P)\n",
|
280 |
| - "\n", |
281 | 278 | "gain_lqr = np.dot(np.linalg.inv(R_lqr + np.dot(btp, Bd)),\n",
|
282 | 279 | " np.dot(btp, Ad))\n",
|
283 | 280 | "\n",
|
284 |
| - "# We can also comment out the above two lines of code \n", |
285 |
| - "# and use the following line instead to compute for the continuous-time case\n", |
| 281 | + "### We can also comment out the above lines of code \n", |
| 282 | + "### and use the following lines instead to compute for the continuous-time case\n", |
286 | 283 | "# P = scipy.linalg.solve_continuous_are(A, B, Q_lqr, R_lqr)\n",
|
287 | 284 | "# gain_lqr = np.dot(np.linalg.inv(R_lqr), np.dot(B.T, P))\n",
|
288 | 285 | "\n",
|
289 |
| - "# print(\"A (discretized):\\n\", Ad)\n", |
290 |
| - "# print(\"B (discretized):\\n\", Bd)\n", |
291 | 286 | "print(\"gain:\\n\", gain_lqr)\n",
|
292 | 287 | "print(\"shape of gain:\", gain_lqr.shape)\n"
|
293 | 288 | ]
|
|
309 | 304 | "source": [
|
310 | 305 | "SEED = 42\n",
|
311 | 306 | "\n",
|
312 |
| - "obs, info = envs.reset()#seed=SEED)\n", |
| 307 | + "obs, info = envs.reset(seed=SEED)\n", |
313 | 308 | "state = obs_to_state(obs)\n",
|
314 | 309 | "# print(obs)\n",
|
315 | 310 | "# Step through the environment\n",
|
|
324 | 319 | " goal = np.array([0, 0, 5, 0, 1, 0, 0, 0, 0, 0, 0, 0]) # set goal state\n",
|
325 | 320 | "\n",
|
326 | 321 | " control_input = -gain_lqr @ (state - goal) + u_op\n",
|
327 |
| - " # control_input = np.clip(control_input, MIN_THRUST, MAX_THRUST) \n", |
| 322 | + "\n", |
328 | 323 | " control_input = np.clip(control_input, envs.action_space.low, envs.action_space.high)\n",
|
329 | 324 | " action = control_input.reshape(1,4).astype(np.float32)\n",
|
330 | 325 | " # print(action)\n",
|
|
353 | 348 | "cell_type": "markdown",
|
354 | 349 | "metadata": {},
|
355 | 350 | "source": [
|
356 |
| - "### 2.5.6 Plots" |
| 351 | + "### 2.1.6 Plots" |
357 | 352 | ]
|
358 | 353 | },
|
359 | 354 | {
|
|
522 | 517 | "### 2.2.3 Recursion"
|
523 | 518 | ]
|
524 | 519 | },
|
| 520 | + { |
| 521 | + "cell_type": "markdown", |
| 522 | + "metadata": {}, |
| 523 | + "source": [ |
| 524 | + "It will take some time. :)" |
| 525 | + ] |
| 526 | + }, |
525 | 527 | {
|
526 | 528 | "cell_type": "code",
|
527 | 529 | "execution_count": null,
|
|
540 | 542 | " input_pre = input_stack\n",
|
541 | 543 | "\n",
|
542 | 544 | " # Forward / \"rollout\" of the current policy\n",
|
543 |
| - " obs, info = envs.reset() #seed=SEED)\n", |
| 545 | + " obs, info = envs.reset(seed=SEED)\n", |
544 | 546 | " state = obs_to_state(obs) # (12,)\n",
|
545 | 547 | "\n",
|
546 | 548 | " for step in range(2500):\n",
|
|
551 | 553 | " # Clip the control input to the specified range\n",
|
552 | 554 | " control_input = np.clip(control_input, envs.action_space.low, envs.action_space.high) \n",
|
553 | 555 | " \n",
|
554 |
| - " # Convert to np.ndarray\n", |
| 556 | + " # Reshape and Convert to np.ndarray\n", |
555 | 557 | " action = control_input.reshape(1,4).astype(np.float32) # (1, 4)\n",
|
556 | 558 | "\n",
|
557 | 559 | " # Save rollout data.\n",
|
|
570 | 572 | " \n",
|
571 | 573 | " envs.close()\n",
|
572 | 574 | " \n",
|
573 |
| - " # TODO: Compute cost to see if it could converge\n", |
| 575 | + " # TODO: Compute cost to see if it diverse or converge\n", |
574 | 576 | " # cost_curr = 0\n",
|
575 | 577 | " # for i in range(state_stack.shape[0]):\n",
|
576 | 578 | "\n",
|
|
665 | 667 | " iter += 1"
|
666 | 668 | ]
|
667 | 669 | },
|
668 |
| - { |
669 |
| - "cell_type": "code", |
670 |
| - "execution_count": null, |
671 |
| - "metadata": {}, |
672 |
| - "outputs": [], |
673 |
| - "source": [ |
674 |
| - "print(Sm)" |
675 |
| - ] |
676 |
| - }, |
677 | 670 | {
|
678 | 671 | "cell_type": "code",
|
679 | 672 | "execution_count": null,
|
|
696 | 689 | " # Compute control action (force) using the iLQR gain\n",
|
697 | 690 | " control_input = input_ff[:, i] + gains_fb[i].dot(state) # gains_fb[:, i].dot(state) + input_ff[i]\n",
|
698 | 691 | "\n",
|
699 |
| - " # Clip the control iptput to the specified range\n", |
700 |
| - " control_input = np.clip(control_input, 0.028161688, 0.14834145)\n", |
| 692 | + " # Clip the control input to the specified range\n", |
| 693 | + " control_input = np.clip(control_input, envs.action_space.low, envs.action_space.high) \n", |
701 | 694 | " \n",
|
702 | 695 | " # Convert to np.ndarray\n",
|
703 |
| - " action = np.array([control_input], dtype=np.float32) \n", |
| 696 | + " action = control_input.reshape(1,4).astype(np.float32) # (1, 4)\n", |
| 697 | + " \n", |
| 698 | + " # # Convert to np.ndarray\n", |
| 699 | + " # action = np.array([control_input], dtype=np.float32) \n", |
704 | 700 | "\n",
|
705 | 701 | " # Take a step in the environment with the computed action\n",
|
706 | 702 | " obs, reward, terminated, truncated, _ = envs.step(action)\n",
|
|
716 | 712 | " if terminated or truncated:\n",
|
717 | 713 | " print(\"Episode ended at step:\", i)\n",
|
718 | 714 | " break\n",
|
719 |
| - " envs.render()\n", |
| 715 | + "\n", |
| 716 | + " if (i * fps) % envs.sim.freq < fps:\n", |
| 717 | + " envs.render()\n", |
| 718 | + " time.sleep(1 / fps)\n", |
| 719 | + " # envs.render()\n", |
720 | 720 | "# Close the environment\n",
|
721 | 721 | "envs.sim.close()\n",
|
722 | 722 | "envs.close()"
|
|
736 | 736 | "plt.plot(time_log, x_log_ilqr, label=\"x(iLQR)\", color=\"blue\")\n",
|
737 | 737 | "plt.plot(time_log, y_log_ilqr, label=\"y(iLQR)\", color=\"green\")\n",
|
738 | 738 | "plt.plot(time_log, z_log_ilqr, label=\"z(iLQR)\", color=\"red\")\n",
|
739 |
| - "# plt.plot(time_log, x_log, label=\"x(LQR)\", color=\"blue\", linestyle=\"--\")\n", |
740 |
| - "# plt.plot(time_log, y_log, label=\"y(LQR)\", color=\"green\", linestyle=\"--\")\n", |
741 |
| - "# plt.plot(time_log, z_log, label=\"z(LQR)\", color=\"red\", linestyle=\"--\")\n", |
| 739 | + "plt.plot(time_log, x_log, label=\"x(LQR)\", color=\"blue\", linestyle=\"--\")\n", |
| 740 | + "plt.plot(time_log, y_log, label=\"y(LQR)\", color=\"green\", linestyle=\"--\")\n", |
| 741 | + "plt.plot(time_log, z_log, label=\"z(LQR)\", color=\"red\", linestyle=\"--\")\n", |
742 | 742 | "plt.xlabel(\"Time (s)\")\n",
|
743 | 743 | "plt.ylabel(\"position\")\n",
|
744 | 744 | "plt.title(\"position vs Time\")\n",
|
|
763 | 763 | "plt.plot(time_log, thrust_values_ilqr[:, 2], label=\"Motor 3(iLQR)\", color=\"green\")\n",
|
764 | 764 | "plt.plot(time_log, thrust_values_ilqr[:, 3], label=\"Motor 4(iLQR)\", color=\"red\")\n",
|
765 | 765 | "\n",
|
766 |
| - "# plt.plot(time_log, thrust_values[:, 0], label=\"Motor 1(LQR)\", color=\"blue\", linestyle=\"--\" )\n", |
767 |
| - "# plt.plot(time_log, thrust_values[:, 1], label=\"Motor 2(LQR)\", color=\"orange\", linestyle=\"--\")\n", |
768 |
| - "# plt.plot(time_log, thrust_values[:, 2], label=\"Motor 3(LQR)\", color=\"green\", linestyle=\"--\")\n", |
769 |
| - "# plt.plot(time_log, thrust_values[:, 3], label=\"Motor 4(LQR)\", color=\"red\", linestyle=\"--\")\n", |
| 766 | + "plt.plot(time_log, thrust_values[:, 0], label=\"Motor 1(LQR)\", color=\"blue\", linestyle=\"--\" )\n", |
| 767 | + "plt.plot(time_log, thrust_values[:, 1], label=\"Motor 2(LQR)\", color=\"orange\", linestyle=\"--\")\n", |
| 768 | + "plt.plot(time_log, thrust_values[:, 2], label=\"Motor 3(LQR)\", color=\"green\", linestyle=\"--\")\n", |
| 769 | + "plt.plot(time_log, thrust_values[:, 3], label=\"Motor 4(LQR)\", color=\"red\", linestyle=\"--\")\n", |
770 | 770 | "\n",
|
771 | 771 | "plt.xlabel('Time Steps')\n",
|
772 | 772 | "plt.ylabel('Thrust (N)')\n",
|
|
775 | 775 | "plt.grid()\n",
|
776 | 776 | "plt.show()"
|
777 | 777 | ]
|
778 |
| - }, |
779 |
| - { |
780 |
| - "cell_type": "code", |
781 |
| - "execution_count": null, |
782 |
| - "metadata": {}, |
783 |
| - "outputs": [], |
784 |
| - "source": [] |
785 | 778 | }
|
786 | 779 | ],
|
787 | 780 | "metadata": {
|
|
0 commit comments