Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Speedup dag #118

Merged
merged 14 commits into from
Mar 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions workflow/Snakefile
Original file line number Diff line number Diff line change
Expand Up @@ -27,23 +27,26 @@ from legenddataflow.pre_compile_catalog import pre_compile_catalog
utils.subst_vars_in_snakemake_config(workflow, config)
config = AttrsDict(config)

os.environ["XDG_CACHE_HOME"] = config.get("XDG_CACHE_HOME", ".snakemake/cache")

check_in_cycle = True
configs = utils.config_path(config)
chan_maps = utils.chan_map_path(config)
meta = utils.metadata_path(config)
det_status = utils.det_status_path(config)
basedir = workflow.basedir

det_status_textdb = pre_compile_catalog(Path(det_status) / "statuses")
channelmap_textdb = pre_compile_catalog(Path(chan_maps) / "channelmaps")

time = datetime.now().strftime("%Y%m%dT%H%M%SZ")

# NOTE: this will attempt a clone of legend-metadata, if the directory does not exist
metadata = LegendMetadata(meta)
if "legend_metadata_version" in config:
metadata.checkout(config.legend_metadata_version)


det_status_textdb = pre_compile_catalog(Path(det_status) / "statuses")
channelmap_textdb = pre_compile_catalog(Path(chan_maps) / "channelmaps")

part = CalGrouping(config, Path(det_status) / "cal_groupings.yaml")


Expand Down
1 change: 1 addition & 0 deletions workflow/profiles/lngs-build-raw/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@ keep-going: true
rerun-incomplete: true
config:
- system=lngs
- XDG_CACHE_HOME=.snakemake/cache
1 change: 1 addition & 0 deletions workflow/profiles/lngs/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@ keep-going: true
rerun-incomplete: true
config:
- system=lngs
- XDG_CACHE_HOME=.snakemake/cache
5 changes: 4 additions & 1 deletion workflow/rules/ann.smk
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@ rule build_pan:
params:
timestamp="{timestamp}",
datatype="{datatype}",
table_map=lambda wildcards: get_table_mapping(
channelmap_textdb, wildcards.timestamp, wildcards.datatype, "dsp"
),
output:
tier_file=get_pattern_tier(config, "pan", check_in_cycle=check_in_cycle),
db_file=get_pattern_pars_tmp(config, "pan_db"),
Expand All @@ -66,7 +69,7 @@ rule build_pan:
shell:
execenv_pyexe(config, "build-tier-dsp") + "--log {log} "
"--configs {configs} "
"--metadata {meta} "
"--table-map '{params.table_map}' "
"--tier pan "
"--datatype {params.datatype} "
"--timestamp {params.timestamp} "
Expand Down
24 changes: 15 additions & 9 deletions workflow/rules/chanlist_gen.smk
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,7 @@ from dbetto import TextDB
from dbetto.catalog import Catalog


def get_chanlist(config, keypart, workflow, det_status, channelmap, system):
key = ChannelProcKey.parse_keypart(keypart)
def get_chanlist(config, timestamp, datatype, workflow, det_status, channelmap, system):

if isinstance(det_status, (str, Path)):
det_status = TextDB(det_status, lazy=True)
Expand All @@ -26,19 +25,19 @@ def get_chanlist(config, keypart, workflow, det_status, channelmap, system):
channelmap = TextDB(channelmap, lazy=True)

if isinstance(det_status, TextDB):
status_map = det_status.statuses.on(key.timestamp, system=key.datatype)
status_map = det_status.statuses.on(timestamp, system=datatype)
else:
status_map = det_status.valid_for(key.timestamp, system=key.datatype)
status_map = det_status.valid_for(timestamp, system=datatype)
if isinstance(channelmap, TextDB):
chmap = channelmap.channelmaps.on(key.timestamp)
chmap = channelmap.channelmaps.on(timestamp, system=datatype)
else:
chmap = channelmap.valid_for(key.timestamp)
chmap = channelmap.valid_for(timestamp, system=datatype)

# only restrict to a certain system (geds, spms, ...)
channels = []
for channel in chmap.map("system", unique=False)[system].map("name"):
if channel not in status_map:
msg = f"{channel} is not found in the status map (on {key.timestamp})"
msg = f"{channel} is not found in the status map (on {timestamp})"
raise RuntimeError(msg)
if status_map[channel].processable is False:
continue
Expand All @@ -62,8 +61,11 @@ def get_par_chanlist(
name=None,
extension="yaml",
):
key = ChannelProcKey.parse_keypart(keypart)

chan_list = get_chanlist(setup, keypart, workflow, det_status, chan_maps, system)
chan_list = get_chanlist(
setup, key.timestamp, key.datatype, workflow, det_status, chan_maps, system
)

par_pattern = get_pattern_pars_tmp_channel(
setup, tier, name, datatype=datatype, extension=extension
Expand All @@ -85,7 +87,11 @@ def get_plt_chanlist(
name=None,
):

chan_list = get_chanlist(setup, keypart, workflow, det_status, chan_maps, system)
key = ChannelProcKey.parse_keypart(keypart)

chan_list = get_chanlist(
setup, key.timestamp, key.datatype, workflow, det_status, chan_maps, system
)

par_pattern = get_pattern_plts_tmp_channel(setup, tier, name)

Expand Down
36 changes: 36 additions & 0 deletions workflow/rules/common.smk
Original file line number Diff line number Diff line change
Expand Up @@ -131,3 +131,39 @@ def get_table_name(metadata, config, datatype, timestamp, detector, tier):
f"metadata must be a string or a Catalog object, not {type(metadata)}"
)
return config.table_format[tier].format(ch=chmap[detector].daq.rawid)


def get_all_channels(channelmap, timestamp, datatype):
if isinstance(channelmap, (str, Path)):
channelmap = TextDB(channelmap, lazy=True)

if isinstance(channelmap, TextDB):
chmap = channelmap.channelmaps.on(timestamp, system=datatype)
else:
chmap = channelmap.valid_for(timestamp, system=datatype)

channels = list(chmap)

if len(channels) == 0:
print("WARNING: No channels found") # noqa: T201

return channels


def get_table_mapping(channelmap, timestamp, datatype, tier):
if isinstance(channelmap, (str, Path)):
channelmap = TextDB(channelmap, lazy=True)
channel_dict = channelmap.valid_for(timestamp, system=datatype)
detectors = get_all_channels(channelmap, timestamp, datatype)
return json.dumps(
{
detector: f"ch{channel_dict[detector].daq.rawid:07}/{tier}"
for detector in detectors
}
)


def strip_channel_wildcard_constraint(files):
if isinstance(files, str):
files = [files]
return [re.sub(r"\{channel,[^}]+\}", "{channel}", file) for file in files]
7 changes: 4 additions & 3 deletions workflow/rules/dsp.smk
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,11 @@ rule build_dsp:
timestamp="{timestamp}",
datatype="{datatype}",
ro_input=lambda _, input: {k: ro(v) for k, v in input.items()},
table_map=lambda wildcards: get_table_mapping(
channelmap_textdb, wildcards.timestamp, wildcards.datatype, "raw"
),
output:
tier_file=patt.get_pattern_tier(config, "dsp", check_in_cycle=check_in_cycle),
db_file=patt.get_pattern_pars_tmp(config, "dsp_db"),
log:
patt.get_pattern_log(config, "tier_dsp", time),
group:
Expand All @@ -55,10 +57,9 @@ rule build_dsp:
execenv_pyexe(config, "build-tier-dsp") + "--log {log} "
"--tier dsp "
f"--configs {ro(configs)} "
"--metadata {meta} "
"--table-map '{params.table_map}' "
"--datatype {params.datatype} "
"--timestamp {params.timestamp} "
"--input {params.ro_input[raw_file]} "
"--output {output.tier_file} "
"--db-file {output.db_file} "
"--pars-file {params.ro_input[pars_files]}"
6 changes: 4 additions & 2 deletions workflow/rules/dsp_pars_spms.smk
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ rule build_pars_dsp_tau_spms:
datatype="{datatype}",
channels=lambda wildcards: get_chanlist(
config,
f"all-{wildcards.experiment}-{wildcards.period}-{wildcards.run}-{wildcards.datatype}-{wildcards.timestamp}-channels",
wildcards.timestamp,
wildcards.datatype,
workflow,
det_status_textdb,
channelmap_textdb,
Expand All @@ -32,7 +33,8 @@ rule build_pars_dsp_tau_spms:
)
for channel in get_chanlist(
config,
f"all-{wildcards.experiment}-{wildcards.period}-{wildcards.run}-{wildcards.datatype}-{wildcards.timestamp}-channels",
wildcards.timestamp,
wildcards.datatype,
workflow,
det_status_textdb,
channelmap_textdb,
Expand Down
72 changes: 31 additions & 41 deletions workflow/rules/evt.smk
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ rule build_evt:
tcm_file=get_pattern_tier(config, "tcm", check_in_cycle=False),
ann_file=lambda wildcards: (
[]
if int(wildcards["period"][1:]) > 11
if int(wildcards["period"][1:]) > 9
else get_pattern_tier(config, "ann", check_in_cycle=False)
),
par_files=lambda wildcards: hit_par_catalog.get_par_file(
Expand All @@ -36,31 +36,26 @@ rule build_evt:
tier="evt",
ro_input=lambda _, input: {k: ro(v) for k, v in input.items()},
log:
get_pattern_log(config, f"tier_evt", time),
get_pattern_log(config, "tier_evt", time),
group:
"tier-evt"
resources:
runtime=300,
mem_swap=50,
run:
shell_string = (
execenv_pyexe(config, "build-tier-evt") + f"--configs {ro(configs)} "
f"--metadata {ro(meta)} "
"--log {log} "
"--tier {params.tier} "
"--datatype {params.datatype} "
"--timestamp {params.timestamp} "
"--xtc-file {params.ro_input[xtalk_matrix]} "
"--par-files {params.ro_input[par_files]} "
"--hit-file {params.ro_input[hit_file]} "
"--tcm-file {params.ro_input[tcm_file]} "
"--dsp-file {params.ro_input[dsp_file]} "
"--output {output} "
)
if input.ann_file is not None:
shell_string += "--ann-file {params.ro_input[ann_file]} "

shell(shell_string)
shell:
execenv_pyexe(config, "build-tier-evt") + f"--configs {ro(configs)} "
f"--metadata {ro(meta)} "
"--log {log} "
"--tier {params.tier} "
"--datatype {params.datatype} "
"--timestamp {params.timestamp} "
"--xtc-file {params.ro_input[xtalk_matrix]} "
"--par-files {params.ro_input[par_files]} "
"--hit-file {params.ro_input[hit_file]} "
"--tcm-file {params.ro_input[tcm_file]} "
"--dsp-file {params.ro_input[dsp_file]} "
"--output {output} "
"--ann-file {params.ro_input[ann_file]} "


rule build_pet:
Expand All @@ -87,31 +82,26 @@ rule build_pet:
tier="pet",
ro_input=lambda _, input: {k: ro(v) for k, v in input.items()},
log:
get_pattern_log(config, f"tier_pet", time),
get_pattern_log(config, "tier_pet", time),
group:
"tier-evt"
resources:
runtime=300,
mem_swap=50,
run:
shell_string = (
execenv_pyexe(config, "build-tier-evt") + f"--configs {ro(configs)} "
f"--metadata {ro(meta)} "
"--log {log} "
"--tier {params.tier} "
"--datatype {params.datatype} "
"--timestamp {params.timestamp} "
"--xtc-file {params.ro_input[xtalk_matrix]} "
"--par-files {params.ro_input[par_files]} "
"--hit-file {params.ro_input[hit_file]} "
"--tcm-file {params.ro_input[tcm_file]} "
"--dsp-file {params.ro_input[dsp_file]} "
"--output {output} "
)
if input.ann_file is not None:
shell_string += "--ann-file {params.ro_input[ann_file]} "

shell(shell_string)
shell:
execenv_pyexe(config, "build-tier-evt") + f"--configs {ro(configs)} "
f"--metadata {ro(meta)} "
"--log {log} "
"--tier {params.tier} "
"--datatype {params.datatype} "
"--timestamp {params.timestamp} "
"--xtc-file {params.ro_input[xtalk_matrix]} "
"--par-files {params.ro_input[par_files]} "
"--hit-file {params.ro_input[hit_file]} "
"--tcm-file {params.ro_input[tcm_file]} "
"--dsp-file {params.ro_input[dsp_file]} "
"--output {output} "
"--ann-file {params.ro_input[ann_file]} "


for evt_tier in ("evt", "pet"):
Expand Down
7 changes: 4 additions & 3 deletions workflow/rules/hit.smk
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,14 @@ rule build_hit:
),
output:
tier_file=get_pattern_tier(config, "hit", check_in_cycle=check_in_cycle),
db_file=get_pattern_pars_tmp(config, "hit_db"),
params:
timestamp="{timestamp}",
datatype="{datatype}",
tier="hit",
ro_input=lambda _, input: {k: ro(v) for k, v in input.items()},
table_map=lambda wildcards: get_table_mapping(
channelmap_textdb, wildcards.timestamp, wildcards.datatype, "dsp"
),
log:
get_pattern_log(config, "tier_hit", time),
group:
Expand All @@ -47,12 +49,11 @@ rule build_hit:
runtime=300,
shell:
execenv_pyexe(config, "build-tier-hit") + f"--configs {ro(configs)} "
"--metadata {meta} "
"--table-map '{params.table_map}' "
"--tier {params.tier} "
"--datatype {params.datatype} "
"--timestamp {params.timestamp} "
"--pars-file {params.ro_input[pars_file]} "
"--output {output.tier_file} "
"--input {params.ro_input[dsp_file]} "
"--db-file {output.db_file} "
"--log {log}"
Loading
Loading