|
39 | 39 | CountingBatchedEnv,
|
40 | 40 | CountingEnv,
|
41 | 41 | CountingEnvCountPolicy,
|
| 42 | + CountingEnvWithString, |
42 | 43 | DiscreteActionConvMockEnv,
|
43 | 44 | DiscreteActionConvMockEnvNumpy,
|
44 | 45 | EnvWithScalarAction,
|
|
66 | 67 | CountingBatchedEnv,
|
67 | 68 | CountingEnv,
|
68 | 69 | CountingEnvCountPolicy,
|
| 70 | + CountingEnvWithString, |
69 | 71 | DiscreteActionConvMockEnv,
|
70 | 72 | DiscreteActionConvMockEnvNumpy,
|
71 | 73 | EnvWithScalarAction,
|
|
77 | 79 | MultiKeyCountingEnvPolicy,
|
78 | 80 | NestedCountingEnv,
|
79 | 81 | )
|
80 |
| -from tensordict import TensorDict, TensorDictBase, unravel_key |
| 82 | +from tensordict import NonTensorData, TensorDict, TensorDictBase, unravel_key |
81 | 83 | from tensordict.nn import TensorDictSequential
|
82 | 84 | from tensordict.utils import _unravel_key_to_tuple, assert_allclose_td
|
83 | 85 | from torch import multiprocessing as mp, nn, Tensor
|
|
118 | 120 | FrameSkipTransform,
|
119 | 121 | GrayScale,
|
120 | 122 | gSDENoise,
|
| 123 | + Hash, |
121 | 124 | InitTracker,
|
122 | 125 | LineariseRewards,
|
123 | 126 | MultiStepTransform,
|
@@ -2180,6 +2183,246 @@ def test_transform_no_env(self, device, batch):
|
2180 | 2183 | pytest.skip("TrajCounter cannot be called without env")
|
2181 | 2184 |
|
2182 | 2185 |
|
| 2186 | +class TestHash(TransformBase): |
| 2187 | + @pytest.mark.parametrize("datatype", ["tensor", "str", "NonTensorStack"]) |
| 2188 | + def test_transform_no_env(self, datatype): |
| 2189 | + if datatype == "tensor": |
| 2190 | + obs = torch.tensor(10) |
| 2191 | + hash_fn = hash |
| 2192 | + elif datatype == "str": |
| 2193 | + obs = "abcdefg" |
| 2194 | + hash_fn = Hash.reproducible_hash |
| 2195 | + elif datatype == "NonTensorStack": |
| 2196 | + obs = torch.stack( |
| 2197 | + [ |
| 2198 | + NonTensorData(data="abcde"), |
| 2199 | + NonTensorData(data="fghij"), |
| 2200 | + NonTensorData(data="klmno"), |
| 2201 | + ] |
| 2202 | + ) |
| 2203 | + |
| 2204 | + def fn0(x): |
| 2205 | + return torch.stack([Hash.reproducible_hash(x_) for x_ in x]) |
| 2206 | + |
| 2207 | + hash_fn = fn0 |
| 2208 | + else: |
| 2209 | + raise RuntimeError(f"please add a test case for datatype {datatype}") |
| 2210 | + |
| 2211 | + td = TensorDict( |
| 2212 | + { |
| 2213 | + "observation": obs, |
| 2214 | + } |
| 2215 | + ) |
| 2216 | + |
| 2217 | + t = Hash(in_keys=["observation"], out_keys=["hashing"], hash_fn=hash_fn) |
| 2218 | + td_hashed = t(td) |
| 2219 | + |
| 2220 | + assert td_hashed.get("observation") is td.get("observation") |
| 2221 | + |
| 2222 | + if datatype == "NonTensorStack": |
| 2223 | + assert ( |
| 2224 | + td_hashed["hashing"] == hash_fn(td.get("observation").tolist()) |
| 2225 | + ).all() |
| 2226 | + elif datatype == "str": |
| 2227 | + assert all(td_hashed["hashing"] == hash_fn(td["observation"])) |
| 2228 | + else: |
| 2229 | + assert td_hashed["hashing"] == hash_fn(td["observation"]) |
| 2230 | + |
| 2231 | + @pytest.mark.parametrize("datatype", ["tensor", "str"]) |
| 2232 | + def test_single_trans_env_check(self, datatype): |
| 2233 | + if datatype == "tensor": |
| 2234 | + t = Hash( |
| 2235 | + in_keys=["observation"], |
| 2236 | + out_keys=["hashing"], |
| 2237 | + hash_fn=hash, |
| 2238 | + ) |
| 2239 | + base_env = CountingEnv() |
| 2240 | + elif datatype == "str": |
| 2241 | + t = Hash( |
| 2242 | + in_keys=["string"], |
| 2243 | + out_keys=["hashing"], |
| 2244 | + ) |
| 2245 | + base_env = CountingEnvWithString() |
| 2246 | + env = TransformedEnv(base_env, t) |
| 2247 | + check_env_specs(env) |
| 2248 | + |
| 2249 | + @pytest.mark.parametrize("datatype", ["tensor", "str"]) |
| 2250 | + def test_serial_trans_env_check(self, datatype): |
| 2251 | + def make_env(): |
| 2252 | + if datatype == "tensor": |
| 2253 | + t = Hash( |
| 2254 | + in_keys=["observation"], |
| 2255 | + out_keys=["hashing"], |
| 2256 | + hash_fn=hash, |
| 2257 | + ) |
| 2258 | + base_env = CountingEnv() |
| 2259 | + |
| 2260 | + elif datatype == "str": |
| 2261 | + t = Hash( |
| 2262 | + in_keys=["string"], |
| 2263 | + out_keys=["hashing"], |
| 2264 | + ) |
| 2265 | + base_env = CountingEnvWithString() |
| 2266 | + |
| 2267 | + return TransformedEnv(base_env, t) |
| 2268 | + |
| 2269 | + env = SerialEnv(2, make_env) |
| 2270 | + check_env_specs(env) |
| 2271 | + |
| 2272 | + @pytest.mark.parametrize("datatype", ["tensor", "str"]) |
| 2273 | + def test_parallel_trans_env_check(self, maybe_fork_ParallelEnv, datatype): |
| 2274 | + def make_env(): |
| 2275 | + if datatype == "tensor": |
| 2276 | + t = Hash( |
| 2277 | + in_keys=["observation"], |
| 2278 | + out_keys=["hashing"], |
| 2279 | + hash_fn=hash, |
| 2280 | + ) |
| 2281 | + base_env = CountingEnv() |
| 2282 | + elif datatype == "str": |
| 2283 | + t = Hash( |
| 2284 | + in_keys=["string"], |
| 2285 | + out_keys=["hashing"], |
| 2286 | + ) |
| 2287 | + base_env = CountingEnvWithString() |
| 2288 | + return TransformedEnv(base_env, t) |
| 2289 | + |
| 2290 | + env = maybe_fork_ParallelEnv(2, make_env) |
| 2291 | + try: |
| 2292 | + check_env_specs(env) |
| 2293 | + finally: |
| 2294 | + try: |
| 2295 | + env.close() |
| 2296 | + except RuntimeError: |
| 2297 | + pass |
| 2298 | + |
| 2299 | + @pytest.mark.parametrize("datatype", ["tensor", "str"]) |
| 2300 | + def test_trans_serial_env_check(self, datatype): |
| 2301 | + if datatype == "tensor": |
| 2302 | + t = Hash( |
| 2303 | + in_keys=["observation"], |
| 2304 | + out_keys=["hashing"], |
| 2305 | + hash_fn=lambda x: [hash(x[0]), hash(x[1])], |
| 2306 | + ) |
| 2307 | + base_env = CountingEnv |
| 2308 | + elif datatype == "str": |
| 2309 | + t = Hash( |
| 2310 | + in_keys=["string"], |
| 2311 | + out_keys=["hashing"], |
| 2312 | + hash_fn=lambda x: torch.stack([Hash.reproducible_hash(x_) for x_ in x]), |
| 2313 | + ) |
| 2314 | + base_env = CountingEnvWithString |
| 2315 | + |
| 2316 | + env = TransformedEnv(SerialEnv(2, base_env), t) |
| 2317 | + check_env_specs(env) |
| 2318 | + |
| 2319 | + @pytest.mark.parametrize("datatype", ["tensor", "str"]) |
| 2320 | + def test_trans_parallel_env_check(self, maybe_fork_ParallelEnv, datatype): |
| 2321 | + if datatype == "tensor": |
| 2322 | + t = Hash( |
| 2323 | + in_keys=["observation"], |
| 2324 | + out_keys=["hashing"], |
| 2325 | + hash_fn=lambda x: [hash(x[0]), hash(x[1])], |
| 2326 | + ) |
| 2327 | + base_env = CountingEnv |
| 2328 | + elif datatype == "str": |
| 2329 | + t = Hash( |
| 2330 | + in_keys=["string"], |
| 2331 | + out_keys=["hashing"], |
| 2332 | + hash_fn=lambda x: torch.stack([Hash.reproducible_hash(x_) for x_ in x]), |
| 2333 | + ) |
| 2334 | + base_env = CountingEnvWithString |
| 2335 | + |
| 2336 | + env = TransformedEnv(maybe_fork_ParallelEnv(2, base_env), t) |
| 2337 | + try: |
| 2338 | + check_env_specs(env) |
| 2339 | + finally: |
| 2340 | + try: |
| 2341 | + env.close() |
| 2342 | + except RuntimeError: |
| 2343 | + pass |
| 2344 | + |
| 2345 | + @pytest.mark.parametrize("datatype", ["tensor", "str"]) |
| 2346 | + def test_transform_compose(self, datatype): |
| 2347 | + if datatype == "tensor": |
| 2348 | + obs = torch.tensor(10) |
| 2349 | + elif datatype == "str": |
| 2350 | + obs = "abcdefg" |
| 2351 | + |
| 2352 | + td = TensorDict( |
| 2353 | + { |
| 2354 | + "observation": obs, |
| 2355 | + } |
| 2356 | + ) |
| 2357 | + t = Hash( |
| 2358 | + in_keys=["observation"], |
| 2359 | + out_keys=["hashing"], |
| 2360 | + hash_fn=hash, |
| 2361 | + ) |
| 2362 | + t = Compose(t) |
| 2363 | + td_hashed = t(td) |
| 2364 | + |
| 2365 | + assert td_hashed["observation"] is td["observation"] |
| 2366 | + assert td_hashed["hashing"] == hash(td["observation"]) |
| 2367 | + |
| 2368 | + def test_transform_model(self): |
| 2369 | + t = Hash( |
| 2370 | + in_keys=[("next", "observation"), ("observation",)], |
| 2371 | + out_keys=[("next", "hashing"), ("hashing",)], |
| 2372 | + hash_fn=hash, |
| 2373 | + ) |
| 2374 | + model = nn.Sequential(t, nn.Identity()) |
| 2375 | + td = TensorDict( |
| 2376 | + {("next", "observation"): torch.randn(3), "observation": torch.randn(3)}, [] |
| 2377 | + ) |
| 2378 | + td_out = model(td) |
| 2379 | + assert ("next", "hashing") in td_out.keys(True) |
| 2380 | + assert ("hashing",) in td_out.keys(True) |
| 2381 | + assert td_out["next", "hashing"] == hash(td["next", "observation"]) |
| 2382 | + assert td_out["hashing"] == hash(td["observation"]) |
| 2383 | + |
| 2384 | + @pytest.mark.skipif(not _has_gym, reason="Gym not found") |
| 2385 | + def test_transform_env(self): |
| 2386 | + t = Hash( |
| 2387 | + in_keys=["observation"], |
| 2388 | + out_keys=["hashing"], |
| 2389 | + hash_fn=hash, |
| 2390 | + ) |
| 2391 | + env = TransformedEnv(GymEnv(PENDULUM_VERSIONED()), t) |
| 2392 | + assert env.observation_spec["hashing"] |
| 2393 | + assert "observation" in env.observation_spec |
| 2394 | + assert "observation" in env.base_env.observation_spec |
| 2395 | + check_env_specs(env) |
| 2396 | + |
| 2397 | + @pytest.mark.parametrize("rbclass", [ReplayBuffer, TensorDictReplayBuffer]) |
| 2398 | + def test_transform_rb(self, rbclass): |
| 2399 | + t = Hash( |
| 2400 | + in_keys=[("next", "observation"), ("observation",)], |
| 2401 | + out_keys=[("next", "hashing"), ("hashing",)], |
| 2402 | + hash_fn=lambda x: [hash(x[0]), hash(x[1])], |
| 2403 | + ) |
| 2404 | + rb = rbclass(storage=LazyTensorStorage(10)) |
| 2405 | + rb.append_transform(t) |
| 2406 | + td = TensorDict( |
| 2407 | + { |
| 2408 | + "observation": torch.randn(3, 4), |
| 2409 | + "next": TensorDict( |
| 2410 | + {"observation": torch.randn(3, 4)}, |
| 2411 | + [], |
| 2412 | + ), |
| 2413 | + }, |
| 2414 | + [], |
| 2415 | + ).expand(10) |
| 2416 | + rb.extend(td) |
| 2417 | + td = rb.sample(2) |
| 2418 | + assert "hashing" in td.keys() |
| 2419 | + assert "observation" in td.keys() |
| 2420 | + assert ("next", "observation") in td.keys(True) |
| 2421 | + |
| 2422 | + def test_transform_inverse(self): |
| 2423 | + raise pytest.skip("No inverse for Hash") |
| 2424 | + |
| 2425 | + |
2183 | 2426 | class TestStack(TransformBase):
|
2184 | 2427 | def test_single_trans_env_check(self):
|
2185 | 2428 | t = Stack(
|
|
0 commit comments