@@ -4,6 +4,7 @@ package tasty
4
4
import scala .jdk .CollectionConverters ._
5
5
6
6
import scala .quoted ._
7
+ import scala .util .control .NonFatal
7
8
8
9
import NameNormalizer ._
9
10
import SyntheticsSupport ._
@@ -124,6 +125,12 @@ trait TypesSupport:
124
125
++ keyword(" =>> " ).l
125
126
++ inner(resType)
126
127
128
+ case Refinement (parent, " apply" , mt : MethodType ) if isPolyOrEreased(parent) =>
129
+ val isCtx = isContextualMethod(mt)
130
+ val sym = defn.FunctionClass (mt.paramTypes.length, isCtx)
131
+ val at = sym.typeRef.appliedTo(mt.paramTypes :+ mt.resType)
132
+ inner(Refinement (at, " apply" , mt))
133
+
127
134
case r : Refinement => { // (parent, name, info)
128
135
def getRefinementInformation (t : TypeRepr ): List [TypeRepr ] = t match {
129
136
case r : Refinement => getRefinementInformation(r.parent) :+ r
@@ -164,16 +171,22 @@ trait TypesSupport:
164
171
case t : PolyType =>
165
172
val paramBounds = getParamBounds(t)
166
173
val method = t.resType.asInstanceOf [MethodType ]
167
- val paramList = getParamList(method)
168
- val resType = inner(method.resType)
169
- plain(" [" ).l ++ paramBounds ++ plain(" ]" ).l ++ keyword(" => " ).l ++ paramList ++ keyword(" => " ).l ++ resType
174
+ val rest = parseDependentFunctionType(method)
175
+ plain(" [" ).l ++ paramBounds ++ plain(" ]" ).l ++ keyword(" => " ).l ++ rest
170
176
case other => noSupported(s " Not supported type in refinement $info" )
171
177
}
172
178
173
179
def parseDependentFunctionType (info : TypeRepr ): SSignature = info match {
174
180
case m : MethodType =>
175
- val paramList = getParamList(m)
176
- paramList ++ keyword(" => " ).l ++ inner(m.resType)
181
+ val isCtx = isContextualMethod(m)
182
+ if isDependentMethod(m) then
183
+ val paramList = getParamList(m)
184
+ val arrow = keyword(if isCtx then " ?=> " else " => " ).l
185
+ val resType = inner(m.resType)
186
+ paramList ++ arrow ++ resType
187
+ else
188
+ val sym = defn.FunctionClass (m.paramTypes.length, isCtx)
189
+ inner(sym.typeRef.appliedTo(m.paramTypes :+ m.resType))
177
190
case other => noSupported(" Dependent function type without MethodType refinement" )
178
191
}
179
192
@@ -213,8 +226,9 @@ trait TypesSupport:
213
226
case Seq (rtpe) =>
214
227
plain(" ()" ).l ++ keyword(arrow).l ++ inner(rtpe)
215
228
case Seq (arg, rtpe) =>
216
- val partOfSignature = arg match
229
+ val partOfSignature = stripAnnotated( arg) match
217
230
case _ : TermRef | _ : TypeRef | _ : ConstantType | _ : ParamRef => inner(arg)
231
+ case at : AppliedType if ! isInfix(at) && ! at.isFunctionType && ! at.isTupleN => inner(arg)
218
232
case _ => inParens(inner(arg))
219
233
partOfSignature ++ keyword(arrow).l ++ inner(rtpe)
220
234
case args =>
@@ -385,3 +399,21 @@ trait TypesSupport:
385
399
case _ => false
386
400
387
401
at.args.size == 2 && (! at.typeSymbol.name.forall(isIdentifierPart) || infixAnnot)
402
+
403
+ private def isPolyOrEreased (using Quotes )(tr : reflect.TypeRepr ) =
404
+ Set (" scala.PolyFunction" , " scala.runtime.ErasedFunction" )
405
+ .contains(tr.typeSymbol.fullName)
406
+
407
+ private def isContextualMethod (using Quotes )(mt : reflect.MethodType ) =
408
+ mt.asInstanceOf [dotty.tools.dotc.core.Types .MethodType ].isContextualMethod
409
+
410
+ private def isDependentMethod (using Quotes )(mt : reflect.MethodType ) =
411
+ val method = mt.asInstanceOf [dotty.tools.dotc.core.Types .MethodType ]
412
+ try method.isParamDependent || method.isResultDependent
413
+ catch case NonFatal (_) => true
414
+
415
+ private def stripAnnotated (using Quotes )(tr : reflect.TypeRepr ): reflect.TypeRepr =
416
+ import reflect .*
417
+ tr match
418
+ case AnnotatedType (tr, _) => stripAnnotated(tr)
419
+ case other => other
0 commit comments