343 lines
15 KiB
C#
Raw Normal View History

//----------------------------------------------------------
// Copyright (c) Microsoft Corporation. All rights reserved.
//------------------------------------------------------------
namespace System.ServiceModel.Security
{
using System.Net;
using System.ServiceModel.Channels;
using System.ServiceModel;
using System.Net.Sockets;
using System.Collections.ObjectModel;
using System.IdentityModel.Selectors;
using System.IdentityModel.Claims;
using System.IdentityModel.Policy;
using System.IdentityModel.Tokens;
using System.Security.Principal;
using System.ServiceModel.Security.Tokens;
using System.Collections.Generic;
using System.Runtime.Serialization;
using System.Globalization;
using System.ServiceModel.Diagnostics;
using System.ServiceModel.Diagnostics.Application;
using System.Runtime.Diagnostics;
public abstract class IdentityVerifier
{
protected IdentityVerifier()
{
// empty
}
public static IdentityVerifier CreateDefault()
{
return DefaultIdentityVerifier.Instance;
}
internal bool CheckAccess(EndpointAddress reference, Message message)
{
if (reference == null)
throw DiagnosticUtility.ExceptionUtility.ThrowHelperArgumentNull("reference");
if (message == null)
throw DiagnosticUtility.ExceptionUtility.ThrowHelperArgumentNull("message");
EndpointIdentity identity;
if (!this.TryGetIdentity(reference, out identity))
return false;
SecurityMessageProperty securityContextProperty = null;
if (message.Properties != null)
securityContextProperty = message.Properties.Security;
if (securityContextProperty == null || securityContextProperty.ServiceSecurityContext == null)
return false;
return this.CheckAccess(identity, securityContextProperty.ServiceSecurityContext.AuthorizationContext);
}
public abstract bool CheckAccess(EndpointIdentity identity, AuthorizationContext authContext);
public abstract bool TryGetIdentity(EndpointAddress reference, out EndpointIdentity identity);
static void AdjustAddress(ref EndpointAddress reference, Uri via)
{
// if we don't have an identity and we have differing Uris, we should use the Via
if (reference.Identity == null && reference.Uri != via)
{
reference = new EndpointAddress(via);
}
}
internal bool TryGetIdentity(EndpointAddress reference, Uri via, out EndpointIdentity identity)
{
AdjustAddress(ref reference, via);
return this.TryGetIdentity(reference, out identity);
}
internal void EnsureIncomingIdentity(EndpointAddress serviceReference, AuthorizationContext authorizationContext)
{
EnsureIdentity(serviceReference, authorizationContext, SR.IdentityCheckFailedForIncomingMessage);
}
internal void EnsureOutgoingIdentity(EndpointAddress serviceReference, Uri via, AuthorizationContext authorizationContext)
{
AdjustAddress(ref serviceReference, via);
this.EnsureIdentity(serviceReference, authorizationContext, SR.IdentityCheckFailedForOutgoingMessage);
}
internal void EnsureOutgoingIdentity(EndpointAddress serviceReference, ReadOnlyCollection<IAuthorizationPolicy> authorizationPolicies)
{
if (authorizationPolicies == null)
{
throw DiagnosticUtility.ExceptionUtility.ThrowHelperArgumentNull("authorizationPolicies");
}
AuthorizationContext ac = AuthorizationContext.CreateDefaultAuthorizationContext(authorizationPolicies);
EnsureIdentity(serviceReference, ac, SR.IdentityCheckFailedForOutgoingMessage);
}
void EnsureIdentity(EndpointAddress serviceReference, AuthorizationContext authorizationContext, String errorString)
{
if (authorizationContext == null)
{
throw DiagnosticUtility.ExceptionUtility.ThrowHelperArgumentNull("authorizationContext");
}
EndpointIdentity identity;
if (!TryGetIdentity(serviceReference, out identity))
{
SecurityTraceRecordHelper.TraceIdentityVerificationFailure(identity, authorizationContext, this.GetType());
throw DiagnosticUtility.ExceptionUtility.ThrowHelperWarning(new MessageSecurityException(SR.GetString(errorString, identity, serviceReference)));
}
else
{
if (!CheckAccess(identity, authorizationContext))
{
// CheckAccess performs a Trace on failure, no need to do it twice
Exception e = CreateIdentityCheckException(identity, authorizationContext, errorString, serviceReference);
throw DiagnosticUtility.ExceptionUtility.ThrowHelperWarning(e);
}
}
}
Exception CreateIdentityCheckException(EndpointIdentity identity, AuthorizationContext authorizationContext, string errorString, EndpointAddress serviceReference)
{
Exception result;
if (identity.IdentityClaim != null
&& identity.IdentityClaim.ClaimType == ClaimTypes.Dns
&& identity.IdentityClaim.Right == Rights.PossessProperty
&& identity.IdentityClaim.Resource is string)
{
string expectedDnsName = (string)identity.IdentityClaim.Resource;
string actualDnsName = null;
for (int i = 0; i < authorizationContext.ClaimSets.Count; ++i)
{
ClaimSet claimSet = authorizationContext.ClaimSets[i];
foreach (Claim claim in claimSet.FindClaims(ClaimTypes.Dns, Rights.PossessProperty))
{
if (claim.Resource is string)
{
actualDnsName = (string)claim.Resource;
break;
}
}
if (actualDnsName != null)
{
break;
}
}
if (SR.IdentityCheckFailedForIncomingMessage.Equals(errorString))
{
if (actualDnsName == null)
{
result = new MessageSecurityException(SR.GetString(SR.DnsIdentityCheckFailedForIncomingMessageLackOfDnsClaim, expectedDnsName));
}
else
{
result = new MessageSecurityException(SR.GetString(SR.DnsIdentityCheckFailedForIncomingMessage, expectedDnsName, actualDnsName));
}
}
else if (SR.IdentityCheckFailedForOutgoingMessage.Equals(errorString))
{
if (actualDnsName == null)
{
result = new MessageSecurityException(SR.GetString(SR.DnsIdentityCheckFailedForOutgoingMessageLackOfDnsClaim, expectedDnsName));
}
else
{
result = new MessageSecurityException(SR.GetString(SR.DnsIdentityCheckFailedForOutgoingMessage, expectedDnsName, actualDnsName));
}
}
else
{
result = new MessageSecurityException(SR.GetString(errorString, identity, serviceReference));
}
}
else
{
result = new MessageSecurityException(SR.GetString(errorString, identity, serviceReference));
}
return result;
}
class DefaultIdentityVerifier : IdentityVerifier
{
static readonly DefaultIdentityVerifier instance = new DefaultIdentityVerifier();
public static DefaultIdentityVerifier Instance
{
get { return instance; }
}
public override bool TryGetIdentity(EndpointAddress reference, out EndpointIdentity identity)
{
if (reference == null)
throw DiagnosticUtility.ExceptionUtility.ThrowHelperArgumentNull("reference");
identity = reference.Identity;
if (identity == null)
{
identity = this.TryCreateDnsIdentity(reference);
}
if (identity == null)
{
SecurityTraceRecordHelper.TraceIdentityDeterminationFailure(reference, typeof(DefaultIdentityVerifier));
return false;
}
else
{
SecurityTraceRecordHelper.TraceIdentityDeterminationSuccess(reference, identity, typeof(DefaultIdentityVerifier));
return true;
}
}
EndpointIdentity TryCreateDnsIdentity(EndpointAddress reference)
{
Uri toAddress = reference.Uri;
if (!toAddress.IsAbsoluteUri)
return null;
return EndpointIdentity.CreateDnsIdentity(toAddress.DnsSafeHost);
}
SecurityIdentifier GetSecurityIdentifier(Claim claim)
{
// if the incoming claim is a SID and the EndpointIdentity is UPN/SPN/DNS, try to find the SID corresponding to
// the UPN/SPN/DNS (transactions case)
if (claim.Resource is WindowsIdentity)
return ((WindowsIdentity)claim.Resource).User;
else if (claim.Resource is WindowsSidIdentity)
return ((WindowsSidIdentity)claim.Resource).SecurityIdentifier;
return claim.Resource as SecurityIdentifier;
}
Claim CheckDnsEquivalence(ClaimSet claimSet, string expectedSpn)
{
// host/<machine-name> satisfies the DNS identity claim
IEnumerable<Claim> claims = claimSet.FindClaims(ClaimTypes.Spn, Rights.PossessProperty);
foreach (Claim claim in claims)
{
if (expectedSpn.Equals((string)claim.Resource, StringComparison.OrdinalIgnoreCase))
{
return claim;
}
}
return null;
}
Claim CheckSidEquivalence(SecurityIdentifier identitySid, ClaimSet claimSet)
{
foreach (Claim claim in claimSet)
{
SecurityIdentifier sid = GetSecurityIdentifier(claim);
if (sid != null)
{
if (identitySid.Equals(sid))
{
return claim;
}
}
}
return null;
}
public override bool CheckAccess(EndpointIdentity identity, AuthorizationContext authContext)
{
EventTraceActivity eventTraceActivity = null;
if (identity == null)
throw DiagnosticUtility.ExceptionUtility.ThrowHelperArgumentNull("identity");
if (authContext == null)
throw DiagnosticUtility.ExceptionUtility.ThrowHelperArgumentNull("authContext");
if (FxTrace.Trace.IsEnd2EndActivityTracingEnabled)
{
eventTraceActivity = EventTraceActivityHelper.TryExtractActivity((OperationContext.Current != null) ? OperationContext.Current.IncomingMessage : null);
}
for (int i = 0; i < authContext.ClaimSets.Count; ++i)
{
ClaimSet claimSet = authContext.ClaimSets[i];
if (claimSet.ContainsClaim(identity.IdentityClaim))
{
SecurityTraceRecordHelper.TraceIdentityVerificationSuccess(eventTraceActivity, identity, identity.IdentityClaim, this.GetType());
return true;
}
// try Claim equivalence
string expectedSpn = null;
if (ClaimTypes.Dns.Equals(identity.IdentityClaim.ClaimType))
{
expectedSpn = string.Format(CultureInfo.InvariantCulture, "host/{0}", (string)identity.IdentityClaim.Resource);
Claim claim = CheckDnsEquivalence(claimSet, expectedSpn);
if (claim != null)
{
SecurityTraceRecordHelper.TraceIdentityVerificationSuccess(eventTraceActivity, identity, claim, this.GetType());
return true;
}
}
// Allow a Sid claim to support UPN, and SPN identities
SecurityIdentifier identitySid = null;
if (ClaimTypes.Sid.Equals(identity.IdentityClaim.ClaimType))
{
identitySid = GetSecurityIdentifier(identity.IdentityClaim);
}
else if (ClaimTypes.Upn.Equals(identity.IdentityClaim.ClaimType))
{
identitySid = ((UpnEndpointIdentity)identity).GetUpnSid();
}
else if (ClaimTypes.Spn.Equals(identity.IdentityClaim.ClaimType))
{
identitySid = ((SpnEndpointIdentity)identity).GetSpnSid();
}
else if (ClaimTypes.Dns.Equals(identity.IdentityClaim.ClaimType))
{
identitySid = new SpnEndpointIdentity(expectedSpn).GetSpnSid();
}
if (identitySid != null)
{
Claim claim = CheckSidEquivalence(identitySid, claimSet);
if (claim != null)
{
SecurityTraceRecordHelper.TraceIdentityVerificationSuccess(eventTraceActivity, identity, claim, this.GetType());
return true;
}
}
}
SecurityTraceRecordHelper.TraceIdentityVerificationFailure(identity, authContext, this.GetType());
if (TD.SecurityIdentityVerificationFailureIsEnabled())
{
TD.SecurityIdentityVerificationFailure(eventTraceActivity);
}
return false;
}
}
}
}