diff --git a/compiler/rustc_ast_lowering/src/expr.rs b/compiler/rustc_ast_lowering/src/expr.rs index 19433047595a0..8125da361b814 100644 --- a/compiler/rustc_ast_lowering/src/expr.rs +++ b/compiler/rustc_ast_lowering/src/expr.rs @@ -314,21 +314,8 @@ impl<'hir> LoweringContext<'_, 'hir> { hir::ExprKind::Continue(self.lower_jump_destination(e.id, *opt_label)) } ExprKind::Ret(e) => { - let mut e = e.as_ref().map(|x| self.lower_expr(x)); - if let Some(Some((span, fresh_ident))) = self - .contract - .as_ref() - .map(|c| c.ensures.as_ref().map(|e| (e.expr.span, e.fresh_ident))) - { - let checker_fn = self.expr_ident(span, fresh_ident.0, fresh_ident.2); - let args = if let Some(e) = e { - std::slice::from_ref(e) - } else { - std::slice::from_ref(self.expr_unit(span)) - }; - e = Some(self.expr_call(span, checker_fn, args)); - } - hir::ExprKind::Ret(e) + let expr = e.as_ref().map(|x| self.lower_expr(x)); + self.checked_return(expr) } ExprKind::Yeet(sub_expr) => self.lower_expr_yeet(e.span, sub_expr.as_deref()), ExprKind::Become(sub_expr) => { @@ -395,6 +382,32 @@ impl<'hir> LoweringContext<'_, 'hir> { }) } + /// Create an `ExprKind::Ret` that is preceded by a call to check contract ensures clause. + fn checked_return(&mut self, opt_expr: Option<&'hir hir::Expr<'hir>>) -> hir::ExprKind<'hir> { + let checked_ret = if let Some(Some((span, fresh_ident))) = + self.contract.as_ref().map(|c| c.ensures.as_ref().map(|e| (e.expr.span, e.fresh_ident))) + { + let expr = opt_expr.unwrap_or_else(|| self.expr_unit(span)); + Some(self.inject_ensures_check(expr, span, fresh_ident.0, fresh_ident.2)) + } else { + opt_expr + }; + hir::ExprKind::Ret(checked_ret) + } + + /// Wraps an expression with a call to the ensures check before it gets returned. + pub(crate) fn inject_ensures_check( + &mut self, + expr: &'hir hir::Expr<'hir>, + span: Span, + check_ident: Ident, + check_hir_id: HirId, + ) -> &'hir hir::Expr<'hir> { + let checker_fn = self.expr_ident(span, check_ident, check_hir_id); + let span = self.mark_span_with_reason(DesugaringKind::Contract, span, None); + self.expr_call(span, checker_fn, std::slice::from_ref(expr)) + } + pub(crate) fn lower_const_block(&mut self, c: &AnonConst) -> hir::ConstBlock { self.with_new_scopes(c.value.span, |this| { let def_id = this.local_def_id(c.id); @@ -1983,7 +1996,8 @@ impl<'hir> LoweringContext<'_, 'hir> { ), )) } else { - self.arena.alloc(self.expr(try_span, hir::ExprKind::Ret(Some(from_residual_expr)))) + let ret_expr = self.checked_return(Some(from_residual_expr)); + self.arena.alloc(self.expr(try_span, ret_expr)) }; self.lower_attrs(ret_expr.hir_id, &attrs); @@ -2032,7 +2046,7 @@ impl<'hir> LoweringContext<'_, 'hir> { let target_id = Ok(catch_id); hir::ExprKind::Break(hir::Destination { label: None, target_id }, Some(from_yeet_expr)) } else { - hir::ExprKind::Ret(Some(from_yeet_expr)) + self.checked_return(Some(from_yeet_expr)) } } diff --git a/compiler/rustc_ast_lowering/src/item.rs b/compiler/rustc_ast_lowering/src/item.rs index 4679ccdddbbe1..3a701b2c9c7a1 100644 --- a/compiler/rustc_ast_lowering/src/item.rs +++ b/compiler/rustc_ast_lowering/src/item.rs @@ -1097,6 +1097,7 @@ impl<'hir> LoweringContext<'_, 'hir> { // ==> // { rustc_contract_requires(PRECOND); { body } } let result: hir::Expr<'hir> = if let Some(contract) = opt_contract { + let result_ref = this.arena.alloc(result); let lit_unit = |this: &mut LoweringContext<'_, 'hir>| { this.expr(contract.span, hir::ExprKind::Tup(&[])) }; @@ -1131,30 +1132,22 @@ impl<'hir> LoweringContext<'_, 'hir> { this.arena.alloc(checker_binding_pat), hir::LocalSource::Contract, ), - { - let checker_fn = - this.expr_ident(ens.span, fresh_ident.0, fresh_ident.2); - let span = this.mark_span_with_reason( - DesugaringKind::Contract, - ens.span, - None, - ); - this.expr_call_mut( - span, - checker_fn, - std::slice::from_ref(this.arena.alloc(result)), - ) - }, + this.inject_ensures_check( + result_ref, + ens.span, + fresh_ident.0, + fresh_ident.2, + ), ) } else { let u = lit_unit(this); - (this.stmt_expr(contract.span, u), result) + (this.stmt_expr(contract.span, u), &*result_ref) }; let block = this.block_all( contract.span, arena_vec![this; precond, postcond_checker], - Some(this.arena.alloc(result)), + Some(result), ); this.expr_block(block) } else { diff --git a/tests/ui/contracts/contracts-ensures-early-fn-exit.rs b/tests/ui/contracts/contracts-ensures-early-fn-exit.rs new file mode 100644 index 0000000000000..faf97473a90f2 --- /dev/null +++ b/tests/ui/contracts/contracts-ensures-early-fn-exit.rs @@ -0,0 +1,48 @@ +//@ revisions: unchk_pass chk_pass chk_fail_try chk_fail_ret chk_fail_yeet +// +//@ [unchk_pass] run-pass +//@ [chk_pass] run-pass +//@ [chk_fail_try] run-fail +//@ [chk_fail_ret] run-fail +//@ [chk_fail_yeet] run-fail +// +//@ [unchk_pass] compile-flags: -Zcontract-checks=no +//@ [chk_pass] compile-flags: -Zcontract-checks=yes +//@ [chk_fail_try] compile-flags: -Zcontract-checks=yes +//@ [chk_fail_ret] compile-flags: -Zcontract-checks=yes +//@ [chk_fail_yeet] compile-flags: -Zcontract-checks=yes +//! This test ensures that ensures clauses are checked for different return points of a function. + +#![feature(rustc_contracts)] +#![feature(yeet_expr)] + +/// This ensures will fail in different return points depending on the input. +#[core::contracts::ensures(|ret: &Option| ret.is_some())] +fn try_sum(x: u32, y: u32, z: u32) -> Option { + // Use Yeet to return early. + if x == u32::MAX && (y > 0 || z > 0) { do yeet } + + // Use `?` to early return. + let partial = x.checked_add(y)?; + + // Explicitly use `return` clause. + if u32::MAX - partial < z { + return None; + } + + Some(partial + z) +} + +fn main() { + // This should always succeed + assert_eq!(try_sum(0, 1, 2), Some(3)); + + #[cfg(any(unchk_pass, chk_fail_yeet))] + assert_eq!(try_sum(u32::MAX, 1, 1), None); + + #[cfg(any(unchk_pass, chk_fail_try))] + assert_eq!(try_sum(u32::MAX - 10, 12, 0), None); + + #[cfg(any(unchk_pass, chk_fail_ret))] + assert_eq!(try_sum(u32::MAX - 10, 2, 100), None); +}