Skip to content

Commit 1556ade

Browse files
authored
Merge pull request #1007 from ansharlubis/add-restriction-clean
Adding hard-coded restrictions to generics (clean history)
2 parents afb9bb6 + 70f02ba commit 1556ade

26 files changed

+376
-51
lines changed

integration_tests/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,5 +225,6 @@ RUN(NAME test_bit_length LABELS cpython llvm)
225225
RUN(NAME generics_01 LABELS cpython llvm)
226226
RUN(NAME generics_02 LABELS cpython llvm)
227227
RUN(NAME generics_array_01 LABELS llvm)
228+
RUN(NAME generics_list_01 LABELS cpython llvm)
228229
RUN(NAME test_statistics LABELS cpython llvm)
229230
RUN(NAME test_str_attributes LABELS cpython llvm)

integration_tests/generics_01.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
from ltypes import TypeVar
1+
from ltypes import TypeVar, SupportsPlus
22

3-
T = TypeVar('T')
3+
T = TypeVar('T', bound=SupportsPlus)
44

55
def f(x: T, y: T) -> T:
66
return x + y

integration_tests/generics_list_01.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
from ltypes import TypeVar, SupportsPlus, SupportsZero, Divisible
2+
from ltypes import f64, i32
3+
4+
T = TypeVar('T', bound=SupportsPlus|SupportsZero|Divisible)
5+
6+
def mean(x: list[T]) -> f64:
7+
k: i32 = len(x)
8+
if k == 0:
9+
return 0.0
10+
sum: T
11+
sum = 0
12+
i: i32
13+
for i in range(k):
14+
sum = sum + x[i]
15+
return sum/k
16+
17+
print(mean([1,2,3]))
18+
print(mean([1.0,2.0,3.0]))

src/libasr/ASR.asdl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -329,7 +329,11 @@ ttype
329329
| Dict(ttype key_type, ttype value_type)
330330
| Pointer(ttype type)
331331
| CPtr()
332-
| TypeParameter(identifier param, dimension* dims)
332+
| TypeParameter(identifier param, dimension* dims, restriction* rt)
333+
334+
restriction = Restriction(trait rt)
335+
336+
trait = SupportsZero | SupportsPlus | Divisible | Any
333337

334338
binop = Add | Sub | Mul | Div | Pow | BitAnd | BitOr | BitXor | BitLShift | BitRShift
335339

@@ -346,6 +350,7 @@ cast_kind
346350
| IntegerToReal
347351
| LogicalToReal
348352
| RealToReal
353+
| TemplateToReal
349354
| IntegerToInteger
350355
| RealToComplex
351356
| IntegerToComplex

src/libasr/asr_utils.h

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,10 @@ static inline std::string type_to_str(const ASR::ttype_t *t)
142142
return type_to_str(ASRUtils::type_get_past_pointer(
143143
const_cast<ASR::ttype_t*>(t))) + " pointer";
144144
}
145+
case ASR::ttypeType::TypeParameter: {
146+
ASR::TypeParameter_t* tp = ASR::down_cast<ASR::TypeParameter_t>(t);
147+
return tp->m_param;
148+
}
145149
default : throw LCompilersException("Not implemented " + std::to_string(t->type) + ".");
146150
}
147151
}
@@ -983,16 +987,12 @@ static inline bool is_logical(ASR::ttype_t &x) {
983987
}
984988

985989
static inline bool is_generic(ASR::ttype_t &x) {
986-
return ASR::is_a<ASR::TypeParameter_t>(*type_get_past_pointer(&x));
987-
}
988-
989-
static inline std::string get_parameter_name(const ASR::ttype_t* t) {
990-
switch (t->type) {
991-
case ASR::ttypeType::TypeParameter: {
992-
ASR::TypeParameter_t* tp = ASR::down_cast<ASR::TypeParameter_t>(t);
993-
return tp->m_param;
990+
switch (x.type) {
991+
case ASR::ttypeType::List: {
992+
ASR::List_t *list_type = ASR::down_cast<ASR::List_t>(type_get_past_pointer(&x));
993+
return is_generic(*list_type->m_type);
994994
}
995-
default: throw LCompilersException("Cannot obtain type parameter from this type");
995+
default : return ASR::is_a<ASR::TypeParameter_t>(*type_get_past_pointer(&x));
996996
}
997997
}
998998

@@ -1156,7 +1156,7 @@ static inline ASR::ttype_t* duplicate_type(Allocator& al, const ASR::ttype_t* t,
11561156
ASR::dimension_t* dimsp = dims ? dims->p : tp->m_dims;
11571157
size_t dimsn = dims ? dims->n : tp->n_dims;
11581158
return ASRUtils::TYPE(ASR::make_TypeParameter_t(al, t->base.loc,
1159-
tp->m_param, dimsp, dimsn));
1159+
tp->m_param, dimsp, dimsn, tp->m_rt, tp->n_rt));
11601160
}
11611161
default : throw LCompilersException("Not implemented " + std::to_string(t->type));
11621162
}
@@ -1183,7 +1183,7 @@ static inline ASR::ttype_t* duplicate_type_without_dims(Allocator& al, const ASR
11831183
case ASR::ttypeType::TypeParameter: {
11841184
ASR::TypeParameter_t* tp = ASR::down_cast<ASR::TypeParameter_t>(t);
11851185
return ASRUtils::TYPE(ASR::make_TypeParameter_t(al, t->base.loc,
1186-
tp->m_param, nullptr, 0));
1186+
tp->m_param, nullptr, 0, tp->m_rt, tp->n_rt));
11871187
}
11881188
default : throw LCompilersException("Not implemented " + std::to_string(t->type));
11891189
}
@@ -1416,6 +1416,17 @@ static inline ASR::ttype_t* get_type_parameter(ASR::ttype_t* t) {
14161416
}
14171417
}
14181418

1419+
static inline bool has_trait(ASR::TypeParameter_t *tp, ASR::traitType rt) {
1420+
for (size_t i=0; i<tp->n_rt; i++) {
1421+
ASR::Restriction_t *restriction = ASR::down_cast<ASR::Restriction_t>(tp->m_rt[i]);
1422+
if (restriction->m_rt == rt) {
1423+
return true;
1424+
}
1425+
}
1426+
return false;
1427+
}
1428+
1429+
14191430
class ReplaceArgVisitor: public ASR::BaseExprReplacer<ReplaceArgVisitor> {
14201431

14211432
private:

src/libasr/pass/instantiate_template.cpp

Lines changed: 48 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,8 @@ class FunctionInstantiator : public ASR::BaseExprStmtDuplicator<FunctionInstanti
118118
return result;
119119
}
120120

121+
122+
121123
ASR::asr_t* duplicate_Var(ASR::Var_t *x) {
122124
std::string sym_name = ASRUtils::symbol_name(x->m_v);
123125
ASR::symbol_t *sym = current_scope->get_symbol(sym_name);
@@ -139,6 +141,16 @@ class FunctionInstantiator : public ASR::BaseExprStmtDuplicator<FunctionInstanti
139141
return ASR::make_ArrayItem_t(al, x->base.base.loc, m_v, args.p, x->n_args, type, m_value);
140142
}
141143

144+
ASR::asr_t* duplicate_ListItem(ASR::ListItem_t *x) {
145+
ASR::expr_t *m_a = duplicate_expr(x->m_a);
146+
ASR::expr_t *m_pos = duplicate_expr(x->m_pos);
147+
ASR::ttype_t *type = substitute_type(x->m_type);
148+
ASR::expr_t *m_value = duplicate_expr(x->m_value);
149+
150+
return ASR::make_ListItem_t(al, x->base.base.loc,
151+
m_a, m_pos, type, m_value);
152+
}
153+
142154
ASR::array_index_t duplicate_array_index(ASR::array_index_t x) {
143155
ASR::expr_t *left = duplicate_expr(x.m_left);
144156
ASR::expr_t *right = duplicate_expr(x.m_right);
@@ -150,6 +162,21 @@ class FunctionInstantiator : public ASR::BaseExprStmtDuplicator<FunctionInstanti
150162
return result;
151163
}
152164

165+
ASR::asr_t* duplicate_Assignment(ASR::Assignment_t *x) {
166+
ASR::expr_t *target = duplicate_expr(x->m_target);
167+
ASR::ttype_t *target_type = substitute_type(ASRUtils::expr_type(x->m_target));
168+
ASR::expr_t *value = duplicate_expr(x->m_value);
169+
if (ASRUtils::is_real(*target_type) && ASR::is_a<ASR::IntegerConstant_t>(*x->m_value)) {
170+
ASR::IntegerConstant_t *int_value = ASR::down_cast<ASR::IntegerConstant_t>(x->m_value);
171+
if (int_value->m_n == 0) {
172+
value = ASRUtils::EXPR(ASR::make_RealConstant_t(al, value->base.loc, 0,
173+
ASRUtils::duplicate_type(al, target_type)));
174+
}
175+
}
176+
ASR::stmt_t *overloaded = duplicate_stmt(x->m_overloaded);
177+
return ASR::make_Assignment_t(al, x->base.base.loc, target, value, overloaded);
178+
}
179+
153180
ASR::asr_t* duplicate_TemplateBinOp(ASR::TemplateBinOp_t *x) {
154181
ASR::expr_t *left = duplicate_expr(x->m_left);
155182
ASR::expr_t *right = duplicate_expr(x->m_right);
@@ -172,11 +199,20 @@ class FunctionInstantiator : public ASR::BaseExprStmtDuplicator<FunctionInstanti
172199
return ASR::make_DoLoop_t(al, x->base.base.loc, head, m_body.p, x->n_body);
173200
}
174201

202+
ASR::asr_t* duplicate_Cast(ASR::Cast_t *x) {
203+
ASR::expr_t *arg = duplicate_expr(x->m_arg);
204+
ASR::ttype_t *type = substitute_type(ASRUtils::expr_type(x->m_arg));
205+
if (ASRUtils::is_real(*type)) {
206+
return (ASR::asr_t*) arg;
207+
}
208+
return ASRUtils::make_Cast_t_value(al, x->base.base.loc, arg, ASR::cast_kindType::IntegerToReal, x->m_type);
209+
}
210+
175211
ASR::ttype_t* substitute_type(ASR::ttype_t *param_type) {
176212
if (ASR::is_a<ASR::List_t>(*param_type)) {
177-
ASR::List_t *list_type = ASR::down_cast<ASR::List_t>(param_type);
178-
ASR::ttype_t *elem_type = substitute_type(list_type->m_type);
179-
return ASRUtils::TYPE(ASR::make_List_t(al, param_type->base.loc, elem_type));
213+
ASR::List_t *tlist = ASR::down_cast<ASR::List_t>(param_type);
214+
return ASRUtils::TYPE(ASR::make_List_t(al, param_type->base.loc,
215+
substitute_type(tlist->m_type)));
180216
}
181217
if (ASR::is_a<ASR::TypeParameter_t>(*param_type)) {
182218
ASR::TypeParameter_t *param = ASR::down_cast<ASR::TypeParameter_t>(param_type);
@@ -191,28 +227,26 @@ class FunctionInstantiator : public ASR::BaseExprStmtDuplicator<FunctionInstanti
191227
ASR::Real_t* tnew = ASR::down_cast<ASR::Real_t>(t);
192228
return ASRUtils::TYPE(ASR::make_Real_t(al, t->base.loc,
193229
tnew->m_kind, param->m_dims, param->n_dims));
194-
}
230+
}
195231
case ASR::ttypeType::Character: {
196232
ASR::Character_t* tnew = ASR::down_cast<ASR::Character_t>(t);
197233
return ASRUtils::TYPE(ASR::make_Character_t(al, t->base.loc,
198234
tnew->m_kind, tnew->m_len, tnew->m_len_expr,
199235
param->m_dims, param->n_dims));
200-
}
236+
}
201237
default: return subs[param->m_param];
202238
}
203239
}
204240
return param_type;
205241
}
206242

243+
// Commented out part is not yet considered for generic functions
207244
ASR::asr_t* make_BinOp_helper(ASR::expr_t *left, ASR::expr_t *right,
208245
ASR::binopType op, const Location &loc) {
209246
ASR::ttype_t *left_type = ASRUtils::expr_type(left);
210247
ASR::ttype_t *right_type = ASRUtils::expr_type(right);
211248
ASR::ttype_t *dest_type = nullptr;
212249
ASR::expr_t *value = nullptr;
213-
214-
// bool right_is_int = ASRUtils::is_character(*left_type) && ASRUtils::is_integer(*right_type);
215-
// bool left_is_int = ASRUtils::is_integer(*left_type) && ASRUtils::is_character(*right_type);
216250

217251
if ((ASRUtils::is_integer(*left_type) || ASRUtils::is_real(*left_type) ||
218252
ASRUtils::is_complex(*left_type) || ASRUtils::is_logical(*left_type)) &&
@@ -223,7 +257,6 @@ class FunctionInstantiator : public ASR::BaseExprStmtDuplicator<FunctionInstanti
223257
dest_type = ASRUtils::expr_type(left);
224258
} else if (ASRUtils::is_character(*left_type) && ASRUtils::is_character(*right_type)
225259
&& op == ASR::binopType::Add) {
226-
// string concat
227260
ASR::Character_t *left_type2 = ASR::down_cast<ASR::Character_t>(left_type);
228261
ASR::Character_t *right_type2 = ASR::down_cast<ASR::Character_t>(right_type);
229262
LFORTRAN_ASSERT(left_type2->n_dims == 0);
@@ -250,9 +283,10 @@ class FunctionInstantiator : public ASR::BaseExprStmtDuplicator<FunctionInstanti
250283
int64_t result;
251284
switch (op) {
252285
case (ASR::binopType::Add): { result = left_value + right_value; break; }
253-
case (ASR::binopType::Sub): { result = left_value - right_value; break; }
254-
case (ASR::binopType::Mul): { result = left_value * right_value; break; }
286+
// case (ASR::binopType::Sub): { result = left_value - right_value; break; }
287+
// case (ASR::binopType::Mul): { result = left_value * right_value; break; }
255288
case (ASR::binopType::Div): { result = left_value / right_value; break; }
289+
/*
256290
case (ASR::binopType::Pow): { result = std::pow(left_value, right_value); break; }
257291
case (ASR::binopType::BitAnd): { result = left_value & right_value; break; }
258292
case (ASR::binopType::BitOr): { result = left_value | right_value; break; }
@@ -271,17 +305,20 @@ class FunctionInstantiator : public ASR::BaseExprStmtDuplicator<FunctionInstanti
271305
result = left_value >> right_value;
272306
break;
273307
}
308+
*/
274309
default: { LFORTRAN_ASSERT(false); } // should never happen
275310
}
276311
value = ASR::down_cast<ASR::expr_t>(ASR::make_IntegerConstant_t(al, loc, result, dest_type));
277312
}
278313
return ASR::make_IntegerBinOp_t(al, loc, left, op, right, dest_type, value);
279314
} else if (ASRUtils::is_real(*dest_type)) {
315+
/*
280316
if (op == ASR::binopType::BitAnd || op == ASR::binopType::BitOr || op == ASR::binopType::BitXor ||
281317
op == ASR::binopType::BitLShift || op == ASR::binopType::BitRShift) {
282318
throw LCompilersException("ICE: failure in instantiation: Unsupported binary operation on floats: '"
283319
+ ASRUtils::binop_to_str_python(op) + "'");
284320
}
321+
*/
285322
right = cast_helper(left_type, right);
286323
dest_type = ASRUtils::expr_type(right);
287324
if (ASRUtils::expr_value(left) != nullptr && ASRUtils::expr_value(right) != nullptr) {

0 commit comments

Comments
 (0)