Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions tests/buffer/task_scheduler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from trinity.buffer.reader import READER
from trinity.buffer.reader.file_reader import TaskFileReader
from trinity.buffer.task_scheduler import TasksetScheduler, get_taskset_scheduler
from trinity.common.config import FormatConfig, TaskSelectorConfig, TasksetConfig
from trinity.common.config import DataSelectorConfig, FormatConfig, TasksetConfig
from trinity.common.workflows.workflow import Task


Expand Down Expand Up @@ -250,7 +250,7 @@ def _check_batch_tasks(self, batch_tasks: List[Task], indices: List[Dict[str, in
]
)
async def test_task_scheduler(
self, buffer_config_kwargs, task_selector_kwargs, batch_tasks_orders
self, buffer_config_kwargs, data_selector_kwargs, batch_tasks_orders
) -> None:
config = get_template_config()
config.mode = "explore"
Expand All @@ -276,8 +276,8 @@ async def test_task_scheduler(
),
default_workflow_type="math_workflow",
default_reward_fn_type="math_reward",
task_selector=TaskSelectorConfig(
**task_selector_kwargs,
data_selector=DataSelectorConfig(
**data_selector_kwargs,
),
),
TasksetConfig(
Expand All @@ -298,8 +298,8 @@ async def test_task_scheduler(
),
default_workflow_type="math_workflow",
default_reward_fn_type="math_reward",
task_selector=TaskSelectorConfig(
**task_selector_kwargs,
data_selector=DataSelectorConfig(
**data_selector_kwargs,
),
),
]
Expand Down
4 changes: 2 additions & 2 deletions tests/trainer/trainer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,10 @@
AlgorithmConfig,
BufferConfig,
Config,
DataSelectorConfig,
ExperienceBufferConfig,
ExplorerInput,
StageConfig,
TaskSelectorConfig,
TrainerInput,
)
from trinity.common.constants import (
Expand Down Expand Up @@ -92,7 +92,7 @@ def test_trainer(self):
}
self.config.model.rope_theta = 10000
self.config.buffer.explorer_input.taskset = get_unittest_dataset_config("countdown")
self.config.buffer.explorer_input.taskset.task_selector = TaskSelectorConfig(
self.config.buffer.explorer_input.taskset.data_selector = DataSelectorConfig(
selector_type="shuffle", seed=42
)
eval_tasksets = self.config.buffer.explorer_input.eval_tasksets
Expand Down
145 changes: 80 additions & 65 deletions trinity/buffer/reader/file_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,16 @@ def __init__(
):
self.dataset = dataset
self.dataset_size = len(dataset)
if self.dataset_size == 0:
raise ValueError(f"Dataset [{name}] is empty and cannot be read in batches.")
self.name = name
self.current_batch_size = None
self.drop_last = drop_last

self.current_offset = offset

# convert epochs/steps to sample number
if total_steps:
if total_steps is not None:
self.total_samples = default_batch_size * total_steps
else:
self.total_samples = self.dataset_size * total_epochs
Expand Down Expand Up @@ -94,11 +96,63 @@ def select_batch(self, indices: List[int]) -> List:
class BaseFileReader(BufferReader):
async def read_async(self, batch_size: Optional[int] = None, **kwargs):
try:
return self.read(batch_size)
return self.read(batch_size, **kwargs)
except StopIteration as e:
raise StopAsyncIteration from e


class _DatasetFileReader(BaseFileReader):
def __init__(self, config: StorageConfig):
self.config = config
self.name = config.name
self.read_batch_size = config.batch_size
self.formatter, self.dataset = self._init_formatter_and_dataset(config)
self._init_selector(config)

def _init_formatter_and_dataset(self, config: StorageConfig):
raise NotImplementedError

def _init_selector(self, config: StorageConfig):
if config.data_selector is not None:
from trinity.buffer.selector import SELECTORS
from trinity.buffer.selector.selector import BaseSelector

self.selector: BaseSelector = SELECTORS.get(config.data_selector.selector_type)(
self.dataset, config.data_selector
)
Comment on lines +120 to +122
Copy link

Copilot AI Mar 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_init_selector calls SELECTORS.get(config.data_selector.selector_type)(...) without validating that the registry lookup succeeded. If selector_type is None or unsupported (or if config.check_and_update() was not run), this will raise a confusing TypeError: 'NoneType' object is not callable. Add an explicit check and raise a clear ValueError (or fall back to sequential) before instantiating the selector.

Suggested change
self.selector: BaseSelector = SELECTORS.get(config.data_selector.selector_type)(
self.dataset, config.data_selector
)
selector_type = getattr(config.data_selector, "selector_type", None)
selector_cls = SELECTORS.get(selector_type)
if selector_cls is None:
raise ValueError(
f"Unknown or missing data selector type {selector_type!r} for dataset [{self.name}]."
)
self.selector = selector_cls(self.dataset, config.data_selector)

Copilot uses AI. Check for mistakes.
else:
self.selector = None

def _read_samples(self, batch_size: int) -> Tuple[List, List]:
if self.selector is not None:
indices = self.selector.get_indices(batch_size)
samples = self.dataset.select_batch(indices)
return samples, indices
return self.dataset.read_batch(batch_size)

def state_dict(self):
if self.selector is not None:
return self.selector.state_dict()
return {"current_index": self.dataset.current_offset}

def load_state_dict(self, state_dict):
if self.selector is not None:
self.selector.load_state_dict(state_dict)
else:
self.dataset.current_offset = state_dict["current_index"]

def __len__(self):
return self.dataset.dataset_size

def _convert_batch(self, samples: List, indices: List) -> List:
raise NotImplementedError

def read(self, batch_size: Optional[int] = None, **kwargs) -> List:
batch_size = batch_size or self.read_batch_size
samples, indices = self._read_samples(batch_size)
return self._convert_batch(samples, indices)


class FileReader(BaseFileReader):
"""Provide a unified interface for Experience and Task file readers."""

Expand All @@ -109,7 +163,7 @@ def __init__(self, config: StorageConfig):
self.reader = TaskFileReader(config)

def read(self, batch_size: Optional[int] = None, **kwargs) -> List:
return self.reader.read(batch_size)
return self.reader.read(batch_size, **kwargs)

def state_dict(self):
return self.reader.state_dict()
Expand All @@ -125,15 +179,17 @@ def __len__(self):
return self.reader.__len__()


class ExperienceFileReader(BaseFileReader):
class ExperienceFileReader(_DatasetFileReader):
"""Reader for SFT / DPO file data."""

def __init__(self, config: StorageConfig):
self.formatter = FORMATTER.get(config.schema_type)(
super().__init__(config)

def _init_formatter_and_dataset(self, config: StorageConfig):
formatter = FORMATTER.get(config.schema_type)(
tokenizer_path=config.tokenizer_path, format_config=config.format
)
self.read_batch_size = config.batch_size
self.dataset = _HFBatchReader(
dataset = _HFBatchReader(
load_dataset(config.path, name=config.subset_name, split=config.split),
name=config.name,
default_batch_size=self.read_batch_size,
Expand All @@ -142,82 +198,41 @@ def __init__(self, config: StorageConfig):
total_steps=config.total_steps,
enable_progress_bar=config.enable_progress_bar,
)
self.selector = None
return formatter, dataset

def read(self, batch_size: Optional[int] = None, **kwargs) -> List:
samples, _ = self.dataset.read_batch(batch_size or self.read_batch_size)
def _convert_batch(self, samples: List, indices: List) -> List:
exp_list = []
for sample in samples:
experience = self.formatter.format(sample)
exp_list.append(experience)
return exp_list

def state_dict(self):
return {"current_index": self.dataset.current_offset}

def load_state_dict(self, state_dict):
self.dataset.current_offset = state_dict["current_index"]

def __len__(self):
return self.dataset.dataset_size


class TaskFileReader(BaseFileReader):
class TaskFileReader(_DatasetFileReader):
"""A Reader for task file data."""

def __init__(self, config: StorageConfig):
self.config = config
self.name = config.name
self.epoch = 0
datasets.disable_caching()
self.read_batch_size = config.batch_size
self.dataset = _HFBatchReader(
load_dataset(self.config.path, name=self.config.subset_name, split=self.config.split),
name=self.config.name,
super().__init__(config)

def _init_formatter_and_dataset(self, config):
formatter = FORMATTER.get("task")(config)
dataset = _HFBatchReader(
load_dataset(config.path, name=config.subset_name, split=config.split),
name=config.name,
default_batch_size=self.read_batch_size,
total_epochs=self.config.total_epochs if not self.config.is_eval else 1,
offset=self.config.index,
drop_last=not self.config.is_eval,
total_steps=self.config.total_steps if not self.config.is_eval else None,
enable_progress_bar=self.config.enable_progress_bar,
total_epochs=config.total_epochs if not config.is_eval else 1,
offset=config.index,
drop_last=not config.is_eval,
total_steps=config.total_steps if not config.is_eval else None,
enable_progress_bar=config.enable_progress_bar,
)
self.formatter = FORMATTER.get("task")(config)
if self.config.task_selector is not None:
from trinity.buffer.selector import SELECTORS
from trinity.buffer.selector.selector import BaseSelector

self.selector: BaseSelector = SELECTORS.get(self.config.task_selector.selector_type)(
self.dataset, self.config.task_selector
)
else:
self.selector = None
return formatter, dataset

def _get_tasks(self, samples: List, indices: List) -> List:
def _convert_batch(self, samples: List, indices: List) -> List:
tasks = []
for sample, index in zip(samples, indices):
task = self.formatter.format(sample)
task.index["index"] = int(index)
tasks.append(task)
return tasks

def read(self, batch_size: Optional[int] = None, **kwargs) -> List:
batch_size = batch_size or self.read_batch_size
if self.selector is not None:
indices = self.selector.get_indices(batch_size)
samples = self.dataset.select_batch(indices)
else:
samples, indices = self.dataset.read_batch(batch_size)
return self._get_tasks(samples, indices)

def state_dict(self):
if self.selector is not None:
return self.selector.state_dict()
return {"current_index": self.dataset.current_offset}

def load_state_dict(self, state_dict):
if self.selector is not None:
self.selector.load_state_dict(state_dict)
self.dataset.current_offset = state_dict["current_index"]

def __len__(self):
return self.dataset.dataset_size
14 changes: 7 additions & 7 deletions trinity/buffer/selector/selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from trinity.buffer.reader.file_reader import _HFBatchReader
from trinity.buffer.selector.difficulty_estimator import InterpolationBetaPREstimator
from trinity.common.config import TaskSelectorConfig
from trinity.common.config import DataSelectorConfig
from trinity.utils.annotations import Experimental
from trinity.utils.log import get_logger

Expand All @@ -27,7 +27,7 @@ class BaseSelector:
- state_dict / load_state_dict: for saving/loading selector state (checkpointing)
"""

def __init__(self, data_source: _HFBatchReader, config: TaskSelectorConfig):
def __init__(self, data_source: _HFBatchReader, config: DataSelectorConfig):
self.data_source = data_source
self.config = config

Expand Down Expand Up @@ -82,7 +82,7 @@ class SequentialSelector(BaseSelector):
Example: [0,1,2,...,B-1], then [B,B+1,...,2B-1], etc.
"""

def __init__(self, data_source: _HFBatchReader, config: TaskSelectorConfig):
def __init__(self, data_source: _HFBatchReader, config: DataSelectorConfig):
super().__init__(data_source, config)
self.dataset_size = data_source.dataset_size
self.current_index = 0
Expand Down Expand Up @@ -117,7 +117,7 @@ class ShuffleSelector(BaseSelector):
Mimics standard PyTorch DataLoader with shuffle=True.
"""

def __init__(self, data_source: _HFBatchReader, config: TaskSelectorConfig):
def __init__(self, data_source: _HFBatchReader, config: DataSelectorConfig):
super().__init__(data_source, config)
self.dataset_size = data_source.dataset_size # Total available samples
self.current_index = 0 # Progress tracker
Expand Down Expand Up @@ -172,7 +172,7 @@ class RandomSelector(BaseSelector):
Can result in repeated samples within an epoch. Suitable for online or stochastic training regimes.
"""

def __init__(self, data_source: _HFBatchReader, config: TaskSelectorConfig):
def __init__(self, data_source: _HFBatchReader, config: DataSelectorConfig):
super().__init__(data_source, config)
self.dataset_size = data_source.dataset_size
self.current_index = 0
Expand Down Expand Up @@ -214,7 +214,7 @@ class OfflineEasy2HardSelector(BaseSelector):
(e.g., via teacher model confidence, length, BLEU score, etc.).
"""

def __init__(self, data_source, config: TaskSelectorConfig):
def __init__(self, data_source, config: DataSelectorConfig):
super().__init__(data_source, config)
self.logger = get_logger("offline_easy2hard_selector")

Expand Down Expand Up @@ -297,7 +297,7 @@ class DifficultyBasedSelector(BaseSelector):
Supports both greedy selection (`tau=0`) and stochastic sampling (`tau>0`).
"""

def __init__(self, data_source, config: TaskSelectorConfig) -> None:
def __init__(self, data_source, config: DataSelectorConfig) -> None:
super().__init__(data_source, config)
self.logger = get_logger("difficulty_based_selector")

Expand Down
4 changes: 2 additions & 2 deletions trinity/buffer/task_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def get_taskset_scheduler(explorer_state: Dict, config: Config) -> "TasksetSched
TasksetSchedulerBase: The taskset scheduler instance
"""
taskset_configs = config.buffer.explorer_input.tasksets
if len(taskset_configs) == 1 and taskset_configs[0].task_selector.selector_type == "sequential":
if len(taskset_configs) == 1 and taskset_configs[0].data_selector.selector_type == "sequential":
Copy link

Copilot AI Mar 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

get_taskset_scheduler assumes taskset_configs[0].data_selector is non-None. Since other code paths support data_selector = None to disable selection, a config that sets data_selector: null would raise an AttributeError here. Consider treating None as equivalent to sequential/no-selection in this condition.

Suggested change
if len(taskset_configs) == 1 and taskset_configs[0].data_selector.selector_type == "sequential":
data_selector = taskset_configs[0].data_selector
if len(taskset_configs) == 1 and (
data_selector is None or data_selector.selector_type == "sequential"
):

Copilot uses AI. Check for mistakes.
return SimpleTasksetScheduler(explorer_state, config)
else:
return TasksetScheduler(explorer_state, config)
Expand Down Expand Up @@ -70,7 +70,7 @@ def __init__(self, explorer_state: Dict, config: Config):
)
taskset_config = deepcopy(self.config.buffer.explorer_input.tasksets[0])
taskset_config.index = index
taskset_config.task_selector = None # disable selection
taskset_config.data_selector = None # disable selection
self.reader = get_buffer_reader(taskset_config)

async def read_async(self) -> List:
Expand Down
Loading
Loading