Skip to content

Commit 10f285a

Browse files
committed
Use generators when appropriate
1 parent 8ae2a19 commit 10f285a

31 files changed

+96
-134
lines changed

pytensor/configparser.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def get_config_hash(self):
104104
)
105105
return hash_from_code(
106106
"\n".join(
107-
[f"{cv.name} = {cv.__get__(self, self.__class__)}" for cv in all_opts]
107+
f"{cv.name} = {cv.__get__(self, self.__class__)}" for cv in all_opts
108108
)
109109
)
110110

pytensor/d3viz/formatting.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -360,7 +360,7 @@ def dict_to_pdnode(d):
360360
for k, v in d.items():
361361
if v is not None:
362362
if isinstance(v, list):
363-
v = "\t".join([str(x) for x in v])
363+
v = "\t".join(str(x) for x in v)
364364
else:
365365
v = str(v)
366366
v = str(v)

pytensor/graph/rewriting/basic.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1264,7 +1264,7 @@ def __str__(self):
12641264
return getattr(
12651265
self,
12661266
"__name__",
1267-
f"{type(self).__name__}({','.join([str(o) for o in self.rewrites])})",
1267+
f"{type(self).__name__}({','.join(str(o) for o in self.rewrites)})",
12681268
)
12691269

12701270
def tracks(self):
@@ -1666,7 +1666,7 @@ def pattern_to_str(pattern):
16661666
if isinstance(pattern, list | tuple):
16671667
return "{}({})".format(
16681668
str(pattern[0]),
1669-
", ".join([pattern_to_str(p) for p in pattern[1:]]),
1669+
", ".join(pattern_to_str(p) for p in pattern[1:]),
16701670
)
16711671
elif isinstance(pattern, dict):
16721672
return "{} subject to {}".format(
@@ -2569,7 +2569,7 @@ def print_profile(cls, stream, prof, level=0):
25692569
d = sorted(
25702570
loop_process_count[i].items(), key=lambda a: a[1], reverse=True
25712571
)
2572-
loop_times = " ".join([str((str(k), v)) for k, v in d[:5]])
2572+
loop_times = " ".join(str((str(k), v)) for k, v in d[:5])
25732573
if len(d) > 5:
25742574
loop_times += " ..."
25752575
print(

pytensor/link/c/basic.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -235,16 +235,16 @@ def struct_gen(args, struct_builders, blocks, sub):
235235
behavior = code_gen(blocks)
236236

237237
# declares the storage
238-
storage_decl = "\n".join([f"PyObject* {arg};" for arg in args])
238+
storage_decl = "\n".join(f"PyObject* {arg};" for arg in args)
239239
# in the constructor, sets the storage to the arguments
240-
storage_set = "\n".join([f"this->{arg} = {arg};" for arg in args])
240+
storage_set = "\n".join(f"this->{arg} = {arg};" for arg in args)
241241
# increments the storage's refcount in the constructor
242-
storage_incref = "\n".join([f"Py_XINCREF({arg});" for arg in args])
242+
storage_incref = "\n".join(f"Py_XINCREF({arg});" for arg in args)
243243
# decrements the storage's refcount in the destructor
244-
storage_decref = "\n".join([f"Py_XDECREF(this->{arg});" for arg in args])
244+
storage_decref = "\n".join(f"Py_XDECREF(this->{arg});" for arg in args)
245245

246246
args_names = ", ".join(args)
247-
args_decl = ", ".join([f"PyObject* {arg}" for arg in args])
247+
args_decl = ", ".join(f"PyObject* {arg}" for arg in args)
248248

249249
# The following code stores the exception data in __ERROR, which
250250
# is a special field of the struct. __ERROR is a list of length 3

pytensor/link/c/cmodule.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2003,7 +2003,7 @@ def try_blas_flag(flags):
20032003
cflags = list(flags)
20042004
# to support path that includes spaces, we need to wrap it with double quotes on Windows
20052005
path_wrapper = '"' if os.name == "nt" else ""
2006-
cflags.extend([f"-L{path_wrapper}{d}{path_wrapper}" for d in std_lib_dirs()])
2006+
cflags.extend(f"-L{path_wrapper}{d}{path_wrapper}" for d in std_lib_dirs())
20072007

20082008
res = GCC_compiler.try_compile_tmp(
20092009
test_code, tmp_prefix="try_blas_", flags=cflags, try_run=True
@@ -2573,8 +2573,8 @@ def compile_str(
25732573
cmd.extend(preargs)
25742574
# to support path that includes spaces, we need to wrap it with double quotes on Windows
25752575
path_wrapper = '"' if os.name == "nt" else ""
2576-
cmd.extend([f"-I{path_wrapper}{idir}{path_wrapper}" for idir in include_dirs])
2577-
cmd.extend([f"-L{path_wrapper}{ldir}{path_wrapper}" for ldir in lib_dirs])
2576+
cmd.extend(f"-I{path_wrapper}{idir}{path_wrapper}" for idir in include_dirs)
2577+
cmd.extend(f"-L{path_wrapper}{ldir}{path_wrapper}" for ldir in lib_dirs)
25782578
if hide_symbols and sys.platform != "win32":
25792579
# This has been available since gcc 4.0 so we suppose it
25802580
# is always available. We pass it here since it

pytensor/link/c/params_type.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -263,9 +263,7 @@ def __init__(self, params_type, **kwargs):
263263

264264
def __repr__(self):
265265
return "Params({})".format(
266-
", ".join(
267-
[(f"{k}:{type(self[k]).__name__}:{self[k]}") for k in sorted(self)]
268-
)
266+
", ".join((f"{k}:{type(self[k]).__name__}:{self[k]}") for k in sorted(self))
269267
)
270268

271269
def __getattr__(self, key):
@@ -425,9 +423,7 @@ def __getattr__(self, key):
425423

426424
def __repr__(self):
427425
return "ParamsType<{}>".format(
428-
", ".join(
429-
[(f"{self.fields[i]}:{self.types[i]}") for i in range(self.length)]
430-
)
426+
", ".join((f"{self.fields[i]}:{self.types[i]}") for i in range(self.length))
431427
)
432428

433429
def __eq__(self, other):
@@ -748,10 +744,8 @@ def c_support_code(self, **kwargs):
748744
}}
749745
""".format(
750746
"\n".join(
751-
[
752-
("case %d: extract_%s(object); break;" % (i, self.fields[i]))
753-
for i in range(self.length)
754-
]
747+
("case %d: extract_%s(object); break;" % (i, self.fields[i]))
748+
for i in range(self.length)
755749
)
756750
)
757751
final_struct_code = """

pytensor/link/numba/dispatch/elemwise.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -485,8 +485,8 @@ def numba_funcify_Elemwise(op, node, **kwargs):
485485
nout = len(node.outputs)
486486
core_op_fn = store_core_outputs(scalar_op_fn, nin=nin, nout=nout)
487487

488-
input_bc_patterns = tuple([inp.type.broadcastable for inp in node.inputs])
489-
output_bc_patterns = tuple([out.type.broadcastable for out in node.outputs])
488+
input_bc_patterns = tuple(inp.type.broadcastable for inp in node.inputs)
489+
output_bc_patterns = tuple(out.type.broadcastable for out in node.outputs)
490490
output_dtypes = tuple(out.type.dtype for out in node.outputs)
491491
inplace_pattern = tuple(op.inplace_pattern.items())
492492
core_output_shapes = tuple(() for _ in range(nout))

pytensor/link/numba/dispatch/scalar.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -85,9 +85,7 @@ def numba_funcify_ScalarOp(op, node, **kwargs):
8585
unique_names = unique_name_generator(
8686
[scalar_op_fn_name, "scalar_func_numba"], suffix_sep="_"
8787
)
88-
input_names = ", ".join(
89-
[unique_names(v, force_unique=True) for v in node.inputs]
90-
)
88+
input_names = ", ".join(unique_names(v, force_unique=True) for v in node.inputs)
9189
if not has_pyx_skip_dispatch:
9290
scalar_op_src = f"""
9391
def {scalar_op_fn_name}({input_names}):
@@ -115,10 +113,8 @@ def {scalar_op_fn_name}({input_names}):
115113

116114
input_names = [unique_names(v, force_unique=True) for v in node.inputs]
117115
converted_call_args = ", ".join(
118-
[
119-
f"direct_cast({i_name}, {i_tmp_dtype_name})"
120-
for i_name, i_tmp_dtype_name in zip(input_names, input_tmp_dtype_names)
121-
]
116+
f"direct_cast({i_name}, {i_tmp_dtype_name})"
117+
for i_name, i_tmp_dtype_name in zip(input_names, input_tmp_dtype_names)
122118
)
123119
if not has_pyx_skip_dispatch:
124120
scalar_op_src = f"""

pytensor/link/numba/dispatch/scan.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -373,7 +373,7 @@ def add_output_storage_post_proc_stmt(
373373
inner_out_post_processing_block = "\n".join(inner_out_post_processing_stmts)
374374

375375
inner_out_to_outer_out_stmts = "\n".join(
376-
[f"{s} = {d}" for s, d in zip(inner_out_to_outer_in_stmts, inner_output_names)]
376+
f"{s} = {d}" for s, d in zip(inner_out_to_outer_in_stmts, inner_output_names)
377377
)
378378

379379
scan_op_src = f"""

pytensor/link/numba/dispatch/tensor_basic.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,8 @@ def numba_funcify_AllocEmpty(op, node, **kwargs):
3535
shape_var_item_names = [f"{name}_item" for name in shape_var_names]
3636
shapes_to_items_src = indent(
3737
"\n".join(
38-
[
39-
f"{item_name} = to_scalar({shape_name})"
40-
for item_name, shape_name in zip(shape_var_item_names, shape_var_names)
41-
]
38+
f"{item_name} = to_scalar({shape_name})"
39+
for item_name, shape_name in zip(shape_var_item_names, shape_var_names)
4240
),
4341
" " * 4,
4442
)
@@ -69,10 +67,8 @@ def numba_funcify_Alloc(op, node, **kwargs):
6967
shape_var_item_names = [f"{name}_item" for name in shape_var_names]
7068
shapes_to_items_src = indent(
7169
"\n".join(
72-
[
73-
f"{item_name} = to_scalar({shape_name})"
74-
for item_name, shape_name in zip(shape_var_item_names, shape_var_names)
75-
]
70+
f"{item_name} = to_scalar({shape_name})"
71+
for item_name, shape_name in zip(shape_var_item_names, shape_var_names)
7672
),
7773
" " * 4,
7874
)

pytensor/link/numba/dispatch/vectorize_codegen.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,10 +43,8 @@ def store_core_outputs(i0, i1, ..., in, o0, o1, ..., on):
4343
out_signature = ", ".join(outputs)
4444
inner_out_signature = ", ".join(inner_outputs)
4545
store_outputs = "\n".join(
46-
[
47-
f"{output}[...] = {inner_output}"
48-
for output, inner_output in zip(outputs, inner_outputs)
49-
]
46+
f"{output}[...] = {inner_output}"
47+
for output, inner_output in zip(outputs, inner_outputs)
5048
)
5149
func_src = f"""
5250
def store_core_outputs({inp_signature}, {out_signature}):

pytensor/link/vm.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1112,7 +1112,7 @@ def make_vm(
11121112
for i, node in enumerate(nodes):
11131113
prereq_var_idxs = []
11141114
for prereq_node in ords.get(node, []):
1115-
prereq_var_idxs.extend([vars_idx[v] for v in prereq_node.outputs])
1115+
prereq_var_idxs.extend(vars_idx[v] for v in prereq_node.outputs)
11161116
prereq_var_idxs = list(set(prereq_var_idxs))
11171117
prereq_var_idxs.sort() # TODO: why sort?
11181118
node_prereqs.append(prereq_var_idxs)
@@ -1323,9 +1323,7 @@ def __setstate__(self, d):
13231323

13241324
def __repr__(self):
13251325
args_str = ", ".join(
1326-
[
1327-
f"{name}={getattr(self, name)}"
1328-
for name in ("use_cloop", "lazy", "allow_partial_eval", "allow_gc")
1329-
]
1326+
f"{name}={getattr(self, name)}"
1327+
for name in ("use_cloop", "lazy", "allow_partial_eval", "allow_gc")
13301328
)
13311329
return f"{type(self).__name__}({args_str})"

pytensor/misc/check_duplicate_key.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,10 @@
99
DISPLAY_DUPLICATE_KEYS = False
1010
DISPLAY_MOST_FREQUENT_DUPLICATE_CCODE = False
1111

12-
dirs = []
12+
dirs: list = []
1313
if len(sys.argv) > 1:
1414
for compiledir in sys.argv[1:]:
15-
dirs.extend([os.path.join(compiledir, d) for d in os.listdir(compiledir)])
15+
dirs.extend(os.path.join(compiledir, d) for d in os.listdir(compiledir))
1616
else:
1717
dirs = os.listdir(config.compiledir)
1818
dirs = [os.path.join(config.compiledir, d) for d in dirs]

pytensor/printing.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -229,32 +229,32 @@ def debugprint(
229229
topo_orders.append(None)
230230
elif isinstance(obj, Apply):
231231
outputs_to_print.extend(obj.outputs)
232-
profile_list.extend([None for item in obj.outputs])
233-
storage_maps.extend([None for item in obj.outputs])
234-
topo_orders.extend([None for item in obj.outputs])
232+
profile_list.extend(None for item in obj.outputs)
233+
storage_maps.extend(None for item in obj.outputs)
234+
topo_orders.extend(None for item in obj.outputs)
235235
elif isinstance(obj, Function):
236236
if print_fgraph_inputs:
237237
inputs_to_print.extend(obj.maker.fgraph.inputs)
238238
outputs_to_print.extend(obj.maker.fgraph.outputs)
239-
profile_list.extend([obj.profile for item in obj.maker.fgraph.outputs])
239+
profile_list.extend(obj.profile for item in obj.maker.fgraph.outputs)
240240
if print_storage:
241241
storage_maps.extend(
242-
[obj.vm.storage_map for item in obj.maker.fgraph.outputs]
242+
obj.vm.storage_map for item in obj.maker.fgraph.outputs
243243
)
244244
else:
245-
storage_maps.extend([None for item in obj.maker.fgraph.outputs])
245+
storage_maps.extend(None for item in obj.maker.fgraph.outputs)
246246
topo = obj.maker.fgraph.toposort()
247-
topo_orders.extend([topo for item in obj.maker.fgraph.outputs])
247+
topo_orders.extend(topo for item in obj.maker.fgraph.outputs)
248248
elif isinstance(obj, FunctionGraph):
249249
if print_fgraph_inputs:
250250
inputs_to_print.extend(obj.inputs)
251251
outputs_to_print.extend(obj.outputs)
252-
profile_list.extend([getattr(obj, "profile", None) for item in obj.outputs])
252+
profile_list.extend(getattr(obj, "profile", None) for item in obj.outputs)
253253
storage_maps.extend(
254-
[getattr(obj, "storage_map", None) for item in obj.outputs]
254+
getattr(obj, "storage_map", None) for item in obj.outputs
255255
)
256256
topo = obj.toposort()
257-
topo_orders.extend([topo for item in obj.outputs])
257+
topo_orders.extend(topo for item in obj.outputs)
258258
elif isinstance(obj, int | float | np.ndarray):
259259
print(obj, file=_file)
260260
elif isinstance(obj, In | Out):
@@ -980,10 +980,10 @@ def process(self, output, pstate):
980980
name = self.names[idx]
981981
with set_precedence(pstate):
982982
inputs_str = ", ".join(
983-
[pprinter.process(input, pstate) for input in node.inputs]
983+
pprinter.process(input, pstate) for input in node.inputs
984984
)
985985
keywords_str = ", ".join(
986-
[f"{kw}={getattr(node.op, kw)}" for kw in self.keywords]
986+
f"{kw}={getattr(node.op, kw)}" for kw in self.keywords
987987
)
988988

989989
if keywords_str and inputs_str:
@@ -1050,7 +1050,7 @@ def process(self, output, pstate):
10501050
with set_precedence(pstate):
10511051
r = "{}({})".format(
10521052
str(node.op),
1053-
", ".join([pprinter.process(input, pstate) for input in node.inputs]),
1053+
", ".join(pprinter.process(input, pstate) for input in node.inputs),
10541054
)
10551055

10561056
pstate.memo[output] = r

pytensor/scalar/basic.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4224,8 +4224,8 @@ def __init__(self, inputs, outputs, name="Composite"):
42244224
inputs, outputs = res[0], res2[1]
42254225

42264226
self.inputs, self.outputs = self._cleanup_graph(inputs, outputs)
4227-
self.inputs_type = tuple([input.type for input in self.inputs])
4228-
self.outputs_type = tuple([output.type for output in self.outputs])
4227+
self.inputs_type = tuple(input.type for input in self.inputs)
4228+
self.outputs_type = tuple(output.type for output in self.outputs)
42294229
self.nin = len(inputs)
42304230
self.nout = len(outputs)
42314231
super().__init__()
@@ -4247,7 +4247,7 @@ def __str__(self):
42474247
if len(self.fgraph.outputs) > 1 or len(self.fgraph.apply_nodes) > 10:
42484248
self._name = "Composite{...}"
42494249
else:
4250-
outputs_str = ", ".join([pprint(output) for output in self.fgraph.outputs])
4250+
outputs_str = ", ".join(pprint(output) for output in self.fgraph.outputs)
42514251
self._name = f"Composite{{{outputs_str}}}"
42524252

42534253
return self._name
@@ -4295,7 +4295,7 @@ def output_types(self, input_types):
42954295
return self.outputs_type
42964296

42974297
def make_node(self, *inputs):
4298-
if tuple([i.type for i in self.inputs]) == tuple([i.type for i in inputs]):
4298+
if tuple(i.type for i in self.inputs) == tuple(i.type for i in inputs):
42994299
return super().make_node(*inputs)
43004300
else:
43014301
# Make a new op with the right input type.

pytensor/scalar/loop.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ def make_node(self, n_steps, *inputs):
160160
f"Got {n_steps.type.dtype}",
161161
)
162162

163-
if self.inputs_type == tuple([i.type for i in inputs]):
163+
if self.inputs_type == tuple(i.type for i in inputs):
164164
return super().make_node(n_steps, *inputs)
165165
else:
166166
# Make a new op with the right input types.

pytensor/scan/rewriting.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1936,7 +1936,7 @@ def merge(self, nodes):
19361936
profile=old_op.profile,
19371937
truncate_gradient=old_op.truncate_gradient,
19381938
allow_gc=old_op.allow_gc,
1939-
name="&".join([nd.op.name for nd in nodes]),
1939+
name="&".join(nd.op.name for nd in nodes),
19401940
)
19411941
new_outs = new_op(*outer_ins)
19421942

pytensor/scan/utils.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -749,15 +749,13 @@ def var_mappings(self):
749749
def field_names(self):
750750
res = ["mit_mot_out_slices", "mit_mot_in_slices", "mit_sot_in_slices"]
751751
res.extend(
752-
[
753-
attr
754-
for attr in self.__dict__
755-
if attr.startswith("inner_in")
756-
or attr.startswith("inner_out")
757-
or attr.startswith("outer_in")
758-
or attr.startswith("outer_out")
759-
or attr == "n_steps"
760-
]
752+
attr
753+
for attr in self.__dict__
754+
if attr.startswith("inner_in")
755+
or attr.startswith("inner_out")
756+
or attr.startswith("outer_in")
757+
or attr.startswith("outer_out")
758+
or attr == "n_steps"
761759
)
762760
return res
763761

pytensor/tensor/basic.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1554,7 +1554,7 @@ def _check_runtime_broadcast(node, value, shape):
15541554
def perform(self, node, inputs, out_):
15551555
(out,) = out_
15561556
v = inputs[0]
1557-
sh = tuple([int(i) for i in inputs[1:]])
1557+
sh = tuple(int(i) for i in inputs[1:])
15581558
self._check_runtime_broadcast(node, v, sh)
15591559

15601560
if out[0] is None or out[0].shape != sh:
@@ -4180,7 +4180,7 @@ def debug_perform(self, node, inputs, out_):
41804180

41814181
def perform(self, node, inputs, out_):
41824182
(out,) = out_
4183-
sh = tuple([int(i) for i in inputs])
4183+
sh = tuple(int(i) for i in inputs)
41844184
if out[0] is None or out[0].shape != sh:
41854185
out[0] = np.empty(sh, dtype=self.dtype)
41864186

0 commit comments

Comments
 (0)