Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Save cell vars to nwb #79

Draft
wants to merge 7 commits into
base: develop
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 7 additions & 14 deletions bmtk/simulator/bionet/modules/record_cellvars.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,17 +28,6 @@
from bmtk.simulator.bionet.io_tools import io

from bmtk.utils.io import cell_vars
try:
# Check to see if h5py is built to run in parallel
if h5py.get_config().mpi:
MembraneRecorder = cell_vars.CellVarRecorderParallel
else:
MembraneRecorder = cell_vars.CellVarRecorder

except Exception as e:
MembraneRecorder = cell_vars.CellVarRecorder

MembraneRecorder._io = io

pc = h.ParallelContext()
MPI_RANK = int(pc.id())
Expand Down Expand Up @@ -86,8 +75,12 @@ def __init__(self, tmp_dir, file_name, variable_name, cells, sections='all', buf
self._local_gids = []
self._sections = sections

self._var_recorder = MembraneRecorder(self._file_name, self._tmp_dir, self._all_variables,
buffer_data=buffer_data, mpi_rank=MPI_RANK, mpi_size=N_HOSTS)
recorder_cls = cell_vars.get_cell_var_recorder_cls(file_name)
recorder_cls._io = io
self._var_recorder = recorder_cls(
self._file_name, self._tmp_dir, self._all_variables,
buffer_data=buffer_data, mpi_rank=MPI_RANK, mpi_size=N_HOSTS
)

self._gid_list = [] # list of all gids that will have their variables saved
self._data_block = {} # table of variable data indexed by [gid][variable]
Expand Down Expand Up @@ -119,7 +112,7 @@ def initialize(self, sim):
# TODO: Make sure the seg has the recorded variable(s)
sec_list.append(sec_id)
seg_list.append(seg.x)

self._var_recorder.add_cell(gid, sec_list, seg_list)

self._var_recorder.initialize(sim.n_steps, sim.nsteps_block)
Expand Down
18 changes: 7 additions & 11 deletions bmtk/simulator/bionet/modules/record_netcons.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,6 @@
from bmtk.simulator.bionet.pointprocesscell import PointProcessCell

from bmtk.utils.io import cell_vars
try:
# Check to see if h5py is built to run in parallel
if h5py.get_config().mpi:
MembraneRecorder = cell_vars.CellVarRecorderParallel
else:
MembraneRecorder = cell_vars.CellVarRecorder
except Exception as e:
MembraneRecorder = cell_vars.CellVarRecorder

pc = h.ParallelContext()
MPI_RANK = int(pc.id())
Expand Down Expand Up @@ -46,9 +38,13 @@ def __init__(self, tmp_dir, file_name, variable_name, cells, sections='all', syn
self._all_gids = cells
self._local_gids = []
self._sections = sections

self._var_recorder = MembraneRecorder(self._file_name, self._tmp_dir, self._all_variables,
buffer_data=buffer_data, mpi_rank=MPI_RANK, mpi_size=N_HOSTS)

recorder_cls = cell_vars.get_cell_var_recorder_cls(file_name)
recorder_cls._io = io
self._var_recorder = recorder_cls(
self._file_name, self._tmp_dir, self._all_variables,
buffer_data=buffer_data, mpi_rank=MPI_RANK, mpi_size=N_HOSTS
)

self._virt_lookup = {}
self._gid_lookup = {}
Expand Down
2 changes: 1 addition & 1 deletion bmtk/simulator/pointnet/modules/multimeter_reporter.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
import glob
import pandas as pd
from bmtk.utils.io.cell_vars import CellVarRecorder
from bmtk.utils.io.cell_vars import CellVarRecorderH5 as CellVarRecorder
from bmtk.simulator.pointnet.io_tools import io

import nest
Expand Down
198 changes: 143 additions & 55 deletions bmtk/utils/io/cell_vars.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
import os
from datetime import datetime
from collections import defaultdict
import h5py
import numpy as np

from pynwb import NWBFile, NWBHDF5IO
from nwbext_simulation_output import Compartments, CompartmentSeries

from bmtk.utils import io
from bmtk.utils.sonata.utils import add_hdf5_magic, add_hdf5_version

Expand All @@ -11,17 +16,38 @@
comm = MPI.COMM_WORLD
rank = comm.Get_rank()
nhosts = comm.Get_size()

except Exception as exc:
pass
comm = None
rank = 1


def get_cell_var_recorder_cls(file_name):
"""Return the right class for recording cellvars based on the filename and whether parallel h5py is enabled"""
try:
in_mpi = h5py.get_config().mpi
except Exception as e:
in_mpi = False

if file_name.endswith('.nwb'):
# NWB
if in_mpi:
return CellVarRecorderNWBParallel
else:
return CellVarRecorderNWB
else:
# HDF5
if in_mpi:
return CellVarRecorderH5Parallel
else:
return CellVarRecorderH5


class CellVarRecorder(object):
class CellVarRecorderH5(object):
"""Used to save cell membrane variables (V, Ca2+, etc) to the described hdf5 format.

For parallel simulations this class will write to a seperate tmp file on each rank, then use the merge method to
combine the results. This is less efficent, but doesn't require the user to install mpi4py and build h5py in
parallel mode. For better performance use the CellVarRecorderParrallel class instead.
parallel mode. For better performance use one of the CellVarRecorder{H5,NWB}Parallel classes instead.
"""
_io = io

Expand All @@ -36,7 +62,7 @@ def __init__(self, var_name):

def __init__(self, file_name, tmp_dir, variables, buffer_data=True, mpi_rank=0, mpi_size=1):
self._file_name = file_name
self._h5_handle = None
self._file_handle = None
self._tmp_dir = tmp_dir
self._variables = variables if isinstance(variables, list) else [variables]
self._n_vars = len(self._variables) # Used later to keep track if more than one var is saved to the same file.
Expand All @@ -46,7 +72,8 @@ def __init__(self, file_name, tmp_dir, variables, buffer_data=True, mpi_rank=0,
self._tmp_files = []
self._saved_file = file_name

if mpi_size > 1:
if mpi_size > 1 and not isinstance(self, ParallelRecorderMixin):

self._io.log_warning('Was unable to run h5py in parallel (mpi) mode.' +
' Saving of membrane variable(s) may slow down.')
tmp_fname = os.path.basename(file_name) # make sure file names don't clash if there are multiple reports
Expand All @@ -56,8 +83,8 @@ def __init__(self, file_name, tmp_dir, variables, buffer_data=True, mpi_rank=0,

self._mapping_gids = [] # list of gids in the order they appear in the data
self._gid_map = {} # table for looking up the gid offsets
self._map_attrs = {} # Used for additonal attributes in /mapping

self._map_attrs = defaultdict(list) # Used for additonal attributes in /mapping
self._mapping_element_ids = [] # sections
self._mapping_element_pos = [] # segments
self._mapping_index = [0] # index_pointer
Expand Down Expand Up @@ -123,10 +150,10 @@ def _calc_offset(self):
self._gids_beg = 0
self._gids_end = self._n_gids_local

def _create_h5_file(self):
self._h5_handle = h5py.File(self._file_name, 'w')
add_hdf5_version(self._h5_handle)
add_hdf5_magic(self._h5_handle)
def _create_file(self, **io_kwargs):
self._file_handle = h5py.File(self._file_name, 'w', **io_kwargs)
add_hdf5_version(self._file_handle)
add_hdf5_magic(self._file_handle)

def add_cell(self, gid, sec_list, seg_list, **map_attrs):
assert(len(sec_list) == len(seg_list))
Expand All @@ -140,16 +167,26 @@ def add_cell(self, gid, sec_list, seg_list, **map_attrs):
self._n_segments_local += n_segs
self._n_gids_local += 1
for k, v in map_attrs.items():
if k not in self._map_attrs:
self._map_attrs[k] = v
else:
self._map_attrs[k].extend(v)
self._map_attrs[k].extend(v)

def initialize(self, n_steps, buffer_size=0):
self._calc_offset()
self._create_h5_file()
self._create_file()
self._init_mapping()
self._total_steps = n_steps
self._buffer_block_size = buffer_size
self._init_buffers()

if not self._buffer_data:
# If data is not being buffered and instead written to the main block,
# we have to add a rank offset to the gid offset
for gid, gid_offset in self._gid_map.items():
self._gid_map[gid] = (gid_offset[0] + self._seg_offset_beg, gid_offset[1] + self._seg_offset_beg)

self._is_initialized = True

var_grp = self._h5_handle.create_group('/mapping')
def _init_mapping(self):
var_grp = self._file_handle.create_group('/mapping')
var_grp.create_dataset('gids', shape=(self._n_gids_all,), dtype=np.uint)
var_grp.create_dataset('element_id', shape=(self._n_segments_all,), dtype=np.uint)
var_grp.create_dataset('element_pos', shape=(self._n_segments_all,), dtype=np.float)
Expand All @@ -164,32 +201,24 @@ def initialize(self, n_steps, buffer_size=0):
var_grp['index_pointer'][self._gids_beg:(self._gids_end+1)] = self._mapping_index
for k, v in self._map_attrs.items():
var_grp[k][self._seg_offset_beg:self._seg_offset_end] = v

self._total_steps = n_steps
self._buffer_block_size = buffer_size
if not self._buffer_data:
# If data is not being buffered and instead written to the main block, we have to add a rank offset
# to the gid offset
for gid, gid_offset in self._gid_map.items():
self._gid_map[gid] = (gid_offset[0] + self._seg_offset_beg, gid_offset[1] + self._seg_offset_beg)


def _init_buffers(self):
for var_name, data_tables in self._data_blocks.items():
# If users are trying to save multiple variables in the same file put data table in its own /{var} group
# (not sonata compliant). Otherwise the data table is located at the root
data_grp = self._h5_handle if self._n_vars == 1 else self._h5_handle.create_group('/{}'.format(var_name))
data_grp = self._file_handle if self._n_vars == 1 else self._file_handle.create_group('/{}'.format(var_name))
if self._buffer_data:
# Set up in-memory block to buffer recorded variables before writing to the dataset
data_tables.buffer_block = np.zeros((buffer_size, self._n_segments_local), dtype=np.float)
data_tables.data_block = data_grp.create_dataset('data', shape=(n_steps, self._n_segments_all),
data_tables.buffer_block = np.zeros((self._buffer_block_size, self._n_segments_local), dtype=np.float)
data_tables.data_block = data_grp.create_dataset('data', shape=(self._total_steps, self._n_segments_all),
dtype=np.float, chunks=True)
data_tables.data_block.attrs['variable_name'] = var_name
else:
# Since we are not buffering data, we just write directly to the on-disk dataset
data_tables.buffer_block = data_grp.create_dataset('data', shape=(n_steps, self._n_segments_all),
data_tables.buffer_block = data_grp.create_dataset('data', shape=(self._total_steps, self._n_segments_all),
dtype=np.float, chunks=True)
data_tables.buffer_block.attrs['variable_name'] = var_name

self._is_initialized = True

def record_cell(self, gid, var_name, seg_vals, tstep):
"""Record cell parameters.
Expand Down Expand Up @@ -234,7 +263,7 @@ def flush(self):
data_table.data_block[blk_beg:blk_end, :] = data_table.buffer_block[:block_size, :]

def close(self):
self._h5_handle.close()
self._file_handle.close()

def merge(self):
if self._mpi_size > 1 and self._mpi_rank == 0:
Expand Down Expand Up @@ -290,7 +319,6 @@ def merge(self):
gids_ds[beg:end] = tmp_mapping_grp['gids']
index_pointer_ds[beg:(end+1)] = update_index


# combine the /var/data datasets
for var_name in self._variables:
data_name = '/data' if self._n_vars == 1 else '/{}/data'.format(var_name)
Expand All @@ -305,33 +333,85 @@ def merge(self):
os.remove(tmp_file)


class CellVarRecorderParallel(CellVarRecorder):
"""
Unlike the parent, this take advantage of parallel h5py to writting to the results file across different ranks.
class CellVarRecorderNWB(CellVarRecorderH5):
def __init__(self, file_name, tmp_dir, variables, buffer_data=True, mpi_rank=0, mpi_size=1):
super(CellVarRecorderNWB, self).__init__(
file_name, tmp_dir, variables, buffer_data=buffer_data,
mpi_rank=mpi_rank, mpi_size=mpi_size
)
self._compartments = Compartments('compartments')
self._compartmentseries = {}

def _create_file(self, **io_kwargs):
self._nwbio = NWBHDF5IO(self._file_name, 'w', **io_kwargs)
self._file_handle = NWBFile('description', 'id', datetime.now().astimezone()) # TODO: pass in descr, id

def add_cell(self, gid, sec_list, seg_list, **map_attrs):
if map_attrs:
raise NotImplementedError('Cannot use map_attrs with NWB') # TODO: support this
self._compartments.add_row(number=sec_list, position=seg_list, id=gid)
super(CellVarRecorderNWB, self).add_cell(gid, sec_list, seg_list, **map_attrs)

def _init_mapping(self):
# Cell/section id and pos are in the Compartments table
# 1/dt is the rate of the recorded datasets.
# tstart was used as session_start_time when creating the NWBFile
# nwb doesn't store tstop
pass

def _init_buffers(self):
self._file_handle.add_acquisition(self._compartments)
for var_name, data_tables in self._data_blocks.items():
cs = CompartmentSeries(
var_name, data=np.zeros((self._total_steps, self._n_segments_all)),
compartments=self._compartments, unit='mV', rate=1000.0/self.dt
)
self._compartmentseries[var_name] = cs
self._file_handle.add_acquisition(cs)
data_tables.buffer_block = np.zeros((self._buffer_block_size, self._n_segments_local), dtype=np.float)
data_tables.data_block = self._compartmentseries[var_name].data

self._nwbio.write(self._file_handle)

# Re-read data sets so that pynwb forgets it has them in memory
# (forces immediate write upon modification)
self._nwbio.close()
self._nwbio = NWBHDF5IO(self._file_name, 'a', comm=comm)
self._file_handle = self._nwbio.read()
for var_name, data_tables in self._data_blocks.items():
self._data_blocks[var_name].data_block = self._file_handle.acquisition[var_name].data

def close(self):
self._nwbio.close()

def merge(self):
raise NotImplementedError("Can't merge NWB files across ranks")

"""
def __init__(self, file_name, tmp_dir, variables, buffer_data=True):
super(CellVarRecorder, self).__init__(file_name, tmp_dir, variables, buffer_data=buffer_data, mpi_rank=0,
mpi_size=1)

class ParallelRecorderMixin():
"""
When inherited along with one of the CellVarRecorder classes, this takes
advantage of parallel h5py to collectively write the results file from multiple ranks.
"""
def _calc_offset(self):
# iterate through the ranks let rank r determine the offset from rank r-1
for r in range(comm.Get_size()):
if rank == r:
if rank < (nhosts - 1):
# pass the num of segments and num of gids to the next rank
offsets = np.array([self._n_segments_local, self._n_gids_local], dtype=np.uint)
comm.Send([offsets, MPI.UNSIGNED_INT], dest=(rank+1))

if rank > 0:
# get num of segments and gids from prev. rank and calculate offsets
offset = np.empty(2, dtype=np.uint)
offsets = np.empty(2, dtype=np.uint)
comm.Recv([offsets, MPI.UNSIGNED_INT], source=(r-1))
self._seg_offset_beg = offsets[0]
self._seg_offset_end = self._seg_offset_beg + self._n_segments_local
self._gids_beg = offsets[1]

self._seg_offset_end = int(self._seg_offset_beg) \
+ int(self._n_segments_local)
self._gids_end = int(self._gids_beg) + int(self._n_gids_local)

self._gids_beg = offset[1]
self._gids_end = self._gids_beg + self._n_gids_local
if rank < (nhosts - 1):
# pass the next rank its offset
offsets = np.array([self._seg_offset_end, self._gids_end], dtype=np.uint)
comm.Send([offsets, MPI.UNSIGNED_INT], dest=(rank+1))

comm.Barrier()

Expand All @@ -345,10 +425,18 @@ def _calc_offset(self):
self._n_segments_all = total_counts[0]
self._n_gids_all = total_counts[1]

def _create_h5_file(self):
self._h5_handle = h5py.File(self._file_name, 'w', driver='mpio', comm=MPI.COMM_WORLD)
add_hdf5_version(self._h5_handle)
add_hdf5_magic(self._h5_handle)

def merge(self):
pass


class CellVarRecorderH5Parallel(ParallelRecorderMixin, CellVarRecorderH5):
def _create_file(self, **io_kwargs):
io_kwargs['driver'] = 'mpio'
io_kwargs['comm'] = comm
super(CellVarRecorderH5Parallel, self)._create_file(**io_kwargs)


class CellVarRecorderNWBParallel(ParallelRecorderMixin, CellVarRecorderNWB):
def _create_file(self, **io_kwargs):
io_kwargs['comm'] = comm
super(CellVarRecorderNWBParallel, self)._create_file(**io_kwargs)