Skip to content

Commit

Permalink
update CLAM
Browse files Browse the repository at this point in the history
  • Loading branch information
Guo Zhengrui committed Aug 4, 2024
1 parent b789af5 commit dddca5f
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions CLAM/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import math
from itertools import islice
import collections
from .survival_utils import collate_MIL_survival
# from .survival_utils import collate_MIL_survival
device=torch.device("cuda" if torch.cuda.is_available() else "cpu")

class SubsetSequentialSampler(Sampler):
Expand Down Expand Up @@ -55,10 +55,10 @@ def get_simple_loader(dataset, batch_size=1, num_workers=1):
loader = DataLoader(dataset, batch_size=batch_size, sampler = sampler.SequentialSampler(dataset), collate_fn = collate_MIL, **kwargs)
return loader

def get_simple_loader_survival(dataset, batch_size=1, num_workers=1):
kwargs = {'num_workers': 4, 'pin_memory': False, 'num_workers': num_workers} if device.type == "cuda" else {}
loader = DataLoader(dataset, batch_size=batch_size, sampler = sampler.SequentialSampler(dataset), collate_fn = collate_MIL_survival, **kwargs)
return loader
# def get_simple_loader_survival(dataset, batch_size=1, num_workers=1):
# kwargs = {'num_workers': 4, 'pin_memory': False, 'num_workers': num_workers} if device.type == "cuda" else {}
# loader = DataLoader(dataset, batch_size=batch_size, sampler = sampler.SequentialSampler(dataset), collate_fn = collate_MIL_survival, **kwargs)
# return loader

def get_split_loader(split_dataset, training = False, testing = False, weighted = False, batch_size=1):
"""
Expand Down

0 comments on commit dddca5f

Please sign in to comment.