Skip to content

Commit 6d8b511

Browse files
xunnanxufacebook-github-bot
authored andcommitted
reduce gloo rendezvous cost for TCP backend
Summary: The original version would repeat the local device address N times in the payload sent to store during mesh connection rendezvous. This yields quadratic payload sent to store (N ranks, each sending N). During later get, there's an outer for loop to grab the remote pair info, discard most of them and just use one (that matches the current rank). In total this op is cubic which is not efficient for large scale jobs. In reality if one device is used, for TCP based connection, only the seq number is different, the device addresses are all the same. This change aims to reduce the payload size sent to store to linear (approx. almost constant -> 1 addr + N seq numbers). In total this would make the original cubic op to quadratic (or more like O(1.xN) if that makes sense) Reviewed By: bmaurer Differential Revision: D45740631 fbshipit-source-id: aa089bfd81f3d392c0aa23a65f497bbacbdf0384
1 parent c6f3a5b commit 6d8b511

File tree

14 files changed

+331
-76
lines changed

14 files changed

+331
-76
lines changed

gloo/common/CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,16 @@
11
set(GLOO_COMMON_SRCS
22
"${CMAKE_CURRENT_SOURCE_DIR}/logging.cc"
3+
"${CMAKE_CURRENT_SOURCE_DIR}/utils.cc"
34
)
45

56
set(GLOO_COMMON_HDRS
67
"${CMAKE_CURRENT_SOURCE_DIR}/aligned_allocator.h"
78
"${CMAKE_CURRENT_SOURCE_DIR}/common.h"
89
"${CMAKE_CURRENT_SOURCE_DIR}/error.h"
910
"${CMAKE_CURRENT_SOURCE_DIR}/logging.h"
11+
"${CMAKE_CURRENT_SOURCE_DIR}/store.h"
1012
"${CMAKE_CURRENT_SOURCE_DIR}/string.h"
13+
"${CMAKE_CURRENT_SOURCE_DIR}/utils.h"
1114
)
1215

1316
if(${CMAKE_SYSTEM_NAME} STREQUAL "Linux")

gloo/common/store.h

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
/**
2+
* Copyright (c) 2017-present, Facebook, Inc.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include <chrono>
12+
#include <string>
13+
#include <vector>
14+
15+
namespace gloo {
16+
17+
class IStore {
18+
public:
19+
virtual ~IStore() = default;
20+
21+
virtual void set(const std::string& key, const std::vector<char>& data) = 0;
22+
23+
virtual std::vector<char> get(const std::string& key) = 0;
24+
25+
virtual void wait(
26+
const std::vector<std::string>& keys,
27+
const std::chrono::milliseconds& timeout) = 0;
28+
};
29+
30+
} // namespace gloo

gloo/common/utils.cc

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
/**
2+
* Copyright (c) 2017-present, Facebook, Inc.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <system_error>
10+
#include <unistd.h>
11+
12+
#include "gloo/common/utils.h"
13+
14+
namespace gloo {
15+
16+
constexpr int HOSTNAME_MAX_SIZE = 192;
17+
18+
std::string getHostname() {
19+
// Get Hostname using syscall
20+
char hostname[HOSTNAME_MAX_SIZE]; // NOLINT
21+
int rv = gethostname(hostname, HOSTNAME_MAX_SIZE);
22+
if (rv != 0) {
23+
throw std::system_error(errno, std::system_category());
24+
}
25+
return std::string(hostname);
26+
}
27+
28+
} // namespace gloo

gloo/common/utils.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
/**
2+
* Copyright (c) 2017-present, Facebook, Inc.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include <string>
12+
13+
namespace gloo {
14+
15+
std::string getHostname();
16+
17+
} // namespace gloo

gloo/rendezvous/context.cc

Lines changed: 3 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88

99
#include "gloo/rendezvous/context.h"
1010

11+
#include <memory>
12+
1113
#include "gloo/common/logging.h"
1214
#include "gloo/transport/address.h"
1315

@@ -30,85 +32,13 @@ Context::Context(int rank, int size, int base)
3032
Context::~Context() {
3133
}
3234

33-
std::vector<char> Context::extractAddress(
34-
std::vector<char>& allAddrs, int i) {
35-
// Extract address from the list of all addresses
36-
int adjRank = (rank > i ? rank - 1 : rank);
37-
// Adjust for the fact that nodes do not store address for themselves
38-
int addrSize = allAddrs.size() / (size - 1);
39-
return std::vector<char>(allAddrs.begin() + adjRank * addrSize,
40-
allAddrs.begin() + (adjRank + 1) * addrSize);
41-
}
42-
4335
void Context::connectFullMesh(
4436
rendezvous::Store& store,
4537
std::shared_ptr<transport::Device>& dev) {
46-
std::vector<char> allBytes;
47-
int localRank = 0;
48-
49-
// Get Hostname using syscall
50-
char hostname[HOSTNAME_MAX_SIZE]; // NOLINT
51-
int rv = gethostname(hostname, HOSTNAME_MAX_SIZE);
52-
if (rv != 0) {
53-
throw std::system_error(errno, std::system_category());
54-
}
55-
56-
auto localHostName = std::string(hostname);
57-
// Add global rank <> hostname pair to the Store. This store is then passed
58-
// to Gloo when connectFullMesh is called, where Gloo uses the global rank <>
59-
// hostname mapping to compute local ranks.
60-
std::string localKey("rank_" + std::to_string(rank));
61-
const std::vector<char> value(localHostName.begin(), localHostName.end());
62-
store.set(localKey, value);
63-
64-
for (int i = 0; i < size; i++) {
65-
if (i == rank) {
66-
break;
67-
}
68-
69-
std::string key("rank_" + std::to_string(i));
70-
auto val = store.get(key);
71-
auto hostName = std::string((const char*)val.data(), val.size());
72-
73-
if (hostName == localHostName) {
74-
localRank++;
75-
}
76-
}
77-
78-
// Create pairs
7938
auto transportContext = dev->createContext(rank, size);
8039
transportContext->setTimeout(getTimeout());
81-
for (int i = 0; i < size; i++) {
82-
if (i == rank) {
83-
continue;
84-
}
85-
86-
auto& pair = transportContext->createPair(i);
87-
pair->setLocalRank(localRank);
88-
auto addrBytes = pair->address().bytes();
89-
allBytes.insert(allBytes.end(), addrBytes.begin(), addrBytes.end());
90-
}
91-
92-
std::ostringstream storeKey;
93-
storeKey << rank;
94-
store.set(storeKey.str(), allBytes);
9540

96-
// Connect every pair
97-
for (int i = 0; i < size; i++) {
98-
if (i == rank) {
99-
continue;
100-
}
101-
102-
// Wait for address of other side of this pair to become available
103-
std::ostringstream key;
104-
key << i;
105-
store.wait({key.str()}, getTimeout());
106-
107-
// Connect to other side of this pair
108-
auto allAddrs = store.get(key.str());
109-
auto addr = extractAddress(allAddrs, i);
110-
transportContext->getPair(i)->connect(addr);
111-
}
41+
transportContext->createAndConnectAllPairs(store);
11242

11343
device_ = dev;
11444
transportContext_ = std::move(transportContext);

gloo/rendezvous/context.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,11 @@
99
#pragma once
1010

1111
#include <memory>
12+
#include <string>
1213
#include <vector>
1314

1415
#include "gloo/common/error.h"
16+
#include "gloo/common/store.h"
1517
#include "gloo/context.h"
1618
#include "gloo/rendezvous/store.h"
1719
#include "gloo/transport/address.h"
@@ -32,8 +34,6 @@ class Context : public ::gloo::Context {
3234
std::shared_ptr<transport::Device>& dev);
3335

3436
protected:
35-
std::vector<char> extractAddress(std::vector<char>& allAddrs, int i);
36-
3737
friend class ContextFactory;
3838
};
3939

gloo/rendezvous/store.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,15 @@
1414

1515
#include "gloo/common/logging.h"
1616
#include "gloo/common/error.h"
17+
#include "gloo/common/store.h"
1718

1819
//can be used by upstream users to know whether this is available or not.
1920
#define GLOO_STORE_HAS_STORE_V2 1
2021

2122
namespace gloo {
2223
namespace rendezvous {
2324

24-
class Store {
25+
class Store: public IStore {
2526
public:
2627
static constexpr std::chrono::milliseconds kDefaultTimeout =
2728
std::chrono::seconds(30);

gloo/transport/address.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ class Address {
2222
virtual ~Address() = 0;
2323

2424
virtual std::string str() const = 0;
25+
2526
virtual std::vector<char> bytes() const = 0;
2627
};
2728

gloo/transport/context.cc

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
* LICENSE file in the root directory of this source tree.
77
*/
88

9+
#include "gloo/common/utils.h"
910
#include "gloo/transport/context.h"
1011

1112
namespace gloo {
@@ -22,6 +23,81 @@ std::unique_ptr<transport::Pair>& Context::getPair(int rank) {
2223
return pairs_.at(rank);
2324
}
2425

26+
void Context::createAndConnectAllPairs(IStore& store) {
27+
// this is the default un-optimized version of the rendezvous protocol
28+
// where each rank would write N pairs to the store
29+
// and then for each remote peer load the N addresses
30+
// and only pick the 1 useful
31+
// A more efficient version (for transport supporting multiplexing like TCP)
32+
// can be seen in gloo/transport/tcp/context.cc
33+
34+
std::vector<char> allBytes;
35+
int localRank = 0;
36+
37+
auto localHostName = getHostname();
38+
// Add global rank <> hostname pair to the Store. This store is then passed
39+
// to Gloo when connectFullMesh is called, where Gloo uses the global rank <>
40+
// hostname mapping to compute local ranks.
41+
std::string localKey("rank_" + std::to_string(rank));
42+
const std::vector<char> value(localHostName.begin(), localHostName.end());
43+
store.set(localKey, value);
44+
45+
for (int i = 0; i < size; i++) {
46+
if (i == rank) {
47+
break;
48+
}
49+
50+
std::string key("rank_" + std::to_string(i));
51+
auto val = store.get(key);
52+
auto hostName = std::string((const char*)val.data(), val.size());
53+
54+
if (hostName == localHostName) {
55+
localRank++;
56+
}
57+
}
58+
59+
// Create pairs
60+
for (int i = 0; i < size; i++) {
61+
if (i == rank) {
62+
continue;
63+
}
64+
65+
auto& pair = createPair(i);
66+
pair->setLocalRank(localRank);
67+
auto addrBytes = pair->address().bytes();
68+
allBytes.insert(allBytes.end(), addrBytes.begin(), addrBytes.end());
69+
}
70+
71+
store.set(std::to_string(rank), allBytes);
72+
73+
// Connect every pair
74+
for (int i = 0; i < size; i++) {
75+
if (i == rank) {
76+
continue;
77+
}
78+
79+
// Wait for address of other side of this pair to become available
80+
std::ostringstream key;
81+
key << i;
82+
store.wait({key.str()}, getTimeout());
83+
84+
// Connect to other side of this pair
85+
auto allAddrs = store.get(key.str());
86+
auto addr = extractAddress(allAddrs, i);
87+
getPair(i)->connect(addr);
88+
}
89+
}
90+
91+
std::vector<char> Context::extractAddress(
92+
const std::vector<char>& allAddrs, int i) const {
93+
// Extract address from the list of all addresses
94+
int adjRank = (rank > i ? rank - 1 : rank);
95+
// Adjust for the fact that nodes do not store address for themselves
96+
int addrSize = allAddrs.size() / (size - 1);
97+
return std::vector<char>(allAddrs.begin() + adjRank * addrSize,
98+
allAddrs.begin() + (adjRank + 1) * addrSize);
99+
}
100+
25101
Context::LazyTally::LazyTally(std::vector<Tally>& vec, slot_t slot)
26102
: vec_(vec), slot_(slot), initialized_(false) {}
27103

gloo/transport/context.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include <unordered_map>
1818
#include <vector>
1919

20+
#include "gloo/common/store.h"
2021
#include "gloo/transport/pair.h"
2122
#include "gloo/transport/unbound_buffer.h"
2223

@@ -48,6 +49,8 @@ class Context {
4849

4950
virtual std::unique_ptr<Pair>& createPair(int rank) = 0;
5051

52+
virtual void createAndConnectAllPairs(IStore& store);
53+
5154
// Creates unbound buffer to be used with the ranks in this context.
5255
// It is not bound to a specific rank, but still bound to this
5356
// context. This is needed to support recv-from-any semantics, where
@@ -90,6 +93,8 @@ class Context {
9093
// any kind of send/recv operation.
9194
std::chrono::milliseconds timeout_;
9295

96+
std::vector<char> extractAddress(const std::vector<char>& allAddrs, int i) const;
97+
9398
protected:
9499
// Keep track of pending send and recv notifications or operations
95100
// for a single slot.

gloo/transport/ibverbs/address.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@ class Address : public ::gloo::transport::Address {
3131
virtual std::string str() const override;
3232

3333
protected:
34+
explicit Address(const Address&) = default;
35+
3436
struct {
3537
uint32_t lid;
3638
uint32_t qpn;

gloo/transport/tcp/address.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include "gloo/transport/tcp/address.h"
1010

1111
#include <arpa/inet.h>
12+
#include <memory>
1213
#include <mutex>
1314
#include <string.h>
1415

0 commit comments

Comments
 (0)