Skip to content

Commit 50011dc

Browse files
kurtamohlerVincent Moens
authored andcommitted
[Feature] Add Hash transform
ghstack-source-id: dccf63f Pull Request resolved: #2648
1 parent 1a6c9e2 commit 50011dc

File tree

7 files changed

+610
-2
lines changed

7 files changed

+610
-2
lines changed

docs/source/reference/envs.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -827,6 +827,7 @@ to be able to create this other composition:
827827
FlattenObservation
828828
FrameSkipTransform
829829
GrayScale
830+
Hash
830831
InitTracker
831832
KLRewardTransform
832833
LineariseReward
@@ -853,6 +854,7 @@ to be able to create this other composition:
853854
TimeMaxPool
854855
ToTensorImage
855856
TrajCounter
857+
UnaryTransform
856858
UnsqueezeTransform
857859
VC1Transform
858860
VIPRewardTransform

test/mocking_classes.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
# LICENSE file in the root directory of this source tree.
55
from __future__ import annotations
66

7+
import random
8+
import string
79
from typing import Dict, List, Optional
810

911
import torch
@@ -1066,6 +1068,34 @@ def _step(
10661068
return tensordict
10671069

10681070

1071+
class CountingEnvWithString(CountingEnv):
1072+
def __init__(self, *args, **kwargs):
1073+
super().__init__(*args, **kwargs)
1074+
self.observation_spec.set(
1075+
"string",
1076+
NonTensor(
1077+
shape=self.batch_size,
1078+
device=self.device,
1079+
),
1080+
)
1081+
1082+
def get_random_string(self):
1083+
size = random.randint(4, 30)
1084+
return "".join(random.choice(string.ascii_lowercase) for _ in range(size))
1085+
1086+
def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
1087+
res = super()._reset(tensordict, **kwargs)
1088+
random_string = self.get_random_string()
1089+
res["string"] = random_string
1090+
return res
1091+
1092+
def _step(self, tensordict: TensorDictBase) -> TensorDictBase:
1093+
res = super()._step(tensordict)
1094+
random_string = self.get_random_string()
1095+
res["string"] = random_string
1096+
return res
1097+
1098+
10691099
class MultiAgentCountingEnv(EnvBase):
10701100
"""A multi-agent env that is done after a given number of steps.
10711101

test/test_transforms.py

Lines changed: 244 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
CountingBatchedEnv,
4040
CountingEnv,
4141
CountingEnvCountPolicy,
42+
CountingEnvWithString,
4243
DiscreteActionConvMockEnv,
4344
DiscreteActionConvMockEnvNumpy,
4445
EnvWithScalarAction,
@@ -66,6 +67,7 @@
6667
CountingBatchedEnv,
6768
CountingEnv,
6869
CountingEnvCountPolicy,
70+
CountingEnvWithString,
6971
DiscreteActionConvMockEnv,
7072
DiscreteActionConvMockEnvNumpy,
7173
EnvWithScalarAction,
@@ -77,7 +79,7 @@
7779
MultiKeyCountingEnvPolicy,
7880
NestedCountingEnv,
7981
)
80-
from tensordict import TensorDict, TensorDictBase, unravel_key
82+
from tensordict import NonTensorData, TensorDict, TensorDictBase, unravel_key
8183
from tensordict.nn import TensorDictSequential
8284
from tensordict.utils import _unravel_key_to_tuple, assert_allclose_td
8385
from torch import multiprocessing as mp, nn, Tensor
@@ -118,6 +120,7 @@
118120
FrameSkipTransform,
119121
GrayScale,
120122
gSDENoise,
123+
Hash,
121124
InitTracker,
122125
LineariseRewards,
123126
MultiStepTransform,
@@ -2180,6 +2183,246 @@ def test_transform_no_env(self, device, batch):
21802183
pytest.skip("TrajCounter cannot be called without env")
21812184

21822185

2186+
class TestHash(TransformBase):
2187+
@pytest.mark.parametrize("datatype", ["tensor", "str", "NonTensorStack"])
2188+
def test_transform_no_env(self, datatype):
2189+
if datatype == "tensor":
2190+
obs = torch.tensor(10)
2191+
hash_fn = hash
2192+
elif datatype == "str":
2193+
obs = "abcdefg"
2194+
hash_fn = Hash.reproducible_hash
2195+
elif datatype == "NonTensorStack":
2196+
obs = torch.stack(
2197+
[
2198+
NonTensorData(data="abcde"),
2199+
NonTensorData(data="fghij"),
2200+
NonTensorData(data="klmno"),
2201+
]
2202+
)
2203+
2204+
def fn0(x):
2205+
return torch.stack([Hash.reproducible_hash(x_) for x_ in x])
2206+
2207+
hash_fn = fn0
2208+
else:
2209+
raise RuntimeError(f"please add a test case for datatype {datatype}")
2210+
2211+
td = TensorDict(
2212+
{
2213+
"observation": obs,
2214+
}
2215+
)
2216+
2217+
t = Hash(in_keys=["observation"], out_keys=["hashing"], hash_fn=hash_fn)
2218+
td_hashed = t(td)
2219+
2220+
assert td_hashed.get("observation") is td.get("observation")
2221+
2222+
if datatype == "NonTensorStack":
2223+
assert (
2224+
td_hashed["hashing"] == hash_fn(td.get("observation").tolist())
2225+
).all()
2226+
elif datatype == "str":
2227+
assert all(td_hashed["hashing"] == hash_fn(td["observation"]))
2228+
else:
2229+
assert td_hashed["hashing"] == hash_fn(td["observation"])
2230+
2231+
@pytest.mark.parametrize("datatype", ["tensor", "str"])
2232+
def test_single_trans_env_check(self, datatype):
2233+
if datatype == "tensor":
2234+
t = Hash(
2235+
in_keys=["observation"],
2236+
out_keys=["hashing"],
2237+
hash_fn=hash,
2238+
)
2239+
base_env = CountingEnv()
2240+
elif datatype == "str":
2241+
t = Hash(
2242+
in_keys=["string"],
2243+
out_keys=["hashing"],
2244+
)
2245+
base_env = CountingEnvWithString()
2246+
env = TransformedEnv(base_env, t)
2247+
check_env_specs(env)
2248+
2249+
@pytest.mark.parametrize("datatype", ["tensor", "str"])
2250+
def test_serial_trans_env_check(self, datatype):
2251+
def make_env():
2252+
if datatype == "tensor":
2253+
t = Hash(
2254+
in_keys=["observation"],
2255+
out_keys=["hashing"],
2256+
hash_fn=hash,
2257+
)
2258+
base_env = CountingEnv()
2259+
2260+
elif datatype == "str":
2261+
t = Hash(
2262+
in_keys=["string"],
2263+
out_keys=["hashing"],
2264+
)
2265+
base_env = CountingEnvWithString()
2266+
2267+
return TransformedEnv(base_env, t)
2268+
2269+
env = SerialEnv(2, make_env)
2270+
check_env_specs(env)
2271+
2272+
@pytest.mark.parametrize("datatype", ["tensor", "str"])
2273+
def test_parallel_trans_env_check(self, maybe_fork_ParallelEnv, datatype):
2274+
def make_env():
2275+
if datatype == "tensor":
2276+
t = Hash(
2277+
in_keys=["observation"],
2278+
out_keys=["hashing"],
2279+
hash_fn=hash,
2280+
)
2281+
base_env = CountingEnv()
2282+
elif datatype == "str":
2283+
t = Hash(
2284+
in_keys=["string"],
2285+
out_keys=["hashing"],
2286+
)
2287+
base_env = CountingEnvWithString()
2288+
return TransformedEnv(base_env, t)
2289+
2290+
env = maybe_fork_ParallelEnv(2, make_env)
2291+
try:
2292+
check_env_specs(env)
2293+
finally:
2294+
try:
2295+
env.close()
2296+
except RuntimeError:
2297+
pass
2298+
2299+
@pytest.mark.parametrize("datatype", ["tensor", "str"])
2300+
def test_trans_serial_env_check(self, datatype):
2301+
if datatype == "tensor":
2302+
t = Hash(
2303+
in_keys=["observation"],
2304+
out_keys=["hashing"],
2305+
hash_fn=lambda x: [hash(x[0]), hash(x[1])],
2306+
)
2307+
base_env = CountingEnv
2308+
elif datatype == "str":
2309+
t = Hash(
2310+
in_keys=["string"],
2311+
out_keys=["hashing"],
2312+
hash_fn=lambda x: torch.stack([Hash.reproducible_hash(x_) for x_ in x]),
2313+
)
2314+
base_env = CountingEnvWithString
2315+
2316+
env = TransformedEnv(SerialEnv(2, base_env), t)
2317+
check_env_specs(env)
2318+
2319+
@pytest.mark.parametrize("datatype", ["tensor", "str"])
2320+
def test_trans_parallel_env_check(self, maybe_fork_ParallelEnv, datatype):
2321+
if datatype == "tensor":
2322+
t = Hash(
2323+
in_keys=["observation"],
2324+
out_keys=["hashing"],
2325+
hash_fn=lambda x: [hash(x[0]), hash(x[1])],
2326+
)
2327+
base_env = CountingEnv
2328+
elif datatype == "str":
2329+
t = Hash(
2330+
in_keys=["string"],
2331+
out_keys=["hashing"],
2332+
hash_fn=lambda x: torch.stack([Hash.reproducible_hash(x_) for x_ in x]),
2333+
)
2334+
base_env = CountingEnvWithString
2335+
2336+
env = TransformedEnv(maybe_fork_ParallelEnv(2, base_env), t)
2337+
try:
2338+
check_env_specs(env)
2339+
finally:
2340+
try:
2341+
env.close()
2342+
except RuntimeError:
2343+
pass
2344+
2345+
@pytest.mark.parametrize("datatype", ["tensor", "str"])
2346+
def test_transform_compose(self, datatype):
2347+
if datatype == "tensor":
2348+
obs = torch.tensor(10)
2349+
elif datatype == "str":
2350+
obs = "abcdefg"
2351+
2352+
td = TensorDict(
2353+
{
2354+
"observation": obs,
2355+
}
2356+
)
2357+
t = Hash(
2358+
in_keys=["observation"],
2359+
out_keys=["hashing"],
2360+
hash_fn=hash,
2361+
)
2362+
t = Compose(t)
2363+
td_hashed = t(td)
2364+
2365+
assert td_hashed["observation"] is td["observation"]
2366+
assert td_hashed["hashing"] == hash(td["observation"])
2367+
2368+
def test_transform_model(self):
2369+
t = Hash(
2370+
in_keys=[("next", "observation"), ("observation",)],
2371+
out_keys=[("next", "hashing"), ("hashing",)],
2372+
hash_fn=hash,
2373+
)
2374+
model = nn.Sequential(t, nn.Identity())
2375+
td = TensorDict(
2376+
{("next", "observation"): torch.randn(3), "observation": torch.randn(3)}, []
2377+
)
2378+
td_out = model(td)
2379+
assert ("next", "hashing") in td_out.keys(True)
2380+
assert ("hashing",) in td_out.keys(True)
2381+
assert td_out["next", "hashing"] == hash(td["next", "observation"])
2382+
assert td_out["hashing"] == hash(td["observation"])
2383+
2384+
@pytest.mark.skipif(not _has_gym, reason="Gym not found")
2385+
def test_transform_env(self):
2386+
t = Hash(
2387+
in_keys=["observation"],
2388+
out_keys=["hashing"],
2389+
hash_fn=hash,
2390+
)
2391+
env = TransformedEnv(GymEnv(PENDULUM_VERSIONED()), t)
2392+
assert env.observation_spec["hashing"]
2393+
assert "observation" in env.observation_spec
2394+
assert "observation" in env.base_env.observation_spec
2395+
check_env_specs(env)
2396+
2397+
@pytest.mark.parametrize("rbclass", [ReplayBuffer, TensorDictReplayBuffer])
2398+
def test_transform_rb(self, rbclass):
2399+
t = Hash(
2400+
in_keys=[("next", "observation"), ("observation",)],
2401+
out_keys=[("next", "hashing"), ("hashing",)],
2402+
hash_fn=lambda x: [hash(x[0]), hash(x[1])],
2403+
)
2404+
rb = rbclass(storage=LazyTensorStorage(10))
2405+
rb.append_transform(t)
2406+
td = TensorDict(
2407+
{
2408+
"observation": torch.randn(3, 4),
2409+
"next": TensorDict(
2410+
{"observation": torch.randn(3, 4)},
2411+
[],
2412+
),
2413+
},
2414+
[],
2415+
).expand(10)
2416+
rb.extend(td)
2417+
td = rb.sample(2)
2418+
assert "hashing" in td.keys()
2419+
assert "observation" in td.keys()
2420+
assert ("next", "observation") in td.keys(True)
2421+
2422+
def test_transform_inverse(self):
2423+
raise pytest.skip("No inverse for Hash")
2424+
2425+
21832426
class TestStack(TransformBase):
21842427
def test_single_trans_env_check(self):
21852428
t = Stack(

torchrl/envs/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@
6767
FrameSkipTransform,
6868
GrayScale,
6969
gSDENoise,
70+
Hash,
7071
InitTracker,
7172
KLRewardTransform,
7273
LineariseRewards,
@@ -97,6 +98,7 @@
9798
TrajCounter,
9899
Transform,
99100
TransformedEnv,
101+
UnaryTransform,
100102
UnsqueezeTransform,
101103
VC1Transform,
102104
VecGymEnvTransform,

torchrl/envs/transforms/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
FrameSkipTransform,
3232
GrayScale,
3333
gSDENoise,
34+
Hash,
3435
InitTracker,
3536
LineariseRewards,
3637
NoopResetEnv,
@@ -58,6 +59,7 @@
5859
TrajCounter,
5960
Transform,
6061
TransformedEnv,
62+
UnaryTransform,
6163
UnsqueezeTransform,
6264
VecGymEnvTransform,
6365
VecNorm,

0 commit comments

Comments
 (0)