-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
130 lines (106 loc) · 5.58 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import json
import logging
import os
from env import MovieLens100KEnv
from io_utils import extract_model
from vw_agent import VWAgent
from vw_utils import MODEL_CHANNEL, MODEL_OUTPUT_DIR, DATA_OUTPUT_DIR
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)
def main():
""" Train a Vowpal Wabbit (VW) model through C++ process. """
channel_names = json.loads(os.environ['SM_CHANNELS'])
hyperparameters = json.loads(os.environ['SM_HPS'])
# Fetch algorithm hyperparameters
num_arms = int(hyperparameters.get("num_arms", 0)) # Used if arm features are not present
num_policies = int(hyperparameters.get("num_policies", 3))
exploration_policy = hyperparameters.get("exploration_policy", "egreedy").lower()
epsilon = float(hyperparameters.get("epsilon", 0))
mellowness = float(hyperparameters.get("mellowness", 0.01))
arm_features_present = bool(hyperparameters.get("arm_features", True))
# Fetch environment parameters
item_pool_size = int(hyperparameters.get("item_pool_size", 0))
top_k = int(hyperparameters.get("top_k", 5))
max_users = int(hyperparameters.get("max_users", 3))
total_interactions = int(hyperparameters.get("total_interactions", 1000))
if not arm_features_present and num_arms is 0:
raise ValueError("Customer Error: Please provide a non-zero value for 'num_arms'")
logging.info("channels %s" % channel_names)
logging.info("hps: %s" % hyperparameters)
# Different exploration policies in VW
# https://github.com/VowpalWabbit/vowpal_wabbit/wiki/Contextual-Bandit-algorithms
valid_policies = ["egreedy", "bag", "regcbopt", "regcb"]
if exploration_policy not in valid_policies:
raise ValueError(f"Customer Error: exploration_policy must be one of {valid_policies}.")
if exploration_policy == "egreedy":
vw_args_base = f"--cb_explore_adf --cb_type mtr --epsilon {epsilon}"
elif exploration_policy in ["regcbopt", "regcb"]:
vw_args_base = f"--cb_explore_adf --cb_type mtr --{exploration_policy} --mellowness {mellowness}"
else:
vw_args_base = f"--cb_explore_adf --cb_type mtr --{exploration_policy} {num_policies}"
# If pre-trained model is present
if MODEL_CHANNEL not in channel_names:
logging.info(f"No pre-trained model has been specified in channel {MODEL_CHANNEL}."
f"Training will start from scratch.")
vw_agent = VWAgent(cli_args=vw_args_base,
output_dir=MODEL_OUTPUT_DIR,
model_path=None,
test_only=False,
quiet_mode=False,
adf_mode=arm_features_present,
num_actions=num_arms)
else:
# Load the pre-trained model for training.
model_folder = os.environ[f'SM_CHANNEL_{MODEL_CHANNEL.upper()}']
metadata_path, weights_path = extract_model(model_folder)
logging.info(f"Loading model from {weights_path}")
vw_agent = VWAgent.load_model(metadata_loc=metadata_path,
weights_loc=weights_path,
test_only=False,
quiet_mode=False,
output_dir=MODEL_OUTPUT_DIR)
# Start the VW C++ process. This python program will communicate with the C++ process using PIPES
vw_agent.start()
if "movielens" not in channel_names:
raise ValueError(
"Cannot find `movielens` channel. Please make sure to provide the data as `movielens` channel.")
# Initialize MovieLens environment
env = MovieLens100KEnv(data_dir=os.environ['SM_CHANNEL_MOVIELENS'],
item_pool_size=item_pool_size,
top_k=top_k,
max_users=max_users)
regrets = []
random_regrets = []
obs = env.reset()
# Learn by interacting with the environment
for i in range(total_interactions):
user_features, items_features = obs
actions, probs = vw_agent.choose_actions(shared_features=user_features,
candidate_arms_features=items_features,
user_id=env.current_user_id,
candidate_ids=env.current_item_pool,
top_k=5)
clicks, regret, random_regret = env.get_feedback(actions)
regrets.append(regret)
random_regrets.append(random_regret)
for index, reward in enumerate(clicks):
vw_agent.learn(shared_features=user_features,
candidate_arms_features=items_features,
action_index=actions[index],
reward=reward,
user_id=env.current_user_id,
candidate_ids=env.current_item_pool,
action_prob=probs[index],
cost_fn=lambda x: -x)
# Step the environment to pick next user and new list of candidate items
obs, rewards, done, info = env.step(actions)
if i % 500 == 0:
logging.info(f"Processed {i} interactions")
stdout = vw_agent.save_model(close=True)
print(stdout.decode())
logging.info(f"Model learned using {total_interactions} training experiences.")
all_regrets = {"agent": regrets, "random": random_regrets}
with open(os.path.join(DATA_OUTPUT_DIR, "output.json"), "w") as file:
json.dump(all_regrets, file)
if __name__ == '__main__':
main()