Skip to content

Commit

Permalink
PullRequest: 873 [Shuffle] refactor shuffle singleton classes (#445)
Browse files Browse the repository at this point in the history
Merge branch opensource_shuffle_refactor of [email protected]:AntGraph/GeaFlow.git into dev_opensource
https://code.alipay.com/AntGraph/GeaFlow/pull_requests/873

Signed-off-by: 知尘 <[email protected]>

* PullRequest: 863 [Shuffle] refactor shuffle singleton classes
* fix ut
* update

Co-authored-by: 唉唉 <[email protected]>
  • Loading branch information
qingwen220 and 唉唉 authored Jan 17, 2025
1 parent 69b4ce0 commit a421b12
Show file tree
Hide file tree
Showing 31 changed files with 242 additions and 156 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import com.antgroup.geaflow.shuffle.api.writer.IShuffleWriter;
import com.antgroup.geaflow.shuffle.api.writer.IWriterContext;
import com.antgroup.geaflow.shuffle.api.writer.WriterContext;
import com.antgroup.geaflow.shuffle.config.ShuffleConfig;
import com.antgroup.geaflow.shuffle.desc.IOutputDesc;
import com.antgroup.geaflow.shuffle.desc.OutputType;
import com.antgroup.geaflow.shuffle.message.Shard;
Expand Down Expand Up @@ -87,6 +88,7 @@ public void update(UpdateEmitterRequest request) {

int outputNum = outputDescList.size();
AtomicBoolean[] flags = new AtomicBoolean[outputNum];
ShuffleConfig shuffleConfig = ShuffleManager.getInstance().getShuffleConfig();
for (int i = 0; i < outputNum; i++) {
IOutputDesc outputDesc = outputDescList.get(i);
if (outputDesc.getType() == OutputType.RESPONSE) {
Expand All @@ -102,7 +104,7 @@ public void update(UpdateEmitterRequest request) {
IWriterContext writerContext = WriterContext.newBuilder()
.setPipelineId(request.getPipelineId())
.setPipelineName(request.getPipelineName())
.setConfig(initEmitterRequest.getConfiguration())
.setConfig(shuffleConfig)
.setVertexId(forwardOutputDesc.getPartitioner().getOpId())
.setEdgeId(forwardOutputDesc.getEdgeId())
.setTaskId(taskArgs.getTaskId())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import com.antgroup.geaflow.shuffle.pipeline.slice.SliceManager;
import com.antgroup.geaflow.shuffle.pipeline.slice.SpillablePipelineSlice;
import com.antgroup.geaflow.shuffle.serialize.AbstractMessageIterator;
import com.antgroup.geaflow.shuffle.service.ShuffleManager;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;

Expand All @@ -39,7 +40,8 @@ public class PrefetchMessageBuffer<T> implements IInputMessageBuffer<T> {
public PrefetchMessageBuffer(String logTag, SliceId sliceId) {
this.slice = new SpillablePipelineSlice(logTag, sliceId);
this.edgeId = sliceId.getEdgeId();
SliceManager.getInstance().register(sliceId, this.slice);
SliceManager sliceManager = ShuffleManager.getInstance().getSliceManager();
sliceManager.register(sliceId, this.slice);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import com.antgroup.geaflow.shuffle.serialize.EncoderMessageIterator;
import com.antgroup.geaflow.shuffle.serialize.IMessageIterator;
import com.antgroup.geaflow.shuffle.serialize.MessageIterator;
import com.antgroup.geaflow.shuffle.service.ShuffleManager;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
Expand Down Expand Up @@ -198,7 +199,8 @@ private List<PipelineSliceMeta> buildPrefetchSlice(List<PipelineSliceMeta> slice
PipelineSliceMeta slice = slices.get(0);
SliceId tmp = slice.getSliceId();
SliceId sliceId = new SliceId(tmp.getPipelineId(), tmp.getEdgeId(), -1, tmp.getSliceIndex());
SpillablePipelineSlice resultSlice = (SpillablePipelineSlice) SliceManager.getInstance().getSlice(sliceId);
SliceManager sliceManager = ShuffleManager.getInstance().getSliceManager();
SpillablePipelineSlice resultSlice = (SpillablePipelineSlice) sliceManager.getSlice(sliceId);
if (resultSlice == null || !resultSlice.isReady2read() || resultSlice.isReleased()) {
throw new GeaflowRuntimeException("illegal slice: " + sliceId);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@

package com.antgroup.geaflow.shuffle.api.writer;

import com.antgroup.geaflow.common.config.Configuration;
import com.antgroup.geaflow.common.encoder.IEncoder;
import com.antgroup.geaflow.common.shuffle.DataExchangeMode;
import com.antgroup.geaflow.shuffle.config.ShuffleConfig;
import com.antgroup.geaflow.shuffle.message.PipelineInfo;
import java.io.Serializable;

Expand Down Expand Up @@ -76,7 +76,7 @@ public interface IWriterContext extends Serializable {
*
* @return configuration.
*/
Configuration getConfig();
ShuffleConfig getConfig();

/**
* Get the encoder for serialize and deserialize data.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import com.antgroup.geaflow.shuffle.serialize.EncoderRecordSerializer;
import com.antgroup.geaflow.shuffle.serialize.IRecordSerializer;
import com.antgroup.geaflow.shuffle.serialize.RecordSerializer;
import com.antgroup.geaflow.shuffle.service.ShuffleManager;
import java.io.IOException;
import java.util.List;
import java.util.Optional;
Expand Down Expand Up @@ -60,7 +61,7 @@ public abstract class ShardWriter<T, R> {

public void init(IWriterContext writerContext) {
this.writerContext = writerContext;
this.shuffleConfig = ShuffleConfig.getInstance();
this.shuffleConfig = writerContext.getConfig();
this.writeMetrics = new ShuffleWriteMetrics();

this.pipelineId = writerContext.getPipelineInfo().getPipelineId();
Expand Down Expand Up @@ -91,11 +92,12 @@ private BufferBuilder[] buildBufferBuilder(int channels) {
protected IPipelineSlice[] buildResultSlices(int channels, int refCount) {
IPipelineSlice[] slices = new IPipelineSlice[channels];
WriterId writerId = new WriterId(this.pipelineId, this.edgeId, this.taskIndex);
SliceManager sliceManager = ShuffleManager.getInstance().getSliceManager();
for (int i = 0; i < channels; i++) {
SliceId sliceId = new SliceId(writerId, i);
IPipelineSlice slice = this.newSlice(this.taskLogTag, sliceId, refCount);
slices[i] = slice;
SliceManager.getInstance().register(sliceId, slice);
sliceManager.register(sliceId, slice);
}
return slices;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@

package com.antgroup.geaflow.shuffle.api.writer;

import com.antgroup.geaflow.common.config.Configuration;
import com.antgroup.geaflow.common.encoder.IEncoder;
import com.antgroup.geaflow.common.shuffle.DataExchangeMode;
import com.antgroup.geaflow.shuffle.config.ShuffleConfig;
import com.antgroup.geaflow.shuffle.message.PipelineInfo;

public class WriterContext implements IWriterContext {
Expand All @@ -29,7 +29,7 @@ public class WriterContext implements IWriterContext {
private String taskName;
private DataExchangeMode dataExchangeMode;
private int targetChannels;
private Configuration config;
private ShuffleConfig config;
private int refCount;
private IEncoder<?> encoder;

Expand Down Expand Up @@ -72,7 +72,7 @@ public WriterContext setChannelNum(int targetChannels) {
return this;
}

public WriterContext setConfig(Configuration config) {
public WriterContext setConfig(ShuffleConfig config) {
this.config = config;
return this;
}
Expand Down Expand Up @@ -118,7 +118,7 @@ public int getTaskIndex() {
}

@Override
public Configuration getConfig() {
public ShuffleConfig getConfig() {
return config;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@

import com.antgroup.geaflow.common.config.Configuration;
import com.antgroup.geaflow.common.shuffle.StorageLevel;
import com.google.common.annotations.VisibleForTesting;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand Down Expand Up @@ -95,25 +94,7 @@ public class ShuffleConfig {
private final int flushBufferTimeoutMs;
private final StorageLevel storageLevel;

private static ShuffleConfig INSTANCE;

public static synchronized ShuffleConfig getInstance(Configuration config) {
if (INSTANCE == null) {
INSTANCE = new ShuffleConfig(config);
}
return INSTANCE;
}

@VisibleForTesting
public static synchronized void reset(Configuration config) {
INSTANCE = new ShuffleConfig(config);
}

public static ShuffleConfig getInstance() {
return INSTANCE;
}

private ShuffleConfig(Configuration config) {
public ShuffleConfig(Configuration config) {
this.configuration = config;

// netty
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,20 @@

package com.antgroup.geaflow.shuffle.pipeline.buffer;

import com.antgroup.geaflow.shuffle.service.ShuffleManager;

public abstract class AbstractBuffer implements OutBuffer {

private final boolean memoryTrack;
private final ShuffleMemoryTracker memoryTracker;
protected int refCount;

public AbstractBuffer(boolean memoryTrack) {
this.memoryTrack = memoryTrack;
public AbstractBuffer(boolean enableMemoryTrack) {
this.memoryTracker = enableMemoryTrack
? ShuffleManager.getInstance().getShuffleMemoryTracker() : null;
}

public AbstractBuffer(ShuffleMemoryTracker memoryTracker) {
this.memoryTracker = memoryTracker;
}

@Override
Expand All @@ -33,20 +40,15 @@ public boolean isDisposable() {
return this.refCount <= 0;
}

@Override
public boolean isMemoryTracking() {
return this.memoryTrack;
}

protected void requireMemory(long dataSize) {
if (this.memoryTrack) {
ShuffleMemoryTracker.getInstance().requireMemory(dataSize);
if (this.memoryTracker != null) {
memoryTracker.requireMemory(dataSize);
}
}

protected void releaseMemory(long dataSize) {
if (this.memoryTrack) {
ShuffleMemoryTracker.getInstance().releaseMemory(dataSize);
if (this.memoryTracker != null) {
memoryTracker.releaseMemory(dataSize);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,12 @@ public HeapBuffer(byte[] bytes) {
this(bytes, true);
}

public HeapBuffer(byte[] bytes, ShuffleMemoryTracker memoryTracker) {
super(memoryTracker);
this.bytes = bytes;
this.requireMemory(bytes.length);
}

public HeapBuffer(byte[] bytes, boolean memoryTrack) {
super(memoryTrack);
this.bytes = bytes;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,13 +64,6 @@ public interface OutBuffer {
*/
boolean isDisposable();

/**
* Check if this buffer support memory track.
*
* @return if support memory track.
*/
boolean isMemoryTracking();

/**
* Release this buffer.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import static com.antgroup.geaflow.common.config.keys.ExecutionConfigKeys.SHUFFLE_MEMORY_SAFETY_FRACTION;

import com.antgroup.geaflow.common.config.Configuration;
import com.google.common.annotations.VisibleForTesting;
import java.util.concurrent.atomic.AtomicLong;
import org.apache.commons.io.FileUtils;
import org.slf4j.Logger;
Expand All @@ -33,9 +32,7 @@ public class ShuffleMemoryTracker {
private final long maxShuffleSize;
private final AtomicLong usedMemory;

private static volatile ShuffleMemoryTracker INSTANCE;

private ShuffleMemoryTracker(Configuration config) {
public ShuffleMemoryTracker(Configuration config) {
boolean memoryPool = config.getBoolean(SHUFFLE_MEMORY_POOL_ENABLE);

// Set offHeap 0 or not enable memory pool.
Expand All @@ -50,17 +47,6 @@ private ShuffleMemoryTracker(Configuration config) {
maxMemorySize / FileUtils.ONE_MB, maxShuffleSize / FileUtils.ONE_MB);
}

public static synchronized ShuffleMemoryTracker getInstance(Configuration config) {
if (INSTANCE == null) {
INSTANCE = new ShuffleMemoryTracker(config);
}
return INSTANCE;
}

public static ShuffleMemoryTracker getInstance() {
return INSTANCE;
}

public boolean requireMemory(long requiredBytes) {
if (usedMemory.get() < 0) {
LOGGER.warn("memory statistic incorrect!");
Expand Down Expand Up @@ -90,9 +76,7 @@ public double getUsedRatio() {
return usedMemory.get() * 1.0 / maxShuffleSize;
}

@VisibleForTesting
public void release() {
INSTANCE = null;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import com.antgroup.geaflow.shuffle.pipeline.slice.PipelineSliceListener;
import com.antgroup.geaflow.shuffle.pipeline.slice.PipelineSliceReader;
import com.antgroup.geaflow.shuffle.pipeline.slice.SliceManager;
import com.antgroup.geaflow.shuffle.service.ShuffleManager;
import com.antgroup.geaflow.shuffle.util.SliceNotFoundException;
import com.google.common.base.Preconditions;
import java.io.IOException;
Expand Down Expand Up @@ -60,7 +61,8 @@ public void requestSlice(long batchId) throws IOException {
if (this.sliceReader == null) {
LOGGER.info("Requesting Local slice {}", this.inputSliceId);
try {
this.sliceReader = SliceManager.getInstance()
SliceManager sliceManager = ShuffleManager.getInstance().getSliceManager();
this.sliceReader = sliceManager
.createSliceReader(this.inputSliceId, this.initialBatchId, this);
} catch (SliceNotFoundException notFound) {
if (increaseBackoff()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import com.antgroup.geaflow.shuffle.network.netty.SliceOutputChannelHandler;
import com.antgroup.geaflow.shuffle.pipeline.buffer.PipeChannelBuffer;
import com.antgroup.geaflow.shuffle.pipeline.channel.ChannelId;
import com.antgroup.geaflow.shuffle.service.ShuffleManager;
import java.io.IOException;

public class SequenceSliceReader implements PipelineSliceListener {
Expand All @@ -39,7 +40,8 @@ public SequenceSliceReader(ChannelId inputChannelId, SliceOutputChannelHandler r

public void createSliceReader(SliceId sliceId, long startBatchId) throws IOException {
this.sliceId = sliceId;
this.sliceReader = SliceManager.getInstance().createSliceReader(sliceId,
ShuffleManager shuffleManager = ShuffleManager.getInstance();
this.sliceReader = shuffleManager.getSliceManager().createSliceReader(sliceId,
startBatchId, this);
notifyDataAvailable();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,21 +29,9 @@
public class SliceManager {

private static final Logger LOGGER = LoggerFactory.getLogger(SliceManager.class);
private static SliceManager INSTANCE;

private final Map<Long, Set<SliceId>> pipeline2slices = new HashMap<>();
private final Map<SliceId, IPipelineSlice> slices = new ConcurrentHashMap<>();

public static synchronized void init() {
if (INSTANCE == null) {
INSTANCE = new SliceManager();
}
}

public static SliceManager getInstance() {
return INSTANCE;
}

public void register(SliceId sliceId, IPipelineSlice slice) {
if (this.slices.containsKey(sliceId)) {
throw new GeaflowRuntimeException("slice already registered: " + sliceId);
Expand Down
Loading

0 comments on commit a421b12

Please sign in to comment.