Skip to content

Commit 496dbea

Browse files
authored
fix: Executor memory overhead overriding (#1462)
1 parent 23dccce commit 496dbea

File tree

3 files changed

+44
-2
lines changed

3 files changed

+44
-2
lines changed

spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1409,6 +1409,11 @@ object CometSparkSessionExtensions extends Logging {
14091409
}
14101410
}
14111411

1412+
/** Calculates Comet shuffle memory size in MB */
1413+
def getCometShuffleMemorySizeInMiB(sparkConf: SparkConf, conf: SQLConf = SQLConf.get): Long = {
1414+
ByteUnit.BYTE.toMiB(getCometShuffleMemorySize(sparkConf, conf))
1415+
}
1416+
14121417
def cometUnifiedMemoryManagerEnabled(sparkConf: SparkConf): Boolean = {
14131418
sparkConf.getBoolean("spark.memory.offHeap.enabled", false)
14141419
}

spark/src/main/scala/org/apache/spark/Plugins.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,8 @@ class CometDriverPlugin extends DriverPlugin with Logging with ShimCometDriverPl
5555
sc.getConf.getSizeAsMb(EXECUTOR_MEMORY_OVERHEAD.key)
5656
} else {
5757
// By default, executorMemory * spark.executor.memoryOverheadFactor, with minimum of 384MB
58-
val executorMemory = sc.getConf.getSizeAsMb(EXECUTOR_MEMORY.key, EXECUTOR_MEMORY_DEFAULT)
58+
val executorMemory =
59+
sc.getConf.getSizeAsMb(EXECUTOR_MEMORY.key, EXECUTOR_MEMORY_DEFAULT)
5960
val memoryOverheadFactor = getMemoryOverheadFactor(sc.getConf)
6061
val memoryOverheadMinMib = getMemoryOverheadMinMib(sc.getConf)
6162

@@ -67,7 +68,7 @@ class CometDriverPlugin extends DriverPlugin with Logging with ShimCometDriverPl
6768
CometSparkSessionExtensions.getCometMemoryOverheadInMiB(sc.getConf)
6869
} else {
6970
// comet shuffle unified memory manager is disabled, so we need to add overhead memory
70-
CometSparkSessionExtensions.getCometShuffleMemorySize(sc.getConf)
71+
CometSparkSessionExtensions.getCometShuffleMemorySizeInMiB(sc.getConf)
7172
}
7273
sc.conf.set(EXECUTOR_MEMORY_OVERHEAD.key, s"${execMemOverhead + cometMemOverhead}M")
7374
val newExecMemOverhead = sc.getConf.getSizeAsMb(EXECUTOR_MEMORY_OVERHEAD.key)

spark/src/test/scala/org/apache/spark/CometPluginsSuite.scala

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,3 +143,39 @@ class CometPluginsNonOverrideSuite extends CometTestBase {
143143
assert(execMemOverhead4 == "2G")
144144
}
145145
}
146+
147+
class CometPluginsUnifiedModeOverrideSuite extends CometTestBase {
148+
override protected def sparkConf: SparkConf = {
149+
val conf = new SparkConf()
150+
conf.set("spark.driver.memory", "1G")
151+
conf.set("spark.executor.memory", "1G")
152+
conf.set("spark.executor.memoryOverhead", "1G")
153+
conf.set("spark.plugins", "org.apache.spark.CometPlugin")
154+
conf.set("spark.comet.enabled", "true")
155+
conf.set("spark.memory.offHeap.enabled", "true")
156+
conf.set("spark.memory.offHeap.size", "2G")
157+
conf.set("spark.comet.exec.shuffle.enabled", "true")
158+
conf.set("spark.comet.exec.enabled", "true")
159+
conf.set("spark.comet.memory.overhead.factor", "0.5")
160+
conf
161+
}
162+
163+
/*
164+
* Since using unified memory, but not shuffle unified memory
165+
* executor memory should be overridden by adding comet shuffle memory size
166+
*/
167+
test("executor memory overhead is correctly overridden") {
168+
val execMemOverhead1 = spark.conf.get("spark.executor.memoryOverhead")
169+
val execMemOverhead2 = spark.sessionState.conf.getConfString("spark.executor.memoryOverhead")
170+
val execMemOverhead3 = spark.sparkContext.getConf.get("spark.executor.memoryOverhead")
171+
val execMemOverhead4 = spark.sparkContext.conf.get("spark.executor.memoryOverhead")
172+
173+
// in unified memory mode, comet memory overhead is spark.memory.offHeap.size (2G) * spark.comet.memory.overhead.factor (0.5) = 1G
174+
// so the total executor memory overhead is executor memory overhead (1G) + comet memory overhead (1G) = 2G
175+
// and the overhead is overridden in MiB
176+
assert(execMemOverhead1 == "2048M")
177+
assert(execMemOverhead2 == "2048M")
178+
assert(execMemOverhead3 == "2048M")
179+
assert(execMemOverhead4 == "2048M")
180+
}
181+
}

0 commit comments

Comments
 (0)