Skip to content

Commit 3cf4a16

Browse files
author
minhna4
committed
Save wandb username and project name into env vars
1 parent 6e3a8d1 commit 3cf4a16

File tree

3 files changed

+25
-3
lines changed

3 files changed

+25
-3
lines changed

.env

+2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
WANDB_PROJECT_NAME="tennis-maddpg"
2+
WANDB_USER_NAME="minhna1112"

Tennis.ipynb

+14-3
Original file line numberDiff line numberDiff line change
@@ -300,7 +300,7 @@
300300
"\n",
301301
"from utils import Config, ReplayBuffer\n",
302302
"from agents import MADDPG \n",
303-
"\n",
303+
"from load_env import env_vars\n",
304304
"env = UnityEnvironment(file_name=\"./Tennis_Linux_NoVis/Tennis.x86_64\")\n",
305305
"# get the default brain\n",
306306
"brain_name = env.brain_names[0]\n",
@@ -342,7 +342,8 @@
342342
"SOLVED_SCORE = 0.5\n",
343343
"\n",
344344
"# Initialize wandb run\n",
345-
"wandb.init(project=\"tennis-maddpg\", entity=\"minhna1112\") # Replace 'your_username' with your wandb username\n",
345+
"wandb.init(project=env_vars.WANDB_PROJECT_NAME, \n",
346+
" entity=env_vars.WANDB_USER_NAME) # Replace 'your_username' with your wandb username\n",
346347
"\n",
347348
"model_dir= os.getcwd()+\"/checkpoints\"\n",
348349
"os.makedirs(model_dir, exist_ok=True)\n",
@@ -422,7 +423,17 @@
422423
"execution_count": null,
423424
"metadata": {},
424425
"outputs": [],
425-
"source": []
426+
"source": [
427+
"fig = plt.figure(figsize=(10,4))\n",
428+
"ax = fig.add_subplot(111)\n",
429+
"plt.plot(np.arange(1, len(scores_all)+1), scores_all, label='Scores')\n",
430+
"plt.plot(np.arange(1, len(moving_average)+1), moving_average, c='r', label='Average')\n",
431+
"plt.ylabel('Score')\n",
432+
"plt.xlabel('Episode #')\n",
433+
"ax.legend(fontsize='large', loc='upper left')\n",
434+
"fig.savefig('result.png')\n",
435+
"plt.show()"
436+
]
426437
}
427438
],
428439
"metadata": {

load_env.py

+9
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
from dotenv import load_dotenv
2+
import os
3+
4+
load_dotenv("./env")
5+
class ENV:
6+
WANDB_PROJECT_NAME = os.getenv('WANDB_PROJECT_NAME')
7+
WANDB_USER_NAME = os.getenv('WANDB_USER_NAME')
8+
9+
env_vars = ENV()

0 commit comments

Comments
 (0)