mirror of
https://github.com/encounter/cpp3ds.git
synced 2026-03-30 11:04:22 -07:00
Add secured sockets with mbedtls and HTTPS support
This commit is contained in:
@@ -31,6 +31,9 @@
|
||||
#include <cpp3ds/Network/SocketHandle.hpp>
|
||||
#include <cpp3ds/System/NonCopyable.hpp>
|
||||
#include <vector>
|
||||
#include <mbedtls/net.h>
|
||||
#include <mbedtls/entropy.h>
|
||||
#include <mbedtls/ctr_drbg.h>
|
||||
|
||||
|
||||
namespace cpp3ds
|
||||
@@ -67,6 +70,16 @@ public:
|
||||
AnyPort = 0 ///< Special value that tells the system to pick any available port
|
||||
};
|
||||
|
||||
struct SecureData
|
||||
{
|
||||
mbedtls_net_context socket;
|
||||
mbedtls_entropy_context entropy;
|
||||
mbedtls_ctr_drbg_context ctr_drbg;
|
||||
mbedtls_ssl_context ssl;
|
||||
mbedtls_ssl_config conf;
|
||||
mbedtls_x509_crt cacert;
|
||||
};
|
||||
|
||||
public:
|
||||
|
||||
////////////////////////////////////////////////////////////
|
||||
@@ -104,6 +117,9 @@ public:
|
||||
////////////////////////////////////////////////////////////
|
||||
bool isBlocking() const;
|
||||
|
||||
void setSecure(bool secure);
|
||||
bool isSecure() const;
|
||||
|
||||
protected:
|
||||
|
||||
////////////////////////////////////////////////////////////
|
||||
@@ -124,7 +140,7 @@ protected:
|
||||
/// \param type Type of the socket (TCP or UDP)
|
||||
///
|
||||
////////////////////////////////////////////////////////////
|
||||
Socket(Type type);
|
||||
Socket(Type type, bool secure);
|
||||
|
||||
////////////////////////////////////////////////////////////
|
||||
/// \brief Return the internal handle of the socket
|
||||
@@ -138,6 +154,8 @@ protected:
|
||||
////////////////////////////////////////////////////////////
|
||||
SocketHandle getHandle() const;
|
||||
|
||||
SecureData& getSecureData() const;
|
||||
|
||||
////////////////////////////////////////////////////////////
|
||||
/// \brief Create the internal representation of the socket
|
||||
///
|
||||
@@ -173,8 +191,10 @@ private:
|
||||
// Member data
|
||||
////////////////////////////////////////////////////////////
|
||||
Type m_type; ///< Type of the socket (TCP or UDP)
|
||||
SocketHandle m_socket; ///< Socket descriptor
|
||||
bool m_isBlocking; ///< Current blocking mode of the socket
|
||||
bool m_isSecure; ///< Socket using SSL
|
||||
SocketHandle m_socket; ///< Socket descriptor
|
||||
mutable SecureData m_secureData; ///< Data needed for secure socket
|
||||
};
|
||||
|
||||
} // namespace cpp3ds
|
||||
|
||||
@@ -30,7 +30,7 @@
|
||||
////////////////////////////////////////////////////////////
|
||||
#include <cpp3ds/Config.hpp>
|
||||
|
||||
#if defined(SFML_SYSTEM_WINDOWS)
|
||||
#if defined(CPP3DS_SYSTEM_WINDOWS)
|
||||
#include <basetsd.h>
|
||||
#endif
|
||||
|
||||
@@ -41,7 +41,7 @@ namespace cpp3ds
|
||||
// Define the low-level socket handle type, specific to
|
||||
// each platform
|
||||
////////////////////////////////////////////////////////////
|
||||
#if defined(SFML_SYSTEM_WINDOWS)
|
||||
#if defined(CPP3DS_SYSTEM_WINDOWS)
|
||||
|
||||
typedef UINT_PTR SocketHandle;
|
||||
|
||||
|
||||
@@ -50,7 +50,7 @@ public:
|
||||
/// \brief Default constructor
|
||||
///
|
||||
////////////////////////////////////////////////////////////
|
||||
TcpSocket();
|
||||
TcpSocket(bool secure = false);
|
||||
|
||||
////////////////////////////////////////////////////////////
|
||||
/// \brief Get the port to which the socket is bound locally
|
||||
|
||||
@@ -57,7 +57,7 @@ public:
|
||||
/// \brief Default constructor
|
||||
///
|
||||
////////////////////////////////////////////////////////////
|
||||
UdpSocket();
|
||||
UdpSocket(bool secure = false);
|
||||
|
||||
////////////////////////////////////////////////////////////
|
||||
/// \brief Get the port to which the socket is bound locally
|
||||
|
||||
@@ -319,19 +319,21 @@ void Http::setHost(const std::string& host, unsigned short port)
|
||||
if (toLower(host.substr(0, 7)) == "http://")
|
||||
{
|
||||
// HTTP protocol
|
||||
m_connection.setSecure(false);
|
||||
m_hostName = host.substr(7);
|
||||
m_port = (port != 0 ? port : 80);
|
||||
}
|
||||
else if (toLower(host.substr(0, 8)) == "https://")
|
||||
{
|
||||
// HTTPS protocol -- unsupported (requires encryption and certificates and stuff...)
|
||||
err() << "HTTPS protocol is not supported by cpp3ds::Http" << std::endl;
|
||||
m_hostName = "";
|
||||
m_port = 0;
|
||||
// HTTPS protocol
|
||||
m_connection.setSecure(true);
|
||||
m_hostName = host.substr(8);
|
||||
m_port = (port != 0 ? port : 443);
|
||||
}
|
||||
else
|
||||
{
|
||||
// Undefined protocol - use HTTP
|
||||
m_connection.setSecure(false);
|
||||
m_hostName = host;
|
||||
m_port = (port != 0 ? port : 80);
|
||||
}
|
||||
|
||||
@@ -28,17 +28,31 @@
|
||||
#include <cpp3ds/Network/Socket.hpp>
|
||||
#include <cpp3ds/Network/SocketImpl.hpp>
|
||||
#include <cpp3ds/System/Err.hpp>
|
||||
#include <mbedtls/certs.h>
|
||||
#include <stdlib.h>
|
||||
|
||||
|
||||
namespace
|
||||
{
|
||||
static int entropy_func(void *data, unsigned char *output, size_t len)
|
||||
{
|
||||
for (int i = 0; i < len; ++i)
|
||||
output[i] = rand() % 256;
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
namespace cpp3ds
|
||||
{
|
||||
////////////////////////////////////////////////////////////
|
||||
Socket::Socket(Type type) :
|
||||
Socket::Socket(Type type, bool secure) :
|
||||
m_type (type),
|
||||
m_socket (priv::SocketImpl::invalidSocket()),
|
||||
m_isBlocking(true)
|
||||
m_isBlocking(true),
|
||||
m_isSecure (secure)
|
||||
{
|
||||
|
||||
m_secureData.socket.fd = priv::SocketImpl::invalidSocket();
|
||||
}
|
||||
|
||||
|
||||
@@ -54,8 +68,21 @@ Socket::~Socket()
|
||||
void Socket::setBlocking(bool blocking)
|
||||
{
|
||||
// Apply if the socket is already created
|
||||
if (m_socket != priv::SocketImpl::invalidSocket())
|
||||
priv::SocketImpl::setBlocking(m_socket, blocking);
|
||||
if (m_isSecure)
|
||||
{
|
||||
if (m_secureData.socket.fd != priv::SocketImpl::invalidSocket())
|
||||
{
|
||||
if (blocking)
|
||||
mbedtls_net_set_block(&m_secureData.socket);
|
||||
else
|
||||
mbedtls_net_set_nonblock(&m_secureData.socket);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if (m_socket != priv::SocketImpl::invalidSocket())
|
||||
priv::SocketImpl::setBlocking(m_socket, blocking);
|
||||
}
|
||||
|
||||
m_isBlocking = blocking;
|
||||
}
|
||||
@@ -68,10 +95,39 @@ bool Socket::isBlocking() const
|
||||
}
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////
|
||||
void Socket::setSecure(bool secure)
|
||||
{
|
||||
bool oldSecure = m_isSecure;
|
||||
m_isSecure = secure;
|
||||
|
||||
if ((secure && !oldSecure && m_socket != priv::SocketImpl::invalidSocket()) ||
|
||||
(!secure && oldSecure && m_secureData.socket.fd != priv::SocketImpl::invalidSocket()))
|
||||
{
|
||||
close();
|
||||
create();
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////
|
||||
bool Socket::isSecure() const
|
||||
{
|
||||
return m_isSecure;
|
||||
}
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////
|
||||
SocketHandle Socket::getHandle() const
|
||||
{
|
||||
return m_socket;
|
||||
return m_isSecure ? m_secureData.socket.fd : m_socket;
|
||||
}
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////
|
||||
Socket::SecureData& Socket::getSecureData() const
|
||||
{
|
||||
return m_secureData;
|
||||
}
|
||||
|
||||
|
||||
@@ -79,10 +135,46 @@ SocketHandle Socket::getHandle() const
|
||||
void Socket::create()
|
||||
{
|
||||
// Don't create the socket if it already exists
|
||||
if (m_socket == priv::SocketImpl::invalidSocket())
|
||||
if (m_isSecure)
|
||||
{
|
||||
SocketHandle handle = socket(PF_INET, m_type == Tcp ? SOCK_STREAM : SOCK_DGRAM, 0);
|
||||
create(handle);
|
||||
if (m_secureData.socket.fd == priv::SocketImpl::invalidSocket())
|
||||
{
|
||||
int ret;
|
||||
mbedtls_net_init(&m_secureData.socket);
|
||||
mbedtls_ssl_init(&m_secureData.ssl);
|
||||
mbedtls_ssl_config_init(&m_secureData.conf);
|
||||
mbedtls_x509_crt_init(&m_secureData.cacert);
|
||||
mbedtls_ctr_drbg_init(&m_secureData.ctr_drbg);
|
||||
mbedtls_entropy_init( &m_secureData.entropy );
|
||||
|
||||
srand(time(nullptr));
|
||||
if (ret = mbedtls_ctr_drbg_seed(&m_secureData.ctr_drbg, entropy_func, &m_secureData.entropy, NULL, 0 ))
|
||||
err() << "mbedtls_ctr_drbg_seed failed: " << ret << std::endl;
|
||||
if ((ret = mbedtls_x509_crt_parse(&m_secureData.cacert, (const unsigned char *) mbedtls_test_cas_pem, mbedtls_test_cas_pem_len)) < 0)
|
||||
err() << "mbedtls_x509_crt_parse failed: -0x" << std::hex << ret << std::endl;
|
||||
|
||||
// use mbedtls_ssl_conf_endpoint to set socket as server
|
||||
if ((ret = mbedtls_ssl_config_defaults(&m_secureData.conf, MBEDTLS_SSL_IS_CLIENT, MBEDTLS_SSL_TRANSPORT_STREAM, MBEDTLS_SSL_PRESET_DEFAULT)) != 0)
|
||||
err() << "mbedtls_ssl_config_defaults failed: " << ret << std::endl;
|
||||
mbedtls_ssl_conf_authmode(&m_secureData.conf, MBEDTLS_SSL_VERIFY_OPTIONAL);
|
||||
mbedtls_ssl_conf_ca_chain(&m_secureData.conf, &m_secureData.cacert, NULL);
|
||||
mbedtls_ssl_conf_rng(&m_secureData.conf, mbedtls_ctr_drbg_random, &m_secureData.ctr_drbg);
|
||||
if ((ret = mbedtls_ssl_setup(&m_secureData.ssl, &m_secureData.conf)) != 0)
|
||||
err() << "mbedtls_ssl_setup failed: " << ret << std::endl;
|
||||
if ((ret = mbedtls_ssl_set_hostname(&m_secureData.ssl, "mbed TLS Server")) != 0)
|
||||
err() << "mbedtls_ssl_set_hostname failed: " << ret << std::endl;
|
||||
mbedtls_ssl_set_bio(&m_secureData.ssl, &m_secureData.socket, mbedtls_net_send, mbedtls_net_recv, NULL);
|
||||
|
||||
setBlocking(m_isBlocking);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if (m_socket == priv::SocketImpl::invalidSocket())
|
||||
{
|
||||
SocketHandle handle = socket(PF_INET, m_type == Tcp ? SOCK_STREAM : SOCK_DGRAM, IPPROTO_IP);
|
||||
create(handle);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -141,6 +233,18 @@ void Socket::close()
|
||||
priv::SocketImpl::close(m_socket);
|
||||
m_socket = priv::SocketImpl::invalidSocket();
|
||||
}
|
||||
|
||||
// Close the secure socket
|
||||
if (m_secureData.socket.fd != priv::SocketImpl::invalidSocket())
|
||||
{
|
||||
mbedtls_net_free(&m_secureData.socket);
|
||||
mbedtls_ssl_free(&m_secureData.ssl);
|
||||
mbedtls_ssl_config_free(&m_secureData.conf);
|
||||
mbedtls_x509_crt_free(&m_secureData.cacert);
|
||||
mbedtls_ctr_drbg_free(&m_secureData.ctr_drbg);
|
||||
mbedtls_entropy_free(&m_secureData.entropy);
|
||||
m_secureData.socket.fd = priv::SocketImpl::invalidSocket();
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace cpp3ds
|
||||
|
||||
@@ -35,7 +35,7 @@ namespace cpp3ds
|
||||
{
|
||||
////////////////////////////////////////////////////////////
|
||||
TcpListener::TcpListener() :
|
||||
Socket(Tcp)
|
||||
Socket(Tcp, false)
|
||||
{
|
||||
|
||||
}
|
||||
|
||||
@@ -33,6 +33,7 @@
|
||||
#include <sys/select.h>
|
||||
#include <algorithm>
|
||||
#include <cstring>
|
||||
#include <mbedtls/net.h>
|
||||
|
||||
#ifdef _MSC_VER
|
||||
#pragma warning(disable: 4127) // "conditional expression is constant" generated by the FD_SET macro
|
||||
@@ -42,7 +43,7 @@
|
||||
namespace
|
||||
{
|
||||
// Define the low-level send/receive flags, which depend on the OS
|
||||
#ifdef SFML_SYSTEM_LINUX
|
||||
#ifdef CPP3DS_SYSTEM_LINUX
|
||||
const int flags = MSG_NOSIGNAL;
|
||||
#else
|
||||
const int flags = 0;
|
||||
@@ -52,8 +53,8 @@ namespace
|
||||
namespace cpp3ds
|
||||
{
|
||||
////////////////////////////////////////////////////////////
|
||||
TcpSocket::TcpSocket() :
|
||||
Socket(Tcp)
|
||||
TcpSocket::TcpSocket(bool secure) :
|
||||
Socket(Tcp, secure)
|
||||
{
|
||||
|
||||
}
|
||||
@@ -122,86 +123,101 @@ Socket::Status TcpSocket::connect(const IpAddress& remoteAddress, unsigned short
|
||||
// Create the internal socket if it doesn't exist
|
||||
create();
|
||||
|
||||
// Create the remote address
|
||||
sockaddr_in address = priv::SocketImpl::createAddress(remoteAddress.toInteger(), remotePort);
|
||||
|
||||
if (timeout <= Time::Zero)
|
||||
if (isSecure())
|
||||
{
|
||||
// ----- We're not using a timeout: just try to connect -----
|
||||
|
||||
// Connect the socket
|
||||
if (::connect(getHandle(), reinterpret_cast<sockaddr*>(&address), sizeof(address)) == -1)
|
||||
char port[5];
|
||||
sprintf(port, "%d", remotePort);
|
||||
if (mbedtls_net_connect(&getSecureData().socket, remoteAddress.toString().c_str(), port, MBEDTLS_NET_PROTO_TCP) != 0)
|
||||
return priv::SocketImpl::getErrorStatus();
|
||||
|
||||
// Connection succeeded
|
||||
return Done;
|
||||
int ret;
|
||||
while((ret = mbedtls_ssl_handshake(&getSecureData().ssl)) != 0)
|
||||
if(ret != MBEDTLS_ERR_SSL_WANT_READ && ret != MBEDTLS_ERR_SSL_WANT_WRITE)
|
||||
return Error;
|
||||
}
|
||||
else
|
||||
{
|
||||
// ----- We're using a timeout: we'll need a few tricks to make it work -----
|
||||
// Create the remote address
|
||||
sockaddr_in address = priv::SocketImpl::createAddress(remoteAddress.toInteger(), remotePort);
|
||||
|
||||
// Save the previous blocking state
|
||||
bool blocking = isBlocking();
|
||||
|
||||
// Switch to non-blocking to enable our connection timeout
|
||||
if (blocking)
|
||||
setBlocking(false);
|
||||
|
||||
// Try to connect to the remote address
|
||||
if (::connect(getHandle(), reinterpret_cast<sockaddr*>(&address), sizeof(address)) >= 0)
|
||||
if (timeout <= Time::Zero)
|
||||
{
|
||||
// We got instantly connected! (it may no happen a lot...)
|
||||
setBlocking(blocking);
|
||||
// ----- We're not using a timeout: just try to connect -----
|
||||
|
||||
// Connect the socket
|
||||
if (::connect(getHandle(), reinterpret_cast<sockaddr*>(&address), sizeof(address)) == -1)
|
||||
return priv::SocketImpl::getErrorStatus();
|
||||
|
||||
// Connection succeeded
|
||||
return Done;
|
||||
}
|
||||
|
||||
// Get the error status
|
||||
Status status = priv::SocketImpl::getErrorStatus();
|
||||
|
||||
// If we were in non-blocking mode, return immediately
|
||||
if (!blocking)
|
||||
return status;
|
||||
|
||||
// Otherwise, wait until something happens to our socket (success, timeout or error)
|
||||
if (status == Socket::NotReady)
|
||||
else
|
||||
{
|
||||
// Setup the selector
|
||||
fd_set selector;
|
||||
FD_ZERO(&selector);
|
||||
FD_SET(getHandle(), &selector);
|
||||
// ----- We're using a timeout: we'll need a few tricks to make it work -----
|
||||
|
||||
// Setup the timeout
|
||||
timeval time;
|
||||
time.tv_sec = static_cast<long>(timeout.asMicroseconds() / 1000000);
|
||||
time.tv_usec = static_cast<long>(timeout.asMicroseconds() % 1000000);
|
||||
// Save the previous blocking state
|
||||
bool blocking = isBlocking();
|
||||
|
||||
// Wait for something to write on our socket (which means that the connection request has returned)
|
||||
if (select(static_cast<int>(getHandle() + 1), NULL, &selector, NULL, &time) > 0)
|
||||
// Switch to non-blocking to enable our connection timeout
|
||||
if (blocking)
|
||||
setBlocking(false);
|
||||
|
||||
// Try to connect to the remote address
|
||||
if (::connect(getHandle(), reinterpret_cast<sockaddr*>(&address), sizeof(address)) >= 0)
|
||||
{
|
||||
// At this point the connection may have been either accepted or refused.
|
||||
// To know whether it's a success or a failure, we must check the address of the connected peer
|
||||
if (getRemoteAddress() != cpp3ds::IpAddress::None)
|
||||
// We got instantly connected! (it may no happen a lot...)
|
||||
setBlocking(blocking);
|
||||
return Done;
|
||||
}
|
||||
|
||||
// Get the error status
|
||||
Status status = priv::SocketImpl::getErrorStatus();
|
||||
|
||||
// If we were in non-blocking mode, return immediately
|
||||
if (!blocking)
|
||||
return status;
|
||||
|
||||
// Otherwise, wait until something happens to our socket (success, timeout or error)
|
||||
if (status == Socket::NotReady)
|
||||
{
|
||||
// Setup the selector
|
||||
fd_set selector;
|
||||
FD_ZERO(&selector);
|
||||
FD_SET(getHandle(), &selector);
|
||||
|
||||
// Setup the timeout
|
||||
timeval time;
|
||||
time.tv_sec = static_cast<long>(timeout.asMicroseconds() / 1000000);
|
||||
time.tv_usec = static_cast<long>(timeout.asMicroseconds() % 1000000);
|
||||
|
||||
// Wait for something to write on our socket (which means that the connection request has returned)
|
||||
if (select(static_cast<int>(getHandle() + 1), NULL, &selector, NULL, &time) > 0)
|
||||
{
|
||||
// Connection accepted
|
||||
status = Done;
|
||||
// At this point the connection may have been either accepted or refused.
|
||||
// To know whether it's a success or a failure, we must check the address of the connected peer
|
||||
if (getRemoteAddress() != cpp3ds::IpAddress::None)
|
||||
{
|
||||
// Connection accepted
|
||||
status = Done;
|
||||
}
|
||||
else
|
||||
{
|
||||
// Connection refused
|
||||
status = priv::SocketImpl::getErrorStatus();
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
// Connection refused
|
||||
// Failed to connect before timeout is over
|
||||
status = priv::SocketImpl::getErrorStatus();
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
// Failed to connect before timeout is over
|
||||
status = priv::SocketImpl::getErrorStatus();
|
||||
}
|
||||
|
||||
// Switch back to blocking mode
|
||||
setBlocking(true);
|
||||
|
||||
return status;
|
||||
}
|
||||
|
||||
// Switch back to blocking mode
|
||||
setBlocking(true);
|
||||
|
||||
return status;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -244,7 +260,10 @@ Socket::Status TcpSocket::send(const void* data, std::size_t size, std::size_t&
|
||||
for (sent = 0; sent < size; sent += result)
|
||||
{
|
||||
// Send a chunk of data
|
||||
result = ::send(getHandle(), static_cast<const char*>(data) + sent, size - sent, flags);
|
||||
if (isSecure())
|
||||
result = mbedtls_ssl_write(&getSecureData().ssl, static_cast<const unsigned char*>(data) + sent, size - sent);
|
||||
else
|
||||
result = ::send(getHandle(), static_cast<const char*>(data) + sent, size - sent, flags);
|
||||
|
||||
// Check for errors
|
||||
if (result < 0)
|
||||
@@ -276,7 +295,11 @@ Socket::Status TcpSocket::receive(void* data, std::size_t size, std::size_t& rec
|
||||
}
|
||||
|
||||
// Receive a chunk of bytes
|
||||
int sizeReceived = recv(getHandle(), static_cast<char*>(data), static_cast<int>(size), flags);
|
||||
int sizeReceived;
|
||||
if (isSecure())
|
||||
sizeReceived = mbedtls_ssl_read(&getSecureData().ssl, static_cast<unsigned char*>(data), size);
|
||||
else
|
||||
sizeReceived = recv(getHandle(), static_cast<char*>(data), static_cast<int>(size), flags);
|
||||
|
||||
// Check the number of bytes received
|
||||
if (sizeReceived > 0)
|
||||
|
||||
@@ -36,8 +36,8 @@
|
||||
namespace cpp3ds
|
||||
{
|
||||
////////////////////////////////////////////////////////////
|
||||
UdpSocket::UdpSocket() :
|
||||
Socket (Udp),
|
||||
UdpSocket::UdpSocket(bool secure) :
|
||||
Socket (Udp, secure),
|
||||
m_buffer(MaxDatagramSize)
|
||||
{
|
||||
|
||||
|
||||
Reference in New Issue
Block a user