Replace mbedtls with OpenSSL for emulator

This commit is contained in:
Thomas Edvalson
2016-05-23 12:55:09 -04:00
parent 3e47205423
commit a968c83584
4 changed files with 120 additions and 167 deletions
+6 -9
View File
@@ -32,9 +32,9 @@
#include <cpp3ds/System/NonCopyable.hpp>
#include <vector>
#ifdef EMULATION
#include <mbedtls/net.h>
#include <mbedtls/entropy.h>
#include <mbedtls/ctr_drbg.h>
#define OPENSSL_NO_BIO
#include <openssl/ssl.h>
#include <openssl/err.h>
#else
#include <3ds.h>
#endif
@@ -76,12 +76,9 @@ public:
struct SecureData
{
#ifdef EMULATION
mbedtls_net_context socket;
mbedtls_entropy_context entropy;
mbedtls_ctr_drbg_context ctr_drbg;
mbedtls_ssl_context ssl;
mbedtls_ssl_config conf;
mbedtls_x509_crt cacert;
SSL_CTX *sslContext;
const SSL_METHOD *sslMethod;
::SSL *ssl;
#else
sslcContext sslContext;
u32 rootCertChain;
+73 -86
View File
@@ -34,9 +34,7 @@
#include <algorithm>
#include <cstring>
#include <iostream>
#ifdef EMULATION
#include <mbedtls/net.h>
#else
#ifndef EMULATION
#include <3ds.h>
#endif
@@ -143,110 +141,99 @@ Socket::Status TcpSocket::connect(const IpAddress& remoteAddress, unsigned short
// Create the internal socket if it doesn't exist
create();
#ifdef EMULATION
if (isSecure())
{
char port[5];
sprintf(port, "%d", remotePort);
if (mbedtls_net_connect(&getSecureData().socket, remoteAddress.toString().c_str(), port, MBEDTLS_NET_PROTO_TCP) != 0)
return priv::SocketImpl::getErrorStatus();
// Create the remote address
sockaddr_in address = priv::SocketImpl::createAddress(remoteAddress.toInteger(), remotePort);
int ret;
while((ret = mbedtls_ssl_handshake(&getSecureData().ssl)) != 0)
if(ret != MBEDTLS_ERR_SSL_WANT_READ && ret != MBEDTLS_ERR_SSL_WANT_WRITE)
return Error;
if (timeout <= Time::Zero)
{
// ----- We're not using a timeout: just try to connect -----
// Connect the socket
if (::connect(getHandle(), reinterpret_cast<sockaddr*>(&address), sizeof(address)) == -1)
return priv::SocketImpl::getErrorStatus();
#ifndef EMULATION
if (isSecure())
sslc_init(getSecureData(), getHandle(), remoteAddress.toString().c_str());
#else
if (isSecure())
SSL_connect(getSecureData().ssl);
#endif
// Connection succeeded
return Done;
}
else
#endif
{
// Create the remote address
sockaddr_in address = priv::SocketImpl::createAddress(remoteAddress.toInteger(), remotePort);
// ----- We're using a timeout: we'll need a few tricks to make it work -----
if (timeout <= Time::Zero)
// Save the previous blocking state
bool blocking = isBlocking();
// Switch to non-blocking to enable our connection timeout
if (blocking)
setBlocking(false);
// Try to connect to the remote address
if (::connect(getHandle(), reinterpret_cast<sockaddr*>(&address), sizeof(address)) >= 0)
{
// ----- We're not using a timeout: just try to connect -----
// Connect the socket
if (::connect(getHandle(), reinterpret_cast<sockaddr*>(&address), sizeof(address)) == -1)
return priv::SocketImpl::getErrorStatus();
// We got instantly connected! (it may no happen a lot...)
setBlocking(blocking);
#ifndef EMULATION
if (isSecure())
if (isSecure())
sslc_init(getSecureData(), getHandle(), remoteAddress.toString().c_str());
#else
if (isSecure())
SSL_connect(getSecureData().ssl);
#endif
// Connection succeeded
return Done;
}
else
// Get the error status
Status status = priv::SocketImpl::getErrorStatus();
// If we were in non-blocking mode, return immediately
if (!blocking)
return status;
// Otherwise, wait until something happens to our socket (success, timeout or error)
if (status == Socket::NotReady)
{
// ----- We're using a timeout: we'll need a few tricks to make it work -----
// Setup the selector
fd_set selector;
FD_ZERO(&selector);
FD_SET(getHandle(), &selector);
// Save the previous blocking state
bool blocking = isBlocking();
// Setup the timeout
timeval time;
time.tv_sec = static_cast<long>(timeout.asMicroseconds() / 1000000);
time.tv_usec = static_cast<long>(timeout.asMicroseconds() % 1000000);
// Switch to non-blocking to enable our connection timeout
if (blocking)
setBlocking(false);
// Try to connect to the remote address
if (::connect(getHandle(), reinterpret_cast<sockaddr*>(&address), sizeof(address)) >= 0)
// Wait for something to write on our socket (which means that the connection request has returned)
if (select(static_cast<int>(getHandle() + 1), NULL, &selector, NULL, &time) > 0)
{
// We got instantly connected! (it may no happen a lot...)
setBlocking(blocking);
#ifndef EMULATION
if (isSecure())
sslc_init(getSecureData(), getHandle(), remoteAddress.toString().c_str());
#endif
return Done;
}
// Get the error status
Status status = priv::SocketImpl::getErrorStatus();
// If we were in non-blocking mode, return immediately
if (!blocking)
return status;
// Otherwise, wait until something happens to our socket (success, timeout or error)
if (status == Socket::NotReady)
{
// Setup the selector
fd_set selector;
FD_ZERO(&selector);
FD_SET(getHandle(), &selector);
// Setup the timeout
timeval time;
time.tv_sec = static_cast<long>(timeout.asMicroseconds() / 1000000);
time.tv_usec = static_cast<long>(timeout.asMicroseconds() % 1000000);
// Wait for something to write on our socket (which means that the connection request has returned)
if (select(static_cast<int>(getHandle() + 1), NULL, &selector, NULL, &time) > 0)
// At this point the connection may have been either accepted or refused.
// To know whether it's a success or a failure, we must check the address of the connected peer
if (getRemoteAddress() != cpp3ds::IpAddress::None)
{
// At this point the connection may have been either accepted or refused.
// To know whether it's a success or a failure, we must check the address of the connected peer
if (getRemoteAddress() != cpp3ds::IpAddress::None)
{
// Connection accepted
status = Done;
}
else
{
// Connection refused
status = priv::SocketImpl::getErrorStatus();
}
// Connection accepted
status = Done;
}
else
{
// Failed to connect before timeout is over
// Connection refused
status = priv::SocketImpl::getErrorStatus();
}
}
// Switch back to blocking mode
setBlocking(true);
return status;
else
{
// Failed to connect before timeout is over
status = priv::SocketImpl::getErrorStatus();
}
}
// Switch back to blocking mode
setBlocking(true);
return status;
}
}
@@ -291,7 +278,7 @@ Socket::Status TcpSocket::send(const void* data, std::size_t size, std::size_t&
// Send a chunk of data
if (isSecure())
#ifdef EMULATION
result = mbedtls_ssl_write(&getSecureData().ssl, static_cast<const unsigned char*>(data) + sent, size - sent);
result = SSL_write(getSecureData().ssl, static_cast<const unsigned char*>(data) + sent, size - sent);
#else
result = sslcWrite(&getSecureData().sslContext, static_cast<const unsigned char*>(data) + sent, size - sent);
#endif
@@ -331,7 +318,7 @@ Socket::Status TcpSocket::receive(void* data, std::size_t size, std::size_t& rec
int sizeReceived;
if (isSecure())
#ifdef EMULATION
sizeReceived = mbedtls_ssl_read(&getSecureData().ssl, static_cast<unsigned char*>(data), size);
sizeReceived = SSL_read(getSecureData().ssl, static_cast<unsigned char*>(data), size);
#else
sizeReceived = sslcRead(&getSecureData().sslContext, data, size, false);
#endif
+2 -1
View File
@@ -135,8 +135,9 @@ unset(FREETYPE_INCLUDE_DIRS CACHE)
find_package(JPEG REQUIRED)
find_package(Freetype REQUIRED)
find_package(OpenSSL REQUIRED)
include_directories(${FREETYPE_INCLUDE_DIRS} ${JPEG_INCLUDE_DIR})
include_directories(${FREETYPE_INCLUDE_DIRS} ${JPEG_INCLUDE_DIR} ${OPENSSL_INCLUDE_DIR})
add_library(cpp3ds-emu STATIC
${SRC}
+39 -71
View File
@@ -28,17 +28,21 @@
#include <cpp3ds/Network/Socket.hpp>
#include <cpp3ds/Network/SocketImpl.hpp>
#include <cpp3ds/System/Err.hpp>
#include <mbedtls/certs.h>
#include <stdlib.h>
namespace
{
static int entropy_func(void *data, unsigned char *output, size_t len)
bool initializedSSL = false;
void initSSL()
{
for (int i = 0; i < len; ++i)
output[i] = rand() % 256;
return 0;
if (!initializedSSL)
{
SSL_library_init();
SSL_load_error_strings();
OpenSSL_add_all_algorithms();
initializedSSL = true;
}
}
}
@@ -49,10 +53,13 @@ namespace cpp3ds
Socket::Socket(Type type, bool secure) :
m_type (type),
m_socket (priv::SocketImpl::invalidSocket()),
m_isBlocking(true),
m_isSecure (secure)
m_isBlocking(true)
{
m_secureData.socket.fd = priv::SocketImpl::invalidSocket();
m_secureData.ssl = nullptr;
m_secureData.sslMethod = nullptr;
m_secureData.sslContext = nullptr;
setSecure(secure);
}
@@ -68,21 +75,8 @@ Socket::~Socket()
void Socket::setBlocking(bool blocking)
{
// Apply if the socket is already created
if (m_isSecure)
{
if (m_secureData.socket.fd != priv::SocketImpl::invalidSocket())
{
if (blocking)
mbedtls_net_set_block(&m_secureData.socket);
else
mbedtls_net_set_nonblock(&m_secureData.socket);
}
}
else
{
if (m_socket != priv::SocketImpl::invalidSocket())
priv::SocketImpl::setBlocking(m_socket, blocking);
}
if (m_socket != priv::SocketImpl::invalidSocket())
priv::SocketImpl::setBlocking(m_socket, blocking);
m_isBlocking = blocking;
}
@@ -102,7 +96,7 @@ void Socket::setSecure(bool secure)
m_isSecure = secure;
if ((secure && !oldSecure && m_socket != priv::SocketImpl::invalidSocket()) ||
(!secure && oldSecure && m_secureData.socket.fd != priv::SocketImpl::invalidSocket()))
(!secure && oldSecure && m_secureData.ssl))
{
close();
create();
@@ -120,7 +114,7 @@ bool Socket::isSecure() const
////////////////////////////////////////////////////////////
SocketHandle Socket::getHandle() const
{
return m_isSecure ? m_secureData.socket.fd : m_socket;
return m_socket;
}
@@ -135,46 +129,20 @@ Socket::SecureData& Socket::getSecureData() const
void Socket::create()
{
// Don't create the socket if it already exists
if (m_isSecure)
if (m_socket == priv::SocketImpl::invalidSocket())
{
if (m_secureData.socket.fd == priv::SocketImpl::invalidSocket())
{
int ret;
mbedtls_net_init(&m_secureData.socket);
mbedtls_ssl_init(&m_secureData.ssl);
mbedtls_ssl_config_init(&m_secureData.conf);
mbedtls_x509_crt_init(&m_secureData.cacert);
mbedtls_ctr_drbg_init(&m_secureData.ctr_drbg);
mbedtls_entropy_init( &m_secureData.entropy );
srand(time(nullptr));
if (ret = mbedtls_ctr_drbg_seed(&m_secureData.ctr_drbg, entropy_func, &m_secureData.entropy, NULL, 0 ))
err() << "mbedtls_ctr_drbg_seed failed: " << ret << std::endl;
if ((ret = mbedtls_x509_crt_parse(&m_secureData.cacert, (const unsigned char *) mbedtls_test_cas_pem, mbedtls_test_cas_pem_len)) < 0)
err() << "mbedtls_x509_crt_parse failed: -0x" << std::hex << ret << std::endl;
// use mbedtls_ssl_conf_endpoint to set socket as server
if ((ret = mbedtls_ssl_config_defaults(&m_secureData.conf, MBEDTLS_SSL_IS_CLIENT, MBEDTLS_SSL_TRANSPORT_STREAM, MBEDTLS_SSL_PRESET_DEFAULT)) != 0)
err() << "mbedtls_ssl_config_defaults failed: " << ret << std::endl;
mbedtls_ssl_conf_authmode(&m_secureData.conf, MBEDTLS_SSL_VERIFY_OPTIONAL);
mbedtls_ssl_conf_ca_chain(&m_secureData.conf, &m_secureData.cacert, NULL);
mbedtls_ssl_conf_rng(&m_secureData.conf, mbedtls_ctr_drbg_random, &m_secureData.ctr_drbg);
if ((ret = mbedtls_ssl_setup(&m_secureData.ssl, &m_secureData.conf)) != 0)
err() << "mbedtls_ssl_setup failed: " << ret << std::endl;
if ((ret = mbedtls_ssl_set_hostname(&m_secureData.ssl, "mbed TLS Server")) != 0)
err() << "mbedtls_ssl_set_hostname failed: " << ret << std::endl;
mbedtls_ssl_set_bio(&m_secureData.ssl, &m_secureData.socket, mbedtls_net_send, mbedtls_net_recv, NULL);
setBlocking(m_isBlocking);
}
SocketHandle handle = socket(PF_INET, m_type == Tcp ? SOCK_STREAM : SOCK_DGRAM, IPPROTO_IP);
create(handle);
}
else
if (m_isSecure && !m_secureData.ssl)
{
if (m_socket == priv::SocketImpl::invalidSocket())
{
SocketHandle handle = socket(PF_INET, m_type == Tcp ? SOCK_STREAM : SOCK_DGRAM, IPPROTO_IP);
create(handle);
}
// TODO: handle errors
initSSL();
m_secureData.sslMethod = TLSv1_client_method();
m_secureData.sslContext = SSL_CTX_new(m_secureData.sslMethod);
m_secureData.ssl = SSL_new(m_secureData.sslContext);
SSL_set_verify(m_secureData.ssl, SSL_VERIFY_NONE, nullptr);
SSL_set_fd(m_secureData.ssl, m_socket);
}
}
@@ -235,16 +203,16 @@ void Socket::close()
}
// Close the secure socket
if (m_secureData.socket.fd != priv::SocketImpl::invalidSocket())
if (m_secureData.ssl)
{
mbedtls_net_free(&m_secureData.socket);
mbedtls_ssl_free(&m_secureData.ssl);
mbedtls_ssl_config_free(&m_secureData.conf);
mbedtls_x509_crt_free(&m_secureData.cacert);
mbedtls_ctr_drbg_free(&m_secureData.ctr_drbg);
mbedtls_entropy_free(&m_secureData.entropy);
m_secureData.socket.fd = priv::SocketImpl::invalidSocket();
SSL_shutdown(m_secureData.ssl);
SSL_free (m_secureData.ssl);
SSL_CTX_free (m_secureData.sslContext);
}
m_secureData.ssl = nullptr;
m_secureData.sslMethod = nullptr;
m_secureData.sslContext = nullptr;
}
} // namespace cpp3ds