Skip to content

Commit 169267b

Browse files
committed
[FLINK-27286] Add communication infra to support training high dimension models
1 parent ba327b0 commit 169267b

28 files changed

+3784
-0
lines changed

flink-ml-lib/pom.xml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,11 @@ under the License.
138138
<scope>test</scope>
139139
<type>test-jar</type>
140140
</dependency>
141+
<dependency>
142+
<groupId>it.unimi.dsi</groupId>
143+
<artifactId>fastutil</artifactId>
144+
<version>8.5.12</version>
145+
</dependency>
141146
</dependencies>
142147

143148
<build>
Lines changed: 241 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,241 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing, software
13+
* distributed under the License is distributed on an "AS IS" BASIS,
14+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
* See the License for the specific language governing permissions and
16+
* limitations under the License.
17+
*/
18+
19+
package org.apache.flink.ml.common.ps;
20+
21+
import org.apache.flink.api.common.typeutils.TypeSerializer;
22+
import org.apache.flink.api.java.tuple.Tuple2;
23+
import org.apache.flink.core.memory.DataInputViewStreamWrapper;
24+
import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
25+
import org.apache.flink.ml.util.Bits;
26+
27+
import java.io.ByteArrayInputStream;
28+
import java.io.ByteArrayOutputStream;
29+
import java.io.IOException;
30+
import java.lang.reflect.Array;
31+
import java.util.ArrayList;
32+
import java.util.Comparator;
33+
import java.util.Iterator;
34+
import java.util.List;
35+
36+
/**
37+
* {@link Message} is responsible for encoding all information exchanged between {@link
38+
* WorkerOperator} and {@link ServerOperator}. The message format follows this structure:
39+
*
40+
* <p>`workerId serverId stageId keyLength keys valuesLength values`
41+
*
42+
* <p>where the message fields include the worker ID, server ID, stage ID, length of the keys, keys
43+
* themselves, length of the values, and the values.
44+
*/
45+
public class Message {
46+
private static final int WORKER_ID_OFFSET = 0;
47+
private static final int SERVER_ID_OFFSET = Integer.BYTES;
48+
private static final int STAGE_ID_OFFSET = Integer.BYTES + SERVER_ID_OFFSET;
49+
private static final int KVS_OFFSET = Integer.BYTES + STAGE_ID_OFFSET;
50+
51+
/** The storage of message in bytes. */
52+
public final byte[] bytes;
53+
54+
/** Constructs a message instance from the bytes. */
55+
public Message(byte[] bytes) {
56+
this.bytes = bytes;
57+
}
58+
59+
/** Constructs a message instance from long keys and double values. */
60+
public Message(int workerId, int serverId, int stageId, long[] keys, double[] values) {
61+
int sizeInBytes =
62+
KVS_OFFSET
63+
+ Bits.getLongArraySizeInBytes(keys)
64+
+ Bits.getDoubleArraySizeInBytes(values);
65+
bytes = new byte[sizeInBytes];
66+
Bits.putInt(bytes, WORKER_ID_OFFSET, workerId);
67+
Bits.putInt(bytes, SERVER_ID_OFFSET, serverId);
68+
Bits.putInt(bytes, STAGE_ID_OFFSET, stageId);
69+
int offset = Bits.putLongArray(keys, bytes, KVS_OFFSET);
70+
Bits.putDoubleArray(values, bytes, offset);
71+
}
72+
73+
/** Constructs a message instance from long keys and generics values. */
74+
public <V> Message(
75+
int workerId,
76+
int serverId,
77+
int stageId,
78+
long[] keys,
79+
V[] values,
80+
TypeSerializer<V> serializer)
81+
throws IOException {
82+
ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
83+
DataOutputViewStreamWrapper dataOutputViewStreamWrapper =
84+
new DataOutputViewStreamWrapper(byteArrayOutputStream);
85+
dataOutputViewStreamWrapper.writeInt(workerId);
86+
dataOutputViewStreamWrapper.writeInt(serverId);
87+
dataOutputViewStreamWrapper.writeInt(stageId);
88+
89+
dataOutputViewStreamWrapper.writeInt(keys.length);
90+
for (long key : keys) {
91+
dataOutputViewStreamWrapper.writeLong(key);
92+
}
93+
dataOutputViewStreamWrapper.writeInt(values.length);
94+
for (V value : values) {
95+
serializer.serialize(value, dataOutputViewStreamWrapper);
96+
}
97+
bytes = byteArrayOutputStream.toByteArray();
98+
}
99+
100+
/** Retrieves the keys. */
101+
public long[] getKeys() {
102+
return Bits.getLongArray(bytes, KVS_OFFSET);
103+
}
104+
105+
/** Retrieves the values using the given serializer. */
106+
public <V> V[] getValues(TypeSerializer<V> serializer) throws IOException {
107+
int numIndices = Bits.getInt(bytes, KVS_OFFSET);
108+
int offset = KVS_OFFSET + Integer.BYTES + numIndices * Long.BYTES;
109+
int numValues = Bits.getInt(bytes, offset);
110+
offset += Integer.BYTES;
111+
112+
// Since the generics got erased, we use reflections to create the array.
113+
V[] result = (V[]) Array.newInstance(serializer.createInstance().getClass(), numValues);
114+
ByteArrayInputStream byteArrayInputStream =
115+
new ByteArrayInputStream(bytes, offset, bytes.length - offset);
116+
DataInputViewStreamWrapper dataInputViewStreamWrapper =
117+
new DataInputViewStreamWrapper(byteArrayInputStream);
118+
for (int i = 0; i < numValues; i++) {
119+
result[i] = serializer.deserialize(dataInputViewStreamWrapper);
120+
}
121+
return result;
122+
}
123+
124+
/**
125+
* Retrieves the values in double array.
126+
*
127+
* <p>Note that getting double array in this function using {@link Bits#getDoubleArray(byte[],
128+
* int)} is faster than {@link Message#getValues} by up to 2.3X.
129+
*/
130+
public double[] getValuesInDoubleArray() {
131+
int offset = KVS_OFFSET + Bits.getInt(bytes, KVS_OFFSET) * Long.BYTES + Integer.BYTES;
132+
return Bits.getDoubleArray(bytes, offset);
133+
}
134+
135+
/** Retrieves the worker id. */
136+
public int getWorkerId() {
137+
return Bits.getInt(bytes, WORKER_ID_OFFSET);
138+
}
139+
140+
/** Sets the worker id. */
141+
public void setWorkerId(int workerId) {
142+
Bits.putInt(bytes, WORKER_ID_OFFSET, workerId);
143+
}
144+
145+
/** Retrieves the server id. */
146+
public int getServerId() {
147+
return Bits.getInt(bytes, SERVER_ID_OFFSET);
148+
}
149+
150+
/** Sets the server id. */
151+
public void setServerId(int serverId) {
152+
Bits.putInt(bytes, SERVER_ID_OFFSET, serverId);
153+
}
154+
155+
public int getStageId() {
156+
return Bits.getInt(bytes, STAGE_ID_OFFSET);
157+
}
158+
159+
/**
160+
* Assembles the received messages from servers according to the server id. Note that these
161+
* message should be the responses from the same stage.
162+
*/
163+
public static Message assembleMessages(Iterator<byte[]> messageIterator) {
164+
List<Message> messages = new ArrayList<>();
165+
while (messageIterator.hasNext()) {
166+
messages.add(new Message(messageIterator.next()));
167+
}
168+
messages.sort(Comparator.comparingInt(Message::getServerId));
169+
170+
int numMessages = messages.size();
171+
int numKeys = 0, numValues = 0;
172+
int numAssembledBytes = 0;
173+
int workerId = -1;
174+
int stageId = -1;
175+
for (Message message : messages) {
176+
if (workerId == -1) {
177+
workerId = message.getWorkerId();
178+
stageId = message.getStageId();
179+
}
180+
numKeys += message.getNumKeys();
181+
numValues += message.getNumValues();
182+
numAssembledBytes += message.bytes.length;
183+
}
184+
numAssembledBytes -= (numMessages - 1) * (KVS_OFFSET + Integer.BYTES * 2);
185+
byte[] assembledBytes = new byte[numAssembledBytes];
186+
Bits.putInt(assembledBytes, WORKER_ID_OFFSET, workerId);
187+
Bits.putInt(assembledBytes, STAGE_ID_OFFSET, stageId);
188+
int keysOffset = KVS_OFFSET;
189+
Bits.putInt(assembledBytes, keysOffset, numKeys);
190+
keysOffset += Integer.BYTES;
191+
int valuesOffset = keysOffset + numKeys * Long.BYTES;
192+
Bits.putInt(assembledBytes, valuesOffset, numValues);
193+
valuesOffset += Integer.BYTES;
194+
195+
for (Message message : messages) {
196+
Tuple2<Integer, Integer> keysOffsetAndLength = message.getKeysOffsetAndLength();
197+
System.arraycopy(
198+
message.bytes,
199+
keysOffsetAndLength.f0,
200+
assembledBytes,
201+
keysOffset,
202+
keysOffsetAndLength.f1);
203+
keysOffset += keysOffsetAndLength.f1;
204+
Tuple2<Integer, Integer> valuesOffsetAndLength = message.getValuesOffSetAndLength();
205+
System.arraycopy(
206+
message.bytes,
207+
valuesOffsetAndLength.f0,
208+
assembledBytes,
209+
valuesOffset,
210+
valuesOffsetAndLength.f1);
211+
valuesOffset += valuesOffsetAndLength.f1;
212+
}
213+
214+
Message message = new Message(assembledBytes);
215+
message.setServerId(-1);
216+
return message;
217+
}
218+
219+
private Tuple2<Integer, Integer> getKeysOffsetAndLength() {
220+
int start = KVS_OFFSET + Integer.BYTES;
221+
int numBytes = Bits.getInt(bytes, KVS_OFFSET) * Long.BYTES;
222+
return Tuple2.of(start, numBytes);
223+
}
224+
225+
private Tuple2<Integer, Integer> getValuesOffSetAndLength() {
226+
int start =
227+
Bits.getInt(bytes, KVS_OFFSET) * Long.BYTES
228+
+ KVS_OFFSET
229+
+ Integer.BYTES
230+
+ Integer.BYTES;
231+
return Tuple2.of(start, bytes.length - start);
232+
}
233+
234+
private int getNumKeys() {
235+
return Bits.getInt(bytes, KVS_OFFSET);
236+
}
237+
238+
private int getNumValues() {
239+
return Bits.getInt(bytes, KVS_OFFSET + Integer.BYTES + Long.BYTES * getNumKeys());
240+
}
241+
}

0 commit comments

Comments
 (0)