forked from CoactiveAI/dataperf-vision-selection
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
82 lines (55 loc) · 2.18 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import os
import json
from datetime import datetime
import yaml
from pyspark.sql import SparkSession, DataFrame
import constants as c
def get_spark_session(spark_driver_memory: str) -> SparkSession:
return SparkSession.builder\
.config('spark.driver.memory', spark_driver_memory)\
.getOrCreate()
def get_emb_dim(df: DataFrame) -> int:
return len(df.select(c.EMB_COL).take(1)[0][0])
def load_yaml(path: str) -> dict:
with open(path, 'r') as stream:
try:
yaml_dict = yaml.safe_load(stream)
except yaml.YAMLError as exc:
print(exc)
return yaml_dict
def load_emb_df(ss: SparkSession, path: str, dim: int) -> DataFrame:
df = ss.read.parquet(path)
for col in [c.EMB_COL, c.ID_COL]:
assert col in df.columns, \
f'Embedding file does not have "{col}" column'
actual_dim = get_emb_dim(df)
assert actual_dim == dim, \
f'Embedding file dim={actual_dim}, but setup file specifies dim={dim}'
return df
def load_train_df(ss: SparkSession, path: str) -> DataFrame:
df = ss.read.option('header', True).csv(path)
for col in [c.LABEL_COL, c.ID_COL]:
assert col in df.columns, \
f'{path}: Train file does not have "{col}" column'
return df
def add_emb_col(df: DataFrame, emb_df: DataFrame) -> DataFrame:
emb_df = emb_df.select(c.ID_COL, c.EMB_COL)
return df.join(emb_df, c.ID_COL)
def load_test_df(ss: SparkSession, path: str, dim: int) -> DataFrame:
df = ss.read.parquet(path)
for col in [c.LABEL_COL, c.ID_COL]:
assert col in df.columns, \
f'{path}: Train file does not have "{col}" column'
actual_dim = get_emb_dim(df)
assert actual_dim == dim, \
f'Test file dim={actual_dim}, but setup file specifies dim={dim}'
return df
def save_results(data: dict, save_dir: str, verbose=False) -> None:
dt = datetime.utcnow().strftime("UTC-%Y-%m-%d-%H-%M-%S")
filename = f'{c.RESULT_FILE_PREFIX}_{dt}.json'
path = os.path.join(save_dir, filename)
with open(path, 'w') as f:
json.dump(data, f, indent=4)
if verbose:
print(f'Results saved in {path}')
print(json.dumps(data, indent=4))