Skip to content

Commit

Permalink
analyze: support rewriting field projections on nullable pointers (#1096
Browse files Browse the repository at this point in the history
)

This adds support for rewriting field projections like `&(*p).x` when
`p` is a nullable pointer. The result looks like
`Some(&(*p.unwrap()).x)`.

I initially tried to avoid a panic when `p` is null by implementing a
rewrite using `Option::map`: `p.map(|ptr| &ptr.x)`. However,
implementing this correctly wound up being quite complex. It's undefined
behavior in C to do `&p->x` when `p == NULL`, so it seems reasonable to
introduce a panic in that case.

The `mir_op` changes for this are relatively straightforward, but
`unlower`, `distribute`, and `convert` needed some work. In particular,
`unlower` now has a new variant
`MirOriginDesc::LoadFromTempForAdjustment(i)`, which disambiguates cases
like this:
```Rust
// Rust:
f(&(*p).x)

// MIR:
// Evaluate the main expression:
_tmp1 = &(*_p).x;
// unlower_map entries for &(*p).x:
// * Expr
// * StoreIntoLocal

// Adjustments inserted to coerce `&T` to `*const T`:
_tmp2 = addr_of!(*_tmp1);
// * LoadFromTempForAdjustment(0) (load of _tmp1)
// * Adjustment(0) (deref)
// * Adjustment(1) (addr-of)
// * StoreIntoLocal

// The actual function call:
_result = f(_tmp2);
// * LoadFromTemp (load final result of &(*p).x from _tmp2)
```
Previously, the `LoadFromTempForAdjustment` would be recorded as
`LoadFromTemp`, meaning there would be two `LoadFromTemp` entries in the
unlower_map for the expression `&(*p).x`. Rewrites attached to the first
`LoadFromTemp` (in this case, the use of `_tmp1` in the second
statement) would be wrongly applied at the site of the last
`LoadFromTemp`. This caused `unwrap()` and `Some(_)` rewrites to be
applied in the wrong order, causing type errors in the rewritten code.
  • Loading branch information
spernsteiner authored Jul 22, 2024
2 parents d59dbd3 + 8434b34 commit 7bf22e0
Show file tree
Hide file tree
Showing 7 changed files with 353 additions and 100 deletions.
27 changes: 20 additions & 7 deletions c2rust-analyze/src/analyze.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1581,29 +1581,42 @@ fn run2<'tcx>(
let acx = gacx.function_context_with_data(&mir, info.acx_data.take());
let asn = gasn.and(&mut info.lasn);

// Generate inline annotations for pointer-typed locals
for (local, decl) in mir.local_decls.iter_enumerated() {
let span = local_span(decl);
let mut emit_lty_annotations = |span, lty: LTy, desc: &str| {
let mut ptrs = Vec::new();
let ty_str = context::print_ty_with_pointer_labels(acx.local_tys[local], |ptr| {
let ty_str = context::print_ty_with_pointer_labels(lty, |ptr| {
if ptr.is_none() {
return String::new();
}
ptrs.push(ptr);
format!("{{{}}}", ptr)
});
if ptrs.is_empty() {
continue;
return;
}
// TODO: emit addr_of when it's nontrivial
// TODO: emit pointee_types when nontrivial
ann.emit(span, format_args!("typeof({:?}) = {}", local, ty_str));
ann.emit(span, format_args!("typeof({}) = {}", desc, ty_str));
for ptr in ptrs {
ann.emit(
span,
format_args!(" {} = {:?}, {:?}", ptr, asn.perms()[ptr], asn.flags()[ptr]),
);
}
};

// Generate inline annotations for pointer-typed locals
for (local, decl) in mir.local_decls.iter_enumerated() {
let span = local_span(decl);
// TODO: emit addr_of when it's nontrivial
let desc = format!("{:?}", local);
emit_lty_annotations(span, acx.local_tys[local], &desc);
}

for (&loc, &rv_lty) in &acx.rvalue_tys {
// `loc` must refer to a statement. Terminators don't have `Rvalue`s and thus never
// appear in `rvalue_tys`.
let stmt = mir.stmt_at(loc).either(|stmt| stmt, |_term| unreachable!());
let span = stmt.source_info.span;
emit_lty_annotations(span, rv_lty, &format!("{:?}", stmt));
}

info.acx_data.set(acx.into_data());
Expand Down
73 changes: 43 additions & 30 deletions c2rust-analyze/src/rewrite/expr/convert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -440,20 +440,34 @@ impl<'tcx> Visitor<'tcx> for ConvertVisitor<'tcx> {
.any(|x| matches!(x.desc, MirOriginDesc::Adjustment(_)));
if self.materialize_adjustments || has_adjustment_rewrites {
let adjusts = self.typeck_results.expr_adjustments(ex);
hir_rw = materialize_adjustments(self.tcx, adjusts, hir_rw, |i, hir_rw| {
let adj_rws =
take_prefix_while(&mut mir_rws, |x| x.desc == MirOriginDesc::Adjustment(i));
self.rewrite_from_mir_rws(Some(ex), adj_rws, hir_rw)
});
hir_rw =
materialize_adjustments(self.tcx, adjusts, hir_rw, |step, hir_rw| match step {
AdjustmentStep::Before(i) => {
let load_rws = take_prefix_while(&mut mir_rws, |x| {
x.desc == MirOriginDesc::LoadFromTempForAdjustment(i)
});
self.rewrite_from_mir_rws(Some(ex), load_rws, hir_rw)
}
AdjustmentStep::After(i) => {
let adj_rws = take_prefix_while(&mut mir_rws, |x| {
x.desc == MirOriginDesc::Adjustment(i)
});
self.rewrite_from_mir_rws(Some(ex), adj_rws, hir_rw)
}
});
}

// Apply late rewrites.
assert!(mir_rws.iter().all(|mir_rw| {
matches!(
for mir_rw in mir_rws {
assert!(
matches!(
mir_rw.desc,
MirOriginDesc::StoreIntoLocal | MirOriginDesc::LoadFromTemp
),
"bad desc {:?} for late rewrite: {mir_rw:?}",
mir_rw.desc,
MirOriginDesc::StoreIntoLocal | MirOriginDesc::LoadFromTemp
)
}));
);
}
hir_rw = self.rewrite_from_mir_rws(Some(ex), mir_rws, hir_rw);

if !matches!(hir_rw, Rewrite::Identity) {
Expand All @@ -474,7 +488,7 @@ fn mutbl_from_bool(m: bool) -> hir::Mutability {
}
}

fn apply_identity_adjustment<'tcx>(
fn apply_adjustment<'tcx>(
tcx: TyCtxt<'tcx>,
adjustment: &Adjustment<'tcx>,
rw: Rewrite,
Expand Down Expand Up @@ -509,41 +523,40 @@ fn apply_identity_adjustment<'tcx>(
}
}

enum AdjustmentStep {
Before(usize),
After(usize),
}

fn materialize_adjustments<'tcx>(
tcx: TyCtxt<'tcx>,
adjustments: &[Adjustment<'tcx>],
hir_rw: Rewrite,
mut callback: impl FnMut(usize, Rewrite) -> Rewrite,
mut callback: impl FnMut(AdjustmentStep, Rewrite) -> Rewrite,
) -> Rewrite {
let adj_kinds: Vec<&_> = adjustments.iter().map(|a| &a.kind).collect();
match (hir_rw, &adj_kinds[..]) {
(Rewrite::Identity, []) => Rewrite::Identity,
(Rewrite::Identity, _) => {
let mut hir_rw = Rewrite::Identity;
for (i, adj) in adjustments.iter().enumerate() {
hir_rw = apply_identity_adjustment(tcx, adj, hir_rw);
hir_rw = callback(i, hir_rw);
}
hir_rw
}
// TODO: ideally we should always materialize all adjustments (removing these special
// cases), and use `MirRewrite`s to eliminate any adjustments we no longer need.
(rw @ Rewrite::Ref(..), &[Adjust::Deref(..), Adjust::Borrow(..)]) => rw,
(rw @ Rewrite::MethodCall(..), &[Adjust::Deref(..), Adjust::Borrow(..)]) => rw,
// The mut-to-const cast should be unneeded once the inner rewrite switches to a safe
// reference type appropriate for the pointer's uses. However, we still want to give
// `callback` a chance to remove the cast itself so that if there's a `RemoveCast` rewrite
// on this adjustment, we don't get an error about it failing to apply.
(rw, &[Adjust::Pointer(PointerCast::MutToConstPointer)]) => {
let mut hir_rw = Rewrite::RemovedCast(Box::new(rw));
hir_rw = callback(0, hir_rw);
(mut hir_rw, &[Adjust::Pointer(PointerCast::MutToConstPointer)]) => {
hir_rw = callback(AdjustmentStep::Before(0), hir_rw);
hir_rw = Rewrite::RemovedCast(Box::new(hir_rw));
hir_rw = callback(AdjustmentStep::After(0), hir_rw);
match hir_rw {
Rewrite::RemovedCast(rw) => *rw,
rw => rw,
}
}
(rw, &[]) => rw,
(rw, adjs) => panic!("rewrite {rw:?} and materializations {adjs:?} NYI"),
(mut hir_rw, _) => {
for (i, adj) in adjustments.iter().enumerate() {
hir_rw = callback(AdjustmentStep::Before(i), hir_rw);
hir_rw = apply_adjustment(tcx, adj, hir_rw);
hir_rw = callback(AdjustmentStep::After(i), hir_rw);
}
hir_rw
}
}
}

Expand Down
57 changes: 53 additions & 4 deletions c2rust-analyze/src/rewrite/expr/distribute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use log::*;
use rustc_hir::HirId;
use rustc_middle::mir::Location;
use rustc_middle::ty::TyCtxt;
use std::cmp::Ordering;
use std::collections::{BTreeMap, HashMap};

struct RewriteInfo {
Expand All @@ -19,18 +20,65 @@ struct RewriteInfo {
///
/// The order of variants follows the order of operations we typically see in generated MIR code.
/// For a given HIR `Expr`, the MIR will usually evaluate the expression ([`Priority::Eval`]),
/// apply zero or more adjustments ([`Priority::Adjust(i)`][Priority::Adjust]), store the result
/// into a temporary ([`Priority::_StoreResult`]; currently unused), and later load the result back
/// from the temporary ([`Priority::LoadResult`]).
#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Debug)]
/// apply zero or more adjustments ([`Priority::LoadForAdjust(i)`][Priority::LoadForAdjust] and
/// [`Priority::Adjust(i)`][Priority::Adjust]), store the result into a temporary
/// ([`Priority::_StoreResult`]; currently unused), and later load the result back from the
/// temporary ([`Priority::LoadResult`]).
#[derive(Clone, Copy, PartialEq, Eq, Debug)]
enum Priority {
Eval,
/// Load from a temporary for use in the adjustment at index `i`.
LoadForAdjust(usize),
/// Apply the rewrite just after the adjustment at index `i`.
Adjust(usize),
_StoreResult,
LoadResult,
}

impl PartialOrd for Priority {
fn partial_cmp(&self, other: &Priority) -> Option<Ordering> {
Some(self.cmp(other))
}
}

impl Ord for Priority {
fn cmp(&self, other: &Priority) -> Ordering {
use Priority::*;
match (*self, *other) {
// 1. Eval
(Eval, Eval) => Ordering::Equal,
(Eval, _) => Ordering::Less,
(_, Eval) => Ordering::Greater,

// 2. LoadForAdjust(0), Adjust(0), LoadForAdjust(1), Adjust(1), ...
(LoadForAdjust(i), LoadForAdjust(j)) => i.cmp(&j),
(LoadForAdjust(i), Adjust(j)) => match i.cmp(&j) {
Ordering::Equal => Ordering::Less,
Ordering::Less => Ordering::Less,
Ordering::Greater => Ordering::Greater,
},
(Adjust(i), Adjust(j)) => i.cmp(&j),
(Adjust(i), LoadForAdjust(j)) => match i.cmp(&j) {
Ordering::Equal => Ordering::Greater,
Ordering::Less => Ordering::Less,
Ordering::Greater => Ordering::Greater,
},
(LoadForAdjust(_), _) => Ordering::Less,
(_, LoadForAdjust(_)) => Ordering::Greater,
(Adjust(_), _) => Ordering::Less,
(_, Adjust(_)) => Ordering::Greater,

// 3. _StoreResult
(_StoreResult, _StoreResult) => Ordering::Equal,
(_StoreResult, _) => Ordering::Less,
(_, _StoreResult) => Ordering::Greater,

// 4. LoadResult
(LoadResult, LoadResult) => Ordering::Equal,
}
}
}

#[derive(Clone, Debug)]
pub struct DistRewrite {
pub rw: mir_op::RewriteKind,
Expand Down Expand Up @@ -96,6 +144,7 @@ pub fn distribute(
MirOriginDesc::Expr => Priority::Eval,
MirOriginDesc::Adjustment(i) => Priority::Adjust(i),
MirOriginDesc::LoadFromTemp => Priority::LoadResult,
MirOriginDesc::LoadFromTempForAdjustment(i) => Priority::LoadForAdjust(i),
_ => {
panic!(
"can't distribute rewrites onto {:?} origin {:?}\n\
Expand Down
25 changes: 19 additions & 6 deletions c2rust-analyze/src/rewrite/expr/mir_op.rs
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,12 @@ impl<'a, 'tcx> ExprRewriteVisitor<'a, 'tcx> {
Some(lty.args[0])
}

fn is_nullable(&self, ptr: PointerId) -> bool {
!ptr.is_none()
&& !self.perms[ptr].contains(PermissionSet::NON_NULL)
&& !self.flags[ptr].contains(FlagSet::FIXED)
}

fn visit_statement(&mut self, stmt: &Statement<'tcx>, loc: Location) {
let _g = panic_detail::set_current_span(stmt.source_info.span);
eprintln!(
Expand Down Expand Up @@ -558,6 +564,14 @@ impl<'a, 'tcx> ExprRewriteVisitor<'a, 'tcx> {
BorrowKind::Shared | BorrowKind::Shallow | BorrowKind::Unique => false,
};
self.enter_rvalue_place(0, |v| v.visit_place(pl, mutbl));

if let Some(expect_ty) = expect_ty {
if self.is_nullable(expect_ty.label) {
// Nullable (`Option`) output is expected, but `Ref` always produces a
// `NON_NULL` pointer. Cast rvalue from `&T` to `Option<&T>` or similar.
self.emit(RewriteKind::OptionSome);
}
}
}
Rvalue::ThreadLocalRef(_def_id) => {
// TODO
Expand All @@ -579,6 +593,9 @@ impl<'a, 'tcx> ExprRewriteVisitor<'a, 'tcx> {
}),
_ => (),
}
if desc.option {
self.emit(RewriteKind::OptionSome);
}
}
}
Rvalue::Len(pl) => {
Expand All @@ -588,9 +605,7 @@ impl<'a, 'tcx> ExprRewriteVisitor<'a, 'tcx> {
if util::is_null_const_operand(op) && ty.is_unsafe_ptr() {
// Special case: convert `0 as *const T` to `None`.
if let Some(rv_lty) = expect_ty {
if !self.perms[rv_lty.label].contains(PermissionSet::NON_NULL)
&& !self.flags[rv_lty.label].contains(FlagSet::FIXED)
{
if self.is_nullable(rv_lty.label) {
self.emit(RewriteKind::ZeroAsPtrToNone);
}
}
Expand Down Expand Up @@ -730,9 +745,7 @@ impl<'a, 'tcx> ExprRewriteVisitor<'a, 'tcx> {
PlaceElem::Deref => {
self.enter_place_deref_pointer(|v| {
v.visit_place_ref(base_pl, proj_ltys, in_mutable_context);
if !v.perms[base_lty.label].contains(PermissionSet::NON_NULL)
&& !v.flags[base_lty.label].contains(FlagSet::FIXED)
{
if v.is_nullable(base_lty.label) {
// If the pointer type is non-copy, downgrade (borrow) before calling
// `unwrap()`.
let desc = type_desc::perms_to_desc(
Expand Down
Loading

0 comments on commit 7bf22e0

Please sign in to comment.