Skip to content

Commit 7a4c1ea

Browse files
committed
add seed and offset
1 parent 8d2a550 commit 7a4c1ea

File tree

1 file changed

+24
-0
lines changed

1 file changed

+24
-0
lines changed

dipu/torch_dipu/csrc_dipu/diopirt/diopirt_impl.cpp

+24
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,30 @@ DIOPI_RT_API diopiError_t diopiGeneratorSetState(
183183
return diopiSuccess;
184184
}
185185

186+
DIOPI_RT_API diopiError_t diopiGeneratorGetSeedAndOffset(
187+
diopiGeneratorHandle_t th, uint64_t& seed, uint64_t& offset) {
188+
auto generator = reinterpret_cast<at::Generator*>(th);
189+
auto gen_impl = at::check_generator<dipu::DIPUGeneratorImpl>(*generator);
190+
{
191+
offset = gen_impl->get_offset();
192+
seed = gen_impl->current_seed();
193+
}
194+
195+
return diopiSuccess;
196+
}
197+
198+
DIOPI_RT_API diopiError_t diopiGeneratorSetSeedAndOffset(
199+
diopiGeneratorHandle_t th, uint64_t seed, uint64_t offset) {
200+
auto generator = reinterpret_cast<at::Generator*>(th);
201+
auto gen_impl = at::check_generator<dipu::DIPUGeneratorImpl>(*generator);
202+
{
203+
gen_impl->set_offset(offset);
204+
gen_impl->set_current_seed(seed);
205+
}
206+
207+
return diopiSuccess;
208+
}
209+
186210
DIOPI_RT_API diopiError_t diopiRecordStart(const char* record_name,
187211
void** record) {
188212
*record = new RecordBlockCreator(record_name);

0 commit comments

Comments
 (0)