Skip to content

Commit 0a6d884

Browse files
authored
Merge pull request #19749 from paldepind/rust/impl-parameter-resolution
Rust: Disambiguate some method calls based on argument types
2 parents 6cca016 + ef15df3 commit 0a6d884

File tree

4 files changed

+1407
-1202
lines changed

4 files changed

+1407
-1202
lines changed

rust/ql/lib/codeql/rust/internal/TypeInference.qll

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1231,11 +1231,133 @@ private Function getTypeParameterMethod(TypeParameter tp, string name) {
12311231
result = getMethodSuccessor(tp.(ImplTraitTypeTypeParameter).getImplTraitTypeRepr(), name)
12321232
}
12331233

1234+
bindingset[t1, t2]
1235+
private predicate typeMentionEqual(TypeMention t1, TypeMention t2) {
1236+
forex(TypePath path, Type type | t1.resolveTypeAt(path) = type | t2.resolveTypeAt(path) = type)
1237+
}
1238+
1239+
pragma[nomagic]
1240+
private predicate implSiblingCandidate(
1241+
Impl impl, TraitItemNode trait, Type rootType, TypeMention selfTy
1242+
) {
1243+
trait = impl.(ImplItemNode).resolveTraitTy() and
1244+
// If `impl` has an expansion from a macro attribute, then it's been
1245+
// superseded by the output of the expansion (and usually the expansion
1246+
// contains the same `impl` block so considering both would give spurious
1247+
// siblings).
1248+
not exists(impl.getAttributeMacroExpansion()) and
1249+
// We use this for resolving methods, so exclude traits that do not have methods.
1250+
exists(Function f | f = trait.getASuccessor(_) and f.getParamList().hasSelfParam()) and
1251+
selfTy = impl.getSelfTy() and
1252+
rootType = selfTy.resolveType()
1253+
}
1254+
1255+
/**
1256+
* Holds if `impl1` and `impl2` are a sibling implementations of `trait`. We
1257+
* consider implementations to be siblings if they implement the same trait for
1258+
* the same type. In that case `Self` is the same type in both implementations,
1259+
* and method calls to the implementations cannot be resolved unambiguously
1260+
* based only on the receiver type.
1261+
*/
1262+
pragma[inline]
1263+
private predicate implSiblings(TraitItemNode trait, Impl impl1, Impl impl2) {
1264+
exists(Type rootType, TypeMention selfTy1, TypeMention selfTy2 |
1265+
impl1 != impl2 and
1266+
implSiblingCandidate(impl1, trait, rootType, selfTy1) and
1267+
implSiblingCandidate(impl2, trait, rootType, selfTy2) and
1268+
// In principle the second conjunct below should be superflous, but we still
1269+
// have ill-formed type mentions for types that we don't understand. For
1270+
// those checking both directions restricts further. Note also that we check
1271+
// syntactic equality, whereas equality up to renaming would be more
1272+
// correct.
1273+
typeMentionEqual(selfTy1, selfTy2) and
1274+
typeMentionEqual(selfTy2, selfTy1)
1275+
)
1276+
}
1277+
1278+
/**
1279+
* Holds if `impl` is an implementation of `trait` and if another implementation
1280+
* exists for the same type.
1281+
*/
1282+
pragma[nomagic]
1283+
private predicate implHasSibling(Impl impl, Trait trait) { implSiblings(trait, impl, _) }
1284+
1285+
/**
1286+
* Holds if a type parameter of `trait` occurs in the method with the name
1287+
* `methodName` at the `pos`th parameter at `path`.
1288+
*/
1289+
bindingset[trait]
1290+
pragma[inline_late]
1291+
private predicate traitTypeParameterOccurrence(
1292+
TraitItemNode trait, string methodName, int pos, TypePath path
1293+
) {
1294+
exists(Function f | f = trait.getASuccessor(methodName) |
1295+
f.getParam(pos).getTypeRepr().(TypeMention).resolveTypeAt(path) =
1296+
trait.(TraitTypeAbstraction).getATypeParameter()
1297+
)
1298+
}
1299+
1300+
bindingset[f, pos, path]
1301+
pragma[inline_late]
1302+
private predicate methodTypeAtPath(Function f, int pos, TypePath path, Type type) {
1303+
f.getParam(pos).getTypeRepr().(TypeMention).resolveTypeAt(path) = type
1304+
}
1305+
1306+
/**
1307+
* Holds if resolving the method `f` in `impl` with the name `methodName`
1308+
* requires inspecting the types of applied _arguments_ in order to determine
1309+
* whether it is the correct resolution.
1310+
*/
1311+
pragma[nomagic]
1312+
private predicate methodResolutionDependsOnArgument(
1313+
Impl impl, string methodName, Function f, int pos, TypePath path, Type type
1314+
) {
1315+
/*
1316+
* As seen in the example below, when an implementation has a sibling for a
1317+
* trait we find occurrences of a type parameter of the trait in a method
1318+
* signature in the trait. We then find the type given in the implementation
1319+
* at the same position, which is a position that might disambiguate the
1320+
* method from its siblings.
1321+
*
1322+
* ```rust
1323+
* trait MyTrait<T> {
1324+
* fn method(&self, value: Foo<T>) -> Self;
1325+
* // ^^^^^^^^^^^^^ `pos` = 0
1326+
* // ^ `path` = "T"
1327+
* }
1328+
* impl MyAdd<i64> for i64 {
1329+
* fn method(&self, value: Foo<i64>) -> Self { ... }
1330+
* // ^^^ `type` = i64
1331+
* }
1332+
* ```
1333+
*
1334+
* Note that we only check the root type symbol at the position. If the type
1335+
* at that position is a type constructor (for instance `Vec<..>`) then
1336+
* inspecting the entire type tree could be necessary to disambiguate the
1337+
* method. In that case we will still resolve several methods.
1338+
*/
1339+
1340+
exists(TraitItemNode trait |
1341+
implHasSibling(impl, trait) and
1342+
traitTypeParameterOccurrence(trait, methodName, pos, path) and
1343+
methodTypeAtPath(getMethodSuccessor(impl, methodName), pos, path, type) and
1344+
f = getMethodSuccessor(impl, methodName)
1345+
)
1346+
}
1347+
12341348
/** Gets a method from an `impl` block that matches the method call `mc`. */
12351349
private Function getMethodFromImpl(MethodCall mc) {
12361350
exists(Impl impl |
12371351
IsInstantiationOf<MethodCall, IsInstantiationOfInput>::isInstantiationOf(mc, impl, _) and
12381352
result = getMethodSuccessor(impl, mc.getMethodName())
1353+
|
1354+
not methodResolutionDependsOnArgument(impl, _, _, _, _, _) and
1355+
result = getMethodSuccessor(impl, mc.getMethodName())
1356+
or
1357+
exists(int pos, TypePath path, Type type |
1358+
methodResolutionDependsOnArgument(impl, mc.getMethodName(), result, pos, path, type) and
1359+
inferType(mc.getArgument(pos), path) = type
1360+
)
12391361
)
12401362
}
12411363

rust/ql/test/library-tests/type-inference/main.rs

Lines changed: 46 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1099,6 +1099,7 @@ mod method_call_type_conversion {
10991099
println!("{:?}", x5.0); // $ fieldof=S
11001100

11011101
let x6 = &S(S2); // $ SPURIOUS: type=x6:&T.&T.S
1102+
11021103
// explicit dereference
11031104
println!("{:?}", (*x6).m1()); // $ method=m1 method=deref
11041105

@@ -1668,17 +1669,18 @@ mod async_ {
16681669
}
16691670

16701671
fn f2() -> impl Future<Output = S1> {
1671-
async {
1672-
S1
1673-
}
1672+
async { S1 }
16741673
}
16751674

16761675
struct S2;
16771676

16781677
impl Future for S2 {
16791678
type Output = S1;
16801679

1681-
fn poll(self: std::pin::Pin<&mut Self>, _cx: &mut std::task::Context<'_>) -> std::task::Poll<Self::Output> {
1680+
fn poll(
1681+
self: std::pin::Pin<&mut Self>,
1682+
_cx: &mut std::task::Context<'_>,
1683+
) -> std::task::Poll<Self::Output> {
16821684
std::task::Poll::Ready(S1)
16831685
}
16841686
}
@@ -1692,14 +1694,11 @@ mod async_ {
16921694
f2().await.f(); // $ method=S1f
16931695
f3().await.f(); // $ method=S1f
16941696
S2.await.f(); // $ method=S1f
1695-
let b = async {
1696-
S1
1697-
};
1697+
let b = async { S1 };
16981698
b.await.f(); // $ method=S1f
16991699
}
17001700
}
17011701

1702-
17031702
mod impl_trait {
17041703
struct S1;
17051704
struct S2;
@@ -1816,6 +1815,44 @@ mod macros {
18161815
}
18171816
}
18181817

1818+
mod method_determined_by_argument_type {
1819+
trait MyAdd<T> {
1820+
fn my_add(&self, value: T) -> Self;
1821+
}
1822+
1823+
impl MyAdd<i64> for i64 {
1824+
// MyAdd<i64>::my_add
1825+
fn my_add(&self, value: i64) -> Self {
1826+
value
1827+
}
1828+
}
1829+
1830+
impl MyAdd<&i64> for i64 {
1831+
// MyAdd<&i64>::my_add
1832+
fn my_add(&self, value: &i64) -> Self {
1833+
*value // $ method=deref
1834+
}
1835+
}
1836+
1837+
impl MyAdd<bool> for i64 {
1838+
// MyAdd<bool>::my_add
1839+
fn my_add(&self, value: bool) -> Self {
1840+
if value {
1841+
1
1842+
} else {
1843+
0
1844+
}
1845+
}
1846+
}
1847+
1848+
pub fn f() {
1849+
let x: i64 = 73;
1850+
x.my_add(5i64); // $ method=MyAdd<i64>::my_add
1851+
x.my_add(&5i64); // $ method=MyAdd<&i64>::my_add
1852+
x.my_add(true); // $ method=MyAdd<bool>::my_add
1853+
}
1854+
}
1855+
18191856
fn main() {
18201857
field_access::f();
18211858
method_impl::f();
@@ -1839,4 +1876,5 @@ fn main() {
18391876
impl_trait::f();
18401877
indexers::f();
18411878
macros::f();
1879+
method_determined_by_argument_type::f();
18421880
}

0 commit comments

Comments
 (0)