@@ -8461,6 +8461,28 @@ namespace ts {
8461
8461
return f(type) ? type : neverType;
8462
8462
}
8463
8463
8464
+ function mapType(type: Type, f: (t: Type) => Type): Type {
8465
+ return type.flags & TypeFlags.Union ? getUnionType(map((<UnionType>type).types, f)) : f(type);
8466
+ }
8467
+
8468
+ function extractTypesOfKind(type: Type, kind: TypeFlags) {
8469
+ return filterType(type, t => (t.flags & kind) !== 0);
8470
+ }
8471
+
8472
+ // Return a new type in which occurrences of the string and number primitive types in
8473
+ // typeWithPrimitives have been replaced with occurrences of string literals and numeric
8474
+ // literals in typeWithLiterals, respectively.
8475
+ function replacePrimitivesWithLiterals(typeWithPrimitives: Type, typeWithLiterals: Type) {
8476
+ if (isTypeSubsetOf(stringType, typeWithPrimitives) && maybeTypeOfKind(typeWithLiterals, TypeFlags.StringLiteral) ||
8477
+ isTypeSubsetOf(numberType, typeWithPrimitives) && maybeTypeOfKind(typeWithLiterals, TypeFlags.NumberLiteral)) {
8478
+ return mapType(typeWithPrimitives, t =>
8479
+ t.flags & TypeFlags.String ? extractTypesOfKind(typeWithLiterals, TypeFlags.String | TypeFlags.StringLiteral) :
8480
+ t.flags & TypeFlags.Number ? extractTypesOfKind(typeWithLiterals, TypeFlags.Number | TypeFlags.NumberLiteral) :
8481
+ t);
8482
+ }
8483
+ return typeWithPrimitives;
8484
+ }
8485
+
8464
8486
function isIncomplete(flowType: FlowType) {
8465
8487
return flowType.flags === 0;
8466
8488
}
@@ -8791,16 +8813,12 @@ namespace ts {
8791
8813
assumeTrue ? TypeFacts.EQUndefined : TypeFacts.NEUndefined;
8792
8814
return getTypeWithFacts(type, facts);
8793
8815
}
8794
- if (type.flags & TypeFlags.String && isTypeOfKind(valueType, TypeFlags.StringLiteral) ||
8795
- type.flags & TypeFlags.Number && isTypeOfKind(valueType, TypeFlags.NumberLiteral)) {
8796
- return assumeTrue? valueType : type;
8797
- }
8798
8816
if (type.flags & TypeFlags.NotUnionOrUnit) {
8799
8817
return type;
8800
8818
}
8801
8819
if (assumeTrue) {
8802
8820
const narrowedType = filterType(type, t => areTypesComparable(t, valueType));
8803
- return narrowedType.flags & TypeFlags.Never ? type : narrowedType;
8821
+ return narrowedType.flags & TypeFlags.Never ? type : replacePrimitivesWithLiterals( narrowedType, valueType) ;
8804
8822
}
8805
8823
if (isUnitType(valueType)) {
8806
8824
const regularType = getRegularTypeOfLiteralType(valueType);
@@ -8849,9 +8867,7 @@ namespace ts {
8849
8867
const discriminantType = getUnionType(clauseTypes);
8850
8868
const caseType =
8851
8869
discriminantType.flags & TypeFlags.Never ? neverType :
8852
- type.flags & TypeFlags.String && isTypeOfKind(discriminantType, TypeFlags.StringLiteral) ? discriminantType :
8853
- type.flags & TypeFlags.Number && isTypeOfKind(discriminantType, TypeFlags.NumberLiteral) ? discriminantType :
8854
- filterType(type, t => isTypeComparableTo(discriminantType, t));
8870
+ replacePrimitivesWithLiterals(filterType(type, t => isTypeComparableTo(discriminantType, t)), discriminantType);
8855
8871
if (!hasDefaultClause) {
8856
8872
return caseType;
8857
8873
}
0 commit comments