-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathio_utils.py
152 lines (129 loc) · 4.92 KB
/
io_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
import json
import logging
import boto3
from pathlib import Path
import shutil
import os
import pandas as pd
logger = logging.getLogger(__name__)
def validate_experience(experience):
"""
Validate the collected experience has required keys.
"""
keys = ["shared_context", "actions_context" ,"action_prob", "action", "reward", "user_id"]
for key in keys:
is_valid = key in experience
if not is_valid:
return False
return True
class CSVReader():
"""Reader object that loads experiences from CSV file chunks.
The input files will be read from in an random order."""
def __init__(self, input_files):
self.files = input_files
def get_iterator(self):
for file in self.files:
reader = pd.read_csv(file, chunksize=1000)
for df in reader:
df_no_nans = df.dropna()
for line in df_no_nans.iterrows():
line_dict = line[1].to_dict()
yield line_dict
class JsonLinesReader():
"""Reader object that loads experiences from JSON file chunks.
The input files will be read from in an random order."""
def __init__(self, input_files):
self.files = input_files
self.cur_file = None
self.cur_index = 0
self.max_index = len(input_files) - 1
self.done = False
def get_experience(self):
line = self._next_line()
experience = self._try_parse(line)
while not experience and not self.done:
logger.debug("Skipping empty line in {}".format(self.cur_file))
experience = self._try_parse(self._next_line())
return experience
def _try_parse(self, line):
if line is None or line.strip() == '':
return None
try:
line_json = json.loads(line.strip())
assert "observation" in line_json, "observation not found in record"
assert "action" in line_json, "action not found in record"
assert "reward" in line_json, "reward not found in record"
assert "prob" in line_json, "prob not found in record"
return line_json
except Exception:
logger.exception("Ignoring corrupt json record in {}: {}".format(
self.cur_file, line))
return None
def _next_line(self):
if not self.cur_file:
self.cur_file = self._next_file()
if self.done is True:
return None
line = self.cur_file.readline()
tries = 0
while not line and tries < 100:
tries += 1
self.cur_file.close()
self.cur_file = self._next_file()
if self.done is True:
return None
line = self.cur_file.readline()
if not line:
logger.debug("Ignoring empty file {}".format(self.cur_file))
if not line:
raise ValueError("Failed to read next line from files: {}".format(
self.files))
return line
def _next_file(self):
if self.cur_index > self.max_index:
self.done = True
return None
path = self.files[self.cur_index]
self.cur_index += 1
return open(path, "r")
def get_vw_model(disk_path=None):
"""
Returns a tuple (str, str) of metadata string and model weights URL on disk.
"""
sagemaker_model_path = Path(disk_path)
meta_files = list(sagemaker_model_path.rglob("vw.metadata"))
if len(meta_files) == 0:
raise ValueError("Algorithm Error: 'vw.metadata' not found in model files.")
metadata_path = meta_files[0]
model_files = list(sagemaker_model_path.rglob("vw.model"))
if len(model_files) == 0:
raise ValueError("Algorithm Error: 'vw.model' not found in model files.")
model_path = model_files[0]
return metadata_path.as_posix(), model_path.as_posix()
def extract_model(tar_gz_folder):
"""
This function extracts the model.tar.gz and then
returns a tuple (str, str) of metadata string and model weights URL on disk.
"""
shutil.unpack_archive(filename=os.path.join(tar_gz_folder, "model.tar.gz"), extract_dir=tar_gz_folder)
return get_vw_model(tar_gz_folder)
def parse_s3_uri(uri):
uri = uri.replace("s3://", "")
bucket, *key = uri.split("/")
file_name = key[-1]
key = "/".join(key)
return bucket, key, file_name
def download_manifest_data(manifest_file_path, output_dir):
"""
Download the s3 files contained in a manifest file.
"""
with open(manifest_file_path.as_posix()) as f:
manifest = json.load(f)
s3_prefix = manifest[0]["prefix"]
s3 = boto3.client('s3')
for file in manifest[1:]:
s3_uri = os.path.join(s3_prefix, file)
bucket, key, file_name = parse_s3_uri(s3_uri)
output_file = os.path.join(output_dir.as_posix(), file_name)
s3.download_file(bucket, key, output_file)
print("Downloaded file ", output_file)