Skip to content

feat: Multi threading support for circom prover #1120

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions codex/codex.nim
Original file line number Diff line number Diff line change
Expand Up @@ -293,8 +293,9 @@ proc new*(
store = NetworkStore.new(engine, repoStore)
prover =
if config.prover:
let backend =
config.initializeBackend().expect("Unable to create prover backend.")
let backend = config.initializeBackend(taskpool = taskpool).expect(
"Unable to create prover backend."
)
some Prover.new(store, backend, config.numProofSamples)
else:
none Prover
Expand Down
17 changes: 10 additions & 7 deletions codex/slots/proofs/backendfactory.nim
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import os
import strutils
import pkg/chronos
import pkg/chronicles
import pkg/taskpools
import pkg/questionable
import pkg/confutils/defs
import pkg/stew/io2
Expand All @@ -11,7 +12,9 @@ import ../../conf
import ./backends
import ./backendutils

proc initializeFromConfig(config: CodexConf, utils: BackendUtils): ?!AnyBackend =
proc initializeFromConfig(
config: CodexConf, utils: BackendUtils, taskpool: Taskpool
): ?!AnyBackend =
if not fileAccessible($config.circomR1cs, {AccessFlags.Read}) or
not endsWith($config.circomR1cs, ".r1cs"):
return failure("Circom R1CS file not accessible")
Expand All @@ -27,7 +30,7 @@ proc initializeFromConfig(config: CodexConf, utils: BackendUtils): ?!AnyBackend
trace "Initialized prover backend from cli config"
success(
utils.initializeCircomBackend(
$config.circomR1cs, $config.circomWasm, $config.circomZkey
$config.circomR1cs, $config.circomWasm, $config.circomZkey, taskpool
)
)

Expand All @@ -41,14 +44,14 @@ proc zkeyFilePath(config: CodexConf): string =
config.circuitDir / "proof_main.zkey"

proc initializeFromCircuitDirFiles(
config: CodexConf, utils: BackendUtils
config: CodexConf, utils: BackendUtils, taskpool: Taskpool
): ?!AnyBackend {.gcsafe.} =
if fileExists(config.r1csFilePath) and fileExists(config.wasmFilePath) and
fileExists(config.zkeyFilePath):
trace "Initialized prover backend from local files"
return success(
utils.initializeCircomBackend(
config.r1csFilePath, config.wasmFilePath, config.zkeyFilePath
config.r1csFilePath, config.wasmFilePath, config.zkeyFilePath, taskpool
)
)

Expand All @@ -68,11 +71,11 @@ proc suggestDownloadTool(config: CodexConf) =
instructions

proc initializeBackend*(
config: CodexConf, utils: BackendUtils = BackendUtils()
config: CodexConf, utils: BackendUtils = BackendUtils(), taskpool: Taskpool
): ?!AnyBackend =
without backend =? initializeFromConfig(config, utils), cliErr:
without backend =? initializeFromConfig(config, utils, taskpool), cliErr:
info "Could not initialize prover backend from CLI options...", msg = cliErr.msg
without backend =? initializeFromCircuitDirFiles(config, utils), localErr:
without backend =? initializeFromCircuitDirFiles(config, utils, taskpool), localErr:
info "Could not initialize prover backend from circuit dir files...",
msg = localErr.msg
suggestDownloadTool(config)
Expand Down
194 changes: 158 additions & 36 deletions codex/slots/proofs/backends/circomcompat.nim
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,11 @@

{.push raises: [].}

import std/sugar
import std/[sugar, atomics, locks]

import pkg/chronos
import pkg/taskpools
import pkg/chronos/threadsync
import pkg/questionable/results
import pkg/circomcompat

Expand All @@ -22,6 +24,7 @@
import ./converters

export circomcompat, converters
export taskpools

type
CircomCompat* = object
Expand All @@ -35,9 +38,25 @@
zkeyPath: string # path to the zkey file
backendCfg: ptr CircomBn254Cfg
vkp*: ptr CircomKey
taskpool: Taskpool
lock: ptr Lock

NormalizedProofInputs*[H] {.borrow: `.`.} = distinct ProofInputs[H]

ProveTask = object
circom: ptr CircomCompat
ctx: ptr CircomCompatCtx
proof: ptr Proof
success: Atomic[bool]
signal: ThreadSignalPtr

VerifyTask = object
proof: ptr CircomProof
vkp: ptr CircomKey
inputs: ptr CircomInputs
success: VerifyResult
signal: ThreadSignalPtr

func normalizeInput*[H](
self: CircomCompat, input: ProofInputs[H]
): NormalizedProofInputs[H] =
Expand Down Expand Up @@ -79,7 +98,33 @@
if not isNil(self.vkp):
self.vkp.unsafeAddr.release_key()

proc prove[H](self: CircomCompat, input: NormalizedProofInputs[H]): ?!CircomProof =
if not isNil(self.lock):
deinitLock(self.lock[]) # Cleanup the lock
dealloc(self.lock) # Free the memory

proc circomProveTask(task: ptr ProveTask) {.gcsafe.} =
withLock task[].circom.lock[]:
defer:
discard task[].signal.fireSync()

Check warning on line 108 in codex/slots/proofs/backends/circomcompat.nim

View check run for this annotation

Codecov / codecov/patch

codex/slots/proofs/backends/circomcompat.nim#L108

Added line #L108 was not covered by tests

var proofPtr: ptr Proof = nil
try:
if (
let res = task.circom.backendCfg.prove_circuit(task.ctx, proofPtr.addr)
res != ERR_OK

Check warning on line 114 in codex/slots/proofs/backends/circomcompat.nim

View check run for this annotation

Codecov / codecov/patch

codex/slots/proofs/backends/circomcompat.nim#L114

Added line #L114 was not covered by tests
) or proofPtr == nil:
task.success.store(false)
return

Check warning on line 117 in codex/slots/proofs/backends/circomcompat.nim

View check run for this annotation

Codecov / codecov/patch

codex/slots/proofs/backends/circomcompat.nim#L117

Added line #L117 was not covered by tests

copyProof(task.proof, proofPtr[])
task.success.store(true)
finally:
if proofPtr != nil:
proofPtr.addr.release_proof()

proc asyncProve*[H](
self: CircomCompat, input: NormalizedProofInputs[H], proof: ptr Proof
): Future[?!void] {.async.} =
doAssert input.samples.len == self.numSamples, "Number of samples does not match"

doAssert input.slotProof.len <= self.datasetDepth,
Expand Down Expand Up @@ -143,58 +188,130 @@

for s in input.samples:
var
merklePaths = s.merklePaths.mapIt(it.toBytes)
merklePaths = s.merklePaths.mapIt(@(it.toBytes)).concat
data = s.cellData.mapIt(@(it.toBytes)).concat

if ctx.push_input_u256_array(
"merklePaths".cstring,
merklePaths[0].addr,
uint (merklePaths[0].len * merklePaths.len),
"merklePaths".cstring, merklePaths[0].addr, uint (merklePaths.len)
) != ERR_OK:
return failure("Failed to push merkle paths")

if ctx.push_input_u256_array("cellData".cstring, data[0].addr, data.len.uint) !=
ERR_OK:
return failure("Failed to push cell data")

var proofPtr: ptr Proof = nil
without threadPtr =? ThreadSignalPtr.new():
return failure("Unable to create thread signal")

Check warning on line 204 in codex/slots/proofs/backends/circomcompat.nim

View check run for this annotation

Codecov / codecov/patch

codex/slots/proofs/backends/circomcompat.nim#L204

Added line #L204 was not covered by tests

let proof =
try:
if (let res = self.backendCfg.prove_circuit(ctx, proofPtr.addr); res != ERR_OK) or
proofPtr == nil:
return failure("Failed to prove - err code: " & $res)
defer:
threadPtr.close().expect("closing once works")

proofPtr[]
finally:
if proofPtr != nil:
proofPtr.addr.release_proof()
var task = ProveTask(circom: addr self, ctx: ctx, proof: proof, signal: threadPtr)

success proof
let taskPtr = addr task

doAssert task.circom.taskpool.numThreads > 1,
"Must have at least one separate thread or signal will never be fired"

Check warning on line 214 in codex/slots/proofs/backends/circomcompat.nim

View check run for this annotation

Codecov / codecov/patch

codex/slots/proofs/backends/circomcompat.nim#L214

Added line #L214 was not covered by tests
task.circom.taskpool.spawn circomProveTask(taskPtr)
let threadFut = threadPtr.wait()

if joinErr =? catch(await threadFut.join()).errorOption:
if err =? catch(await noCancel threadFut).errorOption:
return failure(err)
if joinErr of CancelledError:
raise joinErr
else:
return failure(joinErr)

proc prove*[H](self: CircomCompat, input: ProofInputs[H]): ?!CircomProof =
self.prove(self.normalizeInput(input))
if not task.success.load():
return failure("Failed to prove circuit")

Check warning on line 227 in codex/slots/proofs/backends/circomcompat.nim

View check run for this annotation

Codecov / codecov/patch

codex/slots/proofs/backends/circomcompat.nim#L227

Added line #L227 was not covered by tests

success()

proc prove*[H](
self: CircomCompat, input: ProofInputs[H]
): Future[?!CircomProof] {.async, raises: [CancelledError].} =
var proof = ProofPtr.new()
defer:
destroyProof(proof)

Check warning on line 236 in codex/slots/proofs/backends/circomcompat.nim

View check run for this annotation

Codecov / codecov/patch

codex/slots/proofs/backends/circomcompat.nim#L236

Added line #L236 was not covered by tests

try:
if error =? (await self.asyncProve(self.normalizeInput(input), proof)).errorOption:
return failure(error)
return success(deepCopy(proof)[])
except CancelledError as exc:
raise exc

proc circomVerifyTask(task: ptr VerifyTask) {.gcsafe.} =
defer:
task[].inputs[].releaseCircomInputs()
discard task[].signal.fireSync()

let res = verify_circuit(task[].proof, task[].inputs, task[].vkp)
if res == ERR_OK:
task[].success[] = true
elif res == ERR_FAILED_TO_VERIFY_PROOF:
task[].success[] = false
else:
task[].success[] = false
error "Failed to verify proof", errorCode = res

Check warning on line 258 in codex/slots/proofs/backends/circomcompat.nim

View check run for this annotation

Codecov / codecov/patch

codex/slots/proofs/backends/circomcompat.nim#L257-L258

Added lines #L257 - L258 were not covered by tests
proc asyncVerify*[H](
self: CircomCompat,
proof: CircomProof,
inputs: ProofInputs[H],
success: VerifyResult,
): Future[?!void] {.async.} =
var proofPtr = unsafeAddr proof
var inputs = inputs.toCircomInputs()

without threadPtr =? ThreadSignalPtr.new():
return failure("Unable to create thread signal")

Check warning on line 269 in codex/slots/proofs/backends/circomcompat.nim

View check run for this annotation

Codecov / codecov/patch

codex/slots/proofs/backends/circomcompat.nim#L269

Added line #L269 was not covered by tests

defer:
threadPtr.close().expect("closing once works")

var task = VerifyTask(
proof: proofPtr,
vkp: self.vkp,
inputs: addr inputs,
success: success,
signal: threadPtr,

Check warning on line 279 in codex/slots/proofs/backends/circomcompat.nim

View check run for this annotation

Codecov / codecov/patch

codex/slots/proofs/backends/circomcompat.nim#L278-L279

Added lines #L278 - L279 were not covered by tests
)

let taskPtr = addr task

doAssert self.taskpool.numThreads > 1,
"Must have at least one separate thread or signal will never be fired"

self.taskpool.spawn circomVerifyTask(taskPtr)

let threadFut = threadPtr.wait()

if joinErr =? catch(await threadFut.join()).errorOption:
if err =? catch(await noCancel threadFut).errorOption:
return failure(err)
if joinErr of CancelledError:
raise joinErr
else:
return failure(joinErr)

success()

proc verify*[H](
self: CircomCompat, proof: CircomProof, inputs: ProofInputs[H]
): ?!bool =
): Future[?!bool] {.async, raises: [CancelledError].} =
## Verify a proof using a ctx
##

var
proofPtr = unsafeAddr proof
inputs = inputs.toCircomInputs()

var res = VerifyResult.new()
defer:
destroyVerifyResult(res)

Check warning on line 308 in codex/slots/proofs/backends/circomcompat.nim

View check run for this annotation

Codecov / codecov/patch

codex/slots/proofs/backends/circomcompat.nim#L308

Added line #L308 was not covered by tests
try:
let res = verify_circuit(proofPtr, inputs.addr, self.vkp)
if res == ERR_OK:
success true
elif res == ERR_FAILED_TO_VERIFY_PROOF:
success false
else:
failure("Failed to verify proof - err code: " & $res)
finally:
inputs.releaseCircomInputs()
if error =? (await self.asyncVerify(proof, inputs, res)).errorOption:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not just return a bool here (or even void) and set to false on failure, true otherwise?

Copy link
Contributor Author

@munna0908 munna0908 Feb 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dryajov To check task completion upon cancellation, I need the response for validation, so I'm passing it as a parameter to asyncVerify.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The response is either passed or failed, right?

return failure(error)
return success(res[])
except CancelledError as exc:
raise exc

proc init*(
_: type CircomCompat,
Expand All @@ -206,10 +323,13 @@
blkDepth = DefaultBlockDepth,
cellElms = DefaultCellElms,
numSamples = DefaultSamplesNum,
taskpool: Taskpool,
): CircomCompat =
## Create a new ctx
##
# Allocate and initialize the lock
var lockPtr = create(Lock) # Allocate memory for the lock
initLock(lockPtr[]) # Initialize the lock

## Create a new ctx
var cfg: ptr CircomBn254Cfg
var zkey = if zkeyPath.len > 0: zkeyPath.cstring else: nil

Expand Down Expand Up @@ -237,4 +357,6 @@
numSamples: numSamples,
backendCfg: cfg,
vkp: vkpPtr,
taskpool: taskpool,

Check warning on line 360 in codex/slots/proofs/backends/circomcompat.nim

View check run for this annotation

Codecov / codecov/patch

codex/slots/proofs/backends/circomcompat.nim#L360

Added line #L360 was not covered by tests
lock: lockPtr,
)
32 changes: 32 additions & 0 deletions codex/slots/proofs/backends/converters.nim
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
{.push raises: [].}

import pkg/circomcompat
import std/atomics

import ../../../contracts
import ../../types
Expand All @@ -22,6 +23,14 @@
CircomProof* = Proof
CircomKey* = VerifyingKey
CircomInputs* = Inputs
VerifyResult* = ptr bool
ProofPtr* = ptr Proof

proc new*(_: type ProofPtr): ProofPtr =
cast[ptr Proof](allocShared0(sizeof(Proof)))

proc new*(_: type VerifyResult): VerifyResult =
cast[ptr bool](allocShared0(sizeof(bool)))

proc toCircomInputs*(inputs: ProofInputs[Poseidon2Hash]): CircomInputs =
var
Expand Down Expand Up @@ -52,3 +61,26 @@

func toGroth16Proof*(proof: CircomProof): Groth16Proof =
Groth16Proof(a: proof.a.toG1, b: proof.b.toG2, c: proof.c.toG1)

proc destroyVerifyResult*(result: VerifyResult) =

Check warning on line 65 in codex/slots/proofs/backends/converters.nim

View check run for this annotation

Codecov / codecov/patch

codex/slots/proofs/backends/converters.nim#L64-L65

Added lines #L64 - L65 were not covered by tests
if result != nil:
deallocShared(result)

proc destroyProof*(proof: ProofPtr) =
if proof != nil:
deallocShared(proof)

proc copyInto*(dest: var G1, src: G1) =
copyMem(addr dest.x[0], addr src.x[0], 32)
copyMem(addr dest.y[0], addr src.y[0], 32)

proc copyInto*(dest: var G2, src: G2) =
for i in 0 .. 1:
copyMem(addr dest.x[i][0], addr src.x[i][0], 32)
copyMem(addr dest.y[i][0], addr src.y[i][0], 32)

proc copyProof*(dest: ptr Proof, src: Proof) =
if not isNil(dest):
copyInto(dest.a, src.a)
copyInto(dest.b, src.b)
copyInto(dest.c, src.c)
Loading
Loading