Add secured sockets with mbedtls and HTTPS support

This commit is contained in:
Thomas Edvalson
2016-02-04 15:08:47 -05:00
parent 681dc96361
commit 263f48ec13
9 changed files with 234 additions and 85 deletions
+22 -2
View File
@@ -31,6 +31,9 @@
#include <cpp3ds/Network/SocketHandle.hpp>
#include <cpp3ds/System/NonCopyable.hpp>
#include <vector>
#include <mbedtls/net.h>
#include <mbedtls/entropy.h>
#include <mbedtls/ctr_drbg.h>
namespace cpp3ds
@@ -67,6 +70,16 @@ public:
AnyPort = 0 ///< Special value that tells the system to pick any available port
};
struct SecureData
{
mbedtls_net_context socket;
mbedtls_entropy_context entropy;
mbedtls_ctr_drbg_context ctr_drbg;
mbedtls_ssl_context ssl;
mbedtls_ssl_config conf;
mbedtls_x509_crt cacert;
};
public:
////////////////////////////////////////////////////////////
@@ -104,6 +117,9 @@ public:
////////////////////////////////////////////////////////////
bool isBlocking() const;
void setSecure(bool secure);
bool isSecure() const;
protected:
////////////////////////////////////////////////////////////
@@ -124,7 +140,7 @@ protected:
/// \param type Type of the socket (TCP or UDP)
///
////////////////////////////////////////////////////////////
Socket(Type type);
Socket(Type type, bool secure);
////////////////////////////////////////////////////////////
/// \brief Return the internal handle of the socket
@@ -138,6 +154,8 @@ protected:
////////////////////////////////////////////////////////////
SocketHandle getHandle() const;
SecureData& getSecureData() const;
////////////////////////////////////////////////////////////
/// \brief Create the internal representation of the socket
///
@@ -173,8 +191,10 @@ private:
// Member data
////////////////////////////////////////////////////////////
Type m_type; ///< Type of the socket (TCP or UDP)
SocketHandle m_socket; ///< Socket descriptor
bool m_isBlocking; ///< Current blocking mode of the socket
bool m_isSecure; ///< Socket using SSL
SocketHandle m_socket; ///< Socket descriptor
mutable SecureData m_secureData; ///< Data needed for secure socket
};
} // namespace cpp3ds
+2 -2
View File
@@ -30,7 +30,7 @@
////////////////////////////////////////////////////////////
#include <cpp3ds/Config.hpp>
#if defined(SFML_SYSTEM_WINDOWS)
#if defined(CPP3DS_SYSTEM_WINDOWS)
#include <basetsd.h>
#endif
@@ -41,7 +41,7 @@ namespace cpp3ds
// Define the low-level socket handle type, specific to
// each platform
////////////////////////////////////////////////////////////
#if defined(SFML_SYSTEM_WINDOWS)
#if defined(CPP3DS_SYSTEM_WINDOWS)
typedef UINT_PTR SocketHandle;
+1 -1
View File
@@ -50,7 +50,7 @@ public:
/// \brief Default constructor
///
////////////////////////////////////////////////////////////
TcpSocket();
TcpSocket(bool secure = false);
////////////////////////////////////////////////////////////
/// \brief Get the port to which the socket is bound locally
+1 -1
View File
@@ -57,7 +57,7 @@ public:
/// \brief Default constructor
///
////////////////////////////////////////////////////////////
UdpSocket();
UdpSocket(bool secure = false);
////////////////////////////////////////////////////////////
/// \brief Get the port to which the socket is bound locally
+6 -4
View File
@@ -319,19 +319,21 @@ void Http::setHost(const std::string& host, unsigned short port)
if (toLower(host.substr(0, 7)) == "http://")
{
// HTTP protocol
m_connection.setSecure(false);
m_hostName = host.substr(7);
m_port = (port != 0 ? port : 80);
}
else if (toLower(host.substr(0, 8)) == "https://")
{
// HTTPS protocol -- unsupported (requires encryption and certificates and stuff...)
err() << "HTTPS protocol is not supported by cpp3ds::Http" << std::endl;
m_hostName = "";
m_port = 0;
// HTTPS protocol
m_connection.setSecure(true);
m_hostName = host.substr(8);
m_port = (port != 0 ? port : 443);
}
else
{
// Undefined protocol - use HTTP
m_connection.setSecure(false);
m_hostName = host;
m_port = (port != 0 ? port : 80);
}
+113 -9
View File
@@ -28,17 +28,31 @@
#include <cpp3ds/Network/Socket.hpp>
#include <cpp3ds/Network/SocketImpl.hpp>
#include <cpp3ds/System/Err.hpp>
#include <mbedtls/certs.h>
#include <stdlib.h>
namespace
{
static int entropy_func(void *data, unsigned char *output, size_t len)
{
for (int i = 0; i < len; ++i)
output[i] = rand() % 256;
return 0;
}
}
namespace cpp3ds
{
////////////////////////////////////////////////////////////
Socket::Socket(Type type) :
Socket::Socket(Type type, bool secure) :
m_type (type),
m_socket (priv::SocketImpl::invalidSocket()),
m_isBlocking(true)
m_isBlocking(true),
m_isSecure (secure)
{
m_secureData.socket.fd = priv::SocketImpl::invalidSocket();
}
@@ -54,8 +68,21 @@ Socket::~Socket()
void Socket::setBlocking(bool blocking)
{
// Apply if the socket is already created
if (m_socket != priv::SocketImpl::invalidSocket())
priv::SocketImpl::setBlocking(m_socket, blocking);
if (m_isSecure)
{
if (m_secureData.socket.fd != priv::SocketImpl::invalidSocket())
{
if (blocking)
mbedtls_net_set_block(&m_secureData.socket);
else
mbedtls_net_set_nonblock(&m_secureData.socket);
}
}
else
{
if (m_socket != priv::SocketImpl::invalidSocket())
priv::SocketImpl::setBlocking(m_socket, blocking);
}
m_isBlocking = blocking;
}
@@ -68,10 +95,39 @@ bool Socket::isBlocking() const
}
////////////////////////////////////////////////////////////
void Socket::setSecure(bool secure)
{
bool oldSecure = m_isSecure;
m_isSecure = secure;
if ((secure && !oldSecure && m_socket != priv::SocketImpl::invalidSocket()) ||
(!secure && oldSecure && m_secureData.socket.fd != priv::SocketImpl::invalidSocket()))
{
close();
create();
}
}
////////////////////////////////////////////////////////////
bool Socket::isSecure() const
{
return m_isSecure;
}
////////////////////////////////////////////////////////////
SocketHandle Socket::getHandle() const
{
return m_socket;
return m_isSecure ? m_secureData.socket.fd : m_socket;
}
////////////////////////////////////////////////////////////
Socket::SecureData& Socket::getSecureData() const
{
return m_secureData;
}
@@ -79,10 +135,46 @@ SocketHandle Socket::getHandle() const
void Socket::create()
{
// Don't create the socket if it already exists
if (m_socket == priv::SocketImpl::invalidSocket())
if (m_isSecure)
{
SocketHandle handle = socket(PF_INET, m_type == Tcp ? SOCK_STREAM : SOCK_DGRAM, 0);
create(handle);
if (m_secureData.socket.fd == priv::SocketImpl::invalidSocket())
{
int ret;
mbedtls_net_init(&m_secureData.socket);
mbedtls_ssl_init(&m_secureData.ssl);
mbedtls_ssl_config_init(&m_secureData.conf);
mbedtls_x509_crt_init(&m_secureData.cacert);
mbedtls_ctr_drbg_init(&m_secureData.ctr_drbg);
mbedtls_entropy_init( &m_secureData.entropy );
srand(time(nullptr));
if (ret = mbedtls_ctr_drbg_seed(&m_secureData.ctr_drbg, entropy_func, &m_secureData.entropy, NULL, 0 ))
err() << "mbedtls_ctr_drbg_seed failed: " << ret << std::endl;
if ((ret = mbedtls_x509_crt_parse(&m_secureData.cacert, (const unsigned char *) mbedtls_test_cas_pem, mbedtls_test_cas_pem_len)) < 0)
err() << "mbedtls_x509_crt_parse failed: -0x" << std::hex << ret << std::endl;
// use mbedtls_ssl_conf_endpoint to set socket as server
if ((ret = mbedtls_ssl_config_defaults(&m_secureData.conf, MBEDTLS_SSL_IS_CLIENT, MBEDTLS_SSL_TRANSPORT_STREAM, MBEDTLS_SSL_PRESET_DEFAULT)) != 0)
err() << "mbedtls_ssl_config_defaults failed: " << ret << std::endl;
mbedtls_ssl_conf_authmode(&m_secureData.conf, MBEDTLS_SSL_VERIFY_OPTIONAL);
mbedtls_ssl_conf_ca_chain(&m_secureData.conf, &m_secureData.cacert, NULL);
mbedtls_ssl_conf_rng(&m_secureData.conf, mbedtls_ctr_drbg_random, &m_secureData.ctr_drbg);
if ((ret = mbedtls_ssl_setup(&m_secureData.ssl, &m_secureData.conf)) != 0)
err() << "mbedtls_ssl_setup failed: " << ret << std::endl;
if ((ret = mbedtls_ssl_set_hostname(&m_secureData.ssl, "mbed TLS Server")) != 0)
err() << "mbedtls_ssl_set_hostname failed: " << ret << std::endl;
mbedtls_ssl_set_bio(&m_secureData.ssl, &m_secureData.socket, mbedtls_net_send, mbedtls_net_recv, NULL);
setBlocking(m_isBlocking);
}
}
else
{
if (m_socket == priv::SocketImpl::invalidSocket())
{
SocketHandle handle = socket(PF_INET, m_type == Tcp ? SOCK_STREAM : SOCK_DGRAM, IPPROTO_IP);
create(handle);
}
}
}
@@ -141,6 +233,18 @@ void Socket::close()
priv::SocketImpl::close(m_socket);
m_socket = priv::SocketImpl::invalidSocket();
}
// Close the secure socket
if (m_secureData.socket.fd != priv::SocketImpl::invalidSocket())
{
mbedtls_net_free(&m_secureData.socket);
mbedtls_ssl_free(&m_secureData.ssl);
mbedtls_ssl_config_free(&m_secureData.conf);
mbedtls_x509_crt_free(&m_secureData.cacert);
mbedtls_ctr_drbg_free(&m_secureData.ctr_drbg);
mbedtls_entropy_free(&m_secureData.entropy);
m_secureData.socket.fd = priv::SocketImpl::invalidSocket();
}
}
} // namespace cpp3ds
+1 -1
View File
@@ -35,7 +35,7 @@ namespace cpp3ds
{
////////////////////////////////////////////////////////////
TcpListener::TcpListener() :
Socket(Tcp)
Socket(Tcp, false)
{
}
+86 -63
View File
@@ -33,6 +33,7 @@
#include <sys/select.h>
#include <algorithm>
#include <cstring>
#include <mbedtls/net.h>
#ifdef _MSC_VER
#pragma warning(disable: 4127) // "conditional expression is constant" generated by the FD_SET macro
@@ -42,7 +43,7 @@
namespace
{
// Define the low-level send/receive flags, which depend on the OS
#ifdef SFML_SYSTEM_LINUX
#ifdef CPP3DS_SYSTEM_LINUX
const int flags = MSG_NOSIGNAL;
#else
const int flags = 0;
@@ -52,8 +53,8 @@ namespace
namespace cpp3ds
{
////////////////////////////////////////////////////////////
TcpSocket::TcpSocket() :
Socket(Tcp)
TcpSocket::TcpSocket(bool secure) :
Socket(Tcp, secure)
{
}
@@ -122,86 +123,101 @@ Socket::Status TcpSocket::connect(const IpAddress& remoteAddress, unsigned short
// Create the internal socket if it doesn't exist
create();
// Create the remote address
sockaddr_in address = priv::SocketImpl::createAddress(remoteAddress.toInteger(), remotePort);
if (timeout <= Time::Zero)
if (isSecure())
{
// ----- We're not using a timeout: just try to connect -----
// Connect the socket
if (::connect(getHandle(), reinterpret_cast<sockaddr*>(&address), sizeof(address)) == -1)
char port[5];
sprintf(port, "%d", remotePort);
if (mbedtls_net_connect(&getSecureData().socket, remoteAddress.toString().c_str(), port, MBEDTLS_NET_PROTO_TCP) != 0)
return priv::SocketImpl::getErrorStatus();
// Connection succeeded
return Done;
int ret;
while((ret = mbedtls_ssl_handshake(&getSecureData().ssl)) != 0)
if(ret != MBEDTLS_ERR_SSL_WANT_READ && ret != MBEDTLS_ERR_SSL_WANT_WRITE)
return Error;
}
else
{
// ----- We're using a timeout: we'll need a few tricks to make it work -----
// Create the remote address
sockaddr_in address = priv::SocketImpl::createAddress(remoteAddress.toInteger(), remotePort);
// Save the previous blocking state
bool blocking = isBlocking();
// Switch to non-blocking to enable our connection timeout
if (blocking)
setBlocking(false);
// Try to connect to the remote address
if (::connect(getHandle(), reinterpret_cast<sockaddr*>(&address), sizeof(address)) >= 0)
if (timeout <= Time::Zero)
{
// We got instantly connected! (it may no happen a lot...)
setBlocking(blocking);
// ----- We're not using a timeout: just try to connect -----
// Connect the socket
if (::connect(getHandle(), reinterpret_cast<sockaddr*>(&address), sizeof(address)) == -1)
return priv::SocketImpl::getErrorStatus();
// Connection succeeded
return Done;
}
// Get the error status
Status status = priv::SocketImpl::getErrorStatus();
// If we were in non-blocking mode, return immediately
if (!blocking)
return status;
// Otherwise, wait until something happens to our socket (success, timeout or error)
if (status == Socket::NotReady)
else
{
// Setup the selector
fd_set selector;
FD_ZERO(&selector);
FD_SET(getHandle(), &selector);
// ----- We're using a timeout: we'll need a few tricks to make it work -----
// Setup the timeout
timeval time;
time.tv_sec = static_cast<long>(timeout.asMicroseconds() / 1000000);
time.tv_usec = static_cast<long>(timeout.asMicroseconds() % 1000000);
// Save the previous blocking state
bool blocking = isBlocking();
// Wait for something to write on our socket (which means that the connection request has returned)
if (select(static_cast<int>(getHandle() + 1), NULL, &selector, NULL, &time) > 0)
// Switch to non-blocking to enable our connection timeout
if (blocking)
setBlocking(false);
// Try to connect to the remote address
if (::connect(getHandle(), reinterpret_cast<sockaddr*>(&address), sizeof(address)) >= 0)
{
// At this point the connection may have been either accepted or refused.
// To know whether it's a success or a failure, we must check the address of the connected peer
if (getRemoteAddress() != cpp3ds::IpAddress::None)
// We got instantly connected! (it may no happen a lot...)
setBlocking(blocking);
return Done;
}
// Get the error status
Status status = priv::SocketImpl::getErrorStatus();
// If we were in non-blocking mode, return immediately
if (!blocking)
return status;
// Otherwise, wait until something happens to our socket (success, timeout or error)
if (status == Socket::NotReady)
{
// Setup the selector
fd_set selector;
FD_ZERO(&selector);
FD_SET(getHandle(), &selector);
// Setup the timeout
timeval time;
time.tv_sec = static_cast<long>(timeout.asMicroseconds() / 1000000);
time.tv_usec = static_cast<long>(timeout.asMicroseconds() % 1000000);
// Wait for something to write on our socket (which means that the connection request has returned)
if (select(static_cast<int>(getHandle() + 1), NULL, &selector, NULL, &time) > 0)
{
// Connection accepted
status = Done;
// At this point the connection may have been either accepted or refused.
// To know whether it's a success or a failure, we must check the address of the connected peer
if (getRemoteAddress() != cpp3ds::IpAddress::None)
{
// Connection accepted
status = Done;
}
else
{
// Connection refused
status = priv::SocketImpl::getErrorStatus();
}
}
else
{
// Connection refused
// Failed to connect before timeout is over
status = priv::SocketImpl::getErrorStatus();
}
}
else
{
// Failed to connect before timeout is over
status = priv::SocketImpl::getErrorStatus();
}
// Switch back to blocking mode
setBlocking(true);
return status;
}
// Switch back to blocking mode
setBlocking(true);
return status;
}
}
@@ -244,7 +260,10 @@ Socket::Status TcpSocket::send(const void* data, std::size_t size, std::size_t&
for (sent = 0; sent < size; sent += result)
{
// Send a chunk of data
result = ::send(getHandle(), static_cast<const char*>(data) + sent, size - sent, flags);
if (isSecure())
result = mbedtls_ssl_write(&getSecureData().ssl, static_cast<const unsigned char*>(data) + sent, size - sent);
else
result = ::send(getHandle(), static_cast<const char*>(data) + sent, size - sent, flags);
// Check for errors
if (result < 0)
@@ -276,7 +295,11 @@ Socket::Status TcpSocket::receive(void* data, std::size_t size, std::size_t& rec
}
// Receive a chunk of bytes
int sizeReceived = recv(getHandle(), static_cast<char*>(data), static_cast<int>(size), flags);
int sizeReceived;
if (isSecure())
sizeReceived = mbedtls_ssl_read(&getSecureData().ssl, static_cast<unsigned char*>(data), size);
else
sizeReceived = recv(getHandle(), static_cast<char*>(data), static_cast<int>(size), flags);
// Check the number of bytes received
if (sizeReceived > 0)
+2 -2
View File
@@ -36,8 +36,8 @@
namespace cpp3ds
{
////////////////////////////////////////////////////////////
UdpSocket::UdpSocket() :
Socket (Udp),
UdpSocket::UdpSocket(bool secure) :
Socket (Udp, secure),
m_buffer(MaxDatagramSize)
{