Skip to content

Commit c767793

Browse files
committed
Add generalized compare funcs
1 parent 09fc2fa commit c767793

File tree

1 file changed

+27
-3
lines changed

1 file changed

+27
-3
lines changed

src/libasr/codegen/c_utils.h

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,7 @@ class CCPPDSUtils {
208208
func_decls += indent + tab + "int32_t current_end_point;\n";
209209
func_decls += indent + tab + list_element_type + "* data;\n";
210210
func_decls += indent + "};\n\n";
211-
generate_compare_list_element(list_type->m_type);
211+
generate_compare_funcs((ASR::ttype_t *)list_type);
212212
list_init(list_struct_type, list_type_code, list_element_type);
213213
list_deepcopy(list_struct_type, list_type_code, list_element_type, list_type->m_type);
214214
resize_if_needed(list_struct_type, list_type_code, list_element_type);
@@ -272,7 +272,7 @@ class CCPPDSUtils {
272272
return func_decls;
273273
}
274274

275-
void generate_compare_list_element(ASR::ttype_t *t) {
275+
void generate_compare_funcs(ASR::ttype_t *t) {
276276
std::string type_code = ASRUtils::get_type_code(t, true);
277277
if (compareTwoDS.find(type_code) != compareTwoDS.end()) {
278278
return;
@@ -289,7 +289,7 @@ class CCPPDSUtils {
289289
signature = indent + signature;
290290
tmp_gen += indent + signature + " {\n";
291291
ASR::ttype_t *tt = ASR::down_cast<ASR::List_t>(t)->m_type;
292-
generate_compare_list_element(tt);
292+
generate_compare_funcs(tt);
293293
std::string ele_func = compareTwoDS[ASRUtils::get_type_code(tt, true)];
294294
tmp_gen += indent + tab + "if (a.current_end_point != b.current_end_point)\n";
295295
tmp_gen += indent + tab + tab + "return false;\n";
@@ -299,6 +299,29 @@ class CCPPDSUtils {
299299
tmp_gen += indent + tab + "}\n";
300300
tmp_gen += indent + tab + "return true;\n";
301301

302+
} else if (ASR::is_a<ASR::Tuple_t>(*t)) {
303+
ASR::Tuple_t *tt = ASR::down_cast<ASR::Tuple_t>(t);
304+
std::string signature = "bool " + cmp_func + "(" + element_type + " a, " + element_type+ " b)";
305+
func_decls += indent + "inline " + signature + ";\n";
306+
signature = indent + signature;
307+
tmp_gen += indent + signature + " {\n";
308+
tmp_gen += indent + tab + "if (a.length != b.length)\n";
309+
tmp_gen += indent + tab + tab + "return false;\n";
310+
tmp_gen += indent + tab + "bool ans = true;\n";
311+
for (size_t i=0; i<tt->n_type; i++) {
312+
generate_compare_funcs(tt->m_type[i]);
313+
std::string ele_func = compareTwoDS[ASRUtils::get_type_code(tt->m_type[i], true)];
314+
std::string num = std::to_string(i);
315+
tmp_gen += indent + tab + "ans &= " + ele_func + "(a.element_" +
316+
num + ", " + "b.element_" + num + ");\n";
317+
}
318+
tmp_gen += indent + tab + "return ans;\n";
319+
} else if (ASR::is_a<ASR::Character_t>(*t)) {
320+
std::string signature = "bool " + cmp_func + "(" + element_type + " a, " + element_type + " b)";
321+
func_decls += indent + "inline " + signature + ";\n";
322+
signature = indent + signature;
323+
tmp_gen += indent + signature + " {\n";
324+
tmp_gen += indent + tab + "return strcmp(a, b) == 0;\n";
302325
} else {
303326
std::string signature = "bool " + cmp_func + "(" + element_type + " a, " + element_type + " b)";
304327
func_decls += indent + "inline " + signature + ";\n";
@@ -558,6 +581,7 @@ class CCPPDSUtils {
558581
CUtils::get_c_type_from_ttype_t(tuple_type->m_type[i]) + " element_" + std::to_string(i) + ";\n";
559582
}
560583
func_decls += indent + "};\n\n";
584+
generate_compare_funcs((ASR::ttype_t *)tuple_type);
561585
return tuple_struct_type;
562586
}
563587

0 commit comments

Comments
 (0)