Skip to content

Commit 87133b6

Browse files
authored
Explicit type instantiations (f<<T>>) (#1980)
This adds in support for [Explicit type parameter instantiation](https://rfcs.luau.org/explicit-type-parameter-instantiation.html). This PR is still a work in progress, but some important notes: 1. Metatables with `__call` are **not** supported. While `t<<A>>()` is obvious what it should do with `__call`, it's not obvious what `t<<A>>` on its own would be. 2. Intersection types are not supported at the moment either. Both of these are possible to bring later.
1 parent f071cb1 commit 87133b6

35 files changed

+1763
-122
lines changed

Analysis/include/Luau/AstUtils.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99
namespace Luau
1010
{
1111

12-
// Search through the expression 'expr' for types that are known to represent
13-
// uniquely held references. Append these types to 'uniqueTypes'.
12+
// Search through the expression 'expr' for typeArguments that are known to represent
13+
// uniquely held references. Append these typeArguments to 'uniqueTypes'.
1414
void findUniqueTypes(NotNull<DenseHashSet<TypeId>> uniqueTypes, AstExpr* expr, NotNull<const DenseHashMap<const AstExpr*, TypeId>> astTypes);
1515

1616
void findUniqueTypes(

Analysis/include/Luau/Constraint.h

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,9 @@ struct FunctionCallConstraint
9999
class AstExprCall* callSite = nullptr;
100100
std::vector<std::optional<TypeId>> discriminantTypes;
101101

102+
std::vector<TypeId> typeArguments;
103+
std::vector<TypePackId> typePackArguments;
104+
102105
// When we dispatch this constraint, we update the key at this map to record
103106
// the overload that we selected.
104107
DenseHashMap<const AstNode*, TypeId>* astOverloadResolvedTypes = nullptr;
@@ -292,6 +295,16 @@ struct PushFunctionTypeConstraint
292295
bool isSelf;
293296
};
294297

298+
// Binds the function to a set of explicitly specified types,
299+
// for f<<T>>.
300+
struct TypeInstantiationConstraint
301+
{
302+
TypeId functionType;
303+
TypeId placeholderType;
304+
std::vector<TypeId> typeArguments;
305+
std::vector<TypePackId> typePackArguments;
306+
};
307+
295308
struct PushTypeConstraint
296309
{
297310
TypeId expectedType;
@@ -321,7 +334,8 @@ using ConstraintV = Variant<
321334
EqualityConstraint,
322335
SimplifyConstraint,
323336
PushFunctionTypeConstraint,
324-
PushTypeConstraint>;
337+
PushTypeConstraint,
338+
TypeInstantiationConstraint>;
325339

326340
struct Constraint
327341
{

Analysis/include/Luau/ConstraintGenerator.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -333,6 +333,7 @@ struct ConstraintGenerator
333333
Inference check(const ScopePtr& scope, AstExprIfElse* ifElse, std::optional<TypeId> expectedType);
334334
Inference check(const ScopePtr& scope, AstExprTypeAssertion* typeAssert);
335335
Inference check(const ScopePtr& scope, AstExprInterpString* interpString);
336+
Inference check(const ScopePtr& scope, AstExprInstantiate* explicitTypeInstantiation);
336337
Inference check(const ScopePtr& scope, AstExprTable* expr, std::optional<TypeId> expectedType);
337338
std::tuple<TypeId, TypeId, RefinementId> checkBinary(
338339
const ScopePtr& scope,
@@ -482,6 +483,8 @@ struct ConstraintGenerator
482483

483484
void fillInInferredBindings(const ScopePtr& globalScope, AstStatBlock* block);
484485

486+
std::pair<std::vector<TypeId>, std::vector<TypePackId>> resolveTypeArguments(const ScopePtr& scope, const AstArray<AstTypeOrPack>& typeArguments);
487+
485488
/** Given a function type annotation, return a vector describing the expected types of the calls to the function
486489
* For example, calling a function with annotation ((number) -> string & ((string) -> number))
487490
* yields a vector of size 1, with value: [number | string]

Analysis/include/Luau/ConstraintSolver.h

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,7 @@ struct ConstraintSolver
241241
bool tryDispatch(const FunctionCheckConstraint& c, NotNull<const Constraint> constraint, bool force);
242242
bool tryDispatch(const PrimitiveTypeConstraint& c, NotNull<const Constraint> constraint);
243243
bool tryDispatch(const HasPropConstraint& c, NotNull<const Constraint> constraint);
244-
244+
bool tryDispatch(const TypeInstantiationConstraint& c, NotNull<const Constraint> constraint);
245245

246246
bool tryDispatchHasIndexer(
247247
int& recursionDepth,
@@ -470,6 +470,14 @@ struct ConstraintSolver
470470

471471
TypeId simplifyUnion(NotNull<Scope> scope, Location location, TypeId left, TypeId right);
472472

473+
TypeId instantiateFunctionType(
474+
TypeId functionTypeId,
475+
const std::vector<TypeId>& typeArguments,
476+
const std::vector<TypePackId>& typePackArguments,
477+
NotNull<Scope> scope,
478+
const Location& location
479+
);
480+
473481
TypePackId anyifyModuleReturnTypePackGenerics(TypePackId tp);
474482

475483
void throwTimeLimitError() const;

Analysis/include/Luau/DataFlowGraph.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,7 @@ struct DataFlowGraphBuilder
189189
DataFlowResult visitExpr(AstExprTypeAssertion* t);
190190
DataFlowResult visitExpr(AstExprIfElse* i);
191191
DataFlowResult visitExpr(AstExprInterpString* i);
192+
DataFlowResult visitExpr(AstExprInstantiate* i);
192193
DataFlowResult visitExpr(AstExprError* error);
193194

194195
void visitLValue(AstExpr* e, DefId incomingDef);

Analysis/include/Luau/Error.h

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -535,6 +535,36 @@ struct GenericBoundsMismatch
535535
bool operator==(const GenericBoundsMismatch& rhs) const;
536536
};
537537

538+
// Used `f<<T>>` where f is not a function
539+
struct InstantiateGenericsOnNonFunction
540+
{
541+
enum class InterestingEdgeCase
542+
{
543+
None,
544+
MetatableCall,
545+
Intersection,
546+
};
547+
548+
InterestingEdgeCase interestingEdgeCase;
549+
550+
bool operator==(const InstantiateGenericsOnNonFunction&) const;
551+
};
552+
553+
// Provided too many generics inside `f<<T>>`
554+
struct TypeInstantiationCountMismatch
555+
{
556+
std::optional<std::string> functionName;
557+
TypeId functionType;
558+
559+
size_t providedTypes = 0;
560+
size_t maximumTypes = 0;
561+
562+
size_t providedTypePacks = 0;
563+
size_t maximumTypePacks = 0;
564+
565+
bool operator==(const TypeInstantiationCountMismatch&) const;
566+
};
567+
538568
// Error when referencing a type function without providing explicit generics.
539569
//
540570
// type function create_table_with_key()
@@ -609,8 +639,9 @@ using TypeErrorData = Variant<
609639
MultipleNonviableOverloads,
610640
RecursiveRestraintViolation,
611641
GenericBoundsMismatch,
612-
UnappliedTypeFunction>;
613-
642+
UnappliedTypeFunction,
643+
InstantiateGenericsOnNonFunction,
644+
TypeInstantiationCountMismatch>;
614645

615646
struct TypeErrorSummary
616647
{

Analysis/include/Luau/ToString.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,7 @@ std::string dump(const std::optional<TypeId>& ty);
140140
std::string dump(TypePackId ty);
141141
std::string dump(const std::optional<TypePackId>& ty);
142142
std::string dump(const std::vector<TypeId>& types);
143+
std::string dump(const std::vector<TypePackId>& types);
143144
std::string dump(DenseHashMap<TypeId, TypeId>& types);
144145
std::string dump(DenseHashMap<TypePackId, TypePackId>& types);
145146

@@ -163,4 +164,5 @@ inline std::string toString(const TypeOrPack& tyOrTp)
163164
std::string dump(const TypeOrPack& tyOrTp);
164165

165166
std::string toStringVector(const std::vector<TypeId>& types, ToStringOptions& opts);
167+
std::string toStringVector(const std::vector<TypePackId>& typePacks, ToStringOptions& opts);
166168
} // namespace Luau

Analysis/include/Luau/TypeChecker2.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,7 @@ struct TypeChecker2
173173
void visit(AstExprTypeAssertion* expr);
174174
void visit(AstExprIfElse* expr);
175175
void visit(AstExprInterpString* interpString);
176+
void visit(AstExprInstantiate* explicitTypeInstantiation);
176177
void visit(AstExprError* expr);
177178
TypeId flattenPack(TypePackId pack);
178179
void visitGenerics(AstArray<AstGenericType*> generics, AstArray<AstGenericTypePack*> genericPacks);
@@ -233,6 +234,8 @@ struct TypeChecker2
233234

234235
void suggestAnnotations(AstExprFunction* expr, TypeId ty);
235236

237+
void checkTypeInstantiation(AstExpr* baseFunctionExpr, TypeId fnType, const Location& location, const AstArray<AstTypeOrPack>& typeArguments);
238+
236239
void diagnoseMissingTableKey(UnknownProperty* utk, TypeErrorData& data) const;
237240
bool isErrorSuppressing(Location loc, TypeId ty);
238241
bool isErrorSuppressing(Location loc1, TypeId ty1, Location loc2, TypeId ty2);

Analysis/include/Luau/TypeInfer.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@ struct TypeChecker
134134
WithPredicate<TypeId> checkExpr(const ScopePtr& scope, const AstExprError& expr);
135135
WithPredicate<TypeId> checkExpr(const ScopePtr& scope, const AstExprIfElse& expr, std::optional<TypeId> expectedType = std::nullopt);
136136
WithPredicate<TypeId> checkExpr(const ScopePtr& scope, const AstExprInterpString& expr);
137+
WithPredicate<TypeId> checkExpr(const ScopePtr& scope, const AstExprInstantiate& expr);
137138

138139
TypeId checkExprTable(
139140
const ScopePtr& scope,
@@ -227,6 +228,14 @@ struct TypeChecker
227228
const std::vector<std::optional<TypeId>>& expectedTypes = {}
228229
);
229230

231+
TypeId instantiateTypeParameters(
232+
const ScopePtr& scope,
233+
TypeId baseType,
234+
const AstArray<AstTypeOrPack>& explicitTypes,
235+
const AstExpr* functionExpr,
236+
const Location& location
237+
);
238+
230239
static std::optional<AstExpr*> matchRequire(const AstExprCall& call);
231240
TypeId checkRequire(const ScopePtr& scope, const ModuleInfo& moduleInfo, const Location& location);
232241

0 commit comments

Comments
 (0)