Skip to content

Commit b95332b

Browse files
committed
spring support for scala default parameter values
1 parent 05bdc3d commit b95332b

File tree

8 files changed

+281
-61
lines changed

8 files changed

+281
-61
lines changed

commons-core/src/main/scala/com/avsystem/commons/SharedExtensions.scala

+10
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,16 @@ object SharedExtensions extends SharedExtensions {
143143
def evalFuture: Future[A] = FutureCompanionOps.eval(a())
144144

145145
def evalTry: Try[A] = Try(a())
146+
147+
def recoverFrom[T <: Throwable : ClassTag](fallbackValue: => A): A =
148+
try a() catch {
149+
case _: T => fallbackValue
150+
}
151+
152+
def recoverToOpt[T <: Throwable : ClassTag]: Opt[A] =
153+
try Opt(a()) catch {
154+
case _: T => Opt.Empty
155+
}
146156
}
147157

148158
class NullableOps[A >: Null](private val a: A) extends AnyVal {

commons-spring/src/main/scala/com/avsystem/commons/spring/HoconBeanDefinitionReader.scala

+14-10
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,7 @@ class HoconBeanDefinitionReader(registry: BeanDefinitionRegistry)
247247
getProps(obj).foreach {
248248
case (key, value) =>
249249
if (construct) {
250-
addConstructorArg(readConstructorArg(value, Some(key)))
250+
addConstructorArg(readConstructorArg(value, forcedName = key))
251251
} else {
252252
propertyValues.addPropertyValue(readPropertyValue(key, value))
253253
}
@@ -275,31 +275,35 @@ class HoconBeanDefinitionReader(registry: BeanDefinitionRegistry)
275275
private def readConstructorArgs(value: ConfigValue) = {
276276
value.as[Option[Either[ConfigList, ConfigObject]]] match {
277277
case Some(Left(list)) =>
278-
list.iterator.asScala.map(configValue => readConstructorArg(configValue))
278+
list.iterator.asScala.zipWithIndex.map { case (configValue, idx) =>
279+
readConstructorArg(configValue, forcedIndex = idx)
280+
}
279281
case Some(Right(obj)) =>
280282
validateObj(props = true)(obj)
281283
getProps(obj).iterator.map { case (name, configValue) =>
282-
val (idxOpt, holder) = readConstructorArg(configValue)
283-
holder.setName(name)
284-
(idxOpt, holder)
284+
readConstructorArg(configValue, forcedName = name)
285285
}
286286
case None =>
287287
Iterator.empty
288288
}
289289
}
290290

291-
private def readConstructorArg(value: ConfigValue, forcedName: Option[String] = None) = value match {
291+
private def readConstructorArg(
292+
value: ConfigValue,
293+
forcedIndex: OptArg[Int] = OptArg.Empty,
294+
forcedName: OptArg[String] = OptArg.Empty
295+
) = value match {
292296
case ValueDefinition(obj) =>
293297
validateObj(required = Set(ValueAttr), allowed = Set(IndexAttr, TypeAttr, NameAttr))(obj)
294298
val vh = new ValueHolder(read(obj.get(ValueAttr)))
295299
obj.get(TypeAttr).as[Option[String]].foreach(vh.setType)
296-
(forcedName orElse obj.get(NameAttr).as[Option[String]]).foreach(vh.setName)
297-
val indexOpt = obj.get(IndexAttr).as[Option[Int]]
300+
(forcedName.toOption orElse obj.get(NameAttr).as[Option[String]]).foreach(vh.setName)
301+
val indexOpt = forcedIndex.toOption orElse obj.get(IndexAttr).as[Option[Int]]
298302
(indexOpt, vh)
299303
case _ =>
300304
val vh = new ValueHolder(read(value))
301305
forcedName.foreach(vh.setName)
302-
(None, vh)
306+
(forcedIndex.toOption, vh)
303307
}
304308

305309
private def readPropertyValue(name: String, value: ConfigValue) = value match {
@@ -344,6 +348,6 @@ class HoconBeanDefinitionReader(registry: BeanDefinitionRegistry)
344348
result
345349
}
346350

347-
def loadBeanDefinitions(resource: Resource) =
351+
def loadBeanDefinitions(resource: Resource): Int =
348352
loadBeanDefinitions(ConfigFactory.parseURL(resource.getURL).resolve)
349353
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
package com.avsystem.commons
2+
package spring
3+
4+
import java.lang.reflect.{Constructor, Method, Modifier}
5+
6+
import org.springframework.beans.factory.config.ConstructorArgumentValues.ValueHolder
7+
import org.springframework.beans.factory.config.{BeanDefinition, BeanDefinitionHolder, ConfigurableListableBeanFactory}
8+
import org.springframework.beans.factory.support.{BeanDefinitionRegistry, BeanDefinitionRegistryPostProcessor, ManagedList, ManagedMap, ManagedSet}
9+
import org.springframework.core.ParameterNameDiscoverer
10+
11+
import scala.beans.BeanProperty
12+
import scala.reflect.{ScalaLongSignature, ScalaSignature}
13+
14+
class ScalaDefaultValuesInjector extends BeanDefinitionRegistryPostProcessor {
15+
@BeanProperty var paramNameDiscoverer: ParameterNameDiscoverer =
16+
new ScalaParameterNameDiscoverer
17+
18+
def classLoader: ClassLoader =
19+
Thread.currentThread.getContextClassLoader.opt getOrElse getClass.getClassLoader
20+
21+
def loadClass(name: String): Class[_] = Class.forName(name, false, classLoader)
22+
23+
def postProcessBeanDefinitionRegistry(registry: BeanDefinitionRegistry): Unit = {
24+
def traverse(value: Any): Unit = value match {
25+
case bd: BeanDefinition =>
26+
bd.getConstructorArgumentValues.getGenericArgumentValues.asScala.foreach(traverse)
27+
bd.getConstructorArgumentValues.getIndexedArgumentValues.values.asScala.foreach(traverse)
28+
bd.getPropertyValues.getPropertyValueList.asScala.foreach(pv => traverse(pv.getValue))
29+
injectDefaultValues(bd)
30+
case bdw: BeanDefinitionHolder =>
31+
traverse(bdw.getBeanDefinition)
32+
case vh: ValueHolder =>
33+
traverse(vh.getValue)
34+
case ml: ManagedList[_] =>
35+
ml.asScala.foreach(traverse)
36+
case ms: ManagedSet[_] =>
37+
ms.asScala.foreach(traverse)
38+
case mm: ManagedMap[_, _] =>
39+
mm.asScala.foreach {
40+
case (k, v) =>
41+
traverse(k)
42+
traverse(v)
43+
}
44+
case _ =>
45+
}
46+
47+
registry.getBeanDefinitionNames
48+
.foreach(n => traverse(registry.getBeanDefinition(n)))
49+
}
50+
51+
private def isScalaClass(cls: Class[_]): Boolean = cls.getEnclosingClass match {
52+
case null => cls.getAnnotation(classOf[ScalaSignature]) != null ||
53+
cls.getAnnotation(classOf[ScalaLongSignature]) != null
54+
case encls => isScalaClass(encls)
55+
}
56+
57+
private def injectDefaultValues(bd: BeanDefinition): Unit = {
58+
val className = bd.getFactoryBeanName.opt getOrElse bd.getBeanClassName
59+
loadClass(className).recoverToOpt[ClassNotFoundException].filter(isScalaClass).foreach { clazz =>
60+
val usingConstructor = bd.getFactoryMethodName == null
61+
val factoryExecs =
62+
if (usingConstructor) clazz.getConstructors.toVector
63+
else clazz.getMethods.iterator.filter(_.getName == bd.getFactoryMethodName).toVector
64+
val factorySymbolName =
65+
if (usingConstructor) "$lessinit$greater" else bd.getFactoryMethodName
66+
67+
if (factoryExecs.size == 1) {
68+
val constrVals = bd.getConstructorArgumentValues
69+
val factoryExec = factoryExecs.head
70+
val paramNames = factoryExec match {
71+
case c: Constructor[_] => paramNameDiscoverer.getParameterNames(c)
72+
case m: Method => paramNameDiscoverer.getParameterNames(m)
73+
}
74+
(0 until factoryExec.getParameterCount).foreach { i =>
75+
def defaultValueMethod = clazz.getMethod(s"$factorySymbolName$$default$$${i + 1}")
76+
.recoverToOpt[NoSuchMethodException].filter(m => Modifier.isStatic(m.getModifiers))
77+
def specifiedNamed = paramNames != null &&
78+
constrVals.getGenericArgumentValues.asScala.exists(_.getName == paramNames(i))
79+
def specifiedIndexed =
80+
constrVals.getIndexedArgumentValues.get(i) != null
81+
if (!specifiedNamed && !specifiedIndexed) {
82+
defaultValueMethod.foreach { dvm =>
83+
constrVals.addIndexedArgumentValue(i, dvm.invoke(null))
84+
}
85+
}
86+
}
87+
}
88+
}
89+
}
90+
91+
def postProcessBeanFactory(beanFactory: ConfigurableListableBeanFactory): Unit = ()
92+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
package com.avsystem.commons
2+
package spring
3+
4+
import java.lang.reflect.{Constructor, Executable, Method, Modifier}
5+
6+
import org.springframework.core.{JdkVersion, ParameterNameDiscoverer}
7+
8+
import scala.annotation.tailrec
9+
import scala.ref.WeakReference
10+
import scala.reflect.api.JavaUniverse
11+
import scala.reflect.{ScalaLongSignature, ScalaSignature}
12+
13+
object ScalaParameterNameDiscoverer {
14+
final val ScalaSignatureClasses =
15+
List(classOf[ScalaSignature], classOf[ScalaLongSignature])
16+
17+
final val JdkAtLeast8 =
18+
JdkVersion.getMajorJavaVersion >= JdkVersion.JAVA_18
19+
20+
// we don't want to keep the universe in memory forever, so we don't use scala.reflect.runtime.universe
21+
private var universeRef: WeakReference[JavaUniverse] = _
22+
23+
private def universe: JavaUniverse = {
24+
universeRef.option.flatMap(_.get) match {
25+
case Some(result) => result
26+
case None =>
27+
val result = new scala.reflect.runtime.JavaUniverse
28+
universeRef = new WeakReference[JavaUniverse](result)
29+
result
30+
}
31+
}
32+
}
33+
34+
class ScalaParameterNameDiscoverer extends ParameterNameDiscoverer {
35+
36+
import ScalaParameterNameDiscoverer._
37+
38+
@tailrec private def isScala(cls: Class[_]): Boolean = cls.getEnclosingClass match {
39+
case null => ScalaSignatureClasses.exists(ac => cls.getAnnotation(ac) != null)
40+
case encls => isScala(encls)
41+
}
42+
43+
private def discoverNames(u: JavaUniverse)(executable: Executable, symbolPredicate: u.Symbol => Boolean): Array[String] = {
44+
import u._
45+
46+
val declaringClass = executable.getDeclaringClass
47+
val mirror = runtimeMirror(declaringClass.getClassLoader)
48+
val ownerSymbol =
49+
if (Modifier.isStatic(executable.getModifiers)) mirror.moduleSymbol(declaringClass).moduleClass.asType
50+
else mirror.classSymbol(declaringClass)
51+
52+
def argErasuresMatch(ms: MethodSymbol) =
53+
ms.paramLists.flatten.map(s => mirror.runtimeClass(s.typeSignature)) == executable.getParameterTypes.toList
54+
55+
def paramNames(ms: MethodSymbol) =
56+
ms.paramLists.flatten.map(_.name.toString).toArray
57+
58+
ownerSymbol.toType.members
59+
.find(s => symbolPredicate(s) && argErasuresMatch(s.asMethod))
60+
.map(s => paramNames(s.asMethod))
61+
.orNull
62+
}
63+
64+
def getParameterNames(ctor: Constructor[_]): Array[String] =
65+
if (JdkAtLeast8 && ctor.getParameters.forall(_.isNamePresent))
66+
ctor.getParameters.map(_.getName)
67+
else if (isScala(ctor.getDeclaringClass))
68+
discoverNames(universe)(ctor, s => s.isConstructor)
69+
else null
70+
71+
def getParameterNames(method: Method): Array[String] = {
72+
val declaringCls = method.getDeclaringClass
73+
if (JdkAtLeast8 && method.getParameters.forall(_.isNamePresent))
74+
method.getParameters.map(_.getName)
75+
else if (isScala(declaringCls)) {
76+
// https://github.com/scala/bug/issues/10650
77+
val forStaticForwarder =
78+
if (Modifier.isStatic(method.getModifiers))
79+
Class.forName(declaringCls.getName + "$", false, declaringCls.getClassLoader)
80+
.recoverToOpt[ClassNotFoundException]
81+
.flatMap(_.getMethod(method.getName, method.getParameterTypes: _*).recoverToOpt[NoSuchMethodException])
82+
.map(getParameterNames)
83+
else
84+
Opt.Empty
85+
forStaticForwarder.getOrElse(
86+
discoverNames(universe)(method, s => s.isMethod && s.name.toString == method.getName))
87+
}
88+
else null
89+
}
90+
}
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,58 @@
1-
beanClass = com.avsystem.commons.spring.TestBean
1+
abstract {
2+
testBean {
3+
%class = com.avsystem.commons.spring.TestBean
4+
}
5+
constrTestBean = ${abstract.testBean} {
6+
%construct = true
7+
}
8+
fmTestBean = ${abstract.constrTestBean} {
9+
%factory-method = create
10+
}
11+
}
212

313
beans {
4-
testBean {
5-
%class = ${beanClass}
6-
%constructor-args = [42, "lolzsy"]
14+
testBean = ${abstract.testBean} {
15+
%constructor-args = [42, lolzsy]
716
int = 5
8-
string = "lol"
17+
string = lol
918
strIntMap {
10-
"fuu" = 42
19+
fuu = 42
1120
}
12-
strList = ["a", "b"]
13-
strSet = ["A", "B"]
14-
nestedBean {
15-
%class = ${beanClass}
21+
strList = [a, b]
22+
strSet = [A, B]
23+
nestedBean = ${abstract.testBean} {
1624
%constructor-args {
17-
constrString = "wut"
25+
constrString = wut
1826
constrInt = 1
1927
}
2028
int = 6
21-
nestedBean {
22-
%class = ${beanClass}
23-
%construct = true
24-
constrString = "yes"
29+
nestedBean = ${abstract.constrTestBean} {
30+
constrString = yes
2531
constrInt = 2
2632
}
2733
}
2834
config.%config {
2935
srsly = dafuq
3036
}
3137
}
38+
39+
testBeanDefInt = ${abstract.constrTestBean} {
40+
constrString = constrNonDefault
41+
}
42+
43+
testBeanDefString = ${abstract.constrTestBean} {
44+
constrInt = 2
45+
}
46+
47+
testBeanDefAll = ${abstract.constrTestBean}
48+
49+
testBeanFMDefInt = ${abstract.fmTestBean} {
50+
theString = factoryNonDefault
51+
}
52+
53+
testBeanFMDefString = ${abstract.fmTestBean} {
54+
theInt = -2
55+
}
56+
57+
testBeanFMDefAll = ${abstract.fmTestBean}
3258
}

commons-spring/src/test/scala/com/avsystem/commons/spring/AnnotationParameterNameDiscoverer.scala

-14
This file was deleted.

0 commit comments

Comments
 (0)