|
5 | 5 | #include <libasr/asr_verify.h>
|
6 | 6 | #include <libasr/pass/array_op.h>
|
7 | 7 | #include <libasr/pass/pass_utils.h>
|
| 8 | +#include <libasr/pass/intrinsic_function_registry.h> |
8 | 9 |
|
9 | 10 | #include <vector>
|
10 | 11 | #include <utility>
|
@@ -207,6 +208,7 @@ class ReplaceArrayOp: public ASR::BaseExprReplacer<ReplaceArrayOp> {
|
207 | 208 | ASR::stmt_t* assign = ASRUtils::STMT(ASR::make_Assignment_t(al, loc, res, ref, nullptr));
|
208 | 209 | doloop_body.push_back(al, assign);
|
209 | 210 | });
|
| 211 | + *current_expr = nullptr; |
210 | 212 | result_var = nullptr;
|
211 | 213 | use_custom_loop_params = false;
|
212 | 214 | }
|
@@ -281,6 +283,7 @@ class ReplaceArrayOp: public ASR::BaseExprReplacer<ReplaceArrayOp> {
|
281 | 283 | });
|
282 | 284 | result_var = nullptr;
|
283 | 285 | use_custom_loop_params = false;
|
| 286 | + *current_expr = nullptr; |
284 | 287 | }
|
285 | 288 |
|
286 | 289 | void replace_IntegerConstant(ASR::IntegerConstant_t* x) {
|
@@ -347,6 +350,24 @@ class ReplaceArrayOp: public ASR::BaseExprReplacer<ReplaceArrayOp> {
|
347 | 350 | }
|
348 | 351 | }
|
349 | 352 |
|
| 353 | + ASR::ttype_t* get_result_type(ASR::expr_t* arr_expr, ASR::exprType class_type) { |
| 354 | + switch( class_type ) { |
| 355 | + case ASR::exprType::RealCompare: |
| 356 | + case ASR::exprType::ComplexCompare: |
| 357 | + case ASR::exprType::LogicalCompare: |
| 358 | + case ASR::exprType::IntegerCompare: { |
| 359 | + ASR::ttype_t* arr_expr_type = ASRUtils::expr_type(arr_expr); |
| 360 | + ASR::dimension_t* m_dims; |
| 361 | + size_t n_dims = ASRUtils::extract_dimensions_from_ttype(arr_expr_type, m_dims); |
| 362 | + return ASRUtils::TYPE(ASR::make_Logical_t(al, arr_expr->base.loc, |
| 363 | + 4, m_dims, n_dims)); |
| 364 | + } |
| 365 | + default: { |
| 366 | + return PassUtils::get_matching_type(arr_expr, al); |
| 367 | + } |
| 368 | + } |
| 369 | + } |
| 370 | + |
350 | 371 | template <typename T>
|
351 | 372 | void replace_ArrayOpCommon(T* x, std::string res_prefix) {
|
352 | 373 | const Location& loc = x->base.base.loc;
|
@@ -387,8 +408,9 @@ class ReplaceArrayOp: public ASR::BaseExprReplacer<ReplaceArrayOp> {
|
387 | 408 | }
|
388 | 409 |
|
389 | 410 | if( result_var == nullptr ) {
|
| 411 | + ASR::ttype_t* result_var_type = get_result_type(left, x->class_type); |
390 | 412 | result_var = PassUtils::create_var(result_counter, res_prefix, loc,
|
391 |
| - left, al, current_scope); |
| 413 | + result_var_type, al, current_scope); |
392 | 414 | result_counter += 1;
|
393 | 415 | }
|
394 | 416 | *current_expr = result_var;
|
@@ -422,8 +444,9 @@ class ReplaceArrayOp: public ASR::BaseExprReplacer<ReplaceArrayOp> {
|
422 | 444 | n_dims = rank_right;
|
423 | 445 | }
|
424 | 446 | if( result_var == nullptr ) {
|
| 447 | + ASR::ttype_t* result_var_type = get_result_type(arr_expr, x->class_type); |
425 | 448 | result_var = PassUtils::create_var(result_counter, res_prefix, loc,
|
426 |
| - arr_expr, al, current_scope); |
| 449 | + result_var_type, al, current_scope); |
427 | 450 | result_counter += 1;
|
428 | 451 | }
|
429 | 452 | *current_expr = result_var;
|
@@ -547,6 +570,7 @@ class ReplaceArrayOp: public ASR::BaseExprReplacer<ReplaceArrayOp> {
|
547 | 570 | });
|
548 | 571 | result_var = nullptr;
|
549 | 572 | use_custom_loop_params = false;
|
| 573 | + *current_expr = nullptr; |
550 | 574 | }
|
551 | 575 | }
|
552 | 576 | return ;
|
@@ -590,6 +614,7 @@ class ReplaceArrayOp: public ASR::BaseExprReplacer<ReplaceArrayOp> {
|
590 | 614 | op_el_wise, nullptr));
|
591 | 615 | doloop_body.push_back(al, assign);
|
592 | 616 | });
|
| 617 | + result_var = nullptr; |
593 | 618 | use_custom_loop_params = false;
|
594 | 619 | }
|
595 | 620 | }
|
@@ -647,6 +672,9 @@ class ReplaceArrayOp: public ASR::BaseExprReplacer<ReplaceArrayOp> {
|
647 | 672 | }
|
648 | 673 |
|
649 | 674 | void replace_IntrinsicFunction(ASR::IntrinsicFunction_t* x) {
|
| 675 | + if( !ASRUtils::IntrinsicFunctionRegistry::is_elemental(x->m_intrinsic_id) ) { |
| 676 | + return ; |
| 677 | + } |
650 | 678 | LCOMPILERS_ASSERT(current_scope != nullptr);
|
651 | 679 | const Location& loc = x->base.base.loc;
|
652 | 680 | std::vector<bool> array_mask(x->n_args, false);
|
@@ -924,6 +952,8 @@ class ArrayOpVisitor : public ASR::CallReplacerOnExpressionsVisitor<ArrayOpVisit
|
924 | 952 | pass_result.n = 0;
|
925 | 953 | pass_result.reserve(al, 1);
|
926 | 954 | remove_original_statement = false;
|
| 955 | + replacer.result_var = nullptr; |
| 956 | + replacer.result_type = nullptr; |
927 | 957 | Vec<ASR::stmt_t*>* parent_body_copy = parent_body;
|
928 | 958 | parent_body = &body;
|
929 | 959 | visit_stmt(*m_body[i]);
|
@@ -1147,10 +1177,19 @@ class ArrayOpVisitor : public ASR::CallReplacerOnExpressionsVisitor<ArrayOpVisit
|
1147 | 1177 |
|
1148 | 1178 | inline void visit_AssignmentUtil(const ASR::Assignment_t& x) {
|
1149 | 1179 | ASR::expr_t** current_expr_copy_9 = current_expr;
|
| 1180 | + ASR::expr_t* original_value = x.m_value; |
1150 | 1181 | current_expr = const_cast<ASR::expr_t**>(&(x.m_value));
|
1151 | 1182 | this->call_replacer();
|
1152 | 1183 | current_expr = current_expr_copy_9;
|
1153 |
| - this->visit_expr(*x.m_value); |
| 1184 | + if( x.m_value == original_value ) { |
| 1185 | + ASR::expr_t* result_var_copy = replacer.result_var; |
| 1186 | + replacer.result_var = nullptr; |
| 1187 | + this->visit_expr(*x.m_value); |
| 1188 | + replacer.result_var = result_var_copy; |
| 1189 | + remove_original_statement = false; |
| 1190 | + } else if( x.m_value ) { |
| 1191 | + this->visit_expr(*x.m_value); |
| 1192 | + } |
1154 | 1193 | if (x.m_overloaded) {
|
1155 | 1194 | this->visit_stmt(*x.m_overloaded);
|
1156 | 1195 | }
|
|
0 commit comments