Skip to content

Commit

Permalink
Check ensures on early return due to Try / Yeet
Browse files Browse the repository at this point in the history
Expand these two expressions to include a call to contract checking
  • Loading branch information
celinval committed Jan 31, 2025
1 parent 8d841fd commit 52024f9
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 33 deletions.
48 changes: 31 additions & 17 deletions compiler/rustc_ast_lowering/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -314,21 +314,8 @@ impl<'hir> LoweringContext<'_, 'hir> {
hir::ExprKind::Continue(self.lower_jump_destination(e.id, *opt_label))
}
ExprKind::Ret(e) => {
let mut e = e.as_ref().map(|x| self.lower_expr(x));
if let Some(Some((span, fresh_ident))) = self
.contract
.as_ref()
.map(|c| c.ensures.as_ref().map(|e| (e.expr.span, e.fresh_ident)))
{
let checker_fn = self.expr_ident(span, fresh_ident.0, fresh_ident.2);
let args = if let Some(e) = e {
std::slice::from_ref(e)
} else {
std::slice::from_ref(self.expr_unit(span))
};
e = Some(self.expr_call(span, checker_fn, args));
}
hir::ExprKind::Ret(e)
let expr = e.as_ref().map(|x| self.lower_expr(x));
self.checked_return(expr)
}
ExprKind::Yeet(sub_expr) => self.lower_expr_yeet(e.span, sub_expr.as_deref()),
ExprKind::Become(sub_expr) => {
Expand Down Expand Up @@ -395,6 +382,32 @@ impl<'hir> LoweringContext<'_, 'hir> {
})
}

/// Create an `ExprKind::Ret` that is preceded by a call to check contract ensures clause.
fn checked_return(&mut self, opt_expr: Option<&'hir hir::Expr<'hir>>) -> hir::ExprKind<'hir> {
let checked_ret = if let Some(Some((span, fresh_ident))) =
self.contract.as_ref().map(|c| c.ensures.as_ref().map(|e| (e.expr.span, e.fresh_ident)))
{
let expr = opt_expr.unwrap_or_else(|| self.expr_unit(span));
Some(self.inject_ensures_check(expr, span, fresh_ident.0, fresh_ident.2))
} else {
opt_expr
};
hir::ExprKind::Ret(checked_ret)
}

/// Wraps an expression with a call to the ensures check before it gets returned.
pub(crate) fn inject_ensures_check(
&mut self,
expr: &'hir hir::Expr<'hir>,
span: Span,
check_ident: Ident,
check_hir_id: HirId,
) -> &'hir hir::Expr<'hir> {
let checker_fn = self.expr_ident(span, check_ident, check_hir_id);
let span = self.mark_span_with_reason(DesugaringKind::Contract, span, None);
self.expr_call(span, checker_fn, std::slice::from_ref(expr))
}

pub(crate) fn lower_const_block(&mut self, c: &AnonConst) -> hir::ConstBlock {
self.with_new_scopes(c.value.span, |this| {
let def_id = this.local_def_id(c.id);
Expand Down Expand Up @@ -1983,7 +1996,8 @@ impl<'hir> LoweringContext<'_, 'hir> {
),
))
} else {
self.arena.alloc(self.expr(try_span, hir::ExprKind::Ret(Some(from_residual_expr))))
let ret_expr = self.checked_return(Some(from_residual_expr));
self.arena.alloc(self.expr(try_span, ret_expr))
};
self.lower_attrs(ret_expr.hir_id, &attrs);

Expand Down Expand Up @@ -2032,7 +2046,7 @@ impl<'hir> LoweringContext<'_, 'hir> {
let target_id = Ok(catch_id);
hir::ExprKind::Break(hir::Destination { label: None, target_id }, Some(from_yeet_expr))
} else {
hir::ExprKind::Ret(Some(from_yeet_expr))
self.checked_return(Some(from_yeet_expr))
}
}

Expand Down
25 changes: 9 additions & 16 deletions compiler/rustc_ast_lowering/src/item.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1097,6 +1097,7 @@ impl<'hir> LoweringContext<'_, 'hir> {
// ==>
// { rustc_contract_requires(PRECOND); { body } }
let result: hir::Expr<'hir> = if let Some(contract) = opt_contract {
let result_ref = this.arena.alloc(result);
let lit_unit = |this: &mut LoweringContext<'_, 'hir>| {
this.expr(contract.span, hir::ExprKind::Tup(&[]))
};
Expand Down Expand Up @@ -1131,30 +1132,22 @@ impl<'hir> LoweringContext<'_, 'hir> {
this.arena.alloc(checker_binding_pat),
hir::LocalSource::Contract,
),
{
let checker_fn =
this.expr_ident(ens.span, fresh_ident.0, fresh_ident.2);
let span = this.mark_span_with_reason(
DesugaringKind::Contract,
ens.span,
None,
);
this.expr_call_mut(
span,
checker_fn,
std::slice::from_ref(this.arena.alloc(result)),
)
},
this.inject_ensures_check(
result_ref,
ens.span,
fresh_ident.0,
fresh_ident.2,
),
)
} else {
let u = lit_unit(this);
(this.stmt_expr(contract.span, u), result)
(this.stmt_expr(contract.span, u), &*result_ref)
};

let block = this.block_all(
contract.span,
arena_vec![this; precond, postcond_checker],
Some(this.arena.alloc(result)),
Some(result),
);
this.expr_block(block)
} else {
Expand Down
48 changes: 48 additions & 0 deletions tests/ui/contracts/contracts-ensures-early-fn-exit.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
//@ revisions: unchk_pass chk_pass chk_fail_try chk_fail_ret chk_fail_yeet
//
//@ [unchk_pass] run-pass
//@ [chk_pass] run-pass
//@ [chk_fail_try] run-fail
//@ [chk_fail_ret] run-fail
//@ [chk_fail_yeet] run-fail
//
//@ [unchk_pass] compile-flags: -Zcontract-checks=no
//@ [chk_pass] compile-flags: -Zcontract-checks=yes
//@ [chk_fail_try] compile-flags: -Zcontract-checks=yes
//@ [chk_fail_ret] compile-flags: -Zcontract-checks=yes
//@ [chk_fail_yeet] compile-flags: -Zcontract-checks=yes
//! This test ensures that ensures clauses are checked for different return points of a function.
#![feature(rustc_contracts)]
#![feature(yeet_expr)]

/// This ensures will fail in different return points depending on the input.
#[core::contracts::ensures(|ret: &Option<u32>| ret.is_some())]
fn try_sum(x: u32, y: u32, z: u32) -> Option<u32> {
// Use Yeet to return early.
if x == u32::MAX && (y > 0 || z > 0) { do yeet }

// Use `?` to early return.
let partial = x.checked_add(y)?;

// Explicitly use `return` clause.
if u32::MAX - partial < z {
return None;
}

Some(partial + z)
}

fn main() {
// This should always succeed
assert_eq!(try_sum(0, 1, 2), Some(3));

#[cfg(any(unchk_pass, chk_fail_yeet))]
assert_eq!(try_sum(u32::MAX, 1, 1), None);

#[cfg(any(unchk_pass, chk_fail_try))]
assert_eq!(try_sum(u32::MAX - 10, 12, 0), None);

#[cfg(any(unchk_pass, chk_fail_ret))]
assert_eq!(try_sum(u32::MAX - 10, 2, 100), None);
}

0 comments on commit 52024f9

Please sign in to comment.