//---------------------------------------------------------------------------- // Copyright (c) Microsoft Corporation. All rights reserved. //---------------------------------------------------------------------------- namespace System.ServiceModel.Channels { using System.Diagnostics; using System.Net; using System.Runtime; using System.Security; using System.Security.Authentication.ExtendedProtection; using System.ServiceModel; using System.ServiceModel.Diagnostics; using System.ServiceModel.Diagnostics.Application; using System.ServiceModel.Dispatcher; using System.Threading; using System.Runtime.Diagnostics; class SharedHttpTransportManager : HttpTransportManager { int maxPendingAccepts; HttpListener listener; ManualResetEvent listenStartedEvent; Exception listenStartedException; AsyncCallback onGetContext; AsyncCallback onContextReceived; Action onMessageDequeued; Action onCompleteGetContextLater; bool unsafeConnectionNtlmAuthentication; ReaderWriterLockSlim listenerRWLock; internal SharedHttpTransportManager(Uri listenUri, HttpChannelListener channelListener) : base(listenUri, channelListener.HostNameComparisonMode, channelListener.Realm) { this.onGetContext = Fx.ThunkCallback(new AsyncCallback(OnGetContext)); this.onMessageDequeued = new Action(OnMessageDequeued); this.unsafeConnectionNtlmAuthentication = channelListener.UnsafeConnectionNtlmAuthentication; this.onContextReceived = new AsyncCallback(this.HandleHttpContextReceived); this.listenerRWLock = new ReaderWriterLockSlim(); this.maxPendingAccepts = channelListener.MaxPendingAccepts; } // We are NOT checking the RequestInitializationTimeout here since the HttpChannelListener should be handle them // individually. However, some of the scenarios might be impacted, e.g., if we have one endpoint with high RequestInitializationTimeout // and the other is just normal, the first endpoint might be occupying all the receiving loops, then the requests to the normal endpoint // will experience timeout issues. The mitigation for this issue is that customers should be able to increase the MaxPendingAccepts number. internal override bool IsCompatible(HttpChannelListener channelListener) { if (channelListener.InheritBaseAddressSettings) return true; if (!channelListener.IsScopeIdCompatible(HostNameComparisonMode, this.ListenUri)) { return false; } if (this.maxPendingAccepts != channelListener.MaxPendingAccepts) { return false; } return channelListener.UnsafeConnectionNtlmAuthentication == this.unsafeConnectionNtlmAuthentication && base.IsCompatible(channelListener); } internal override void OnClose(TimeSpan timeout) { Cleanup(false, timeout); } internal override void OnAbort() { Cleanup(true, TimeSpan.Zero); base.OnAbort(); } void Cleanup(bool aborting, TimeSpan timeout) { using (LockHelper.TakeWriterLock(this.listenerRWLock)) { HttpListener listenerSnapshot = this.listener; if (listenerSnapshot == null) { return; } try { listenerSnapshot.Stop(); } finally { try { listenerSnapshot.Close(); } finally { if (!aborting) { base.OnClose(timeout); } else { base.OnAbort(); } } } this.listener = null; } } [Fx.Tag.SecurityNote(Critical = "Calls into critical method ExecutionContext.SuppressFlow", Safe = "Doesn't leak information\\resources; the callback that is invoked is safe")] [SecuritySafeCritical] IAsyncResult BeginGetContext(bool startListening) { EventTraceActivity eventTraceActivity = null; if (FxTrace.Trace.IsEnd2EndActivityTracingEnabled) { eventTraceActivity = EventTraceActivity.GetFromThreadOrCreate(true); if (TD.HttpGetContextStartIsEnabled()) { TD.HttpGetContextStart(eventTraceActivity); } } while (true) { Exception unexpectedException = null; try { try { if (ExecutionContext.IsFlowSuppressed()) { return this.BeginGetContextCore(eventTraceActivity); } else { using (ExecutionContext.SuppressFlow()) { return this.BeginGetContextCore(eventTraceActivity); } } } catch (HttpListenerException e) { if (!this.HandleHttpException(e)) { throw; } } } catch (Exception e) { if (Fx.IsFatal(e)) { throw; } if (startListening) { // Since we're under a call to StartListening(), just throw the exception up the stack. throw; } unexpectedException = e; } if (unexpectedException != null) { this.Fault(unexpectedException); return null; } } } IAsyncResult BeginGetContextCore(EventTraceActivity eventTraceActivity) { using (LockHelper.TakeReaderLock(this.listenerRWLock)) { if (this.listener == null) { return null; } return this.listener.BeginGetContext(onGetContext, eventTraceActivity); } } void OnGetContext(IAsyncResult result) { if (result.CompletedSynchronously) { return; } OnGetContextCore(result); } void OnCompleteGetContextLater(object state) { OnGetContextCore((IAsyncResult)state); } void OnGetContextCore(IAsyncResult listenerContextResult) { Fx.Assert(listenerContextResult != null, "listenerContextResult cannot be null."); bool enqueued = false; while (!enqueued) { Exception unexpectedException = null; try { try { enqueued = this.EnqueueContext(listenerContextResult); } catch (HttpListenerException e) { if (!this.HandleHttpException(e)) { throw; } } } catch (Exception exception) { if (Fx.IsFatal(exception)) { throw; } unexpectedException = exception; } if (unexpectedException != null) { this.Fault(unexpectedException); } // NormalHttpPipeline calls HttpListener.BeginGetContext() by itself (via its dequeuedCallback) in the short-circuit case // when there was no error processing the inboud request (see the comments in the NormalHttpPipeline.Close() for details). if (!enqueued) // onMessageDequeued will handle this in the enqueued case { // Continue the loop with the async result if it completed synchronously. listenerContextResult = this.BeginGetContext(false); if ((listenerContextResult == null) || !listenerContextResult.CompletedSynchronously) { return; } } } } bool EnqueueContext(IAsyncResult listenerContextResult) { EventTraceActivity eventTraceActivity = null; HttpListenerContext listenerContext; bool enqueued = false; if (FxTrace.Trace.IsEnd2EndActivityTracingEnabled) { eventTraceActivity = (EventTraceActivity)listenerContextResult.AsyncState; if (eventTraceActivity == null) { eventTraceActivity = EventTraceActivity.GetFromThreadOrCreate(true); } } using (LockHelper.TakeReaderLock(this.listenerRWLock)) { if (this.listener == null) { return true; } listenerContext = this.listener.EndGetContext(listenerContextResult); } // Grab the activity from the context and set that as the surrounding activity. // If a message appears, we will transfer to the message's activity next using (DiagnosticUtility.ShouldUseActivity ? ServiceModelActivity.BoundOperation(this.Activity) : null) { ServiceModelActivity activity = DiagnosticUtility.ShouldUseActivity ? ServiceModelActivity.CreateBoundedActivityWithTransferInOnly(listenerContext.Request.RequestTraceIdentifier) : null; try { if (activity != null) { StartReceiveBytesActivity(activity, listenerContext.Request.Url); } if (DiagnosticUtility.ShouldTraceInformation) { TraceUtility.TraceHttpConnectionInformation(listenerContext.Request.LocalEndPoint.ToString(), listenerContext.Request.RemoteEndPoint.ToString(), this); } base.TraceMessageReceived(eventTraceActivity, this.ListenUri); HttpChannelListener channelListener; if (base.TryLookupUri(listenerContext.Request.Url, listenerContext.Request.HttpMethod, this.HostNameComparisonMode, listenerContext.Request.IsWebSocketRequest, out channelListener)) { HttpRequestContext context = HttpRequestContext.CreateContext(channelListener, listenerContext, eventTraceActivity); IAsyncResult httpContextReceivedResult = channelListener.BeginHttpContextReceived(context, onMessageDequeued, onContextReceived, DiagnosticUtility.ShouldUseActivity ? (object)new ActivityHolder(activity, context) : (object)context); if (httpContextReceivedResult.CompletedSynchronously) { enqueued = EndHttpContextReceived(httpContextReceivedResult); } else { // The callback has been enqueued. enqueued = true; } } else { HandleMessageReceiveFailed(listenerContext); } } finally { if (DiagnosticUtility.ShouldUseActivity && activity != null) { if (!enqueued) { // Error during enqueuing activity.Dispose(); } } } } return enqueued; } void HandleHttpContextReceived(IAsyncResult httpContextReceivedResult) { if (httpContextReceivedResult.CompletedSynchronously) { return; } bool enqueued = false; Exception unexpectedException = null; try { try { enqueued = EndHttpContextReceived(httpContextReceivedResult); } catch (HttpListenerException e) { if (!this.HandleHttpException(e)) { throw; } } } catch (Exception exception) { if (Fx.IsFatal(exception)) { throw; } unexpectedException = exception; } if (unexpectedException != null) { this.Fault(unexpectedException); } IAsyncResult listenerContextResult = null; if (!enqueued) // onMessageDequeued will handle this in the enqueued case { listenerContextResult = this.BeginGetContext(false); if ((listenerContextResult == null) || !listenerContextResult.CompletedSynchronously) { return; } // Handle the context and continue the receive loop. this.OnGetContextCore(listenerContextResult); } } static bool EndHttpContextReceived(IAsyncResult httpContextReceivedResult) { using (DiagnosticUtility.ShouldUseActivity ? (ActivityHolder)httpContextReceivedResult.AsyncState : null) { HttpChannelListener channelListener = (DiagnosticUtility.ShouldUseActivity ? ((ActivityHolder)httpContextReceivedResult.AsyncState).context : (HttpRequestContext)httpContextReceivedResult.AsyncState).Listener; return channelListener.EndHttpContextReceived(httpContextReceivedResult); } } bool HandleHttpException(HttpListenerException e) { switch (e.ErrorCode) { case UnsafeNativeMethods.ERROR_NOT_ENOUGH_MEMORY: case UnsafeNativeMethods.ERROR_OUTOFMEMORY: case UnsafeNativeMethods.ERROR_NO_SYSTEM_RESOURCES: throw DiagnosticUtility.ExceptionUtility.ThrowHelperError(new InsufficientMemoryException(SR.GetString(SR.InsufficentMemory), e)); default: return ExceptionHandler.HandleTransportExceptionHelper(e); } } static void HandleMessageReceiveFailed(HttpListenerContext listenerContext) { TraceMessageReceiveFailed(); // no match -- 405 or 404 if (string.Compare(listenerContext.Request.HttpMethod, "POST", StringComparison.OrdinalIgnoreCase) != 0) { listenerContext.Response.StatusCode = (int)HttpStatusCode.MethodNotAllowed; listenerContext.Response.Headers.Add(HttpResponseHeader.Allow, "POST"); } else { listenerContext.Response.StatusCode = (int)HttpStatusCode.NotFound; } listenerContext.Response.ContentLength64 = 0; listenerContext.Response.Close(); } static void TraceMessageReceiveFailed() { if (TD.HttpMessageReceiveStartIsEnabled()) { TD.HttpMessageReceiveFailed(); } if (DiagnosticUtility.ShouldTraceWarning) { TraceUtility.TraceEvent(TraceEventType.Warning, TraceCode.HttpChannelMessageReceiveFailed, SR.GetString(SR.TraceCodeHttpChannelMessageReceiveFailed), (object)null); } } void StartListening() { for (int i = 0; i < maxPendingAccepts; i++) { IAsyncResult result = this.BeginGetContext(true); if (result.CompletedSynchronously) { if (onCompleteGetContextLater == null) { onCompleteGetContextLater = new Action(OnCompleteGetContextLater); } ActionItem.Schedule(onCompleteGetContextLater, result); } } } void OnListening(object state) { try { this.StartListening(); } catch (Exception e) { if (Fx.IsFatal(e)) { throw; } this.listenStartedException = e; } finally { this.listenStartedEvent.Set(); } } void OnMessageDequeued() { ThreadTrace.Trace("message dequeued"); IAsyncResult result = this.BeginGetContext(false); if (result != null && result.CompletedSynchronously) { if (onCompleteGetContextLater == null) { onCompleteGetContextLater = new Action(OnCompleteGetContextLater); } ActionItem.Schedule(onCompleteGetContextLater, result); } } internal override void OnOpen() { listener = new HttpListener(); string host; switch (HostNameComparisonMode) { case HostNameComparisonMode.Exact: // Uri.DnsSafeHost strips the [], but preserves the scopeid for IPV6 addresses. if (ListenUri.HostNameType == UriHostNameType.IPv6) { host = string.Concat("[", ListenUri.DnsSafeHost, "]"); } else { host = ListenUri.NormalizedHost(); } break; case HostNameComparisonMode.StrongWildcard: host = "+"; break; case HostNameComparisonMode.WeakWildcard: host = "*"; break; default: throw DiagnosticUtility.ExceptionUtility.ThrowHelperError(new InvalidOperationException(SR.GetString(SR.UnrecognizedHostNameComparisonMode, HostNameComparisonMode.ToString()))); } string path = ListenUri.GetComponents(UriComponents.Path, UriFormat.Unescaped); if (!path.StartsWith("/", StringComparison.Ordinal)) path = "/" + path; if (!path.EndsWith("/", StringComparison.Ordinal)) path = path + "/"; string httpListenUrl = string.Concat(Scheme, "://", host, ":", ListenUri.Port, path); listener.UnsafeConnectionNtlmAuthentication = this.unsafeConnectionNtlmAuthentication; listener.AuthenticationSchemeSelectorDelegate = new AuthenticationSchemeSelector(SelectAuthenticationScheme); if (ExtendedProtectionPolicy.OSSupportsExtendedProtection) { //This API will throw if on an unsupported platform. listener.ExtendedProtectionSelectorDelegate = new HttpListener.ExtendedProtectionSelector(SelectExtendedProtectionPolicy); } if (this.Realm != null) { listener.Realm = this.Realm; } bool success = false; try { listener.Prefixes.Add(httpListenUrl); listener.Start(); bool startedListening = false; try { if (Thread.CurrentThread.IsThreadPoolThread) { StartListening(); } else { // If we're not on a threadpool thread, then we need to post a callback to start our accepting loop // Otherwise if the calling thread aborts then the async I/O will get inadvertantly cancelled listenStartedEvent = new ManualResetEvent(false); ActionItem.Schedule(OnListening, null); listenStartedEvent.WaitOne(); listenStartedEvent.Close(); listenStartedEvent = null; if (listenStartedException != null) { throw DiagnosticUtility.ExceptionUtility.ThrowHelperError(listenStartedException); } } startedListening = true; } finally { if (!startedListening) { listener.Stop(); } } success = true; } catch (HttpListenerException listenerException) { switch (listenerException.NativeErrorCode) { case UnsafeNativeMethods.ERROR_ALREADY_EXISTS: throw DiagnosticUtility.ExceptionUtility.ThrowHelperError(new AddressAlreadyInUseException(SR.GetString(SR.HttpRegistrationAlreadyExists, httpListenUrl), listenerException)); case UnsafeNativeMethods.ERROR_SHARING_VIOLATION: throw DiagnosticUtility.ExceptionUtility.ThrowHelperError(new AddressAlreadyInUseException(SR.GetString(SR.HttpRegistrationPortInUse, httpListenUrl, ListenUri.Port), listenerException)); case UnsafeNativeMethods.ERROR_ACCESS_DENIED: throw DiagnosticUtility.ExceptionUtility.ThrowHelperError(new AddressAccessDeniedException(SR.GetString(SR.HttpRegistrationAccessDenied, httpListenUrl), listenerException)); case UnsafeNativeMethods.ERROR_ALLOTTED_SPACE_EXCEEDED: throw DiagnosticUtility.ExceptionUtility.ThrowHelperError(new CommunicationException(SR.GetString(SR.HttpRegistrationLimitExceeded, httpListenUrl), listenerException)); case UnsafeNativeMethods.ERROR_INVALID_PARAMETER: throw DiagnosticUtility.ExceptionUtility.ThrowHelperError(new InvalidOperationException(SR.GetString(SR.HttpInvalidListenURI, ListenUri.OriginalString), listenerException)); default: throw DiagnosticUtility.ExceptionUtility.ThrowHelperError( HttpChannelUtilities.CreateCommunicationException(listenerException)); } } finally { if (!success) { listener.Abort(); } } } AuthenticationSchemes SelectAuthenticationScheme(HttpListenerRequest request) { try { AuthenticationSchemes result; HttpChannelListener channelListener; if (base.TryLookupUri(request.Url, request.HttpMethod, this.HostNameComparisonMode, request.IsWebSocketRequest, out channelListener)) { result = channelListener.AuthenticationScheme; } else { // if we don't match a listener factory, we want to "fall through" the // auth delegate code and run through our normal OnGetContext codepath. // System.Net treats "None" as Access Denied, which is not our intent here. // In most cases this will just fall through to the code that returns a "404 Not Found" result = AuthenticationSchemes.Anonymous; } return result; } catch (Exception e) { DiagnosticUtility.TraceHandledException(e, TraceEventType.Error); throw; } } ExtendedProtectionPolicy SelectExtendedProtectionPolicy(HttpListenerRequest request) { ExtendedProtectionPolicy result = null; try { HttpChannelListener channelListener; if (base.TryLookupUri(request.Url, request.HttpMethod, this.HostNameComparisonMode, request.IsWebSocketRequest, out channelListener)) { result = channelListener.ExtendedProtectionPolicy; } else { //if the listener isn't found, then the auth scheme will be anonymous //(see SelectAuthenticationScheme function) and will fall through to the //404 Not Found code path, so it doesn't really matter what we return from here... result = ChannelBindingUtility.DisabledPolicy; } return result; } catch (Exception e) { DiagnosticUtility.TraceHandledException(e, TraceEventType.Error); throw; } } } }