Skip to content

Commit 0f03ae8

Browse files
authored
Update OurDataset.py
1 parent d8dfd70 commit 0f03ae8

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

OurDataset.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -6,27 +6,27 @@
66
from typing import Tuple, Optional, Union
77

88
class OurDataset(Dataset):
9-
def __init__(self, pkl_path,seed, transform = None):
9+
def __init__(self, pkl_path, transform = None):
1010
dataset_info = pickle.load(open(pkl_path, 'rb+'))
1111

12-
self.seed = seed
1312
self.transform=transform
1413
self.rgb_image_name = dataset_info.image_rgb_name
1514
self.lunkuo_image_name = dataset_info.image_lunkuo_name
1615
def __getitem__(self, index)-> Tuple[torch.Tensor, ...]:
1716

17+
seed = torch.randint(0, 100000, (1,)).item()
1818
rgb_image= self.rgb_image_name[index]
1919
lunkuo_image= self.lunkuo_image_name[index]
2020

2121
rgb_img_pil = Image.open(rgb_image)
2222
lunkuo_img_pil = Image.open(lunkuo_image).convert('RGB')
2323

2424
if self.transform is not None:
25-
torch.manual_seed(self.seed)
25+
torch.manual_seed(seed)
2626
rgb_image = self.transform(rgb_img_pil)
27-
torch.manual_seed(self.seed)
27+
torch.manual_seed(seed)
2828
lunkuo_image = self.transform(lunkuo_img_pil)
2929

3030
return rgb_image,lunkuo_image
3131
def __len__(self):
32-
return len(self.rgb_image_name)
32+
return len(self.rgb_image_name)

0 commit comments

Comments
 (0)