Skip to content

Commit

Permalink
Rust: Use PathResolution module in data flow
Browse files Browse the repository at this point in the history
  • Loading branch information
hvitved committed Jan 28, 2025
1 parent 9bfd713 commit 88186da
Show file tree
Hide file tree
Showing 2 changed files with 129 additions and 72 deletions.
188 changes: 118 additions & 70 deletions rust/ql/lib/codeql/rust/dataflow/internal/DataFlowImpl.qll
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ private import codeql.rust.controlflow.CfgNodes
private import codeql.rust.dataflow.Ssa
private import codeql.rust.dataflow.FlowSummary
private import FlowSummaryImpl as FlowSummaryImpl
private import codeql.rust.elements.internal.PathResolution as PathResolution

/**
* A return kind. A return kind describes how a value can be returned from a
Expand Down Expand Up @@ -652,6 +653,8 @@ private predicate resolveExtendedCanonicalPath(Resolvable r, CrateOriginOption c
abstract class Content extends TContent {
/** Gets a textual representation of this content. */
abstract string toString();

abstract Location getLocation();

Check warning on line 657 in rust/ql/lib/codeql/rust/dataflow/internal/DataFlowImpl.qll

View workflow job for this annotation

GitHub Actions / qldoc

Missing QLdoc for member-predicate DataFlowImpl::Content::getLocation/0
}

/** A canonical path pointing to an enum variant. */
Expand Down Expand Up @@ -680,65 +683,96 @@ abstract class VariantContent extends Content { }

/** A tuple variant. */
private class VariantPositionContent extends VariantContent, TVariantPositionContent {
private VariantCanonicalPath v;
private Variant v;
private int pos_;

VariantPositionContent() { this = TVariantPositionContent(v, pos_) }

Variant getVariant(int pos) { result = v and pos = pos_ }

final override string toString() {
exists(string name |
name = v.getName().getText() and
// only print indices when the arity is > 1
if exists(TVariantPositionContent(v, 1))
then result = name + "(" + pos_ + ")"
else result = name
)
}

final override Location getLocation() {
result = v.getFieldList().(TupleFieldList).getField(pos_).getLocation()
}
}

/**
* TODO: Remove once library types are extracted
*
* A tuple variant from library code.
*/
private class VariantPositionContentLib extends VariantContent, TVariantPositionContentLib {
private VariantCanonicalPath v;
private int pos_;

VariantPositionContentLib() { this = TVariantPositionContentLib(v, pos_) }

VariantCanonicalPath getVariantCanonicalPath(int pos) { result = v and pos = pos_ }

final override string toString() {
// only print indices when the arity is > 1
if exists(TVariantPositionContent(v, 1))
if exists(TVariantPositionContentLib(v, 1))
then result = v.toString() + "(" + pos_ + ")"
else result = v.toString()
}

final override Location getLocation() { result instanceof EmptyLocation }
}

/** A record variant. */
private class VariantFieldContent extends VariantContent, TVariantFieldContent {
private VariantCanonicalPath v;
private Variant v;
private string field_;

VariantFieldContent() { this = TVariantFieldContent(v, field_) }

VariantCanonicalPath getVariantCanonicalPath(string field) { result = v and field = field_ }
Variant getVariant(string field) { result = v and field = field_ }

final override string toString() {
// only print field when the arity is > 1
if strictcount(string f | exists(TVariantFieldContent(v, f))) > 1
then result = v.toString() + "{" + field_ + "}"
else result = v.toString()
exists(string name |
name = v.getName().getText() and
// only print field when the arity is > 1
if strictcount(string f | exists(TVariantFieldContent(v, f))) > 1
then result = name + "{" + field_ + "}"
else result = name
)
}
}

/** A canonical path pointing to a struct. */
class StructCanonicalPath extends MkStructCanonicalPath {
CrateOriginOption crate;
string path;

StructCanonicalPath() { this = MkStructCanonicalPath(crate, path) }

/** Gets the underlying struct. */
Struct getStruct() { hasExtendedCanonicalPath(result, crate, path) }

string getExtendedCanonicalPath() { result = path }

string toString() { result = this.getStruct().getName().getText() }

Location getLocation() { result = this.getStruct().getLocation() }
final override Location getLocation() {
exists(Name f |
f = v.getFieldList().(RecordFieldList).getAField().getName() and
f.getText() = field_ and
result = f.getLocation()
)
}
}

/** Content stored in a field on a struct. */
private class StructFieldContent extends Content, TStructFieldContent {
private StructCanonicalPath s;
private Struct s;
private string field_;

StructFieldContent() { this = TStructFieldContent(s, field_) }

StructCanonicalPath getStructCanonicalPath(string field) { result = s and field = field_ }
Struct getStruct(string field) { result = s and field = field_ }

override string toString() { result = s.getName().getText() + "." + field_.toString() }

override string toString() { result = s.toString() + "." + field_.toString() }
override Location getLocation() {
exists(Name f | f = s.getFieldList().(RecordFieldList).getAField().getName() |
f.getText() = field_ and
result = f.getLocation()
)
}
}

/** A captured variable. */
Expand All @@ -751,11 +785,15 @@ private class CapturedVariableContent extends Content, TCapturedVariableContent
Variable getVariable() { result = v }

override string toString() { result = "captured " + v }

override Location getLocation() { result = v.getLocation() }
}

/** A value referred to by a reference. */
final class ReferenceContent extends Content, TReferenceContent {
override string toString() { result = "&ref" }

override Location getLocation() { result instanceof EmptyLocation }
}

/**
Expand All @@ -768,6 +806,8 @@ final class ReferenceContent extends Content, TReferenceContent {
*/
final class ElementContent extends Content, TElementContent {
override string toString() { result = "element" }

override Location getLocation() { result instanceof EmptyLocation }
}

/**
Expand All @@ -784,6 +824,8 @@ final class TuplePositionContent extends Content, TTuplePositionContent {
int getPosition() { result = pos }

override string toString() { result = "tuple." + pos.toString() }

override Location getLocation() { result instanceof EmptyLocation }
}

/** Holds if `access` indexes a tuple at an index corresponding to `c`. */
Expand Down Expand Up @@ -1011,11 +1053,8 @@ module RustDataFlow implements InputSig<Location> {
}

/** Holds if path `p` resolves to struct `s`. */
private predicate pathResolveToStructCanonicalPath(PathAstNode p, StructCanonicalPath s) {
exists(CrateOriginOption crate, string path |
resolveExtendedCanonicalPath(p, crate, path) and
s = MkStructCanonicalPath(crate, path)
)
private predicate pathResolveToStruct(PathAstNode p, Struct s) {
s = PathResolution::resolvePath(p.getPath())
}

/** Holds if path `p` resolves to variant `v`. */
Expand All @@ -1026,23 +1065,30 @@ module RustDataFlow implements InputSig<Location> {
)
}

/** Holds if path `p` resolves to variant `v`. */
private predicate pathResolveToVariant(PathAstNode p, Variant v) {
v = PathResolution::resolvePath(p.getPath())
}

/** Holds if `p` destructs an enum variant `v`. */
pragma[nomagic]
private predicate tupleVariantDestruction(TupleStructPat p, VariantCanonicalPath v) {
private predicate tupleVariantCanonicalDestruction(TupleStructPat p, VariantCanonicalPath v) {
pathResolveToVariantCanonicalPath(p, v)
}

/** Holds if `p` destructs an enum variant `v`. */
pragma[nomagic]
private predicate recordVariantDestruction(RecordPat p, VariantCanonicalPath v) {
pathResolveToVariantCanonicalPath(p, v)
private predicate tupleVariantDestruction(TupleStructPat p, Variant v) {
pathResolveToVariant(p, v)
}

/** Holds if `p` destructs an enum variant `v`. */
pragma[nomagic]
private predicate recordVariantDestruction(RecordPat p, Variant v) { pathResolveToVariant(p, v) }

/** Holds if `p` destructs a struct `s`. */
pragma[nomagic]
private predicate structDestruction(RecordPat p, StructCanonicalPath s) {
pathResolveToStructCanonicalPath(p, s)
}
private predicate structDestruction(RecordPat p, Struct s) { pathResolveToStruct(p, s) }

/**
* Holds if data can flow from `node1` to `node2` via a read of `c`. Thus,
Expand All @@ -1053,9 +1099,12 @@ module RustDataFlow implements InputSig<Location> {
exists(Content c | c = cs.(SingletonContentSet).getContent() |
exists(TupleStructPatCfgNode pat, int pos |
pat = node1.asPat() and
tupleVariantDestruction(pat.getPat(),
c.(VariantPositionContent).getVariantCanonicalPath(pos)) and
node2.asPat() = pat.getField(pos)
|
tupleVariantDestruction(pat.getPat(), c.(VariantPositionContent).getVariant(pos))
or
tupleVariantCanonicalDestruction(pat.getPat(),
c.(VariantPositionContentLib).getVariantCanonicalPath(pos))
)
or
exists(TuplePatCfgNode pat, int pos |
Expand All @@ -1068,11 +1117,10 @@ module RustDataFlow implements InputSig<Location> {
pat = node1.asPat() and
(
// Pattern destructs a struct-like variant.
recordVariantDestruction(pat.getPat(),
c.(VariantFieldContent).getVariantCanonicalPath(field))
recordVariantDestruction(pat.getPat(), c.(VariantFieldContent).getVariant(field))
or
// Pattern destructs a struct.
structDestruction(pat.getPat(), c.(StructFieldContent).getStructCanonicalPath(field))
structDestruction(pat.getPat(), c.(StructFieldContent).getStruct(field))
) and
node2.asPat() = pat.getFieldPat(field)
)
Expand Down Expand Up @@ -1109,7 +1157,7 @@ module RustDataFlow implements InputSig<Location> {
exists(TryExprCfgNode try |
node1.asExpr() = try.getExpr() and
node2.asExpr() = try and
c.(VariantPositionContent).getVariantCanonicalPath(0).getExtendedCanonicalPath() =
c.(VariantPositionContentLib).getVariantCanonicalPath(0).getExtendedCanonicalPath() =
["crate::option::Option::Some", "crate::result::Result::Ok"]
)
or
Expand All @@ -1129,21 +1177,25 @@ module RustDataFlow implements InputSig<Location> {

/** Holds if `ce` constructs an enum value of type `v`. */
pragma[nomagic]
private predicate tupleVariantConstruction(CallExpr ce, VariantCanonicalPath v) {
private predicate tupleVariantConstruction(CallExpr ce, Variant v) {
pathResolveToVariant(ce.getFunction().(PathExpr), v)
}

/** Holds if `ce` constructs an enum value of type `v`. */
pragma[nomagic]
private predicate tupleVariantCanonicalConstruction(CallExpr ce, VariantCanonicalPath v) {
pathResolveToVariantCanonicalPath(ce.getFunction().(PathExpr), v)
}

/** Holds if `re` constructs an enum value of type `v`. */
pragma[nomagic]
private predicate recordVariantConstruction(RecordExpr re, VariantCanonicalPath v) {
pathResolveToVariantCanonicalPath(re, v)
private predicate recordVariantConstruction(RecordExpr re, Variant v) {
pathResolveToVariant(re, v)
}

/** Holds if `re` constructs a struct value of type `s`. */
pragma[nomagic]
private predicate structConstruction(RecordExpr re, StructCanonicalPath s) {
pathResolveToStructCanonicalPath(re, s)
}
private predicate structConstruction(RecordExpr re, Struct s) { pathResolveToStruct(re, s) }

private predicate tupleAssignment(Node node1, Node node2, TuplePositionContent c) {
exists(AssignmentExprCfgNode assignment, FieldExprCfgNode access |
Expand All @@ -1157,20 +1209,22 @@ module RustDataFlow implements InputSig<Location> {
pragma[nomagic]
private predicate storeContentStep(Node node1, Content c, Node node2) {
exists(CallExprCfgNode call, int pos |
tupleVariantConstruction(call.getCallExpr(),
c.(VariantPositionContent).getVariantCanonicalPath(pos)) and
node1.asExpr() = call.getArgument(pos) and
node2.asExpr() = call
|
tupleVariantConstruction(call.getCallExpr(), c.(VariantPositionContent).getVariant(pos))
or
tupleVariantCanonicalConstruction(call.getCallExpr(),
c.(VariantPositionContentLib).getVariantCanonicalPath(pos))
)
or
exists(RecordExprCfgNode re, string field |
(
// Expression is for a struct-like enum variant.
recordVariantConstruction(re.getRecordExpr(),
c.(VariantFieldContent).getVariantCanonicalPath(field))
recordVariantConstruction(re.getRecordExpr(), c.(VariantFieldContent).getVariant(field))
or
// Expression is for a struct.
structConstruction(re.getRecordExpr(), c.(StructFieldContent).getStructCanonicalPath(field))
structConstruction(re.getRecordExpr(), c.(StructFieldContent).getStruct(field))
) and
node1.asExpr() = re.getFieldExpr(field) and
node2.asExpr() = re
Expand Down Expand Up @@ -1542,27 +1596,21 @@ private module Cached {
)
}

cached
newtype TStructCanonicalPath =
MkStructCanonicalPath(CrateOriginOption crate, string path) {
exists(Struct s | hasExtendedCanonicalPath(s, crate, path))
}

cached
newtype TContent =
TVariantPositionContent(VariantCanonicalPath v, int pos) {
pos in [0 .. v.getVariant().getFieldList().(TupleFieldList).getNumberOfFields() - 1]
or
// TODO: Remove once library types are extracted
TVariantPositionContent(Variant v, int pos) {
pos in [0 .. v.getFieldList().(TupleFieldList).getNumberOfFields() - 1]
} or
// TODO: Remove once library types are extracted
TVariantPositionContentLib(VariantCanonicalPath v, int pos) {
v = MkVariantCanonicalPath(langCoreCrate(), "crate::option::Option", "Some") and
pos = 0
or
// TODO: Remove once library types are extracted
v = MkVariantCanonicalPath(langCoreCrate(), "crate::result::Result", ["Ok", "Err"]) and
pos = 0
} or
TVariantFieldContent(VariantCanonicalPath v, string field) {
field = v.getVariant().getFieldList().(RecordFieldList).getAField().getName().getText()
TVariantFieldContent(Variant v, string field) {
field = v.getFieldList().(RecordFieldList).getAField().getName().getText()
} or
TElementContent() or
TTuplePositionContent(int pos) {
Expand All @@ -1572,8 +1620,8 @@ private module Cached {
]
)]
} or
TStructFieldContent(StructCanonicalPath s, string field) {
field = s.getStruct().getFieldList().(RecordFieldList).getAField().getName().getText()
TStructFieldContent(Struct s, string field) {
field = s.getFieldList().(RecordFieldList).getAField().getName().getText()
} or
TCapturedVariableContent(VariableCapture::CapturedVariable v) or
TReferenceContent()
Expand Down
Loading

0 comments on commit 88186da

Please sign in to comment.