Skip to content

KAFKA-19478 [3/N]: Use heaps to discover the least loaded process #20172

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: trunk
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,14 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Map;
import java.util.Optional;
import java.util.PriorityQueue;
import java.util.Set;
import java.util.stream.Collectors;

Expand Down Expand Up @@ -115,7 +117,7 @@ private void initialize(final GroupSpec groupSpec, final TopologyDescriber topol
Set<Integer> partitionNoSet = entry.getValue();
for (int partitionNo : partitionNoSet) {
TaskId taskId = new TaskId(entry.getKey(), partitionNo);
localState.standbyTaskToPrevMember.putIfAbsent(taskId, new HashSet<>());
localState.standbyTaskToPrevMember.putIfAbsent(taskId, new ArrayList<>());
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

minor: no need to deduplicate here, so I'd rather just use arrays

localState.standbyTaskToPrevMember.get(taskId).add(member);
}
}
Expand Down Expand Up @@ -171,38 +173,43 @@ private void assignActive(final Set<TaskId> activeTasks) {
final TaskId task = it.next();
final Member prevMember = localState.activeTaskToPrevMember.get(task);
if (prevMember != null && hasUnfulfilledQuota(prevMember)) {
localState.processIdToState.get(prevMember.processId).addTask(prevMember.memberId, task, true);
updateHelpers(prevMember, true);
ProcessState processState = localState.processIdToState.get(prevMember.processId);
processState.addTask(prevMember.memberId, task, true);
maybeUpdateTasksPerMember(processState.activeTaskCount());
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just inlining updateHelpers

it.remove();
}
}

// 2. re-assigning tasks to clients that previously have seen the same task (as standby task)
for (Iterator<TaskId> it = activeTasks.iterator(); it.hasNext();) {
final TaskId task = it.next();
final Set<Member> prevMembers = localState.standbyTaskToPrevMember.get(task);
final Member prevMember = findMemberWithLeastLoad(prevMembers, task, true);
final ArrayList<Member> prevMembers = localState.standbyTaskToPrevMember.get(task);
final Member prevMember = findPrevMemberWithLeastLoad(prevMembers, null);
if (prevMember != null && hasUnfulfilledQuota(prevMember)) {
localState.processIdToState.get(prevMember.processId).addTask(prevMember.memberId, task, true);
updateHelpers(prevMember, true);
ProcessState processState = localState.processIdToState.get(prevMember.processId);
processState.addTask(prevMember.memberId, task, true);
maybeUpdateTasksPerMember(processState.activeTaskCount());
it.remove();
}
}

// 3. assign any remaining unassigned tasks
PriorityQueue<ProcessState> processByLoad = new PriorityQueue<>(Comparator.comparingDouble(ProcessState::load));
processByLoad.addAll(localState.processIdToState.values());
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Initial build of the priority queue by load

for (Iterator<TaskId> it = activeTasks.iterator(); it.hasNext();) {
final TaskId task = it.next();
final Set<Member> allMembers = localState.processIdToState.entrySet().stream().flatMap(entry -> entry.getValue().memberToTaskCounts().keySet().stream()
.map(memberId -> new Member(entry.getKey(), memberId))).collect(Collectors.toSet());
final Member member = findMemberWithLeastLoad(allMembers, task, false);
ProcessState processWithLeastLoad = processByLoad.poll();
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

replace iteration to find least loaded member in findMemberWithLeastLoad by polling the priority queue.

if (processWithLeastLoad == null) {
throw new TaskAssignorException("No process available to assign active task {}." + task);
}
String member = memberWithLeastLoad(processWithLeastLoad);
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

memberWithLeastLoad still uses linear search within the process, as before.

if (member == null) {
log.error("Unable to assign active task {} to any member.", task);
throw new TaskAssignorException("No member available to assign active task {}." + task);
Comment on lines +203 to 207
Copy link
Preview

Copilot AI Jul 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The error message format is incorrect. The placeholder '{}' is not being used properly with string concatenation. Should be either 'No process available to assign active task ' + task + '.' or use proper string formatting.

Copilot uses AI. Check for mistakes.

Comment on lines +203 to 207
Copy link
Preview

Copilot AI Jul 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The error message format is incorrect. The placeholder '{}' is not being used properly with string concatenation. Should be either 'No member available to assign active task ' + task + '.' or use proper string formatting.

Copilot uses AI. Check for mistakes.

}
localState.processIdToState.get(member.processId).addTask(member.memberId, task, true);
processWithLeastLoad.addTask(member, task, true);
it.remove();
updateHelpers(member, true);

maybeUpdateTasksPerMember(processWithLeastLoad.activeTaskCount());
processByLoad.add(processWithLeastLoad); // Add it back to the queue after updating its state
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After we have changed the load, we need to add it back to the priority queue, so that it is inserted at the correct position

}
}

Expand All @@ -214,85 +221,125 @@ private void maybeUpdateTasksPerMember(final int activeTasksNo) {
}
}

private Member findMemberWithLeastLoad(final Set<Member> members, TaskId taskId, final boolean returnSameMember) {
private boolean assignStandbyToMemberWithLeastLoad(PriorityQueue<ProcessState> queue, TaskId taskId) {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the same as the active case assignment above. The difference is that there may be processes that already have the task, so consider more than one candidate process here (see recursion below).

ProcessState processWithLeastLoad = queue.poll();
if (processWithLeastLoad == null) {
return false;
}
boolean found = false;
if (!processWithLeastLoad.hasTask(taskId)) {
String memberId = memberWithLeastLoad(processWithLeastLoad);
if (memberId != null) {
processWithLeastLoad.addTask(memberId, taskId, false);
found = true;
}
} else if (!queue.isEmpty()) {
found = assignStandbyToMemberWithLeastLoad(queue, taskId);
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are using recursion here. If the least loaded member already has the task, we recurse to find the next least loaded member. The point is that the least-loaded process is not added back to the queue at this point.

Recursion is fine here, because we know that we get only numStandbyTasks + 1 recursions, since only numStandbyTasks processes can have the task already. By default, we only allow 2 standby replicas, so we'd get at most 3 recursive calls here.

}
queue.add(processWithLeastLoad); // Add it back to the queue after updating its state
return found;
}

/**
* Finds the previous member with the least load for a given task.
*
* @param members The list of previous members owning the task.
* @param taskId The taskId, to check if the previous member already has the task. Can be null, if we assign it
* for the first time (e.g., during active task assignment).
*
* @return Previous member with the least load that does not have the task, or null if no such member exists.
*/
private Member findPrevMemberWithLeastLoad(final ArrayList<Member> members, final TaskId taskId) {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

findPrevMemberWithLeastLoad works very similar to the old findMemberWithLeastLoad - that is, it does a linear search among a collection of candidates.

However, since we don't use it anymore to find the least loaded node among all members anymore - we use a priority queue there.

This is only used to select the least loaded node among all members that previously owned the task. I replaced the Java Streams based iteration with a loop, since it's more efficient.

if (members == null || members.isEmpty()) {
return null;
}
Optional<ProcessState> processWithLeastLoad = members.stream()
.map(member -> localState.processIdToState.get(member.processId))
.min(Comparator.comparingDouble(ProcessState::load));

// if the same exact former member is needed
if (returnSameMember) {
return localState.standbyTaskToPrevMember.get(taskId).stream()
.filter(standby -> standby.processId.equals(processWithLeastLoad.get().processId()))
.findFirst()
.orElseGet(() -> memberWithLeastLoad(processWithLeastLoad.get()));

Member candidate = members.get(0);
ProcessState candidateProcessState = localState.processIdToState.get(candidate.processId);
double candidateProcessLoad = candidateProcessState.load();
double candidateMemberLoad = candidateProcessState.memberToTaskCounts().get(candidate.memberId);
for (int i = 1; i < members.size(); i++) {
Member member = members.get(i);
ProcessState processState = localState.processIdToState.get(member.processId);
double newProcessLoad = processState.load();
if (newProcessLoad < candidateProcessLoad && (taskId == null || !processState.hasTask(taskId))) {
double newMemberLoad = processState.memberToTaskCounts().get(member.memberId);
if (newMemberLoad < candidateMemberLoad) {
candidateProcessLoad = newProcessLoad;
candidateMemberLoad = newMemberLoad;
candidate = member;
}
}
}
return memberWithLeastLoad(processWithLeastLoad.get());

if (taskId == null || !candidateProcessState.hasTask(taskId)) {
return candidate;
}
return null;
}

private Member memberWithLeastLoad(final ProcessState processWithLeastLoad) {
private String memberWithLeastLoad(final ProcessState processWithLeastLoad) {
Map<String, Integer> members = processWithLeastLoad.memberToTaskCounts();
if (members.isEmpty()) {
return null;
}
if (members.size() == 1) {
return members.keySet().iterator().next();
}
Optional<String> memberWithLeastLoad = processWithLeastLoad.memberToTaskCounts().entrySet().stream()
.min(Map.Entry.comparingByValue())
.map(Map.Entry::getKey);
return memberWithLeastLoad.map(memberId -> new Member(processWithLeastLoad.processId(), memberId)).orElse(null);
return memberWithLeastLoad.orElse(null);
}

private boolean hasUnfulfilledQuota(final Member member) {
return localState.processIdToState.get(member.processId).memberToTaskCounts().get(member.memberId) < localState.tasksPerMember;
}

private void assignStandby(final Set<TaskId> standbyTasks, final int numStandbyReplicas) {
ArrayList<StandbyToAssign> toLeastLoaded = new ArrayList<>(standbyTasks.size() * numStandbyReplicas);
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This arrayList is used to store all standby tasks that we couldn't assign to a member that previously owned that task, and needs to be assigned to the "least loaded" node.

for (TaskId task : standbyTasks) {
for (int i = 0; i < numStandbyReplicas; i++) {

final Set<String> availableProcesses = localState.processIdToState.values().stream()
.filter(process -> !process.hasTask(task))
.map(ProcessState::processId)
.collect(Collectors.toSet());

if (availableProcesses.isEmpty()) {
log.warn("{} There is not enough available capacity. " +
"You should increase the number of threads and/or application instances to maintain the requested number of standby replicas.",
errorMessage(numStandbyReplicas, i, task));
break;
}
Member standby = null;

// prev active task
Member prevMember = localState.activeTaskToPrevMember.get(task);
if (prevMember != null && availableProcesses.contains(prevMember.processId) && isLoadBalanced(prevMember.processId)) {
standby = prevMember;
if (prevMember != null) {
ProcessState prevMemberProcessState = localState.processIdToState.get(prevMember.processId);
if (!prevMemberProcessState.hasTask(task) && isLoadBalanced(prevMemberProcessState)) {
prevMemberProcessState.addTask(prevMember.memberId, task, false);
continue;
}
}

// prev standby tasks
if (standby == null) {
final Set<Member> prevMembers = localState.standbyTaskToPrevMember.get(task);
if (prevMembers != null && !prevMembers.isEmpty()) {
prevMembers.removeIf(member -> !availableProcesses.contains(member.processId));
prevMember = findMemberWithLeastLoad(prevMembers, task, true);
if (prevMember != null && isLoadBalanced(prevMember.processId)) {
standby = prevMember;
final ArrayList<Member> prevMembers = localState.standbyTaskToPrevMember.get(task);
if (prevMembers != null && !prevMembers.isEmpty()) {
prevMember = findPrevMemberWithLeastLoad(prevMembers, task);
if (prevMember != null) {
ProcessState prevMemberProcessState = localState.processIdToState.get(prevMember.processId);
if (isLoadBalanced(prevMemberProcessState)) {
prevMemberProcessState.addTask(prevMember.memberId, task, false);
continue;
}
}
}

// others
if (standby == null) {
final Set<Member> availableMembers = availableProcesses.stream()
.flatMap(pId -> localState.processIdToState.get(pId).memberToTaskCounts().keySet().stream()
.map(mId -> new Member(pId, mId))).collect(Collectors.toSet());
standby = findMemberWithLeastLoad(availableMembers, task, false);
if (standby == null) {
log.warn("{} Error in standby task assignment!", errorMessage(numStandbyReplicas, i, task));
break;
}
}
localState.processIdToState.get(standby.processId).addTask(standby.memberId, task, false);
updateHelpers(standby, false);
toLeastLoaded.add(new StandbyToAssign(task, numStandbyReplicas - i));
break;
}
}

PriorityQueue<ProcessState> processByLoad = new PriorityQueue<>(Comparator.comparingDouble(ProcessState::load));
processByLoad.addAll(localState.processIdToState.values());
for (StandbyToAssign toAssign : toLeastLoaded) {
for (int i = 0; i < toAssign.remainingReplicas; i++) {
if (!assignStandbyToMemberWithLeastLoad(processByLoad, toAssign.taskId)) {
log.warn("{} There is not enough available capacity. " +
"You should increase the number of threads and/or application instances to maintain the requested number of standby replicas.",
errorMessage(numStandbyReplicas, i, toAssign.taskId));
break;
}
}
}
}

Expand All @@ -301,21 +348,13 @@ private String errorMessage(final int numStandbyReplicas, final int i, final Tas
" of " + numStandbyReplicas + " standby tasks for task [" + task + "].";
}

private boolean isLoadBalanced(final String processId) {
final ProcessState process = localState.processIdToState.get(processId);
private boolean isLoadBalanced(final ProcessState process) {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

minor: passing the process in here saves us from looking it up in the hashmap again.

final double load = process.load();
boolean isLeastLoadedProcess = localState.processIdToState.values().stream()
.allMatch(p -> p.load() >= load);
return process.hasCapacity() || isLeastLoadedProcess;
}

private void updateHelpers(final Member member, final boolean isActive) {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just inlined this function

if (isActive) {
// update task per process
maybeUpdateTasksPerMember(localState.processIdToState.get(member.processId).activeTaskCount());
}
}

private static int computeTasksPerMember(final int numberOfTasks, final int numberOfMembers) {
if (numberOfMembers == 0) {
return 0;
Expand All @@ -327,6 +366,16 @@ private static int computeTasksPerMember(final int numberOfTasks, final int numb
return tasksPerMember;
}

static class StandbyToAssign {
private final TaskId taskId;
private final int remainingReplicas;

public StandbyToAssign(final TaskId taskId, final int remainingReplicas) {
this.taskId = taskId;
this.remainingReplicas = remainingReplicas;
}
}

static class Member {
private final String processId;
private final String memberId;
Expand All @@ -340,11 +389,11 @@ public Member(final String processId, final String memberId) {
private static class LocalState {
// helper data structures:
Map<TaskId, Member> activeTaskToPrevMember;
Map<TaskId, Set<Member>> standbyTaskToPrevMember;
Map<TaskId, ArrayList<Member>> standbyTaskToPrevMember;
Map<String, ProcessState> processIdToState;

int allTasks;
int totalCapacity;
int tasksPerMember;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -140,15 +140,19 @@ public void shouldWorkWithRebalance(


final Properties props = new Properties();
final String appId = safeUniqueTestName(testInfo);
props.put(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, bootstrapServers);
props.put(StreamsConfig.APPLICATION_ID_CONFIG, safeUniqueTestName(testInfo));
props.put(StreamsConfig.APPLICATION_ID_CONFIG, appId);
props.put(InternalConfig.STATE_UPDATER_ENABLED, stateUpdaterEnabled);
props.put(InternalConfig.PROCESSING_THREADS_ENABLED, processingThreadsEnabled);
// decrease the session timeout so that we can trigger the rebalance soon after old client left closed
props.put(ConsumerConfig.SESSION_TIMEOUT_MS_CONFIG, 10000);
props.put(ConsumerConfig.HEARTBEAT_INTERVAL_MS_CONFIG, 500);
if (streamsProtocolEnabled) {
props.put(StreamsConfig.GROUP_PROTOCOL_CONFIG, GroupProtocol.STREAMS.name().toLowerCase(Locale.getDefault()));
// decrease the session timeout so that we can trigger the rebalance soon after old client left closed
CLUSTER.setGroupSessionTimeout(appId, 10000);
CLUSTER.setGroupHeartbeatTimeout(appId, 1000);
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The integration test set the session timeout and the heartbeat interval incorrectly before in the new protocol. We need to set it on the group level.

This sometimes made the test flaky with the new assignment algorithm, since we iteratively cycle out the "oldest" member, and tend to assign the tasks from the next-oldest member. But, due to the high session timeout and heartbeat timeout, it could sometimes take too long for the new member to get the new tasks assigned, before being cycled out as well.

I don't see a problem in the assignment logic here - actually, it seems useful to assign tasks to "old" members, since they are stable. We are just cycling out the tasks to quickly in this integration test.

} else {
// decrease the session timeout so that we can trigger the rebalance soon after old client left closed
props.put(ConsumerConfig.SESSION_TIMEOUT_MS_CONFIG, 10000);
}

// cycle out Streams instances as long as the test is running.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ private Properties streamsConfiguration(final boolean streamsProtocolEnabled) {
streamsConfiguration.put(StreamsConfig.DEFAULT_VALUE_SERDE_CLASS_CONFIG, Serdes.IntegerSerde.class);
if (streamsProtocolEnabled) {
streamsConfiguration.put(StreamsConfig.GROUP_PROTOCOL_CONFIG, GroupProtocol.STREAMS.name().toLowerCase(Locale.getDefault()));
CLUSTER.setStandbyReplicas("app-" + safeTestName, 1);
CLUSTER.setGroupStandbyReplicas("app-" + safeTestName, 1);
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just renamed

} else {
streamsConfiguration.put(StreamsConfig.NUM_STANDBY_REPLICAS_CONFIG, 1);
}
Expand Down
Loading