6
6
from typing import Tuple , Optional , Union
7
7
8
8
class OurDataset (Dataset ):
9
- def __init__ (self , pkl_path ,seed , transform = None ):
9
+ def __init__ (self , pkl_path , transform = None ):
10
10
dataset_info = pickle .load (open (pkl_path , 'rb+' ))
11
11
12
- self .seed = seed
13
12
self .transform = transform
14
13
self .rgb_image_name = dataset_info .image_rgb_name
15
14
self .lunkuo_image_name = dataset_info .image_lunkuo_name
16
15
def __getitem__ (self , index )-> Tuple [torch .Tensor , ...]:
17
16
17
+ seed = torch .randint (0 , 100000 , (1 ,)).item ()
18
18
rgb_image = self .rgb_image_name [index ]
19
19
lunkuo_image = self .lunkuo_image_name [index ]
20
20
21
21
rgb_img_pil = Image .open (rgb_image )
22
22
lunkuo_img_pil = Image .open (lunkuo_image ).convert ('RGB' )
23
23
24
24
if self .transform is not None :
25
- torch .manual_seed (self . seed )
25
+ torch .manual_seed (seed )
26
26
rgb_image = self .transform (rgb_img_pil )
27
- torch .manual_seed (self . seed )
27
+ torch .manual_seed (seed )
28
28
lunkuo_image = self .transform (lunkuo_img_pil )
29
29
30
30
return rgb_image ,lunkuo_image
31
31
def __len__ (self ):
32
- return len (self .rgb_image_name )
32
+ return len (self .rgb_image_name )
0 commit comments