Skip to content

Commit db5ff2b

Browse files
Prerak SinghPrerak Singh
authored andcommitted
added cpp implementation for prims
1 parent 6d81439 commit db5ff2b

File tree

7 files changed

+242
-32
lines changed

7 files changed

+242
-32
lines changed

pydatastructs/graphs/_backend/cpp/Algorithms.hpp

Lines changed: 110 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,11 @@
33
#include <queue>
44
#include <string>
55
#include <unordered_set>
6+
#include <variant>
7+
#include "GraphEdge.hpp"
68
#include "AdjacencyList.hpp"
79
#include "AdjacencyMatrix.hpp"
810

9-
1011
static PyObject* breadth_first_search_adjacency_list(PyObject* self, PyObject* args, PyObject* kwargs) {
1112
PyObject* graph_obj;
1213
const char* source_name;
@@ -153,3 +154,111 @@ static PyObject* breadth_first_search_adjacency_matrix(PyObject* self, PyObject*
153154

154155
Py_RETURN_NONE;
155156
}
157+
158+
static PyObject* minimum_spanning_tree_prim_adjacency_list(PyObject* self, PyObject* args, PyObject* kwargs) {
159+
160+
PyObject* graph_obj;
161+
static const char* kwlist[] = {"graph", nullptr};
162+
163+
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O!", const_cast<char**>(kwlist),
164+
&AdjacencyListGraphType, &graph_obj)) {
165+
return nullptr;
166+
}
167+
168+
AdjacencyListGraph* graph = reinterpret_cast<AdjacencyListGraph*>(graph_obj);
169+
170+
struct EdgeTuple {
171+
std::string source;
172+
std::string target;
173+
std::variant<std::monostate, int64_t, double, std::string> value;
174+
DataType value_type;
175+
176+
bool operator>(const EdgeTuple& other) const {
177+
if (value_type != other.value_type)
178+
return value_type > other.value_type;
179+
if (std::holds_alternative<int64_t>(value))
180+
return std::get<int64_t>(value) > std::get<int64_t>(other.value);
181+
if (std::holds_alternative<double>(value))
182+
return std::get<double>(value) > std::get<double>(other.value);
183+
if (std::holds_alternative<std::string>(value))
184+
return std::get<std::string>(value) > std::get<std::string>(other.value);
185+
return false;
186+
}
187+
};
188+
189+
std::priority_queue<EdgeTuple, std::vector<EdgeTuple>, std::greater<>> pq;
190+
std::unordered_set<std::string> visited;
191+
192+
PyObject* mst_graph = PyObject_CallObject(reinterpret_cast<PyObject*>(&AdjacencyListGraphType), nullptr);
193+
AdjacencyListGraph* mst = reinterpret_cast<AdjacencyListGraph*>(mst_graph);
194+
195+
std::string start = graph->node_map.begin()->first;
196+
visited.insert(start);
197+
198+
AdjacencyListGraphNode* start_node = graph->node_map[start];
199+
200+
Py_INCREF(start_node);
201+
mst->nodes.push_back(start_node);
202+
mst->node_map[start] = start_node;
203+
204+
for (const auto& [adj_name, _] : start_node->adjacent) {
205+
std::string key = make_edge_key(start, adj_name);
206+
GraphEdge* edge = graph->edges[key];
207+
pq.push({start, adj_name, edge->value, edge->value_type});
208+
}
209+
210+
while (!pq.empty()) {
211+
EdgeTuple edge = pq.top();
212+
pq.pop();
213+
214+
if (visited.count(edge.target)) continue;
215+
visited.insert(edge.target);
216+
217+
for (const std::string& name : {edge.source, edge.target}) {
218+
if (!mst->node_map.count(name)) {
219+
AdjacencyListGraphNode* node = graph->node_map[name];
220+
Py_INCREF(node);
221+
mst->nodes.push_back(node);
222+
mst->node_map[name] = node;
223+
}
224+
}
225+
226+
AdjacencyListGraphNode* u = mst->node_map[edge.source];
227+
AdjacencyListGraphNode* v = mst->node_map[edge.target];
228+
229+
Py_INCREF(v);
230+
Py_INCREF(u);
231+
u->adjacent[edge.target] = reinterpret_cast<PyObject*>(v);
232+
v->adjacent[edge.source] = reinterpret_cast<PyObject*>(u);
233+
234+
std::string key_uv = make_edge_key(edge.source, edge.target);
235+
GraphEdge* new_edge = PyObject_New(GraphEdge, &GraphEdgeType);
236+
Py_INCREF(u);
237+
Py_INCREF(v);
238+
new_edge->source = reinterpret_cast<PyObject*>(u);
239+
new_edge->target = reinterpret_cast<PyObject*>(v);
240+
new (&new_edge->value) std::variant<std::monostate, int64_t, double, std::string>(edge.value);
241+
new_edge->value_type = edge.value_type;
242+
mst->edges[key_uv] = new_edge;
243+
244+
std::string key_vu = make_edge_key(edge.target, edge.source);
245+
GraphEdge* new_edge_rev = PyObject_New(GraphEdge, &GraphEdgeType);
246+
new_edge_rev->source = reinterpret_cast<PyObject*>(v);
247+
new_edge_rev->target = reinterpret_cast<PyObject*>(u);
248+
new (&new_edge_rev->value) std::variant<std::monostate, int64_t, double, std::string>(edge.value);
249+
new_edge_rev->value_type = edge.value_type;
250+
mst->edges[key_vu] = new_edge_rev;
251+
252+
AdjacencyListGraphNode* next_node = graph->node_map[edge.target];
253+
254+
for (const auto& [adj_name, _] : next_node->adjacent) {
255+
if (visited.count(adj_name)) continue;
256+
std::string key = make_edge_key(edge.target, adj_name);
257+
GraphEdge* adj_edge = graph->edges[key];
258+
pq.push({edge.target, adj_name, adj_edge->value, adj_edge->value_type});
259+
}
260+
}
261+
262+
Py_INCREF(mst);
263+
return reinterpret_cast<PyObject*>(mst);
264+
}

pydatastructs/graphs/_backend/cpp/algorithms.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
static PyMethodDef AlgorithmsMethods[] = {
77
{"bfs_adjacency_list", (PyCFunction)breadth_first_search_adjacency_list, METH_VARARGS | METH_KEYWORDS, "Run BFS on adjacency list with callback"},
88
{"bfs_adjacency_matrix", (PyCFunction)breadth_first_search_adjacency_matrix, METH_VARARGS | METH_KEYWORDS, "Run BFS on adjacency matrix with callback"},
9+
{"minimum_spanning_tree_prim_adjacency_list", (PyCFunction)minimum_spanning_tree_prim_adjacency_list, METH_VARARGS | METH_KEYWORDS, "Run Prim's algorithm on adjacency list"},
910
{NULL, NULL, 0, NULL}
1011
};
1112

pydatastructs/graphs/algorithms.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -338,16 +338,20 @@ def minimum_spanning_tree(graph, algorithm, **kwargs):
338338
should be used only for such graphs. Using with other
339339
types of graphs may lead to unwanted results.
340340
"""
341-
raise_if_backend_is_not_python(
342-
minimum_spanning_tree, kwargs.get('backend', Backend.PYTHON))
343-
import pydatastructs.graphs.algorithms as algorithms
344-
func = "_minimum_spanning_tree_" + algorithm + "_" + graph._impl
345-
if not hasattr(algorithms, func):
346-
raise NotImplementedError(
347-
"Currently %s algoithm for %s implementation of graphs "
348-
"isn't implemented for finding minimum spanning trees."
349-
%(algorithm, graph._impl))
350-
return getattr(algorithms, func)(graph)
341+
backend = kwargs.get('backend', Backend.PYTHON)
342+
if backend == Backend.PYTHON:
343+
import pydatastructs.graphs.algorithms as algorithms
344+
func = "_minimum_spanning_tree_" + algorithm + "_" + graph._impl
345+
if not hasattr(algorithms, func):
346+
raise NotImplementedError(
347+
"Currently %s algoithm for %s implementation of graphs "
348+
"isn't implemented for finding minimum spanning trees."
349+
%(algorithm, graph._impl))
350+
return getattr(algorithms, func)(graph)
351+
else:
352+
from pydatastructs.graphs._backend.cpp._algorithms import minimum_spanning_tree_prim_adjacency_list
353+
if graph._impl == "adjacency_list" and algorithm == 'prim':
354+
return minimum_spanning_tree_prim_adjacency_list(graph)
351355

352356
def _minimum_spanning_tree_parallel_kruskal_adjacency_list(graph, num_threads):
353357
mst = _generate_mst_object(graph)

pydatastructs/graphs/tests/test_adjacency_list.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -67,16 +67,16 @@ def test_adjacency_list():
6767
g2.add_vertex(v)
6868
g2.add_edge('v_4', 'v', 0)
6969
g2.add_edge('v_5', 'v', 0)
70-
g2.add_edge('v_6', 'v', 0)
70+
g2.add_edge('v_6', 'v', "h")
7171
assert g2.is_adjacent('v_4', 'v') is True
7272
assert g2.is_adjacent('v_5', 'v') is True
7373
assert g2.is_adjacent('v_6', 'v') is True
7474
e1 = g2.get_edge('v_4', 'v')
7575
e2 = g2.get_edge('v_5', 'v')
7676
e3 = g2.get_edge('v_6', 'v')
77-
assert (str(e1)) == "('v_4', 'v')"
78-
assert (str(e2)) == "('v_5', 'v')"
79-
assert (str(e3)) == "('v_6', 'v')"
77+
assert (str(e1)) == "('v_4', 'v', 0)"
78+
assert (str(e2)) == "('v_5', 'v', 0)"
79+
assert (str(e3)) == "('v_6', 'v', h)"
8080
g2.remove_edge('v_4', 'v')
8181
assert g2.is_adjacent('v_4', 'v') is False
8282
g2.remove_vertex('v')

pydatastructs/graphs/tests/test_algorithms.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,33 @@ def _test_minimum_spanning_tree(func, ds, algorithm, *args):
185185
for k, v in mst.edge_weights.items():
186186
assert (k, v.value) in expected_mst
187187

188+
def _test_minimum_spanning_tree_cpp(ds, algorithm, *args):
189+
if (ds == 'List' and algorithm == "prim"):
190+
a1 = AdjacencyListGraphNode('a',0, backend = Backend.CPP)
191+
b1 = AdjacencyListGraphNode('b',0, backend = Backend.CPP)
192+
c1 = AdjacencyListGraphNode('c',0, backend = Backend.CPP)
193+
d1 = AdjacencyListGraphNode('d',0, backend = Backend.CPP)
194+
e1 = AdjacencyListGraphNode('e',0, backend = Backend.CPP)
195+
g = Graph(a1, b1, c1, d1, e1, backend = Backend.CPP)
196+
g.add_edge(a1.name, c1.name, 10)
197+
g.add_edge(c1.name, a1.name, 10)
198+
g.add_edge(a1.name, d1.name, 7)
199+
g.add_edge(d1.name, a1.name, 7)
200+
g.add_edge(c1.name, d1.name, 9)
201+
g.add_edge(d1.name, c1.name, 9)
202+
g.add_edge(d1.name, b1.name, 32)
203+
g.add_edge(b1.name, d1.name, 32)
204+
g.add_edge(d1.name, e1.name, 23)
205+
g.add_edge(e1.name, d1.name, 23)
206+
mst = minimum_spanning_tree(g, "prim", backend = Backend.CPP)
207+
expected_mst = ["('a', 'd', 7)", "('d', 'c', 9)", "('e', 'd', 23)", "('b', 'd', 32)",
208+
"('d', 'a', 7)", "('c', 'd', 9)", "('d', 'e', 23)", "('d', 'b', 32)"]
209+
assert str(mst.get_edge('a','d')) in expected_mst
210+
assert str(mst.get_edge('e','d')) in expected_mst
211+
assert str(mst.get_edge('d','c')) in expected_mst
212+
assert str(mst.get_edge('b','d')) in expected_mst
213+
assert mst.num_edges() == 8
214+
188215
fmst = minimum_spanning_tree
189216
fmstp = minimum_spanning_tree_parallel
190217
_test_minimum_spanning_tree(fmst, "List", "kruskal")
@@ -193,6 +220,7 @@ def _test_minimum_spanning_tree(func, ds, algorithm, *args):
193220
_test_minimum_spanning_tree(fmstp, "List", "kruskal", 3)
194221
_test_minimum_spanning_tree(fmstp, "Matrix", "kruskal", 3)
195222
_test_minimum_spanning_tree(fmstp, "List", "prim", 3)
223+
_test_minimum_spanning_tree_cpp("List", "prim")
196224

197225
def test_strongly_connected_components():
198226

pydatastructs/utils/_backend/cpp/GraphEdge.hpp

Lines changed: 84 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#define PY_SSIZE_T_CLEAN
55
#include <Python.h>
66
#include <string>
7+
#include <variant>
78
#include "GraphNode.hpp"
89

910
extern PyTypeObject GraphEdgeType;
@@ -12,56 +13,123 @@ typedef struct {
1213
PyObject_HEAD
1314
PyObject* source;
1415
PyObject* target;
15-
PyObject* value;
16+
std::variant<std::monostate, int64_t, double, std::string> value;
17+
DataType value_type;
1618
} GraphEdge;
1719

1820
static void GraphEdge_dealloc(GraphEdge* self) {
1921
Py_XDECREF(self->source);
2022
Py_XDECREF(self->target);
21-
Py_XDECREF(self->value);
2223
Py_TYPE(self)->tp_free(reinterpret_cast<PyObject*>(self));
2324
}
2425

2526
static PyObject* GraphEdge_new(PyTypeObject* type, PyObject* args, PyObject* kwds) {
26-
GraphEdge* self;
27-
self = reinterpret_cast<GraphEdge*>(type->tp_alloc(type, 0));
27+
GraphEdge* self = PyObject_New(GraphEdge, &GraphEdgeType);
2828
if (!self) return NULL;
2929

30+
new (&self->value) std::variant<std::monostate, int64_t, double, std::string>();
31+
self->value_type = DataType::None;
32+
3033
static char* kwlist[] = {"node1", "node2", "value", NULL};
3134
PyObject* node1;
3235
PyObject* node2;
3336
PyObject* value = Py_None;
3437

3538
if (!PyArg_ParseTupleAndKeywords(args, kwds, "OO|O", kwlist, &node1, &node2, &value)) {
36-
PyErr_SetString(PyExc_ValueError, "Invalid arguments: Expected (GraphNode, GraphNode, optional value)");
39+
PyErr_SetString(PyExc_ValueError, "Expected (GraphNode, GraphNode, optional value)");
3740
return NULL;
3841
}
3942

4043
Py_INCREF(node1);
4144
Py_INCREF(node2);
42-
Py_INCREF(value);
43-
4445
self->source = node1;
4546
self->target = node2;
46-
self->value = value;
47+
48+
if (value == Py_None) {
49+
self->value_type = DataType::None;
50+
self->value = std::monostate{};
51+
} else if (PyLong_Check(value)) {
52+
self->value_type = DataType::Int;
53+
self->value = static_cast<int64_t>(PyLong_AsLongLong(value));
54+
} else if (PyFloat_Check(value)) {
55+
self->value_type = DataType::Double;
56+
self->value = PyFloat_AsDouble(value);
57+
} else if (PyUnicode_Check(value)) {
58+
const char* str = PyUnicode_AsUTF8(value);
59+
self->value_type = DataType::String;
60+
self->value = std::string(str);
61+
} else {
62+
PyErr_SetString(PyExc_TypeError, "Unsupported edge value type (must be int, float, str, or None)");
63+
return NULL;
64+
}
4765

4866
return reinterpret_cast<PyObject*>(self);
4967
}
5068

5169
static PyObject* GraphEdge_str(GraphEdge* self) {
52-
std::string source_name = (reinterpret_cast<GraphNode*>(self->source))->name;
53-
std::string target_name = (reinterpret_cast<GraphNode*>(self->target))->name;
54-
55-
if (source_name.empty() || target_name.empty()) {
56-
PyErr_SetString(PyExc_AttributeError, "Both nodes must have a 'name' attribute.");
57-
return NULL;
70+
std::string src = reinterpret_cast<GraphNode*>(self->source)->name;
71+
std::string tgt = reinterpret_cast<GraphNode*>(self->target)->name;
72+
std::string val_str;
73+
74+
switch (self->value_type) {
75+
case DataType::Int:
76+
val_str = std::to_string(std::get<int64_t>(self->value));
77+
break;
78+
case DataType::Double:
79+
val_str = std::to_string(std::get<double>(self->value));
80+
break;
81+
case DataType::String:
82+
val_str = std::get<std::string>(self->value);
83+
break;
84+
case DataType::None:
85+
default:
86+
val_str = "None";
87+
break;
5888
}
5989

60-
PyObject* str_repr = PyUnicode_FromFormat("('%s', '%s')", source_name.c_str(), target_name.c_str());
90+
return PyUnicode_FromFormat("('%s', '%s', %s)", src.c_str(), tgt.c_str(), val_str.c_str());
91+
}
6192

62-
return str_repr;
93+
static PyObject* GraphEdge_get_value(GraphEdge* self, void* closure) {
94+
switch (self->value_type) {
95+
case DataType::Int:
96+
return PyLong_FromLongLong(std::get<int64_t>(self->value));
97+
case DataType::Double:
98+
return PyFloat_FromDouble(std::get<double>(self->value));
99+
case DataType::String:
100+
return PyUnicode_FromString(std::get<std::string>(self->value).c_str());
101+
case DataType::None:
102+
default:
103+
Py_RETURN_NONE;
104+
}
63105
}
64106

107+
static int GraphEdge_set_value(GraphEdge* self, PyObject* value) {
108+
if (value == Py_None) {
109+
self->value_type = DataType::None;
110+
self->value = std::monostate{};
111+
} else if (PyLong_Check(value)) {
112+
self->value_type = DataType::Int;
113+
self->value = static_cast<int64_t>(PyLong_AsLongLong(value));
114+
} else if (PyFloat_Check(value)) {
115+
self->value_type = DataType::Double;
116+
self->value = PyFloat_AsDouble(value);
117+
} else if (PyUnicode_Check(value)) {
118+
const char* str = PyUnicode_AsUTF8(value);
119+
self->value_type = DataType::String;
120+
self->value = std::string(str);
121+
} else {
122+
PyErr_SetString(PyExc_TypeError, "Edge value must be int, float, str, or None.");
123+
return -1;
124+
}
125+
return 0;
126+
}
127+
128+
static PyGetSetDef GraphEdge_getsetters[] = {
129+
{"value", (getter)GraphEdge_get_value, (setter)GraphEdge_set_value, "Get or set edge value", NULL},
130+
{NULL}
131+
};
132+
65133
inline PyTypeObject GraphEdgeType = {
66134
/* tp_name */ PyVarObject_HEAD_INIT(NULL, 0) "GraphEdge",
67135
/* tp_basicsize */ sizeof(GraphEdge),

pydatastructs/utils/tests/test_misc_util.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def test_GraphEdge():
5353
h_1 = AdjacencyListGraphNode('h_1', 1, backend = Backend.CPP)
5454
h_2 = AdjacencyListGraphNode('h_2', 2, backend = Backend.CPP)
5555
e2 = GraphEdge(h_1, h_2, value = 2, backend = Backend.CPP)
56-
assert str(e2) == "('h_1', 'h_2')"
56+
assert str(e2) == "('h_1', 'h_2', 2)"
5757

5858
def test_BinomialTreeNode():
5959
b = BinomialTreeNode(1,1)

0 commit comments

Comments
 (0)