Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions neat/population.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def __init__(self, config, initial_state=None):
self.species.speciate(config, self.population, self.generation)
else:
self.population, self.species, self.generation = initial_state
self.species.reporters = self.reporters
# If the reproduction object has a genome indexer,
# set it to continue from the last genome ID.
if hasattr(self.reproduction, "genome_indexer"):
Expand Down
46 changes: 46 additions & 0 deletions tests/test_population.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,52 @@ def eval_genomes(genomes, config):
last_genome_key + 1
)

def test_reporter_consistency_after_checkpoint_restore(self):
"""
Test that ReportSets in the different objects in population are the same
after restoring from a checkpoint.
"""
# Load configuration.
local_dir = os.path.dirname(__file__)
config_path = os.path.join(local_dir, 'test_configuration')
config = neat.Config(neat.DefaultGenome, neat.DefaultReproduction,
neat.DefaultSpeciesSet, neat.DefaultStagnation,
config_path)

p = neat.Population(config)
filename_prefix = 'neat-checkpoint-test_population'
checkpointer = neat.Checkpointer(1, 5, filename_prefix=filename_prefix)
p.add_reporter(checkpointer)

reporter_set = p.reporters
self.assertEqual(reporter_set, p.reproduction.reporters)
self.assertEqual(reporter_set, p.species.reporters)

def eval_genomes(genomes, config):
for genome_id, genome in genomes:
genome.fitness = 0.5

p.run(eval_genomes, 5)

filename = '{0}{1}'.format(
filename_prefix, checkpointer.last_generation_checkpoint
)
restored_population = neat.Checkpointer.restore_checkpoint(filename)

# Check that the reporters are consistent
restored_reporter_set = restored_population.reporters
self.assertEqual(
restored_reporter_set,
restored_population.reproduction.reporters,
msg="Reproduction reporters do not match after restore"
)
self.assertEqual(
restored_reporter_set,
restored_population.species.reporters,
msg="Species reporters do not match after restore"
)


# def test_minimal():
# # sample fitness function
# def eval_fitness(population):
Expand Down
Loading