// Copyright (c) Microsoft Open Technologies, Inc. All rights reserved. See License.txt in the project root for license information. namespace System.Data.Entity { using System.Collections.Generic; using System.Configuration; using System.Data.Common; using System.Data.Entity.Core.Metadata.Edm; using System.Data.Entity.Core.Objects; using System.Data.Entity.Infrastructure; using System.Data.SqlClient; using System.Diagnostics; using System.Globalization; using System.IO; using System.Linq; using System.Text; using System.Xml; using System.Xml.Linq; using Xunit; public static class ModelHelpers { #region State entry helpers /// /// Gets all GetStateEntries for the given DbContext. /// /// A DbContext instance. /// All state entries in the ObjectStateManager. public static IEnumerable GetStateEntries(DbContext dbContext) { return GetStateEntries(TestBase.GetObjectContext(dbContext)); } /// /// Gets all GetStateEntries for the given ObjectContext. /// /// A ObjectContext instance. /// All state entries in the ObjectStateManager. public static IEnumerable GetStateEntries(ObjectContext objectContext) { return objectContext.ObjectStateManager.GetObjectStateEntries(~EntityState.Detached); } /// /// Gets the ObjectStateEntry for the given entity in the given DbContext. /// /// A DbContext instance. /// The entity to lookup. /// The ObjectStateEntry. public static ObjectStateEntry GetStateEntry(DbContext dbContext, object entity) { return GetStateEntry(TestBase.GetObjectContext(dbContext), entity); } /// /// Gets the ObjectStateEntry for the given entity in the given ObjectContext. /// /// A ObjectContext instance. /// The entity to lookup. /// The ObjectStateEntry. public static ObjectStateEntry GetStateEntry(ObjectContext objectContext, object entity) { return objectContext.ObjectStateManager.GetObjectStateEntry(entity); } /// /// Asserts that there's no ObjectStateEntry for the given entity in the given DbContext. /// /// A DbContext instance. /// The entity to lookup. public static void AssertNoStateEntry(DbContext dbContext, object entity) { AssertNoStateEntry(TestBase.GetObjectContext(dbContext), entity); } /// /// Asserts that there's no ObjectStateEntry for the given entity in the given ObjectContext. /// /// A ObjectContext instance. /// The entity to lookup. public static void AssertNoStateEntry(ObjectContext objectContext, object entity) { ObjectStateEntry entry; Assert.False( objectContext.ObjectStateManager.TryGetObjectStateEntry(entity, out entry), "The context contains an unexpected entry for the given entity"); } #endregion #region Connection helpers private static string _baseConnectionString = ConfigurationManager.AppSettings["BaseConnectionString"] ?? @"Data Source=.\SQLEXPRESS; Integrated Security=True;"; public static string BaseConnectionString { get { return _baseConnectionString; } } /// /// Returns a simple SQL Server connection string to the local machine with the given database name. /// /// The database name. /// The connection string. public static string SimpleConnectionString(string databaseName) { return new SqlConnectionStringBuilder(_baseConnectionString) { InitialCatalog = databaseName } .ConnectionString; } /// /// Returns a simple SQL Server connection string to the local machine using an attachable database with the given database name. /// /// The database name. /// The connection string. public static string SimpleAttachConnectionString(string databaseName) { var databasePath = Path.Combine(AppDomain.CurrentDomain.BaseDirectory, databaseName + ".mdf"); return new SqlConnectionStringBuilder(_baseConnectionString) { InitialCatalog = databaseName, AttachDBFilename = databasePath } .ConnectionString; } /// /// Returns a simple SQL Server connection string with the specified credentials. /// /// The database name. /// User ID to be use when connecting to SQL Server. /// Password for the SQL Server account. /// /// Indicates if security-sensitive information is not returned as part of the /// connection if the connection has ever been opened. /// /// The connection string. public static string SimpleConnectionStringWithCredentials( string databaseName, string userId, string password, bool persistSecurityInfo = false) { var builder = new SqlConnectionStringBuilder(_baseConnectionString) { InitialCatalog = databaseName, UserID = userId, Password = password, PersistSecurityInfo = persistSecurityInfo }; builder.Remove("Integrated Security"); return builder.ConnectionString; } /// /// Returns a simple SQL CE connection string to the local machine with the given database name. /// Name of the database. /// The connection string. public static string SimpleCeConnectionString(string databaseName) { return String.Format( CultureInfo.InvariantCulture, "Data Source={0}.sdf;Persist Security Info=False;", Path.Combine(AppDomain.CurrentDomain.BaseDirectory, databaseName)); } /// /// Returns the default name that will be created for the context of the given type. /// /// The type of the context to create a name for. /// The name. public static string DefaultDbName() where TContext : DbContext { return typeof(TContext).FullName; } /// /// Returns a simple SQL Server connection string to the local machine for the given context type. /// /// The type of the context to create a connection string for. /// The connection string. public static string SimpleConnectionString() where TContext : DbContext { return SimpleConnectionString(DefaultDbName()); } /// /// Returns a simple SQL Server connection string to the local machine using an attachable database for the given context type. /// /// The type of the context to create a connection string for. /// The connection string. public static string SimpleAttachConnectionString() where TContext : DbContext { return SimpleAttachConnectionString(DefaultDbName()); } /// /// Returns a simple SQLCE connection string to the local machine for the given context type. /// /// The type of the context to create a connection string for. /// The connection string. public static string SimpleCeConnectionString() where TContext : DbContext { return SimpleCeConnectionString(DefaultDbName()); } /// /// Returns a simple SQL Server connection to the local machine for the given context type. /// /// The type of the context to create a connection for. /// The connection. public static SqlConnection SimpleConnection() where TContext : DbContext { return new SqlConnection(SimpleConnectionString()); } /// /// Returns a simple SQL CE connection for the given context type. /// /// The type of the context to create a connection for. /// The connection. public static DbConnection SimpleCeConnection() where TContext : DbContext { return new SqlCeConnectionFactory("System.Data.SqlServerCe.4.0", AppDomain.CurrentDomain.BaseDirectory, ""). CreateConnection(DefaultDbName()); } #endregion #region Entity set name helpers /// /// Gets the entity set name for the given CLR type, assuming no MEST. /// /// The context to look in. /// The type to lookup. /// The entity set name. public static string GetEntitySetName(DbContext dbContext, Type clrType) { return GetEntitySetName(TestBase.GetObjectContext(dbContext), clrType); } /// /// Gets the entity set name for the given CLR type, assuming no MEST. /// /// The context to look in. /// The type to lookup. /// The entity set name. public static string GetEntitySetName(ObjectContext objectContext, Type clrType) { var cspaceType = GetStructuralType(objectContext, clrType); if (cspaceType == null) { return null; } var inverseHierarchy = new Stack(); do { inverseHierarchy.Push(cspaceType); cspaceType = (EntityType)cspaceType.BaseType; } while (cspaceType != null); while (inverseHierarchy.Count > 0) { cspaceType = inverseHierarchy.Pop(); foreach (var container in objectContext.MetadataWorkspace.GetItems(DataSpace.CSpace)) { var entitySet = container.BaseEntitySets.Where(s => s.ElementType == cspaceType).FirstOrDefault(); if (entitySet != null) { return entitySet.Name; } } } return null; } #endregion #region Entity type helpers /// /// Gets the Entity Type of the entity, given the CLR type /// /// The context to look in. /// Type of the CLR. /// public static EntityType GetEntityType(DbContext dbContext, Type clrType) { return GetStructuralType(TestBase.GetObjectContext(dbContext), clrType); } /// /// Gets the structural type of the entity type or complex type given the CLR type /// /// The context to look in. /// The CLR type. /// The EntityType or ComplexType public static TStructural GetStructuralType(ObjectContext objectContext, Type clrType) where TStructural : StructuralType { var objectItemCollection = (ObjectItemCollection)objectContext.MetadataWorkspace.GetItemCollection(DataSpace.OSpace); var ospaceTypes = objectContext.MetadataWorkspace.GetItems(DataSpace.OSpace); var ospaceType = ospaceTypes.Where(t => objectItemCollection.GetClrType(t) == clrType).FirstOrDefault(); if (ospaceType == null) { objectContext.MetadataWorkspace.LoadFromAssembly(clrType.Assembly); ospaceType = ospaceTypes.Where(t => objectItemCollection.GetClrType(t) == clrType).FirstOrDefault(); if (ospaceType == null) { return null; } } return (TStructural)objectContext.MetadataWorkspace.GetEdmSpaceType(ospaceType); } #endregion #region Helpers for creating metadata (csdl/ssdl/msl) files /// /// Writes an edmx file into the current directory for the model generated from the given model builder. /// /// The builder. /// The filename to use for the edmx file. public static void WriteEdmx(DbModelBuilder builder, string filename) { EdmxWriter.WriteEdmx( builder.Build(new DbProviderInfo("System.Data.SqlClient", "2008")), XmlWriter.Create(filename)); } /// /// Writes csdl, msdl, and ssdl files into the current directory for the model generated from /// the given model builder. /// /// The builder. /// The base filename to use for csdl, msdl, and sssl files. public static void WriteMetadataFiles(DbModelBuilder builder, string filename) { var xml = new StringBuilder(); EdmxWriter.WriteEdmx( builder.Build(new DbProviderInfo("System.Data.SqlClient", "2008")), XmlWriter.Create(xml)); WriteMetadataFiles(xml.ToString(), filename); } /// /// Takes the edmx given as input and splits it into csdl, msl, and ssdl files that are written to the /// current directory. /// /// The edmx. (Note that this is NOT the filename of an edmx file; it is the actual edmx.) /// The base filename to use for csdl, msdl, and sssl files. public static void WriteMetadataFiles(string edmx, string filename) { var csdlNameV2 = (XNamespace)"http://schemas.microsoft.com/ado/2008/09/edm" + "Schema"; var ssdlNameV2 = (XNamespace)"http://schemas.microsoft.com/ado/2009/02/edm/ssdl" + "Schema"; var mslNameV2 = (XNamespace)"http://schemas.microsoft.com/ado/2008/09/mapping/cs" + "Mapping"; var csdlNameV3 = (XNamespace)"http://schemas.microsoft.com/ado/2009/11/edm" + "Schema"; var ssdlNameV3 = (XNamespace)"http://schemas.microsoft.com/ado/2009/11/edm/ssdl" + "Schema"; var mslNameV3 = (XNamespace)"http://schemas.microsoft.com/ado/2009/11/mapping/cs" + "Mapping"; var edmxDoc = XDocument.Load(new StringReader(edmx)); WriteMetadataFile( filename + ".csdl", ExtractMetadataContent(edmxDoc, "ConceptualModels", csdlNameV2, csdlNameV3)); WriteMetadataFile( filename + ".ssdl", ExtractMetadataContent(edmxDoc, "StorageModels", ssdlNameV2, ssdlNameV3)); WriteMetadataFile(filename + ".msl", ExtractMetadataContent(edmxDoc, "Mappings", mslNameV2, mslNameV3)); } private static void WriteMetadataFile(string filename, XElement element) { Debug.Assert(element != null, "Expected to find element"); using (var writer = XmlWriter.Create(filename)) { element.Save(writer); } } private static XElement ExtractMetadataContent(XDocument edmxDoc, string part, params XName[] elements) { XNamespace edmxnsV2 = "http://schemas.microsoft.com/ado/2008/10/edmx"; XNamespace edmxnsV3 = "http://schemas.microsoft.com/ado/2009/11/edmx"; var edmxNode = edmxDoc.Element(edmxnsV2 + "Edmx") ?? edmxDoc.Element(edmxnsV3 + "Edmx"); Debug.Assert(edmxNode != null, "Expected to find edmx node."); var runtimeNode = edmxNode.Element(edmxnsV2 + "Runtime") ?? edmxNode.Element(edmxnsV3 + "Runtime"); Debug.Assert(runtimeNode != null, "Expected to find runtime node."); var partNode = runtimeNode.Element(edmxnsV2 + part) ?? runtimeNode.Element(edmxnsV3 + part); Debug.Assert(partNode != null, "Expected to find " + part); return partNode.Element(elements[0]) ?? partNode.Element(elements[1]); } #endregion } }