From 263f48ec139f8813149c5634b5cee275bee75ae5 Mon Sep 17 00:00:00 2001 From: Thomas Edvalson Date: Thu, 4 Feb 2016 15:08:47 -0500 Subject: [PATCH] Add secured sockets with mbedtls and HTTPS support --- include/cpp3ds/Network/Socket.hpp | 24 +++- include/cpp3ds/Network/SocketHandle.hpp | 4 +- include/cpp3ds/Network/TcpSocket.hpp | 2 +- include/cpp3ds/Network/UdpSocket.hpp | 2 +- src/cpp3ds/Network/Http.cpp | 10 +- src/cpp3ds/Network/Socket.cpp | 122 +++++++++++++++++-- src/cpp3ds/Network/TcpListener.cpp | 2 +- src/cpp3ds/Network/TcpSocket.cpp | 149 ++++++++++++++---------- src/cpp3ds/Network/UdpSocket.cpp | 4 +- 9 files changed, 234 insertions(+), 85 deletions(-) diff --git a/include/cpp3ds/Network/Socket.hpp b/include/cpp3ds/Network/Socket.hpp index 096f25f..1182679 100644 --- a/include/cpp3ds/Network/Socket.hpp +++ b/include/cpp3ds/Network/Socket.hpp @@ -31,6 +31,9 @@ #include #include #include +#include +#include +#include namespace cpp3ds @@ -67,6 +70,16 @@ public: AnyPort = 0 ///< Special value that tells the system to pick any available port }; + struct SecureData + { + mbedtls_net_context socket; + mbedtls_entropy_context entropy; + mbedtls_ctr_drbg_context ctr_drbg; + mbedtls_ssl_context ssl; + mbedtls_ssl_config conf; + mbedtls_x509_crt cacert; + }; + public: //////////////////////////////////////////////////////////// @@ -104,6 +117,9 @@ public: //////////////////////////////////////////////////////////// bool isBlocking() const; + void setSecure(bool secure); + bool isSecure() const; + protected: //////////////////////////////////////////////////////////// @@ -124,7 +140,7 @@ protected: /// \param type Type of the socket (TCP or UDP) /// //////////////////////////////////////////////////////////// - Socket(Type type); + Socket(Type type, bool secure); //////////////////////////////////////////////////////////// /// \brief Return the internal handle of the socket @@ -138,6 +154,8 @@ protected: //////////////////////////////////////////////////////////// SocketHandle getHandle() const; + SecureData& getSecureData() const; + //////////////////////////////////////////////////////////// /// \brief Create the internal representation of the socket /// @@ -173,8 +191,10 @@ private: // Member data //////////////////////////////////////////////////////////// Type m_type; ///< Type of the socket (TCP or UDP) - SocketHandle m_socket; ///< Socket descriptor bool m_isBlocking; ///< Current blocking mode of the socket + bool m_isSecure; ///< Socket using SSL + SocketHandle m_socket; ///< Socket descriptor + mutable SecureData m_secureData; ///< Data needed for secure socket }; } // namespace cpp3ds diff --git a/include/cpp3ds/Network/SocketHandle.hpp b/include/cpp3ds/Network/SocketHandle.hpp index 2be0748..eb17937 100644 --- a/include/cpp3ds/Network/SocketHandle.hpp +++ b/include/cpp3ds/Network/SocketHandle.hpp @@ -30,7 +30,7 @@ //////////////////////////////////////////////////////////// #include -#if defined(SFML_SYSTEM_WINDOWS) +#if defined(CPP3DS_SYSTEM_WINDOWS) #include #endif @@ -41,7 +41,7 @@ namespace cpp3ds // Define the low-level socket handle type, specific to // each platform //////////////////////////////////////////////////////////// -#if defined(SFML_SYSTEM_WINDOWS) +#if defined(CPP3DS_SYSTEM_WINDOWS) typedef UINT_PTR SocketHandle; diff --git a/include/cpp3ds/Network/TcpSocket.hpp b/include/cpp3ds/Network/TcpSocket.hpp index 7006a38..2900e06 100644 --- a/include/cpp3ds/Network/TcpSocket.hpp +++ b/include/cpp3ds/Network/TcpSocket.hpp @@ -50,7 +50,7 @@ public: /// \brief Default constructor /// //////////////////////////////////////////////////////////// - TcpSocket(); + TcpSocket(bool secure = false); //////////////////////////////////////////////////////////// /// \brief Get the port to which the socket is bound locally diff --git a/include/cpp3ds/Network/UdpSocket.hpp b/include/cpp3ds/Network/UdpSocket.hpp index e8234ed..35162ba 100644 --- a/include/cpp3ds/Network/UdpSocket.hpp +++ b/include/cpp3ds/Network/UdpSocket.hpp @@ -57,7 +57,7 @@ public: /// \brief Default constructor /// //////////////////////////////////////////////////////////// - UdpSocket(); + UdpSocket(bool secure = false); //////////////////////////////////////////////////////////// /// \brief Get the port to which the socket is bound locally diff --git a/src/cpp3ds/Network/Http.cpp b/src/cpp3ds/Network/Http.cpp index b5b878a..c06780b 100644 --- a/src/cpp3ds/Network/Http.cpp +++ b/src/cpp3ds/Network/Http.cpp @@ -319,19 +319,21 @@ void Http::setHost(const std::string& host, unsigned short port) if (toLower(host.substr(0, 7)) == "http://") { // HTTP protocol + m_connection.setSecure(false); m_hostName = host.substr(7); m_port = (port != 0 ? port : 80); } else if (toLower(host.substr(0, 8)) == "https://") { - // HTTPS protocol -- unsupported (requires encryption and certificates and stuff...) - err() << "HTTPS protocol is not supported by cpp3ds::Http" << std::endl; - m_hostName = ""; - m_port = 0; + // HTTPS protocol + m_connection.setSecure(true); + m_hostName = host.substr(8); + m_port = (port != 0 ? port : 443); } else { // Undefined protocol - use HTTP + m_connection.setSecure(false); m_hostName = host; m_port = (port != 0 ? port : 80); } diff --git a/src/cpp3ds/Network/Socket.cpp b/src/cpp3ds/Network/Socket.cpp index 583ce9f..a654101 100644 --- a/src/cpp3ds/Network/Socket.cpp +++ b/src/cpp3ds/Network/Socket.cpp @@ -28,17 +28,31 @@ #include #include #include +#include +#include + + +namespace +{ + static int entropy_func(void *data, unsigned char *output, size_t len) + { + for (int i = 0; i < len; ++i) + output[i] = rand() % 256; + return 0; + } +} namespace cpp3ds { //////////////////////////////////////////////////////////// -Socket::Socket(Type type) : +Socket::Socket(Type type, bool secure) : m_type (type), m_socket (priv::SocketImpl::invalidSocket()), -m_isBlocking(true) +m_isBlocking(true), +m_isSecure (secure) { - + m_secureData.socket.fd = priv::SocketImpl::invalidSocket(); } @@ -54,8 +68,21 @@ Socket::~Socket() void Socket::setBlocking(bool blocking) { // Apply if the socket is already created - if (m_socket != priv::SocketImpl::invalidSocket()) - priv::SocketImpl::setBlocking(m_socket, blocking); + if (m_isSecure) + { + if (m_secureData.socket.fd != priv::SocketImpl::invalidSocket()) + { + if (blocking) + mbedtls_net_set_block(&m_secureData.socket); + else + mbedtls_net_set_nonblock(&m_secureData.socket); + } + } + else + { + if (m_socket != priv::SocketImpl::invalidSocket()) + priv::SocketImpl::setBlocking(m_socket, blocking); + } m_isBlocking = blocking; } @@ -68,10 +95,39 @@ bool Socket::isBlocking() const } +//////////////////////////////////////////////////////////// +void Socket::setSecure(bool secure) +{ + bool oldSecure = m_isSecure; + m_isSecure = secure; + + if ((secure && !oldSecure && m_socket != priv::SocketImpl::invalidSocket()) || + (!secure && oldSecure && m_secureData.socket.fd != priv::SocketImpl::invalidSocket())) + { + close(); + create(); + } +} + + +//////////////////////////////////////////////////////////// +bool Socket::isSecure() const +{ + return m_isSecure; +} + + //////////////////////////////////////////////////////////// SocketHandle Socket::getHandle() const { - return m_socket; + return m_isSecure ? m_secureData.socket.fd : m_socket; +} + + +//////////////////////////////////////////////////////////// +Socket::SecureData& Socket::getSecureData() const +{ + return m_secureData; } @@ -79,10 +135,46 @@ SocketHandle Socket::getHandle() const void Socket::create() { // Don't create the socket if it already exists - if (m_socket == priv::SocketImpl::invalidSocket()) + if (m_isSecure) { - SocketHandle handle = socket(PF_INET, m_type == Tcp ? SOCK_STREAM : SOCK_DGRAM, 0); - create(handle); + if (m_secureData.socket.fd == priv::SocketImpl::invalidSocket()) + { + int ret; + mbedtls_net_init(&m_secureData.socket); + mbedtls_ssl_init(&m_secureData.ssl); + mbedtls_ssl_config_init(&m_secureData.conf); + mbedtls_x509_crt_init(&m_secureData.cacert); + mbedtls_ctr_drbg_init(&m_secureData.ctr_drbg); + mbedtls_entropy_init( &m_secureData.entropy ); + + srand(time(nullptr)); + if (ret = mbedtls_ctr_drbg_seed(&m_secureData.ctr_drbg, entropy_func, &m_secureData.entropy, NULL, 0 )) + err() << "mbedtls_ctr_drbg_seed failed: " << ret << std::endl; + if ((ret = mbedtls_x509_crt_parse(&m_secureData.cacert, (const unsigned char *) mbedtls_test_cas_pem, mbedtls_test_cas_pem_len)) < 0) + err() << "mbedtls_x509_crt_parse failed: -0x" << std::hex << ret << std::endl; + + // use mbedtls_ssl_conf_endpoint to set socket as server + if ((ret = mbedtls_ssl_config_defaults(&m_secureData.conf, MBEDTLS_SSL_IS_CLIENT, MBEDTLS_SSL_TRANSPORT_STREAM, MBEDTLS_SSL_PRESET_DEFAULT)) != 0) + err() << "mbedtls_ssl_config_defaults failed: " << ret << std::endl; + mbedtls_ssl_conf_authmode(&m_secureData.conf, MBEDTLS_SSL_VERIFY_OPTIONAL); + mbedtls_ssl_conf_ca_chain(&m_secureData.conf, &m_secureData.cacert, NULL); + mbedtls_ssl_conf_rng(&m_secureData.conf, mbedtls_ctr_drbg_random, &m_secureData.ctr_drbg); + if ((ret = mbedtls_ssl_setup(&m_secureData.ssl, &m_secureData.conf)) != 0) + err() << "mbedtls_ssl_setup failed: " << ret << std::endl; + if ((ret = mbedtls_ssl_set_hostname(&m_secureData.ssl, "mbed TLS Server")) != 0) + err() << "mbedtls_ssl_set_hostname failed: " << ret << std::endl; + mbedtls_ssl_set_bio(&m_secureData.ssl, &m_secureData.socket, mbedtls_net_send, mbedtls_net_recv, NULL); + + setBlocking(m_isBlocking); + } + } + else + { + if (m_socket == priv::SocketImpl::invalidSocket()) + { + SocketHandle handle = socket(PF_INET, m_type == Tcp ? SOCK_STREAM : SOCK_DGRAM, IPPROTO_IP); + create(handle); + } } } @@ -141,6 +233,18 @@ void Socket::close() priv::SocketImpl::close(m_socket); m_socket = priv::SocketImpl::invalidSocket(); } + + // Close the secure socket + if (m_secureData.socket.fd != priv::SocketImpl::invalidSocket()) + { + mbedtls_net_free(&m_secureData.socket); + mbedtls_ssl_free(&m_secureData.ssl); + mbedtls_ssl_config_free(&m_secureData.conf); + mbedtls_x509_crt_free(&m_secureData.cacert); + mbedtls_ctr_drbg_free(&m_secureData.ctr_drbg); + mbedtls_entropy_free(&m_secureData.entropy); + m_secureData.socket.fd = priv::SocketImpl::invalidSocket(); + } } } // namespace cpp3ds diff --git a/src/cpp3ds/Network/TcpListener.cpp b/src/cpp3ds/Network/TcpListener.cpp index 6af8f67..c3605bd 100644 --- a/src/cpp3ds/Network/TcpListener.cpp +++ b/src/cpp3ds/Network/TcpListener.cpp @@ -35,7 +35,7 @@ namespace cpp3ds { //////////////////////////////////////////////////////////// TcpListener::TcpListener() : -Socket(Tcp) +Socket(Tcp, false) { } diff --git a/src/cpp3ds/Network/TcpSocket.cpp b/src/cpp3ds/Network/TcpSocket.cpp index 1ec4390..3a30af6 100644 --- a/src/cpp3ds/Network/TcpSocket.cpp +++ b/src/cpp3ds/Network/TcpSocket.cpp @@ -33,6 +33,7 @@ #include #include #include +#include #ifdef _MSC_VER #pragma warning(disable: 4127) // "conditional expression is constant" generated by the FD_SET macro @@ -42,7 +43,7 @@ namespace { // Define the low-level send/receive flags, which depend on the OS - #ifdef SFML_SYSTEM_LINUX + #ifdef CPP3DS_SYSTEM_LINUX const int flags = MSG_NOSIGNAL; #else const int flags = 0; @@ -52,8 +53,8 @@ namespace namespace cpp3ds { //////////////////////////////////////////////////////////// -TcpSocket::TcpSocket() : -Socket(Tcp) +TcpSocket::TcpSocket(bool secure) : +Socket(Tcp, secure) { } @@ -122,86 +123,101 @@ Socket::Status TcpSocket::connect(const IpAddress& remoteAddress, unsigned short // Create the internal socket if it doesn't exist create(); - // Create the remote address - sockaddr_in address = priv::SocketImpl::createAddress(remoteAddress.toInteger(), remotePort); - - if (timeout <= Time::Zero) + if (isSecure()) { - // ----- We're not using a timeout: just try to connect ----- - - // Connect the socket - if (::connect(getHandle(), reinterpret_cast(&address), sizeof(address)) == -1) + char port[5]; + sprintf(port, "%d", remotePort); + if (mbedtls_net_connect(&getSecureData().socket, remoteAddress.toString().c_str(), port, MBEDTLS_NET_PROTO_TCP) != 0) return priv::SocketImpl::getErrorStatus(); - // Connection succeeded - return Done; + int ret; + while((ret = mbedtls_ssl_handshake(&getSecureData().ssl)) != 0) + if(ret != MBEDTLS_ERR_SSL_WANT_READ && ret != MBEDTLS_ERR_SSL_WANT_WRITE) + return Error; } else { - // ----- We're using a timeout: we'll need a few tricks to make it work ----- + // Create the remote address + sockaddr_in address = priv::SocketImpl::createAddress(remoteAddress.toInteger(), remotePort); - // Save the previous blocking state - bool blocking = isBlocking(); - - // Switch to non-blocking to enable our connection timeout - if (blocking) - setBlocking(false); - - // Try to connect to the remote address - if (::connect(getHandle(), reinterpret_cast(&address), sizeof(address)) >= 0) + if (timeout <= Time::Zero) { - // We got instantly connected! (it may no happen a lot...) - setBlocking(blocking); + // ----- We're not using a timeout: just try to connect ----- + + // Connect the socket + if (::connect(getHandle(), reinterpret_cast(&address), sizeof(address)) == -1) + return priv::SocketImpl::getErrorStatus(); + + // Connection succeeded return Done; } - - // Get the error status - Status status = priv::SocketImpl::getErrorStatus(); - - // If we were in non-blocking mode, return immediately - if (!blocking) - return status; - - // Otherwise, wait until something happens to our socket (success, timeout or error) - if (status == Socket::NotReady) + else { - // Setup the selector - fd_set selector; - FD_ZERO(&selector); - FD_SET(getHandle(), &selector); + // ----- We're using a timeout: we'll need a few tricks to make it work ----- - // Setup the timeout - timeval time; - time.tv_sec = static_cast(timeout.asMicroseconds() / 1000000); - time.tv_usec = static_cast(timeout.asMicroseconds() % 1000000); + // Save the previous blocking state + bool blocking = isBlocking(); - // Wait for something to write on our socket (which means that the connection request has returned) - if (select(static_cast(getHandle() + 1), NULL, &selector, NULL, &time) > 0) + // Switch to non-blocking to enable our connection timeout + if (blocking) + setBlocking(false); + + // Try to connect to the remote address + if (::connect(getHandle(), reinterpret_cast(&address), sizeof(address)) >= 0) { - // At this point the connection may have been either accepted or refused. - // To know whether it's a success or a failure, we must check the address of the connected peer - if (getRemoteAddress() != cpp3ds::IpAddress::None) + // We got instantly connected! (it may no happen a lot...) + setBlocking(blocking); + return Done; + } + + // Get the error status + Status status = priv::SocketImpl::getErrorStatus(); + + // If we were in non-blocking mode, return immediately + if (!blocking) + return status; + + // Otherwise, wait until something happens to our socket (success, timeout or error) + if (status == Socket::NotReady) + { + // Setup the selector + fd_set selector; + FD_ZERO(&selector); + FD_SET(getHandle(), &selector); + + // Setup the timeout + timeval time; + time.tv_sec = static_cast(timeout.asMicroseconds() / 1000000); + time.tv_usec = static_cast(timeout.asMicroseconds() % 1000000); + + // Wait for something to write on our socket (which means that the connection request has returned) + if (select(static_cast(getHandle() + 1), NULL, &selector, NULL, &time) > 0) { - // Connection accepted - status = Done; + // At this point the connection may have been either accepted or refused. + // To know whether it's a success or a failure, we must check the address of the connected peer + if (getRemoteAddress() != cpp3ds::IpAddress::None) + { + // Connection accepted + status = Done; + } + else + { + // Connection refused + status = priv::SocketImpl::getErrorStatus(); + } } else { - // Connection refused + // Failed to connect before timeout is over status = priv::SocketImpl::getErrorStatus(); } } - else - { - // Failed to connect before timeout is over - status = priv::SocketImpl::getErrorStatus(); - } + + // Switch back to blocking mode + setBlocking(true); + + return status; } - - // Switch back to blocking mode - setBlocking(true); - - return status; } } @@ -244,7 +260,10 @@ Socket::Status TcpSocket::send(const void* data, std::size_t size, std::size_t& for (sent = 0; sent < size; sent += result) { // Send a chunk of data - result = ::send(getHandle(), static_cast(data) + sent, size - sent, flags); + if (isSecure()) + result = mbedtls_ssl_write(&getSecureData().ssl, static_cast(data) + sent, size - sent); + else + result = ::send(getHandle(), static_cast(data) + sent, size - sent, flags); // Check for errors if (result < 0) @@ -276,7 +295,11 @@ Socket::Status TcpSocket::receive(void* data, std::size_t size, std::size_t& rec } // Receive a chunk of bytes - int sizeReceived = recv(getHandle(), static_cast(data), static_cast(size), flags); + int sizeReceived; + if (isSecure()) + sizeReceived = mbedtls_ssl_read(&getSecureData().ssl, static_cast(data), size); + else + sizeReceived = recv(getHandle(), static_cast(data), static_cast(size), flags); // Check the number of bytes received if (sizeReceived > 0) diff --git a/src/cpp3ds/Network/UdpSocket.cpp b/src/cpp3ds/Network/UdpSocket.cpp index d178f9b..0fd32c9 100644 --- a/src/cpp3ds/Network/UdpSocket.cpp +++ b/src/cpp3ds/Network/UdpSocket.cpp @@ -36,8 +36,8 @@ namespace cpp3ds { //////////////////////////////////////////////////////////// -UdpSocket::UdpSocket() : -Socket (Udp), +UdpSocket::UdpSocket(bool secure) : +Socket (Udp, secure), m_buffer(MaxDatagramSize) {