|
5 | 5 | #include <libasr/asr_verify.h> |
6 | 6 | #include <libasr/pass/array_op.h> |
7 | 7 | #include <libasr/pass/pass_utils.h> |
| 8 | +#include <libasr/pass/intrinsic_function_registry.h> |
8 | 9 |
|
9 | 10 | #include <vector> |
10 | 11 | #include <utility> |
@@ -207,6 +208,7 @@ class ReplaceArrayOp: public ASR::BaseExprReplacer<ReplaceArrayOp> { |
207 | 208 | ASR::stmt_t* assign = ASRUtils::STMT(ASR::make_Assignment_t(al, loc, res, ref, nullptr)); |
208 | 209 | doloop_body.push_back(al, assign); |
209 | 210 | }); |
| 211 | + *current_expr = nullptr; |
210 | 212 | result_var = nullptr; |
211 | 213 | use_custom_loop_params = false; |
212 | 214 | } |
@@ -281,6 +283,7 @@ class ReplaceArrayOp: public ASR::BaseExprReplacer<ReplaceArrayOp> { |
281 | 283 | }); |
282 | 284 | result_var = nullptr; |
283 | 285 | use_custom_loop_params = false; |
| 286 | + *current_expr = nullptr; |
284 | 287 | } |
285 | 288 |
|
286 | 289 | void replace_IntegerConstant(ASR::IntegerConstant_t* x) { |
@@ -347,6 +350,24 @@ class ReplaceArrayOp: public ASR::BaseExprReplacer<ReplaceArrayOp> { |
347 | 350 | } |
348 | 351 | } |
349 | 352 |
|
| 353 | + ASR::ttype_t* get_result_type(ASR::expr_t* arr_expr, ASR::exprType class_type) { |
| 354 | + switch( class_type ) { |
| 355 | + case ASR::exprType::RealCompare: |
| 356 | + case ASR::exprType::ComplexCompare: |
| 357 | + case ASR::exprType::LogicalCompare: |
| 358 | + case ASR::exprType::IntegerCompare: { |
| 359 | + ASR::ttype_t* arr_expr_type = ASRUtils::expr_type(arr_expr); |
| 360 | + ASR::dimension_t* m_dims; |
| 361 | + size_t n_dims = ASRUtils::extract_dimensions_from_ttype(arr_expr_type, m_dims); |
| 362 | + return ASRUtils::TYPE(ASR::make_Logical_t(al, arr_expr->base.loc, |
| 363 | + 4, m_dims, n_dims)); |
| 364 | + } |
| 365 | + default: { |
| 366 | + return PassUtils::get_matching_type(arr_expr, al); |
| 367 | + } |
| 368 | + } |
| 369 | + } |
| 370 | + |
350 | 371 | template <typename T> |
351 | 372 | void replace_ArrayOpCommon(T* x, std::string res_prefix) { |
352 | 373 | const Location& loc = x->base.base.loc; |
@@ -387,8 +408,9 @@ class ReplaceArrayOp: public ASR::BaseExprReplacer<ReplaceArrayOp> { |
387 | 408 | } |
388 | 409 |
|
389 | 410 | if( result_var == nullptr ) { |
| 411 | + ASR::ttype_t* result_var_type = get_result_type(left, x->class_type); |
390 | 412 | result_var = PassUtils::create_var(result_counter, res_prefix, loc, |
391 | | - left, al, current_scope); |
| 413 | + result_var_type, al, current_scope); |
392 | 414 | result_counter += 1; |
393 | 415 | } |
394 | 416 | *current_expr = result_var; |
@@ -422,8 +444,9 @@ class ReplaceArrayOp: public ASR::BaseExprReplacer<ReplaceArrayOp> { |
422 | 444 | n_dims = rank_right; |
423 | 445 | } |
424 | 446 | if( result_var == nullptr ) { |
| 447 | + ASR::ttype_t* result_var_type = get_result_type(arr_expr, x->class_type); |
425 | 448 | result_var = PassUtils::create_var(result_counter, res_prefix, loc, |
426 | | - arr_expr, al, current_scope); |
| 449 | + result_var_type, al, current_scope); |
427 | 450 | result_counter += 1; |
428 | 451 | } |
429 | 452 | *current_expr = result_var; |
@@ -547,6 +570,7 @@ class ReplaceArrayOp: public ASR::BaseExprReplacer<ReplaceArrayOp> { |
547 | 570 | }); |
548 | 571 | result_var = nullptr; |
549 | 572 | use_custom_loop_params = false; |
| 573 | + *current_expr = nullptr; |
550 | 574 | } |
551 | 575 | } |
552 | 576 | return ; |
@@ -590,6 +614,7 @@ class ReplaceArrayOp: public ASR::BaseExprReplacer<ReplaceArrayOp> { |
590 | 614 | op_el_wise, nullptr)); |
591 | 615 | doloop_body.push_back(al, assign); |
592 | 616 | }); |
| 617 | + result_var = nullptr; |
593 | 618 | use_custom_loop_params = false; |
594 | 619 | } |
595 | 620 | } |
@@ -647,6 +672,9 @@ class ReplaceArrayOp: public ASR::BaseExprReplacer<ReplaceArrayOp> { |
647 | 672 | } |
648 | 673 |
|
649 | 674 | void replace_IntrinsicFunction(ASR::IntrinsicFunction_t* x) { |
| 675 | + if( !ASRUtils::IntrinsicFunctionRegistry::is_elemental(x->m_intrinsic_id) ) { |
| 676 | + return ; |
| 677 | + } |
650 | 678 | LCOMPILERS_ASSERT(current_scope != nullptr); |
651 | 679 | const Location& loc = x->base.base.loc; |
652 | 680 | std::vector<bool> array_mask(x->n_args, false); |
@@ -924,6 +952,8 @@ class ArrayOpVisitor : public ASR::CallReplacerOnExpressionsVisitor<ArrayOpVisit |
924 | 952 | pass_result.n = 0; |
925 | 953 | pass_result.reserve(al, 1); |
926 | 954 | remove_original_statement = false; |
| 955 | + replacer.result_var = nullptr; |
| 956 | + replacer.result_type = nullptr; |
927 | 957 | Vec<ASR::stmt_t*>* parent_body_copy = parent_body; |
928 | 958 | parent_body = &body; |
929 | 959 | visit_stmt(*m_body[i]); |
@@ -1147,10 +1177,19 @@ class ArrayOpVisitor : public ASR::CallReplacerOnExpressionsVisitor<ArrayOpVisit |
1147 | 1177 |
|
1148 | 1178 | inline void visit_AssignmentUtil(const ASR::Assignment_t& x) { |
1149 | 1179 | ASR::expr_t** current_expr_copy_9 = current_expr; |
| 1180 | + ASR::expr_t* original_value = x.m_value; |
1150 | 1181 | current_expr = const_cast<ASR::expr_t**>(&(x.m_value)); |
1151 | 1182 | this->call_replacer(); |
1152 | 1183 | current_expr = current_expr_copy_9; |
1153 | | - this->visit_expr(*x.m_value); |
| 1184 | + if( x.m_value == original_value ) { |
| 1185 | + ASR::expr_t* result_var_copy = replacer.result_var; |
| 1186 | + replacer.result_var = nullptr; |
| 1187 | + this->visit_expr(*x.m_value); |
| 1188 | + replacer.result_var = result_var_copy; |
| 1189 | + remove_original_statement = false; |
| 1190 | + } else if( x.m_value ) { |
| 1191 | + this->visit_expr(*x.m_value); |
| 1192 | + } |
1154 | 1193 | if (x.m_overloaded) { |
1155 | 1194 | this->visit_stmt(*x.m_overloaded); |
1156 | 1195 | } |
|
0 commit comments