Skip to content

Refactor RNGHierarchy to enable tying the random number generators to detectors#523

Draft
anand-avinash wants to merge 3 commits into
masterfrom
seeder
Draft

Refactor RNGHierarchy to enable tying the random number generators to detectors#523
anand-avinash wants to merge 3 commits into
masterfrom
seeder

Conversation

@anand-avinash
Copy link
Copy Markdown
Contributor

The current RNGHierarchy class guarantees reproducible RNG hierarchy on a given MPI rank provided that the user provided seed, size of the MPI communicator, and number of detectors on that rank remain the same. This can be modified to pass the MPI communicator to RNGHierarchy instead of the size of the communicator. The size of the communicator will be inferred from the communicator itself within the class during its instantiation, keeping the overall behavior of the class unchanged, when the user provides the global MPI communicator to this class.

However, this updated interface can be used to provided same RNG for a given detector across different rank (or time blocks) if a suitable MPI communicator is supplied. Consider an example where 3 detector and their TODs are distributed across 2 detector blocks and 3 time blocks with a total 6 MPI processes:

import time
import litebird_sim as lbs
from litebird_sim import RNGHierarchy

sim = lbs.Simulation(
    start_time=0.0,
    duration_s=86400.0,
    random_seed=12345,
)

sim.create_observations(
    detectors=[
        lbs.DetectorInfo(name="det_A", sampling_rate_hz=1.0),
        lbs.DetectorInfo(name="det_B", sampling_rate_hz=1.0),
        lbs.DetectorInfo(name="det_C", sampling_rate_hz=1.0),
    ],
    split_list_over_processes=False,
    num_of_obs_per_detector=1,
    n_blocks_det=2,
    n_blocks_time=3,
)


global_comm = lbs.MPI_COMM_WORLD

global_rng = RNGHierarchy(
    base_seed=123,
    comm=global_comm,
    num_detectors_per_rank=sim.observations[0].n_detectors,
)


det_block_rng = RNGHierarchy(
    base_seed=456,
    comm=sim.observations[0].comm_det_block,
    num_detectors_per_rank=sim.observations[0].n_detectors,
)


time_block_rng = RNGHierarchy(
    base_seed=789,
    comm=sim.observations[0].comm_time_block,
    num_detectors_per_rank=sim.observations[0].n_detectors,
)


def print_det_level_generators(
    global_rank, local_rank, det_names, det_level_generators
):
    assert len(det_names) == len(det_level_generators)
    for det_name, gen in zip(det_names, det_level_generators):
        print(
            f"global_rank = {global_rank}, local_rank = {local_rank}",
            det_name,
            gen.__getstate__(),
        )


### RNG hierarchy for global communicator

print_det_level_generators(
    global_comm.rank,
    global_comm.rank,
    sim.observations[0].name,
    global_rng.get_detector_level_generators_on_rank(global_comm.rank),
)

time.sleep(0.5)
if global_comm.rank == 0:
    print("\n")
time.sleep(0.5)


### RNG hierarchy for time block communicator

print_det_level_generators(
    global_comm.rank,
    sim.observations[0].comm_time_block.rank,
    sim.observations[0].name,
    time_block_rng.get_detector_level_generators_on_rank(
        sim.observations[0].comm_time_block.rank
    ),
)


time.sleep(0.5)
if global_comm.rank == 0:
    print("\n")
time.sleep(0.5)


### RNG hierarchy for detector block communicator
### This serves no purpose

# print_det_level_generators(
#     global_comm.rank,
#     sim.observations[0].comm_det_block.rank,
#     sim.observations[0].name,
#     det_block_rng.get_detector_level_generators_on_rank(
#         sim.observations[0].comm_det_block.rank
#     ),
# )

The data distributions would look like the following:

det A, det B det C
time_block 1 rank 0 rank 3
time_block 2 rank 1 rank 4
time_block 3 rank 2 rank 5

With this distribution, the global communicator will be partitioned into three time block communicators accessible with sim.observations[0].comm_time_block. The first one will contain ranks 0 and 3, second will contain ranks 1 and 4, and the third will contain ranks 2 and 5. As a result, since for ranks 0, 1, and 2 the size of sim.observations[0].comm_time_block and the number of detectors per rank are same, RNGHierarchy will produce exactly the same hierarchy with a common seed on these ranks. The same goes for ranks 3, 4, and 5. So, the detector level RNGs for detector A will be same across rank 0, 1, and 2; detector level RNGs for detector B will be same across ranks 0, 1, and 2; and detector level RNGs for detector C will be same across ranks 3, 4,and 5 -- guaranteeing that random numbers produced for given detector are same across different MPI ranks. This completely solves the issue raised with #510.

Here I am showing the result of the print statements from the script I added above for RNG state verification:

### RNG hierarchy for global communicator - each block of TOD has an RNG with unique state


global_rank = 0, local_rank = 0 det_A {'bit_generator': 'PCG64', 'state': {'state': 144619848914138535302906084289083140116, 'inc': 137901272774577361021997142407036157741}, 'has_uint32': 0, 'uinteger': 0}
global_rank = 0, local_rank = 0 det_B {'bit_generator': 'PCG64', 'state': {'state': 68488701009948191967673224139517843068, 'inc': 181070714709730587431689905102663183521}, 'has_uint32': 0, 'uinteger': 0}

global_rank = 1, local_rank = 1 det_A {'bit_generator': 'PCG64', 'state': {'state': 31074665416037212809444864448348876291, 'inc': 271684247880976466715616379799581786115}, 'has_uint32': 0, 'uinteger': 0}
global_rank = 1, local_rank = 1 det_B {'bit_generator': 'PCG64', 'state': {'state': 39625932023379125989704284680160134973, 'inc': 245272176840412043495140278152880377797}, 'has_uint32': 0, 'uinteger': 0}

global_rank = 2, local_rank = 2 det_A {'bit_generator': 'PCG64', 'state': {'state': 327913130787634335600406052353961127019, 'inc': 285897430671043131730482796764641238999}, 'has_uint32': 0, 'uinteger': 0}
global_rank = 2, local_rank = 2 det_B {'bit_generator': 'PCG64', 'state': {'state': 310112831948863994768429857504937681947, 'inc': 27809136912511448367194439464235745091}, 'has_uint32': 0, 'uinteger': 0}

global_rank = 3, local_rank = 3 det_C {'bit_generator': 'PCG64', 'state': {'state': 83821247557234737719374691102984488495, 'inc': 106074201384900303510227639608675842753}, 'has_uint32': 0, 'uinteger': 0}

global_rank = 4, local_rank = 4 det_C {'bit_generator': 'PCG64', 'state': {'state': 94339045799319636164492511438536753607, 'inc': 246373086755039278226331009459390934807}, 'has_uint32': 0, 'uinteger': 0}

global_rank = 5, local_rank = 5 det_C {'bit_generator': 'PCG64', 'state': {'state': 228074635286621183168880197838952638317, 'inc': 14417719495420430718010116231659948761}, 'has_uint32': 0, 'uinteger': 0}



### RNG hierarchy for time block communicator - RNGs corresponding to a given detector have same state across different MPI ranks


global_rank = 0, local_rank = 0 det_A {'bit_generator': 'PCG64', 'state': {'state': 211423005980925913106542892991315253702, 'inc': 103874591357702541628135110433674472181}, 'has_uint32': 0, 'uinteger': 0}
global_rank = 1, local_rank = 0 det_A {'bit_generator': 'PCG64', 'state': {'state': 211423005980925913106542892991315253702, 'inc': 103874591357702541628135110433674472181}, 'has_uint32': 0, 'uinteger': 0}
global_rank = 2, local_rank = 0 det_A {'bit_generator': 'PCG64', 'state': {'state': 211423005980925913106542892991315253702, 'inc': 103874591357702541628135110433674472181}, 'has_uint32': 0, 'uinteger': 0}

global_rank = 0, local_rank = 0 det_B {'bit_generator': 'PCG64', 'state': {'state': 333871133010337104201869106881567755991, 'inc': 156961187590396004783913536309161093731}, 'has_uint32': 0, 'uinteger': 0}
global_rank = 1, local_rank = 0 det_B {'bit_generator': 'PCG64', 'state': {'state': 333871133010337104201869106881567755991, 'inc': 156961187590396004783913536309161093731}, 'has_uint32': 0, 'uinteger': 0}
global_rank = 2, local_rank = 0 det_B {'bit_generator': 'PCG64', 'state': {'state': 333871133010337104201869106881567755991, 'inc': 156961187590396004783913536309161093731}, 'has_uint32': 0, 'uinteger': 0}

global_rank = 3, local_rank = 1 det_C {'bit_generator': 'PCG64', 'state': {'state': 103944651775319475029580060687993738622, 'inc': 51519893583648684950972087145344717905}, 'has_uint32': 0, 'uinteger': 0}
global_rank = 4, local_rank = 1 det_C {'bit_generator': 'PCG64', 'state': {'state': 103944651775319475029580060687993738622, 'inc': 51519893583648684950972087145344717905}, 'has_uint32': 0, 'uinteger': 0}
global_rank = 5, local_rank = 1 det_C {'bit_generator': 'PCG64', 'state': {'state': 103944651775319475029580060687993738622, 'inc': 51519893583648684950972087145344717905}, 'has_uint32': 0, 'uinteger': 0}

Please take a look at this, if it seems suitable and sufficient, I would update the docstrings before merging this to master branch.

@github-actions
Copy link
Copy Markdown

github-actions Bot commented May 20, 2026

Coverage report

Click to see where and how coverage changed

FileStatementsMissingCoverageCoverage
(new stmts)
Lines missing
  litebird_sim
  seeding.py
  simulations.py
Project Total  

This report was generated by python-coverage-comment-action

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant