Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions examples/pufferl.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,34 @@
def cli():
pufferl.train('puffer_breakout')

# Simple trainer based on pufferl functions
def simple_trainer(env_name='puffer_breakout'):
args = pufferl.load_config(env_name)

# You can customize the puffer-provided config
args['vec']['num_envs'] = 2
args['env']['num_envs'] = 2048
args['policy']['hidden_size'] = 256
args['rnn']['input_size'] = 256
args['rnn']['hidden_size'] = 256
args['train']['total_timesteps'] = 10_000_000
args['train']['learning_rate'] = 0.03

# Or, you can create and use a separate config file
# args = pufferl.load_config_file(<YOUR_OWN_CONFIG.ini>, fill_in_default=True)

vecenv = pufferl.load_env(env_name, args)
policy = pufferl.load_policy(args, vecenv, env_name)

trainer = pufferl.PuffeRL(args['train'], vecenv, policy)

while trainer.epoch < trainer.total_epochs:
trainer.evaluate()
logs = trainer.train()

trainer.print_dashboard()
trainer.close()

class Policy(torch.nn.Module):
def __init__(self, env):
super().__init__()
Expand Down
59 changes: 39 additions & 20 deletions pufferlib/pufferl.py
Original file line number Diff line number Diff line change
Expand Up @@ -904,7 +904,7 @@ def train(env_name, args=None, vecenv=None, policy=None, logger=None):
elif args['wandb']:
logger = WandbLogger(args)

train_config = dict(**args['train'], env=env_name)
train_config = { **args['train'], 'env': env_name }
pufferl = PuffeRL(train_config, vecenv, policy, logger)

all_logs = []
Expand Down Expand Up @@ -992,7 +992,7 @@ def eval(env_name, args=None, vecenv=None, policy=None):
if len(frames) > 0 and len(frames) == args['save_frames']:
import imageio
imageio.mimsave(args['gif_path'], frames, fps=args['fps'], loop=0)
frames.append('Done')
print(f'Saved {len(frames)} frames to {args["gif_path"]}')

def sweep(args=None, env_name=None):
args = args or load_config(env_name)
Expand Down Expand Up @@ -1120,6 +1120,39 @@ def load_policy(args, vecenv, env_name=''):
return policy

def load_config(env_name):
puffer_dir = os.path.dirname(os.path.realpath(__file__))
puffer_config_dir = os.path.join(puffer_dir, 'config/**/*.ini')
puffer_default_config = os.path.join(puffer_dir, 'config/default.ini')
if env_name == 'default':
p = configparser.ConfigParser()
p.read(puffer_default_config)
else:
for path in glob.glob(puffer_config_dir, recursive=True):
p = configparser.ConfigParser()
p.read([puffer_default_config, path])
if env_name in p['base']['env_name'].split(): break
else:
raise pufferlib.APIUsageError('No config for env_name {}'.format(env_name))

return process_config(p)

def load_config_file(file_path, fill_in_default=True):
if not os.path.exists(file_path):
raise pufferlib.APIUsageError('No config file found')

config_paths = [file_path]

if fill_in_default:
puffer_dir = os.path.dirname(os.path.realpath(__file__))
# Process the puffer defaults first
config_paths.insert(0, os.path.join(puffer_dir, 'config/default.ini'))

p = configparser.ConfigParser()
p.read(config_paths)

return process_config(p)

def process_config(config):
parser = argparse.ArgumentParser(
description=f':blowfish: PufferLib [bright_cyan]{pufferlib.__version__}[/]'
' demo options. Shows valid args for your env and policy',
Expand All @@ -1144,34 +1177,19 @@ def load_config(env_name):
parser.add_argument('--tag', type=str, default=None, help='Tag for experiment')
args = parser.parse_known_args()[0]

# Load defaults and config
puffer_dir = os.path.dirname(os.path.realpath(__file__))
puffer_config_dir = os.path.join(puffer_dir, 'config/**/*.ini')
puffer_default_config = os.path.join(puffer_dir, 'config/default.ini')
if env_name == 'default':
p = configparser.ConfigParser()
p.read(puffer_default_config)
else:
for path in glob.glob(puffer_config_dir, recursive=True):
p = configparser.ConfigParser()
p.read([puffer_default_config, path])
if env_name in p['base']['env_name'].split(): break
else:
raise pufferlib.APIUsageError('No config for env_name {}'.format(env_name))

# Dynamic help menu from config
def puffer_type(value):
try:
return ast.literal_eval(value)
except:
return value

for section in p.sections():
for key in p[section]:
for section in config.sections():
for key in config[section]:
fmt = f'--{key}' if section == 'base' else f'--{section}.{key}'
parser.add_argument(
fmt.replace('_', '-'),
default=puffer_type(p[section][key]),
default=puffer_type(config[section][key]),
type=puffer_type
)

Expand All @@ -1189,6 +1207,7 @@ def puffer_type(value):

prev[subkey] = value

args['train']['env'] = args['env_name'] or '' # for trainer dashboard
args['train']['use_rnn'] = args['rnn_name'] is not None
return args

Expand Down
Loading