Skip to content

Commit 9eb898a

Browse files
Adding support to felt add mul and sub in constant propogation
1 parent e0a4e19 commit 9eb898a

File tree

13 files changed

+588
-313
lines changed

13 files changed

+588
-313
lines changed

crates/cairo-lang-lowering/src/lower/test_data/closure

Lines changed: 37 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -307,7 +307,7 @@ test_generated_function
307307
fn foo(a: u32) {
308308
let f = |a: felt252| {
309309
let mut b = @0;
310-
if 1 == 2 {
310+
if 1 == bar(2) {
311311
b = @a;
312312
} else {
313313
b = @a;
@@ -320,6 +320,10 @@ fn foo(a: u32) {
320320
foo
321321

322322
//! > module_code
323+
#[inline(never)]
324+
fn bar(a: felt252) -> felt252 {
325+
a + 1
326+
}
323327

324328
//! > semantic_diagnostics
325329

@@ -330,11 +334,11 @@ Main:
330334
Parameters: v0: core::integer::u32
331335
blk0 (root):
332336
Statements:
333-
(v1: {[email protected]:2:13: 2:25}) <- struct_construct()
334-
(v2: {[email protected]:2:13: 2:25}, v3: @{[email protected]:2:13: 2:25}) <- snapshot(v1)
337+
(v1: {[email protected]:6:13: 6:25}) <- struct_construct()
338+
(v2: {[email protected]:6:13: 6:25}, v3: @{[email protected]:6:13: 6:25}) <- snapshot(v1)
335339
(v4: core::felt252) <- 0
336340
(v5: (core::felt252,)) <- struct_construct(v4)
337-
(v6: ()) <- Generated `core::ops::function::Fn::call` for {[email protected]:2:13: 2:25}(v3, v5)
341+
(v6: ()) <- Generated `core::ops::function::Fn::call` for {[email protected]:6:13: 6:25}(v3, v5)
338342
(v7: ()) <- struct_construct()
339343
End:
340344
Return(v7)
@@ -344,13 +348,14 @@ Final lowering:
344348
Parameters: v0: core::integer::u32
345349
blk0 (root):
346350
Statements:
347-
(v1: core::felt252) <- 1
348-
(v2: core::felt252) <- 2
349-
(v3: core::felt252) <- core::felt252_sub(v1, v2)
351+
(v1: core::felt252) <- 2
352+
(v2: core::felt252) <- test::bar(v1)
353+
(v3: core::felt252) <- 1
354+
(v4: core::felt252) <- core::felt252_sub(v3, v2)
350355
End:
351-
Match(match core::felt252_is_zero(v3) {
356+
Match(match core::felt252_is_zero(v4) {
352357
IsZeroResult::Zero => blk1,
353-
IsZeroResult::NonZero(v4) => blk2,
358+
IsZeroResult::NonZero(v5) => blk2,
354359
})
355360

356361
blk1:
@@ -368,7 +373,7 @@ Generated core::traits::Destruct::destruct lowering for source location:
368373
let f = |a: felt252| {
369374
^^^^^^^^^^^^
370375

371-
Parameters: v0: {[email protected]:2:13: 2:25}
376+
Parameters: v0: {[email protected]:6:13: 6:25}
372377
blk0 (root):
373378
Statements:
374379
() <- struct_destructure(v0)
@@ -378,7 +383,7 @@ End:
378383

379384

380385
Final lowering:
381-
Parameters: v0: {[email protected]:2:13: 2:25}
386+
Parameters: v0: {[email protected]:6:13: 6:25}
382387
blk0 (root):
383388
Statements:
384389
End:
@@ -389,55 +394,57 @@ Generated core::ops::function::Fn::call lowering for source location:
389394
let f = |a: felt252| {
390395
^^^^^^^^^^^^
391396

392-
Parameters: v0: @{[email protected]:2:13: 2:25}, v2: (core::felt252,)
397+
Parameters: v0: @{[email protected]:6:13: 6:25}, v2: (core::felt252,)
393398
blk0 (root):
394399
Statements:
395-
(v1: {[email protected]:2:13: 2:25}) <- desnap(v0)
400+
(v1: {[email protected]:6:13: 6:25}) <- desnap(v0)
396401
() <- struct_destructure(v1)
397402
(v3: core::felt252) <- struct_destructure(v2)
398403
(v4: core::felt252) <- 0
399404
(v5: core::felt252, v6: @core::felt252) <- snapshot(v4)
400405
(v7: core::felt252) <- 1
401406
(v8: core::felt252, v9: @core::felt252) <- snapshot(v7)
402407
(v10: core::felt252) <- 2
403-
(v11: core::felt252, v12: @core::felt252) <- snapshot(v10)
404-
(v13: core::bool) <- core::Felt252PartialEq::eq(v9, v12)
408+
(v11: core::felt252) <- test::bar(v10)
409+
(v12: core::felt252, v13: @core::felt252) <- snapshot(v11)
410+
(v14: core::bool) <- core::Felt252PartialEq::eq(v9, v13)
405411
End:
406-
Match(match_enum(v13) {
407-
bool::False(v15) => blk2,
408-
bool::True(v14) => blk1,
412+
Match(match_enum(v14) {
413+
bool::False(v16) => blk2,
414+
bool::True(v15) => blk1,
409415
})
410416

411417
blk1:
412418
Statements:
413-
(v18: core::felt252, v19: @core::felt252) <- snapshot(v3)
419+
(v19: core::felt252, v20: @core::felt252) <- snapshot(v3)
414420
End:
415-
Goto(blk3, {v18 -> v20, v19 -> v21})
421+
Goto(blk3, {v19 -> v21, v20 -> v22})
416422

417423
blk2:
418424
Statements:
419-
(v16: core::felt252, v17: @core::felt252) <- snapshot(v3)
425+
(v17: core::felt252, v18: @core::felt252) <- snapshot(v3)
420426
End:
421-
Goto(blk3, {v16 -> v20, v17 -> v21})
427+
Goto(blk3, {v17 -> v21, v18 -> v22})
422428

423429
blk3:
424430
Statements:
425-
(v22: ()) <- struct_construct()
431+
(v23: ()) <- struct_construct()
426432
End:
427-
Return(v22)
433+
Return(v23)
428434

429435

430436
Final lowering:
431-
Parameters: v0: @{[email protected]:2:13: 2:25}, v1: (core::felt252,)
437+
Parameters: v0: @{[email protected]:6:13: 6:25}, v1: (core::felt252,)
432438
blk0 (root):
433439
Statements:
434-
(v2: core::felt252) <- 1
435-
(v3: core::felt252) <- 2
436-
(v4: core::felt252) <- core::felt252_sub(v2, v3)
440+
(v2: core::felt252) <- 2
441+
(v3: core::felt252) <- test::bar(v2)
442+
(v4: core::felt252) <- 1
443+
(v5: core::felt252) <- core::felt252_sub(v4, v3)
437444
End:
438-
Match(match core::felt252_is_zero(v4) {
445+
Match(match core::felt252_is_zero(v5) {
439446
IsZeroResult::Zero => blk1,
440-
IsZeroResult::NonZero(v5) => blk2,
447+
IsZeroResult::NonZero(v6) => blk2,
441448
})
442449

443450
blk1:

crates/cairo-lang-lowering/src/lower/test_data/for

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ Parameters: v0: core::RangeCheck, v1: core::gas::GasBuiltin
5555
blk0 (root):
5656
Statements:
5757
(v2: core::felt252) <- 1
58-
(v3: core::felt252) <- core::felt252_add(v2, v2)
58+
(v3: core::felt252) <- 2
5959
(v4: core::array::Array::<core::felt252>) <- core::array::array_new::<core::felt252>()
6060
(v5: core::felt252) <- 10
6161
(v6: core::array::Array::<core::felt252>) <- core::array::array_append::<core::felt252>(v4, v5)

crates/cairo-lang-lowering/src/optimizations/const_folding.rs

Lines changed: 40 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ use cairo_lang_semantic::types::{TypeSizeInformation, TypesSemantic};
1818
use cairo_lang_semantic::{
1919
ConcreteTypeId, GenericArgumentId, MatchArmSelector, TypeId, TypeLongId, corelib,
2020
};
21+
use cairo_lang_utils::bigint::felt252_mod;
2122
use cairo_lang_utils::byte_array::BYTE_ARRAY_MAGIC;
2223
use cairo_lang_utils::ordered_hash_map::OrderedHashMap;
2324
use cairo_lang_utils::ordered_hash_set::OrderedHashSet;
@@ -493,10 +494,39 @@ impl<'db, 'mt> ConstFoldingContext<'db, 'mt> {
493494
}
494495
let (id, _generic_args) = stmt.function.get_extern(db)?;
495496
if id == self.felt_sub {
496-
// (a - 0) can be replaced by a.
497-
let val = self.as_int(stmt.inputs[1].var_id)?;
498-
if val.is_zero() {
499-
self.var_info.insert(stmt.outputs[0], VarInfo::Var(stmt.inputs[0]));
497+
if let Some(rhs) = self.as_int(stmt.inputs[1].var_id) {
498+
if rhs.is_zero() {
499+
self.var_info.insert(stmt.outputs[0], VarInfo::Var(stmt.inputs[0]));
500+
return None;
501+
}
502+
}
503+
if let (Some(lhs), Some(rhs)) =
504+
(self.as_int(stmt.inputs[0].var_id), self.as_int(stmt.inputs[1].var_id))
505+
{
506+
let value = felt252_mod(&(lhs - rhs));
507+
return Some(self.propagate_const_and_get_statement(value, stmt.outputs[0], false));
508+
}
509+
None
510+
} else if id == self.felt_add {
511+
if let (Some(lhs), Some(rhs)) =
512+
(self.as_int(stmt.inputs[0].var_id), self.as_int(stmt.inputs[1].var_id))
513+
{
514+
let value = felt252_mod(&(lhs + rhs));
515+
return Some(self.propagate_const_and_get_statement(value, stmt.outputs[0], false));
516+
}
517+
None
518+
} else if id == self.felt_mul {
519+
let lhs = self.as_int_ex(stmt.inputs[0].var_id);
520+
let rhs = self.as_int_ex(stmt.inputs[1].var_id);
521+
if lhs.map(|(v, _)| v.is_zero()).unwrap_or(false)
522+
|| rhs.map(|(v, _)| v.is_zero()).unwrap_or(false)
523+
{
524+
return Some(self.propagate_zero_and_get_statement(stmt.outputs[0]));
525+
}
526+
if let (Some((lhs_val, lhs_nz)), Some((rhs_val, rhs_nz))) = (lhs, rhs) {
527+
let value = felt252_mod(&(lhs_val * rhs_val));
528+
let nz_ty = lhs_nz && rhs_nz;
529+
return Some(self.propagate_const_and_get_statement(value, stmt.outputs[0], nz_ty));
500530
}
501531
None
502532
} else if self.wide_mul_fns.contains(&id) {
@@ -1211,6 +1241,10 @@ fn priv_const_folding_info<'db>(
12111241
pub struct ConstFoldingLibfuncInfo<'db> {
12121242
/// The `felt252_sub` libfunc.
12131243
felt_sub: ExternFunctionId<'db>,
1244+
/// The `felt252_add` libfunc.
1245+
felt_add: ExternFunctionId<'db>,
1246+
/// The `felt252_mul` libfunc.
1247+
felt_mul: ExternFunctionId<'db>,
12141248
/// The `into_box` libfunc.
12151249
into_box: ExternFunctionId<'db>,
12161250
/// The `unbox` libfunc.
@@ -1380,6 +1414,8 @@ impl<'db> ConstFoldingLibfuncInfo<'db> {
13801414
);
13811415
Self {
13821416
felt_sub: core.extern_function_id("felt252_sub"),
1417+
felt_add: core.extern_function_id("felt252_add"),
1418+
felt_mul: core.extern_function_id("felt252_mul"),
13831419
into_box: box_module.extern_function_id("into_box"),
13841420
unbox: box_module.extern_function_id("unbox"),
13851421
box_forward_snapshot: box_module.generic_function_id("box_forward_snapshot"),

0 commit comments

Comments
 (0)