Skip to content

Commit 61cace0

Browse files
committed
Extend ASR, add restriction as bound to ltypes
1 parent a851b13 commit 61cace0

File tree

3 files changed

+47
-12
lines changed

3 files changed

+47
-12
lines changed

src/libasr/ASR.asdl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -329,7 +329,11 @@ ttype
329329
| Dict(ttype key_type, ttype value_type)
330330
| Pointer(ttype type)
331331
| CPtr()
332-
| TypeParameter(identifier param, dimension* dims)
332+
| TypeParameter(identifier param, dimension* dims, restriction* rt)
333+
334+
restriction = Restriction(trait rt)
335+
336+
trait = SupportsZero | SupportsPlus | Divisible | Any
333337

334338
binop = Add | Sub | Mul | Div | Pow | BitAnd | BitOr | BitXor | BitLShift | BitRShift
335339

@@ -346,6 +350,7 @@ cast_kind
346350
| IntegerToReal
347351
| LogicalToReal
348352
| RealToReal
353+
| TemplateToReal
349354
| IntegerToInteger
350355
| RealToComplex
351356
| IntegerToComplex

src/libasr/asr_utils.h

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,10 @@ static inline std::string type_to_str(const ASR::ttype_t *t)
142142
return type_to_str(ASRUtils::type_get_past_pointer(
143143
const_cast<ASR::ttype_t*>(t))) + " pointer";
144144
}
145+
case ASR::ttypeType::TypeParameter: {
146+
ASR::TypeParameter_t* tp = ASR::down_cast<ASR::TypeParameter_t>(t);
147+
return tp->m_param;
148+
}
145149
default : throw LCompilersException("Not implemented " + std::to_string(t->type) + ".");
146150
}
147151
}
@@ -983,16 +987,12 @@ static inline bool is_logical(ASR::ttype_t &x) {
983987
}
984988

985989
static inline bool is_generic(ASR::ttype_t &x) {
986-
return ASR::is_a<ASR::TypeParameter_t>(*type_get_past_pointer(&x));
987-
}
988-
989-
static inline std::string get_parameter_name(const ASR::ttype_t* t) {
990-
switch (t->type) {
991-
case ASR::ttypeType::TypeParameter: {
992-
ASR::TypeParameter_t* tp = ASR::down_cast<ASR::TypeParameter_t>(t);
993-
return tp->m_param;
990+
switch (x.type) {
991+
case ASR::ttypeType::List: {
992+
ASR::List_t *list_type = ASR::down_cast<ASR::List_t>(type_get_past_pointer(&x));
993+
return is_generic(*list_type->m_type);
994994
}
995-
default: throw LCompilersException("Cannot obtain type parameter from this type");
995+
default : return ASR::is_a<ASR::TypeParameter_t>(*type_get_past_pointer(&x));
996996
}
997997
}
998998

@@ -1156,7 +1156,7 @@ static inline ASR::ttype_t* duplicate_type(Allocator& al, const ASR::ttype_t* t,
11561156
ASR::dimension_t* dimsp = dims ? dims->p : tp->m_dims;
11571157
size_t dimsn = dims ? dims->n : tp->n_dims;
11581158
return ASRUtils::TYPE(ASR::make_TypeParameter_t(al, t->base.loc,
1159-
tp->m_param, dimsp, dimsn));
1159+
tp->m_param, dimsp, dimsn, tp->m_rt, tp->n_rt));
11601160
}
11611161
default : throw LCompilersException("Not implemented " + std::to_string(t->type));
11621162
}
@@ -1183,7 +1183,7 @@ static inline ASR::ttype_t* duplicate_type_without_dims(Allocator& al, const ASR
11831183
case ASR::ttypeType::TypeParameter: {
11841184
ASR::TypeParameter_t* tp = ASR::down_cast<ASR::TypeParameter_t>(t);
11851185
return ASRUtils::TYPE(ASR::make_TypeParameter_t(al, t->base.loc,
1186-
tp->m_param, nullptr, 0));
1186+
tp->m_param, nullptr, 0, tp->m_rt, tp->n_rt));
11871187
}
11881188
default : throw LCompilersException("Not implemented " + std::to_string(t->type));
11891189
}
@@ -1416,6 +1416,17 @@ static inline ASR::ttype_t* get_type_parameter(ASR::ttype_t* t) {
14161416
}
14171417
}
14181418

1419+
static inline bool has_trait(ASR::TypeParameter_t *tp, ASR::traitType rt) {
1420+
for (size_t i=0; i<tp->n_rt; i++) {
1421+
ASR::Restriction_t *restriction = ASR::down_cast<ASR::Restriction_t>(tp->m_rt[i]);
1422+
if (restriction->m_rt == rt) {
1423+
return true;
1424+
}
1425+
}
1426+
return false;
1427+
}
1428+
1429+
14191430
class ReplaceArgVisitor: public ASR::BaseExprReplacer<ReplaceArgVisitor> {
14201431

14211432
private:

src/runtime/ltypes/ltypes.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,25 @@ def __init__(self, type, dims):
3838
c64 = Type("c64")
3939
CPtr = Type("c_ptr")
4040

41+
# Restrictions
42+
43+
class Any:
44+
def __init__(self):
45+
pass
46+
47+
class SupportsPlus:
48+
def __init__(self):
49+
pass
50+
51+
class SupportsZero:
52+
def __init__(self):
53+
pass
54+
55+
class Divisible:
56+
def __init__(self):
57+
pass
58+
59+
4160
# Overloading support
4261

4362
def ltype(x):

0 commit comments

Comments
 (0)