Skip to content

Commit 8ebc41a

Browse files
authored
Merge pull request #118 from legend-exp/speedup_dag
Speedup dag
2 parents b1bbe44 + 29fae09 commit 8ebc41a

29 files changed

+534
-675
lines changed

workflow/Snakefile

+6-3
Original file line numberDiff line numberDiff line change
@@ -27,23 +27,26 @@ from legenddataflow.pre_compile_catalog import pre_compile_catalog
2727
utils.subst_vars_in_snakemake_config(workflow, config)
2828
config = AttrsDict(config)
2929

30+
os.environ["XDG_CACHE_HOME"] = config.get("XDG_CACHE_HOME", ".snakemake/cache")
31+
3032
check_in_cycle = True
3133
configs = utils.config_path(config)
3234
chan_maps = utils.chan_map_path(config)
3335
meta = utils.metadata_path(config)
3436
det_status = utils.det_status_path(config)
3537
basedir = workflow.basedir
3638

37-
det_status_textdb = pre_compile_catalog(Path(det_status) / "statuses")
38-
channelmap_textdb = pre_compile_catalog(Path(chan_maps) / "channelmaps")
39-
4039
time = datetime.now().strftime("%Y%m%dT%H%M%SZ")
4140

4241
# NOTE: this will attempt a clone of legend-metadata, if the directory does not exist
4342
metadata = LegendMetadata(meta)
4443
if "legend_metadata_version" in config:
4544
metadata.checkout(config.legend_metadata_version)
4645

46+
47+
det_status_textdb = pre_compile_catalog(Path(det_status) / "statuses")
48+
channelmap_textdb = pre_compile_catalog(Path(chan_maps) / "channelmaps")
49+
4750
part = CalGrouping(config, Path(det_status) / "cal_groupings.yaml")
4851

4952

workflow/profiles/lngs-build-raw/config.yaml

+1
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,4 @@ keep-going: true
77
rerun-incomplete: true
88
config:
99
- system=lngs
10+
- XDG_CACHE_HOME=.snakemake/cache

workflow/profiles/lngs/config.yaml

+1
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,4 @@ keep-going: true
88
rerun-incomplete: true
99
config:
1010
- system=lngs
11+
- XDG_CACHE_HOME=.snakemake/cache

workflow/rules/ann.smk

+4-1
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,9 @@ rule build_pan:
5353
params:
5454
timestamp="{timestamp}",
5555
datatype="{datatype}",
56+
table_map=lambda wildcards: get_table_mapping(
57+
channelmap_textdb, wildcards.timestamp, wildcards.datatype, "dsp"
58+
),
5659
output:
5760
tier_file=get_pattern_tier(config, "pan", check_in_cycle=check_in_cycle),
5861
db_file=get_pattern_pars_tmp(config, "pan_db"),
@@ -66,7 +69,7 @@ rule build_pan:
6669
shell:
6770
execenv_pyexe(config, "build-tier-dsp") + "--log {log} "
6871
"--configs {configs} "
69-
"--metadata {meta} "
72+
"--table-map '{params.table_map}' "
7073
"--tier pan "
7174
"--datatype {params.datatype} "
7275
"--timestamp {params.timestamp} "

workflow/rules/chanlist_gen.smk

+15-9
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,7 @@ from dbetto import TextDB
1616
from dbetto.catalog import Catalog
1717

1818

19-
def get_chanlist(config, keypart, workflow, det_status, channelmap, system):
20-
key = ChannelProcKey.parse_keypart(keypart)
19+
def get_chanlist(config, timestamp, datatype, workflow, det_status, channelmap, system):
2120

2221
if isinstance(det_status, (str, Path)):
2322
det_status = TextDB(det_status, lazy=True)
@@ -26,19 +25,19 @@ def get_chanlist(config, keypart, workflow, det_status, channelmap, system):
2625
channelmap = TextDB(channelmap, lazy=True)
2726

2827
if isinstance(det_status, TextDB):
29-
status_map = det_status.statuses.on(key.timestamp, system=key.datatype)
28+
status_map = det_status.statuses.on(timestamp, system=datatype)
3029
else:
31-
status_map = det_status.valid_for(key.timestamp, system=key.datatype)
30+
status_map = det_status.valid_for(timestamp, system=datatype)
3231
if isinstance(channelmap, TextDB):
33-
chmap = channelmap.channelmaps.on(key.timestamp)
32+
chmap = channelmap.channelmaps.on(timestamp, system=datatype)
3433
else:
35-
chmap = channelmap.valid_for(key.timestamp)
34+
chmap = channelmap.valid_for(timestamp, system=datatype)
3635

3736
# only restrict to a certain system (geds, spms, ...)
3837
channels = []
3938
for channel in chmap.map("system", unique=False)[system].map("name"):
4039
if channel not in status_map:
41-
msg = f"{channel} is not found in the status map (on {key.timestamp})"
40+
msg = f"{channel} is not found in the status map (on {timestamp})"
4241
raise RuntimeError(msg)
4342
if status_map[channel].processable is False:
4443
continue
@@ -62,8 +61,11 @@ def get_par_chanlist(
6261
name=None,
6362
extension="yaml",
6463
):
64+
key = ChannelProcKey.parse_keypart(keypart)
6565

66-
chan_list = get_chanlist(setup, keypart, workflow, det_status, chan_maps, system)
66+
chan_list = get_chanlist(
67+
setup, key.timestamp, key.datatype, workflow, det_status, chan_maps, system
68+
)
6769

6870
par_pattern = get_pattern_pars_tmp_channel(
6971
setup, tier, name, datatype=datatype, extension=extension
@@ -85,7 +87,11 @@ def get_plt_chanlist(
8587
name=None,
8688
):
8789

88-
chan_list = get_chanlist(setup, keypart, workflow, det_status, chan_maps, system)
90+
key = ChannelProcKey.parse_keypart(keypart)
91+
92+
chan_list = get_chanlist(
93+
setup, key.timestamp, key.datatype, workflow, det_status, chan_maps, system
94+
)
8995

9096
par_pattern = get_pattern_plts_tmp_channel(setup, tier, name)
9197

workflow/rules/common.smk

+36
Original file line numberDiff line numberDiff line change
@@ -131,3 +131,39 @@ def get_table_name(metadata, config, datatype, timestamp, detector, tier):
131131
f"metadata must be a string or a Catalog object, not {type(metadata)}"
132132
)
133133
return config.table_format[tier].format(ch=chmap[detector].daq.rawid)
134+
135+
136+
def get_all_channels(channelmap, timestamp, datatype):
137+
if isinstance(channelmap, (str, Path)):
138+
channelmap = TextDB(channelmap, lazy=True)
139+
140+
if isinstance(channelmap, TextDB):
141+
chmap = channelmap.channelmaps.on(timestamp, system=datatype)
142+
else:
143+
chmap = channelmap.valid_for(timestamp, system=datatype)
144+
145+
channels = list(chmap)
146+
147+
if len(channels) == 0:
148+
print("WARNING: No channels found") # noqa: T201
149+
150+
return channels
151+
152+
153+
def get_table_mapping(channelmap, timestamp, datatype, tier):
154+
if isinstance(channelmap, (str, Path)):
155+
channelmap = TextDB(channelmap, lazy=True)
156+
channel_dict = channelmap.valid_for(timestamp, system=datatype)
157+
detectors = get_all_channels(channelmap, timestamp, datatype)
158+
return json.dumps(
159+
{
160+
detector: f"ch{channel_dict[detector].daq.rawid:07}/{tier}"
161+
for detector in detectors
162+
}
163+
)
164+
165+
166+
def strip_channel_wildcard_constraint(files):
167+
if isinstance(files, str):
168+
files = [files]
169+
return [re.sub(r"\{channel,[^}]+\}", "{channel}", file) for file in files]

workflow/rules/dsp.smk

+4-3
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,11 @@ rule build_dsp:
4141
timestamp="{timestamp}",
4242
datatype="{datatype}",
4343
ro_input=lambda _, input: {k: ro(v) for k, v in input.items()},
44+
table_map=lambda wildcards: get_table_mapping(
45+
channelmap_textdb, wildcards.timestamp, wildcards.datatype, "raw"
46+
),
4447
output:
4548
tier_file=patt.get_pattern_tier(config, "dsp", check_in_cycle=check_in_cycle),
46-
db_file=patt.get_pattern_pars_tmp(config, "dsp_db"),
4749
log:
4850
patt.get_pattern_log(config, "tier_dsp", time),
4951
group:
@@ -55,10 +57,9 @@ rule build_dsp:
5557
execenv_pyexe(config, "build-tier-dsp") + "--log {log} "
5658
"--tier dsp "
5759
f"--configs {ro(configs)} "
58-
"--metadata {meta} "
60+
"--table-map '{params.table_map}' "
5961
"--datatype {params.datatype} "
6062
"--timestamp {params.timestamp} "
6163
"--input {params.ro_input[raw_file]} "
6264
"--output {output.tier_file} "
63-
"--db-file {output.db_file} "
6465
"--pars-file {params.ro_input[pars_files]}"

workflow/rules/dsp_pars_spms.smk

+4-2
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@ rule build_pars_dsp_tau_spms:
1515
datatype="{datatype}",
1616
channels=lambda wildcards: get_chanlist(
1717
config,
18-
f"all-{wildcards.experiment}-{wildcards.period}-{wildcards.run}-{wildcards.datatype}-{wildcards.timestamp}-channels",
18+
wildcards.timestamp,
19+
wildcards.datatype,
1920
workflow,
2021
det_status_textdb,
2122
channelmap_textdb,
@@ -32,7 +33,8 @@ rule build_pars_dsp_tau_spms:
3233
)
3334
for channel in get_chanlist(
3435
config,
35-
f"all-{wildcards.experiment}-{wildcards.period}-{wildcards.run}-{wildcards.datatype}-{wildcards.timestamp}-channels",
36+
wildcards.timestamp,
37+
wildcards.datatype,
3638
workflow,
3739
det_status_textdb,
3840
channelmap_textdb,

workflow/rules/evt.smk

+31-41
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ rule build_evt:
1919
tcm_file=get_pattern_tier(config, "tcm", check_in_cycle=False),
2020
ann_file=lambda wildcards: (
2121
[]
22-
if int(wildcards["period"][1:]) > 11
22+
if int(wildcards["period"][1:]) > 9
2323
else get_pattern_tier(config, "ann", check_in_cycle=False)
2424
),
2525
par_files=lambda wildcards: hit_par_catalog.get_par_file(
@@ -36,31 +36,26 @@ rule build_evt:
3636
tier="evt",
3737
ro_input=lambda _, input: {k: ro(v) for k, v in input.items()},
3838
log:
39-
get_pattern_log(config, f"tier_evt", time),
39+
get_pattern_log(config, "tier_evt", time),
4040
group:
4141
"tier-evt"
4242
resources:
4343
runtime=300,
4444
mem_swap=50,
45-
run:
46-
shell_string = (
47-
execenv_pyexe(config, "build-tier-evt") + f"--configs {ro(configs)} "
48-
f"--metadata {ro(meta)} "
49-
"--log {log} "
50-
"--tier {params.tier} "
51-
"--datatype {params.datatype} "
52-
"--timestamp {params.timestamp} "
53-
"--xtc-file {params.ro_input[xtalk_matrix]} "
54-
"--par-files {params.ro_input[par_files]} "
55-
"--hit-file {params.ro_input[hit_file]} "
56-
"--tcm-file {params.ro_input[tcm_file]} "
57-
"--dsp-file {params.ro_input[dsp_file]} "
58-
"--output {output} "
59-
)
60-
if input.ann_file is not None:
61-
shell_string += "--ann-file {params.ro_input[ann_file]} "
62-
63-
shell(shell_string)
45+
shell:
46+
execenv_pyexe(config, "build-tier-evt") + f"--configs {ro(configs)} "
47+
f"--metadata {ro(meta)} "
48+
"--log {log} "
49+
"--tier {params.tier} "
50+
"--datatype {params.datatype} "
51+
"--timestamp {params.timestamp} "
52+
"--xtc-file {params.ro_input[xtalk_matrix]} "
53+
"--par-files {params.ro_input[par_files]} "
54+
"--hit-file {params.ro_input[hit_file]} "
55+
"--tcm-file {params.ro_input[tcm_file]} "
56+
"--dsp-file {params.ro_input[dsp_file]} "
57+
"--output {output} "
58+
"--ann-file {params.ro_input[ann_file]} "
6459

6560

6661
rule build_pet:
@@ -87,31 +82,26 @@ rule build_pet:
8782
tier="pet",
8883
ro_input=lambda _, input: {k: ro(v) for k, v in input.items()},
8984
log:
90-
get_pattern_log(config, f"tier_pet", time),
85+
get_pattern_log(config, "tier_pet", time),
9186
group:
9287
"tier-evt"
9388
resources:
9489
runtime=300,
9590
mem_swap=50,
96-
run:
97-
shell_string = (
98-
execenv_pyexe(config, "build-tier-evt") + f"--configs {ro(configs)} "
99-
f"--metadata {ro(meta)} "
100-
"--log {log} "
101-
"--tier {params.tier} "
102-
"--datatype {params.datatype} "
103-
"--timestamp {params.timestamp} "
104-
"--xtc-file {params.ro_input[xtalk_matrix]} "
105-
"--par-files {params.ro_input[par_files]} "
106-
"--hit-file {params.ro_input[hit_file]} "
107-
"--tcm-file {params.ro_input[tcm_file]} "
108-
"--dsp-file {params.ro_input[dsp_file]} "
109-
"--output {output} "
110-
)
111-
if input.ann_file is not None:
112-
shell_string += "--ann-file {params.ro_input[ann_file]} "
113-
114-
shell(shell_string)
91+
shell:
92+
execenv_pyexe(config, "build-tier-evt") + f"--configs {ro(configs)} "
93+
f"--metadata {ro(meta)} "
94+
"--log {log} "
95+
"--tier {params.tier} "
96+
"--datatype {params.datatype} "
97+
"--timestamp {params.timestamp} "
98+
"--xtc-file {params.ro_input[xtalk_matrix]} "
99+
"--par-files {params.ro_input[par_files]} "
100+
"--hit-file {params.ro_input[hit_file]} "
101+
"--tcm-file {params.ro_input[tcm_file]} "
102+
"--dsp-file {params.ro_input[dsp_file]} "
103+
"--output {output} "
104+
"--ann-file {params.ro_input[ann_file]} "
115105

116106

117107
for evt_tier in ("evt", "pet"):

workflow/rules/hit.smk

+4-3
Original file line numberDiff line numberDiff line change
@@ -33,12 +33,14 @@ rule build_hit:
3333
),
3434
output:
3535
tier_file=get_pattern_tier(config, "hit", check_in_cycle=check_in_cycle),
36-
db_file=get_pattern_pars_tmp(config, "hit_db"),
3736
params:
3837
timestamp="{timestamp}",
3938
datatype="{datatype}",
4039
tier="hit",
4140
ro_input=lambda _, input: {k: ro(v) for k, v in input.items()},
41+
table_map=lambda wildcards: get_table_mapping(
42+
channelmap_textdb, wildcards.timestamp, wildcards.datatype, "dsp"
43+
),
4244
log:
4345
get_pattern_log(config, "tier_hit", time),
4446
group:
@@ -47,12 +49,11 @@ rule build_hit:
4749
runtime=300,
4850
shell:
4951
execenv_pyexe(config, "build-tier-hit") + f"--configs {ro(configs)} "
50-
"--metadata {meta} "
52+
"--table-map '{params.table_map}' "
5153
"--tier {params.tier} "
5254
"--datatype {params.datatype} "
5355
"--timestamp {params.timestamp} "
5456
"--pars-file {params.ro_input[pars_file]} "
5557
"--output {output.tier_file} "
5658
"--input {params.ro_input[dsp_file]} "
57-
"--db-file {output.db_file} "
5859
"--log {log}"

0 commit comments

Comments
 (0)