Skip to content

Commit cd0b05b

Browse files
tyb0807tensorflower-gardener
authored andcommitted
[xla:gpu] Simplify the search for operand slices in DynamicAddressComputationFusion
PiperOrigin-RevId: 618110053
1 parent 42419a5 commit cd0b05b

File tree

2 files changed

+21
-19
lines changed

2 files changed

+21
-19
lines changed

third_party/xla/xla/service/gpu/fusions/BUILD

+1
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ cc_library(
8080
"//xla/service/gpu/runtime:custom_call_thunk",
8181
"//xla/service/gpu/runtime:gemm_thunk",
8282
"//xla/service/gpu/runtime:kernel_thunk",
83+
"@com_google_absl//absl/algorithm:container",
8384
"@com_google_absl//absl/log",
8485
"@com_google_absl//absl/status",
8586
"@com_google_absl//absl/strings",

third_party/xla/xla/service/gpu/fusions/custom.cc

+20-19
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ limitations under the License.
2424
#include <variant>
2525
#include <vector>
2626

27+
#include "absl/algorithm/container.h"
2728
#include "absl/log/log.h"
2829
#include "absl/status/status.h"
2930
#include "absl/strings/str_cat.h"
@@ -197,32 +198,25 @@ absl::StatusOr<FusionEmissionResult> EmitDynamicSlicedGemm(
197198
auto get_original_operand_slice =
198199
[&](const HloInstruction* start,
199200
const ShapeIndex& index) -> absl::StatusOr<BufferAllocation::Slice> {
200-
if (const auto* param = DynCast<HloParameterInstruction>(start)) {
201-
return GetAllocationSlice(
202-
buffer_assignment, fusion.operand(param->parameter_number()), index);
203-
}
204-
201+
auto* param = DynCast<HloParameterInstruction>(start);
205202
auto slice_adaptor = HloFindIf(
206203
{HloInstructionAdaptor(*start)}, adaptor,
207204
[](auto node) { return node.opcode() == HloOpcode::kDynamicSlice; });
208-
if (!slice_adaptor.has_value()) {
209-
return absl::InternalError(
210-
"DynamicAddressComputationFusion expects all operands to be either "
211-
"sliced or parameter");
212-
}
205+
if (slice_adaptor.has_value()) {
206+
slice_instr = const_cast<HloDynamicIndexInstruction*>(
207+
static_cast<const HloDynamicIndexInstruction*>(
208+
&slice_adaptor->instruction()));
213209

214-
slice_instr = const_cast<HloDynamicIndexInstruction*>(
215-
static_cast<const HloDynamicIndexInstruction*>(
216-
&slice_adaptor->instruction()));
210+
if (!IsContiguousSlice(slice_instr->operand(0)->shape(),
211+
slice_instr->shape())) {
212+
return absl::InternalError(
213+
"DynamicAddressComputationFusion only handles contiguous slices "
214+
"currently");
215+
}
217216

218-
if (!IsContiguousSlice(slice_instr->operand(0)->shape(),
219-
slice_instr->shape())) {
220-
return absl::InternalError(
221-
"DynamicAddressComputationFusion only handles contiguous slices "
222-
"currently");
217+
param = Cast<HloParameterInstruction>(slice_instr->operand(0));
223218
}
224219

225-
const auto* param = Cast<HloParameterInstruction>(slice_instr->operand(0));
226220
return GetAllocationSlice(buffer_assignment,
227221
fusion.operand(param->parameter_number()), index);
228222
};
@@ -316,6 +310,13 @@ absl::StatusOr<FusionEmissionResult> EmitDynamicSlicedGemm(
316310
BufferAllocation::Slice(workspace->allocation(), 0, workspace->size());
317311
}
318312

313+
if (absl::c_all_of(offset_buffer_indices, [&](auto offset_slices) {
314+
return offset_slices == std::nullopt;
315+
}))
316+
return absl::InternalError(
317+
"DynamicAddressComputationFusion expects at least one sliced "
318+
"operand/result");
319+
319320
// Creating embedded GEMM thunk.
320321
bool deterministic_ops =
321322
ir_emitter_context.debug_options().xla_gpu_deterministic_ops();

0 commit comments

Comments
 (0)