@@ -391,20 +391,61 @@ class ContinuousBox(Box):
391
391
_low : torch .Tensor
392
392
_high : torch .Tensor
393
393
device : torch .device | None = None
394
+ _batch_size : torch .Size | None = None
395
+
396
+ @property
397
+ def batch_size (self ):
398
+ return self ._batch_size
399
+
400
+ @batch_size .setter
401
+ def batch_size (self , value : torch .Size | tuple ):
402
+ # Check batch size is compatible with low and high
403
+ value = _remove_neg_shapes (value )
404
+ if self ._batch_size is None :
405
+ if value != self ._low .shape [: len (value )]:
406
+ raise ValueError (
407
+ f"Batch size { value } is not compatible with low and high { self ._low .shape } "
408
+ )
409
+ if value is None :
410
+ self ._batch_size = None
411
+ self ._low = self .low .clone ()
412
+ self ._high = self .high .clone ()
413
+ return
414
+ # Remove batch size from low and high
415
+ if value :
416
+ # Check that low and high have a single value
417
+ td_low_high = TensorDict (
418
+ low = self .low , high = self .high , batch_size = value
419
+ ).flatten ()
420
+ td_low_high0 = td_low_high [0 ]
421
+ if torch .allclose (
422
+ td_low_high0 ["low" ], td_low_high ["low" ]
423
+ ) and torch .allclose (td_low_high0 ["high" ], td_low_high ["high" ]):
424
+ self ._low = td_low_high0 ["low" ].clone ()
425
+ self ._high = td_low_high0 ["high" ].clone ()
426
+ self ._batch_size = torch .Size (value )
427
+ else :
428
+ self ._low = self .low .clone ()
429
+ self ._high = self .high .clone ()
430
+ self ._batch_size = torch .Size (value )
394
431
395
432
# We store the tensors on CPU to avoid overloading CUDA with tensors that are rarely used.
396
433
@property
397
434
def low (self ):
398
435
low = self ._low
399
436
if self .device is not None and low .device != self .device :
400
437
low = low .to (self .device )
438
+ if self ._batch_size :
439
+ low = low .expand ((* self ._batch_size , * low .shape )).clone ()
401
440
return low
402
441
403
442
@property
404
443
def high (self ):
405
444
high = self ._high
406
445
if self .device is not None and high .device != self .device :
407
446
high = high .to (self .device )
447
+ if self ._batch_size :
448
+ high = high .expand ((* self ._batch_size , * high .shape )).clone ()
408
449
return high
409
450
410
451
def unbind (self , dim : int = 0 ):
@@ -417,15 +458,30 @@ def unbind(self, dim: int = 0):
417
458
def low (self , value ):
418
459
self .device = value .device
419
460
self ._low = value
461
+ if self ._batch_size is not None :
462
+ if value .shape [: len (self ._batch_size )] != self ._batch_size :
463
+ raise ValueError (
464
+ f"Batch size { value .shape [:len (self ._batch_size )]} is not compatible with low and high { self ._batch_size } "
465
+ )
466
+ if self ._batch_size :
467
+ self ._low = self ._low .flatten (0 , len (self ._batch_size ) - 1 )[0 ].clone ()
420
468
421
469
@high .setter
422
470
def high (self , value ):
423
471
self .device = value .device
424
472
self ._high = value
473
+ if self ._batch_size is not None :
474
+ if value .shape [: len (self ._batch_size )] != self ._batch_size :
475
+ raise ValueError (
476
+ f"Batch size { value .shape [:len (self ._batch_size )]} is not compatible with low and high { self ._batch_size } "
477
+ )
478
+ if self ._batch_size :
479
+ self ._high = self ._high .flatten (0 , len (self ._batch_size ) - 1 )[0 ].clone ()
425
480
426
481
def __post_init__ (self ):
427
482
self .low = self .low .clone ()
428
483
self .high = self .high .clone ()
484
+ self ._batch_size = None
429
485
430
486
def __iter__ (self ):
431
487
yield self .low
@@ -2366,6 +2422,10 @@ def __init__(
2366
2422
)
2367
2423
self .encode = self ._encode_eager
2368
2424
2425
+ def _register_batch_size (self , batch_size : torch .Size | tuple ):
2426
+ # Register batch size in the space to decrease the memory footprint of the specs
2427
+ self .space .batch_size = batch_size
2428
+
2369
2429
def index (
2370
2430
self , index : INDEX_TYPING , tensor_to_index : torch .Tensor | TensorDictBase
2371
2431
) -> torch .Tensor | TensorDictBase :
@@ -5191,6 +5251,8 @@ def set(self, name: str, spec: TensorSpec) -> Composite:
5191
5251
f"{ self .ndim } dimensions should match but got spec.shape={ spec .shape } and "
5192
5252
f"Composite.shape={ self .shape } ."
5193
5253
)
5254
+ if isinstance (spec , Bounded ):
5255
+ spec ._register_batch_size (self .shape )
5194
5256
self ._specs [name ] = spec
5195
5257
return self
5196
5258
0 commit comments