|
| 1 | +import backend |
| 2 | + |
| 3 | +import argparse |
| 4 | +import bz2 |
| 5 | +import glob |
| 6 | +import random |
| 7 | +import os.path |
| 8 | +import multiprocessing |
| 9 | + |
| 10 | +import pandas |
| 11 | + |
| 12 | +@backend.logged_main |
| 13 | +def main(): |
| 14 | + parser = argparse.ArgumentParser(description='Read all the metadata and select top n players for training/validation/testing', formatter_class=argparse.ArgumentDefaultsHelpFormatter) |
| 15 | + parser.add_argument('inputs', help='input csvs dir') |
| 16 | + parser.add_argument('output_train', help='output csv for training data') |
| 17 | + parser.add_argument('num_train', type=int, help='num for main training') |
| 18 | + parser.add_argument('output_val', help='output csv for validation data') |
| 19 | + parser.add_argument('num_val', type=int, help='num for big validation run') |
| 20 | + parser.add_argument('output_test', help='output csv for testing data') |
| 21 | + parser.add_argument('num_test', type=int, help='num for holdout set') |
| 22 | + parser.add_argument('--pool_size', type=int, help='Number of models to run in parallel', default = 48) |
| 23 | + parser.add_argument('--min_elo', type=int, help='min elo to select', default = 1100) |
| 24 | + parser.add_argument('--max_elo', type=int, help='max elo to select', default = 2000) |
| 25 | + parser.add_argument('--seed', type=int, help='random seed', default = 1) |
| 26 | + args = parser.parse_args() |
| 27 | + random.seed(args.seed) |
| 28 | + |
| 29 | + targets = glob.glob(os.path.join(args.inputs, '*csv.bz2')) |
| 30 | + |
| 31 | + with multiprocessing.Pool(args.pool_size) as pool: |
| 32 | + players = pool.starmap(check_player, ((t, args.min_elo, args.max_elo) for t in targets)) |
| 33 | + |
| 34 | + players_top = sorted( |
| 35 | + (p for p in players if p is not None), |
| 36 | + key = lambda x : x[1], |
| 37 | + reverse=True, |
| 38 | + )[:args.num_train + args.num_val + args.num_test] |
| 39 | + |
| 40 | + random.shuffle(players_top) |
| 41 | + |
| 42 | + write_output_file(args.output_train, args.num_train, players_top) |
| 43 | + write_output_file(args.output_val, args.num_val, players_top) |
| 44 | + write_output_file(args.output_test, args.num_test, players_top) |
| 45 | + |
| 46 | +def write_output_file(path, count, targets): |
| 47 | + with open(path, 'wt') as f: |
| 48 | + f.write("player,count,ELO\n") |
| 49 | + for i in range(count): |
| 50 | + t = targets.pop() |
| 51 | + f.write(f"{t[0]},{t[1]},{t[2]}\n") |
| 52 | + |
| 53 | +def check_player(path, min_elo, max_elo): |
| 54 | + df = pandas.read_csv(path, low_memory=False) |
| 55 | + elo = df['ELO'][-10000:].mean() |
| 56 | + count = len(df) |
| 57 | + if elo > min_elo and elo < max_elo: |
| 58 | + return path.split('/')[-1].split('.')[0], count, elo |
| 59 | + else: |
| 60 | + return None |
| 61 | + |
| 62 | +if __name__ == "__main__": |
| 63 | + main() |
0 commit comments