Skip to content

Commit f3abb76

Browse files
Raise warnings when test_val is accessed
1 parent b7b309d commit f3abb76

File tree

12 files changed

+95
-1
lines changed

12 files changed

+95
-1
lines changed

pytensor/graph/basic.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -479,7 +479,10 @@ def get_test_value(self):
479479
if not hasattr(self.tag, "test_value"):
480480
detailed_err_msg = get_variable_trace_string(self)
481481
raise TestValueError(f"{self} has no test value {detailed_err_msg}")
482-
482+
warnings.warn(
483+
"test_value machinery is deprecated and will stop working in the future.",
484+
FutureWarning,
485+
)
483486
return self.tag.test_value
484487

485488
def __str__(self):

pytensor/graph/fg.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""A container for specifying and manipulating a graph with distinct inputs and outputs."""
22

33
import time
4+
import warnings
45
from collections import OrderedDict
56
from collections.abc import Iterable, Sequence
67
from typing import TYPE_CHECKING, Any, Literal, Union, cast
@@ -493,6 +494,10 @@ def replace(
493494
return
494495

495496
if config.compute_test_value != "off":
497+
warnings.warn(
498+
"test_value machinery is deprecated and will stop working in the future.",
499+
FutureWarning,
500+
)
496501
try:
497502
tval = pytensor.graph.op.get_test_value(var)
498503
new_tval = pytensor.graph.op.get_test_value(new_var)

pytensor/graph/op.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,12 @@ def compute_test_value(node: Apply):
6969
7070
"""
7171
# Gather the test values for each input of the node
72+
73+
warnings.warn(
74+
"compute_test_value is deprecated and will stop working in the future.",
75+
FutureWarning,
76+
)
77+
7278
storage_map = {}
7379
compute_map = {}
7480
for i, ins in enumerate(node.inputs):
@@ -301,6 +307,10 @@ def __call__(
301307
n.name = f"{name}_{i}"
302308

303309
if config.compute_test_value != "off":
310+
warnings.warn(
311+
"test_value machinery is deprecated and will stop working in the future.",
312+
FutureWarning,
313+
)
304314
compute_test_value(node)
305315

306316
if self.default_output is not None:
@@ -711,6 +721,11 @@ def get_test_values(*args: Variable) -> Any | list[Any]:
711721
if config.compute_test_value == "off":
712722
return []
713723

724+
warnings.warn(
725+
"test_value machinery is deprecated and will stop working in the future.",
726+
FutureWarning,
727+
)
728+
714729
rval = []
715730

716731
for i, arg in enumerate(args):

pytensor/graph/rewriting/basic.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -983,6 +983,10 @@ def transform(self, fgraph, node, *args, **kwargs):
983983
if isinstance(input, pytensor.compile.SharedVariable):
984984
pass
985985
elif hasattr(input.tag, "test_value"):
986+
warnings.warn(
987+
"compute_test_value is deprecated and will stop working in the future.",
988+
FutureWarning,
989+
)
986990
givens[input] = pytensor.shared(
987991
input.type.filter(input.tag.test_value),
988992
input.name,

pytensor/misc/pkl_utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import pickle
1010
import sys
1111
import tempfile
12+
import warnings
1213
import zipfile
1314
from collections import defaultdict
1415
from contextlib import closing
@@ -61,6 +62,10 @@ class StripPickler(Pickler):
6162
def __init__(self, file, protocol=0, extra_tag_to_remove=None):
6263
# Can't use super as Pickler isn't a new style class
6364
super().__init__(file, protocol)
65+
warnings.warn(
66+
"compute_test_value is deprecated and will stop working in the future.",
67+
FutureWarning,
68+
)
6469
self.tag_to_remove = ["trace", "test_value"]
6570
if extra_tag_to_remove:
6671
self.tag_to_remove.extend(extra_tag_to_remove)

pytensor/scalar/basic.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
import builtins
1414
import math
15+
import warnings
1516
from collections.abc import Callable, Mapping
1617
from copy import copy
1718
from itertools import chain
@@ -4414,6 +4415,10 @@ def apply(self, fgraph):
44144415
if i.dtype == "float16":
44154416
mapping[i] = get_scalar_type("float32")()
44164417
if hasattr(i.tag, "test_value"):
4418+
warnings.warn(
4419+
"test_value machinery is deprecated and will stop working in the future.",
4420+
FutureWarning,
4421+
)
44174422
mapping[i].tag.test_value = i.tag.test_value
44184423
else:
44194424
mapping[i] = i

pytensor/scan/basic.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -598,6 +598,10 @@ def wrap_into_list(x):
598598

599599
# Try to transfer test_value to the new variable
600600
if config.compute_test_value != "off":
601+
warnings.warn(
602+
"test_value machinery is deprecated and will stop working in the future.",
603+
FutureWarning,
604+
)
601605
try:
602606
nw_slice.tag.test_value = get_test_value(_seq_val_slice)
603607
except TestValueError:
@@ -725,6 +729,10 @@ def wrap_into_list(x):
725729

726730
# Try to transfer test_value to the new variable
727731
if config.compute_test_value != "off":
732+
warnings.warn(
733+
"test_value machinery is deprecated and will stop working in the future.",
734+
FutureWarning,
735+
)
728736
try:
729737
arg.tag.test_value = get_test_value(actual_arg)
730738
except TestValueError:
@@ -780,6 +788,10 @@ def wrap_into_list(x):
780788

781789
# Try to transfer test_value to the new variable
782790
if config.compute_test_value != "off":
791+
warnings.warn(
792+
"test_value machinery is deprecated and will stop working in the future.",
793+
FutureWarning,
794+
)
783795
try:
784796
nw_slice.tag.test_value = get_test_value(_init_out_var_slice)
785797
except TestValueError:

pytensor/scan/op.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
import dataclasses
4747
import logging
4848
import time
49+
import warnings
4950
from collections import OrderedDict
5051
from collections.abc import Callable, Iterable
5152
from copy import copy
@@ -2650,6 +2651,10 @@ def compute_all_gradients(known_grads):
26502651
# fct add and we want to keep it for all Scan op. This is
26512652
# used in T_Scan.test_grad_multiple_outs_taps to test
26522653
# that.
2654+
warnings.warn(
2655+
"test_value machinery is deprecated and will stop working in the future.",
2656+
FutureWarning,
2657+
)
26532658
if info.as_while:
26542659
n = n_steps.tag.test_value
26552660
else:

pytensor/scan/rewriting.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import copy
44
import dataclasses
5+
import warnings
56
from itertools import chain
67
from sys import maxsize
78
from typing import cast
@@ -305,6 +306,10 @@ def add_to_replace(y):
305306
pushed_out_node = nd.op.make_node(*new_inputs)
306307

307308
if config.compute_test_value != "off":
309+
warnings.warn(
310+
"test_value machinery is deprecated and will stop working in the future.",
311+
FutureWarning,
312+
)
308313
compute_test_value(pushed_out_node)
309314

310315
# Step 2. Create variables to replace the old outputs of the node
@@ -511,6 +516,10 @@ def add_to_replace(y):
511516
nw_outer_node = nd.op.make_node(*outside_ins)
512517

513518
if config.compute_test_value != "off":
519+
warnings.warn(
520+
"test_value machinery is deprecated and will stop working in the future.",
521+
FutureWarning,
522+
)
514523
compute_test_value(nw_outer_node)
515524

516525
# Step 2. Create variables for replacements
@@ -545,6 +554,10 @@ def add_to_replace(y):
545554
replace_with_out.append(new_outer)
546555

547556
if hasattr(new_outer.tag, "test_value"):
557+
warnings.warn(
558+
"test_value machinery is deprecated and will stop working in the future.",
559+
FutureWarning,
560+
)
548561
new_sh = new_outer.tag.test_value.shape
549562
ref_sh = (outside_ins.tag.test_value.shape[0],)
550563
ref_sh += nd.outputs[0].tag.test_value.shape
@@ -982,6 +995,10 @@ def attempt_scan_inplace(
982995
new_lsi = inp.owner.op.make_node(*inp.owner.inputs)
983996

984997
if config.compute_test_value != "off":
998+
warnings.warn(
999+
"test_value machinery is deprecated and will stop working in the future.",
1000+
FutureWarning,
1001+
)
9851002
compute_test_value(new_lsi)
9861003

9871004
new_lsi_out = new_lsi.outputs

pytensor/scan/utils.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import copy
44
import dataclasses
55
import logging
6+
import warnings
67
from collections import OrderedDict, namedtuple
78
from collections.abc import Callable, Sequence
89
from itertools import chain
@@ -74,6 +75,10 @@ def safe_new(
7475
nw_x.name = nw_name
7576
if config.compute_test_value != "off":
7677
# Copy test value, cast it if necessary
78+
warnings.warn(
79+
"test_value machinery is deprecated and will stop working in the future.",
80+
FutureWarning,
81+
)
7782
try:
7883
x_test_value = get_test_value(x)
7984
except TestValueError:
@@ -104,6 +109,10 @@ def safe_new(
104109
# between test values, due to inplace operations for instance. This may
105110
# not be the most efficient memory-wise, though.
106111
if config.compute_test_value != "off":
112+
warnings.warn(
113+
"test_value machinery is deprecated and will stop working in the future.",
114+
FutureWarning,
115+
)
107116
try:
108117
nw_x.tag.test_value = copy.deepcopy(get_test_value(x))
109118
except TestValueError:

pytensor/tensor/blas.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@
7878
import logging
7979
import os
8080
import time
81+
import warnings
8182

8283
import numpy as np
8384

@@ -1967,6 +1968,10 @@ def R_op(self, inputs, eval_points):
19671968
test_values_enabled = config.compute_test_value != "off"
19681969

19691970
if test_values_enabled:
1971+
warnings.warn(
1972+
"test_value machinery is deprecated and will stop working in the future.",
1973+
FutureWarning,
1974+
)
19701975
try:
19711976
iv0 = pytensor.graph.op.get_test_value(inputs[0])
19721977
except TestValueError:

pytensor/tensor/random/rewriting/basic.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import warnings
12
from itertools import chain
23

34
from pytensor.compile import optdb
@@ -108,6 +109,10 @@ def local_rv_size_lift(fgraph, node):
108109
new_node = node.op.make_node(rng, None, *dist_params)
109110

110111
if config.compute_test_value != "off":
112+
warnings.warn(
113+
"test_value machinery is deprecated and will stop working in the future.",
114+
FutureWarning,
115+
)
111116
compute_test_value(new_node)
112117

113118
return new_node.outputs
@@ -187,6 +192,10 @@ def local_dimshuffle_rv_lift(fgraph, node):
187192
new_node = rv_op.make_node(rng, new_size, *new_dist_params)
188193

189194
if config.compute_test_value != "off":
195+
warnings.warn(
196+
"test_value machinery is deprecated and will stop working in the future.",
197+
FutureWarning,
198+
)
190199
compute_test_value(new_node)
191200

192201
out = new_node.outputs[1]

0 commit comments

Comments
 (0)