You've already forked linux-packaging-mono
							
							
		
			
	
	
		
			343 lines
		
	
	
		
			15 KiB
		
	
	
	
		
			C#
		
	
	
	
	
	
		
		
			
		
	
	
			343 lines
		
	
	
		
			15 KiB
		
	
	
	
		
			C#
		
	
	
	
	
	
|   | //---------------------------------------------------------- | ||
|  | // Copyright (c) Microsoft Corporation.  All rights reserved. | ||
|  | //------------------------------------------------------------ | ||
|  | 
 | ||
|  | namespace System.ServiceModel.Security | ||
|  | { | ||
|  |     using System.Net; | ||
|  |     using System.ServiceModel.Channels; | ||
|  |     using System.ServiceModel; | ||
|  |     using System.Net.Sockets; | ||
|  |     using System.Collections.ObjectModel; | ||
|  |     using System.IdentityModel.Selectors; | ||
|  |     using System.IdentityModel.Claims; | ||
|  |     using System.IdentityModel.Policy; | ||
|  |     using System.IdentityModel.Tokens; | ||
|  |     using System.Security.Principal; | ||
|  |     using System.ServiceModel.Security.Tokens; | ||
|  |     using System.Collections.Generic; | ||
|  |     using System.Runtime.Serialization; | ||
|  |     using System.Globalization; | ||
|  |     using System.ServiceModel.Diagnostics;     | ||
|  |     using System.ServiceModel.Diagnostics.Application; | ||
|  |     using System.Runtime.Diagnostics; | ||
|  | 
 | ||
|  |     public abstract class IdentityVerifier | ||
|  |     { | ||
|  |         protected IdentityVerifier() | ||
|  |         { | ||
|  |             // empty | ||
|  |         } | ||
|  | 
 | ||
|  |         public static IdentityVerifier CreateDefault() | ||
|  |         { | ||
|  |             return DefaultIdentityVerifier.Instance; | ||
|  |         } | ||
|  | 
 | ||
|  |         internal bool CheckAccess(EndpointAddress reference, Message message) | ||
|  |         { | ||
|  |             if (reference == null) | ||
|  |                 throw DiagnosticUtility.ExceptionUtility.ThrowHelperArgumentNull("reference"); | ||
|  |             if (message == null) | ||
|  |                 throw DiagnosticUtility.ExceptionUtility.ThrowHelperArgumentNull("message"); | ||
|  | 
 | ||
|  |             EndpointIdentity identity; | ||
|  |             if (!this.TryGetIdentity(reference, out identity)) | ||
|  |                 return false; | ||
|  | 
 | ||
|  |             SecurityMessageProperty securityContextProperty = null; | ||
|  |             if (message.Properties != null) | ||
|  |                 securityContextProperty = message.Properties.Security; | ||
|  | 
 | ||
|  |             if (securityContextProperty == null || securityContextProperty.ServiceSecurityContext == null) | ||
|  |                 return false; | ||
|  | 
 | ||
|  |             return this.CheckAccess(identity, securityContextProperty.ServiceSecurityContext.AuthorizationContext); | ||
|  |         } | ||
|  | 
 | ||
|  |         public abstract bool CheckAccess(EndpointIdentity identity, AuthorizationContext authContext); | ||
|  | 
 | ||
|  |         public abstract bool TryGetIdentity(EndpointAddress reference, out EndpointIdentity identity); | ||
|  | 
 | ||
|  |         static void AdjustAddress(ref EndpointAddress reference, Uri via) | ||
|  |         { | ||
|  |             // if we don't have an identity and we have differing Uris, we should use the Via | ||
|  |             if (reference.Identity == null && reference.Uri != via) | ||
|  |             { | ||
|  |                 reference = new EndpointAddress(via); | ||
|  |             } | ||
|  |         } | ||
|  | 
 | ||
|  |         internal bool TryGetIdentity(EndpointAddress reference, Uri via, out EndpointIdentity identity) | ||
|  |         { | ||
|  |             AdjustAddress(ref reference, via); | ||
|  |             return this.TryGetIdentity(reference, out identity); | ||
|  |         } | ||
|  | 
 | ||
|  |         internal void EnsureIncomingIdentity(EndpointAddress serviceReference, AuthorizationContext authorizationContext) | ||
|  |         { | ||
|  |             EnsureIdentity(serviceReference, authorizationContext, SR.IdentityCheckFailedForIncomingMessage); | ||
|  |         } | ||
|  | 
 | ||
|  |         internal void EnsureOutgoingIdentity(EndpointAddress serviceReference, Uri via, AuthorizationContext authorizationContext) | ||
|  |         { | ||
|  |             AdjustAddress(ref serviceReference, via); | ||
|  |             this.EnsureIdentity(serviceReference, authorizationContext, SR.IdentityCheckFailedForOutgoingMessage); | ||
|  |         } | ||
|  | 
 | ||
|  |         internal void EnsureOutgoingIdentity(EndpointAddress serviceReference, ReadOnlyCollection<IAuthorizationPolicy> authorizationPolicies) | ||
|  |         { | ||
|  |             if (authorizationPolicies == null) | ||
|  |             { | ||
|  |                 throw DiagnosticUtility.ExceptionUtility.ThrowHelperArgumentNull("authorizationPolicies"); | ||
|  |             } | ||
|  |             AuthorizationContext ac = AuthorizationContext.CreateDefaultAuthorizationContext(authorizationPolicies); | ||
|  |             EnsureIdentity(serviceReference, ac, SR.IdentityCheckFailedForOutgoingMessage); | ||
|  |         } | ||
|  | 
 | ||
|  |         void EnsureIdentity(EndpointAddress serviceReference, AuthorizationContext authorizationContext, String errorString) | ||
|  |         { | ||
|  |             if (authorizationContext == null) | ||
|  |             { | ||
|  |                 throw DiagnosticUtility.ExceptionUtility.ThrowHelperArgumentNull("authorizationContext"); | ||
|  |             } | ||
|  |             EndpointIdentity identity; | ||
|  |             if (!TryGetIdentity(serviceReference, out identity)) | ||
|  |             { | ||
|  |                 SecurityTraceRecordHelper.TraceIdentityVerificationFailure(identity, authorizationContext, this.GetType()); | ||
|  |                 throw DiagnosticUtility.ExceptionUtility.ThrowHelperWarning(new MessageSecurityException(SR.GetString(errorString, identity, serviceReference))); | ||
|  |             } | ||
|  |             else | ||
|  |             { | ||
|  |                 if (!CheckAccess(identity, authorizationContext)) | ||
|  |                 { | ||
|  |                     // CheckAccess performs a Trace on failure, no need to do it twice | ||
|  |                     Exception e = CreateIdentityCheckException(identity, authorizationContext, errorString, serviceReference); | ||
|  |                     throw DiagnosticUtility.ExceptionUtility.ThrowHelperWarning(e); | ||
|  |                 } | ||
|  |             } | ||
|  |         } | ||
|  | 
 | ||
|  |         Exception CreateIdentityCheckException(EndpointIdentity identity, AuthorizationContext authorizationContext, string errorString, EndpointAddress serviceReference) | ||
|  |         { | ||
|  |             Exception result; | ||
|  | 
 | ||
|  |             if (identity.IdentityClaim != null | ||
|  |                 && identity.IdentityClaim.ClaimType == ClaimTypes.Dns | ||
|  |                 && identity.IdentityClaim.Right == Rights.PossessProperty | ||
|  |                 && identity.IdentityClaim.Resource is string) | ||
|  |             { | ||
|  |                 string expectedDnsName = (string)identity.IdentityClaim.Resource; | ||
|  |                 string actualDnsName = null; | ||
|  |                 for (int i = 0; i < authorizationContext.ClaimSets.Count; ++i) | ||
|  |                 { | ||
|  |                     ClaimSet claimSet = authorizationContext.ClaimSets[i]; | ||
|  |                     foreach (Claim claim in claimSet.FindClaims(ClaimTypes.Dns, Rights.PossessProperty)) | ||
|  |                     { | ||
|  |                         if (claim.Resource is string) | ||
|  |                         { | ||
|  |                             actualDnsName = (string)claim.Resource; | ||
|  |                             break; | ||
|  |                         } | ||
|  |                     } | ||
|  |                     if (actualDnsName != null) | ||
|  |                     { | ||
|  |                         break; | ||
|  |                     } | ||
|  |                 } | ||
|  |                 if (SR.IdentityCheckFailedForIncomingMessage.Equals(errorString)) | ||
|  |                 { | ||
|  |                     if (actualDnsName == null) | ||
|  |                     { | ||
|  |                         result = new MessageSecurityException(SR.GetString(SR.DnsIdentityCheckFailedForIncomingMessageLackOfDnsClaim, expectedDnsName)); | ||
|  |                     } | ||
|  |                     else | ||
|  |                     { | ||
|  |                         result = new MessageSecurityException(SR.GetString(SR.DnsIdentityCheckFailedForIncomingMessage, expectedDnsName, actualDnsName)); | ||
|  |                     } | ||
|  |                 } | ||
|  |                 else if (SR.IdentityCheckFailedForOutgoingMessage.Equals(errorString)) | ||
|  |                 { | ||
|  |                     if (actualDnsName == null) | ||
|  |                     { | ||
|  |                         result = new MessageSecurityException(SR.GetString(SR.DnsIdentityCheckFailedForOutgoingMessageLackOfDnsClaim, expectedDnsName)); | ||
|  |                     } | ||
|  |                     else | ||
|  |                     { | ||
|  |                         result = new MessageSecurityException(SR.GetString(SR.DnsIdentityCheckFailedForOutgoingMessage, expectedDnsName, actualDnsName)); | ||
|  |                     } | ||
|  |                 } | ||
|  |                 else | ||
|  |                 { | ||
|  |                     result = new MessageSecurityException(SR.GetString(errorString, identity, serviceReference)); | ||
|  |                 } | ||
|  |             } | ||
|  |             else | ||
|  |             { | ||
|  |                 result = new MessageSecurityException(SR.GetString(errorString, identity, serviceReference)); | ||
|  |             } | ||
|  | 
 | ||
|  |             return result; | ||
|  |         } | ||
|  | 
 | ||
|  |         class DefaultIdentityVerifier : IdentityVerifier | ||
|  |         { | ||
|  |             static readonly DefaultIdentityVerifier instance = new DefaultIdentityVerifier(); | ||
|  | 
 | ||
|  |             public static DefaultIdentityVerifier Instance | ||
|  |             { | ||
|  |                 get { return instance; } | ||
|  |             } | ||
|  | 
 | ||
|  |             public override bool TryGetIdentity(EndpointAddress reference, out EndpointIdentity identity) | ||
|  |             { | ||
|  |                 if (reference == null) | ||
|  |                     throw DiagnosticUtility.ExceptionUtility.ThrowHelperArgumentNull("reference"); | ||
|  | 
 | ||
|  |                 identity = reference.Identity; | ||
|  | 
 | ||
|  |                 if (identity == null) | ||
|  |                 { | ||
|  |                     identity = this.TryCreateDnsIdentity(reference); | ||
|  |                 } | ||
|  | 
 | ||
|  |                 if (identity == null) | ||
|  |                 { | ||
|  |                     SecurityTraceRecordHelper.TraceIdentityDeterminationFailure(reference, typeof(DefaultIdentityVerifier)); | ||
|  |                     return false; | ||
|  |                 } | ||
|  |                 else | ||
|  |                 { | ||
|  |                     SecurityTraceRecordHelper.TraceIdentityDeterminationSuccess(reference, identity, typeof(DefaultIdentityVerifier)); | ||
|  |                     return true; | ||
|  |                 } | ||
|  |             } | ||
|  | 
 | ||
|  |             EndpointIdentity TryCreateDnsIdentity(EndpointAddress reference) | ||
|  |             { | ||
|  |                 Uri toAddress = reference.Uri; | ||
|  | 
 | ||
|  |                 if (!toAddress.IsAbsoluteUri) | ||
|  |                     return null; | ||
|  | 
 | ||
|  |                 return EndpointIdentity.CreateDnsIdentity(toAddress.DnsSafeHost); | ||
|  |             } | ||
|  | 
 | ||
|  |             SecurityIdentifier GetSecurityIdentifier(Claim claim) | ||
|  |             { | ||
|  |                 // if the incoming claim is a SID and the EndpointIdentity is UPN/SPN/DNS, try to find the SID corresponding to | ||
|  |                 // the UPN/SPN/DNS (transactions case) | ||
|  |                 if (claim.Resource is WindowsIdentity) | ||
|  |                     return ((WindowsIdentity)claim.Resource).User; | ||
|  |                 else if (claim.Resource is WindowsSidIdentity) | ||
|  |                     return ((WindowsSidIdentity)claim.Resource).SecurityIdentifier; | ||
|  |                 return claim.Resource as SecurityIdentifier; | ||
|  |             } | ||
|  | 
 | ||
|  |             Claim CheckDnsEquivalence(ClaimSet claimSet, string expectedSpn) | ||
|  |             { | ||
|  |                 // host/<machine-name> satisfies the DNS identity claim | ||
|  |                 IEnumerable<Claim> claims = claimSet.FindClaims(ClaimTypes.Spn, Rights.PossessProperty); | ||
|  |                 foreach (Claim claim in claims) | ||
|  |                 { | ||
|  |                     if (expectedSpn.Equals((string)claim.Resource, StringComparison.OrdinalIgnoreCase)) | ||
|  |                     { | ||
|  |                         return claim; | ||
|  |                     } | ||
|  |                 } | ||
|  |                 return null; | ||
|  |             } | ||
|  | 
 | ||
|  |             Claim CheckSidEquivalence(SecurityIdentifier identitySid, ClaimSet claimSet) | ||
|  |             { | ||
|  |                 foreach (Claim claim in claimSet) | ||
|  |                 { | ||
|  |                     SecurityIdentifier sid = GetSecurityIdentifier(claim); | ||
|  |                     if (sid != null) | ||
|  |                     { | ||
|  |                         if (identitySid.Equals(sid)) | ||
|  |                         { | ||
|  |                             return claim; | ||
|  |                         } | ||
|  |                     } | ||
|  |                 } | ||
|  |                 return null; | ||
|  |             } | ||
|  | 
 | ||
|  |             public override bool CheckAccess(EndpointIdentity identity, AuthorizationContext authContext) | ||
|  |             { | ||
|  |                 EventTraceActivity eventTraceActivity = null; | ||
|  | 
 | ||
|  |                 if (identity == null) | ||
|  |                     throw DiagnosticUtility.ExceptionUtility.ThrowHelperArgumentNull("identity"); | ||
|  | 
 | ||
|  |                 if (authContext == null) | ||
|  |                     throw DiagnosticUtility.ExceptionUtility.ThrowHelperArgumentNull("authContext"); | ||
|  | 
 | ||
|  | 
 | ||
|  |                 if (FxTrace.Trace.IsEnd2EndActivityTracingEnabled) | ||
|  |                 { | ||
|  |                     eventTraceActivity = EventTraceActivityHelper.TryExtractActivity((OperationContext.Current != null) ? OperationContext.Current.IncomingMessage : null); | ||
|  |                 } | ||
|  | 
 | ||
|  |                 for (int i = 0; i < authContext.ClaimSets.Count; ++i) | ||
|  |                 { | ||
|  |                     ClaimSet claimSet = authContext.ClaimSets[i]; | ||
|  |                     if (claimSet.ContainsClaim(identity.IdentityClaim)) | ||
|  |                     { | ||
|  |                         SecurityTraceRecordHelper.TraceIdentityVerificationSuccess(eventTraceActivity, identity, identity.IdentityClaim, this.GetType()); | ||
|  |                         return true; | ||
|  |                     } | ||
|  | 
 | ||
|  |                     // try Claim equivalence | ||
|  |                     string expectedSpn = null; | ||
|  |                     if (ClaimTypes.Dns.Equals(identity.IdentityClaim.ClaimType)) | ||
|  |                     { | ||
|  |                         expectedSpn = string.Format(CultureInfo.InvariantCulture, "host/{0}", (string)identity.IdentityClaim.Resource); | ||
|  |                         Claim claim = CheckDnsEquivalence(claimSet, expectedSpn); | ||
|  |                         if (claim != null) | ||
|  |                         { | ||
|  |                             SecurityTraceRecordHelper.TraceIdentityVerificationSuccess(eventTraceActivity, identity, claim, this.GetType()); | ||
|  |                             return true; | ||
|  |                         } | ||
|  |                     } | ||
|  |                     // Allow a Sid claim to support UPN, and SPN identities | ||
|  |                     SecurityIdentifier identitySid = null; | ||
|  |                     if (ClaimTypes.Sid.Equals(identity.IdentityClaim.ClaimType)) | ||
|  |                     { | ||
|  |                         identitySid = GetSecurityIdentifier(identity.IdentityClaim); | ||
|  |                     } | ||
|  |                     else if (ClaimTypes.Upn.Equals(identity.IdentityClaim.ClaimType)) | ||
|  |                     { | ||
|  |                         identitySid = ((UpnEndpointIdentity)identity).GetUpnSid(); | ||
|  |                     } | ||
|  |                     else if (ClaimTypes.Spn.Equals(identity.IdentityClaim.ClaimType)) | ||
|  |                     { | ||
|  |                         identitySid = ((SpnEndpointIdentity)identity).GetSpnSid(); | ||
|  |                     } | ||
|  |                     else if (ClaimTypes.Dns.Equals(identity.IdentityClaim.ClaimType)) | ||
|  |                     { | ||
|  |                         identitySid = new SpnEndpointIdentity(expectedSpn).GetSpnSid(); | ||
|  |                     } | ||
|  |                     if (identitySid != null) | ||
|  |                     { | ||
|  |                         Claim claim = CheckSidEquivalence(identitySid, claimSet); | ||
|  |                         if (claim != null) | ||
|  |                         { | ||
|  |                             SecurityTraceRecordHelper.TraceIdentityVerificationSuccess(eventTraceActivity, identity, claim, this.GetType()); | ||
|  |                             return true; | ||
|  |                         } | ||
|  |                     } | ||
|  |                 } | ||
|  |                 SecurityTraceRecordHelper.TraceIdentityVerificationFailure(identity, authContext, this.GetType()); | ||
|  |                 if (TD.SecurityIdentityVerificationFailureIsEnabled()) | ||
|  |                 { | ||
|  |                     TD.SecurityIdentityVerificationFailure(eventTraceActivity); | ||
|  |                 } | ||
|  | 
 | ||
|  |                 return false; | ||
|  |             } | ||
|  |         } | ||
|  |     } | ||
|  | } |