Skip to content

Make parameter length based on underlying column length except for complex ops #2627

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 19 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions src/Config/DatabasePrimitives/DatabaseObject.cs
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,16 @@ public bool IsAnyColumnNullable(List<string> columnsToCheck)

return null;
}

public virtual int? GetLengthForParam(string paramName)
{
if (Columns.TryGetValue(paramName, out ColumnDefinition? columnDefinition))
{
return columnDefinition.Length;
}

return null;
}
}

/// <summary>
Expand Down Expand Up @@ -264,6 +274,7 @@ public class ColumnDefinition
public bool IsNullable { get; set; }
public bool IsReadOnly { get; set; }
public object? DefaultValue { get; set; }
public int? Length { get; set; }

public ColumnDefinition() { }

Expand Down
6 changes: 5 additions & 1 deletion src/Core/Models/DbConnectionParam.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,12 @@ namespace Azure.DataApiBuilder.Core.Models;
/// </summary>
public class DbConnectionParam
{
public DbConnectionParam(object? value, DbType? dbType = null, SqlDbType? sqlDbType = null)
public DbConnectionParam(object? value, DbType? dbType = null, SqlDbType? sqlDbType = null, int? length = null)
{
Value = value;
DbType = dbType;
SqlDbType = sqlDbType;
Length = length;
}

/// <summary>
Expand All @@ -31,4 +32,7 @@ public DbConnectionParam(object? value, DbType? dbType = null, SqlDbType? sqlDbT
// This is being made nullable
// because it's not populated for DB's other than MSSQL.
public SqlDbType? SqlDbType { get; set; }

// Nullable integer parameter representing length. nullable for back compatibility and for where its not needed
public int? Length { get; set; }
}
16 changes: 11 additions & 5 deletions src/Core/Models/GraphQLFilterParsers.cs
Original file line number Diff line number Diff line change
Expand Up @@ -486,7 +486,7 @@ private static Predicate ParseScalarType(
string schemaName,
string tableName,
string tableAlias,
Func<object, string?, string> processLiterals,
Func<object, string?, bool, string> processLiterals,
bool isListType = false)
{
Column column = new(schemaName, tableName, columnName: fieldName, tableAlias);
Expand Down Expand Up @@ -614,7 +614,7 @@ public static Predicate Parse(
IInputField argumentSchema,
Column column,
List<ObjectFieldNode> fields,
Func<object, string?, string> processLiterals,
Func<object, string?, bool, string> processLiterals,
bool isListType = false)
{
List<PredicateOperand> predicates = new();
Expand All @@ -635,6 +635,8 @@ public static Predicate Parse(
continue;
}

bool lengthOverride = false;

PredicateOperation op;
switch (name)
{
Expand Down Expand Up @@ -665,6 +667,7 @@ public static Predicate Parse(
{
op = PredicateOperation.LIKE;
value = $"%{EscapeLikeString((string)value)}%";
lengthOverride = true;
}

break;
Expand All @@ -677,16 +680,19 @@ public static Predicate Parse(
{
op = PredicateOperation.NOT_LIKE;
value = $"%{EscapeLikeString((string)value)}%";
lengthOverride = true;
}

break;
case "startsWith":
op = PredicateOperation.LIKE;
value = $"{EscapeLikeString((string)value)}%";
lengthOverride = true;
break;
case "endsWith":
op = PredicateOperation.LIKE;
value = $"%{EscapeLikeString((string)value)}";
lengthOverride = true;
break;
case "isNull":
processLiteral = false;
Expand All @@ -699,10 +705,10 @@ public static Predicate Parse(
}

predicates.Push(new PredicateOperand(new Predicate(
new PredicateOperand(column),
new(column),
op,
new PredicateOperand(processLiteral ? $"{processLiterals(value, column.ColumnName)}" : value.ToString()))
));
new(processLiteral ? $"{processLiterals(value, column.ColumnName, lengthOverride)}" : value.ToString())
)));
}

return GQLFilterParser.MakeChainPredicate(predicates, PredicateOperation.AND);
Expand Down
5 changes: 3 additions & 2 deletions src/Core/Resolvers/BaseQueryStructure.cs
Original file line number Diff line number Diff line change
Expand Up @@ -117,15 +117,16 @@ public BaseQueryStructure(
/// </summary>
/// <param name="value">Value to be assigned to parameter, which can be null for nullable columns.</param>
/// <param name="paramName"> The name of the parameter - backing column name for table/views or parameter name for stored procedures.</param>
public virtual string MakeDbConnectionParam(object? value, string? paramName = null)
public virtual string MakeDbConnectionParam(object? value, string? paramName = null, bool lengthOverride = false)
{
string encodedParamName = GetEncodedParamName(Counter.Next());
if (!string.IsNullOrEmpty(paramName))
{
Parameters.Add(encodedParamName,
new(value,
dbType: GetUnderlyingSourceDefinition().GetDbTypeForParam(paramName),
sqlDbType: GetUnderlyingSourceDefinition().GetSqlDbTypeForParam(paramName)));
sqlDbType: GetUnderlyingSourceDefinition().GetSqlDbTypeForParam(paramName),
length: lengthOverride ? -1 : GetUnderlyingSourceDefinition().GetLengthForParam(paramName)));
}
else
{
Expand Down
2 changes: 1 addition & 1 deletion src/Core/Resolvers/CosmosQueryStructure.cs
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ public CosmosQueryStructure(
}

/// <inheritdoc/>
public override string MakeDbConnectionParam(object? value, string? columnName = null)
public override string MakeDbConnectionParam(object? value, string? columnName = null, bool lengthOverride = false)
{
string encodedParamName = $"{PARAM_NAME_PREFIX}param{Counter.Next()}";
Parameters.Add(encodedParamName, new(value));
Expand Down
10 changes: 9 additions & 1 deletion src/Core/Resolvers/MsSqlQueryExecutor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -392,8 +392,16 @@ public override SqlCommand PrepareDbCommand(
{
SqlParameter parameter = cmd.CreateParameter();
parameter.ParameterName = parameterEntry.Key;
parameter.Value = parameterEntry.Value.Value ?? DBNull.Value;
parameter.Value = parameterEntry.Value?.Value ?? DBNull.Value;

PopulateDbTypeForParameter(parameterEntry, parameter);

//if sqldbtype is varchar, nvarchar then set the length
if (parameter.SqlDbType is SqlDbType.VarChar or SqlDbType.NVarChar or SqlDbType.Char or SqlDbType.NChar)
{
parameter.Size = parameterEntry.Value?.Length ?? -1;
}

cmd.Parameters.Add(parameter);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1321,7 +1321,8 @@ private async Task PopulateSourceDefinitionAsync(
SystemType = (Type)columnInfoFromAdapter["DataType"],
// An auto-increment column is also considered as a read-only column. For other types of read-only columns,
// the flag is populated later via PopulateColumnDefinitionsWithReadOnlyFlag() method.
IsReadOnly = (bool)columnInfoFromAdapter["IsAutoIncrement"]
IsReadOnly = (bool)columnInfoFromAdapter["IsAutoIncrement"],
Length = GetDatabaseType() is DatabaseType.MSSQL ? (int)columnInfoFromAdapter["ColumnSize"] : null
};

// Tests may try to add the same column simultaneously
Expand Down
4 changes: 2 additions & 2 deletions src/Service.Tests/DatabaseSchema-MsSql.sql
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ CREATE TABLE publishers_mm(

CREATE TABLE books(
id int IDENTITY(5001, 1) PRIMARY KEY,
title varchar(max) NOT NULL,
title varchar(30) NOT NULL,
publisher_id int NOT NULL
);

Expand Down Expand Up @@ -514,7 +514,7 @@ SET IDENTITY_INSERT books ON
INSERT INTO books(id, title, publisher_id)
VALUES (1, 'Awesome book', 1234),
(2, 'Also Awesome book', 1234),
(3, 'Great wall of china explained', 2345),
(3, 'Great wall of china explained]', 2345),
(4, 'US history in a nutshell', 2345),
(5, 'Chernobyl Diaries', 2323),
(6, 'The Palace Door', 2324),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,109 @@ SELECT TOP 1 content FROM reviews
await QueryWithMultipleColumnPrimaryKey(msSqlQuery);
}

/// <sumary>
/// Test if filter param successfully filters when string filter
/// </summary>
[TestMethod]
public virtual async Task TestFilterParamForStringFilter()
{
string graphQLQueryName = "books";
string graphQLQuery = @"{
books( " + Service.GraphQLBuilder.Queries.QueryBuilder.FILTER_FIELD_NAME + @":{ title: {eq:""Awesome book""}}) {
items {
id
title
}
}
}";

string expected = @"
[
{
""id"": 1,
""title"": ""Awesome book""
}
]";

JsonElement actual = await ExecuteGraphQLRequestAsync(graphQLQuery, graphQLQueryName, isAuthenticated: false);

SqlTestHelper.PerformTestEqualJsonStrings(expected, actual.GetProperty("items").ToString());
}

/// <sumary>
/// Test if filter param successfully filters when string filter results in a value longer than the column
/// </summary>
/// <remarks>
/// When using complex operators i.e. NotContains due to wildcards being added or special characters being escaped
/// the string being passed as a parameter maybe longer than the length of the column. The parameter data type
/// can't be fixed to the length of the underlying column, otherwise the parameter value would be truncated and
/// we'd get incorrect results
/// Thus checking the parameter length is overridden to cater for the extra length i.e. lengthOverride = true codepath.
/// </remarks>
[DataTestMethod]
[DataRow("contains")]
[DataRow("startsWith")]
[DataRow("endsWith")]
public virtual async Task TestFilterParamForStringFilterWorkWithComplexOp(string op)
{
string graphQLQueryName = "books";

//using a lookup value that is the length of the title column AND includes special characters
string graphQLQuery = @"{
books( " + Service.GraphQLBuilder.Queries.QueryBuilder.FILTER_FIELD_NAME + @":{ title: {" + op + @":""Great wall of china explained]""}}) {
items {
id
title
}
}
}";

string expected = @"
[
{
""id"": 3,
""title"": ""Great wall of china explained]""
}
]";

JsonElement actual = await ExecuteGraphQLRequestAsync(graphQLQuery, graphQLQueryName, isAuthenticated: false);

SqlTestHelper.PerformTestEqualJsonStrings(expected, actual.GetProperty("items").ToString());
}

/// <sumary>
/// Test if filter param successfully filters when string filter results in a value longer than the column
/// </summary>
/// <remarks>
/// When using complex operators i.e. NotContains due to wildcards being added or special characters being escaped
/// the string being passed as a parameter maybe longer than the length of the column. The parameter data type
/// can't be fixed to the length of the underlying column, otherwise the parameter value would be truncated and
/// we'd get incorrect results.
/// Thus checking the parameter length is overridden to cater for the extra length i.e. lengthOverride = true codepath.
/// </remarks>
[TestMethod]
public virtual async Task TestFilterParamForStringFilterWorkWithNotContains(string op)
{
string graphQLQueryName = "books";
//using a lookup value that is the length of the title column AND includes special characters
string graphQLQuery = @"{
books( " + Service.GraphQLBuilder.Queries.QueryBuilder.FILTER_FIELD_NAME + @":{ title: { notContains:""Great wall of china explained]""},id:{eq:3} }) {
items {
id
title
}
}
}";

string expected = @"
[
]";

JsonElement actual = await ExecuteGraphQLRequestAsync(graphQLQuery, graphQLQueryName, isAuthenticated: false);

SqlTestHelper.PerformTestEqualJsonStrings(expected, actual.GetProperty("items").ToString());
}

[TestMethod]
public async Task QueryWithNullableForeignKey()
{
Expand Down Expand Up @@ -421,8 +524,8 @@ public async Task TestStoredProcedureQueryWithNoDefaultInConfig()
public async Task TestSupportForAggregationsWithAliases()
{
string msSqlQuery = @"
SELECT
MAX(categoryid) AS max,
SELECT
MAX(categoryid) AS max,
MAX(price) AS max_price,
MIN(price) AS min_price,
AVG(price) AS avg_price,
Expand Down
Loading