@@ -24,6 +24,7 @@ limitations under the License.
24
24
#include < variant>
25
25
#include < vector>
26
26
27
+ #include " absl/algorithm/container.h"
27
28
#include " absl/log/log.h"
28
29
#include " absl/status/status.h"
29
30
#include " absl/strings/str_cat.h"
@@ -197,32 +198,25 @@ absl::StatusOr<FusionEmissionResult> EmitDynamicSlicedGemm(
197
198
auto get_original_operand_slice =
198
199
[&](const HloInstruction* start,
199
200
const ShapeIndex& index ) -> absl::StatusOr<BufferAllocation::Slice> {
200
- if (const auto * param = DynCast<HloParameterInstruction>(start)) {
201
- return GetAllocationSlice (
202
- buffer_assignment, fusion.operand (param->parameter_number ()), index );
203
- }
204
-
201
+ auto * param = DynCast<HloParameterInstruction>(start);
205
202
auto slice_adaptor = HloFindIf (
206
203
{HloInstructionAdaptor (*start)}, adaptor,
207
204
[](auto node) { return node.opcode () == HloOpcode::kDynamicSlice ; });
208
- if (!slice_adaptor.has_value ()) {
209
- return absl::InternalError (
210
- " DynamicAddressComputationFusion expects all operands to be either "
211
- " sliced or parameter" );
212
- }
205
+ if (slice_adaptor.has_value ()) {
206
+ slice_instr = const_cast <HloDynamicIndexInstruction*>(
207
+ static_cast <const HloDynamicIndexInstruction*>(
208
+ &slice_adaptor->instruction ()));
213
209
214
- slice_instr = const_cast <HloDynamicIndexInstruction*>(
215
- static_cast <const HloDynamicIndexInstruction*>(
216
- &slice_adaptor->instruction ()));
210
+ if (!IsContiguousSlice (slice_instr->operand (0 )->shape (),
211
+ slice_instr->shape ())) {
212
+ return absl::InternalError (
213
+ " DynamicAddressComputationFusion only handles contiguous slices "
214
+ " currently" );
215
+ }
217
216
218
- if (!IsContiguousSlice (slice_instr->operand (0 )->shape (),
219
- slice_instr->shape ())) {
220
- return absl::InternalError (
221
- " DynamicAddressComputationFusion only handles contiguous slices "
222
- " currently" );
217
+ param = Cast<HloParameterInstruction>(slice_instr->operand (0 ));
223
218
}
224
219
225
- const auto * param = Cast<HloParameterInstruction>(slice_instr->operand (0 ));
226
220
return GetAllocationSlice (buffer_assignment,
227
221
fusion.operand (param->parameter_number ()), index );
228
222
};
@@ -316,6 +310,13 @@ absl::StatusOr<FusionEmissionResult> EmitDynamicSlicedGemm(
316
310
BufferAllocation::Slice (workspace->allocation (), 0 , workspace->size ());
317
311
}
318
312
313
+ if (absl::c_all_of (offset_buffer_indices, [&](auto offset_slices) {
314
+ return offset_slices == std::nullopt;
315
+ }))
316
+ return absl::InternalError (
317
+ " DynamicAddressComputationFusion expects at least one sliced "
318
+ " operand/result" );
319
+
319
320
// Creating embedded GEMM thunk.
320
321
bool deterministic_ops =
321
322
ir_emitter_context.debug_options ().xla_gpu_deterministic_ops ();
0 commit comments