-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrun.py
executable file
·167 lines (130 loc) · 7.12 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
#!/usr/bin/python3.9
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# -*- coding: utf-8 -*-
"""CLI to run the baseline Deep Q-learning and Random agents
on a sample CyberBattle gym environment and plot the respective
cummulative rewards in the terminal.
Example usage:
python -m run --training_episode_count 50 --iteration_count 9000 --rewardplot_width 80 --chain_size=20 --ownership_goal 1.0
"""
import torch
import gym
import logging
import time
import sys
import asciichartpy
import argparse
import cyberbattle._env.cyberbattle_env as cyberbattle_env
from cyberbattle.agents.baseline.agent_wrapper import Verbosity
import cyberbattle.agents.baseline.agent_dql as dqla
import cyberbattle.agents.baseline.agent_wrapper as w
import cyberbattle.agents.baseline.plotting as p
import cyberbattle.agents.baseline.learner as learner
from cyberbattle.agents.baseline.notebooks import notebook_dql_debug_with_tinymicro as training_module
from cyberbattle.agents.baseline.notebooks import notebook_debug_tinymicro as evaluation_module
parser = argparse.ArgumentParser(description='Run simulation with DQL baseline agent.')
parser.add_argument('--train', action='store_true', help='Run training code (DQL agent)')
parser.add_argument('--no_train', dest='train', action='store_false', help='Run training code (DQL agent)')
parser.add_argument('--eval', action='store_true', help='Run evaluation code (DQL agent)')
parser.add_argument('--eval_type', default='manual', help='Evaluation type of DQL agent')
parser.add_argument('--log_results', action='store_true', help='Turn on logging of results to the file')
parser.add_argument('--no_log_results', dest='log_results', action='store_false', help='do not reults to the file')
parser.set_defaults(log_results=True)
parser.set_defaults(train=True)
parser.set_defaults(eval=False)
parser.add_argument('--training_episode_count', default=2000, type=int,
help='number of training epochs')
parser.add_argument('--eval_episode_count', default=10, type=int,
help='number of evaluation epochs')
parser.add_argument('--iteration_count', default=50, type=int,
help='number of simulation iterations for each epoch')
parser.add_argument('--reward_goal', default=2180, type=int,
help='minimum target rewards to reach for the attacker to reach its goal')
parser.add_argument('--ownership_goal', default=1.0, type=float,
help='percentage of network nodes to own for the attacker to reach its goal')
parser.add_argument('--rewardplot_width', default=80, type=int,
help='width of the reward plot (values are averaged across iterations to fit in the desired width)')
parser.add_argument('--chain_size', default=4, type=int,
help='size of the chain of the CyberBattleChain sample environment')
parser.add_argument('--gymid', default='CyberBattleTinyMicro-v100', type=str,
help='Gym environment name to run in a simulator')
parser.add_argument('--checkpoint_name', default='best', type=str,
help='checkpoint name, either date or manual or best')
parser.add_argument('--checkpoint_date', default='', type=str,
help='checkpoint date in form %Y%m%d_%H%M%S')
parser.add_argument('--eps_exp_decay', default=2000, type=int,
help='Relative number of episodes to pass, after which epsilon_minimum reached')
parser.add_argument('--reward_clip', action='store_true', help='Apply reward clipping to [-1, 1]')
parser.set_defaults(reward_clip=False)
parser.add_argument('--gamma', default=0.015, type=float,
help='gamma hyperparameter value for RL algorithms (default: 0.015)')
parser.add_argument('--seed', default=time.time(), type=int,
help='special seeding (random with current CPU time in sec by default)')
parser.add_argument('--random_agent', dest='run_random_agent', action='store_true', help='run the random agent as a baseline for comparison')
parser.add_argument('--no-random_agent', dest='run_random_agent', action='store_false', help='do not run the random agent as a baseline for comparison')
parser.set_defaults(run_random_agent=False)
parser.add_argument('--qtabular', dest='run_qtabular', action='store_true', help='run the q-tabular learning agent')
parser.add_argument('--no-qtabular', dest='run_qtabular', action='store_false', help='do not run the q-tabular learning agent')
parser.set_defaults(run_qtabular=False)
args = parser.parse_args()
if args.eval:
evaluation_module.main(args=args)
else:
if args.train:
training_module.main(args=args)
else:
logging.basicConfig(stream=sys.stdout, level=logging.ERROR, format="%(levelname)s: %(message)s")
print(f"torch cuda available={torch.cuda.is_available()}")
cyberbattlechain = gym.make('CyberBattleChain-v0',
size=args.chain_size,
attacker_goal=cyberbattle_env.AttackerGoal(
own_atleast_percent=args.ownership_goal,
reward=args.reward_goal))
ep = w.EnvironmentBounds.of_identifiers(
maximum_total_credentials=22,
maximum_node_count=22,
identifiers=cyberbattlechain.identifiers
)
all_runs = []
# Run Deep Q-learning
dqn_learning_run = learner.epsilon_greedy_search(
cyberbattle_gym_env=cyberbattlechain,
environment_properties=ep,
learner=dqla.DeepQLearnerPolicy(
ep=ep,
gamma=0.015,
replay_memory_size=10000,
target_update=10,
batch_size=512,
learning_rate=0.01), # torch default is 1e-2
episode_count=args.training_episode_count,
iteration_count=args.iteration_count,
epsilon=0.90,
render=True,
# epsilon_multdecay=0.75, # 0.999,
epsilon_exponential_decay=5000, # 10000
epsilon_minimum=0.10,
verbosity=Verbosity.Quiet,
title="DQL"
)
all_runs.append(dqn_learning_run)
if args.run_random_agent:
random_run = learner.epsilon_greedy_search(
cyberbattlechain,
ep,
learner=learner.RandomPolicy(),
episode_count=args.eval_episode_count,
iteration_count=args.iteration_count,
epsilon=1.0, # purely random
render=False,
verbosity=Verbosity.Quiet,
title="Random search"
)
all_runs.append(random_run)
colors = [asciichartpy.red, asciichartpy.green, asciichartpy.yellow, asciichartpy.blue]
print("Episode duration -- DQN=Red, Random=Green")
print(asciichartpy.plot(p.episodes_lengths_for_all_runs(all_runs), {'height': 30, 'colors': colors}))
print("Cumulative rewards -- DQN=Red, Random=Green")
c = p.averaged_cummulative_rewards(all_runs, args.rewardplot_width)
print(asciichartpy.plot(c, {'height': 10, 'colors': colors}))