Skip to content

Rust: Fix type inference for trait objects for traits with associated types #20122

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jul 26, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 48 additions & 10 deletions rust/ql/lib/codeql/rust/internal/Type.qll
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,27 @@ private import codeql.rust.internal.CachedStages
private import codeql.rust.elements.internal.generated.Raw
private import codeql.rust.elements.internal.generated.Synth

/**
* Holds if a dyn trait type should have a type parameter associated with `n`. A
* dyn trait type inherits the type parameters of the trait it implements. That
* includes the type parameters corresponding to associated types.
*
* For instance in
* ```rust
* trait SomeTrait<A> {
* type AssociatedType;
* }
* ```
* this predicate holds for the nodes `A` and `type AssociatedType`.
*/
private predicate dynTraitTypeParameter(Trait trait, AstNode n) {
trait = any(DynTraitTypeRepr dt).getTrait() and
(
n = trait.getGenericParamList().getATypeParam() or
n = trait.(TraitItemNode).getAnAssocItem().(TypeAlias)
)
}

cached
newtype TType =
TTuple(int arity) {
Expand All @@ -30,9 +51,7 @@ newtype TType =
TTypeParamTypeParameter(TypeParam t) or
TAssociatedTypeTypeParameter(TypeAlias t) { any(TraitItemNode trait).getAnAssocItem() = t } or
TArrayTypeParameter() or
TDynTraitTypeParameter(TypeParam tp) {
tp = any(DynTraitTypeRepr dt).getTrait().getGenericParamList().getATypeParam()
} or
TDynTraitTypeParameter(AstNode n) { dynTraitTypeParameter(_, n) } or
TRefTypeParameter() or
TSelfTypeParameter(Trait t) or
TSliceTypeParameter()
Expand Down Expand Up @@ -406,15 +425,35 @@ class ArrayTypeParameter extends TypeParameter, TArrayTypeParameter {
}

class DynTraitTypeParameter extends TypeParameter, TDynTraitTypeParameter {
private TypeParam typeParam;
private AstNode n;

DynTraitTypeParameter() { this = TDynTraitTypeParameter(typeParam) }
DynTraitTypeParameter() { this = TDynTraitTypeParameter(n) }

TypeParam getTypeParam() { result = typeParam }
Trait getTrait() { dynTraitTypeParameter(result, n) }

override string toString() { result = "dyn(" + typeParam.toString() + ")" }
/** Gets the dyn trait type that this type parameter belongs to. */
DynTraitType getDynTraitType() { result.getTrait() = this.getTrait() }

override Location getLocation() { result = typeParam.getLocation() }
/** Gets the `TypeParam` of this dyn trait type parameter, if any. */
TypeParam getTypeParam() { result = n }

/** Gets the `TypeAlias` of this dyn trait type parameter, if any. */
TypeAlias getTypeAlias() { result = n }

/** Gets the trait type parameter that this dyn trait type parameter corresponds to. */
TypeParameter getTraitTypeParameter() {
result.(TypeParamTypeParameter).getTypeParam() = n
or
result.(AssociatedTypeTypeParameter).getTypeAlias() = n
}

private string toStringInner() {
result = [this.getTypeParam().toString(), this.getTypeAlias().getName().toString()]
}

override string toString() { result = "dyn(" + this.toStringInner() + ")" }

override Location getLocation() { result = n.getLocation() }
}

/** An implicit reference type parameter. */
Expand Down Expand Up @@ -503,8 +542,7 @@ final class ImplTypeAbstraction extends TypeAbstraction, Impl {

final class DynTypeAbstraction extends TypeAbstraction, DynTraitTypeRepr {
override TypeParameter getATypeParameter() {
result.(TypeParamTypeParameter).getTypeParam() =
this.getTrait().getGenericParamList().getATypeParam()
result = any(DynTraitTypeParameter tp | tp.getTrait() = this.getTrait()).getTraitTypeParameter()
}
}

Expand Down
6 changes: 5 additions & 1 deletion rust/ql/lib/codeql/rust/internal/TypeInference.qll
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,11 @@ private module Input1 implements InputSig1<Location> {
id = 2
or
kind = 1 and
id = idOfTypeParameterAstNode(tp0.(DynTraitTypeParameter).getTypeParam())
id =
idOfTypeParameterAstNode([
tp0.(DynTraitTypeParameter).getTypeParam().(AstNode),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

.(AstNode) appears to be redundant.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is to help QL figure out what the element type of the set is. It complains without the .(AstNode) :)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's surprising, but fair enough.

tp0.(DynTraitTypeParameter).getTypeAlias()
])
or
kind = 2 and
exists(AstNode node | id = idOfTypeParameterAstNode(node) |
Expand Down
12 changes: 6 additions & 6 deletions rust/ql/lib/codeql/rust/internal/TypeMention.qll
Original file line number Diff line number Diff line change
Expand Up @@ -324,10 +324,10 @@ class DynTraitTypeReprMention extends TypeMention instanceof DynTraitTypeRepr {
result = dynType
or
exists(DynTraitTypeParameter tp, TypePath path0, TypePath suffix |
tp = dynType.getTypeParameter(_) and
dynType = tp.getDynTraitType() and
path = TypePath::cons(tp, suffix) and
result = super.getTypeBoundList().getBound(0).getTypeRepr().(TypeMention).resolveTypeAt(path0) and
path0.isCons(TTypeParamTypeParameter(tp.getTypeParam()), suffix)
path0.isCons(tp.getTraitTypeParameter(), suffix)
)
}
}
Expand Down Expand Up @@ -363,10 +363,10 @@ class DynTypeBoundListMention extends TypeMention instanceof TypeBoundList {
path.isEmpty() and
result.(DynTraitType).getTrait() = trait
or
exists(TypeParam param |
param = trait.getGenericParamList().getATypeParam() and
path = TypePath::singleton(TDynTraitTypeParameter(param)) and
result = TTypeParamTypeParameter(param)
exists(DynTraitTypeParameter tp |
trait = tp.getTrait() and
path = TypePath::singleton(tp) and
result = tp.getTraitTypeParameter()
)
}
}
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
multipleCallTargets
| dereference.rs:61:15:61:24 | e1.deref() |
| main.rs:2213:13:2213:31 | ...::from(...) |
| main.rs:2214:13:2214:31 | ...::from(...) |
| main.rs:2215:13:2215:31 | ...::from(...) |
| main.rs:2221:13:2221:31 | ...::from(...) |
| main.rs:2222:13:2222:31 | ...::from(...) |
| main.rs:2223:13:2223:31 | ...::from(...) |
| main.rs:2253:13:2253:31 | ...::from(...) |
| main.rs:2254:13:2254:31 | ...::from(...) |
| main.rs:2255:13:2255:31 | ...::from(...) |
| main.rs:2261:13:2261:31 | ...::from(...) |
| main.rs:2262:13:2262:31 | ...::from(...) |
| main.rs:2263:13:2263:31 | ...::from(...) |
41 changes: 41 additions & 0 deletions rust/ql/test/library-tests/type-inference/dyn_type.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,12 @@ trait GenericGet<A> {
fn get(&self) -> A;
}

trait AssocTrait<GP> {
type AP;
// AssocTrait::get
fn get(&self) -> (GP, Self::AP);
}

#[derive(Clone, Debug)]
struct MyStruct {
value: i32,
Expand All @@ -36,6 +42,17 @@ impl<A: Clone + Debug> GenericGet<A> for GenStruct<A> {
}
}

impl<GGP> AssocTrait<GGP> for GenStruct<GGP>
where
GGP: Clone + Debug,
{
type AP = bool;
// GenStruct<GGP>::get
fn get(&self) -> (GGP, bool) {
(self.value.clone(), true) // $ fieldof=GenStruct target=clone
}
}

fn get_a<A, G: GenericGet<A> + ?Sized>(a: &G) -> A {
a.get() // $ target=GenericGet::get
}
Expand All @@ -58,10 +75,34 @@ fn test_poly_dyn_trait() {
let _result = (*obj).get(); // $ target=deref target=GenericGet::get type=_result:bool
}

fn assoc_dyn_get<A, B>(a: &dyn AssocTrait<A, AP = B>) -> (A, B) {
a.get() // $ target=AssocTrait::get
}

fn assoc_get<A, B, T: AssocTrait<A, AP = B> + ?Sized>(a: &T) -> (A, B) {
a.get() // $ target=AssocTrait::get
}

fn test_assoc_type(obj: &dyn AssocTrait<i64, AP = bool>) {
let (
_gp, // $ type=_gp:i64
_ap, // $ type=_ap:bool
) = (*obj).get(); // $ target=deref target=AssocTrait::get
let (
_gp, // $ type=_gp:i64
_ap, // $ type=_ap:bool
) = assoc_dyn_get(obj); // $ target=assoc_dyn_get
let (
_gp, // $ type=_gp:i64
_ap, // $ type=_ap:bool
) = assoc_get(obj); // $ target=assoc_get
}

pub fn test() {
test_basic_dyn_trait(&MyStruct { value: 42 }); // $ target=test_basic_dyn_trait
test_generic_dyn_trait(&GenStruct {
value: "".to_string(),
}); // $ target=test_generic_dyn_trait
test_poly_dyn_trait(); // $ target=test_poly_dyn_trait
test_assoc_type(&GenStruct { value: 100 }); // $ target=test_assoc_type
}
44 changes: 42 additions & 2 deletions rust/ql/test/library-tests/type-inference/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -653,7 +653,7 @@ mod function_trait_bounds {
}
}

mod trait_associated_type {
mod associated_type_in_trait {
#[derive(Debug)]
struct Wrapper<A> {
field: A,
Expand Down Expand Up @@ -803,6 +803,46 @@ mod trait_associated_type {
}
}

mod associated_type_in_supertrait {
trait Supertrait {
type Content;
fn insert(content: Self::Content);
}

trait Subtrait: Supertrait {
// Subtrait::get_content
fn get_content(&self) -> Self::Content;
}

struct MyType<T>(T);

impl<T> Supertrait for MyType<T> {
type Content = T;
fn insert(_content: Self::Content) {
println!("Inserting content: ");
}
}

impl<T: Clone> Subtrait for MyType<T> {
// MyType::get_content
fn get_content(&self) -> Self::Content {
(*self).0.clone() // $ fieldof=MyType target=clone target=deref
}
}

fn get_content<T: Subtrait>(item: &T) -> T::Content {
item.get_content() // $ target=Subtrait::get_content
}

fn test() {
let item1 = MyType(42i64);
let _content1 = item1.get_content(); // $ target=MyType::get_content MISSING: type=_content1:i64

let item2 = MyType(true);
let _content2 = get_content(&item2); // $ target=get_content MISSING: type=_content2:bool
}
}

mod generic_enum {
#[derive(Debug)]
enum MyEnum<A> {
Expand Down Expand Up @@ -2469,7 +2509,7 @@ fn main() {
method_non_parametric_impl::f(); // $ target=f
method_non_parametric_trait_impl::f(); // $ target=f
function_trait_bounds::f(); // $ target=f
trait_associated_type::f(); // $ target=f
associated_type_in_trait::f(); // $ target=f
generic_enum::f(); // $ target=f
method_supertraits::f(); // $ target=f
function_trait_bounds_2::f(); // $ target=f
Expand Down
Loading