Skip to content

Commit 7a1e4cb

Browse files
authored
C++ implementation of Prim's Minimum Spanning Tree Algorithm (#685)
1 parent e520be3 commit 7a1e4cb

File tree

8 files changed

+319
-35
lines changed

8 files changed

+319
-35
lines changed

pydatastructs/graphs/_backend/cpp/Algorithms.hpp

Lines changed: 150 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,11 @@
33
#include <queue>
44
#include <string>
55
#include <unordered_set>
6+
#include <variant>
7+
#include "GraphEdge.hpp"
68
#include "AdjacencyList.hpp"
79
#include "AdjacencyMatrix.hpp"
810

9-
1011
static PyObject* breadth_first_search_adjacency_list(PyObject* self, PyObject* args, PyObject* kwargs) {
1112
PyObject* graph_obj;
1213
const char* source_name;
@@ -153,3 +154,151 @@ static PyObject* breadth_first_search_adjacency_matrix(PyObject* self, PyObject*
153154

154155
Py_RETURN_NONE;
155156
}
157+
158+
static PyObject* minimum_spanning_tree_prim_adjacency_list(PyObject* self, PyObject* args, PyObject* kwargs) {
159+
160+
PyObject* graph_obj;
161+
static const char* kwlist[] = {"graph", nullptr};
162+
163+
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O!", const_cast<char**>(kwlist),
164+
&AdjacencyListGraphType, &graph_obj)) {
165+
return nullptr;
166+
}
167+
168+
AdjacencyListGraph* graph = reinterpret_cast<AdjacencyListGraph*>(graph_obj);
169+
170+
struct EdgeTuple {
171+
std::string source;
172+
std::string target;
173+
std::variant<std::monostate, int64_t, double, std::string> value;
174+
DataType value_type;
175+
176+
bool operator>(const EdgeTuple& other) const {
177+
if (value_type != other.value_type)
178+
return value_type > other.value_type;
179+
if (std::holds_alternative<int64_t>(value))
180+
return std::get<int64_t>(value) > std::get<int64_t>(other.value);
181+
if (std::holds_alternative<double>(value))
182+
return std::get<double>(value) > std::get<double>(other.value);
183+
if (std::holds_alternative<std::string>(value))
184+
return std::get<std::string>(value) > std::get<std::string>(other.value);
185+
return false;
186+
}
187+
};
188+
189+
std::priority_queue<EdgeTuple, std::vector<EdgeTuple>, std::greater<>> pq;
190+
std::unordered_set<std::string> visited;
191+
192+
PyObject* mst_graph = PyObject_CallObject(reinterpret_cast<PyObject*>(&AdjacencyListGraphType), nullptr);
193+
AdjacencyListGraph* mst = reinterpret_cast<AdjacencyListGraph*>(mst_graph);
194+
195+
std::string start = graph->node_map.begin()->first;
196+
visited.insert(start);
197+
198+
AdjacencyListGraphNode* start_node = graph->node_map[start];
199+
200+
Py_INCREF(start_node);
201+
mst->nodes.push_back(start_node);
202+
mst->node_map[start] = start_node;
203+
204+
for (const auto& [adj_name, _] : start_node->adjacent) {
205+
std::string key = make_edge_key(start, adj_name);
206+
GraphEdge* edge = graph->edges[key];
207+
EdgeTuple et;
208+
et.source = start;
209+
et.target = adj_name;
210+
et.value_type = edge->value_type;
211+
212+
switch (edge->value_type) {
213+
case DataType::Int:
214+
et.value = std::get<int64_t>(edge->value);
215+
break;
216+
case DataType::Double:
217+
et.value = std::get<double>(edge->value);
218+
break;
219+
case DataType::String:
220+
et.value = std::get<std::string>(edge->value);
221+
break;
222+
default:
223+
et.value = std::monostate{};
224+
}
225+
226+
pq.push(et);
227+
}
228+
229+
while (!pq.empty()) {
230+
EdgeTuple edge = pq.top();
231+
pq.pop();
232+
233+
if (visited.count(edge.target)) continue;
234+
visited.insert(edge.target);
235+
236+
for (const std::string& name : {edge.source, edge.target}) {
237+
if (!mst->node_map.count(name)) {
238+
AdjacencyListGraphNode* node = graph->node_map[name];
239+
Py_INCREF(node);
240+
mst->nodes.push_back(node);
241+
mst->node_map[name] = node;
242+
}
243+
}
244+
245+
AdjacencyListGraphNode* u = mst->node_map[edge.source];
246+
AdjacencyListGraphNode* v = mst->node_map[edge.target];
247+
248+
Py_INCREF(v);
249+
Py_INCREF(u);
250+
u->adjacent[edge.target] = reinterpret_cast<PyObject*>(v);
251+
v->adjacent[edge.source] = reinterpret_cast<PyObject*>(u);
252+
253+
std::string key_uv = make_edge_key(edge.source, edge.target);
254+
GraphEdge* new_edge = PyObject_New(GraphEdge, &GraphEdgeType);
255+
PyObject_Init(reinterpret_cast<PyObject*>(new_edge), &GraphEdgeType);
256+
new (&new_edge->value) std::variant<std::monostate, int64_t, double, std::string>(edge.value);
257+
new_edge->value_type = edge.value_type;
258+
Py_INCREF(u);
259+
Py_INCREF(v);
260+
new_edge->source = reinterpret_cast<PyObject*>(u);
261+
new_edge->target = reinterpret_cast<PyObject*>(v);
262+
mst->edges[key_uv] = new_edge;
263+
264+
std::string key_vu = make_edge_key(edge.target, edge.source);
265+
GraphEdge* new_edge_rev = PyObject_New(GraphEdge, &GraphEdgeType);
266+
PyObject_Init(reinterpret_cast<PyObject*>(new_edge_rev), &GraphEdgeType);
267+
new (&new_edge_rev->value) std::variant<std::monostate, int64_t, double, std::string>(edge.value);
268+
new_edge_rev->value_type = edge.value_type;
269+
Py_INCREF(u);
270+
Py_INCREF(v);
271+
new_edge_rev->source = reinterpret_cast<PyObject *>(v);
272+
new_edge_rev->target = reinterpret_cast<PyObject*>(u);
273+
mst->edges[key_vu] = new_edge_rev;
274+
275+
AdjacencyListGraphNode* next_node = graph->node_map[edge.target];
276+
277+
for (const auto& [adj_name, _] : next_node->adjacent) {
278+
if (visited.count(adj_name)) continue;
279+
std::string key = make_edge_key(edge.target, adj_name);
280+
GraphEdge* adj_edge = graph->edges[key];
281+
EdgeTuple adj_et;
282+
adj_et.source = edge.target;
283+
adj_et.target = adj_name;
284+
adj_et.value_type = adj_edge->value_type;
285+
286+
switch (adj_edge->value_type) {
287+
case DataType::Int:
288+
adj_et.value = std::get<int64_t>(adj_edge->value);
289+
break;
290+
case DataType::Double:
291+
adj_et.value = std::get<double>(adj_edge->value);
292+
break;
293+
case DataType::String:
294+
adj_et.value = std::get<std::string>(adj_edge->value);
295+
break;
296+
default:
297+
adj_et.value = std::monostate{};
298+
}
299+
300+
pq.push(adj_et);
301+
}
302+
}
303+
return reinterpret_cast<PyObject*>(mst);
304+
}

pydatastructs/graphs/_backend/cpp/algorithms.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
static PyMethodDef AlgorithmsMethods[] = {
77
{"bfs_adjacency_list", (PyCFunction)breadth_first_search_adjacency_list, METH_VARARGS | METH_KEYWORDS, "Run BFS on adjacency list with callback"},
88
{"bfs_adjacency_matrix", (PyCFunction)breadth_first_search_adjacency_matrix, METH_VARARGS | METH_KEYWORDS, "Run BFS on adjacency matrix with callback"},
9+
{"minimum_spanning_tree_prim_adjacency_list", (PyCFunction)minimum_spanning_tree_prim_adjacency_list, METH_VARARGS | METH_KEYWORDS, "Run Prim's algorithm on adjacency list"},
910
{NULL, NULL, 0, NULL}
1011
};
1112

pydatastructs/graphs/algorithms.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -338,16 +338,20 @@ def minimum_spanning_tree(graph, algorithm, **kwargs):
338338
should be used only for such graphs. Using with other
339339
types of graphs may lead to unwanted results.
340340
"""
341-
raise_if_backend_is_not_python(
342-
minimum_spanning_tree, kwargs.get('backend', Backend.PYTHON))
343-
import pydatastructs.graphs.algorithms as algorithms
344-
func = "_minimum_spanning_tree_" + algorithm + "_" + graph._impl
345-
if not hasattr(algorithms, func):
346-
raise NotImplementedError(
347-
"Currently %s algoithm for %s implementation of graphs "
348-
"isn't implemented for finding minimum spanning trees."
349-
%(algorithm, graph._impl))
350-
return getattr(algorithms, func)(graph)
341+
backend = kwargs.get('backend', Backend.PYTHON)
342+
if backend == Backend.PYTHON:
343+
import pydatastructs.graphs.algorithms as algorithms
344+
func = "_minimum_spanning_tree_" + algorithm + "_" + graph._impl
345+
if not hasattr(algorithms, func):
346+
raise NotImplementedError(
347+
"Currently %s algoithm for %s implementation of graphs "
348+
"isn't implemented for finding minimum spanning trees."
349+
%(algorithm, graph._impl))
350+
return getattr(algorithms, func)(graph)
351+
else:
352+
from pydatastructs.graphs._backend.cpp._algorithms import minimum_spanning_tree_prim_adjacency_list
353+
if graph._impl == "adjacency_list" and algorithm == 'prim':
354+
return minimum_spanning_tree_prim_adjacency_list(graph)
351355

352356
def _minimum_spanning_tree_parallel_kruskal_adjacency_list(graph, num_threads):
353357
mst = _generate_mst_object(graph)

pydatastructs/graphs/tests/test_adjacency_list.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -67,16 +67,16 @@ def test_adjacency_list():
6767
g2.add_vertex(v)
6868
g2.add_edge('v_4', 'v', 0)
6969
g2.add_edge('v_5', 'v', 0)
70-
g2.add_edge('v_6', 'v', 0)
70+
g2.add_edge('v_6', 'v', "h")
7171
assert g2.is_adjacent('v_4', 'v') is True
7272
assert g2.is_adjacent('v_5', 'v') is True
7373
assert g2.is_adjacent('v_6', 'v') is True
7474
e1 = g2.get_edge('v_4', 'v')
7575
e2 = g2.get_edge('v_5', 'v')
7676
e3 = g2.get_edge('v_6', 'v')
77-
assert (str(e1)) == "('v_4', 'v')"
78-
assert (str(e2)) == "('v_5', 'v')"
79-
assert (str(e3)) == "('v_6', 'v')"
77+
assert (str(e1)) == "('v_4', 'v', 0)"
78+
assert (str(e2)) == "('v_5', 'v', 0)"
79+
assert (str(e3)) == "('v_6', 'v', h)"
8080
g2.remove_edge('v_4', 'v')
8181
assert g2.is_adjacent('v_4', 'v') is False
8282
g2.remove_vertex('v')

pydatastructs/graphs/tests/test_algorithms.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,46 @@ def _test_minimum_spanning_tree(func, ds, algorithm, *args):
185185
for k, v in mst.edge_weights.items():
186186
assert (k, v.value) in expected_mst
187187

188+
def _test_minimum_spanning_tree_cpp(ds, algorithm, *args):
189+
if (ds == 'List' and algorithm == "prim"):
190+
a1 = AdjacencyListGraphNode('a', 0, backend = Backend.CPP)
191+
b1 = AdjacencyListGraphNode('b', 0, backend = Backend.CPP)
192+
c1 = AdjacencyListGraphNode('c', 0, backend = Backend.CPP)
193+
d1 = AdjacencyListGraphNode('d', 0, backend = Backend.CPP)
194+
e1 = AdjacencyListGraphNode('e', 0, backend = Backend.CPP)
195+
g = Graph(a1, b1, c1, d1, e1, backend = Backend.CPP)
196+
g.add_edge(a1.name, c1.name, 10)
197+
g.add_edge(c1.name, a1.name, 10)
198+
g.add_edge(a1.name, d1.name, 7)
199+
g.add_edge(d1.name, a1.name, 7)
200+
g.add_edge(c1.name, d1.name, 9)
201+
g.add_edge(d1.name, c1.name, 9)
202+
g.add_edge(d1.name, b1.name, 32)
203+
g.add_edge(b1.name, d1.name, 32)
204+
g.add_edge(d1.name, e1.name, 23)
205+
g.add_edge(e1.name, d1.name, 23)
206+
mst = minimum_spanning_tree(g, "prim", backend = Backend.CPP)
207+
expected_mst = ["('a', 'd', 7)", "('d', 'c', 9)", "('e', 'd', 23)", "('b', 'd', 32)",
208+
"('d', 'a', 7)", "('c', 'd', 9)", "('d', 'e', 23)", "('d', 'b', 32)"]
209+
assert str(mst.get_edge('a', 'd')) in expected_mst
210+
assert str(mst.get_edge('e', 'd')) in expected_mst
211+
assert str(mst.get_edge('d', 'c')) in expected_mst
212+
assert str(mst.get_edge('b', 'd')) in expected_mst
213+
assert mst.num_edges() == 8
214+
a=AdjacencyListGraphNode('0', 0, backend = Backend.CPP)
215+
b=AdjacencyListGraphNode('1', 0, backend = Backend.CPP)
216+
c=AdjacencyListGraphNode('2', 0, backend = Backend.CPP)
217+
d=AdjacencyListGraphNode('3', 0, backend = Backend.CPP)
218+
g2 = Graph(a,b,c,d,backend = Backend.CPP)
219+
g2.add_edge('0', '1', 74)
220+
g2.add_edge('1', '0', 74)
221+
g2.add_edge('0', '3', 55)
222+
g2.add_edge('3', '0', 55)
223+
g2.add_edge('1', '2', 74)
224+
g2.add_edge('2', '1', 74)
225+
mst2=minimum_spanning_tree(g2, "prim", backend = Backend.CPP)
226+
assert mst2.num_edges() == 6
227+
188228
fmst = minimum_spanning_tree
189229
fmstp = minimum_spanning_tree_parallel
190230
_test_minimum_spanning_tree(fmst, "List", "kruskal")
@@ -193,6 +233,7 @@ def _test_minimum_spanning_tree(func, ds, algorithm, *args):
193233
_test_minimum_spanning_tree(fmstp, "List", "kruskal", 3)
194234
_test_minimum_spanning_tree(fmstp, "Matrix", "kruskal", 3)
195235
_test_minimum_spanning_tree(fmstp, "List", "prim", 3)
236+
_test_minimum_spanning_tree_cpp("List", "prim")
196237

197238
def test_strongly_connected_components():
198239

0 commit comments

Comments
 (0)