Skip to content

Commit

Permalink
Implement Contains over array fields
Browse files Browse the repository at this point in the history
  • Loading branch information
roji committed Jan 24, 2025
1 parent 46991f9 commit 11c6b29
Show file tree
Hide file tree
Showing 15 changed files with 373 additions and 44 deletions.
4 changes: 2 additions & 2 deletions dotnet/SK-dotnet.sln
Original file line number Diff line number Diff line change
Expand Up @@ -441,9 +441,9 @@ Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "kernel-functions-generator"
EndProject
Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "VectorDataIntegrationTests", "VectorDataIntegrationTests", "{4F381919-F1BE-47D8-8558-3187ED04A84F}"
EndProject
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "QdrantIntegrationTests", "src\VectorDataTests\QdrantIntegrationTests\QdrantIntegrationTests.csproj", "{27D33AB3-4DFF-48BC-8D76-FB2CDF90B707}"
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "QdrantIntegrationTests", "src\VectorDataIntegrationTests\QdrantIntegrationTests\QdrantIntegrationTests.csproj", "{27D33AB3-4DFF-48BC-8D76-FB2CDF90B707}"
EndProject
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "VectorDataIntegrationTests", "src\VectorDataTests\VectorDataIntegrationTests\VectorDataIntegrationTests.csproj", "{B29A972F-A774-4140-AECF-6B577C476627}"
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "VectorDataIntegrationTests", "src\VectorDataIntegrationTests\VectorDataIntegrationTests\VectorDataIntegrationTests.csproj", "{B29A972F-A774-4140-AECF-6B577C476627}"
EndProject
Global
GlobalSection(SolutionConfigurationPlatforms) = preSolution
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
using System.Collections.Generic;
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.Linq;
using System.Linq.Expressions;
using Google.Protobuf.Collections;
using Qdrant.Client.Grpc;
Expand All @@ -22,46 +23,43 @@ internal Filter Translate(LambdaExpression lambdaExpression, IReadOnlyDictionary
Debug.Assert(lambdaExpression.Parameters.Count == 1);
this._recordParameter = lambdaExpression.Parameters[0];

return this.Visit(lambdaExpression.Body);
return this.Translate(lambdaExpression.Body);
}

private Filter Visit(Expression? node)
private Filter Translate(Expression? node)
=> node switch
{
BinaryExpression { NodeType: ExpressionType.Equal } equal => this.VisitEqual(equal, negated: false),
BinaryExpression { NodeType: ExpressionType.NotEqual } notEqual => this.VisitEqual(notEqual, negated: true),
BinaryExpression { NodeType: ExpressionType.Equal } equal => this.TranslateEqual(equal.Left, equal.Right),
BinaryExpression { NodeType: ExpressionType.NotEqual } notEqual => this.TranslateEqual(notEqual.Left, notEqual.Right, negated: true),

BinaryExpression { NodeType: ExpressionType.AndAlso } andAlso => this.VisitAndAlso(andAlso),
BinaryExpression { NodeType: ExpressionType.OrElse } orElse => this.VisitOrElse(orElse),
BinaryExpression { NodeType: ExpressionType.AndAlso } andAlso => this.TranslateAndAlso(andAlso.Left, andAlso.Right),
BinaryExpression { NodeType: ExpressionType.OrElse } orElse => this.TranslateOrElse(orElse.Left, orElse.Right),

// MemberExpression member => this.VisitMember(member),

// null => null, // TODO: Not sure
// TODO: Other Contains variants (e.g. List.Contains)
MethodCallExpression
{
Method.Name: nameof(Enumerable.Contains),
Arguments: [var source, var item]
} contains when contains.Method.DeclaringType == typeof(Enumerable)
=> this.TranslateContains(source, item),

_ => throw new NotSupportedException("Unsupported expression type: " + node.GetType().Name)
_ => throw new NotSupportedException("Qdrant does not support the following expression type in filters: " + node?.GetType().Name)
};

private Filter VisitEqual(BinaryExpression equal, bool negated)
private Filter TranslateEqual(Expression left, Expression right, bool negated = false)
{
return TryProcessEqual(equal.Left, equal.Right, out var result)
return TryProcessEqual(left, right, out var result)
? result
: TryProcessEqual(equal.Right, equal.Left, out result)
: TryProcessEqual(right, left, out result)
? result
: throw new NotSupportedException("Equality expression not supported by Qdrant");

bool TryProcessEqual(Expression first, Expression second, [NotNullWhen(true)] out Filter? result)
{
// TODO: Captured variable
// TODO: Nullable
if (first is MemberExpression memberExpression
&& memberExpression.Expression == this._recordParameter
&& second is ConstantExpression { Value: var constantValue })
if (this.TryTranslateFieldAccess(first, out var storagePropertyName) && second is ConstantExpression { Value: var constantValue })
{
if (!this._storagePropertyNames.TryGetValue(memberExpression.Member.Name, out var storagePropertyName))
{
throw new InvalidOperationException($"Property name '{memberExpression.Member.Name}' provided as part of the filter clause is not a valid property name.");
}

var condition = constantValue is null
? new Condition { IsNull = new() { Key = storagePropertyName } }
: new Condition
Expand Down Expand Up @@ -100,27 +98,29 @@ bool TryProcessEqual(Expression first, Expression second, [NotNullWhen(true)] ou
}
}

private Filter VisitAndAlso(BinaryExpression andAlso)
#region Logical operators

private Filter TranslateAndAlso(Expression left, Expression right)
{
var left = this.Visit(andAlso.Left);
var right = this.Visit(andAlso.Right);
var leftFilter = this.Translate(left);
var rightFilter = this.Translate(right);

// Qdrant doesn't allow arbitrary nesting of logical operators, only one MUST list (AND), one SHOULD list (OR), and one MUST_NOT list (AND NOT).
// We can combine MUST and MUST_NOT; but we can only combine SHOULD if it's the *only* thing on the one side (no MUST/MUST_NOT), and there's no SHOULD on the other (only MUST/MUST_NOT).
if (left.Should.Count > 0)
if (leftFilter.Should.Count > 0)
{
return ProcessWithShould(left, right);
return ProcessWithShould(leftFilter, rightFilter);
}

if (right.Should.Count > 0)
if (rightFilter.Should.Count > 0)
{
return ProcessWithShould(right, left);
return ProcessWithShould(rightFilter, leftFilter);
}

left.Must.AddRange(right.Must);
left.MustNot.AddRange(right.MustNot);
leftFilter.Must.AddRange(rightFilter.Must);
leftFilter.MustNot.AddRange(rightFilter.MustNot);

return left;
return leftFilter;

static Filter ProcessWithShould(Filter filterWithShould, Filter otherFilter)
{
Expand All @@ -134,17 +134,17 @@ static Filter ProcessWithShould(Filter filterWithShould, Filter otherFilter)
}
}

private Filter VisitOrElse(BinaryExpression orElse)
private Filter TranslateOrElse(Expression left, Expression right)
{
// Qdrant doesn't allow arbitrary nesting of logical operators, only one MUST list (AND), one SHOULD list (OR), and one MUST_NOT list (AND NOT).
// As a result, we can only combine single conditions with OR - the moment there's a nested AND we can't.

var left = this.Visit(orElse.Left);
var right = this.Visit(orElse.Right);
var leftFilter = this.Translate(left);
var rightFilter = this.Translate(right);

var result = new Filter();
result.Should.AddRange(GetShouldConditions(left));
result.Should.AddRange(GetShouldConditions(right));
result.Should.AddRange(GetShouldConditions(leftFilter));
result.Should.AddRange(GetShouldConditions(rightFilter));
return result;

static RepeatedField<Condition> GetShouldConditions(Filter filter)
Expand All @@ -157,4 +157,34 @@ static RepeatedField<Condition> GetShouldConditions(Filter filter)
_ => throw new NotSupportedException("Qdrant does not support the given logical operator combination")
};
}

#endregion Logical operators

private Filter TranslateContains(Expression source, Expression item)
{
// TODO: Inline/parameterized array?
if (this.TryTranslateFieldAccess(source, out _))
{
// Oddly, in Qdrant, tag list contains is handled using a Match condition, just like equality.
return this.TranslateEqual(source, item);
}

throw new NotSupportedException("Contains only supported over Qdrant list fields");
}

private bool TryTranslateFieldAccess(Expression expression, [NotNullWhen(true)] out string? storagePropertyName)
{
if (expression is MemberExpression memberExpression && memberExpression.Expression == this._recordParameter)
{
if (!this._storagePropertyNames.TryGetValue(memberExpression.Member.Name, out storagePropertyName))
{
throw new InvalidOperationException($"Property name '{memberExpression.Member.Name}' provided as part of the filter clause is not a valid property name.");
}

return true;
}

storagePropertyName = null;
return false;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

using System;
using System.Collections.Generic;
using System.Linq.Expressions;
using Microsoft.Extensions.VectorData;
using Qdrant.Client.Grpc;

Expand Down
Loading

0 comments on commit 11c6b29

Please sign in to comment.