|
300 | 300 | "\n",
|
301 | 301 | "from utils import Config, ReplayBuffer\n",
|
302 | 302 | "from agents import MADDPG \n",
|
303 |
| - "\n", |
| 303 | + "from load_env import env_vars\n", |
304 | 304 | "env = UnityEnvironment(file_name=\"./Tennis_Linux_NoVis/Tennis.x86_64\")\n",
|
305 | 305 | "# get the default brain\n",
|
306 | 306 | "brain_name = env.brain_names[0]\n",
|
|
342 | 342 | "SOLVED_SCORE = 0.5\n",
|
343 | 343 | "\n",
|
344 | 344 | "# Initialize wandb run\n",
|
345 |
| - "wandb.init(project=\"tennis-maddpg\", entity=\"minhna1112\") # Replace 'your_username' with your wandb username\n", |
| 345 | + "wandb.init(project=env_vars.WANDB_PROJECT_NAME, \n", |
| 346 | + " entity=env_vars.WANDB_USER_NAME) # Replace 'your_username' with your wandb username\n", |
346 | 347 | "\n",
|
347 | 348 | "model_dir= os.getcwd()+\"/checkpoints\"\n",
|
348 | 349 | "os.makedirs(model_dir, exist_ok=True)\n",
|
|
422 | 423 | "execution_count": null,
|
423 | 424 | "metadata": {},
|
424 | 425 | "outputs": [],
|
425 |
| - "source": [] |
| 426 | + "source": [ |
| 427 | + "fig = plt.figure(figsize=(10,4))\n", |
| 428 | + "ax = fig.add_subplot(111)\n", |
| 429 | + "plt.plot(np.arange(1, len(scores_all)+1), scores_all, label='Scores')\n", |
| 430 | + "plt.plot(np.arange(1, len(moving_average)+1), moving_average, c='r', label='Average')\n", |
| 431 | + "plt.ylabel('Score')\n", |
| 432 | + "plt.xlabel('Episode #')\n", |
| 433 | + "ax.legend(fontsize='large', loc='upper left')\n", |
| 434 | + "fig.savefig('result.png')\n", |
| 435 | + "plt.show()" |
| 436 | + ] |
426 | 437 | }
|
427 | 438 | ],
|
428 | 439 | "metadata": {
|
|
0 commit comments