ContainsKey method call will now short circuit expression evaluation.

main
Sean McArde 2024-04-23 15:45:33 -07:00
parent e6389cab1d
commit cd24d36aa7
2 changed files with 55 additions and 9 deletions

View File

@ -1,5 +1,4 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using System;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Linq.Expressions;
@ -76,5 +75,19 @@ namespace McRule.Tests {
Assert.NotNull(filteredContexts);
Assert.AreEqual(filteredContexts.Count, 2);
}
[Test]
public void CanSelectDictionaryValueWithContainsKeyCheck() {
var lambda = itPeople.GetPredicateExpression<ContextStringDictionary>();
Console.WriteLine(lambda);
var filter = lambda.Compile();
var filteredContexts = SomeContexts.Select(x => x.Context)
.Where(filter)?.ToList();
Assert.NotNull(filteredContexts);
Assert.AreEqual(filteredContexts.Count, 2);
}
}
}

View File

@ -34,6 +34,29 @@ public static partial class PredicateExpressionPolicyExtensions
return Expression.AndAlso(notNull, expression);
}
/// <summary>
/// Returns an expression which executes a ContainsKey method call on IDictionary types
/// and prepends it to a given expression with an AndAlso operator. Note, this comparison
/// short circuits so the right hand side will not execute when a key is not found.
/// </summary>
/// <typeparam name="T"></typeparam>
/// <param name="left"></param>
/// <param name="right"></param>
/// <param name="dictKey"></param>
/// <returns></returns>
internal static Expression AddContainsKeyCheck<T>(
Expression left,
string dictKey,
Expression<Func<T, bool>> right) {
// Create generic method which is bound with the Call Expression below
var containsKeyRuntimeMethod = left.Type.GetMethod("ContainsKey");
var containsKeyCall = Expression.Call(left, containsKeyRuntimeMethod, Expression.Constant(dictKey));
var methodExpression = Expression.Lambda<Func<T, bool>>(containsKeyCall, false, right.Parameters);
return PredicateBuilder.And<T>(methodExpression, right);
}
/// <summary>
/// Test for null value. This is used to test for null literals.
/// </summary>
@ -293,6 +316,10 @@ public abstract class ExpressionGeneratorBase : ExpressionGenerator {
internal Expression Member { get; set; }
internal bool LOpIsDict { get; set; } = false;
internal Expression LOp { get; set; }
internal string DictKey { get; set; }
internal void AddNewPreCheck(Expression<Func<T, bool>> lambda) {
PreChecks.Add(lambda);
}
@ -313,17 +340,15 @@ public abstract class ExpressionGeneratorBase : ExpressionGenerator {
Expression opLeft = parameter;
foreach (string p in propertyName.Split(".")) {
result.LOpIsDict = false;
if (opLeft.Type.GetInterfaces().Contains(typeof(IDictionary))) {
result.LOpIsDict = true;
result.DictKey = p;
result.LOp = opLeft;
var dictKey = Expression.Constant(p);
// Create generic method which is bound with the Call Expression below
var containsKeyRuntimeMethod = opLeft.Type.GetMethod("ContainsKey");
var containsKeyCall = Expression.Call(opLeft, containsKeyRuntimeMethod, dictKey);
//var preCheckLambda = Expression.Lambda(containsKeyCall, false, new ParameterExpression[] { opLeft, dictKey });
result.AddNewPreCheck(Expression.Lambda<Func<T,bool>>(containsKeyCall, false, Expression.Parameter(opLeft.Type, "x")));
opLeft = Expression.Property(opLeft, "Item", dictKey);
} else {
@ -462,6 +487,14 @@ public abstract class ExpressionGeneratorBase : ExpressionGenerator {
comparison = AddNotNullCheck<T>(opLeft, comparison);
}
// When the left hand side of the comparision implements IDictionary we need to add a ContainsKey
// method call to assert there's a value to compare against before actually retrieving it by name.
// A missing key evaluates to false.
// TODO: use ~ operator to return true for a comparison predicate where a key is missing.
if (resolvedMember.LOpIsDict) {
comparison = AddContainsKeyCheck<T>(resolvedMember.LOp, resolvedMember.DictKey, (Expression<Func<T, bool>> )comparison);
}
// The value may have the right type and should just be returned.
Expression<Func<T, bool>> result = default(Expression<Func<T, bool>>);
if (comparison is Expression<Func<T, bool>> checkedResult && checkedResult != default(Expression<Func<T, bool>>)) {