// Copyright (c) Microsoft Open Technologies, Inc. All rights reserved. See License.txt in the project root for license information.
namespace System.Data.Entity
{
using System.Collections.Generic;
using System.Configuration;
using System.Data.Common;
using System.Data.Entity.Core.Metadata.Edm;
using System.Data.Entity.Core.Objects;
using System.Data.Entity.Infrastructure;
using System.Data.SqlClient;
using System.Diagnostics;
using System.Globalization;
using System.IO;
using System.Linq;
using System.Text;
using System.Xml;
using System.Xml.Linq;
using Xunit;
public static class ModelHelpers
{
#region State entry helpers
///
/// Gets all GetStateEntries for the given DbContext.
///
/// A DbContext instance.
/// All state entries in the ObjectStateManager.
public static IEnumerable GetStateEntries(DbContext dbContext)
{
return GetStateEntries(TestBase.GetObjectContext(dbContext));
}
///
/// Gets all GetStateEntries for the given ObjectContext.
///
/// A ObjectContext instance.
/// All state entries in the ObjectStateManager.
public static IEnumerable GetStateEntries(ObjectContext objectContext)
{
return objectContext.ObjectStateManager.GetObjectStateEntries(~EntityState.Detached);
}
///
/// Gets the ObjectStateEntry for the given entity in the given DbContext.
///
/// A DbContext instance.
/// The entity to lookup.
/// The ObjectStateEntry.
public static ObjectStateEntry GetStateEntry(DbContext dbContext, object entity)
{
return GetStateEntry(TestBase.GetObjectContext(dbContext), entity);
}
///
/// Gets the ObjectStateEntry for the given entity in the given ObjectContext.
///
/// A ObjectContext instance.
/// The entity to lookup.
/// The ObjectStateEntry.
public static ObjectStateEntry GetStateEntry(ObjectContext objectContext, object entity)
{
return objectContext.ObjectStateManager.GetObjectStateEntry(entity);
}
///
/// Asserts that there's no ObjectStateEntry for the given entity in the given DbContext.
///
/// A DbContext instance.
/// The entity to lookup.
public static void AssertNoStateEntry(DbContext dbContext, object entity)
{
AssertNoStateEntry(TestBase.GetObjectContext(dbContext), entity);
}
///
/// Asserts that there's no ObjectStateEntry for the given entity in the given ObjectContext.
///
/// A ObjectContext instance.
/// The entity to lookup.
public static void AssertNoStateEntry(ObjectContext objectContext, object entity)
{
ObjectStateEntry entry;
Assert.False(
objectContext.ObjectStateManager.TryGetObjectStateEntry(entity, out entry),
"The context contains an unexpected entry for the given entity");
}
#endregion
#region Connection helpers
private static string _baseConnectionString = ConfigurationManager.AppSettings["BaseConnectionString"]
?? @"Data Source=.\SQLEXPRESS; Integrated Security=True;";
public static string BaseConnectionString
{
get { return _baseConnectionString; }
}
///
/// Returns a simple SQL Server connection string to the local machine with the given database name.
///
/// The database name.
/// The connection string.
public static string SimpleConnectionString(string databaseName)
{
return new SqlConnectionStringBuilder(_baseConnectionString)
{
InitialCatalog = databaseName
}
.ConnectionString;
}
///
/// Returns a simple SQL Server connection string to the local machine using an attachable database with the given database name.
///
/// The database name.
/// The connection string.
public static string SimpleAttachConnectionString(string databaseName)
{
var databasePath = Path.Combine(AppDomain.CurrentDomain.BaseDirectory, databaseName + ".mdf");
return new SqlConnectionStringBuilder(_baseConnectionString)
{
InitialCatalog = databaseName,
AttachDBFilename = databasePath
}
.ConnectionString;
}
///
/// Returns a simple SQL Server connection string with the specified credentials.
///
/// The database name.
/// User ID to be use when connecting to SQL Server.
/// Password for the SQL Server account.
///
/// Indicates if security-sensitive information is not returned as part of the
/// connection if the connection has ever been opened.
///
/// The connection string.
public static string SimpleConnectionStringWithCredentials(
string databaseName,
string userId,
string password,
bool persistSecurityInfo = false)
{
var builder = new SqlConnectionStringBuilder(_baseConnectionString)
{
InitialCatalog = databaseName,
UserID = userId,
Password = password,
PersistSecurityInfo = persistSecurityInfo
};
builder.Remove("Integrated Security");
return builder.ConnectionString;
}
///
/// Returns a simple SQL CE connection string to the local machine with the given database name.
/// Name of the database.
/// The connection string.
public static string SimpleCeConnectionString(string databaseName)
{
return String.Format(
CultureInfo.InvariantCulture, "Data Source={0}.sdf;Persist Security Info=False;",
Path.Combine(AppDomain.CurrentDomain.BaseDirectory, databaseName));
}
///
/// Returns the default name that will be created for the context of the given type.
///
/// The type of the context to create a name for.
/// The name.
public static string DefaultDbName() where TContext : DbContext
{
return typeof(TContext).FullName;
}
///
/// Returns a simple SQL Server connection string to the local machine for the given context type.
///
/// The type of the context to create a connection string for.
/// The connection string.
public static string SimpleConnectionString() where TContext : DbContext
{
return SimpleConnectionString(DefaultDbName());
}
///
/// Returns a simple SQL Server connection string to the local machine using an attachable database for the given context type.
///
/// The type of the context to create a connection string for.
/// The connection string.
public static string SimpleAttachConnectionString() where TContext : DbContext
{
return SimpleAttachConnectionString(DefaultDbName());
}
///
/// Returns a simple SQLCE connection string to the local machine for the given context type.
///
/// The type of the context to create a connection string for.
/// The connection string.
public static string SimpleCeConnectionString() where TContext : DbContext
{
return SimpleCeConnectionString(DefaultDbName());
}
///
/// Returns a simple SQL Server connection to the local machine for the given context type.
///
/// The type of the context to create a connection for.
/// The connection.
public static SqlConnection SimpleConnection() where TContext : DbContext
{
return new SqlConnection(SimpleConnectionString());
}
///
/// Returns a simple SQL CE connection for the given context type.
///
/// The type of the context to create a connection for.
/// The connection.
public static DbConnection SimpleCeConnection() where TContext : DbContext
{
return
new SqlCeConnectionFactory("System.Data.SqlServerCe.4.0", AppDomain.CurrentDomain.BaseDirectory, "").
CreateConnection(DefaultDbName());
}
#endregion
#region Entity set name helpers
///
/// Gets the entity set name for the given CLR type, assuming no MEST.
///
/// The context to look in.
/// The type to lookup.
/// The entity set name.
public static string GetEntitySetName(DbContext dbContext, Type clrType)
{
return GetEntitySetName(TestBase.GetObjectContext(dbContext), clrType);
}
///
/// Gets the entity set name for the given CLR type, assuming no MEST.
///
/// The context to look in.
/// The type to lookup.
/// The entity set name.
public static string GetEntitySetName(ObjectContext objectContext, Type clrType)
{
var cspaceType = GetStructuralType(objectContext, clrType);
if (cspaceType == null)
{
return null;
}
var inverseHierarchy = new Stack();
do
{
inverseHierarchy.Push(cspaceType);
cspaceType = (EntityType)cspaceType.BaseType;
}
while (cspaceType != null);
while (inverseHierarchy.Count > 0)
{
cspaceType = inverseHierarchy.Pop();
foreach (var container in objectContext.MetadataWorkspace.GetItems(DataSpace.CSpace))
{
var entitySet = container.BaseEntitySets.Where(s => s.ElementType == cspaceType).FirstOrDefault();
if (entitySet != null)
{
return entitySet.Name;
}
}
}
return null;
}
#endregion
#region Entity type helpers
///
/// Gets the Entity Type of the entity, given the CLR type
///
/// The context to look in.
/// Type of the CLR.
///
public static EntityType GetEntityType(DbContext dbContext, Type clrType)
{
return GetStructuralType(TestBase.GetObjectContext(dbContext), clrType);
}
///
/// Gets the structural type of the entity type or complex type given the CLR type
///
/// The context to look in.
/// The CLR type.
/// The EntityType or ComplexType
public static TStructural GetStructuralType(ObjectContext objectContext, Type clrType)
where TStructural : StructuralType
{
var objectItemCollection =
(ObjectItemCollection)objectContext.MetadataWorkspace.GetItemCollection(DataSpace.OSpace);
var ospaceTypes = objectContext.MetadataWorkspace.GetItems(DataSpace.OSpace);
var ospaceType = ospaceTypes.Where(t => objectItemCollection.GetClrType(t) == clrType).FirstOrDefault();
if (ospaceType == null)
{
objectContext.MetadataWorkspace.LoadFromAssembly(clrType.Assembly);
ospaceType = ospaceTypes.Where(t => objectItemCollection.GetClrType(t) == clrType).FirstOrDefault();
if (ospaceType == null)
{
return null;
}
}
return (TStructural)objectContext.MetadataWorkspace.GetEdmSpaceType(ospaceType);
}
#endregion
#region Helpers for creating metadata (csdl/ssdl/msl) files
///
/// Writes an edmx file into the current directory for the model generated from the given model builder.
///
/// The builder.
/// The filename to use for the edmx file.
public static void WriteEdmx(DbModelBuilder builder, string filename)
{
EdmxWriter.WriteEdmx(
builder.Build(new DbProviderInfo("System.Data.SqlClient", "2008")),
XmlWriter.Create(filename));
}
///
/// Writes csdl, msdl, and ssdl files into the current directory for the model generated from
/// the given model builder.
///
/// The builder.
/// The base filename to use for csdl, msdl, and sssl files.
public static void WriteMetadataFiles(DbModelBuilder builder, string filename)
{
var xml = new StringBuilder();
EdmxWriter.WriteEdmx(
builder.Build(new DbProviderInfo("System.Data.SqlClient", "2008")),
XmlWriter.Create(xml));
WriteMetadataFiles(xml.ToString(), filename);
}
///
/// Takes the edmx given as input and splits it into csdl, msl, and ssdl files that are written to the
/// current directory.
///
/// The edmx. (Note that this is NOT the filename of an edmx file; it is the actual edmx.)
/// The base filename to use for csdl, msdl, and sssl files.
public static void WriteMetadataFiles(string edmx, string filename)
{
var csdlNameV2 = (XNamespace)"http://schemas.microsoft.com/ado/2008/09/edm" + "Schema";
var ssdlNameV2 = (XNamespace)"http://schemas.microsoft.com/ado/2009/02/edm/ssdl" + "Schema";
var mslNameV2 = (XNamespace)"http://schemas.microsoft.com/ado/2008/09/mapping/cs" + "Mapping";
var csdlNameV3 = (XNamespace)"http://schemas.microsoft.com/ado/2009/11/edm" + "Schema";
var ssdlNameV3 = (XNamespace)"http://schemas.microsoft.com/ado/2009/11/edm/ssdl" + "Schema";
var mslNameV3 = (XNamespace)"http://schemas.microsoft.com/ado/2009/11/mapping/cs" + "Mapping";
var edmxDoc = XDocument.Load(new StringReader(edmx));
WriteMetadataFile(
filename + ".csdl",
ExtractMetadataContent(edmxDoc, "ConceptualModels", csdlNameV2, csdlNameV3));
WriteMetadataFile(
filename + ".ssdl",
ExtractMetadataContent(edmxDoc, "StorageModels", ssdlNameV2, ssdlNameV3));
WriteMetadataFile(filename + ".msl", ExtractMetadataContent(edmxDoc, "Mappings", mslNameV2, mslNameV3));
}
private static void WriteMetadataFile(string filename, XElement element)
{
Debug.Assert(element != null, "Expected to find element");
using (var writer = XmlWriter.Create(filename))
{
element.Save(writer);
}
}
private static XElement ExtractMetadataContent(XDocument edmxDoc, string part, params XName[] elements)
{
XNamespace edmxnsV2 = "http://schemas.microsoft.com/ado/2008/10/edmx";
XNamespace edmxnsV3 = "http://schemas.microsoft.com/ado/2009/11/edmx";
var edmxNode = edmxDoc.Element(edmxnsV2 + "Edmx") ?? edmxDoc.Element(edmxnsV3 + "Edmx");
Debug.Assert(edmxNode != null, "Expected to find edmx node.");
var runtimeNode = edmxNode.Element(edmxnsV2 + "Runtime") ?? edmxNode.Element(edmxnsV3 + "Runtime");
Debug.Assert(runtimeNode != null, "Expected to find runtime node.");
var partNode = runtimeNode.Element(edmxnsV2 + part) ?? runtimeNode.Element(edmxnsV3 + part);
Debug.Assert(partNode != null, "Expected to find " + part);
return partNode.Element(elements[0]) ?? partNode.Element(elements[1]);
}
#endregion
}
}