Skip to content

Commit 4726aad

Browse files
committed
Introduce capability classifiers
1 parent 84a2f25 commit 4726aad

20 files changed

+262
-55
lines changed

compiler/src/dotty/tools/dotc/cc/Capability.scala

Lines changed: 95 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -258,9 +258,11 @@ object Capabilities:
258258
trait Capability extends Showable:
259259

260260
private var myCaptureSet: CaptureSet | Null = uninitialized
261-
private var myCaptureSetValid: Validity = invalid
261+
private var captureSetValid: Validity = invalid
262262
private var mySingletonCaptureSet: CaptureSet.Const | Null = null
263263
private var myDerived: List[DerivedCapability] = Nil
264+
private var myClassifiers: Classifiers = UnknownClassifier
265+
private var classifiersValid: Validity = invalid
264266

265267
protected def cached[C <: DerivedCapability](newRef: C): C =
266268
def recur(refs: List[DerivedCapability]): C = refs match
@@ -292,10 +294,7 @@ object Capabilities:
292294
case Maybe(ref1) => Maybe(ref1.restrict(cls))
293295
case ReadOnly(ref1) => ReadOnly(ref1.restrict(cls).asInstanceOf[Restricted])
294296
case self @ Restricted(ref1, prevCls) =>
295-
val combinedCls =
296-
if prevCls.isSubClass(cls) then prevCls
297-
else if cls.isSubClass(prevCls) then cls
298-
else defn.NothingClass
297+
val combinedCls = leastClassifier(prevCls, cls)
299298
if combinedCls == prevCls then self
300299
else cached(Restricted(ref1, combinedCls))
301300
case self: (ObjectCapability | RootCapability | Reach) => cached(Restricted(self, cls))
@@ -469,7 +468,7 @@ object Capabilities:
469468

470469
def derivesFromCapability(using Context): Boolean = derivesFromCapTrait(defn.Caps_Capability)
471470
def derivesFromMutable(using Context): Boolean = derivesFromCapTrait(defn.Caps_Mutable)
472-
def derivesFromSharedCapability(using Context): Boolean = derivesFromCapTrait(defn.Caps_SharedCapability)
471+
def derivesFromSharable(using Context): Boolean = derivesFromCapTrait(defn.Caps_Sharable)
473472

474473
/** The capture set consisting of exactly this reference */
475474
def singletonCaptureSet(using Context): CaptureSet.Const =
@@ -479,7 +478,7 @@ object Capabilities:
479478

480479
/** The capture set of the type underlying this reference */
481480
def captureSetOfInfo(using Context): CaptureSet =
482-
if myCaptureSetValid == currentId then myCaptureSet.nn
481+
if captureSetValid == currentId then myCaptureSet.nn
483482
else if myCaptureSet.asInstanceOf[AnyRef] eq CaptureSet.Pending then CaptureSet.empty
484483
else
485484
myCaptureSet = CaptureSet.Pending
@@ -491,11 +490,60 @@ object Capabilities:
491490
myCaptureSet = null
492491
else
493492
myCaptureSet = computed
494-
myCaptureSetValid = currentId
493+
captureSetValid = currentId
495494
computed
496495

496+
/** The transitive classifiers of this capability. */
497+
def transClassifiers(using Context): Classifiers =
498+
def toClassifiers(cls: ClassSymbol): Classifiers =
499+
if cls == defn.AnyClass then Unclassified
500+
else ClassifiedAs(cls :: Nil)
501+
if classifiersValid != currentId then
502+
myClassifiers = this match
503+
case self: FreshCap =>
504+
toClassifiers(self.hiddenSet.classifier)
505+
case self: RootCapability =>
506+
Unclassified
507+
case Restricted(_, cls) =>
508+
assert(cls != defn.AnyClass)
509+
if cls == defn.NothingClass then ClassifiedAs(Nil)
510+
else ClassifiedAs(cls :: Nil)
511+
case ReadOnly(ref1) =>
512+
ref1.transClassifiers
513+
case Maybe(ref1) =>
514+
ref1.transClassifiers
515+
case Reach(_) =>
516+
captureSetOfInfo.transClassifiers
517+
case self: CoreCapability =>
518+
joinClassifiers(toClassifiers(self.classifier), captureSetOfInfo.transClassifiers)
519+
if myClassifiers != UnknownClassifier then
520+
classifiersValid == currentId
521+
myClassifiers
522+
end transClassifiers
523+
524+
def tryClassifyAs(cls: ClassSymbol)(using Context): Boolean =
525+
cls == defn.AnyClass
526+
|| this.match
527+
case self: FreshCap =>
528+
self.hiddenSet.tryClassifyAs(cls)
529+
case self: RootCapability =>
530+
true
531+
case Restricted(_, cls1) =>
532+
assert(cls != defn.AnyClass)
533+
cls1.isSubClass(cls)
534+
case ReadOnly(ref1) =>
535+
ref1.tryClassifyAs(cls)
536+
case Maybe(ref1) =>
537+
ref1.tryClassifyAs(cls)
538+
case Reach(_) =>
539+
captureSetOfInfo.tryClassifyAs(cls)
540+
case self: CoreCapability =>
541+
self.classifier.isSubClass(cls)
542+
&& captureSetOfInfo.tryClassifyAs(cls)
543+
497544
def invalidateCaches() =
498-
myCaptureSetValid = invalid
545+
captureSetValid = invalid
546+
classifiersValid = invalid
499547

500548
/** x subsumes x
501549
* x =:= y ==> x subsumes y
@@ -603,12 +651,15 @@ object Capabilities:
603651

604652
vs.ifNotSeen(this)(x.hiddenSet.elems.exists(_.subsumes(y)))
605653
|| levelOK
654+
&& ( y.tryClassifyAs(x.hiddenSet.classifier)
655+
|| { capt.println(i"$y is not classified as $x"); false }
656+
)
606657
&& canAddHidden
607658
&& vs.addHidden(x.hiddenSet, y)
608659
case x: ResultCap =>
609660
val result = y match
610661
case y: ResultCap => vs.unify(x, y)
611-
case _ => y.derivesFromSharedCapability
662+
case _ => y.derivesFromSharable
612663
if !result then
613664
TypeComparer.addErrorNote(CaptureSet.ExistentialSubsumesFailure(x, y))
614665
result
@@ -618,7 +669,7 @@ object Capabilities:
618669
case _: ResultCap => false
619670
case _: FreshCap if CCState.collapseFresh => true
620671
case _ =>
621-
y.derivesFromSharedCapability
672+
y.derivesFromSharable
622673
|| canAddHidden && vs != VarState.HardSeparate && CCState.capIsRoot
623674
case _ =>
624675
y match
@@ -674,6 +725,39 @@ object Capabilities:
674725
def toText(printer: Printer): Text = printer.toTextCapability(this)
675726
end Capability
676727

728+
/** Result type of `transClassifiers`. Interprete as follows:
729+
* UnknownClassifier: No list could be computed since some capture sets
730+
* are still unsolved variables
731+
* Unclassified : No set exists since some parts of tcs are not classified
732+
* ClassifiedAs(clss: All parts of tcss are classified with classes in clss
733+
*/
734+
enum Classifiers:
735+
case UnknownClassifier
736+
case Unclassified
737+
case ClassifiedAs(clss: List[ClassSymbol])
738+
739+
export Classifiers.{UnknownClassifier, Unclassified, ClassifiedAs}
740+
741+
/** The least classifier between `cls1` and `cls2`, which are either
742+
* AnyClass, NothingClass, or a class directly extending caps.Classifier.
743+
* @return if oen of cls1, cls2 is a subclass of the other, the subclass
744+
* otherwise NothingClass (which is a subclass of all classes)
745+
*/
746+
def leastClassifier(cls1: ClassSymbol, cls2: ClassSymbol)(using Context): ClassSymbol =
747+
if cls1.isSubClass(cls2) then cls1
748+
else if cls2.isSubClass(cls1) then cls2
749+
else defn.NothingClass
750+
751+
def joinClassifiers(cs1: Classifiers, cs2: Classifiers)(using Context): Classifiers =
752+
// Drop classes that subclass classes of the other set
753+
def filterSub(cs1: List[ClassSymbol], cs2: List[ClassSymbol]) =
754+
cs1.filter(cls1 => !cs2.exists(cls2 => cls1.isSubClass(cls2)))
755+
(cs1, cs2) match
756+
case (Unclassified, _) | (_, Unclassified) => Unclassified
757+
case (UnknownClassifier, _) | (_, UnknownClassifier) => UnknownClassifier
758+
case (ClassifiedAs(cs1), ClassifiedAs(cs2)) =>
759+
ClassifiedAs(filterSub(cs1, cs2) ++ filterSub(cs2, cs1))
760+
677761
/** The place of - and cause for - creating a fresh capability. Used for
678762
* error diagnostics
679763
*/

compiler/src/dotty/tools/dotc/cc/CaptureOps.scala

Lines changed: 35 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -376,7 +376,7 @@ extension (tp: Type)
376376

377377
def derivesFromCapability(using Context): Boolean = derivesFromCapTrait(defn.Caps_Capability)
378378
def derivesFromMutable(using Context): Boolean = derivesFromCapTrait(defn.Caps_Mutable)
379-
def derivesFromSharedCapability(using Context): Boolean = derivesFromCapTrait(defn.Caps_SharedCapability)
379+
def derivesFromSharedCapability(using Context): Boolean = derivesFromCapTrait(defn.Caps_Sharable)
380380

381381
/** Drop @retains annotations everywhere */
382382
def dropAllRetains(using Context): Type = // TODO we should drop retains from inferred types before unpickling
@@ -442,6 +442,30 @@ extension (tp: Type)
442442
def dropUseAndConsumeAnnots(using Context): Type =
443443
tp.dropAnnot(defn.UseAnnot).dropAnnot(defn.ConsumeAnnot)
444444

445+
/** If `tp` is a function or method, a type of the same kind with the given
446+
* argument and result types.
447+
*/
448+
def derivedFunctionOrMethod(argTypes: List[Type], resType: Type)(using Context): Type = tp match
449+
case tp @ AppliedType(tycon, args) if defn.isNonRefinedFunction(tp) =>
450+
val args1 = argTypes :+ resType
451+
if args.corresponds(args1)(_ eq _) then tp
452+
else tp.derivedAppliedType(tycon, args1)
453+
case tp @ defn.RefinedFunctionOf(rinfo) =>
454+
val rinfo1 = rinfo.derivedFunctionOrMethod(argTypes, resType)
455+
if rinfo1 eq rinfo then tp
456+
else if rinfo1.isInstanceOf[PolyType] then tp.derivedRefinedType(refinedInfo = rinfo1)
457+
else rinfo1.toFunctionType(alwaysDependent = true)
458+
case tp: MethodType =>
459+
tp.derivedLambdaType(paramInfos = argTypes, resType = resType)
460+
case tp: PolyType =>
461+
assert(argTypes.isEmpty)
462+
tp.derivedLambdaType(resType = resType)
463+
case _ =>
464+
tp
465+
466+
def classifier(using Context): ClassSymbol =
467+
tp.classSymbols.map(_.classifier).foldLeft(defn.AnyClass)(leastClassifier)
468+
445469
extension (tp: MethodType)
446470
/** A method marks an existential scope unless it is the prefix of a curried method */
447471
def marksExistentialScope(using Context): Boolean =
@@ -473,6 +497,16 @@ extension (cls: ClassSymbol)
473497
val selfType = bc.givenSelfType
474498
bc.is(CaptureChecked) && selfType.exists && selfType.captureSet.elems == refs.elems
475499

500+
def isClassifiedCapabilityClass(using Context): Boolean =
501+
cls.derivesFrom(defn.Caps_Capability) && cls.parentSyms.contains(defn.Caps_Classifier)
502+
503+
def classifier(using Context): ClassSymbol =
504+
if cls.derivesFrom(defn.Caps_Capability) then
505+
cls.baseClasses
506+
.filter(_.parentSyms.contains(defn.Caps_Classifier))
507+
.foldLeft(defn.AnyClass)(leastClassifier)
508+
else defn.AnyClass
509+
476510
extension (sym: Symbol)
477511

478512
/** This symbol is one of `retains` or `retainsCap` */
@@ -628,28 +662,6 @@ object FunctionOrMethod:
628662
case defn.RefinedFunctionOf(rinfo) => unapply(rinfo)
629663
case _ => None
630664

631-
/** If `tp` is a function or method, a type of the same kind with the given
632-
* argument and result types.
633-
*/
634-
extension (self: Type)
635-
def derivedFunctionOrMethod(argTypes: List[Type], resType: Type)(using Context): Type = self match
636-
case self @ AppliedType(tycon, args) if defn.isNonRefinedFunction(self) =>
637-
val args1 = argTypes :+ resType
638-
if args.corresponds(args1)(_ eq _) then self
639-
else self.derivedAppliedType(tycon, args1)
640-
case self @ defn.RefinedFunctionOf(rinfo) =>
641-
val rinfo1 = rinfo.derivedFunctionOrMethod(argTypes, resType)
642-
if rinfo1 eq rinfo then self
643-
else if rinfo1.isInstanceOf[PolyType] then self.derivedRefinedType(refinedInfo = rinfo1)
644-
else rinfo1.toFunctionType(alwaysDependent = true)
645-
case self: MethodType =>
646-
self.derivedLambdaType(paramInfos = argTypes, resType = resType)
647-
case self: PolyType =>
648-
assert(argTypes.isEmpty)
649-
self.derivedLambdaType(resType = resType)
650-
case _ =>
651-
self
652-
653665
/** An extractor for a contains argument */
654666
object ContainsImpl:
655667
def unapply(tree: TypeApply)(using Context): Option[(Tree, Tree)] =

compiler/src/dotty/tools/dotc/cc/CaptureSet.scala

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -410,6 +410,20 @@ sealed abstract class CaptureSet extends Showable:
410410
if mutability != Ignored then res.mutability = Reader
411411
res
412412

413+
def transClassifiers(using Context): Classifiers =
414+
if isConst then
415+
(ClassifiedAs(Nil) /: elems.map(_.transClassifiers))(joinClassifiers)
416+
else UnknownClassifier
417+
418+
def tryClassifyAs(cls: ClassSymbol)(using Context): Boolean =
419+
elems.forall(_.tryClassifyAs(cls))
420+
421+
def adoptClassifier(cls: ClassSymbol)(using Context): Unit =
422+
for elem <- elems do
423+
elem.stripReadOnly match
424+
case fresh: FreshCap => fresh.hiddenSet.adoptClassifier(cls)
425+
case _ =>
426+
413427
/** A bad root `elem` is inadmissible as a member of this set. What is a bad roots depends
414428
* on the value of `rootLimit`.
415429
* If the limit is null, all capture roots are good.
@@ -651,6 +665,25 @@ object CaptureSet:
651665
*/
652666
private[CaptureSet] var rootLimit: Symbol | Null = null
653667

668+
private var myClassifier: ClassSymbol = defn.AnyClass
669+
def classifier: ClassSymbol = myClassifier
670+
671+
private def narrowClassifier(cls: ClassSymbol)(using Context): Unit =
672+
val newClassifier = leastClassifier(classifier, cls)
673+
if newClassifier == defn.NothingClass then
674+
println(i"conflicting classifications for $this, was $classifier, now $cls")
675+
myClassifier = newClassifier
676+
677+
override def adoptClassifier(cls: ClassSymbol)(using Context): Unit =
678+
if !classifier.isSubClass(cls) then // serves as recursion brake
679+
narrowClassifier(cls)
680+
super.adoptClassifier(cls)
681+
682+
override def tryClassifyAs(cls: ClassSymbol)(using Context): Boolean =
683+
classifier.isSubClass(cls)
684+
|| super.tryClassifyAs(cls)
685+
&& { narrowClassifier(cls); true }
686+
654687
/** A handler to be invoked when new elems are added to this set */
655688
var newElemAddedHandler: Capability => Context ?=> Unit = _ => ()
656689

@@ -682,14 +715,15 @@ object CaptureSet:
682715
addIfHiddenOrFail(elem)
683716
else if !levelOK(elem) then
684717
failWith(IncludeFailure(this, elem, levelError = true)) // or `elem` is not visible at the level of the set.
718+
else if !elem.tryClassifyAs(classifier) then
719+
failWith(IncludeFailure(this, elem))
685720
else
686721
// id == 108 then assert(false, i"trying to add $elem to $this")
687722
assert(elem.isWellformed, elem)
688723
assert(!this.isInstanceOf[HiddenSet] || summon[VarState].isSeparating, summon[VarState])
689724
includeElem(elem)
690725
if isBadRoot(rootLimit, elem) then
691726
rootAddedHandler()
692-
newElemAddedHandler(elem)
693727
val normElem = if isMaybeSet then elem else elem.stripMaybe
694728
// assert(id != 5 || elems.size != 3, this)
695729
val res = deps.forall: dep =>

compiler/src/dotty/tools/dotc/cc/CapturingType.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ object CapturingType:
4040
apply(parent1, refs ++ refs1, boxed)
4141
case _ =>
4242
if parent.derivesFromMutable then refs.setMutable()
43+
val classifier = parent.classifier
44+
refs.adoptClassifier(parent.classifier)
4345
AnnotatedType(parent, CaptureAnnotation(refs, boxed)(defn.RetainsAnnot))
4446

4547
/** An extractor for CapturingTypes. Capturing types are recognized if

compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,11 @@ object CheckCaptures:
113113
case ReadOnlyCapability(ref) =>
114114
check(ref)
115115
case OnlyCapability(ref, cls) =>
116+
if !cls.isClassifiedCapabilityClass then
117+
report.error(
118+
em"""${ref.showRef}.only[${cls.name}] is not well-formed since $cls is not a classifier class.
119+
|A classifier class is a class extending `caps.Capability` and directly extending `caps.Classifier`.""",
120+
ann.srcPos)
116121
check(ref)
117122
case tpe =>
118123
report.error(em"$elem: $tpe is not a legal element of a capture set", ann.srcPos)

compiler/src/dotty/tools/dotc/cc/SepCheck.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -598,7 +598,7 @@ class SepCheck(checker: CheckCaptures.CheckerAPI) extends tpd.TreeTraverser:
598598
* - If the reference is to a this type of the enclosing class, the
599599
* access must be in a @consume method.
600600
*
601-
* References that extend cpas.Sharable are excluded from checking.
601+
* References that extend caps.Sharable are excluded from checking.
602602
* As a side effect, add all checked references with the given position `pos`
603603
* to the global `consumed` map.
604604
*
@@ -612,7 +612,7 @@ class SepCheck(checker: CheckCaptures.CheckerAPI) extends tpd.TreeTraverser:
612612
val badParams = mutable.ListBuffer[Symbol]()
613613
def currentOwner = role.dclSym.orElse(ctx.owner)
614614
for hiddenRef <- refsToCheck.deductSymRefs(role.dclSym).deduct(explicitRefs(tpe)) do
615-
if !hiddenRef.derivesFromSharedCapability then
615+
if !hiddenRef.derivesFromSharable then
616616
hiddenRef.pathRoot match
617617
case ref: TermRef =>
618618
val refSym = ref.symbol
@@ -649,7 +649,7 @@ class SepCheck(checker: CheckCaptures.CheckerAPI) extends tpd.TreeTraverser:
649649
role match
650650
case _: TypeRole.Argument | _: TypeRole.Qualifier =>
651651
for ref <- refsToCheck do
652-
if !ref.derivesFromSharedCapability then
652+
if !ref.derivesFromSharable then
653653
consumed.put(ref, pos)
654654
case _ =>
655655
end checkConsumedRefs

0 commit comments

Comments
 (0)