Skip to content

Commit ce366fd

Browse files
Adding support to felt add mul and sub in constant propogation (#8594)
1 parent ab659e2 commit ce366fd

File tree

15 files changed

+625
-936
lines changed

15 files changed

+625
-936
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

crates/cairo-lang-lowering/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ num-integer = { workspace = true, default-features = true }
2626
num-traits = { workspace = true, default-features = true }
2727
salsa.workspace = true
2828
serde = { workspace = true, default-features = true }
29+
starknet-types-core.workspace = true
2930
thiserror.workspace = true
3031

3132
[dev-dependencies]

crates/cairo-lang-lowering/src/lower/test_data/closure

Lines changed: 37 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -307,7 +307,7 @@ test_generated_function
307307
fn foo(a: u32) {
308308
let f = |a: felt252| {
309309
let mut b = @0;
310-
if 1 == 2 {
310+
if 1 == bar(2) {
311311
b = @a;
312312
} else {
313313
b = @a;
@@ -320,6 +320,10 @@ fn foo(a: u32) {
320320
foo
321321

322322
//! > module_code
323+
#[inline(never)]
324+
fn bar(a: felt252) -> felt252 {
325+
a + 1
326+
}
323327

324328
//! > semantic_diagnostics
325329

@@ -330,11 +334,11 @@ Main:
330334
Parameters: v0: core::integer::u32
331335
blk0 (root):
332336
Statements:
333-
(v1: {[email protected]:2:13: 2:25}) <- struct_construct()
334-
(v2: {[email protected]:2:13: 2:25}, v3: @{[email protected]:2:13: 2:25}) <- snapshot(v1)
337+
(v1: {[email protected]:6:13: 6:25}) <- struct_construct()
338+
(v2: {[email protected]:6:13: 6:25}, v3: @{[email protected]:6:13: 6:25}) <- snapshot(v1)
335339
(v4: core::felt252) <- 0
336340
(v5: (core::felt252,)) <- struct_construct(v4)
337-
(v6: ()) <- Generated `core::ops::function::Fn::call` for {[email protected]:2:13: 2:25}(v3, v5)
341+
(v6: ()) <- Generated `core::ops::function::Fn::call` for {[email protected]:6:13: 6:25}(v3, v5)
338342
(v7: ()) <- struct_construct()
339343
End:
340344
Return(v7)
@@ -344,13 +348,14 @@ Final lowering:
344348
Parameters: v0: core::integer::u32
345349
blk0 (root):
346350
Statements:
347-
(v1: core::felt252) <- 1
348-
(v2: core::felt252) <- 2
349-
(v3: core::felt252) <- core::felt252_sub(v1, v2)
351+
(v1: core::felt252) <- 2
352+
(v2: core::felt252) <- test::bar(v1)
353+
(v3: core::felt252) <- 1
354+
(v4: core::felt252) <- core::felt252_sub(v3, v2)
350355
End:
351-
Match(match core::felt252_is_zero(v3) {
356+
Match(match core::felt252_is_zero(v4) {
352357
IsZeroResult::Zero => blk1,
353-
IsZeroResult::NonZero(v4) => blk2,
358+
IsZeroResult::NonZero(v5) => blk2,
354359
})
355360

356361
blk1:
@@ -368,7 +373,7 @@ Generated core::traits::Destruct::destruct lowering for source location:
368373
let f = |a: felt252| {
369374
^^^^^^^^^^^^
370375

371-
Parameters: v0: {[email protected]:2:13: 2:25}
376+
Parameters: v0: {[email protected]:6:13: 6:25}
372377
blk0 (root):
373378
Statements:
374379
() <- struct_destructure(v0)
@@ -378,7 +383,7 @@ End:
378383

379384

380385
Final lowering:
381-
Parameters: v0: {[email protected]:2:13: 2:25}
386+
Parameters: v0: {[email protected]:6:13: 6:25}
382387
blk0 (root):
383388
Statements:
384389
End:
@@ -389,55 +394,57 @@ Generated core::ops::function::Fn::call lowering for source location:
389394
let f = |a: felt252| {
390395
^^^^^^^^^^^^
391396

392-
Parameters: v0: @{[email protected]:2:13: 2:25}, v2: (core::felt252,)
397+
Parameters: v0: @{[email protected]:6:13: 6:25}, v2: (core::felt252,)
393398
blk0 (root):
394399
Statements:
395-
(v1: {[email protected]:2:13: 2:25}) <- desnap(v0)
400+
(v1: {[email protected]:6:13: 6:25}) <- desnap(v0)
396401
() <- struct_destructure(v1)
397402
(v3: core::felt252) <- struct_destructure(v2)
398403
(v4: core::felt252) <- 0
399404
(v5: core::felt252, v6: @core::felt252) <- snapshot(v4)
400405
(v7: core::felt252) <- 1
401406
(v8: core::felt252, v9: @core::felt252) <- snapshot(v7)
402407
(v10: core::felt252) <- 2
403-
(v11: core::felt252, v12: @core::felt252) <- snapshot(v10)
404-
(v13: core::bool) <- core::Felt252PartialEq::eq(v9, v12)
408+
(v11: core::felt252) <- test::bar(v10)
409+
(v12: core::felt252, v13: @core::felt252) <- snapshot(v11)
410+
(v14: core::bool) <- core::Felt252PartialEq::eq(v9, v13)
405411
End:
406-
Match(match_enum(v13) {
407-
bool::False(v15) => blk2,
408-
bool::True(v14) => blk1,
412+
Match(match_enum(v14) {
413+
bool::False(v16) => blk2,
414+
bool::True(v15) => blk1,
409415
})
410416

411417
blk1:
412418
Statements:
413-
(v18: core::felt252, v19: @core::felt252) <- snapshot(v3)
419+
(v19: core::felt252, v20: @core::felt252) <- snapshot(v3)
414420
End:
415-
Goto(blk3, {v18 -> v20, v19 -> v21})
421+
Goto(blk3, {v19 -> v21, v20 -> v22})
416422

417423
blk2:
418424
Statements:
419-
(v16: core::felt252, v17: @core::felt252) <- snapshot(v3)
425+
(v17: core::felt252, v18: @core::felt252) <- snapshot(v3)
420426
End:
421-
Goto(blk3, {v16 -> v20, v17 -> v21})
427+
Goto(blk3, {v17 -> v21, v18 -> v22})
422428

423429
blk3:
424430
Statements:
425-
(v22: ()) <- struct_construct()
431+
(v23: ()) <- struct_construct()
426432
End:
427-
Return(v22)
433+
Return(v23)
428434

429435

430436
Final lowering:
431-
Parameters: v0: @{[email protected]:2:13: 2:25}, v1: (core::felt252,)
437+
Parameters: v0: @{[email protected]:6:13: 6:25}, v1: (core::felt252,)
432438
blk0 (root):
433439
Statements:
434-
(v2: core::felt252) <- 1
435-
(v3: core::felt252) <- 2
436-
(v4: core::felt252) <- core::felt252_sub(v2, v3)
440+
(v2: core::felt252) <- 2
441+
(v3: core::felt252) <- test::bar(v2)
442+
(v4: core::felt252) <- 1
443+
(v5: core::felt252) <- core::felt252_sub(v4, v3)
437444
End:
438-
Match(match core::felt252_is_zero(v4) {
445+
Match(match core::felt252_is_zero(v5) {
439446
IsZeroResult::Zero => blk1,
440-
IsZeroResult::NonZero(v5) => blk2,
447+
IsZeroResult::NonZero(v6) => blk2,
441448
})
442449

443450
blk1:

crates/cairo-lang-lowering/src/lower/test_data/for

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ Parameters: v0: core::RangeCheck, v1: core::gas::GasBuiltin
5555
blk0 (root):
5656
Statements:
5757
(v2: core::felt252) <- 1
58-
(v3: core::felt252) <- core::felt252_add(v2, v2)
58+
(v3: core::felt252) <- 2
5959
(v4: core::array::Array::<core::felt252>) <- core::array::array_new::<core::felt252>()
6060
(v5: core::felt252) <- 10
6161
(v6: core::array::Array::<core::felt252>) <- core::array::array_append::<core::felt252>(v4, v5)

crates/cairo-lang-lowering/src/optimizations/const_folding.rs

Lines changed: 61 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ use num_integer::Integer;
2929
use num_traits::cast::ToPrimitive;
3030
use num_traits::{Num, One, Zero};
3131
use salsa::Database;
32+
use starknet_types_core::felt::Felt as Felt252;
3233

3334
use crate::db::LoweringGroup;
3435
use crate::ids::{
@@ -493,12 +494,62 @@ impl<'db, 'mt> ConstFoldingContext<'db, 'mt> {
493494
}
494495
let (id, _generic_args) = stmt.function.get_extern(db)?;
495496
if id == self.felt_sub {
496-
// (a - 0) can be replaced by a.
497-
let val = self.as_int(stmt.inputs[1].var_id)?;
498-
if val.is_zero() {
497+
if let Some(rhs) = self.as_int(stmt.inputs[1].var_id)
498+
&& rhs.is_zero()
499+
{
499500
self.var_info.insert(stmt.outputs[0], VarInfo::Var(stmt.inputs[0]));
501+
None
502+
} else if let (Some(lhs), Some(rhs)) =
503+
(self.as_int(stmt.inputs[0].var_id), self.as_int(stmt.inputs[1].var_id))
504+
{
505+
let value = Felt252::from(lhs - rhs).to_bigint();
506+
Some(self.propagate_const_and_get_statement(value, stmt.outputs[0], false))
507+
} else {
508+
None
509+
}
510+
} else if id == self.felt_add {
511+
if let Some(lhs) = self.as_int(stmt.inputs[0].var_id)
512+
&& lhs.is_zero()
513+
{
514+
self.var_info.insert(stmt.outputs[0], VarInfo::Var(stmt.inputs[1]));
515+
None
516+
} else if let Some(rhs) = self.as_int(stmt.inputs[1].var_id)
517+
&& rhs.is_zero()
518+
{
519+
self.var_info.insert(stmt.outputs[0], VarInfo::Var(stmt.inputs[0]));
520+
None
521+
} else if let (Some(lhs), Some(rhs)) =
522+
(self.as_int(stmt.inputs[0].var_id), self.as_int(stmt.inputs[1].var_id))
523+
{
524+
let value = Felt252::from(lhs + rhs).to_bigint();
525+
Some(self.propagate_const_and_get_statement(value, stmt.outputs[0], false))
526+
} else {
527+
None
528+
}
529+
} else if id == self.felt_mul {
530+
let lhs = self.as_int_ex(stmt.inputs[0].var_id);
531+
let rhs = self.as_int_ex(stmt.inputs[1].var_id);
532+
if lhs.map(|(v, _)| v.is_zero()).unwrap_or(false)
533+
|| rhs.map(|(v, _)| v.is_zero()).unwrap_or(false)
534+
{
535+
Some(self.propagate_zero_and_get_statement(stmt.outputs[0]))
536+
} else if let Some(rhs) = self.as_int(stmt.inputs[1].var_id)
537+
&& rhs.is_one()
538+
{
539+
self.var_info.insert(stmt.outputs[0], VarInfo::Var(stmt.inputs[0]));
540+
None
541+
} else if let Some(lhs) = self.as_int(stmt.inputs[0].var_id)
542+
&& lhs.is_one()
543+
{
544+
self.var_info.insert(stmt.outputs[0], VarInfo::Var(stmt.inputs[1]));
545+
None
546+
} else if let (Some((lhs_val, lhs_nz)), Some((rhs_val, rhs_nz))) = (lhs, rhs) {
547+
let value = Felt252::from(lhs_val * rhs_val).to_bigint();
548+
let nz_ty = lhs_nz && rhs_nz;
549+
Some(self.propagate_const_and_get_statement(value, stmt.outputs[0], nz_ty))
550+
} else {
551+
None
500552
}
501-
None
502553
} else if self.wide_mul_fns.contains(&id) {
503554
let lhs = self.as_int_ex(stmt.inputs[0].var_id);
504555
let rhs = self.as_int(stmt.inputs[1].var_id);
@@ -1211,6 +1262,10 @@ fn priv_const_folding_info<'db>(
12111262
pub struct ConstFoldingLibfuncInfo<'db> {
12121263
/// The `felt252_sub` libfunc.
12131264
felt_sub: ExternFunctionId<'db>,
1265+
/// The `felt252_add` libfunc.
1266+
felt_add: ExternFunctionId<'db>,
1267+
/// The `felt252_mul` libfunc.
1268+
felt_mul: ExternFunctionId<'db>,
12141269
/// The `into_box` libfunc.
12151270
into_box: ExternFunctionId<'db>,
12161271
/// The `unbox` libfunc.
@@ -1380,6 +1435,8 @@ impl<'db> ConstFoldingLibfuncInfo<'db> {
13801435
);
13811436
Self {
13821437
felt_sub: core.extern_function_id("felt252_sub"),
1438+
felt_add: core.extern_function_id("felt252_add"),
1439+
felt_mul: core.extern_function_id("felt252_mul"),
13831440
into_box: box_module.extern_function_id("into_box"),
13841441
unbox: box_module.extern_function_id("unbox"),
13851442
box_forward_snapshot: box_module.generic_function_id("box_forward_snapshot"),

0 commit comments

Comments
 (0)