Skip to content

Commit 5e9fd04

Browse files
committed
add --increment-seed argument
1 parent 66dde8d commit 5e9fd04

File tree

1 file changed

+31
-14
lines changed

1 file changed

+31
-14
lines changed

swift/StableDiffusionCLI/main.swift

Lines changed: 31 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,9 @@ struct StableDiffusionSample: ParsableCommand {
6262
@Flag(help: "Reduce memory usage")
6363
var reduceMemory: Bool = false
6464

65+
@Flag(help: "Increse random seed by 1 for each image")
66+
var incrementSeed: Bool = false
67+
6568
mutating func run() throws {
6669
guard FileManager.default.fileExists(atPath: resourcePath) else {
6770
throw RunError.resources("Resource path does not exist \(resourcePath)")
@@ -83,22 +86,36 @@ struct StableDiffusionSample: ParsableCommand {
8386
let sampleTimer = SampleTimer()
8487
sampleTimer.start()
8588

86-
let images = try pipeline.generateImages(
87-
prompt: prompt,
88-
imageCount: imageCount,
89-
stepCount: stepCount,
90-
seed: seed,
91-
scheduler: scheduler.stableDiffusionScheduler
92-
) { progress in
93-
sampleTimer.stop()
94-
handleProgress(progress,sampleTimer)
95-
if progress.stepCount != progress.step {
96-
sampleTimer.start()
89+
let loops = incrementSeed ? imageCount : 1
90+
let imageCountPerBatch = incrementSeed ? 1 : imageCount
91+
92+
for i in 0 ..< loops {
93+
if (incrementSeed) {
94+
log("Generating image \(i+1) of \(imageCount) with seed \(seed)\n")
95+
log("\n")
9796
}
98-
return true
99-
}
10097

101-
_ = try saveImages(images, logNames: true)
98+
let images = try pipeline.generateImages(
99+
prompt: prompt,
100+
imageCount: imageCountPerBatch,
101+
stepCount: stepCount,
102+
seed: seed,
103+
scheduler: scheduler.stableDiffusionScheduler
104+
) { progress in
105+
sampleTimer.stop()
106+
handleProgress(progress,sampleTimer)
107+
if progress.stepCount != progress.step {
108+
sampleTimer.start()
109+
}
110+
return true
111+
}
112+
113+
_ = try saveImages(images, logNames: true)
114+
115+
if (incrementSeed) {
116+
seed += 1
117+
}
118+
}
102119
}
103120

104121
func handleProgress(

0 commit comments

Comments
 (0)