Imported Upstream version 5.16.0.100

Former-commit-id: 38faa55fb9669e35e7d8448b15c25dc447f25767
This commit is contained in:
Xamarin Public Jenkins (auto-signing)
2018-08-07 15:19:03 +00:00
parent 0a9828183b
commit 7d7f676260
4419 changed files with 170950 additions and 90273 deletions

View File

@@ -32,6 +32,8 @@ namespace System.Net.Security
RequireEncryption = 0,
}
public delegate System.Security.Cryptography.X509Certificates.X509Certificate LocalCertificateSelectionCallback(object sender, string targetHost, System.Security.Cryptography.X509Certificates.X509CertificateCollection localCertificates, System.Security.Cryptography.X509Certificates.X509Certificate remoteCertificate, string[] acceptableIssuers);
public delegate System.Security.Cryptography.X509Certificates.X509Certificate ServerCertificateSelectionCallback(object sender, string hostName);
public partial class NegotiateStream : AuthenticatedStream
{
public NegotiateStream(System.IO.Stream innerStream) : base(innerStream, false) { }
@@ -108,6 +110,7 @@ namespace System.Net.Security
public X509RevocationMode CertificateRevocationCheckMode { get { throw null; } set { } }
public List<SslApplicationProtocol> ApplicationProtocols { get { throw null; } set { } }
public RemoteCertificateValidationCallback RemoteCertificateValidationCallback { get { throw null; } set { } }
public ServerCertificateSelectionCallback ServerCertificateSelectionCallback { get { throw null; } set { } }
public EncryptionPolicy EncryptionPolicy { get { throw null; } set { } }
}
public partial class SslClientAuthenticationOptions
@@ -186,7 +189,7 @@ namespace System.Net.Security
public virtual System.Threading.Tasks.Task AuthenticateAsServerAsync(System.Security.Cryptography.X509Certificates.X509Certificate serverCertificate) { throw null; }
public virtual System.Threading.Tasks.Task AuthenticateAsServerAsync(System.Security.Cryptography.X509Certificates.X509Certificate serverCertificate, bool clientCertificateRequired, System.Security.Authentication.SslProtocols enabledSslProtocols, bool checkCertificateRevocation) { throw null; }
public virtual System.Threading.Tasks.Task AuthenticateAsServerAsync(System.Security.Cryptography.X509Certificates.X509Certificate serverCertificate, bool clientCertificateRequired, bool checkCertificateRevocation) { throw null; }
public Task AuthenticateAsServerAsync(SslServerAuthenticationOptions sslClientAuthenticationOptions, CancellationToken cancellationToken) { throw null; }
public Task AuthenticateAsServerAsync(SslServerAuthenticationOptions sslServerAuthenticationOptions, CancellationToken cancellationToken) { throw null; }
public virtual System.IAsyncResult BeginAuthenticateAsClient(string targetHost, System.AsyncCallback asyncCallback, object asyncState) { throw null; }
public virtual System.IAsyncResult BeginAuthenticateAsClient(string targetHost, System.Security.Cryptography.X509Certificates.X509CertificateCollection clientCertificates, System.Security.Authentication.SslProtocols enabledSslProtocols, bool checkCertificateRevocation, System.AsyncCallback asyncCallback, object asyncState) { throw null; }
public virtual System.IAsyncResult BeginAuthenticateAsClient(string targetHost, System.Security.Cryptography.X509Certificates.X509CertificateCollection clientCertificates, bool checkCertificateRevocation, System.AsyncCallback asyncCallback, object asyncState) { throw null; }

View File

@@ -208,7 +208,7 @@
<value>Once authentication is attempted as the client or server, additional authentication attempts must use the same client or server role.</value>
</data>
<data name="net_auth_SSPI" xml:space="preserve">
<value>A call to SSPI failed, see inner exception.</value>
<value>Authentication failed, see inner exception.</value>
</data>
<data name="net_auth_eof" xml:space="preserve">
<value>Authentication failed because the remote party has closed the transport stream.</value>

View File

@@ -5,6 +5,7 @@
<AssemblyName>System.Net.Security</AssemblyName>
<ProjectGuid>{89F37791-6254-4D60-AB96-ACD3CCA0E771}</ProjectGuid>
<AllowUnsafeBlocks>true</AllowUnsafeBlocks>
<ILLinkClearInitLocals>true</ILLinkClearInitLocals>
</PropertyGroup>
<PropertyGroup Condition=" '$(TargetsOSX)' == 'true' ">
<DefineConstants>$(DefineConstants);SYSNETSECURITY_NO_OPENSSL</DefineConstants>
@@ -22,6 +23,7 @@
<Compile Include="System\Net\FixedSizeReader.cs" />
<Compile Include="System\Net\HelperAsyncResults.cs" />
<Compile Include="System\Net\Logging\NetEventSource.cs" />
<Compile Include="System\Net\Security\SniHelper.cs" />
<Compile Include="System\Net\Security\SslApplicationProtocol.cs" />
<Compile Include="System\Net\Security\SslAuthenticationOptions.cs" />
<Compile Include="System\Net\Security\SslClientAuthenticationOptions.cs" />
@@ -76,9 +78,6 @@
<Compile Include="$(CommonPath)\System\Net\ExceptionCheck.cs">
<Link>Common\System\Net\ExceptionCheck.cs</Link>
</Compile>
<Compile Include="$(CommonPath)\System\Net\IntPtrHelper.cs">
<Link>Common\System\Net\IntPtrHelper.cs</Link>
</Compile>
<Compile Include="$(CommonPath)\System\Net\LazyAsyncResult.cs">
<Link>Common\System\Net\LazyAsyncResult.cs</Link>
</Compile>
@@ -329,9 +328,6 @@
<Compile Include="$(CommonPath)\Microsoft\Win32\SafeHandles\SafeX509Handles.Unix.cs">
<Link>Common\Microsoft\Win32\SafeHandles\SafeX509Handles.Unix.cs</Link>
</Compile>
<Compile Include="$(CommonPath)\Microsoft\Win32\SafeHandles\SafeX509NameHandle.Unix.cs">
<Link>Common\Microsoft\Win32\SafeHandles\SafeX509NameHandle.Unix.cs</Link>
</Compile>
<Compile Include="$(CommonPath)\Microsoft\Win32\SafeHandles\X509ExtensionSafeHandles.Unix.cs">
<Link>Common\Microsoft\Win32\SafeHandles\X509ExtensionSafeHandles.Unix.cs</Link>
</Compile>

View File

@@ -53,7 +53,7 @@ namespace System.Net
int remainingCount = request.Count, offset = request.Offset;
do
{
int bytes = await transport.ReadAsync(request.Buffer, offset, remainingCount, CancellationToken.None).ConfigureAwait(false);
int bytes = await transport.ReadAsync(new Memory<byte>(request.Buffer, offset, remainingCount), CancellationToken.None).ConfigureAwait(false);
if (bytes == 0)
{
if (remainingCount != request.Count)

View File

@@ -14,21 +14,12 @@ namespace System.Net
[EventSource(Name = "Microsoft-System-Net-Security", LocalizationResources = "FxResources.System.Net.Security.SR")]
internal sealed partial class NetEventSource
{
private const int EnumerateSecurityPackagesId = NextAvailableEventId;
private const int SspiPackageNotFoundId = EnumerateSecurityPackagesId + 1;
private const int AcquireDefaultCredentialId = SspiPackageNotFoundId + 1;
private const int AcquireCredentialsHandleId = AcquireDefaultCredentialId + 1;
private const int SecureChannelCtorId = AcquireCredentialsHandleId + 1;
private const int SecureChannelCtorId = NextAvailableEventId;
private const int LocatingPrivateKeyId = SecureChannelCtorId + 1;
private const int CertIsType2Id = LocatingPrivateKeyId + 1;
private const int FoundCertInStoreId = CertIsType2Id + 1;
private const int NotFoundCertInStoreId = FoundCertInStoreId + 1;
private const int InitializeSecurityContextId = NotFoundCertInStoreId + 1;
private const int SecurityContextInputBufferId = InitializeSecurityContextId + 1;
private const int SecurityContextInputBuffersId = SecurityContextInputBufferId + 1;
private const int AcceptSecuritContextId = SecurityContextInputBuffersId + 1;
private const int OperationReturnedSomethingId = AcceptSecuritContextId + 1;
private const int RemoteCertificateId = OperationReturnedSomethingId + 1;
private const int RemoteCertificateId = NotFoundCertInStoreId + 1;
private const int CertificateFromDelegateId = RemoteCertificateId + 1;
private const int NoDelegateNoClientCertId = CertificateFromDelegateId + 1;
private const int NoDelegateButClientCertId = NoDelegateNoClientCertId + 1;
@@ -343,29 +334,46 @@ namespace System.Net
const int NumEventDatas = 8;
var descrs = stackalloc EventData[NumEventDatas];
descrs[0].DataPointer = (IntPtr)(arg1Ptr);
descrs[0].Size = (arg1.Length + 1) * sizeof(char);
descrs[1].DataPointer = (IntPtr)(&arg2);
descrs[1].Size = sizeof(int);
descrs[2].DataPointer = (IntPtr)(&arg3);
descrs[2].Size = sizeof(int);
descrs[3].DataPointer = (IntPtr)(&arg4);
descrs[3].Size = sizeof(int);
descrs[4].DataPointer = (IntPtr)(&arg5);
descrs[4].Size = sizeof(int);
descrs[5].DataPointer = (IntPtr)(&arg6);
descrs[5].Size = sizeof(int);
descrs[6].DataPointer = (IntPtr)(&arg7);
descrs[6].Size = sizeof(int);
descrs[7].DataPointer = (IntPtr)(&arg8);
descrs[7].Size = sizeof(int);
descrs[0] = new EventData
{
DataPointer = (IntPtr)(arg1Ptr),
Size = (arg1.Length + 1) * sizeof(char)
};
descrs[1] = new EventData
{
DataPointer = (IntPtr)(&arg2),
Size = sizeof(int)
};
descrs[2] = new EventData
{
DataPointer = (IntPtr)(&arg3),
Size = sizeof(int)
};
descrs[3] = new EventData
{
DataPointer = (IntPtr)(&arg4),
Size = sizeof(int)
};
descrs[4] = new EventData
{
DataPointer = (IntPtr)(&arg5),
Size = sizeof(int)
};
descrs[5] = new EventData
{
DataPointer = (IntPtr)(&arg6),
Size = sizeof(int)
};
descrs[6] = new EventData
{
DataPointer = (IntPtr)(&arg7),
Size = sizeof(int)
};
descrs[7] = new EventData
{
DataPointer = (IntPtr)(&arg8),
Size = sizeof(int)
};
WriteEventCore(eventId, NumEventDatas, descrs);
}

View File

@@ -66,7 +66,7 @@ namespace System.Net
// let it pass.
break;
default:
throw new PlatformNotSupportedException(SR.net_encryptionpolicy_notsupported);
throw new PlatformNotSupportedException(SR.Format(SR.net_encryptionpolicy_notsupported, credential.Policy));
}
SafeSslHandle sslContext = Interop.AppleCrypto.SslCreateContext(isServer ? 1 : 0);
@@ -114,11 +114,6 @@ namespace System.Net
_sslContext.Dispose();
_sslContext = null;
}
_toConnection = null;
_fromConnection = null;
_writeCallback = null;
_readCallback = null;
}
base.Dispose(disposing);
@@ -239,44 +234,81 @@ namespace System.Net
}
}
private static readonly SslProtocols[] s_orderedSslProtocols = new SslProtocols[5]
{
#pragma warning disable 0618
SslProtocols.Ssl2,
SslProtocols.Ssl3,
#pragma warning restore 0618
SslProtocols.Tls,
SslProtocols.Tls11,
SslProtocols.Tls12
};
private static void SetProtocols(SafeSslHandle sslContext, SslProtocols protocols)
{
const SslProtocols SupportedProtocols = SslProtocols.Tls | SslProtocols.Tls11 | SslProtocols.Tls12;
SslProtocols minProtocolId;
SslProtocols maxProtocolId;
// A contiguous range of protocols is required. Find the min and max of the range,
// or throw if it's non-contiguous or if no protocols are specified.
switch (protocols & SupportedProtocols)
// First, mark all of the specified protocols.
SslProtocols[] orderedSslProtocols = s_orderedSslProtocols;
Span<bool> protocolSet = stackalloc bool[orderedSslProtocols.Length];
for (int i = 0; i < orderedSslProtocols.Length; i++)
{
case SslProtocols.None:
throw new PlatformNotSupportedException(SR.net_securityprotocolnotsupported);
case SslProtocols.Tls:
minProtocolId = SslProtocols.Tls;
maxProtocolId = SslProtocols.Tls;
break;
case SslProtocols.Tls11:
minProtocolId = SslProtocols.Tls11;
maxProtocolId = SslProtocols.Tls11;
break;
case SslProtocols.Tls12:
minProtocolId = SslProtocols.Tls12;
maxProtocolId = SslProtocols.Tls12;
break;
case SslProtocols.Tls | SslProtocols.Tls11 | SslProtocols.Tls12:
minProtocolId = SslProtocols.Tls;
maxProtocolId = SslProtocols.Tls12;
break;
case SslProtocols.Tls11 | SslProtocols.Tls12:
minProtocolId = SslProtocols.Tls11;
maxProtocolId = SslProtocols.Tls12;
break;
case SslProtocols.Tls | SslProtocols.Tls11:
minProtocolId = SslProtocols.Tls;
maxProtocolId = SslProtocols.Tls11;
break;
default:
throw new PlatformNotSupportedException(SR.net_security_sslprotocol_contiguous);
protocolSet[i] = (protocols & orderedSslProtocols[i]) != 0;
}
SslProtocols minProtocolId = (SslProtocols)(-1);
SslProtocols maxProtocolId = (SslProtocols)(-1);
// Loop through them, starting from the lowest.
for (int min = 0; min < protocolSet.Length; min++)
{
if (protocolSet[min])
{
// We found the first one that's set; that's the bottom of the range.
minProtocolId = orderedSslProtocols[min];
// Now loop from there to look for the max of the range.
for (int max = min + 1; max < protocolSet.Length; max++)
{
if (!protocolSet[max])
{
// We found the first one after the min that's not set; the top of the range
// is the one before this (which might be the same as the min).
maxProtocolId = orderedSslProtocols[max - 1];
// Finally, verify that nothing beyond this one is set, as that would be
// a discontiguous set of protocols.
for (int verifyNotSet = max + 1; verifyNotSet < protocolSet.Length; verifyNotSet++)
{
if (protocolSet[verifyNotSet])
{
throw new PlatformNotSupportedException(SR.Format(SR.net_security_sslprotocol_contiguous, protocols));
}
}
break;
}
}
break;
}
}
// If no protocols were set, throw.
if (minProtocolId == (SslProtocols)(-1))
{
throw new PlatformNotSupportedException(SR.net_securityprotocolnotsupported);
}
// If we didn't find an unset protocol after the min, go all the way to the last one.
if (maxProtocolId == (SslProtocols)(-1))
{
maxProtocolId = orderedSslProtocols[orderedSslProtocols.Length - 1];
}
// Finally set this min and max.
Interop.AppleCrypto.SslSetMinProtocolVersion(sslContext, minProtocolId);
Interop.AppleCrypto.SslSetMaxProtocolVersion(sslContext, maxProtocolId);
}

View File

@@ -624,7 +624,7 @@ namespace System.Net.Security
//
// Acquire Server Side Certificate information and set it on the class.
//
private bool AcquireServerCredentials(ref byte[] thumbPrint)
private bool AcquireServerCredentials(ref byte[] thumbPrint, byte[] clientHello)
{
if (NetEventSource.IsEnabled)
NetEventSource.Enter(this);
@@ -632,10 +632,25 @@ namespace System.Net.Security
X509Certificate localCertificate = null;
bool cachedCred = false;
if (_sslAuthenticationOptions.CertSelectionDelegate != null)
// There are three options for selecting the server certificate. When
// selecting which to use, we prioritize the new ServerCertSelectionDelegate
// API. If the new API isn't used we call LocalCertSelectionCallback (for compat
// with .NET Framework), and if neither is set we fall back to using ServerCertificate.
if (_sslAuthenticationOptions.ServerCertSelectionDelegate != null)
{
string serverIdentity = SniHelper.GetServerName(clientHello);
localCertificate = _sslAuthenticationOptions.ServerCertSelectionDelegate(serverIdentity);
if (localCertificate == null)
{
throw new AuthenticationException(SR.net_ssl_io_no_server_cert);
}
}
else if (_sslAuthenticationOptions.CertSelectionDelegate != null)
{
X509CertificateCollection tempCollection = new X509CertificateCollection();
tempCollection.Add(_sslAuthenticationOptions.ServerCertificate);
// We pass string.Empty here to maintain strict compatability with .NET Framework.
localCertificate = _sslAuthenticationOptions.CertSelectionDelegate(string.Empty, tempCollection, null, Array.Empty<string>());
if (NetEventSource.IsEnabled)
NetEventSource.Info(this, "Use delegate selected Cert");
@@ -744,7 +759,6 @@ namespace System.Net.Security
#if TRACE_VERBOSE
if (NetEventSource.IsEnabled) NetEventSource.Enter(this, $"_refreshCredentialNeeded = {_refreshCredentialNeeded}");
#endif
if (offset < 0 || offset > (input == null ? 0 : input.Length))
{
NetEventSource.Fail(this, "Argument 'offset' out of range.");
@@ -786,7 +800,7 @@ namespace System.Net.Security
if (_refreshCredentialNeeded)
{
cachedCreds = _sslAuthenticationOptions.IsServer
? AcquireServerCredentials(ref thumbPrint)
? AcquireServerCredentials(ref thumbPrint, input)
: AcquireClientCredentials(ref thumbPrint);
}
@@ -870,24 +884,20 @@ namespace System.Net.Security
if (NetEventSource.IsEnabled)
NetEventSource.Enter(this);
StreamSizes streamSizes;
SslStreamPal.QueryContextStreamSizes(_securityContext, out streamSizes);
SslStreamPal.QueryContextStreamSizes(_securityContext, out StreamSizes streamSizes);
if (streamSizes != null)
try
{
try
{
_headerSize = streamSizes.Header;
_trailerSize = streamSizes.Trailer;
_maxDataSize = checked(streamSizes.MaximumMessage - (_headerSize + _trailerSize));
_headerSize = streamSizes.Header;
_trailerSize = streamSizes.Trailer;
_maxDataSize = checked(streamSizes.MaximumMessage - (_headerSize + _trailerSize));
Debug.Assert(_maxDataSize > 0, "_maxDataSize > 0");
}
catch (Exception e) when (!ExceptionCheck.IsFatal(e))
{
NetEventSource.Fail(this, "StreamSizes out of range.");
throw;
}
Debug.Assert(_maxDataSize > 0, "_maxDataSize > 0");
}
catch (Exception e) when (!ExceptionCheck.IsFatal(e))
{
NetEventSource.Fail(this, "StreamSizes out of range.");
throw;
}
SslStreamPal.QueryContextConnectionInfo(_securityContext, out _connectionInfo);

View File

@@ -0,0 +1,391 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System.Buffers.Binary;
using System.Globalization;
using System.Text;
namespace System.Net.Security
{
internal class SniHelper
{
private const int ProtocolVersionSize = 2;
private const int UInt24Size = 3;
private const int RandomSize = 32;
private readonly static IdnMapping s_idnMapping = CreateIdnMapping();
private readonly static Encoding s_encoding = CreateEncoding();
public static string GetServerName(byte[] clientHello)
{
return GetSniFromSslPlainText(clientHello);
}
private static string GetSniFromSslPlainText(ReadOnlySpan<byte> sslPlainText)
{
// https://tools.ietf.org/html/rfc6101#section-5.2.1
// struct {
// ContentType type; // enum with max value 255
// ProtocolVersion version; // 2x uint8
// uint16 length;
// opaque fragment[SSLPlaintext.length];
// } SSLPlaintext;
const int ContentTypeOffset = 0;
const int ProtocolVersionOffset = ContentTypeOffset + sizeof(ContentType);
const int LengthOffset = ProtocolVersionOffset + ProtocolVersionSize;
const int HandshakeOffset = LengthOffset + sizeof(ushort);
// SSL v2's ContentType has 0x80 bit set.
// We do not care about SSL v2 here because it does not support client hello extensions
if (sslPlainText.Length < HandshakeOffset || (ContentType)sslPlainText[ContentTypeOffset] != ContentType.Handshake)
{
return null;
}
// Skip ContentType and ProtocolVersion
int handshakeLength = BinaryPrimitives.ReadUInt16BigEndian(sslPlainText.Slice(LengthOffset));
ReadOnlySpan<byte> sslHandshake = sslPlainText.Slice(HandshakeOffset);
if (handshakeLength != sslHandshake.Length)
{
return null;
}
return GetSniFromSslHandshake(sslHandshake);
}
private static string GetSniFromSslHandshake(ReadOnlySpan<byte> sslHandshake)
{
// https://tools.ietf.org/html/rfc6101#section-5.6
// struct {
// HandshakeType msg_type; /* handshake type */
// uint24 length; /* bytes in message */
// select (HandshakeType) {
// ...
// case client_hello: ClientHello;
// ...
// } body;
// } Handshake;
const int HandshakeTypeOffset = 0;
const int ClientHelloLengthOffset = HandshakeTypeOffset + sizeof(HandshakeType);
const int ClientHelloOffset = ClientHelloLengthOffset + UInt24Size;
if (sslHandshake.Length < ClientHelloOffset || (HandshakeType)sslHandshake[HandshakeTypeOffset] != HandshakeType.ClientHello)
{
return null;
}
int clientHelloLength = ReadUInt24BigEndian(sslHandshake.Slice(ClientHelloLengthOffset));
ReadOnlySpan<byte> clientHello = sslHandshake.Slice(ClientHelloOffset);
if (clientHello.Length != clientHelloLength)
{
return null;
}
return GetSniFromClientHello(clientHello);
}
private static string GetSniFromClientHello(ReadOnlySpan<byte> clientHello)
{
// Basic structure: https://tools.ietf.org/html/rfc6101#section-5.6.1.2
// Extended structure: https://tools.ietf.org/html/rfc3546#section-2.1
// struct {
// ProtocolVersion client_version; // 2x uint8
// Random random; // 32 bytes
// SessionID session_id; // opaque type
// CipherSuite cipher_suites<2..2^16-1>; // opaque type
// CompressionMethod compression_methods<1..2^8-1>; // opaque type
// Extension client_hello_extension_list<0..2^16-1>;
// } ClientHello;
ReadOnlySpan<byte> p = SkipBytes(clientHello, ProtocolVersionSize + RandomSize);
// Skip SessionID (max size 32 => size fits in 1 byte)
p = SkipOpaqueType1(p);
// Skip cipher suites (max size 2^16-1 => size fits in 2 bytes)
p = SkipOpaqueType2(p, out _);
// Skip compression methods (max size 2^8-1 => size fits in 1 byte)
p = SkipOpaqueType1(p);
// is invalid structure or no extensions?
if (p.IsEmpty)
{
return null;
}
// client_hello_extension_list (max size 2^16-1 => size fits in 2 bytes)
int extensionListLength = BinaryPrimitives.ReadUInt16BigEndian(p);
p = SkipBytes(p, sizeof(ushort));
if (extensionListLength != p.Length)
{
return null;
}
string ret = null;
while (!p.IsEmpty)
{
bool invalid;
string sni = GetSniFromExtension(p, out p, out invalid);
if (invalid)
{
return null;
}
if (ret != null && sni != null)
{
return null;
}
if (sni != null)
{
ret = sni;
}
}
return ret;
}
private static string GetSniFromExtension(ReadOnlySpan<byte> extension, out ReadOnlySpan<byte> remainingBytes, out bool invalid)
{
// https://tools.ietf.org/html/rfc3546#section-2.3
// struct {
// ExtensionType extension_type;
// opaque extension_data<0..2^16-1>;
// } Extension;
const int ExtensionDataOffset = sizeof(ExtensionType);
if (extension.Length < ExtensionDataOffset)
{
remainingBytes = ReadOnlySpan<byte>.Empty;
invalid = true;
return null;
}
ExtensionType extensionType = (ExtensionType)BinaryPrimitives.ReadUInt16BigEndian(extension);
ReadOnlySpan<byte> extensionData = extension.Slice(ExtensionDataOffset);
if (extensionType == ExtensionType.ServerName)
{
return GetSniFromServerNameList(extensionData, out remainingBytes, out invalid);
}
else
{
remainingBytes = SkipOpaqueType2(extensionData, out invalid);
return null;
}
}
private static string GetSniFromServerNameList(ReadOnlySpan<byte> serverNameListExtension, out ReadOnlySpan<byte> remainingBytes, out bool invalid)
{
// https://tools.ietf.org/html/rfc3546#section-3.1
// struct {
// ServerName server_name_list<1..2^16-1>
// } ServerNameList;
// ServerNameList is an opaque type (length of sufficient size for max data length is prepended)
const int ServerNameListOffset = sizeof(ushort);
if (serverNameListExtension.Length < ServerNameListOffset)
{
remainingBytes = ReadOnlySpan<byte>.Empty;
invalid = true;
return null;
}
int serverNameListLength = BinaryPrimitives.ReadUInt16BigEndian(serverNameListExtension);
ReadOnlySpan<byte> serverNameList = serverNameListExtension.Slice(ServerNameListOffset);
if (serverNameListLength > serverNameList.Length)
{
remainingBytes = ReadOnlySpan<byte>.Empty;
invalid = true;
return null;
}
remainingBytes = serverNameList.Slice(serverNameListLength);
ReadOnlySpan<byte> serverName = serverNameList.Slice(0, serverNameListLength);
return GetSniFromServerName(serverName, out invalid);
}
private static string GetSniFromServerName(ReadOnlySpan<byte> serverName, out bool invalid)
{
// https://tools.ietf.org/html/rfc3546#section-3.1
// struct {
// NameType name_type;
// select (name_type) {
// case host_name: HostName;
// } name;
// } ServerName;
// ServerName is an opaque type (length of sufficient size for max data length is prepended)
const int ServerNameLengthOffset = 0;
const int NameTypeOffset = ServerNameLengthOffset + sizeof(ushort);
const int HostNameStructOffset = NameTypeOffset + sizeof(NameType);
if (serverName.Length < HostNameStructOffset)
{
invalid = true;
return null;
}
// Following can underflow but it is ok due to equality check below
int hostNameStructLength = BinaryPrimitives.ReadUInt16BigEndian(serverName) - sizeof(NameType);
NameType nameType = (NameType)serverName[NameTypeOffset];
ReadOnlySpan<byte> hostNameStruct = serverName.Slice(HostNameStructOffset);
if (hostNameStructLength != hostNameStruct.Length || nameType != NameType.HostName)
{
invalid = true;
return null;
}
return GetSniFromHostNameStruct(hostNameStruct, out invalid);
}
private static string GetSniFromHostNameStruct(ReadOnlySpan<byte> hostNameStruct, out bool invalid)
{
// https://tools.ietf.org/html/rfc3546#section-3.1
// HostName is an opaque type (length of sufficient size for max data length is prepended)
const int HostNameLengthOffset = 0;
const int HostNameOffset = HostNameLengthOffset + sizeof(ushort);
int hostNameLength = BinaryPrimitives.ReadUInt16BigEndian(hostNameStruct);
ReadOnlySpan<byte> hostName = hostNameStruct.Slice(HostNameOffset);
if (hostNameLength != hostName.Length)
{
invalid = true;
return null;
}
invalid = false;
return DecodeString(hostName);
}
private static string DecodeString(ReadOnlySpan<byte> bytes)
{
// https://tools.ietf.org/html/rfc3546#section-3.1
// Per spec:
// If the hostname labels contain only US-ASCII characters, then the
// client MUST ensure that labels are separated only by the byte 0x2E,
// representing the dot character U+002E (requirement 1 in section 3.1
// of [IDNA] notwithstanding). If the server needs to match the HostName
// against names that contain non-US-ASCII characters, it MUST perform
// the conversion operation described in section 4 of [IDNA], treating
// the HostName as a "query string" (i.e. the AllowUnassigned flag MUST
// be set). Note that IDNA allows labels to be separated by any of the
// Unicode characters U+002E, U+3002, U+FF0E, and U+FF61, therefore
// servers MUST accept any of these characters as a label separator. If
// the server only needs to match the HostName against names containing
// exclusively ASCII characters, it MUST compare ASCII names case-
// insensitively.
string idnEncodedString;
try
{
idnEncodedString = s_encoding.GetString(bytes);
}
catch (DecoderFallbackException)
{
return null;
}
try
{
return s_idnMapping.GetUnicode(idnEncodedString);
}
catch (ArgumentException)
{
// client has not done IDN mapping
return idnEncodedString;
}
}
private static int ReadUInt24BigEndian(ReadOnlySpan<byte> bytes)
{
return (bytes[0] << 16) | (bytes[1] << 8) | bytes[2];
}
private static ReadOnlySpan<byte> SkipBytes(ReadOnlySpan<byte> bytes, int numberOfBytesToSkip)
{
return (numberOfBytesToSkip < bytes.Length) ? bytes.Slice(numberOfBytesToSkip) : ReadOnlySpan<byte>.Empty;
}
// Opaque type is of structure:
// - length (minimum number of bytes to hold the max value)
// - data (length bytes)
// We will only use opaque types which are of max size: 255 (length = 1) or 2^16-1 (length = 2).
// We will call them SkipOpaqueType`length`
private static ReadOnlySpan<byte> SkipOpaqueType1(ReadOnlySpan<byte> bytes)
{
const int OpaqueTypeLengthSize = sizeof(byte);
if (bytes.Length < OpaqueTypeLengthSize)
{
return ReadOnlySpan<byte>.Empty;
}
byte length = bytes[0];
int totalBytes = OpaqueTypeLengthSize + length;
return SkipBytes(bytes, totalBytes);
}
private static ReadOnlySpan<byte> SkipOpaqueType2(ReadOnlySpan<byte> bytes, out bool invalid)
{
const int OpaqueTypeLengthSize = sizeof(ushort);
if (bytes.Length < OpaqueTypeLengthSize)
{
invalid = true;
return ReadOnlySpan<byte>.Empty;
}
ushort length = BinaryPrimitives.ReadUInt16BigEndian(bytes);
int totalBytes = OpaqueTypeLengthSize + length;
invalid = bytes.Length < totalBytes;
if (invalid)
{
return ReadOnlySpan<byte>.Empty;
}
else
{
return bytes.Slice(totalBytes);
}
}
private static IdnMapping CreateIdnMapping()
{
return new IdnMapping()
{
// Per spec "AllowUnassigned flag MUST be set". See comment above GetSniFromServerNameList for more details.
AllowUnassigned = true
};
}
private static Encoding CreateEncoding()
{
return Encoding.GetEncoding("utf-8", new EncoderExceptionFallback(), new DecoderExceptionFallback());
}
private enum ContentType : byte
{
Handshake = 0x16
}
private enum HandshakeType : byte
{
ClientHello = 0x01
}
private enum ExtensionType : ushort
{
ServerName = 0x00
}
private enum NameType : byte
{
HostName = 0x00
}
}
}

View File

@@ -115,17 +115,17 @@ namespace System.Net.Security
}
return new string(byteChars, 0, byteCharsLength - 1);
char GetHexValue(int i)
{
if (i < 10)
return (char)(i + '0');
return (char)(i - 10 + 'a');
}
}
}
static char GetHexValue(int i)
{
if (i < 10)
return (char)(i + '0');
return (char)(i - 10 + 'a');
}
public static bool operator ==(SslApplicationProtocol left, SslApplicationProtocol right)
{
return left.Equals(right);

View File

@@ -11,12 +11,12 @@ namespace System.Net.Security
{
internal class SslAuthenticationOptions
{
internal SslAuthenticationOptions(SslClientAuthenticationOptions sslClientAuthenticationOptions)
internal SslAuthenticationOptions(SslClientAuthenticationOptions sslClientAuthenticationOptions, RemoteCertValidationCallback remoteCallback, LocalCertSelectionCallback localCallback)
{
// Common options.
AllowRenegotiation = sslClientAuthenticationOptions.AllowRenegotiation;
ApplicationProtocols = sslClientAuthenticationOptions.ApplicationProtocols;
CertValidationDelegate = sslClientAuthenticationOptions._certValidationDelegate;
CertValidationDelegate = remoteCallback;
CheckCertName = true;
EnabledSslProtocols = sslClientAuthenticationOptions.EnabledSslProtocols;
EncryptionPolicy = sslClientAuthenticationOptions.EncryptionPolicy;
@@ -26,7 +26,7 @@ namespace System.Net.Security
TargetHost = sslClientAuthenticationOptions.TargetHost;
// Client specific options.
CertSelectionDelegate = sslClientAuthenticationOptions._certSelectionDelegate;
CertSelectionDelegate = localCallback;
CertificateRevocationCheckMode = sslClientAuthenticationOptions.CertificateRevocationCheckMode;
ClientCertificates = sslClientAuthenticationOptions.ClientCertificates;
LocalCertificateSelectionCallback = sslClientAuthenticationOptions.LocalCertificateSelectionCallback;
@@ -37,7 +37,6 @@ namespace System.Net.Security
// Common options.
AllowRenegotiation = sslServerAuthenticationOptions.AllowRenegotiation;
ApplicationProtocols = sslServerAuthenticationOptions.ApplicationProtocols;
CertValidationDelegate = sslServerAuthenticationOptions._certValidationDelegate;
CheckCertName = false;
EnabledSslProtocols = sslServerAuthenticationOptions.EnabledSslProtocols;
EncryptionPolicy = sslServerAuthenticationOptions.EncryptionPolicy;
@@ -66,6 +65,7 @@ namespace System.Net.Security
internal bool CheckCertName { get; set; }
internal RemoteCertValidationCallback CertValidationDelegate { get; set; }
internal LocalCertSelectionCallback CertSelectionDelegate { get; set; }
internal ServerCertSelectionCallback ServerCertSelectionDelegate { get; set; }
}
}

View File

@@ -16,9 +16,6 @@ namespace System.Net.Security
private SslProtocols _enabledSslProtocols = SecurityProtocol.SystemDefaultSecurityProtocols;
private bool _allowRenegotiation = true;
internal RemoteCertValidationCallback _certValidationDelegate;
internal LocalCertSelectionCallback _certSelectionDelegate;
public bool AllowRenegotiation
{
get => _allowRenegotiation;

View File

@@ -15,8 +15,6 @@ namespace System.Net.Security
private EncryptionPolicy _encryptionPolicy = EncryptionPolicy.RequireEncryption;
private bool _allowRenegotiation = true;
internal RemoteCertValidationCallback _certValidationDelegate;
public bool AllowRenegotiation
{
get => _allowRenegotiation;
@@ -29,6 +27,8 @@ namespace System.Net.Security
public RemoteCertificateValidationCallback RemoteCertificateValidationCallback { get; set; }
public ServerCertificateSelectionCallback ServerCertificateSelectionCallback { get; set; }
public X509Certificate ServerCertificate { get; set; }
public SslProtocols EnabledSslProtocols

View File

@@ -77,12 +77,29 @@ namespace System.Net.Security
_innerStream = innerStream;
}
internal void ValidateCreateContext(SslClientAuthenticationOptions sslClientAuthenticationOptions)
/// <summary>Set as the _exception when the instance is disposed.</summary>
private static readonly ExceptionDispatchInfo s_disposedSentinel = ExceptionDispatchInfo.Capture(new ObjectDisposedException(nameof(SslStream)));
private void ThrowIfExceptional()
{
if (_exception != null)
ExceptionDispatchInfo e = _exception;
if (e != null)
{
_exception.Throw();
// If the stored exception just indicates disposal, throw a new ODE rather than the stored one,
// so as to not continually build onto the shared exception's stack.
if (ReferenceEquals(e, s_disposedSentinel))
{
throw new ObjectDisposedException(nameof(SslStream));
}
// Throw the stored exception.
e.Throw();
}
}
internal void ValidateCreateContext(SslClientAuthenticationOptions sslClientAuthenticationOptions, RemoteCertValidationCallback remoteCallback, LocalCertSelectionCallback localCallback)
{
ThrowIfExceptional();
if (Context != null && Context.IsValidContext)
{
@@ -99,15 +116,14 @@ namespace System.Net.Security
throw new ArgumentNullException(nameof(sslClientAuthenticationOptions.TargetHost));
}
if (sslClientAuthenticationOptions.TargetHost.Length == 0)
{
sslClientAuthenticationOptions.TargetHost = "?" + Interlocked.Increment(ref s_uniqueNameInteger).ToString(NumberFormatInfo.InvariantInfo);
}
_exception = null;
try
{
_sslAuthenticationOptions = new SslAuthenticationOptions(sslClientAuthenticationOptions);
_sslAuthenticationOptions = new SslAuthenticationOptions(sslClientAuthenticationOptions, remoteCallback, localCallback);
if (_sslAuthenticationOptions.TargetHost.Length == 0)
{
_sslAuthenticationOptions.TargetHost = "?" + Interlocked.Increment(ref s_uniqueNameInteger).ToString(NumberFormatInfo.InvariantInfo);
}
_context = new SecureChannel(_sslAuthenticationOptions);
}
catch (Win32Exception e)
@@ -116,12 +132,9 @@ namespace System.Net.Security
}
}
internal void ValidateCreateContext(SslServerAuthenticationOptions sslServerAuthenticationOptions)
internal void ValidateCreateContext(SslAuthenticationOptions sslAuthenticationOptions)
{
if (_exception != null)
{
_exception.Throw();
}
ThrowIfExceptional();
if (Context != null && Context.IsValidContext)
{
@@ -133,15 +146,11 @@ namespace System.Net.Security
throw new InvalidOperationException(SR.net_auth_client_server);
}
if (sslServerAuthenticationOptions.ServerCertificate == null)
{
throw new ArgumentNullException(nameof(sslServerAuthenticationOptions.ServerCertificate));
}
_exception = null;
_sslAuthenticationOptions = sslAuthenticationOptions;
try
{
_sslAuthenticationOptions = new SslAuthenticationOptions(sslServerAuthenticationOptions);
_context = new SecureChannel(_sslAuthenticationOptions);
}
catch (Win32Exception e)
@@ -401,7 +410,7 @@ namespace System.Net.Security
}
}
private ExceptionDispatchInfo SetException(Exception e)
private void SetException(Exception e)
{
Debug.Assert(e != null, $"Expected non-null Exception to be passed to {nameof(SetException)}");
@@ -410,12 +419,7 @@ namespace System.Net.Security
_exception = ExceptionDispatchInfo.Capture(e);
}
if (_exception != null && Context != null)
{
Context.Close();
}
return _exception;
Context?.Close();
}
private bool HandshakeCompleted
@@ -436,10 +440,7 @@ namespace System.Net.Security
internal void CheckThrow(bool authSuccessCheck, bool shutdownCheck = false)
{
if (_exception != null)
{
_exception.Throw();
}
ThrowIfExceptional();
if (authSuccessCheck && !IsAuthenticated)
{
@@ -467,11 +468,9 @@ namespace System.Net.Security
//
internal void Close()
{
_exception = ExceptionDispatchInfo.Capture(new ObjectDisposedException("SslStream"));
if (Context != null)
{
Context.Close();
}
_exception = s_disposedSentinel;
Context?.Close();
_secureStream?.Dispose();
}
internal SecurityStatusPal EncryptData(ReadOnlyMemory<byte> buffer, ref byte[] outBuffer, out int outSize)
@@ -668,14 +667,12 @@ namespace System.Net.Security
_Framing = Framing.Unknown;
_handshakeCompleted = false;
if (SetException(e).SourceException == e)
SetException(e);
if (_exception.SourceException != e)
{
throw;
}
else
{
_exception.Throw();
ThrowIfExceptional();
}
throw;
}
finally
{
@@ -734,7 +731,8 @@ namespace System.Net.Security
_Framing = Framing.Unknown;
_handshakeCompleted = false;
SetException(e).Throw();
SetException(e);
ThrowIfExceptional();
}
}
@@ -1286,7 +1284,7 @@ namespace System.Net.Security
AsyncProtocolRequest request = (AsyncProtocolRequest)_queuedReadStateRequest;
request.Buffer = renegotiateBuffer;
_queuedReadStateRequest = null;
ThreadPool.QueueUserWorkItem(new WaitCallback(AsyncResumeHandshakeRead), request);
ThreadPool.QueueUserWorkItem(s => s.sslState.AsyncResumeHandshakeRead(s.request), (sslState: this, request), preferLocal: false);
}
}
}
@@ -1384,7 +1382,7 @@ namespace System.Net.Security
taskCompletionSource.SetResult(0);
break;
default:
ThreadPool.QueueUserWorkItem(new WaitCallback(AsyncResumeHandshake), obj);
ThreadPool.QueueUserWorkItem(s => s.sslState.AsyncResumeHandshake(s.obj), (sslState: this, obj), preferLocal: false);
break;
}
}
@@ -1748,9 +1746,8 @@ namespace System.Net.Security
//
// Called with no user stack.
//
private void AsyncResumeHandshakeRead(object state)
private void AsyncResumeHandshakeRead(AsyncProtocolRequest asyncRequest)
{
AsyncProtocolRequest asyncRequest = (AsyncProtocolRequest)state;
try
{
if (_pendingReHandshake)
@@ -1776,22 +1773,6 @@ namespace System.Net.Security
}
}
//
// Called with no user stack.
//
private void CompleteRequestWaitCallback(object state)
{
AsyncProtocolRequest request = (AsyncProtocolRequest)state;
// Force async completion.
if (request.MustCompleteSynchronously)
{
throw new InternalException();
}
request.CompleteRequest(0);
}
private void RehandshakeCompleteCallback(IAsyncResult result)
{
LazyAsyncResult lazyAsyncResult = (LazyAsyncResult)result;

View File

@@ -31,9 +31,12 @@ namespace System.Net.Security
// A user delegate used to select local SSL certificate.
public delegate X509Certificate LocalCertificateSelectionCallback(object sender, string targetHost, X509CertificateCollection localCertificates, X509Certificate remoteCertificate, string[] acceptableIssuers);
public delegate X509Certificate ServerCertificateSelectionCallback(object sender, string hostName);
// Internal versions of the above delegates.
internal delegate bool RemoteCertValidationCallback(string host, X509Certificate2 certificate, X509Chain chain, SslPolicyErrors sslPolicyErrors);
internal delegate X509Certificate LocalCertSelectionCallback(string targetHost, X509CertificateCollection localCertificates, X509Certificate2 remoteCertificate, string[] acceptableIssuers);
internal delegate X509Certificate ServerCertSelectionCallback(string hostName);
public class SslStream : AuthenticatedStream
{
@@ -42,6 +45,7 @@ namespace System.Net.Security
internal RemoteCertificateValidationCallback _userCertificateValidationCallback;
internal LocalCertificateSelectionCallback _userCertificateSelectionCallback;
internal ServerCertificateSelectionCallback _userServerCertificateSelectionCallback;
internal RemoteCertValidationCallback _certValidationDelegate;
internal LocalCertSelectionCallback _certSelectionDelegate;
internal EncryptionPolicy _encryptionPolicy;
@@ -141,6 +145,34 @@ namespace System.Net.Security
return _userCertificateSelectionCallback(this, targetHost, localCertificates, remoteCertificate, acceptableIssuers);
}
private X509Certificate ServerCertSelectionCallbackWrapper(string targetHost)
{
return _userServerCertificateSelectionCallback(this, targetHost);
}
private SslAuthenticationOptions CreateAuthenticationOptions(SslServerAuthenticationOptions sslServerAuthenticationOptions)
{
if (sslServerAuthenticationOptions.ServerCertificate == null && sslServerAuthenticationOptions.ServerCertificateSelectionCallback == null && _certSelectionDelegate == null)
{
throw new ArgumentNullException(nameof(sslServerAuthenticationOptions.ServerCertificate));
}
if ((sslServerAuthenticationOptions.ServerCertificate != null || _certSelectionDelegate != null) && sslServerAuthenticationOptions.ServerCertificateSelectionCallback != null)
{
throw new InvalidOperationException(SR.Format(SR.net_conflicting_options, nameof(ServerCertificateSelectionCallback)));
}
var authOptions = new SslAuthenticationOptions(sslServerAuthenticationOptions);
_userServerCertificateSelectionCallback = sslServerAuthenticationOptions.ServerCertificateSelectionCallback;
authOptions.ServerCertSelectionDelegate = _userServerCertificateSelectionCallback == null ? null : new ServerCertSelectionCallback(ServerCertSelectionCallbackWrapper);
authOptions.CertValidationDelegate = _certValidationDelegate;
authOptions.CertSelectionDelegate = _certSelectionDelegate;
return authOptions;
}
//
// Client side auth.
//
@@ -174,15 +206,10 @@ namespace System.Net.Security
internal virtual IAsyncResult BeginAuthenticateAsClient(SslClientAuthenticationOptions sslClientAuthenticationOptions, CancellationToken cancellationToken, AsyncCallback asyncCallback, object asyncState)
{
SecurityProtocol.ThrowOnNotAllowed(sslClientAuthenticationOptions.EnabledSslProtocols);
SetAndVerifyValidationCallback(sslClientAuthenticationOptions.RemoteCertificateValidationCallback);
SetAndVerifySelectionCallback(sslClientAuthenticationOptions.LocalCertificateSelectionCallback);
// Set the delegates on the options.
sslClientAuthenticationOptions._certValidationDelegate = _certValidationDelegate;
sslClientAuthenticationOptions._certSelectionDelegate = _certSelectionDelegate;
_sslState.ValidateCreateContext(sslClientAuthenticationOptions);
_sslState.ValidateCreateContext(sslClientAuthenticationOptions, _certValidationDelegate, _certSelectionDelegate);
LazyAsyncResult result = new LazyAsyncResult(_sslState, asyncState, asyncCallback);
_sslState.ProcessAuthentication(result);
@@ -230,13 +257,9 @@ namespace System.Net.Security
private IAsyncResult BeginAuthenticateAsServer(SslServerAuthenticationOptions sslServerAuthenticationOptions, CancellationToken cancellationToken, AsyncCallback asyncCallback, object asyncState)
{
SecurityProtocol.ThrowOnNotAllowed(sslServerAuthenticationOptions.EnabledSslProtocols);
SetAndVerifyValidationCallback(sslServerAuthenticationOptions.RemoteCertificateValidationCallback);
// Set the delegate on the options.
sslServerAuthenticationOptions._certValidationDelegate = _certValidationDelegate;
_sslState.ValidateCreateContext(sslServerAuthenticationOptions);
_sslState.ValidateCreateContext(CreateAuthenticationOptions(sslServerAuthenticationOptions));
LazyAsyncResult result = new LazyAsyncResult(_sslState, asyncState, asyncCallback);
_sslState.ProcessAuthentication(result);
@@ -298,15 +321,10 @@ namespace System.Net.Security
private void AuthenticateAsClient(SslClientAuthenticationOptions sslClientAuthenticationOptions)
{
SecurityProtocol.ThrowOnNotAllowed(sslClientAuthenticationOptions.EnabledSslProtocols);
SetAndVerifyValidationCallback(sslClientAuthenticationOptions.RemoteCertificateValidationCallback);
SetAndVerifySelectionCallback(sslClientAuthenticationOptions.LocalCertificateSelectionCallback);
// Set the delegates on the options.
sslClientAuthenticationOptions._certValidationDelegate = _certValidationDelegate;
sslClientAuthenticationOptions._certSelectionDelegate = _certSelectionDelegate;
_sslState.ValidateCreateContext(sslClientAuthenticationOptions);
_sslState.ValidateCreateContext(sslClientAuthenticationOptions, _certValidationDelegate, _certSelectionDelegate);
_sslState.ProcessAuthentication(null);
}
@@ -336,13 +354,9 @@ namespace System.Net.Security
private void AuthenticateAsServer(SslServerAuthenticationOptions sslServerAuthenticationOptions)
{
SecurityProtocol.ThrowOnNotAllowed(sslServerAuthenticationOptions.EnabledSslProtocols);
SetAndVerifyValidationCallback(sslServerAuthenticationOptions.RemoteCertificateValidationCallback);
// Set the delegate on the options.
sslServerAuthenticationOptions._certValidationDelegate = _certValidationDelegate;
_sslState.ValidateCreateContext(sslServerAuthenticationOptions);
_sslState.ValidateCreateContext(CreateAuthenticationOptions(sslServerAuthenticationOptions));
_sslState.ProcessAuthentication(null);
}
#endregion
@@ -711,9 +725,9 @@ namespace System.Net.Security
return _sslState.SecureStream.WriteAsync(buffer, offset, count, cancellationToken);
}
public override Task WriteAsync(ReadOnlyMemory<byte> source, CancellationToken cancellationToken)
public override ValueTask WriteAsync(ReadOnlyMemory<byte> buffer, CancellationToken cancellationToken)
{
return _sslState.SecureStream.WriteAsync(source, cancellationToken);
return _sslState.SecureStream.WriteAsync(buffer, cancellationToken);
}
public override Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
@@ -721,9 +735,9 @@ namespace System.Net.Security
return _sslState.SecureStream.ReadAsync(buffer, offset, count, cancellationToken);
}
public override ValueTask<int> ReadAsync(Memory<byte> destination, CancellationToken cancellationToken = default)
public override ValueTask<int> ReadAsync(Memory<byte> buffer, CancellationToken cancellationToken = default)
{
return _sslState.SecureStream.ReadAsync(destination, cancellationToken);
return _sslState.SecureStream.ReadAsync(buffer, cancellationToken);
}
}
}

View File

@@ -12,7 +12,7 @@ namespace System.Net.Security
private interface ISslWriteAdapter
{
Task LockAsync();
Task WriteAsync(byte[] buffer, int offset, int count);
ValueTask WriteAsync(byte[] buffer, int offset, int count);
}
private interface ISslReadAdapter
@@ -61,7 +61,7 @@ namespace System.Net.Security
public Task LockAsync() => _sslState.CheckEnqueueWriteAsync();
public Task WriteAsync(byte[] buffer, int offset, int count) => _sslState.InnerStream.WriteAsync(buffer, offset, count, _cancellationToken);
public ValueTask WriteAsync(byte[] buffer, int offset, int count) => _sslState.InnerStream.WriteAsync(new ReadOnlyMemory<byte>(buffer, offset, count), _cancellationToken);
}
private readonly struct SslWriteSync : ISslWriteAdapter
@@ -76,10 +76,10 @@ namespace System.Net.Security
return Task.CompletedTask;
}
public Task WriteAsync(byte[] buffer, int offset, int count)
public ValueTask WriteAsync(byte[] buffer, int offset, int count)
{
_sslState.InnerStream.Write(buffer, offset, count);
return Task.CompletedTask;
return default;
}
}
}

View File

@@ -14,7 +14,7 @@ namespace System.Net.Security
//
// This is a wrapping stream that does data encryption/decryption based on a successfully authenticated SSPI context.
//
internal partial class SslStreamInternal
internal partial class SslStreamInternal : IDisposable
{
private const int FrameOverhead = 32;
private const int ReadBufferSize = 4096 * 4 + FrameOverhead; // We read in 16K chunks + headers.
@@ -57,10 +57,36 @@ namespace System.Net.Security
~SslStreamInternal()
{
if (_internalBuffer != null)
Dispose(disposing: false);
}
public void Dispose()
{
Dispose(disposing: true);
if (_internalBuffer == null)
{
ArrayPool<byte>.Shared.Return(_internalBuffer);
_internalBuffer = null;
// Suppress finalizer if the read buffer was returned.
GC.SuppressFinalize(this);
}
}
private void Dispose(bool disposing)
{
// Ensure a Read operation is not in progress,
// block potential reads since SslStream is disposing.
// This leaves the _nestedRead = 1, but that's ok, since
// subsequent Reads first check if the context is still available.
if (Interlocked.CompareExchange(ref _nestedRead, 1, 0) == 0)
{
byte[] buffer = _internalBuffer;
if (buffer != null)
{
_internalBuffer = null;
_internalBufferCount = 0;
_internalOffset = 0;
ArrayPool<byte>.Shared.Return(buffer);
}
}
}
@@ -141,7 +167,7 @@ namespace System.Net.Security
internal void EndWrite(IAsyncResult asyncResult) => TaskToApm.End(asyncResult);
internal Task WriteAsync(ReadOnlyMemory<byte> buffer, CancellationToken cancellationToken)
internal ValueTask WriteAsync(ReadOnlyMemory<byte> buffer, CancellationToken cancellationToken)
{
SslWriteAsync writeAdapter = new SslWriteAsync(_sslState, cancellationToken);
return WriteAsyncInternal(writeAdapter, buffer);
@@ -150,13 +176,12 @@ namespace System.Net.Security
internal Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
{
ValidateParameters(buffer, offset, count);
return WriteAsync(new ReadOnlyMemory<byte>(buffer, offset, count), cancellationToken);
return WriteAsync(new ReadOnlyMemory<byte>(buffer, offset, count), cancellationToken).AsTask();
}
private void ResetReadBuffer()
{
Debug.Assert(_decryptedBytesCount == 0);
Debug.Assert(_internalBuffer == null || _internalBufferCount > 0);
if (_internalBuffer == null)
{
@@ -206,22 +231,21 @@ namespace System.Net.Security
throw new NotSupportedException(SR.Format(SR.net_io_invalidnestedcall, nameof(ReadAsync), "read"));
}
while (true)
try
{
int copyBytes;
if (_decryptedBytesCount != 0)
while (true)
{
copyBytes = CopyDecryptedData(buffer);
int copyBytes;
if (_decryptedBytesCount != 0)
{
copyBytes = CopyDecryptedData(buffer);
_sslState.FinishRead(null);
_nestedRead = 0;
_sslState.FinishRead(null);
return copyBytes;
}
return copyBytes;
}
copyBytes = await adapter.LockAsync(buffer).ConfigureAwait(false);
try
{
copyBytes = await adapter.LockAsync(buffer).ConfigureAwait(false);
if (copyBytes > 0)
{
return copyBytes;
@@ -241,9 +265,10 @@ namespace System.Net.Security
}
readBytes = await FillBufferAsync(adapter, SecureChannel.ReadHeaderSize + payloadBytes).ConfigureAwait(false);
if (readBytes < 0)
Debug.Assert(readBytes >= 0);
if (readBytes == 0)
{
throw new IOException(SR.net_frame_read_size);
throw new IOException(SR.net_io_eof);
}
// At this point, readBytes contains the size of the header plus body.
@@ -294,25 +319,25 @@ namespace System.Net.Security
throw new IOException(SR.net_io_decrypt, message.GetException());
}
}
catch (Exception e)
{
_sslState.FinishRead(null);
}
catch (Exception e)
{
_sslState.FinishRead(null);
if (e is IOException)
{
throw;
}
throw new IOException(SR.net_io_read, e);
}
finally
if (e is IOException)
{
_nestedRead = 0;
throw;
}
throw new IOException(SR.net_io_read, e);
}
finally
{
_nestedRead = 0;
}
}
private Task WriteAsyncInternal<TWriteAdapter>(TWriteAdapter writeAdapter, ReadOnlyMemory<byte> buffer)
private ValueTask WriteAsyncInternal<TWriteAdapter>(TWriteAdapter writeAdapter, ReadOnlyMemory<byte> buffer)
where TWriteAdapter : struct, ISslWriteAdapter
{
_sslState.CheckThrow(authSuccessCheck: true, shutdownCheck: true);
@@ -320,7 +345,7 @@ namespace System.Net.Security
if (buffer.Length == 0 && !SslStreamPal.CanEncryptEmptyMessage)
{
// If it's an empty message and the PAL doesn't support that, we're done.
return Task.CompletedTask;
return default;
}
if (Interlocked.Exchange(ref _nestedWrite, 1) == 1)
@@ -328,18 +353,18 @@ namespace System.Net.Security
throw new NotSupportedException(SR.Format(SR.net_io_invalidnestedcall, nameof(WriteAsync), "write"));
}
Task t = buffer.Length < _sslState.MaxDataSize ?
ValueTask t = buffer.Length < _sslState.MaxDataSize ?
WriteSingleChunk(writeAdapter, buffer) :
WriteAsyncChunked(writeAdapter, buffer);
new ValueTask(WriteAsyncChunked(writeAdapter, buffer));
if (t.IsCompletedSuccessfully)
{
_nestedWrite = 0;
return t;
}
return ExitWriteAsync(t);
return new ValueTask(ExitWriteAsync(t));
async Task ExitWriteAsync(Task task)
async Task ExitWriteAsync(ValueTask task)
{
try
{
@@ -363,7 +388,7 @@ namespace System.Net.Security
}
}
private Task WriteSingleChunk<TWriteAdapter>(TWriteAdapter writeAdapter, ReadOnlyMemory<byte> buffer)
private ValueTask WriteSingleChunk<TWriteAdapter>(TWriteAdapter writeAdapter, ReadOnlyMemory<byte> buffer)
where TWriteAdapter : struct, ISslWriteAdapter
{
// Request a write IO slot.
@@ -371,7 +396,7 @@ namespace System.Net.Security
if (!ioSlot.IsCompletedSuccessfully)
{
// Operation is async and has been queued, return.
return WaitForWriteIOSlot(writeAdapter, ioSlot, buffer);
return new ValueTask(WaitForWriteIOSlot(writeAdapter, ioSlot, buffer));
}
byte[] rentedBuffer = ArrayPool<byte>.Shared.Rent(buffer.Length + FrameOverhead);
@@ -384,10 +409,10 @@ namespace System.Net.Security
// Re-handshake status is not supported.
ArrayPool<byte>.Shared.Return(rentedBuffer);
ProtocolToken message = new ProtocolToken(null, status);
return Task.FromException(new IOException(SR.net_io_encrypt, message.GetException()));
return new ValueTask(Task.FromException(new IOException(SR.net_io_encrypt, message.GetException())));
}
Task t = writeAdapter.WriteAsync(outBuffer, 0, encryptedBytes);
ValueTask t = writeAdapter.WriteAsync(outBuffer, 0, encryptedBytes);
if (t.IsCompletedSuccessfully)
{
ArrayPool<byte>.Shared.Return(rentedBuffer);
@@ -396,7 +421,7 @@ namespace System.Net.Security
}
else
{
return CompleteAsync(t, rentedBuffer);
return new ValueTask(CompleteAsync(t, rentedBuffer));
}
async Task WaitForWriteIOSlot(TWriteAdapter wAdapter, Task lockTask, ReadOnlyMemory<byte> buff)
@@ -405,7 +430,7 @@ namespace System.Net.Security
await WriteSingleChunk(wAdapter, buff).ConfigureAwait(false);
}
async Task CompleteAsync(Task writeTask, byte[] bufferToReturn)
async Task CompleteAsync(ValueTask writeTask, byte[] bufferToReturn)
{
try
{
@@ -445,7 +470,7 @@ namespace System.Net.Security
ValueTask<int> t = adapter.ReadAsync(_internalBuffer, _internalBufferCount, _internalBuffer.Length - _internalBufferCount);
if (!t.IsCompletedSuccessfully)
{
return new ValueTask<int>(InternalFillBufferAsync(adapter, t.AsTask(), minSize, initialCount));
return new ValueTask<int>(InternalFillBufferAsync(adapter, t, minSize, initialCount));
}
int bytes = t.Result;
if (bytes == 0)
@@ -464,7 +489,7 @@ namespace System.Net.Security
return new ValueTask<int>(minSize);
async Task<int> InternalFillBufferAsync(TReadAdapter adap, Task<int> task, int min, int initial)
async Task<int> InternalFillBufferAsync(TReadAdapter adap, ValueTask<int> task, int min, int initial)
{
while (true)
{
@@ -485,7 +510,7 @@ namespace System.Net.Security
return min;
}
task = adap.ReadAsync(_internalBuffer, _internalBufferCount, _internalBuffer.Length - _internalBufferCount).AsTask();
task = adap.ReadAsync(_internalBuffer, _internalBufferCount, _internalBuffer.Length - _internalBufferCount);
}
}
}

View File

@@ -16,8 +16,6 @@ namespace System.Net.Security
{
internal static class SslStreamPal
{
private static readonly StreamSizes s_streamSizes = new StreamSizes();
public static Exception GetException(SecurityStatusPal status)
{
return status.Exception ?? new Win32Exception((int)status.ErrorCode);
@@ -128,7 +126,7 @@ namespace System.Net.Security
unsafe
{
MemoryHandle memHandle = input.Retain(pin: true);
MemoryHandle memHandle = input.Pin();
try
{
PAL_TlsIo status;
@@ -259,7 +257,7 @@ namespace System.Net.Security
SafeDeleteContext securityContext,
out StreamSizes streamSizes)
{
streamSizes = s_streamSizes;
streamSizes = StreamSizes.Default;
}
public static void QueryContextConnectionInfo(

View File

@@ -15,8 +15,6 @@ namespace System.Net.Security
{
internal static class SslStreamPal
{
private static readonly StreamSizes s_streamSizes = new StreamSizes();
public static Exception GetException(SecurityStatusPal status)
{
return status.Exception ?? new Interop.OpenSsl.SslException((int)status.ErrorCode);
@@ -122,7 +120,7 @@ namespace System.Net.Security
public static void QueryContextStreamSizes(SafeDeleteContext securityContext, out StreamSizes streamSizes)
{
streamSizes = s_streamSizes;
streamSizes = StreamSizes.Default;
}
public static void QueryContextConnectionInfo(SafeDeleteContext securityContext, out SslConnectionInfo connectionInfo)
@@ -257,6 +255,11 @@ namespace System.Net.Security
{
return new SecurityStatusPal(SecurityStatusPalErrorCode.OK);
}
else if (code == Interop.Ssl.SslErrorCode.SSL_ERROR_SSL)
{
// OpenSSL failure occurred. The error queue contains more details, when building the exception the queue will be cleared.
return new SecurityStatusPal(SecurityStatusPalErrorCode.InternalError, Interop.Crypto.CreateOpenSslCryptographicException());
}
else
{
return new SecurityStatusPal(SecurityStatusPalErrorCode.InternalError, new Interop.OpenSsl.SslException((int)code));

View File

@@ -153,7 +153,8 @@ namespace System.Net.Security
Interop.SspiCli.SCHANNEL_CRED.Flags.SCH_SEND_AUX_RECORD;
// CoreFX: always opt-in SCH_USE_STRONG_CRYPTO except for SSL3.
if (((protocolFlags & (Interop.SChannel.SP_PROT_TLS1_0 | Interop.SChannel.SP_PROT_TLS1_1 | Interop.SChannel.SP_PROT_TLS1_2)) != 0)
if (((protocolFlags == 0) ||
(protocolFlags & (Interop.SChannel.SP_PROT_TLS1_0 | Interop.SChannel.SP_PROT_TLS1_1 | Interop.SChannel.SP_PROT_TLS1_2)) != 0)
&& (policy != EncryptionPolicy.AllowNoEncryption) && (policy != EncryptionPolicy.NoEncryption))
{
flags |= Interop.SspiCli.SCHANNEL_CRED.Flags.SCH_USE_STRONG_CRYPTO;
@@ -165,6 +166,7 @@ namespace System.Net.Security
flags = Interop.SspiCli.SCHANNEL_CRED.Flags.SCH_SEND_AUX_RECORD;
}
if (NetEventSource.IsEnabled) NetEventSource.Info($"flags=({flags}), ProtocolFlags=({protocolFlags}), EncryptionPolicy={policy}");
Interop.SspiCli.SCHANNEL_CRED secureCredential = CreateSecureCredential(
Interop.SspiCli.SCHANNEL_CRED.CurrentVersion,
certificate,
@@ -465,9 +467,8 @@ namespace System.Net.Security
return SSPIWrapper.AcquireCredentialsHandle(GlobalSSPI.SSPISecureChannel, SecurityPackage, credUsage, secureCredential);
});
}
catch (Exception ex)
catch
{
Debug.Fail("AcquireCredentialsHandle failed.", ex.ToString());
return SSPIWrapper.AcquireCredentialsHandle(GlobalSSPI.SSPISecureChannel, SecurityPackage, credUsage, secureCredential);
}
}

View File

@@ -4,7 +4,7 @@
namespace System.Net
{
internal partial class StreamSizes
internal partial struct StreamSizes
{
// Windows SChannel requires that you pass it a buffer big enough to hold
// the header, the trailer, and the payload. You're also required to do your
@@ -19,11 +19,6 @@ namespace System.Net
// but using a bound of 32k means that if we were to switch from pointers to temporary
// arrays, we'd be maintaining a reasonable upper bound.
public StreamSizes()
{
Header = 0;
Trailer = 0;
MaximumMessage = 32 * 1024;
}
public static StreamSizes Default => new StreamSizes { MaximumMessage = 32 * 1024 };
}
}

Some files were not shown because too many files have changed in this diff Show More