Imported Upstream version 4.0.0~alpha1

Former-commit-id: 806294f5ded97629b74c85c09952f2a74fe182d9
This commit is contained in:
Jo Shields
2015-04-07 09:35:12 +01:00
parent 283343f570
commit 3c1f479b9d
22469 changed files with 2931443 additions and 869343 deletions

View File

@@ -0,0 +1,75 @@
using System;
using System.Collections.Generic;
using System.Text;
namespace System.Data.Linq.SqlClient {
internal class SqlRowNumberChecker {
Visitor rowNumberVisitor;
internal SqlRowNumberChecker() {
this.rowNumberVisitor = new Visitor();
}
internal bool HasRowNumber(SqlNode node) {
this.rowNumberVisitor.Visit(node);
return rowNumberVisitor.HasRowNumber;
}
internal bool HasRowNumber(SqlRow row) {
foreach (SqlColumn column in row.Columns) {
if (this.HasRowNumber(column)) {
return true;
}
}
return false;
}
internal SqlColumn RowNumberColumn {
get {
return rowNumberVisitor.HasRowNumber ? rowNumberVisitor.CurrentColumn : null;
}
}
private class Visitor: SqlVisitor {
bool hasRowNumber = false;
public bool HasRowNumber {
get { return hasRowNumber; }
}
public SqlColumn CurrentColumn { private set; get; }
internal override SqlRowNumber VisitRowNumber(SqlRowNumber rowNumber) {
this.hasRowNumber = true;
return rowNumber;
}
// shortcuts
internal override SqlExpression VisitScalarSubSelect(SqlSubSelect ss) {
return ss;
}
internal override SqlExpression VisitSubSelect(SqlSubSelect ss) {
return ss;
}
internal override SqlRow VisitRow(SqlRow row)
{
for (int i = 0, n = row.Columns.Count; i < n; i++) {
row.Columns[i].Expression = this.VisitExpression(row.Columns[i].Expression);
if (this.hasRowNumber) {
this.CurrentColumn = row.Columns[i];
break;
}
}
return row;
}
internal override SqlSelect VisitSelect(SqlSelect select) {
this.Visit(select.Row);
this.Visit(select.Where);
return select;
}
}
}
}

View File

@@ -0,0 +1,445 @@
using System;
using System.Collections.Generic;
using System.Collections.ObjectModel;
using System.Linq;
using System.Linq.Expressions;
namespace System.Data.Linq.SqlClient {
using System.Data.Linq.Mapping;
using System.Diagnostics.CodeAnalysis;
internal static class Funcletizer {
internal static Expression Funcletize(Expression expression) {
return new Localizer(new LocalMapper().MapLocals(expression)).Localize(expression);
}
class Localizer : ExpressionVisitor {
Dictionary<Expression, bool> locals;
internal Localizer(Dictionary<Expression, bool> locals) {
this.locals = locals;
}
internal Expression Localize(Expression expression) {
return this.Visit(expression);
}
internal override Expression Visit(Expression exp) {
if (exp == null) {
return null;
}
if (this.locals.ContainsKey(exp)) {
return MakeLocal(exp);
}
if (exp.NodeType == (ExpressionType)InternalExpressionType.Known) {
return exp;
}
return base.Visit(exp);
}
private static Expression MakeLocal(Expression e) {
if (e.NodeType == ExpressionType.Constant) {
return e;
}
else if (e.NodeType == ExpressionType.Convert || e.NodeType == ExpressionType.ConvertChecked) {
UnaryExpression ue = (UnaryExpression)e;
if (ue.Type == typeof(object)) {
Expression local = MakeLocal(ue.Operand);
return (e.NodeType == ExpressionType.Convert) ? Expression.Convert(local, e.Type) : Expression.ConvertChecked(local, e.Type);
}
// convert a const null
if (ue.Operand.NodeType == ExpressionType.Constant) {
ConstantExpression c = (ConstantExpression)ue.Operand;
if (c.Value == null) {
return Expression.Constant(null, ue.Type);
}
}
}
return Expression.Invoke(Expression.Constant(Expression.Lambda(e).Compile()));
}
}
class DependenceChecker : ExpressionVisitor {
HashSet<ParameterExpression> inScope = new HashSet<ParameterExpression>();
bool isIndependent = true;
/// <summary>
/// This method returns 'true' when the expression doesn't reference any parameters
/// from outside the scope of the expression.
/// </summary>
static public bool IsIndependent(Expression expression) {
var v = new DependenceChecker();
v.Visit(expression);
return v.isIndependent;
}
internal override Expression VisitLambda(LambdaExpression lambda) {
foreach (var p in lambda.Parameters) {
this.inScope.Add(p);
}
return base.VisitLambda(lambda);
}
internal override Expression VisitParameter(ParameterExpression p) {
this.isIndependent &= this.inScope.Contains(p);
return p;
}
}
class LocalMapper : ExpressionVisitor {
bool isRemote;
Dictionary<Expression, bool> locals;
internal Dictionary<Expression, bool> MapLocals(Expression expression) {
this.locals = new Dictionary<Expression, bool>();
this.isRemote = false;
this.Visit(expression);
return this.locals;
}
internal override Expression Visit(Expression expression) {
if (expression == null) {
return null;
}
bool saveIsRemote = this.isRemote;
switch (expression.NodeType) {
case (ExpressionType)InternalExpressionType.Known:
return expression;
case (ExpressionType)ExpressionType.Constant:
break;
default:
this.isRemote = false;
base.Visit(expression);
if (!this.isRemote
&& expression.NodeType != ExpressionType.Lambda
&& expression.NodeType != ExpressionType.Quote
&& DependenceChecker.IsIndependent(expression)) {
this.locals[expression] = true; // Not 'Add' because the same expression may exist in the tree twice.
}
break;
}
if (typeof(ITable).IsAssignableFrom(expression.Type) ||
typeof(DataContext).IsAssignableFrom(expression.Type)) {
this.isRemote = true;
}
this.isRemote |= saveIsRemote;
return expression;
}
internal override Expression VisitMemberAccess(MemberExpression m) {
base.VisitMemberAccess(m);
this.isRemote |= (m.Expression != null && typeof(ITable).IsAssignableFrom(m.Expression.Type));
return m;
}
internal override Expression VisitMethodCall(MethodCallExpression m) {
base.VisitMethodCall(m);
this.isRemote |= m.Method.DeclaringType == typeof(System.Data.Linq.Provider.DataManipulation)
|| Attribute.IsDefined(m.Method, typeof(FunctionAttribute));
return m;
}
}
}
internal abstract class ExpressionVisitor {
internal ExpressionVisitor() {
}
[SuppressMessage("Microsoft.Maintainability", "CA1502:AvoidExcessiveComplexity", Justification = "These issues are related to our use of if-then and case statements for node types, which adds to the complexity count however when reviewed they are easy to navigate and understand.")]
[SuppressMessage("Microsoft.Performance", "CA1800:DoNotCastUnnecessarily", Justification = "[....]: Cast is dependent on node type and casts do not happen unecessarily in a single code path.")]
internal virtual Expression Visit(Expression exp) {
if (exp == null)
return exp;
switch (exp.NodeType) {
case ExpressionType.Negate:
case ExpressionType.NegateChecked:
case ExpressionType.Not:
case ExpressionType.Convert:
case ExpressionType.ConvertChecked:
case ExpressionType.ArrayLength:
case ExpressionType.Quote:
case ExpressionType.TypeAs:
return this.VisitUnary((UnaryExpression)exp);
case ExpressionType.Add:
case ExpressionType.AddChecked:
case ExpressionType.Subtract:
case ExpressionType.SubtractChecked:
case ExpressionType.Multiply:
case ExpressionType.MultiplyChecked:
case ExpressionType.Divide:
case ExpressionType.Modulo:
case ExpressionType.Power:
case ExpressionType.And:
case ExpressionType.AndAlso:
case ExpressionType.Or:
case ExpressionType.OrElse:
case ExpressionType.LessThan:
case ExpressionType.LessThanOrEqual:
case ExpressionType.GreaterThan:
case ExpressionType.GreaterThanOrEqual:
case ExpressionType.Equal:
case ExpressionType.NotEqual:
case ExpressionType.Coalesce:
case ExpressionType.ArrayIndex:
case ExpressionType.RightShift:
case ExpressionType.LeftShift:
case ExpressionType.ExclusiveOr:
return this.VisitBinary((BinaryExpression)exp);
case ExpressionType.TypeIs:
return this.VisitTypeIs((TypeBinaryExpression)exp);
case ExpressionType.Conditional:
return this.VisitConditional((ConditionalExpression)exp);
case ExpressionType.Constant:
return this.VisitConstant((ConstantExpression)exp);
case ExpressionType.Parameter:
return this.VisitParameter((ParameterExpression)exp);
case ExpressionType.MemberAccess:
return this.VisitMemberAccess((MemberExpression)exp);
case ExpressionType.Call:
return this.VisitMethodCall((MethodCallExpression)exp);
case ExpressionType.Lambda:
return this.VisitLambda((LambdaExpression)exp);
case ExpressionType.New:
return this.VisitNew((NewExpression)exp);
case ExpressionType.NewArrayInit:
case ExpressionType.NewArrayBounds:
return this.VisitNewArray((NewArrayExpression)exp);
case ExpressionType.Invoke:
return this.VisitInvocation((InvocationExpression)exp);
case ExpressionType.MemberInit:
return this.VisitMemberInit((MemberInitExpression)exp);
case ExpressionType.ListInit:
return this.VisitListInit((ListInitExpression)exp);
case ExpressionType.UnaryPlus:
if (exp.Type == typeof(TimeSpan))
return this.VisitUnary((UnaryExpression)exp);
throw Error.UnhandledExpressionType(exp.NodeType);
default:
throw Error.UnhandledExpressionType(exp.NodeType);
}
}
internal virtual MemberBinding VisitBinding(MemberBinding binding) {
switch (binding.BindingType) {
case MemberBindingType.Assignment:
return this.VisitMemberAssignment((MemberAssignment)binding);
case MemberBindingType.MemberBinding:
return this.VisitMemberMemberBinding((MemberMemberBinding)binding);
case MemberBindingType.ListBinding:
return this.VisitMemberListBinding((MemberListBinding)binding);
default:
throw Error.UnhandledBindingType(binding.BindingType);
}
}
internal virtual ElementInit VisitElementInitializer(ElementInit initializer) {
ReadOnlyCollection<Expression> arguments = this.VisitExpressionList(initializer.Arguments);
if (arguments != initializer.Arguments) {
return Expression.ElementInit(initializer.AddMethod, arguments);
}
return initializer;
}
internal virtual Expression VisitUnary(UnaryExpression u) {
Expression operand = this.Visit(u.Operand);
if (operand != u.Operand) {
return Expression.MakeUnary(u.NodeType, operand, u.Type, u.Method);
}
return u;
}
internal virtual Expression VisitBinary(BinaryExpression b) {
Expression left = this.Visit(b.Left);
Expression right = this.Visit(b.Right);
if (left != b.Left || right != b.Right) {
return Expression.MakeBinary(b.NodeType, left, right, b.IsLiftedToNull, b.Method);
}
return b;
}
internal virtual Expression VisitTypeIs(TypeBinaryExpression b) {
Expression expr = this.Visit(b.Expression);
if (expr != b.Expression) {
return Expression.TypeIs(expr, b.TypeOperand);
}
return b;
}
internal virtual Expression VisitConstant(ConstantExpression c) {
return c;
}
internal virtual Expression VisitConditional(ConditionalExpression c) {
Expression test = this.Visit(c.Test);
Expression ifTrue = this.Visit(c.IfTrue);
Expression ifFalse = this.Visit(c.IfFalse);
if (test != c.Test || ifTrue != c.IfTrue || ifFalse != c.IfFalse) {
return Expression.Condition(test, ifTrue, ifFalse);
}
return c;
}
internal virtual Expression VisitParameter(ParameterExpression p) {
return p;
}
internal virtual Expression VisitMemberAccess(MemberExpression m) {
Expression exp = this.Visit(m.Expression);
if (exp != m.Expression) {
return Expression.MakeMemberAccess(exp, m.Member);
}
return m;
}
internal virtual Expression VisitMethodCall(MethodCallExpression m) {
Expression obj = this.Visit(m.Object);
IEnumerable<Expression> args = this.VisitExpressionList(m.Arguments);
if (obj != m.Object || args != m.Arguments) {
return Expression.Call(obj, m.Method, args);
}
return m;
}
internal virtual ReadOnlyCollection<Expression> VisitExpressionList(ReadOnlyCollection<Expression> original) {
List<Expression> list = null;
for (int i = 0, n = original.Count; i < n; i++) {
Expression p = this.Visit(original[i]);
if (list != null) {
list.Add(p);
}
else if (p != original[i]) {
list = new List<Expression>(n);
for (int j = 0; j < i; j++) {
list.Add(original[j]);
}
list.Add(p);
}
}
if (list != null)
return new ReadOnlyCollection<Expression>(list);
return original;
}
internal virtual MemberAssignment VisitMemberAssignment(MemberAssignment assignment) {
Expression e = this.Visit(assignment.Expression);
if (e != assignment.Expression) {
return Expression.Bind(assignment.Member, e);
}
return assignment;
}
internal virtual MemberMemberBinding VisitMemberMemberBinding(MemberMemberBinding binding) {
IEnumerable<MemberBinding> bindings = this.VisitBindingList(binding.Bindings);
if (bindings != binding.Bindings) {
return Expression.MemberBind(binding.Member, bindings);
}
return binding;
}
internal virtual MemberListBinding VisitMemberListBinding(MemberListBinding binding) {
IEnumerable<ElementInit> initializers = this.VisitElementInitializerList(binding.Initializers);
if (initializers != binding.Initializers) {
return Expression.ListBind(binding.Member, initializers);
}
return binding;
}
internal virtual IEnumerable<MemberBinding> VisitBindingList(ReadOnlyCollection<MemberBinding> original) {
List<MemberBinding> list = null;
for (int i = 0, n = original.Count; i < n; i++) {
MemberBinding b = this.VisitBinding(original[i]);
if (list != null) {
list.Add(b);
}
else if (b != original[i]) {
list = new List<MemberBinding>(n);
for (int j = 0; j < i; j++) {
list.Add(original[j]);
}
list.Add(b);
}
}
if (list != null)
return list;
return original;
}
internal virtual IEnumerable<ElementInit> VisitElementInitializerList(ReadOnlyCollection<ElementInit> original) {
List<ElementInit> list = null;
for (int i = 0, n = original.Count; i < n; i++) {
ElementInit init = this.VisitElementInitializer(original[i]);
if (list != null) {
list.Add(init);
}
else if (init != original[i]) {
list = new List<ElementInit>(n);
for (int j = 0; j < i; j++) {
list.Add(original[j]);
}
list.Add(init);
}
}
if (list != null) {
return list;
}
return original;
}
internal virtual Expression VisitLambda(LambdaExpression lambda) {
Expression body = this.Visit(lambda.Body);
if (body != lambda.Body) {
return Expression.Lambda(lambda.Type, body, lambda.Parameters);
}
return lambda;
}
internal virtual NewExpression VisitNew(NewExpression nex) {
IEnumerable<Expression> args = this.VisitExpressionList(nex.Arguments);
if (args != nex.Arguments) {
if (nex.Members != null) {
return Expression.New(nex.Constructor, args, nex.Members);
}
else {
return Expression.New(nex.Constructor, args);
}
}
return nex;
}
internal virtual Expression VisitMemberInit(MemberInitExpression init) {
NewExpression n = this.VisitNew(init.NewExpression);
IEnumerable<MemberBinding> bindings = this.VisitBindingList(init.Bindings);
if (n != init.NewExpression || bindings != init.Bindings) {
return Expression.MemberInit(n, bindings);
}
return init;
}
internal virtual Expression VisitListInit(ListInitExpression init) {
NewExpression n = this.VisitNew(init.NewExpression);
IEnumerable<ElementInit> initializers = this.VisitElementInitializerList(init.Initializers);
if (n != init.NewExpression || initializers != init.Initializers) {
return Expression.ListInit(n, initializers);
}
return init;
}
internal virtual Expression VisitNewArray(NewArrayExpression na) {
IEnumerable<Expression> exprs = this.VisitExpressionList(na.Expressions);
if (exprs != na.Expressions) {
if (na.NodeType == ExpressionType.NewArrayInit) {
return Expression.NewArrayInit(na.Type.GetElementType(), exprs);
}
else {
return Expression.NewArrayBounds(na.Type.GetElementType(), exprs);
}
}
return na;
}
internal virtual Expression VisitInvocation(InvocationExpression iv) {
IEnumerable<Expression> args = this.VisitExpressionList(iv.Arguments);
Expression expr = this.Visit(iv.Expression);
if (args != iv.Arguments || expr != iv.Expression) {
return Expression.Invoke(expr, args);
}
return iv;
}
}
}

View File

@@ -0,0 +1,118 @@
using System;
using System.Collections.Generic;
using System.Text;
using System.Data.Linq;
namespace System.Data.Linq.SqlClient {
/// <summary>
/// Converts expressions of type NText, Text, Image to NVarChar(MAX), VarChar(MAX), VarBinary(MAX)
/// where necessary. This can only be done on SQL2005, so we add a SqlServerCompatibilityAnnotation
/// to the changed nodes.
/// </summary>
internal class LongTypeConverter {
Visitor visitor;
internal LongTypeConverter(SqlFactory sql) {
this.visitor = new Visitor(sql);
}
internal SqlNode AddConversions(SqlNode node, SqlNodeAnnotations annotations) {
visitor.Annotations = annotations;
return visitor.Visit(node);
}
class Visitor : SqlVisitor {
SqlFactory sql;
SqlNodeAnnotations annotations;
internal SqlNodeAnnotations Annotations {
set { this.annotations = value; }
}
internal Visitor(SqlFactory sql) {
this.sql = sql;
}
private SqlExpression ConvertToMax(SqlExpression expr, ProviderType newType) {
return sql.UnaryConvert(expr.ClrType, newType, expr, expr.SourceExpression);
}
// returns CONVERT(VARCHAR/NVARCHAR/VARBINARY(MAX), expr) if provType is one of Text, NText or Image
// otherwise just returns expr
// changed is true if CONVERT(...(MAX),...) was added
private SqlExpression ConvertToMax(SqlExpression expr, out bool changed) {
changed = false;
if (!expr.SqlType.IsLargeType)
return expr;
ProviderType newType = sql.TypeProvider.GetBestLargeType(expr.SqlType);
changed = true;
if (expr.SqlType != newType) {
return ConvertToMax(expr, newType);
}
changed = false;
return expr;
}
private void ConvertColumnsToMax(SqlSelect select, out bool changed, out bool containsLongExpressions) {
SqlRow row = select.Row;
changed = false;
containsLongExpressions = false;
foreach (SqlColumn col in row.Columns) {
bool columnChanged;
containsLongExpressions = containsLongExpressions || col.SqlType.IsLargeType;
col.Expression = ConvertToMax(col.Expression, out columnChanged);
changed = changed || columnChanged;
}
}
internal override SqlSelect VisitSelect(SqlSelect select) {
if (select.IsDistinct) {
bool changed;
bool containsLongExpressions;
ConvertColumnsToMax(select, out changed, out containsLongExpressions);
if (containsLongExpressions) {
this.annotations.Add(select, new SqlServerCompatibilityAnnotation(
Strings.TextNTextAndImageCannotOccurInDistinct(select.SourceExpression), SqlProvider.ProviderMode.Sql2000, SqlProvider.ProviderMode.SqlCE));
}
}
return base.VisitSelect(select);
}
internal override SqlNode VisitUnion(SqlUnion su) {
bool changedLeft = false;
bool containsLongExpressionsLeft = false;
SqlSelect left = su.Left as SqlSelect;
if (left != null) {
ConvertColumnsToMax(left, out changedLeft, out containsLongExpressionsLeft);
}
bool changedRight = false;
bool containsLongExpressionsRight = false;
SqlSelect right = su.Right as SqlSelect;
if (right != null) {
ConvertColumnsToMax(right, out changedRight, out containsLongExpressionsRight);
}
if (!su.All && (containsLongExpressionsLeft || containsLongExpressionsRight)) {
// unless the UNION is 'ALL', the server will perform a DISTINCT operation,
// which isn't valid for large types (text, ntext, image)
this.annotations.Add(su, new SqlServerCompatibilityAnnotation(
Strings.TextNTextAndImageCannotOccurInUnion(su.SourceExpression), SqlProvider.ProviderMode.Sql2000, SqlProvider.ProviderMode.SqlCE));
}
return base.VisitUnion(su);
}
internal override SqlExpression VisitFunctionCall(SqlFunctionCall fc) {
if (fc.Name == "LEN") {
bool changed;
fc.Arguments[0] = ConvertToMax(fc.Arguments[0],out changed);
if (fc.Arguments[0].SqlType.IsLargeType) {
this.annotations.Add(fc, new SqlServerCompatibilityAnnotation(
Strings.LenOfTextOrNTextNotSupported(fc.SourceExpression), SqlProvider.ProviderMode.Sql2000));
}
}
return base.VisitFunctionCall(fc);
}
}
}
}

View File

@@ -0,0 +1,357 @@
using System;
using System.Collections.Generic;
using System.Text;
using System.Linq.Expressions;
using System.Reflection;
using System.Data.Linq;
using System.Data.Linq.Provider;
using System.Data.Linq.Mapping;
namespace System.Data.Linq.SqlClient {
// convert special method calls and member accesses into known sql nodes
internal static class PreBindDotNetConverter {
internal static SqlNode Convert(SqlNode node, SqlFactory sql, MetaModel model) {
return new Visitor(sql, model).Visit(node);
}
internal static bool CanConvert(SqlNode node) {
SqlBinary bo = node as SqlBinary;
if (bo != null && (IsCompareToValue(bo) || IsVbCompareStringEqualsValue(bo))) {
return true;
}
SqlMember sm = node as SqlMember;
if (sm != null && IsSupportedMember(sm)) {
return true;
}
SqlMethodCall mc = node as SqlMethodCall;
if (mc != null && (IsSupportedMethod(mc) || IsSupportedVbHelperMethod(mc))) {
return true;
}
return false;
}
private static bool IsCompareToValue(SqlBinary bo) {
if (IsComparison(bo.NodeType)
&& bo.Left.NodeType == SqlNodeType.MethodCall
&& bo.Right.NodeType == SqlNodeType.Value) {
SqlMethodCall call = (SqlMethodCall)bo.Left;
return IsCompareToMethod(call) || IsCompareMethod(call);
}
return false;
}
private static bool IsCompareToMethod(SqlMethodCall call) {
return !call.Method.IsStatic && call.Method.Name == "CompareTo" && call.Arguments.Count == 1 && call.Method.ReturnType == typeof(int);
}
private static bool IsCompareMethod(SqlMethodCall call) {
return call.Method.IsStatic && call.Method.Name == "Compare" && call.Arguments.Count > 1 && call.Method.ReturnType == typeof(int);
}
private static bool IsComparison(SqlNodeType nodeType) {
switch (nodeType) {
case SqlNodeType.EQ:
case SqlNodeType.NE:
case SqlNodeType.LT:
case SqlNodeType.LE:
case SqlNodeType.GT:
case SqlNodeType.GE:
case SqlNodeType.EQ2V:
case SqlNodeType.NE2V:
return true;
default:
return false;
}
}
private static bool IsVbCompareStringEqualsValue(SqlBinary bo) {
return IsComparison(bo.NodeType)
&& bo.Left.NodeType == SqlNodeType.MethodCall
&& bo.Right.NodeType == SqlNodeType.Value
&& IsVbCompareString((SqlMethodCall)bo.Left);
}
private static bool IsVbCompareString(SqlMethodCall call) {
return call.Method.IsStatic &&
call.Method.DeclaringType.FullName == "Microsoft.VisualBasic.CompilerServices.Operators" &&
call.Method.Name == "CompareString";
}
private static bool IsSupportedVbHelperMethod(SqlMethodCall mc) {
return IsVbIIF(mc);
}
private static bool IsVbIIF(SqlMethodCall mc) {
return mc.Method.IsStatic &&
mc.Method.DeclaringType.FullName == "Microsoft.VisualBasic.Interaction" && mc.Method.Name == "IIf";
}
private static bool IsSupportedMember(SqlMember m) {
return IsNullableHasValue(m) || IsNullableHasValue(m);
}
private static bool IsNullableValue(SqlMember m) {
return TypeSystem.IsNullableType(m.Expression.ClrType) && m.Member.Name == "Value";
}
private static bool IsNullableHasValue(SqlMember m) {
return TypeSystem.IsNullableType(m.Expression.ClrType) && m.Member.Name == "HasValue";
}
private static bool IsSupportedMethod(SqlMethodCall mc) {
if (mc.Method.IsStatic) {
switch (mc.Method.Name) {
case "op_Equality":
case "op_Inequality":
case "op_LessThan":
case "op_LessThanOrEqual":
case "op_GreaterThan":
case "op_GreaterThanOrEqual":
case "op_Multiply":
case "op_Division":
case "op_Subtraction":
case "op_Addition":
case "op_Modulus":
case "op_BitwiseAnd":
case "op_BitwiseOr":
case "op_ExclusiveOr":
case "op_UnaryNegation":
case "op_OnesComplement":
case "op_False":
return true;
case "Equals":
return mc.Arguments.Count == 2;
case "Concat":
return mc.Method.DeclaringType == typeof(string);
}
}
else {
return mc.Method.Name == "Equals" && mc.Arguments.Count == 1 ||
mc.Method.Name == "GetType" && mc.Arguments.Count == 0;
}
return false;
}
private class Visitor : SqlVisitor {
SqlFactory sql;
MetaModel model;
internal Visitor(SqlFactory sql, MetaModel model) {
this.sql = sql;
this.model = model;
}
internal override SqlExpression VisitBinaryOperator(SqlBinary bo) {
if (IsCompareToValue(bo)) {
SqlMethodCall call = (SqlMethodCall)bo.Left;
if (IsCompareToMethod(call)) {
int iValue = System.Convert.ToInt32(this.Eval(bo.Right), Globalization.CultureInfo.InvariantCulture);
bo = this.MakeCompareTo(call.Object, call.Arguments[0], bo.NodeType, iValue) ?? bo;
}
else if (IsCompareMethod(call)) {
int iValue = System.Convert.ToInt32(this.Eval(bo.Right), Globalization.CultureInfo.InvariantCulture);
bo = this.MakeCompareTo(call.Arguments[0], call.Arguments[1], bo.NodeType, iValue) ?? bo;
}
}
else if (IsVbCompareStringEqualsValue(bo)) {
SqlMethodCall call = (SqlMethodCall)bo.Left;
int iValue = System.Convert.ToInt32(this.Eval(bo.Right), Globalization.CultureInfo.InvariantCulture);
//in VB, comparing a string with Nothing means comparing with ""
SqlValue strValue = call.Arguments[1] as SqlValue;
if (strValue != null && strValue.Value == null) {
SqlValue emptyStr = new SqlValue(strValue.ClrType, strValue.SqlType, String.Empty, strValue.IsClientSpecified, strValue.SourceExpression);
bo = this.MakeCompareTo(call.Arguments[0], emptyStr, bo.NodeType, iValue) ?? bo;
}
else {
bo = this.MakeCompareTo(call.Arguments[0], call.Arguments[1], bo.NodeType, iValue) ?? bo;
}
}
return base.VisitBinaryOperator(bo);
}
private SqlBinary MakeCompareTo(SqlExpression left, SqlExpression right, SqlNodeType op, int iValue) {
if (iValue == 0) {
return sql.Binary(op, left, right);
}
else if (op == SqlNodeType.EQ || op == SqlNodeType.EQ2V) {
switch (iValue) {
case -1:
return sql.Binary(SqlNodeType.LT, left, right);
case 1:
return sql.Binary(SqlNodeType.GT, left, right);
}
}
return null;
}
private SqlExpression CreateComparison(SqlExpression a, SqlExpression b, Expression source) {
SqlExpression lower = sql.Binary(SqlNodeType.LT, a, b);
SqlExpression equal = sql.Binary(SqlNodeType.EQ2V, a, b);
return sql.SearchedCase(
new SqlWhen[] {
new SqlWhen(lower, sql.ValueFromObject(-1, false, source)),
new SqlWhen(equal, sql.ValueFromObject(0, false, source)),
},
sql.ValueFromObject(1, false, source), source
);
}
internal override SqlNode VisitMember(SqlMember m) {
m.Expression = this.VisitExpression(m.Expression);
if (IsNullableValue(m)) {
return sql.UnaryValueOf(m.Expression, m.SourceExpression);
}
else if (IsNullableHasValue(m)) {
return sql.Unary(SqlNodeType.IsNotNull, m.Expression, m.SourceExpression);
}
return m;
}
internal override SqlExpression VisitMethodCall(SqlMethodCall mc) {
mc.Object = this.VisitExpression(mc.Object);
for (int i = 0, n = mc.Arguments.Count; i < n; i++) {
mc.Arguments[i] = this.VisitExpression(mc.Arguments[i]);
}
if (mc.Method.IsStatic) {
if (mc.Method.Name == "Equals" && mc.Arguments.Count == 2) {
return sql.Binary(SqlNodeType.EQ2V, mc.Arguments[0], mc.Arguments[1], mc.Method);
}
else if (mc.Method.DeclaringType == typeof(string) && mc.Method.Name == "Concat") {
SqlClientArray arr = mc.Arguments[0] as SqlClientArray;
List<SqlExpression> exprs = null;
if (arr != null) {
exprs = arr.Expressions;
}
else {
exprs = mc.Arguments;
}
if (exprs.Count == 0) {
return sql.ValueFromObject("", false, mc.SourceExpression);
}
else {
SqlExpression sum;
if (exprs[0].SqlType.IsString || exprs[0].SqlType.IsChar) {
sum = exprs[0];
}
else {
sum = sql.ConvertTo(typeof(string), exprs[0]);
}
for (int i = 1; i < exprs.Count; i++) {
if (exprs[i].SqlType.IsString || exprs[i].SqlType.IsChar) {
sum = sql.Concat(sum, exprs[i]);
}
else {
sum = sql.Concat(sum, sql.ConvertTo(typeof(string), exprs[i]));
}
}
return sum;
}
}
else if (IsVbIIF(mc)) {
return TranslateVbIIF(mc);
}
else {
switch (mc.Method.Name) {
case "op_Equality":
return sql.Binary(SqlNodeType.EQ, mc.Arguments[0], mc.Arguments[1], mc.Method, mc.ClrType);
case "op_Inequality":
return sql.Binary(SqlNodeType.NE, mc.Arguments[0], mc.Arguments[1], mc.Method, mc.ClrType);
case "op_LessThan":
return sql.Binary(SqlNodeType.LT, mc.Arguments[0], mc.Arguments[1], mc.Method, mc.ClrType);
case "op_LessThanOrEqual":
return sql.Binary(SqlNodeType.LE, mc.Arguments[0], mc.Arguments[1], mc.Method, mc.ClrType);
case "op_GreaterThan":
return sql.Binary(SqlNodeType.GT, mc.Arguments[0], mc.Arguments[1], mc.Method, mc.ClrType);
case "op_GreaterThanOrEqual":
return sql.Binary(SqlNodeType.GE, mc.Arguments[0], mc.Arguments[1], mc.Method, mc.ClrType);
case "op_Multiply":
return sql.Binary(SqlNodeType.Mul, mc.Arguments[0], mc.Arguments[1], mc.Method, mc.ClrType);
case "op_Division":
return sql.Binary(SqlNodeType.Div, mc.Arguments[0], mc.Arguments[1], mc.Method, mc.ClrType);
case "op_Subtraction":
return sql.Binary(SqlNodeType.Sub, mc.Arguments[0], mc.Arguments[1], mc.Method, mc.ClrType);
case "op_Addition":
return sql.Binary(SqlNodeType.Add, mc.Arguments[0], mc.Arguments[1], mc.Method, mc.ClrType);
case "op_Modulus":
return sql.Binary(SqlNodeType.Mod, mc.Arguments[0], mc.Arguments[1], mc.Method, mc.ClrType);
case "op_BitwiseAnd":
return sql.Binary(SqlNodeType.BitAnd, mc.Arguments[0], mc.Arguments[1], mc.Method, mc.ClrType);
case "op_BitwiseOr":
return sql.Binary(SqlNodeType.BitOr, mc.Arguments[0], mc.Arguments[1], mc.Method, mc.ClrType);
case "op_ExclusiveOr":
return sql.Binary(SqlNodeType.BitXor, mc.Arguments[0], mc.Arguments[1], mc.Method, mc.ClrType);
case "op_UnaryNegation":
return sql.Unary(SqlNodeType.Negate, mc.Arguments[0], mc.Method, mc.SourceExpression);
case "op_OnesComplement":
return sql.Unary(SqlNodeType.BitNot, mc.Arguments[0], mc.Method, mc.SourceExpression);
case "op_False":
return sql.Unary(SqlNodeType.Not, mc.Arguments[0], mc.Method, mc.SourceExpression);
}
}
}
else {
if (mc.Method.Name == "Equals" && mc.Arguments.Count == 1) {
return sql.Binary(SqlNodeType.EQ, mc.Object, mc.Arguments[0]);
}
else if (mc.Method.Name == "GetType" && mc.Arguments.Count == 0) {
MetaType mt = TypeSource.GetSourceMetaType(mc.Object, this.model);
if (mt.HasInheritance) {
Type discriminatorType = mt.Discriminator.Type;
SqlDiscriminatorOf discriminatorOf = new SqlDiscriminatorOf(mc.Object, discriminatorType, this.sql.TypeProvider.From(discriminatorType), mc.SourceExpression);
return this.VisitExpression(sql.DiscriminatedType(discriminatorOf, mt));
}
return this.VisitExpression(sql.StaticType(mt, mc.SourceExpression));
}
}
return mc;
}
private SqlExpression TranslateVbIIF(SqlMethodCall mc) {
//Check to see if the types can be implicitly converted from one to another.
if (mc.Arguments[1].ClrType == mc.Arguments[2].ClrType) {
List<SqlWhen> whens = new List<SqlWhen>(1);
whens.Add(new SqlWhen(mc.Arguments[0], mc.Arguments[1]));
SqlExpression @else = mc.Arguments[2];
while (@else.NodeType == SqlNodeType.SearchedCase) {
SqlSearchedCase sc = (SqlSearchedCase)@else;
whens.AddRange(sc.Whens);
@else = sc.Else;
}
return sql.SearchedCase(whens.ToArray(), @else, mc.SourceExpression);
}
else {
throw Error.IifReturnTypesMustBeEqual(mc.Arguments[1].ClrType.Name, mc.Arguments[2].ClrType.Name);
}
}
internal override SqlExpression VisitTreat(SqlUnary t) {
t.Operand = this.VisitExpression(t.Operand);
Type treatType = t.ClrType;
Type originalType = model.GetMetaType(t.Operand.ClrType).InheritanceRoot.Type;
// .NET nullability rules are that typeof(int)==typeof(int?). Let's be consistent with that:
treatType = TypeSystem.GetNonNullableType(treatType);
originalType = TypeSystem.GetNonNullableType(originalType);
if (treatType == originalType) {
return t.Operand;
}
else if (treatType.IsAssignableFrom(originalType)) {
t.Operand.SetClrType(treatType);
return t.Operand;
}
else if (!treatType.IsAssignableFrom(originalType) && !originalType.IsAssignableFrom(treatType)) {
if (!treatType.IsInterface && !originalType.IsInterface) { // You can't tell when there's an interface involved.
// We statically know the TREAT will result in NULL.
return sql.TypedLiteralNull(treatType, t.SourceExpression);
}
}
//return base.VisitTreat(t);
return t;
}
}
}
}

View File

@@ -0,0 +1 @@
55ae51ee9571e225d66d051c088e676a1ee4f3e6

View File

@@ -0,0 +1,120 @@
using System;
using System.Collections.Generic;
using System.Linq.Expressions;
using System.Reflection;
using System.Data.Linq;
using System.Data.Linq.Mapping;
using System.Data.Linq.Provider;
using System.Linq;
using System.Data.Linq.SqlClient;
using System.Diagnostics.CodeAnalysis;
namespace System.Data.Linq.SqlClient {
/// <summary>
/// Determines whether an expression is simple or not.
/// Simple is a scalar expression that contains only functions, operators and column references
/// </summary>
internal static class SimpleExpression {
internal static bool IsSimple(SqlExpression expr) {
Visitor v = new Visitor();
v.Visit(expr);
return v.IsSimple;
}
class Visitor : SqlVisitor {
bool isSimple = true;
internal bool IsSimple {
get { return this.isSimple; }
}
[SuppressMessage("Microsoft.Maintainability", "CA1502:AvoidExcessiveComplexity", Justification = "These issues are related to our use of if-then and case statements for node types, which adds to the complexity count however when reviewed they are easy to navigate and understand.")]
internal override SqlNode Visit(SqlNode node) {
if (node == null) {
return null;
}
if (!this.isSimple) {
return node;
}
switch (node.NodeType) {
case SqlNodeType.Not:
case SqlNodeType.Not2V:
case SqlNodeType.Negate:
case SqlNodeType.BitNot:
case SqlNodeType.IsNull:
case SqlNodeType.IsNotNull:
case SqlNodeType.ValueOf:
case SqlNodeType.OuterJoinedValue:
case SqlNodeType.ClrLength:
case SqlNodeType.Add:
case SqlNodeType.Sub:
case SqlNodeType.Mul:
case SqlNodeType.Div:
case SqlNodeType.Mod:
case SqlNodeType.BitAnd:
case SqlNodeType.BitOr:
case SqlNodeType.BitXor:
case SqlNodeType.And:
case SqlNodeType.Or:
case SqlNodeType.GE:
case SqlNodeType.GT:
case SqlNodeType.LE:
case SqlNodeType.LT:
case SqlNodeType.EQ:
case SqlNodeType.NE:
case SqlNodeType.EQ2V:
case SqlNodeType.NE2V:
case SqlNodeType.Between:
case SqlNodeType.Concat:
case SqlNodeType.Convert:
case SqlNodeType.Treat:
case SqlNodeType.Member:
case SqlNodeType.TypeCase:
case SqlNodeType.SearchedCase:
case SqlNodeType.SimpleCase:
case SqlNodeType.Like:
case SqlNodeType.FunctionCall:
case SqlNodeType.ExprSet:
case SqlNodeType.OptionalValue:
case SqlNodeType.Parameter:
case SqlNodeType.ColumnRef:
case SqlNodeType.Value:
case SqlNodeType.Variable:
return base.Visit(node);
case SqlNodeType.Column:
case SqlNodeType.ClientCase:
case SqlNodeType.DiscriminatedType:
case SqlNodeType.Link:
case SqlNodeType.Row:
case SqlNodeType.UserQuery:
case SqlNodeType.StoredProcedureCall:
case SqlNodeType.UserRow:
case SqlNodeType.UserColumn:
case SqlNodeType.Multiset:
case SqlNodeType.ScalarSubSelect:
case SqlNodeType.Element:
case SqlNodeType.Exists:
case SqlNodeType.Join:
case SqlNodeType.Select:
case SqlNodeType.New:
case SqlNodeType.ClientQuery:
case SqlNodeType.ClientArray:
case SqlNodeType.Insert:
case SqlNodeType.Update:
case SqlNodeType.Delete:
case SqlNodeType.MemberAssign:
case SqlNodeType.Assign:
case SqlNodeType.Block:
case SqlNodeType.Union:
case SqlNodeType.DoNotVisit:
case SqlNodeType.MethodCall:
case SqlNodeType.Nop:
default:
this.isSimple = false;
return node;
}
}
}
}
}

View File

@@ -0,0 +1,49 @@
using System;
using System.Collections.Generic;
using System.Linq.Expressions;
using System.Data.Linq;
namespace System.Data.Linq.SqlClient {
internal class SqlAggregateChecker {
Visitor visitor;
internal SqlAggregateChecker() {
this.visitor = new Visitor();
}
internal bool HasAggregates(SqlNode node) {
visitor.hasAggregates = false;
visitor.Visit(node);
return visitor.hasAggregates;
}
class Visitor : SqlVisitor {
internal bool hasAggregates;
internal Visitor() {
}
internal override SqlExpression VisitSubSelect(SqlSubSelect ss) {
return ss;
}
internal override SqlSource VisitSource(SqlSource source) {
return source;
}
internal override SqlExpression VisitUnaryOperator(SqlUnary uo) {
switch (uo.NodeType) {
case SqlNodeType.Min:
case SqlNodeType.Max:
case SqlNodeType.Avg:
case SqlNodeType.Sum:
case SqlNodeType.Count:
case SqlNodeType.LongCount:
this.hasAggregates = true;
return uo;
default:
return base.VisitUnaryOperator(uo);
}
}
}
}
}

View File

@@ -0,0 +1,52 @@
using System;
using System.Collections.Generic;
using System.Text;
using System.Data.Linq;
namespace System.Data.Linq.SqlClient {
internal class SqlAliaser {
Visitor visitor;
internal SqlAliaser() {
this.visitor = new Visitor();
}
internal SqlNode AssociateColumnsWithAliases(SqlNode node) {
return this.visitor.Visit(node);
}
class Visitor : SqlVisitor {
SqlAlias alias;
internal Visitor() {
}
internal override SqlAlias VisitAlias(SqlAlias sqlAlias) {
SqlAlias save = this.alias;
this.alias = sqlAlias;
sqlAlias.Node = this.Visit(sqlAlias.Node);
this.alias = save;
return sqlAlias;
}
internal override SqlRow VisitRow(SqlRow row) {
foreach (SqlColumn c in row.Columns) {
c.Alias = alias;
}
return base.VisitRow(row);
}
internal override SqlTable VisitTable(SqlTable tab) {
foreach (SqlColumn c in tab.Columns) {
c.Alias = alias;
}
return base.VisitTable(tab);
}
internal override SqlExpression VisitTableValuedFunctionCall(SqlTableValuedFunctionCall fc) {
foreach (SqlColumn c in fc.Columns) {
c.Alias = this.alias;
}
return base.VisitTableValuedFunctionCall(fc);
}
}
}
}

View File

@@ -0,0 +1,68 @@
using System;
using System.Collections.Generic;
using System.Text;
namespace System.Data.Linq.SqlClient
{
/// <summary>
/// Find referenced Aliases within a node.
/// </summary>
internal static class SqlAliasesReferenced
{
/// <summary>
/// Private visitor which walks the tree and looks for referenced aliases.
/// </summary>
private class Visitor : SqlVisitor {
internal IEnumerable<SqlAlias> aliases;
internal bool referencesAnyMatchingAliases = false;
internal override SqlNode Visit(SqlNode node) {
// Short-circuit when the answer is alreading known
if (this.referencesAnyMatchingAliases) {
return node;
}
return base.Visit(node);
}
internal SqlAlias VisitAliasConsumed(SqlAlias a) {
if (a == null)
return a;
bool match = false;
foreach (SqlAlias alias in aliases)
if (alias == a) {
match = true;
break;
}
if (match) {
this.referencesAnyMatchingAliases = true;
}
return a;
}
internal override SqlExpression VisitColumn(SqlColumn col) {
VisitAliasConsumed(col.Alias);
VisitExpression(col.Expression);
return col;
}
internal override SqlExpression VisitColumnRef(SqlColumnRef cref) {
VisitAliasConsumed(cref.Column.Alias);
VisitExpression(cref.Column.Expression);
return cref;
}
}
/// <summary>
/// Returns true iff the given node references any aliases the list of 'aliases'.
/// </summary>
internal static bool ReferencesAny(SqlNode node, IEnumerable<SqlAlias> aliases) {
Visitor visitor = new Visitor();
visitor.aliases = aliases;
visitor.Visit(node);
return visitor.referencesAnyMatchingAliases;
}
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,150 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Linq.Expressions;
using System.Data;
namespace System.Data.Linq.SqlClient {
using System.Data.Linq;
using System.Diagnostics.CodeAnalysis;
/// <summary>
/// This visitor searches for places where 'Predicate' is found but a 'Bit'
/// was expected or vice versa. In response, it will call VisitBitExpectedPredicate
/// and VisitPredicateExpectedBit.
/// </summary>
internal abstract class SqlBooleanMismatchVisitor : SqlVisitor {
internal SqlBooleanMismatchVisitor() {
}
internal abstract SqlExpression ConvertValueToPredicate(SqlExpression valueExpression);
internal abstract SqlExpression ConvertPredicateToValue(SqlExpression predicateExpression);
internal override SqlSelect VisitSelect(SqlSelect select) {
select.From = this.VisitSource(select.From);
select.Where = this.VisitPredicate(select.Where);
for (int i = 0, n = select.GroupBy.Count; i < n; i++) {
select.GroupBy[i] = this.VisitExpression(select.GroupBy[i]);
}
select.Having = this.VisitPredicate(select.Having);
for (int i = 0, n = select.OrderBy.Count; i < n; i++) {
select.OrderBy[i].Expression = this.VisitExpression(select.OrderBy[i].Expression);
}
select.Top = this.VisitExpression(select.Top);
select.Row = (SqlRow)this.Visit(select.Row);
// don't visit selection
//select.Selection = this.VisitExpression(select.Selection);
return select;
}
internal override SqlSource VisitJoin(SqlJoin join) {
join.Left = this.VisitSource(join.Left);
join.Right = this.VisitSource(join.Right);
join.Condition = this.VisitPredicate(join.Condition);
return join;
}
internal override SqlExpression VisitUnaryOperator(SqlUnary uo) {
if (uo.NodeType.IsUnaryOperatorExpectingPredicateOperand()) {
uo.Operand = this.VisitPredicate(uo.Operand);
} else {
uo.Operand = this.VisitExpression(uo.Operand);
}
return uo;
}
internal override SqlExpression VisitBinaryOperator(SqlBinary bo) {
if (bo.NodeType.IsBinaryOperatorExpectingPredicateOperands()) {
bo.Left = this.VisitPredicate(bo.Left);
bo.Right = this.VisitPredicate(bo.Right);
} else {
bo.Left = this.VisitExpression(bo.Left);
bo.Right = this.VisitExpression(bo.Right);
}
return bo;
}
internal override SqlStatement VisitAssign(SqlAssign sa) {
// L-Value of assign is never a 'Bit' nor a 'Predicate'.
sa.LValue = this.VisitExpression(sa.LValue);
sa.RValue = this.VisitExpression(sa.RValue);
return sa;
}
internal override SqlExpression VisitSearchedCase(SqlSearchedCase c) {
for (int i = 0, n = c.Whens.Count; i < n; i++) {
SqlWhen when = c.Whens[i];
when.Match = this.VisitPredicate(when.Match);
when.Value = this.VisitExpression(when.Value);
}
c.Else = this.VisitExpression(c.Else);
return c;
}
internal override SqlExpression VisitLift(SqlLift lift) {
lift.Expression = base.VisitExpression(lift.Expression);
return lift;
}
/// <summary>
/// If an expression is type 'Bit' but a 'Predicate' is expected then
/// call 'VisitBitExpectedPredicate'.
/// </summary>
internal SqlExpression VisitPredicate(SqlExpression exp) {
exp = (SqlExpression)base.Visit(exp);
if (exp != null) {
if (!IsPredicateExpression(exp)) {
exp = ConvertValueToPredicate(exp);
}
}
return exp;
}
/// <summary>
/// Any remaining calls to VisitExpression expect a 'Bit' when there's
/// a boolean expression.
/// </summary>
internal override SqlExpression VisitExpression(SqlExpression exp) {
exp = (SqlExpression)base.Visit(exp);
if (exp != null) {
if (IsPredicateExpression(exp)) {
exp = ConvertPredicateToValue(exp);
}
}
return exp;
}
[SuppressMessage("Microsoft.Maintainability", "CA1502:AvoidExcessiveComplexity", Justification = "These issues are related to our use of if-then and case statements for node types, which adds to the complexity count however when reviewed they are easy to navigate and understand.")]
private static bool IsPredicateExpression(SqlExpression exp) {
switch (exp.NodeType) {
case SqlNodeType.And:
case SqlNodeType.Or:
case SqlNodeType.Not:
case SqlNodeType.Not2V:
case SqlNodeType.EQ:
case SqlNodeType.EQ2V:
case SqlNodeType.NE:
case SqlNodeType.NE2V:
case SqlNodeType.GE:
case SqlNodeType.GT:
case SqlNodeType.LE:
case SqlNodeType.LT:
case SqlNodeType.Exists:
case SqlNodeType.Between:
case SqlNodeType.In:
case SqlNodeType.Like:
case SqlNodeType.IsNotNull:
case SqlNodeType.IsNull:
return true;
case SqlNodeType.Lift:
return IsPredicateExpression(((SqlLift)exp).Expression);
default:
return false;
}
}
}
}

View File

@@ -0,0 +1,90 @@
using System.Data.Linq.Mapping;
namespace System.Data.Linq.SqlClient {
/// <summary>
/// Locate cases in which there is a 'Bit' but a 'Predicate' is expected or vice-versa.
/// Transform these expressions into expressions of the expected type.
/// </summary>
internal class SqlBooleanizer {
private class Booleanizer : SqlBooleanMismatchVisitor {
private SqlFactory sql;
internal Booleanizer(TypeSystemProvider typeProvider, MetaModel model) {
this.sql = new SqlFactory(typeProvider, model);
}
internal override SqlSelect VisitSelect(SqlSelect select) {
// DevDiv 179191
if (select.Where != null && select.Where.NodeType == SqlNodeType.Coalesce) {
SqlBinary bin = (SqlBinary)select.Where;
if (bin.Right.NodeType == SqlNodeType.Value) {
SqlValue value = (SqlValue)bin.Right;
if (value.Value != null && value.Value.GetType() == typeof(bool) && (bool)value.Value == false) {
select.Where = bin.Left;
}
}
}
return base.VisitSelect(select);
}
internal override SqlExpression ConvertValueToPredicate(SqlExpression valueExpression) {
// Transform the 'Bit' expression into a 'Predicate' by forming the
// following operation:
// OriginalExpr = 1
// Yukon and later could also handle:
// OriginalExpr = 'true'
// but Sql2000 does not support this.
return new SqlBinary(SqlNodeType.EQ,
valueExpression.ClrType, sql.TypeProvider.From(typeof(bool)),
valueExpression,
sql.Value(typeof(bool), valueExpression.SqlType, true, false, valueExpression.SourceExpression)
);
}
internal override SqlExpression ConvertPredicateToValue(SqlExpression predicateExpression) {
// Transform the 'Predicate' expression into a 'Bit' by forming the
// following operation:
// CASE
// WHEN predicateExpression THEN 1
// ELSE NOT(predicateExpression) THEN 0
// ELSE NULL
// END
// Possible simplification to the generated SQL would be to detect when 'predicateExpression'
// is SqlUnary(NOT) and use its operand with the literal 1 and 0 below swapped.
SqlExpression valueTrue = sql.ValueFromObject(true, false, predicateExpression.SourceExpression);
SqlExpression valueFalse = sql.ValueFromObject(false, false, predicateExpression.SourceExpression);
if (SqlExpressionNullability.CanBeNull(predicateExpression) != false) {
SqlExpression valueNull = sql.Value(valueTrue.ClrType, valueTrue.SqlType, null, false, predicateExpression.SourceExpression);
return new SqlSearchedCase(
predicateExpression.ClrType,
new SqlWhen[] {
new SqlWhen(predicateExpression, valueTrue),
new SqlWhen(new SqlUnary(SqlNodeType.Not, predicateExpression.ClrType, predicateExpression.SqlType, predicateExpression, predicateExpression.SourceExpression), valueFalse)
},
valueNull,
predicateExpression.SourceExpression
);
}
else {
return new SqlSearchedCase(
predicateExpression.ClrType,
new SqlWhen[] { new SqlWhen(predicateExpression, valueTrue) },
valueFalse,
predicateExpression.SourceExpression
);
}
}
}
/// <summary>
/// Rationalize boolean expressions for the given node.
/// </summary>
internal static SqlNode Rationalize(SqlNode node, TypeSystemProvider typeProvider, MetaModel model) {
return new Booleanizer(typeProvider, model).Visit(node);
}
}
}

View File

@@ -0,0 +1,238 @@
using System;
using System.Collections.Generic;
using System.Data.Linq;
using System.Data.Linq.Provider;
using System.Diagnostics.CodeAnalysis;
namespace System.Data.Linq.SqlClient {
/// <summary>
/// SQL with CASE statements is harder to read. This visitor attempts to reduce CASE
/// statements to equivalent (but easier to read) logic.
/// </summary>
internal class SqlCaseSimplifier {
internal static SqlNode Simplify(SqlNode node, SqlFactory sql) {
return new Visitor(sql).Visit(node);
}
class Visitor : SqlVisitor {
SqlFactory sql;
internal Visitor(SqlFactory sql) {
this.sql = sql;
}
/// <summary>
/// Replace equals and not equals:
///
/// | CASE XXX | CASE XXX CASE XXX
/// | WHEN AAA THEN MMMM | != RRRR ===> WHEN AAA THEN (MMMM != RRRR) ==> WHEN AAA THEN true
/// | WHEN BBB THEN NNNN | WHEN BBB THEN (NNNN != RRRR) WHEN BBB THEN false
/// | etc. | etc. etc.
/// | ELSE OOOO | ELSE (OOOO != RRRR) ELSE true
/// | END END END
///
/// Where MMMM, NNNN and RRRR are constants.
/// </summary>
internal override SqlExpression VisitBinaryOperator(SqlBinary bo) {
switch (bo.NodeType) {
case SqlNodeType.EQ:
case SqlNodeType.NE:
case SqlNodeType.EQ2V:
case SqlNodeType.NE2V:
if (bo.Left.NodeType == SqlNodeType.SimpleCase &&
bo.Right.NodeType == SqlNodeType.Value &&
AreCaseWhenValuesConstant((SqlSimpleCase)bo.Left)) {
return this.DistributeOperatorIntoCase(bo.NodeType, (SqlSimpleCase)bo.Left, bo.Right);
}
else if (bo.Right.NodeType == SqlNodeType.SimpleCase &&
bo.Left.NodeType==SqlNodeType.Value &&
AreCaseWhenValuesConstant((SqlSimpleCase)bo.Right)) {
return this.DistributeOperatorIntoCase(bo.NodeType, (SqlSimpleCase)bo.Right, bo.Left);
}
break;
}
return base.VisitBinaryOperator(bo);
}
/// <summary>
/// Checks to see if all SqlSimpleCase when values are of Value type.
/// </summary>
[SuppressMessage("Microsoft.Performance", "CA1822:MarkMembersAsStatic", Justification="Unknown reason.")]
internal bool AreCaseWhenValuesConstant(SqlSimpleCase sc) {
foreach (SqlWhen when in sc.Whens) {
if (when.Value.NodeType != SqlNodeType.Value) {
return false;
}
}
return true;
}
/// <summary>
/// Helper for VisitBinaryOperator. Builds the new case with distributed valueds.
/// </summary>
private SqlExpression DistributeOperatorIntoCase(SqlNodeType nt, SqlSimpleCase sc, SqlExpression expr) {
if (nt!=SqlNodeType.EQ && nt!=SqlNodeType.NE && nt!=SqlNodeType.EQ2V && nt!=SqlNodeType.NE2V)
throw Error.ArgumentOutOfRange("nt");
object val = Eval(expr);
List<SqlExpression> values = new List<SqlExpression>();
List<SqlExpression> matches = new List<SqlExpression>();
foreach(SqlWhen when in sc.Whens) {
matches.Add(when.Match);
object whenVal = Eval(when.Value);
bool eq = when.Value.SqlType.AreValuesEqual(whenVal, val);
values.Add(sql.ValueFromObject((nt==SqlNodeType.EQ || nt==SqlNodeType.EQ2V) == eq, false, sc.SourceExpression));
}
return this.VisitExpression(sql.Case(typeof(bool), sc.Expression, matches, values, sc.SourceExpression));
}
internal override SqlExpression VisitSimpleCase(SqlSimpleCase c) {
c.Expression = this.VisitExpression(c.Expression);
int compareWhen = 0;
// Find the ELSE if it exists.
for (int i = 0, n = c.Whens.Count; i < n; i++) {
if (c.Whens[i].Match == null) {
compareWhen = i;
break;
}
}
c.Whens[compareWhen].Match = VisitExpression(c.Whens[compareWhen].Match);
c.Whens[compareWhen].Value = VisitExpression(c.Whens[compareWhen].Value);
// Compare each other when value to the compare when
List<SqlWhen> newWhens = new List<SqlWhen>();
bool allValuesLiteral = true;
for (int i = 0, n = c.Whens.Count; i < n; i++) {
if (compareWhen != i) {
SqlWhen when = c.Whens[i];
when.Match = this.VisitExpression(when.Match);
when.Value = this.VisitExpression(when.Value);
if (!SqlComparer.AreEqual(c.Whens[compareWhen].Value, when.Value)) {
newWhens.Add(when);
}
allValuesLiteral = allValuesLiteral && when.Value.NodeType == SqlNodeType.Value;
}
}
newWhens.Add(c.Whens[compareWhen]);
// Did everything reduce to a single CASE?
SqlExpression rewrite = TryToConsolidateAllValueExpressions(newWhens.Count, c.Whens[compareWhen].Value);
if (rewrite != null)
return rewrite;
// Can it be a conjuction (or disjunction) of clauses?
rewrite = TryToWriteAsSimpleBooleanExpression(c.ClrType, c.Expression, newWhens, allValuesLiteral);
if (rewrite != null)
return rewrite;
// Can any WHEN clauses be reduced to fall into the ELSE clause?
rewrite = TryToWriteAsReducedCase(c.ClrType, c.Expression, newWhens, c.Whens[compareWhen].Match, c.Whens.Count);
if (rewrite != null)
return rewrite;
return c;
}
/// <summary>
/// When there is exactly one when clause in the CASE:
///
/// CASE XXX
/// WHEN AAA THEN YYY ===> YYY
/// END
///
/// Then, just reduce it to the value.
/// </summary>
[SuppressMessage("Microsoft.Performance", "CA1822:MarkMembersAsStatic", Justification="Unknown reason.")]
private SqlExpression TryToConsolidateAllValueExpressions(int valueCount, SqlExpression value) {
if (valueCount == 1) {
return value;
}
return null;
}
/// <summary>
/// For CASE statements which represent boolean values:
///
/// CASE XXX
/// WHEN AAA THEN true ===> (XXX==AAA) || (XXX==BBB)
/// WHEN BBB THEN true
/// ELSE false
/// etc.
/// END
///
/// Also,
///
/// CASE XXX
/// WHEN AAA THEN false ===> (XXX!=AAA) && (XXX!=BBB)
/// WHEN BBB THEN false
/// ELSE true
/// etc.
/// END
///
/// The reduce to a conjunction or disjunction of equality or inequality.
/// The possibility of NULL in XXX is taken into account.
/// </summary>
private SqlExpression TryToWriteAsSimpleBooleanExpression(Type caseType, SqlExpression discriminator, List<SqlWhen> newWhens, bool allValuesLiteral) {
SqlExpression rewrite = null;
if (caseType == typeof(bool) && allValuesLiteral) {
bool? holdsNull = SqlExpressionNullability.CanBeNull(discriminator);
// The discriminator can't hold a NULL.
// In this case, we don't need the special fallback that CASE-ELSE gives.
// We can just construct a boolean operation.
bool? whenValue = null;
for (int i = 0; i < newWhens.Count; ++i) {
SqlValue lit = (SqlValue)newWhens[i].Value; // Must be SqlValue because of allValuesLiteral.
bool value = (bool)lit.Value; // Must be bool because of caseType==typeof(bool).
if (newWhens[i].Match != null) { // Skip the ELSE
if (value) {
rewrite = sql.OrAccumulate(rewrite, sql.Binary(SqlNodeType.EQ, discriminator, newWhens[i].Match));
}
else {
rewrite = sql.AndAccumulate(rewrite, sql.Binary(SqlNodeType.NE, discriminator, newWhens[i].Match));
}
}
else {
whenValue = value;
}
}
// If it could possibly hold null values.
if (holdsNull != false && whenValue != null) {
if (whenValue == true) {
rewrite = sql.OrAccumulate(rewrite, sql.Unary(SqlNodeType.IsNull, discriminator, discriminator.SourceExpression));
}
else {
rewrite = sql.AndAccumulate(rewrite, sql.Unary(SqlNodeType.IsNotNull, discriminator, discriminator.SourceExpression));
}
}
}
return rewrite;
}
/// <summary>
/// Remove any WHEN clauses which have the same value as ELSE.
///
/// CASE XXX CASE XXX
/// WHEN AAA THEN YYY ===> WHEN AAA THEN YYY
/// WHEN BBB THEN ZZZ WHEN CCC THEN YYY
/// WHEN CCC THEN YYY ELSE ZZZ
/// ELSE ZZZ END
/// END
///
/// </summary>
[SuppressMessage("Microsoft.Performance", "CA1822:MarkMembersAsStatic", Justification="Unknown reason.")]
private SqlExpression TryToWriteAsReducedCase(Type caseType, SqlExpression discriminator, List<SqlWhen> newWhens, SqlExpression elseCandidate, int originalWhenCount) {
if (newWhens.Count != originalWhenCount) {
// Some whens were the same as the comparand.
if (elseCandidate == null) {
// -and- the comparand is ELSE (value == null).
// In this case, simplify the CASE to eliminate everything equivalent to ELSE.
return new SqlSimpleCase(caseType, discriminator, newWhens, discriminator.SourceExpression);
}
}
return null;
}
}
}
}

View File

@@ -0,0 +1,185 @@
using System;
using System.Collections.Generic;
using System.Linq.Expressions;
using System.Reflection;
using System.Data.Linq;
using System.Data.Linq.Mapping;
using System.Data.Linq.Provider;
using System.Linq;
using System.Data.Linq.SqlClient;
using System.Diagnostics.CodeAnalysis;
using System.Diagnostics;
namespace System.Data.Linq.SqlClient {
// partions select expressions and common subexpressions into scalar and non-scalar pieces by
// wrapping scalar pieces floating column nodes.
internal class SqlColumnizer {
ColumnNominator nominator;
ColumnDeclarer declarer;
internal SqlColumnizer() {
this.nominator = new ColumnNominator();
this.declarer = new ColumnDeclarer();
}
internal SqlExpression ColumnizeSelection(SqlExpression selection) {
return this.declarer.Declare(selection, this.nominator.Nominate(selection));
}
internal static bool CanBeColumn(SqlExpression expression) {
return ColumnNominator.CanBeColumn(expression);
}
class ColumnDeclarer : SqlVisitor {
HashSet<SqlExpression> candidates;
internal ColumnDeclarer() {
}
internal SqlExpression Declare(SqlExpression expression, HashSet<SqlExpression> candidates) {
this.candidates = candidates;
return (SqlExpression)this.Visit(expression);
}
internal override SqlNode Visit(SqlNode node) {
SqlExpression expr = node as SqlExpression;
if (expr != null) {
if (this.candidates.Contains(expr)) {
if (expr.NodeType == SqlNodeType.Column ||
expr.NodeType == SqlNodeType.ColumnRef) {
return expr;
}
else {
return new SqlColumn(expr.ClrType, expr.SqlType, null, null, expr, expr.SourceExpression);
}
}
}
return base.Visit(node);
}
}
class ColumnNominator : SqlVisitor {
bool isBlocked;
HashSet<SqlExpression> candidates;
internal HashSet<SqlExpression> Nominate(SqlExpression expression) {
this.candidates = new HashSet<SqlExpression>();
this.isBlocked = false;
this.Visit(expression);
return this.candidates;
}
internal override SqlNode Visit(SqlNode node) {
SqlExpression expression = node as SqlExpression;
if (expression != null) {
bool saveIsBlocked = this.isBlocked;
this.isBlocked = false;
if (CanRecurseColumnize(expression)) {
base.Visit(expression);
}
if (!this.isBlocked) {
if (CanBeColumn(expression)) {
this.candidates.Add(expression);
}
else {
this.isBlocked = true;
}
}
this.isBlocked |= saveIsBlocked;
}
return node;
}
internal override SqlExpression VisitSimpleCase(SqlSimpleCase c) {
c.Expression = this.VisitExpression(c.Expression);
for (int i = 0, n = c.Whens.Count; i < n; i++) {
// Don't walk down the match side. This can't be a column.
c.Whens[i].Value = this.VisitExpression(c.Whens[i].Value);
}
return c;
}
internal override SqlExpression VisitTypeCase(SqlTypeCase tc) {
tc.Discriminator = this.VisitExpression(tc.Discriminator);
for (int i = 0, n = tc.Whens.Count; i < n; i++) {
// Don't walk down the match side. This can't be a column.
tc.Whens[i].TypeBinding = this.VisitExpression(tc.Whens[i].TypeBinding);
}
return tc;
}
internal override SqlExpression VisitClientCase(SqlClientCase c) {
c.Expression = this.VisitExpression(c.Expression);
for (int i = 0, n = c.Whens.Count; i < n; i++) {
// Don't walk down the match side. This can't be a column.
c.Whens[i].Value = this.VisitExpression(c.Whens[i].Value);
}
return c;
}
private static bool CanRecurseColumnize(SqlExpression expr) {
switch (expr.NodeType) {
case SqlNodeType.AliasRef:
case SqlNodeType.ColumnRef:
case SqlNodeType.Column:
case SqlNodeType.Multiset:
case SqlNodeType.Element:
case SqlNodeType.ScalarSubSelect:
case SqlNodeType.Exists:
case SqlNodeType.ClientQuery:
case SqlNodeType.SharedExpressionRef:
case SqlNodeType.Link:
case SqlNodeType.Nop:
case SqlNodeType.Value:
case SqlNodeType.Select:
return false;
default:
return true;
}
}
[SuppressMessage("Microsoft.Maintainability", "CA1502:AvoidExcessiveComplexity", Justification = "These issues are related to our use of if-then and case statements for node types, which adds to the complexity count however when reviewed they are easy to navigate and understand.")]
private static bool IsClientOnly(SqlExpression expr) {
switch (expr.NodeType) {
case SqlNodeType.ClientCase:
case SqlNodeType.TypeCase:
case SqlNodeType.ClientArray:
case SqlNodeType.Grouping:
case SqlNodeType.DiscriminatedType:
case SqlNodeType.SharedExpression:
case SqlNodeType.SimpleExpression:
case SqlNodeType.AliasRef:
case SqlNodeType.Multiset:
case SqlNodeType.Element:
case SqlNodeType.ClientQuery:
case SqlNodeType.SharedExpressionRef:
case SqlNodeType.Link:
case SqlNodeType.Nop:
return true;
case SqlNodeType.OuterJoinedValue:
return IsClientOnly(((SqlUnary)expr).Operand);
default:
return false;
}
}
internal static bool CanBeColumn(SqlExpression expression) {
if (!IsClientOnly(expression)
&& expression.NodeType != SqlNodeType.Column
&& expression.SqlType.CanBeColumn) {
switch (expression.NodeType) {
case SqlNodeType.MethodCall:
case SqlNodeType.Member:
case SqlNodeType.New:
return PostBindDotNetConverter.CanConvert(expression);
default:
return true;
}
}
return false;
}
}
}
}

View File

@@ -0,0 +1,303 @@
using System;
using System.Collections.Generic;
using System.Data.Linq;
using System.Diagnostics.CodeAnalysis;
namespace System.Data.Linq.SqlClient {
/// <summary>
/// Compare two trees for value equality. Implemented as a parallel visitor.
/// </summary>
internal class SqlComparer {
internal SqlComparer() {
}
[SuppressMessage("Microsoft.Performance", "CA1800:DoNotCastUnnecessarily", Justification = "[....]: Cast is dependent on node type and casts do not happen unecessarily in a single code path.")]
[SuppressMessage("Microsoft.Maintainability", "CA1506:AvoidExcessiveClassCoupling", Justification = "These issues are related to our use of if-then and case statements for node types, which adds to the complexity count however when reviewed they are easy to navigate and understand.")]
[SuppressMessage("Microsoft.Maintainability", "CA1505:AvoidUnmaintainableCode", Justification = "These issues are related to our use of if-then and case statements for node types, which adds to the complexity count however when reviewed they are easy to navigate and understand.")]
[SuppressMessage("Microsoft.Maintainability", "CA1502:AvoidExcessiveComplexity", Justification = "These issues are related to our use of if-then and case statements for node types, which adds to the complexity count however when reviewed they are easy to navigate and understand.")]
internal static bool AreEqual(SqlNode node1, SqlNode node2) {
if (node1 == node2)
return true;
if (node1 == null || node2 == null)
return false;
if (node1.NodeType == SqlNodeType.SimpleCase)
node1 = UnwrapTrivialCaseExpression((SqlSimpleCase)node1);
if (node2.NodeType == SqlNodeType.SimpleCase)
node2 = UnwrapTrivialCaseExpression((SqlSimpleCase)node2);
if (node1.NodeType != node2.NodeType) {
// allow expression sets to compare against single expressions
if (node1.NodeType == SqlNodeType.ExprSet) {
SqlExprSet eset = (SqlExprSet)node1;
for (int i = 0, n = eset.Expressions.Count; i < n; i++) {
if (AreEqual(eset.Expressions[i], node2))
return true;
}
}
else if (node2.NodeType == SqlNodeType.ExprSet) {
SqlExprSet eset = (SqlExprSet)node2;
for (int i = 0, n = eset.Expressions.Count; i < n; i++) {
if (AreEqual(node1, eset.Expressions[i]))
return true;
}
}
return false;
}
if (node1.Equals(node2))
return true;
switch (node1.NodeType) {
case SqlNodeType.Not:
case SqlNodeType.Not2V:
case SqlNodeType.Negate:
case SqlNodeType.BitNot:
case SqlNodeType.IsNull:
case SqlNodeType.IsNotNull:
case SqlNodeType.Count:
case SqlNodeType.Max:
case SqlNodeType.Min:
case SqlNodeType.Sum:
case SqlNodeType.Avg:
case SqlNodeType.Stddev:
case SqlNodeType.ValueOf:
case SqlNodeType.OuterJoinedValue:
case SqlNodeType.ClrLength:
return AreEqual(((SqlUnary)node1).Operand, ((SqlUnary)node2).Operand);
case SqlNodeType.Add:
case SqlNodeType.Sub:
case SqlNodeType.Mul:
case SqlNodeType.Div:
case SqlNodeType.Mod:
case SqlNodeType.BitAnd:
case SqlNodeType.BitOr:
case SqlNodeType.BitXor:
case SqlNodeType.And:
case SqlNodeType.Or:
case SqlNodeType.GE:
case SqlNodeType.GT:
case SqlNodeType.LE:
case SqlNodeType.LT:
case SqlNodeType.EQ:
case SqlNodeType.NE:
case SqlNodeType.EQ2V:
case SqlNodeType.NE2V:
case SqlNodeType.Concat:
SqlBinary firstNode = (SqlBinary)node1;
SqlBinary secondNode = (SqlBinary)node2;
return AreEqual(firstNode.Left, secondNode.Left)
&& AreEqual(firstNode.Right, secondNode.Right);
case SqlNodeType.Convert:
case SqlNodeType.Treat: {
SqlUnary sun1 = (SqlUnary)node1;
SqlUnary sun2 = (SqlUnary)node2;
return sun1.ClrType == sun2.ClrType && sun1.SqlType == sun2.SqlType && AreEqual(sun1.Operand, sun2.Operand);
}
case SqlNodeType.Between: {
SqlBetween b1 = (SqlBetween)node1;
SqlBetween b2 = (SqlBetween)node1;
return AreEqual(b1.Expression, b2.Expression) &&
AreEqual(b1.Start, b2.Start) &&
AreEqual(b1.End, b2.End);
}
case SqlNodeType.Parameter:
return node1 == node2;
case SqlNodeType.Alias:
return AreEqual(((SqlAlias)node1).Node, ((SqlAlias)node2).Node);
case SqlNodeType.AliasRef:
return AreEqual(((SqlAliasRef)node1).Alias, ((SqlAliasRef)node2).Alias);
case SqlNodeType.Column:
SqlColumn col1 = (SqlColumn)node1;
SqlColumn col2 = (SqlColumn)node2;
return col1 == col2 || (col1.Expression != null && col2.Expression != null && AreEqual(col1.Expression, col2.Expression));
case SqlNodeType.Table:
return ((SqlTable)node1).MetaTable == ((SqlTable)node2).MetaTable;
case SqlNodeType.Member:
return (((SqlMember)node1).Member == ((SqlMember)node2).Member) &&
AreEqual(((SqlMember)node1).Expression, ((SqlMember)node2).Expression);
case SqlNodeType.ColumnRef:
SqlColumnRef cref1 = (SqlColumnRef)node1;
SqlColumnRef cref2 = (SqlColumnRef)node2;
return GetBaseColumn(cref1) == GetBaseColumn(cref2);
case SqlNodeType.Value:
return Object.Equals(((SqlValue)node1).Value, ((SqlValue)node2).Value);
case SqlNodeType.TypeCase: {
SqlTypeCase c1 = (SqlTypeCase)node1;
SqlTypeCase c2 = (SqlTypeCase)node2;
if (!AreEqual(c1.Discriminator, c2.Discriminator)) {
return false;
}
if (c1.Whens.Count != c2.Whens.Count) {
return false;
}
for (int i = 0, c = c1.Whens.Count; i < c; ++i) {
if (!AreEqual(c1.Whens[i].Match, c2.Whens[i].Match)) {
return false;
}
if (!AreEqual(c1.Whens[i].TypeBinding, c2.Whens[i].TypeBinding)) {
return false;
}
}
return true;
}
case SqlNodeType.SearchedCase: {
SqlSearchedCase c1 = (SqlSearchedCase)node1;
SqlSearchedCase c2 = (SqlSearchedCase)node2;
if (c1.Whens.Count != c2.Whens.Count)
return false;
for (int i = 0, n = c1.Whens.Count; i < n; i++) {
if (!AreEqual(c1.Whens[i].Match, c2.Whens[i].Match) ||
!AreEqual(c1.Whens[i].Value, c2.Whens[i].Value))
return false;
}
return AreEqual(c1.Else, c2.Else);
}
case SqlNodeType.ClientCase: {
SqlClientCase c1 = (SqlClientCase)node1;
SqlClientCase c2 = (SqlClientCase)node2;
if (c1.Whens.Count != c2.Whens.Count)
return false;
for (int i = 0, n = c1.Whens.Count; i < n; i++) {
if (!AreEqual(c1.Whens[i].Match, c2.Whens[i].Match) ||
!AreEqual(c1.Whens[i].Value, c2.Whens[i].Value))
return false;
}
return true;
}
case SqlNodeType.DiscriminatedType: {
SqlDiscriminatedType dt1 = (SqlDiscriminatedType)node1;
SqlDiscriminatedType dt2 = (SqlDiscriminatedType)node2;
return AreEqual(dt1.Discriminator, dt2.Discriminator);
}
case SqlNodeType.SimpleCase: {
SqlSimpleCase c1 = (SqlSimpleCase)node1;
SqlSimpleCase c2 = (SqlSimpleCase)node2;
if (c1.Whens.Count != c2.Whens.Count)
return false;
for (int i = 0, n = c1.Whens.Count; i < n; i++) {
if (!AreEqual(c1.Whens[i].Match, c2.Whens[i].Match) ||
!AreEqual(c1.Whens[i].Value, c2.Whens[i].Value))
return false;
}
return true;
}
case SqlNodeType.Like: {
SqlLike like1 = (SqlLike)node1;
SqlLike like2 = (SqlLike)node2;
return AreEqual(like1.Expression, like2.Expression) &&
AreEqual(like1.Pattern, like2.Pattern) &&
AreEqual(like1.Escape, like2.Escape);
}
case SqlNodeType.Variable: {
SqlVariable v1 = (SqlVariable)node1;
SqlVariable v2 = (SqlVariable)node2;
return v1.Name == v2.Name;
}
case SqlNodeType.FunctionCall: {
SqlFunctionCall f1 = (SqlFunctionCall)node1;
SqlFunctionCall f2 = (SqlFunctionCall)node2;
if (f1.Name != f2.Name)
return false;
if (f1.Arguments.Count != f2.Arguments.Count)
return false;
for (int i = 0, n = f1.Arguments.Count; i < n; i++) {
if (!AreEqual(f1.Arguments[i], f2.Arguments[i]))
return false;
}
return true;
}
case SqlNodeType.Link: {
SqlLink l1 = (SqlLink)node1;
SqlLink l2 = (SqlLink)node2;
if (!MetaPosition.AreSameMember(l1.Member.Member, l2.Member.Member)) {
return false;
}
if (!AreEqual(l1.Expansion, l2.Expansion)) {
return false;
}
if (l1.KeyExpressions.Count != l2.KeyExpressions.Count) {
return false;
}
for (int i = 0, c = l1.KeyExpressions.Count; i < c; ++i) {
if (!AreEqual(l1.KeyExpressions[i], l2.KeyExpressions[i])) {
return false;
}
}
return true;
}
case SqlNodeType.ExprSet:
SqlExprSet es1 = (SqlExprSet)node1;
SqlExprSet es2 = (SqlExprSet)node2;
if (es1.Expressions.Count != es2.Expressions.Count)
return false;
for(int i = 0, n = es1.Expressions.Count; i < n; i++) {
if (!AreEqual(es1.Expressions[i], es2.Expressions[i]))
return false;
}
return true;
case SqlNodeType.OptionalValue:
SqlOptionalValue ov1 = (SqlOptionalValue)node1;
SqlOptionalValue ov2 = (SqlOptionalValue)node2;
return AreEqual(ov1.Value, ov2.Value);
case SqlNodeType.Row:
case SqlNodeType.UserQuery:
case SqlNodeType.StoredProcedureCall:
case SqlNodeType.UserRow:
case SqlNodeType.UserColumn:
case SqlNodeType.Multiset:
case SqlNodeType.ScalarSubSelect:
case SqlNodeType.Element:
case SqlNodeType.Exists:
case SqlNodeType.Join:
case SqlNodeType.Select:
case SqlNodeType.New:
case SqlNodeType.ClientQuery:
case SqlNodeType.ClientArray:
case SqlNodeType.Insert:
case SqlNodeType.Update:
case SqlNodeType.Delete:
case SqlNodeType.MemberAssign:
case SqlNodeType.Assign:
case SqlNodeType.Block:
case SqlNodeType.Union:
case SqlNodeType.DoNotVisit:
case SqlNodeType.MethodCall:
case SqlNodeType.Nop:
default:
return false;
}
}
private static SqlColumn GetBaseColumn(SqlColumnRef cref) {
while (cref != null && cref.Column.Expression != null) {
SqlColumnRef cr = cref.Column.Expression as SqlColumnRef;
if (cr != null) {
cref = cr;
continue;
}
else {
break;
}
}
return cref.Column;
}
private static SqlExpression UnwrapTrivialCaseExpression(SqlSimpleCase sc) {
if (sc.Whens.Count != 1) {
return sc;
}
if (!SqlComparer.AreEqual(sc.Expression, sc.Whens[0].Match)) {
return sc;
}
SqlExpression result = sc.Whens[0].Value;
if (result.NodeType == SqlNodeType.SimpleCase) {
return UnwrapTrivialCaseExpression((SqlSimpleCase)result);
}
return result;
}
}
}

View File

@@ -0,0 +1,48 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Linq.Expressions;
namespace System.Data.Linq.SqlClient {
using System.Data.Linq;
/// <summary>
/// Turn CROSS APPLY into CROSS JOIN when the right side
/// of the apply doesn't reference anything on the left side.
///
/// Any query which has a CROSS APPLY which cannot be converted to
/// a CROSS JOIN is annotated so that we can give a meaningful
/// error message later for SQL2K.
/// </summary>
internal class SqlCrossApplyToCrossJoin {
internal static SqlNode Reduce(SqlNode node, SqlNodeAnnotations annotations) {
Reducer r = new Reducer();
r.Annotations = annotations;
return r.Visit(node);
}
class Reducer : SqlVisitor {
internal SqlNodeAnnotations Annotations;
internal override SqlSource VisitJoin(SqlJoin join) {
if (join.JoinType == SqlJoinType.CrossApply) {
// Look down the left side to see what table aliases are produced.
HashSet<SqlAlias> p = SqlGatherProducedAliases.Gather(join.Left);
// Look down the right side to see what table aliases are consumed.
HashSet<SqlAlias> c = SqlGatherConsumedAliases.Gather(join.Right);
// Look at each consumed alias and see if they are mentioned in produced.
if (p.Overlaps(c)) {
Annotations.Add(join, new SqlServerCompatibilityAnnotation(Strings.SourceExpressionAnnotation(join.SourceExpression), SqlProvider.ProviderMode.Sql2000));
// Can't reduce because this consumed alias is produced on the left.
return base.VisitJoin(join);
}
// Can turn this into a CROSS JOIN
join.JoinType = SqlJoinType.Cross;
return VisitJoin(join);
}
return base.VisitJoin(join);
}
}
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,448 @@
using System;
using System.Collections.Generic;
using System.Data.Linq;
using System.Data.Linq.Provider;
using System.Diagnostics.CodeAnalysis;
namespace System.Data.Linq.SqlClient {
internal class SqlDuplicator {
DuplicatingVisitor superDuper;
internal SqlDuplicator()
: this(true) {
}
internal SqlDuplicator(bool ignoreExternalRefs) {
this.superDuper = new DuplicatingVisitor(ignoreExternalRefs);
}
internal static SqlNode Copy(SqlNode node) {
if (node == null)
return null;
switch (node.NodeType) {
case SqlNodeType.ColumnRef:
case SqlNodeType.Value:
case SqlNodeType.Parameter:
case SqlNodeType.Variable:
return node;
default:
return new SqlDuplicator().Duplicate(node);
}
}
internal SqlNode Duplicate(SqlNode node) {
return this.superDuper.Visit(node);
}
[SuppressMessage("Microsoft.Maintainability", "CA1506:AvoidExcessiveClassCoupling", Justification="These issues are related to our use of if-then and case statements for node types, which adds to the complexity count however when reviewed they are easy to navigate and understand.")]
internal class DuplicatingVisitor : SqlVisitor {
Dictionary<SqlNode, SqlNode> nodeMap;
bool ingoreExternalRefs;
internal DuplicatingVisitor(bool ignoreExternalRefs) {
this.ingoreExternalRefs = ignoreExternalRefs;
this.nodeMap = new Dictionary<SqlNode, SqlNode>();
}
internal override SqlNode Visit(SqlNode node) {
if (node == null) {
return null;
}
SqlNode result = null;
if (this.nodeMap.TryGetValue(node, out result)) {
return result;
}
result = base.Visit(node);
this.nodeMap[node] = result;
return result;
}
internal override SqlExpression VisitDoNotVisit(SqlDoNotVisitExpression expr) {
// duplicator can duplicate through a do-no-visit node
return new SqlDoNotVisitExpression(this.VisitExpression(expr.Expression));
}
internal override SqlAlias VisitAlias(SqlAlias a) {
SqlAlias n = new SqlAlias(a.Node);
this.nodeMap[a] = n;
n.Node = this.Visit(a.Node);
n.Name = a.Name;
return n;
}
internal override SqlExpression VisitAliasRef(SqlAliasRef aref) {
if (this.ingoreExternalRefs && !this.nodeMap.ContainsKey(aref.Alias)) {
return aref;
}
return new SqlAliasRef((SqlAlias)this.Visit(aref.Alias));
}
internal override SqlRowNumber VisitRowNumber(SqlRowNumber rowNumber) {
List<SqlOrderExpression> orderBy = new List<SqlOrderExpression>();
foreach (SqlOrderExpression expr in rowNumber.OrderBy) {
orderBy.Add(new SqlOrderExpression(expr.OrderType, (SqlExpression)this.Visit(expr.Expression)));
}
return new SqlRowNumber(rowNumber.ClrType, rowNumber.SqlType, orderBy, rowNumber.SourceExpression);
}
internal override SqlExpression VisitBinaryOperator(SqlBinary bo) {
SqlExpression left = (SqlExpression)this.Visit(bo.Left);
SqlExpression right = (SqlExpression)this.Visit(bo.Right);
return new SqlBinary(bo.NodeType, bo.ClrType, bo.SqlType, left, right, bo.Method);
}
internal override SqlExpression VisitClientQuery(SqlClientQuery cq) {
SqlSubSelect query = (SqlSubSelect) this.VisitExpression(cq.Query);
SqlClientQuery nq = new SqlClientQuery(query);
for (int i = 0, n = cq.Arguments.Count; i < n; i++) {
nq.Arguments.Add(this.VisitExpression(cq.Arguments[i]));
}
for (int i = 0, n = cq.Parameters.Count; i < n; i++) {
nq.Parameters.Add((SqlParameter)this.VisitExpression(cq.Parameters[i]));
}
return nq;
}
internal override SqlExpression VisitJoinedCollection(SqlJoinedCollection jc) {
return new SqlJoinedCollection(jc.ClrType, jc.SqlType, this.VisitExpression(jc.Expression), this.VisitExpression(jc.Count), jc.SourceExpression);
}
internal override SqlExpression VisitClientArray(SqlClientArray scar) {
SqlExpression[] exprs = new SqlExpression[scar.Expressions.Count];
for (int i = 0, n = exprs.Length; i < n; i++) {
exprs[i] = this.VisitExpression(scar.Expressions[i]);
}
return new SqlClientArray(scar.ClrType, scar.SqlType, exprs, scar.SourceExpression);
}
internal override SqlExpression VisitTypeCase(SqlTypeCase tc) {
SqlExpression disc = VisitExpression(tc.Discriminator);
List<SqlTypeCaseWhen> whens = new List<SqlTypeCaseWhen>();
foreach(SqlTypeCaseWhen when in tc.Whens) {
whens.Add(new SqlTypeCaseWhen(VisitExpression(when.Match), VisitExpression(when.TypeBinding)));
}
return new SqlTypeCase(tc.ClrType, tc.SqlType, tc.RowType, disc, whens, tc.SourceExpression);
}
internal override SqlExpression VisitNew(SqlNew sox) {
SqlExpression[] args = new SqlExpression[sox.Args.Count];
SqlMemberAssign[] bindings = new SqlMemberAssign[sox.Members.Count];
for (int i = 0, n = args.Length; i < n; i++) {
args[i] = this.VisitExpression(sox.Args[i]);
}
for (int i = 0, n = bindings.Length; i < n; i++) {
bindings[i] = this.VisitMemberAssign(sox.Members[i]);
}
return new SqlNew(sox.MetaType, sox.SqlType, sox.Constructor, args, sox.ArgMembers, bindings, sox.SourceExpression);
}
internal override SqlNode VisitLink(SqlLink link) {
SqlExpression[] exprs = new SqlExpression[link.KeyExpressions.Count];
for (int i = 0, n = exprs.Length; i < n; i++) {
exprs[i] = this.VisitExpression(link.KeyExpressions[i]);
}
SqlLink newLink = new SqlLink(new object(), link.RowType, link.ClrType, link.SqlType, null, link.Member, exprs, null, link.SourceExpression);
this.nodeMap[link] = newLink;
// break the potential cyclic tree by visiting these after adding to the map
newLink.Expression = this.VisitExpression(link.Expression);
newLink.Expansion = this.VisitExpression(link.Expansion);
return newLink;
}
internal override SqlExpression VisitColumn(SqlColumn col) {
SqlColumn n = new SqlColumn(col.ClrType, col.SqlType, col.Name, col.MetaMember, null, col.SourceExpression);
this.nodeMap[col] = n;
n.Expression = this.VisitExpression(col.Expression);
n.Alias = (SqlAlias)this.Visit(col.Alias);
return n;
}
internal override SqlExpression VisitColumnRef(SqlColumnRef cref) {
if (this.ingoreExternalRefs && !this.nodeMap.ContainsKey(cref.Column)) {
return cref;
}
return new SqlColumnRef((SqlColumn)this.Visit(cref.Column));
}
internal override SqlStatement VisitDelete(SqlDelete sd) {
return new SqlDelete((SqlSelect)this.Visit(sd.Select), sd.SourceExpression);
}
internal override SqlExpression VisitElement(SqlSubSelect elem) {
return this.VisitMultiset(elem);
}
internal override SqlExpression VisitExists(SqlSubSelect sqlExpr) {
return new SqlSubSelect(sqlExpr.NodeType, sqlExpr.ClrType, sqlExpr.SqlType, (SqlSelect)this.Visit(sqlExpr.Select));
}
internal override SqlStatement VisitInsert(SqlInsert si) {
SqlInsert n = new SqlInsert(si.Table, this.VisitExpression(si.Expression), si.SourceExpression);
n.OutputKey = si.OutputKey;
n.OutputToLocal = si.OutputToLocal;
n.Row = this.VisitRow(si.Row);
return n;
}
internal override SqlSource VisitJoin(SqlJoin join) {
SqlSource left = this.VisitSource(join.Left);
SqlSource right = this.VisitSource(join.Right);
SqlExpression cond = (SqlExpression)this.Visit(join.Condition);
return new SqlJoin(join.JoinType, left, right, cond, join.SourceExpression);
}
internal override SqlExpression VisitValue(SqlValue value) {
return value;
}
internal override SqlNode VisitMember(SqlMember m) {
return new SqlMember(m.ClrType, m.SqlType, (SqlExpression)this.Visit(m.Expression), m.Member);
}
internal override SqlMemberAssign VisitMemberAssign(SqlMemberAssign ma) {
return new SqlMemberAssign(ma.Member, (SqlExpression)this.Visit(ma.Expression));
}
internal override SqlExpression VisitMultiset(SqlSubSelect sms) {
return new SqlSubSelect(sms.NodeType, sms.ClrType, sms.SqlType, (SqlSelect)this.Visit(sms.Select));
}
internal override SqlExpression VisitParameter(SqlParameter p) {
SqlParameter n = new SqlParameter(p.ClrType, p.SqlType, p.Name, p.SourceExpression);
n.Direction = p.Direction;
return n;
}
internal override SqlRow VisitRow(SqlRow row) {
SqlRow nrow = new SqlRow(row.SourceExpression);
foreach (SqlColumn c in row.Columns) {
nrow.Columns.Add((SqlColumn)this.Visit(c));
}
return nrow;
}
internal override SqlExpression VisitScalarSubSelect(SqlSubSelect ss) {
return new SqlSubSelect(SqlNodeType.ScalarSubSelect, ss.ClrType, ss.SqlType, this.VisitSequence(ss.Select));
}
internal override SqlSelect VisitSelect(SqlSelect select) {
SqlSource from = this.VisitSource(select.From);
List<SqlExpression> gex = null;
if (select.GroupBy.Count > 0) {
gex = new List<SqlExpression>(select.GroupBy.Count);
foreach (SqlExpression sqlExpr in select.GroupBy) {
gex.Add((SqlExpression)this.Visit(sqlExpr));
}
}
SqlExpression having = (SqlExpression)this.Visit(select.Having);
List<SqlOrderExpression> lex = null;
if (select.OrderBy.Count > 0) {
lex = new List<SqlOrderExpression>(select.OrderBy.Count);
foreach (SqlOrderExpression sox in select.OrderBy) {
SqlOrderExpression nsox = new SqlOrderExpression(sox.OrderType, (SqlExpression)this.Visit(sox.Expression));
lex.Add(nsox);
}
}
SqlExpression top = (SqlExpression)this.Visit(select.Top);
SqlExpression where = (SqlExpression)this.Visit(select.Where);
SqlRow row = (SqlRow)this.Visit(select.Row);
SqlExpression selection = this.VisitExpression(select.Selection);
SqlSelect n = new SqlSelect(selection, from, select.SourceExpression);
if (gex != null)
n.GroupBy.AddRange(gex);
n.Having = having;
if (lex != null)
n.OrderBy.AddRange(lex);
n.OrderingType = select.OrderingType;
n.Row = row;
n.Top = top;
n.IsDistinct = select.IsDistinct;
n.IsPercent = select.IsPercent;
n.Where = where;
n.DoNotOutput = select.DoNotOutput;
return n;
}
internal override SqlTable VisitTable(SqlTable tab) {
SqlTable nt = new SqlTable(tab.MetaTable, tab.RowType, tab.SqlRowType, tab.SourceExpression);
this.nodeMap[tab] = nt;
foreach (SqlColumn c in tab.Columns) {
nt.Columns.Add((SqlColumn)this.Visit(c));
}
return nt;
}
internal override SqlUserQuery VisitUserQuery(SqlUserQuery suq) {
List<SqlExpression> args = new List<SqlExpression>(suq.Arguments.Count);
foreach (SqlExpression expr in suq.Arguments) {
args.Add(this.VisitExpression(expr));
}
SqlExpression projection = this.VisitExpression(suq.Projection);
SqlUserQuery n = new SqlUserQuery(suq.QueryText, projection, args, suq.SourceExpression);
this.nodeMap[suq] = n;
foreach (SqlUserColumn suc in suq.Columns) {
SqlUserColumn dupSuc = new SqlUserColumn(suc.ClrType, suc.SqlType, suc.Query, suc.Name, suc.IsRequired, suc.SourceExpression);
this.nodeMap[suc] = dupSuc;
n.Columns.Add(dupSuc);
}
return n;
}
internal override SqlStoredProcedureCall VisitStoredProcedureCall(SqlStoredProcedureCall spc) {
List<SqlExpression> args = new List<SqlExpression>(spc.Arguments.Count);
foreach (SqlExpression expr in spc.Arguments) {
args.Add(this.VisitExpression(expr));
}
SqlExpression projection = this.VisitExpression(spc.Projection);
SqlStoredProcedureCall n = new SqlStoredProcedureCall(spc.Function, projection, args, spc.SourceExpression);
this.nodeMap[spc] = n;
foreach (SqlUserColumn suc in spc.Columns) {
n.Columns.Add((SqlUserColumn)this.Visit(suc));
}
return n;
}
internal override SqlExpression VisitUserColumn(SqlUserColumn suc) {
if (this.ingoreExternalRefs && !this.nodeMap.ContainsKey(suc)) {
return suc;
}
return new SqlUserColumn(suc.ClrType, suc.SqlType, suc.Query, suc.Name, suc.IsRequired, suc.SourceExpression);
}
internal override SqlExpression VisitUserRow(SqlUserRow row) {
return new SqlUserRow(row.RowType, row.SqlType, (SqlUserQuery)this.Visit(row.Query), row.SourceExpression);
}
internal override SqlExpression VisitTreat(SqlUnary t) {
return new SqlUnary(SqlNodeType.Treat, t.ClrType, t.SqlType, (SqlExpression)this.Visit(t.Operand), t.SourceExpression);
}
internal override SqlExpression VisitUnaryOperator(SqlUnary uo) {
return new SqlUnary(uo.NodeType, uo.ClrType, uo.SqlType, (SqlExpression)this.Visit(uo.Operand), uo.Method, uo.SourceExpression);
}
internal override SqlStatement VisitUpdate(SqlUpdate su) {
SqlSelect ss = (SqlSelect)this.Visit(su.Select);
List<SqlAssign> assignments = new List<SqlAssign>(su.Assignments.Count);
foreach (SqlAssign sa in su.Assignments) {
assignments.Add((SqlAssign)this.Visit(sa));
}
return new SqlUpdate(ss, assignments, su.SourceExpression);
}
internal override SqlStatement VisitAssign(SqlAssign sa) {
return new SqlAssign(this.VisitExpression(sa.LValue), this.VisitExpression(sa.RValue), sa.SourceExpression);
}
internal override SqlExpression VisitSearchedCase(SqlSearchedCase c) {
SqlExpression @else = this.VisitExpression(c.Else);
SqlWhen[] whens = new SqlWhen[c.Whens.Count];
for (int i = 0, n = whens.Length; i < n; i++) {
SqlWhen when = c.Whens[i];
whens[i] = new SqlWhen(this.VisitExpression(when.Match), this.VisitExpression(when.Value));
}
return new SqlSearchedCase(c.ClrType, whens, @else, c.SourceExpression);
}
internal override SqlExpression VisitClientCase(SqlClientCase c) {
SqlExpression expr = this.VisitExpression(c.Expression);
SqlClientWhen[] whens = new SqlClientWhen[c.Whens.Count];
for (int i = 0, n = whens.Length; i < n; i++) {
SqlClientWhen when = c.Whens[i];
whens[i] = new SqlClientWhen(this.VisitExpression(when.Match), this.VisitExpression(when.Value));
}
return new SqlClientCase(c.ClrType, expr, whens, c.SourceExpression);
}
internal override SqlExpression VisitSimpleCase(SqlSimpleCase c) {
SqlExpression expr = this.VisitExpression(c.Expression);
SqlWhen[] whens = new SqlWhen[c.Whens.Count];
for (int i = 0, n = whens.Length; i < n; i++) {
SqlWhen when = c.Whens[i];
whens[i] = new SqlWhen(this.VisitExpression(when.Match), this.VisitExpression(when.Value));
}
return new SqlSimpleCase(c.ClrType, expr, whens, c.SourceExpression);
}
internal override SqlNode VisitUnion(SqlUnion su) {
return new SqlUnion(this.Visit(su.Left), this.Visit(su.Right), su.All);
}
internal override SqlExpression VisitExprSet(SqlExprSet xs) {
SqlExpression[] exprs = new SqlExpression[xs.Expressions.Count];
for (int i = 0, n = exprs.Length; i < n; i++) {
exprs[i] = this.VisitExpression(xs.Expressions[i]);
}
return new SqlExprSet(xs.ClrType, exprs, xs.SourceExpression);
}
internal override SqlBlock VisitBlock(SqlBlock block) {
SqlBlock nb = new SqlBlock(block.SourceExpression);
foreach (SqlStatement stmt in block.Statements) {
nb.Statements.Add((SqlStatement)this.Visit(stmt));
}
return nb;
}
internal override SqlExpression VisitVariable(SqlVariable v) {
return v;
}
internal override SqlExpression VisitOptionalValue(SqlOptionalValue sov) {
SqlExpression hasValue = this.VisitExpression(sov.HasValue);
SqlExpression value = this.VisitExpression(sov.Value);
return new SqlOptionalValue(hasValue, value);
}
internal override SqlExpression VisitBetween(SqlBetween between) {
SqlBetween nbet = new SqlBetween(
between.ClrType,
between.SqlType,
this.VisitExpression(between.Expression),
this.VisitExpression(between.Start),
this.VisitExpression(between.End),
between.SourceExpression
);
return nbet;
}
internal override SqlExpression VisitIn(SqlIn sin) {
SqlIn nin = new SqlIn(sin.ClrType, sin.SqlType, this.VisitExpression(sin.Expression), sin.Values, sin.SourceExpression);
for (int i = 0, n = nin.Values.Count; i < n; i++) {
nin.Values[i] = this.VisitExpression(nin.Values[i]);
}
return nin;
}
internal override SqlExpression VisitLike(SqlLike like) {
return new SqlLike(
like.ClrType, like.SqlType,
this.VisitExpression(like.Expression),
this.VisitExpression(like.Pattern),
this.VisitExpression(like.Escape),
like.SourceExpression
);
}
internal override SqlExpression VisitFunctionCall(SqlFunctionCall fc) {
SqlExpression[] args = new SqlExpression[fc.Arguments.Count];
for (int i = 0, n = fc.Arguments.Count; i < n; i++) {
args[i] = this.VisitExpression(fc.Arguments[i]);
}
return new SqlFunctionCall(fc.ClrType, fc.SqlType, fc.Name, args, fc.SourceExpression);
}
internal override SqlExpression VisitTableValuedFunctionCall(SqlTableValuedFunctionCall fc) {
SqlExpression[] args = new SqlExpression[fc.Arguments.Count];
for (int i = 0, n = fc.Arguments.Count; i < n; i++) {
args[i] = this.VisitExpression(fc.Arguments[i]);
}
SqlTableValuedFunctionCall nfc = new SqlTableValuedFunctionCall(fc.RowType, fc.ClrType, fc.SqlType, fc.Name, args, fc.SourceExpression);
this.nodeMap[fc] = nfc;
foreach (SqlColumn c in fc.Columns) {
nfc.Columns.Add((SqlColumn)this.Visit(c));
}
return nfc;
}
internal override SqlExpression VisitMethodCall(SqlMethodCall mc) {
SqlExpression[] args = new SqlExpression[mc.Arguments.Count];
for (int i = 0, n = mc.Arguments.Count; i < n; i++) {
args[i] = this.VisitExpression(mc.Arguments[i]);
}
return new SqlMethodCall(mc.ClrType, mc.SqlType, mc.Method, this.VisitExpression(mc.Object), args, mc.SourceExpression);
}
internal override SqlExpression VisitSharedExpression(SqlSharedExpression sub) {
SqlSharedExpression n = new SqlSharedExpression(sub.Expression);
this.nodeMap[sub] = n;
n.Expression = this.VisitExpression(sub.Expression);
return n;
}
internal override SqlExpression VisitSharedExpressionRef(SqlSharedExpressionRef sref) {
if (this.ingoreExternalRefs && !this.nodeMap.ContainsKey(sref.SharedExpression)) {
return sref;
}
return new SqlSharedExpressionRef((SqlSharedExpression)this.Visit(sref.SharedExpression));
}
internal override SqlExpression VisitSimpleExpression(SqlSimpleExpression simple) {
SqlSimpleExpression n = new SqlSimpleExpression(this.VisitExpression(simple.Expression));
return n;
}
internal override SqlExpression VisitGrouping(SqlGrouping g) {
SqlGrouping n = new SqlGrouping(g.ClrType, g.SqlType,
this.VisitExpression(g.Key), this.VisitExpression(g.Group), g.SourceExpression
);
return n;
}
internal override SqlExpression VisitDiscriminatedType(SqlDiscriminatedType dt) {
return new SqlDiscriminatedType(dt.SqlType, this.VisitExpression(dt.Discriminator), dt.TargetType, dt.SourceExpression);
}
internal override SqlExpression VisitLift(SqlLift lift) {
return new SqlLift(lift.ClrType, this.VisitExpression(lift.Expression), lift.SourceExpression);
}
internal override SqlExpression VisitDiscriminatorOf(SqlDiscriminatorOf dof) {
return new SqlDiscriminatorOf(this.VisitExpression(dof.Object), dof.ClrType, dof.SqlType, dof.SourceExpression);
}
internal override SqlNode VisitIncludeScope(SqlIncludeScope scope) {
return new SqlIncludeScope(this.Visit(scope.Child), scope.SourceExpression);
}
}
}
}

View File

@@ -0,0 +1,349 @@
using System;
using System.Collections.Generic;
using System.Linq.Expressions;
using System.Reflection;
using System.Data.Linq;
using System.Data.Linq.Mapping;
using System.Data.Linq.Provider;
using System.Linq;
using System.Data.Linq.SqlClient;
using System.Diagnostics.CodeAnalysis;
namespace System.Data.Linq.SqlClient {
// duplicates an expression up until a column or column ref is encountered
// goes 'deep' through alias ref's
// assumes that columnizing has been done already
internal class SqlExpander {
SqlFactory factory;
internal SqlExpander(SqlFactory factory) {
this.factory = factory;
}
internal SqlExpression Expand(SqlExpression exp) {
return (new Visitor(this.factory)).VisitExpression(exp);
}
class Visitor : SqlDuplicator.DuplicatingVisitor {
SqlFactory factory;
Expression sourceExpression;
internal Visitor(SqlFactory factory)
: base(true) {
this.factory = factory;
}
internal override SqlExpression VisitColumnRef(SqlColumnRef cref) {
return cref;
}
internal override SqlExpression VisitColumn(SqlColumn col) {
return new SqlColumnRef(col);
}
internal override SqlExpression VisitSharedExpression(SqlSharedExpression shared) {
return this.VisitExpression(shared.Expression);
}
internal override SqlExpression VisitSharedExpressionRef(SqlSharedExpressionRef sref) {
return this.VisitExpression(sref.SharedExpression.Expression);
}
internal override SqlExpression VisitAliasRef(SqlAliasRef aref) {
SqlNode node = aref.Alias.Node;
if (node is SqlTable || node is SqlTableValuedFunctionCall) {
return aref;
}
SqlUnion union = node as SqlUnion;
if (union != null) {
return this.ExpandUnion(union);
}
SqlSelect ss = node as SqlSelect;
if (ss != null) {
return this.VisitExpression(ss.Selection);
}
SqlExpression exp = node as SqlExpression;
if (exp != null)
return this.VisitExpression(exp);
throw Error.CouldNotHandleAliasRef(node.NodeType);
}
internal override SqlExpression VisitSubSelect(SqlSubSelect ss) {
return (SqlExpression)new SqlDuplicator().Duplicate(ss);
}
internal override SqlNode VisitLink(SqlLink link) {
SqlExpression expansion = this.VisitExpression(link.Expansion);
SqlExpression[] exprs = new SqlExpression[link.KeyExpressions.Count];
for (int i = 0, n = exprs.Length; i < n; i++) {
exprs[i] = this.VisitExpression(link.KeyExpressions[i]);
}
return new SqlLink(link.Id, link.RowType, link.ClrType, link.SqlType, link.Expression, link.Member, exprs, expansion, link.SourceExpression);
}
private SqlExpression ExpandUnion(SqlUnion union) {
List<SqlExpression> exprs = new List<SqlExpression>(2);
this.GatherUnionExpressions(union, exprs);
this.sourceExpression = union.SourceExpression;
SqlExpression result = this.ExpandTogether(exprs);
return result;
}
private void GatherUnionExpressions(SqlNode node, List<SqlExpression> exprs) {
SqlUnion union = node as SqlUnion;
if (union != null) {
this.GatherUnionExpressions(union.Left, exprs);
this.GatherUnionExpressions(union.Right, exprs);
}
else {
SqlSelect sel = node as SqlSelect;
if (sel != null) {
SqlAliasRef aref = sel.Selection as SqlAliasRef;
if (aref != null) {
this.GatherUnionExpressions(aref.Alias.Node, exprs);
}
else {
exprs.Add(sel.Selection);
}
}
}
}
[SuppressMessage("Microsoft.Performance", "CA1809:AvoidExcessiveLocals", Justification="These issues are related to our use of if-then and case statements for node types, which adds to the complexity count however when reviewed they are easy to navigate and understand.")]
[SuppressMessage("Microsoft.Maintainability", "CA1505:AvoidUnmaintainableCode", Justification = "These issues are related to our use of if-then and case statements for node types, which adds to the complexity count however when reviewed they are easy to navigate and understand.")]
[SuppressMessage("Microsoft.Maintainability", "CA1502:AvoidExcessiveComplexity", Justification = "These issues are related to our use of if-then and case statements for node types, which adds to the complexity count however when reviewed they are easy to navigate and understand.")]
private SqlExpression ExpandTogether(List<SqlExpression> exprs) {
switch (exprs[0].NodeType) {
case SqlNodeType.MethodCall: {
SqlMethodCall[] mcs = new SqlMethodCall[exprs.Count];
for (int i = 0; i < mcs.Length; ++i) {
mcs[i] = (SqlMethodCall)exprs[i];
}
List<SqlExpression> expandedArgs = new List<SqlExpression>();
for (int i = 0; i < mcs[0].Arguments.Count; ++i) {
List<SqlExpression> args = new List<SqlExpression>();
for (int j = 0; j < mcs.Length; ++j) {
args.Add(mcs[j].Arguments[i]);
}
SqlExpression expanded = this.ExpandTogether(args);
expandedArgs.Add(expanded);
}
return factory.MethodCall(mcs[0].Method, mcs[0].Object, expandedArgs.ToArray(), mcs[0].SourceExpression);
}
case SqlNodeType.ClientCase: {
// Are they all the same?
SqlClientCase[] scs = new SqlClientCase[exprs.Count];
scs[0] = (SqlClientCase)exprs[0];
for (int i = 1; i < scs.Length; ++i) {
scs[i] = (SqlClientCase)exprs[i];
}
// Expand expressions together.
List<SqlExpression> expressions = new List<SqlExpression>();
for (int i = 0; i < scs.Length; ++i) {
expressions.Add(scs[i].Expression);
}
SqlExpression expression = this.ExpandTogether(expressions);
// Expand individual expressions together.
List<SqlClientWhen> whens = new List<SqlClientWhen>();
for (int i = 0; i < scs[0].Whens.Count; ++i) {
List<SqlExpression> scos = new List<SqlExpression>();
for (int j = 0; j < scs.Length; ++j) {
SqlClientWhen when = scs[j].Whens[i];
scos.Add(when.Value);
}
whens.Add(new SqlClientWhen(scs[0].Whens[i].Match, this.ExpandTogether(scos)));
}
return new SqlClientCase(scs[0].ClrType, expression, whens, scs[0].SourceExpression);
}
case SqlNodeType.TypeCase: {
// Are they all the same?
SqlTypeCase[] tcs = new SqlTypeCase[exprs.Count];
tcs[0] = (SqlTypeCase)exprs[0];
for (int i = 1; i < tcs.Length; ++i) {
tcs[i] = (SqlTypeCase)exprs[i];
}
// Expand discriminators together.
List<SqlExpression> discriminators = new List<SqlExpression>();
for (int i = 0; i < tcs.Length; ++i) {
discriminators.Add(tcs[i].Discriminator);
}
SqlExpression discriminator = this.ExpandTogether(discriminators);
// Write expanded discriminators back in.
for (int i = 0; i < tcs.Length; ++i) {
tcs[i].Discriminator = discriminators[i];
}
// Expand individual type bindings together.
List<SqlTypeCaseWhen> whens = new List<SqlTypeCaseWhen>();
for (int i = 0; i < tcs[0].Whens.Count; ++i) {
List<SqlExpression> scos = new List<SqlExpression>();
for (int j = 0; j < tcs.Length; ++j) {
SqlTypeCaseWhen when = tcs[j].Whens[i];
scos.Add(when.TypeBinding);
}
SqlExpression expanded = this.ExpandTogether(scos);
whens.Add(new SqlTypeCaseWhen(tcs[0].Whens[i].Match, expanded));
}
return factory.TypeCase(tcs[0].ClrType, tcs[0].RowType, discriminator, whens, tcs[0].SourceExpression);
}
case SqlNodeType.New: {
// first verify all are similar client objects...
SqlNew[] cobs = new SqlNew[exprs.Count];
cobs[0] = (SqlNew)exprs[0];
for (int i = 1, n = exprs.Count; i < n; i++) {
if (exprs[i] == null || exprs[i].NodeType != SqlNodeType.New)
throw Error.UnionIncompatibleConstruction();
cobs[i] = (SqlNew)exprs[1];
if (cobs[i].Members.Count != cobs[0].Members.Count)
throw Error.UnionDifferentMembers();
for (int m = 0, mn = cobs[0].Members.Count; m < mn; m++) {
if (cobs[i].Members[m].Member != cobs[0].Members[m].Member) {
throw Error.UnionDifferentMemberOrder();
}
}
}
SqlMemberAssign[] bindings = new SqlMemberAssign[cobs[0].Members.Count];
for (int m = 0, mn = bindings.Length; m < mn; m++) {
List<SqlExpression> mexprs = new List<SqlExpression>();
for (int i = 0, n = exprs.Count; i < n; i++) {
mexprs.Add(cobs[i].Members[m].Expression);
}
bindings[m] = new SqlMemberAssign(cobs[0].Members[m].Member, this.ExpandTogether(mexprs));
for (int i = 0, n = exprs.Count; i < n; i++) {
cobs[i].Members[m].Expression = mexprs[i];
}
}
SqlExpression[] arguments = new SqlExpression[cobs[0].Args.Count];
for (int m = 0, mn = arguments.Length; m < mn; ++m) {
List<SqlExpression> mexprs = new List<SqlExpression>();
for (int i = 0, n = exprs.Count; i < n; i++) {
mexprs.Add(cobs[i].Args[m]);
}
arguments[m] = ExpandTogether(mexprs);
}
return factory.New(cobs[0].MetaType, cobs[0].Constructor, arguments, cobs[0].ArgMembers, bindings, exprs[0].SourceExpression);
}
case SqlNodeType.Link: {
SqlLink[] links = new SqlLink[exprs.Count];
links[0] = (SqlLink)exprs[0];
for (int i = 1, n = exprs.Count; i < n; i++) {
if (exprs[i] == null || exprs[i].NodeType != SqlNodeType.Link)
throw Error.UnionIncompatibleConstruction();
links[i] = (SqlLink)exprs[i];
if (links[i].KeyExpressions.Count != links[0].KeyExpressions.Count ||
links[i].Member != links[0].Member ||
(links[i].Expansion != null) != (links[0].Expansion != null))
throw Error.UnionIncompatibleConstruction();
}
SqlExpression[] kexprs = new SqlExpression[links[0].KeyExpressions.Count];
List<SqlExpression> lexprs = new List<SqlExpression>();
for (int k = 0, nk = links[0].KeyExpressions.Count; k < nk; k++) {
lexprs.Clear();
for (int i = 0, n = exprs.Count; i < n; i++) {
lexprs.Add(links[i].KeyExpressions[k]);
}
kexprs[k] = this.ExpandTogether(lexprs);
for (int i = 0, n = exprs.Count; i < n; i++) {
links[i].KeyExpressions[k] = lexprs[i];
}
}
SqlExpression expansion = null;
if (links[0].Expansion != null) {
lexprs.Clear();
for (int i = 0, n = exprs.Count; i < n; i++) {
lexprs.Add(links[i].Expansion);
}
expansion = this.ExpandTogether(lexprs);
for (int i = 0, n = exprs.Count; i < n; i++) {
links[i].Expansion = lexprs[i];
}
}
return new SqlLink(links[0].Id, links[0].RowType, links[0].ClrType, links[0].SqlType, links[0].Expression, links[0].Member, kexprs, expansion, links[0].SourceExpression);
}
case SqlNodeType.Value: {
/*
* ExprSet of all literals of the same value reduce to just a single literal.
*/
SqlValue val0 = (SqlValue)exprs[0];
for (int i = 1; i < exprs.Count; ++i) {
SqlValue val = (SqlValue)exprs[i];
if (!object.Equals(val.Value, val0.Value))
return this.ExpandIntoExprSet(exprs);
}
return val0;
}
case SqlNodeType.OptionalValue: {
if (exprs[0].SqlType.CanBeColumn) {
goto default;
}
List<SqlExpression> hvals = new List<SqlExpression>(exprs.Count);
List<SqlExpression> vals = new List<SqlExpression>(exprs.Count);
for (int i = 0, n = exprs.Count; i < n; i++) {
if (exprs[i] == null || exprs[i].NodeType != SqlNodeType.OptionalValue) {
throw Error.UnionIncompatibleConstruction();
}
SqlOptionalValue sov = (SqlOptionalValue)exprs[i];
hvals.Add(sov.HasValue);
vals.Add(sov.Value);
}
return new SqlOptionalValue(this.ExpandTogether(hvals), this.ExpandTogether(vals));
}
case SqlNodeType.OuterJoinedValue: {
if (exprs[0].SqlType.CanBeColumn) {
goto default;
}
List<SqlExpression> values = new List<SqlExpression>(exprs.Count);
for (int i = 0, n = exprs.Count; i < n; i++) {
if (exprs[i] == null || exprs[i].NodeType != SqlNodeType.OuterJoinedValue) {
throw Error.UnionIncompatibleConstruction();
}
SqlUnary su = (SqlUnary)exprs[i];
values.Add(su.Operand);
}
return factory.Unary(SqlNodeType.OuterJoinedValue, this.ExpandTogether(values));
}
case SqlNodeType.DiscriminatedType: {
SqlDiscriminatedType sdt0 = (SqlDiscriminatedType)exprs[0];
List<SqlExpression> foos = new List<SqlExpression>(exprs.Count);
foos.Add(sdt0.Discriminator);
for (int i = 1, n = exprs.Count; i < n; i++) {
SqlDiscriminatedType sdtN = (SqlDiscriminatedType)exprs[i];
if (sdtN.TargetType != sdt0.TargetType) {
throw Error.UnionIncompatibleConstruction();
}
foos.Add(sdtN.Discriminator);
}
return factory.DiscriminatedType(this.ExpandTogether(foos), ((SqlDiscriminatedType)exprs[0]).TargetType);
}
case SqlNodeType.ClientQuery:
case SqlNodeType.Multiset:
case SqlNodeType.Element:
case SqlNodeType.Grouping:
throw Error.UnionWithHierarchy();
default:
return this.ExpandIntoExprSet(exprs);
}
}
/// <summary>
/// Expand a set of expressions into a single expr set.
/// This is typically a fallback when there is no other way to unify a set of expressions.
/// </summary>
private SqlExpression ExpandIntoExprSet(List<SqlExpression> exprs) {
SqlExpression[] rexprs = new SqlExpression[exprs.Count];
for (int i = 0, n = exprs.Count; i < n; i++) {
rexprs[i] = this.VisitExpression(exprs[i]);
}
return this.factory.ExprSet(rexprs, this.sourceExpression);
}
}
}
}

View File

@@ -0,0 +1,120 @@
using System;
using System.Collections;
using System.Collections.Generic;
using System.Data;
using System.Reflection;
using System.Text;
using System.Linq;
using System.Linq.Expressions;
using System.Data.Linq;
using System.Data.Linq.Provider;
using System.Diagnostics.CodeAnalysis;
namespace System.Data.Linq.SqlClient {
internal static class SqlExpressionNullability {
/// <summary>
/// Determines whether the given expression may return a null result.
/// </summary>
/// <param name="expr">The expression to check.</param>
/// <returns>null means that it couldn't be determined</returns>
[SuppressMessage("Microsoft.Maintainability", "CA1502:AvoidExcessiveComplexity", Justification="These issues are related to our use of if-then and case statements for node types, which adds to the complexity count however when reviewed they are easy to navigate and understand.")]
internal static bool? CanBeNull(SqlExpression expr) {
switch (expr.NodeType) {
case SqlNodeType.ExprSet:
SqlExprSet exprSet = (SqlExprSet)expr;
return CanBeNull(exprSet.Expressions);
case SqlNodeType.SimpleCase:
SqlSimpleCase sc = (SqlSimpleCase)expr;
return CanBeNull(sc.Whens.Select(w => w.Value));
case SqlNodeType.Column:
SqlColumn col = (SqlColumn)expr;
if (col.MetaMember != null) {
return col.MetaMember.CanBeNull;
}
else if (col.Expression != null) {
return CanBeNull(col.Expression);
}
return null; // Don't know.
case SqlNodeType.ColumnRef:
SqlColumnRef cref = (SqlColumnRef)expr;
return CanBeNull(cref.Column);
case SqlNodeType.Value:
return ((SqlValue)expr).Value == null;
case SqlNodeType.New:
case SqlNodeType.Multiset:
case SqlNodeType.Grouping:
case SqlNodeType.DiscriminatedType:
case SqlNodeType.IsNotNull: // IsNull\IsNotNull always return true or false and can never return NULL.
case SqlNodeType.IsNull:
case SqlNodeType.Exists:
return false;
case SqlNodeType.Add:
case SqlNodeType.Sub:
case SqlNodeType.Mul:
case SqlNodeType.Div:
case SqlNodeType.Mod:
case SqlNodeType.BitAnd:
case SqlNodeType.BitOr:
case SqlNodeType.BitXor:
case SqlNodeType.Concat: {
SqlBinary bop = (SqlBinary)expr;
bool? left = CanBeNull(bop.Left);
bool? right = CanBeNull(bop.Right);
return (left != false) || (right != false);
}
case SqlNodeType.Negate:
case SqlNodeType.BitNot: {
SqlUnary uop = (SqlUnary)expr;
return CanBeNull(uop.Operand);
}
case SqlNodeType.Lift: {
SqlLift lift = (SqlLift)expr;
return CanBeNull(lift.Expression);
}
case SqlNodeType.OuterJoinedValue:
return true;
default:
return null; // Don't know.
}
}
/// <summary>
/// Used to determine nullability for a collection of expressions.
/// * If at least one of the expressions is nullable, the collection is nullable.
/// * If no expressions are nullable, but at least one is 'don't know', the collection is 'don't know'.
/// * Otherwise all expressions are non-nullable and the nullability is false.
/// </summary>
private static bool? CanBeNull(IEnumerable<SqlExpression> exprs) {
bool hasAtleastOneUnknown = false;
foreach(SqlExpression e in exprs) {
bool? nullability = CanBeNull(e);
// Even one expression that could return null means the
// collection can return null.
if (nullability == true)
return true;
// If there is one or more 'unknown' and no definitely nullable
// results then the collection nullability is 'unknown'.
if (nullability == null)
hasAtleastOneUnknown = true;
}
if (hasAtleastOneUnknown)
return null;
return false;
}
}
}

Some files were not shown because too many files have changed in this diff Show More