// // Copyright (c) Microsoft Corporation. All rights reserved. // namespace System.ServiceModel.Channels { using System; using System.Collections.Generic; using System.IO; using System.Net; using System.Net.Http; using System.Net.Http.Headers; using System.Runtime; using System.Security.Principal; using System.Threading.Tasks; public static class HttpRequestMessageExtensionMethods { const string MessageHeadersPropertyKey = "System.ServiceModel.Channels.MessageHeaders"; const string PrincipalKey = "MS_UserPrincipal"; public static void SetUserPrincipal(this HttpRequestMessage httpRequestMessage, IPrincipal user) { if (httpRequestMessage == null) { throw FxTrace.Exception.AsError(new ArgumentNullException("httpRequestMessage")); } httpRequestMessage.Properties[PrincipalKey] = user; } public static IPrincipal GetUserPrincipal(this HttpRequestMessage httpRequestMessage) { if (httpRequestMessage == null) { throw FxTrace.Exception.AsError(new ArgumentNullException("httpRequestMessage")); } object user; if (httpRequestMessage.Properties.TryGetValue(PrincipalKey, out user)) { return user as IPrincipal; } return null; } internal static void AddHeader(this HttpRequestMessage httpRequestMessage, string header, string value) { Fx.Assert(httpRequestMessage != null, "The 'httpRequestMessage' parameter should never be null."); Fx.Assert(!string.IsNullOrWhiteSpace(header), "The 'header' parameter should never be null or whitespace."); HttpHeaderInfo headerInfo = HttpHeaderInfo.Create(header); EnsureNotResponseHeader(headerInfo); AddHeader(httpRequestMessage, headerInfo, value); } internal static void SetHeader(this HttpRequestMessage httpRequestMessage, string header, string value) { Fx.Assert(httpRequestMessage != null, "The 'httpRequestMessage' parameter should never be null."); Fx.Assert(!string.IsNullOrWhiteSpace(header), "The 'header' parameter should never be null or whitespace."); HttpHeaderInfo headerInfo = HttpHeaderInfo.Create(header); EnsureNotResponseHeader(headerInfo); RemoveHeader(httpRequestMessage, headerInfo); AddHeader(httpRequestMessage, headerInfo, value); } internal static IEnumerable GetHeader(this HttpRequestMessage httpRequestMessage, string header) { Fx.Assert(httpRequestMessage != null, "The 'httpRequestMessage' parameter should never be null."); Fx.Assert(!string.IsNullOrWhiteSpace(header), "The 'header' parameter should never be null or whitespace."); HttpHeaderInfo headerInfo = HttpHeaderInfo.Create(header); EnsureNotResponseHeader(headerInfo); return GetHeader(httpRequestMessage, headerInfo); } internal static void RemoveHeader(this HttpRequestMessage httpRequestMessage, string header) { Fx.Assert(httpRequestMessage != null, "The 'httpRequestMessage' parameter should never be null."); Fx.Assert(!string.IsNullOrWhiteSpace(header), "The 'header' parameter should never be null or whitespace."); HttpHeaderInfo headerInfo = HttpHeaderInfo.Create(header); EnsureNotResponseHeader(headerInfo); RemoveHeader(httpRequestMessage, headerInfo); } internal static HttpRequestMessage CreateBufferedCopy(this HttpRequestMessage httpRequestMessage) { Fx.Assert(httpRequestMessage != null, "The 'httpRequestMessage' parameter should never be null."); HttpRequestMessage bufferedHttpRequestMessage = new HttpRequestMessage(); bufferedHttpRequestMessage.RequestUri = httpRequestMessage.RequestUri != null ? new Uri(httpRequestMessage.RequestUri, string.Empty) : null; bufferedHttpRequestMessage.Method = httpRequestMessage.Method != null ? new HttpMethod(httpRequestMessage.Method.Method) : null; bufferedHttpRequestMessage.Version = (Version)(httpRequestMessage.Version != null ? httpRequestMessage.Version.Clone() : null); foreach (KeyValuePair> header in httpRequestMessage.Headers) { bufferedHttpRequestMessage.Headers.AddHeaderWithoutValidation(header); } foreach (KeyValuePair header in httpRequestMessage.Properties) { IMessageProperty messageProperty = header.Value as IMessageProperty; object value = messageProperty != null ? messageProperty.CreateCopy() : header.Value; bufferedHttpRequestMessage.Properties.Add(header.Key, value); } bufferedHttpRequestMessage.Content = CreateBufferedCopyOfContent(httpRequestMessage.Content); return bufferedHttpRequestMessage; } internal static HttpContent CreateBufferedCopyOfContent(HttpContent content) { if (content != null) { SharedByteArrayContent shareableContent = content as SharedByteArrayContent; byte[] contentBytes = shareableContent == null ? content.ReadAsByteArrayAsync().Result : shareableContent.ContentBytes; HttpContent bufferedContent = new SharedByteArrayContent(contentBytes); foreach (KeyValuePair> header in content.Headers) { bufferedContent.Headers.AddHeaderWithoutValidation(header); } return bufferedContent; } return null; } internal static void CopyPropertiesFromMessage(this HttpRequestMessage httpRequestMessage, Message message) { Fx.Assert(httpRequestMessage != null, "The 'httpRequestMessage' parameter should not be null."); Fx.Assert(message != null, "The 'message' parameter should not be null."); IDictionary properties = httpRequestMessage.Properties; CopyProperties(message.Properties, properties); properties[MessageHeadersPropertyKey] = message.Headers; } internal static void AddHeaderWithoutValidation(this HttpHeaders httpHeaders, KeyValuePair> header) { Fx.Assert(httpHeaders != null, "httpHeaders should not be null."); if (!httpHeaders.TryAddWithoutValidation(header.Key, header.Value)) { throw FxTrace.Exception.AsError(new InvalidOperationException(SR.GetString( SR.CopyHttpHeaderFailed, header.Key, header.Value, httpHeaders.GetType().Name))); } } private static void CopyProperties(MessageProperties messageProperties, IDictionary properties) { Fx.Assert(messageProperties != null, "The 'messageProperties' parameter should not be null."); Fx.Assert(properties != null, "The 'properties' parameter should not be null."); foreach (KeyValuePair property in messageProperties) { object value = property.Value; string key = property.Key; if ((value is HttpRequestMessageProperty && string.Equals(key, HttpRequestMessageProperty.Name, StringComparison.OrdinalIgnoreCase)) || (value is HttpResponseMessageProperty && string.Equals(key, HttpResponseMessageProperty.Name, StringComparison.OrdinalIgnoreCase))) { continue; } properties[key] = value; } } private static void EnsureNotResponseHeader(HttpHeaderInfo headerInfo) { if (!headerInfo.IsRequestHeader && !headerInfo.IsContentHeader && headerInfo.IsResponseHeader) { throw FxTrace.Exception.AsError( new InvalidOperationException( SR.GetString(SR.ResponseHeaderWithRequestHeadersCollection, headerInfo.Name))); } } private static IEnumerable GetHeader(HttpRequestMessage httpRequestMessage, HttpHeaderInfo headerInfo) { Fx.Assert(httpRequestMessage != null, "The 'httpRequestMessage' parameter should never be null."); Fx.Assert(headerInfo != null, "The 'headerInfo' parameter should never be null."); Fx.Assert(headerInfo.IsRequestHeader || headerInfo.IsContentHeader, "The 'headerInfo' parameter should be either a request or content header."); IEnumerable values = null; if (headerInfo.IsRequestHeader) { values = headerInfo.TryGetHeader(httpRequestMessage.Headers); } if (values == null && headerInfo.IsContentHeader && httpRequestMessage.Content != null) { values = headerInfo.TryGetHeader(httpRequestMessage.Content.Headers); } return values; } private static void RemoveHeader(HttpRequestMessage httpRequestMessage, HttpHeaderInfo headerInfo) { Fx.Assert(httpRequestMessage != null, "The 'httpRequestMessage' parameter should never be null."); Fx.Assert(headerInfo != null, "The 'headerInfo' parameter should never be null."); Fx.Assert(headerInfo.IsRequestHeader || headerInfo.IsContentHeader, "The 'headerInfo' parameter should be either a request or content header."); if (headerInfo.IsRequestHeader) { headerInfo.TryRemoveHeader(httpRequestMessage.Headers); } if (headerInfo.IsContentHeader && httpRequestMessage.Content != null) { headerInfo.TryRemoveHeader(httpRequestMessage.Content.Headers); } } private static void AddHeader(HttpRequestMessage httpRequestMessage, HttpHeaderInfo headerInfo, string value) { Fx.Assert(httpRequestMessage != null, "The 'httpRequestMessage' parameter should never be null."); Fx.Assert(headerInfo != null, "The 'headerInfo' parameter should never be null."); Fx.Assert(headerInfo.IsRequestHeader || headerInfo.IsContentHeader, "The 'headerInfo' parameter should be either a request or content header."); if (headerInfo.IsRequestHeader) { if (headerInfo.TryAddHeader(httpRequestMessage.Headers, value)) { return; } } if (headerInfo.IsContentHeader) { CreateContentIfNull(httpRequestMessage); headerInfo.TryAddHeader(httpRequestMessage.Content.Headers, value); } } private static bool CreateContentIfNull(HttpRequestMessage httpRequestMessage) { Fx.Assert(httpRequestMessage != null, "The 'httpRequestMessage' parameter should never be null."); if (httpRequestMessage.Content == null) { httpRequestMessage.Content = new ByteArrayContent(EmptyArray.Instance); return true; } return false; } class SharedByteArrayContent : ByteArrayContent { public SharedByteArrayContent(byte[] content) : base(content) { this.ContentBytes = content; } public byte[] ContentBytes { get; private set; } } } }