mirror of
https://github.com/encounter/cpp3ds.git
synced 2026-03-30 11:04:22 -07:00
Replace mbedtls with OpenSSL for emulator
This commit is contained in:
@@ -32,9 +32,9 @@
|
||||
#include <cpp3ds/System/NonCopyable.hpp>
|
||||
#include <vector>
|
||||
#ifdef EMULATION
|
||||
#include <mbedtls/net.h>
|
||||
#include <mbedtls/entropy.h>
|
||||
#include <mbedtls/ctr_drbg.h>
|
||||
#define OPENSSL_NO_BIO
|
||||
#include <openssl/ssl.h>
|
||||
#include <openssl/err.h>
|
||||
#else
|
||||
#include <3ds.h>
|
||||
#endif
|
||||
@@ -76,12 +76,9 @@ public:
|
||||
struct SecureData
|
||||
{
|
||||
#ifdef EMULATION
|
||||
mbedtls_net_context socket;
|
||||
mbedtls_entropy_context entropy;
|
||||
mbedtls_ctr_drbg_context ctr_drbg;
|
||||
mbedtls_ssl_context ssl;
|
||||
mbedtls_ssl_config conf;
|
||||
mbedtls_x509_crt cacert;
|
||||
SSL_CTX *sslContext;
|
||||
const SSL_METHOD *sslMethod;
|
||||
::SSL *ssl;
|
||||
#else
|
||||
sslcContext sslContext;
|
||||
u32 rootCertChain;
|
||||
|
||||
@@ -34,9 +34,7 @@
|
||||
#include <algorithm>
|
||||
#include <cstring>
|
||||
#include <iostream>
|
||||
#ifdef EMULATION
|
||||
#include <mbedtls/net.h>
|
||||
#else
|
||||
#ifndef EMULATION
|
||||
#include <3ds.h>
|
||||
#endif
|
||||
|
||||
@@ -143,110 +141,99 @@ Socket::Status TcpSocket::connect(const IpAddress& remoteAddress, unsigned short
|
||||
// Create the internal socket if it doesn't exist
|
||||
create();
|
||||
|
||||
#ifdef EMULATION
|
||||
if (isSecure())
|
||||
{
|
||||
char port[5];
|
||||
sprintf(port, "%d", remotePort);
|
||||
if (mbedtls_net_connect(&getSecureData().socket, remoteAddress.toString().c_str(), port, MBEDTLS_NET_PROTO_TCP) != 0)
|
||||
return priv::SocketImpl::getErrorStatus();
|
||||
// Create the remote address
|
||||
sockaddr_in address = priv::SocketImpl::createAddress(remoteAddress.toInteger(), remotePort);
|
||||
|
||||
int ret;
|
||||
while((ret = mbedtls_ssl_handshake(&getSecureData().ssl)) != 0)
|
||||
if(ret != MBEDTLS_ERR_SSL_WANT_READ && ret != MBEDTLS_ERR_SSL_WANT_WRITE)
|
||||
return Error;
|
||||
if (timeout <= Time::Zero)
|
||||
{
|
||||
// ----- We're not using a timeout: just try to connect -----
|
||||
|
||||
// Connect the socket
|
||||
if (::connect(getHandle(), reinterpret_cast<sockaddr*>(&address), sizeof(address)) == -1)
|
||||
return priv::SocketImpl::getErrorStatus();
|
||||
#ifndef EMULATION
|
||||
if (isSecure())
|
||||
sslc_init(getSecureData(), getHandle(), remoteAddress.toString().c_str());
|
||||
#else
|
||||
if (isSecure())
|
||||
SSL_connect(getSecureData().ssl);
|
||||
#endif
|
||||
// Connection succeeded
|
||||
return Done;
|
||||
}
|
||||
else
|
||||
#endif
|
||||
{
|
||||
// Create the remote address
|
||||
sockaddr_in address = priv::SocketImpl::createAddress(remoteAddress.toInteger(), remotePort);
|
||||
// ----- We're using a timeout: we'll need a few tricks to make it work -----
|
||||
|
||||
if (timeout <= Time::Zero)
|
||||
// Save the previous blocking state
|
||||
bool blocking = isBlocking();
|
||||
|
||||
// Switch to non-blocking to enable our connection timeout
|
||||
if (blocking)
|
||||
setBlocking(false);
|
||||
|
||||
// Try to connect to the remote address
|
||||
if (::connect(getHandle(), reinterpret_cast<sockaddr*>(&address), sizeof(address)) >= 0)
|
||||
{
|
||||
// ----- We're not using a timeout: just try to connect -----
|
||||
|
||||
// Connect the socket
|
||||
if (::connect(getHandle(), reinterpret_cast<sockaddr*>(&address), sizeof(address)) == -1)
|
||||
return priv::SocketImpl::getErrorStatus();
|
||||
// We got instantly connected! (it may no happen a lot...)
|
||||
setBlocking(blocking);
|
||||
#ifndef EMULATION
|
||||
if (isSecure())
|
||||
if (isSecure())
|
||||
sslc_init(getSecureData(), getHandle(), remoteAddress.toString().c_str());
|
||||
#else
|
||||
if (isSecure())
|
||||
SSL_connect(getSecureData().ssl);
|
||||
#endif
|
||||
// Connection succeeded
|
||||
return Done;
|
||||
}
|
||||
else
|
||||
|
||||
// Get the error status
|
||||
Status status = priv::SocketImpl::getErrorStatus();
|
||||
|
||||
// If we were in non-blocking mode, return immediately
|
||||
if (!blocking)
|
||||
return status;
|
||||
|
||||
// Otherwise, wait until something happens to our socket (success, timeout or error)
|
||||
if (status == Socket::NotReady)
|
||||
{
|
||||
// ----- We're using a timeout: we'll need a few tricks to make it work -----
|
||||
// Setup the selector
|
||||
fd_set selector;
|
||||
FD_ZERO(&selector);
|
||||
FD_SET(getHandle(), &selector);
|
||||
|
||||
// Save the previous blocking state
|
||||
bool blocking = isBlocking();
|
||||
// Setup the timeout
|
||||
timeval time;
|
||||
time.tv_sec = static_cast<long>(timeout.asMicroseconds() / 1000000);
|
||||
time.tv_usec = static_cast<long>(timeout.asMicroseconds() % 1000000);
|
||||
|
||||
// Switch to non-blocking to enable our connection timeout
|
||||
if (blocking)
|
||||
setBlocking(false);
|
||||
|
||||
// Try to connect to the remote address
|
||||
if (::connect(getHandle(), reinterpret_cast<sockaddr*>(&address), sizeof(address)) >= 0)
|
||||
// Wait for something to write on our socket (which means that the connection request has returned)
|
||||
if (select(static_cast<int>(getHandle() + 1), NULL, &selector, NULL, &time) > 0)
|
||||
{
|
||||
// We got instantly connected! (it may no happen a lot...)
|
||||
setBlocking(blocking);
|
||||
#ifndef EMULATION
|
||||
if (isSecure())
|
||||
sslc_init(getSecureData(), getHandle(), remoteAddress.toString().c_str());
|
||||
#endif
|
||||
return Done;
|
||||
}
|
||||
|
||||
// Get the error status
|
||||
Status status = priv::SocketImpl::getErrorStatus();
|
||||
|
||||
// If we were in non-blocking mode, return immediately
|
||||
if (!blocking)
|
||||
return status;
|
||||
|
||||
// Otherwise, wait until something happens to our socket (success, timeout or error)
|
||||
if (status == Socket::NotReady)
|
||||
{
|
||||
// Setup the selector
|
||||
fd_set selector;
|
||||
FD_ZERO(&selector);
|
||||
FD_SET(getHandle(), &selector);
|
||||
|
||||
// Setup the timeout
|
||||
timeval time;
|
||||
time.tv_sec = static_cast<long>(timeout.asMicroseconds() / 1000000);
|
||||
time.tv_usec = static_cast<long>(timeout.asMicroseconds() % 1000000);
|
||||
|
||||
// Wait for something to write on our socket (which means that the connection request has returned)
|
||||
if (select(static_cast<int>(getHandle() + 1), NULL, &selector, NULL, &time) > 0)
|
||||
// At this point the connection may have been either accepted or refused.
|
||||
// To know whether it's a success or a failure, we must check the address of the connected peer
|
||||
if (getRemoteAddress() != cpp3ds::IpAddress::None)
|
||||
{
|
||||
// At this point the connection may have been either accepted or refused.
|
||||
// To know whether it's a success or a failure, we must check the address of the connected peer
|
||||
if (getRemoteAddress() != cpp3ds::IpAddress::None)
|
||||
{
|
||||
// Connection accepted
|
||||
status = Done;
|
||||
}
|
||||
else
|
||||
{
|
||||
// Connection refused
|
||||
status = priv::SocketImpl::getErrorStatus();
|
||||
}
|
||||
// Connection accepted
|
||||
status = Done;
|
||||
}
|
||||
else
|
||||
{
|
||||
// Failed to connect before timeout is over
|
||||
// Connection refused
|
||||
status = priv::SocketImpl::getErrorStatus();
|
||||
}
|
||||
}
|
||||
|
||||
// Switch back to blocking mode
|
||||
setBlocking(true);
|
||||
|
||||
return status;
|
||||
else
|
||||
{
|
||||
// Failed to connect before timeout is over
|
||||
status = priv::SocketImpl::getErrorStatus();
|
||||
}
|
||||
}
|
||||
|
||||
// Switch back to blocking mode
|
||||
setBlocking(true);
|
||||
|
||||
return status;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -291,7 +278,7 @@ Socket::Status TcpSocket::send(const void* data, std::size_t size, std::size_t&
|
||||
// Send a chunk of data
|
||||
if (isSecure())
|
||||
#ifdef EMULATION
|
||||
result = mbedtls_ssl_write(&getSecureData().ssl, static_cast<const unsigned char*>(data) + sent, size - sent);
|
||||
result = SSL_write(getSecureData().ssl, static_cast<const unsigned char*>(data) + sent, size - sent);
|
||||
#else
|
||||
result = sslcWrite(&getSecureData().sslContext, static_cast<const unsigned char*>(data) + sent, size - sent);
|
||||
#endif
|
||||
@@ -331,7 +318,7 @@ Socket::Status TcpSocket::receive(void* data, std::size_t size, std::size_t& rec
|
||||
int sizeReceived;
|
||||
if (isSecure())
|
||||
#ifdef EMULATION
|
||||
sizeReceived = mbedtls_ssl_read(&getSecureData().ssl, static_cast<unsigned char*>(data), size);
|
||||
sizeReceived = SSL_read(getSecureData().ssl, static_cast<unsigned char*>(data), size);
|
||||
#else
|
||||
sizeReceived = sslcRead(&getSecureData().sslContext, data, size, false);
|
||||
#endif
|
||||
|
||||
@@ -135,8 +135,9 @@ unset(FREETYPE_INCLUDE_DIRS CACHE)
|
||||
|
||||
find_package(JPEG REQUIRED)
|
||||
find_package(Freetype REQUIRED)
|
||||
find_package(OpenSSL REQUIRED)
|
||||
|
||||
include_directories(${FREETYPE_INCLUDE_DIRS} ${JPEG_INCLUDE_DIR})
|
||||
include_directories(${FREETYPE_INCLUDE_DIRS} ${JPEG_INCLUDE_DIR} ${OPENSSL_INCLUDE_DIR})
|
||||
|
||||
add_library(cpp3ds-emu STATIC
|
||||
${SRC}
|
||||
|
||||
@@ -28,17 +28,21 @@
|
||||
#include <cpp3ds/Network/Socket.hpp>
|
||||
#include <cpp3ds/Network/SocketImpl.hpp>
|
||||
#include <cpp3ds/System/Err.hpp>
|
||||
#include <mbedtls/certs.h>
|
||||
#include <stdlib.h>
|
||||
|
||||
|
||||
namespace
|
||||
{
|
||||
static int entropy_func(void *data, unsigned char *output, size_t len)
|
||||
bool initializedSSL = false;
|
||||
|
||||
void initSSL()
|
||||
{
|
||||
for (int i = 0; i < len; ++i)
|
||||
output[i] = rand() % 256;
|
||||
return 0;
|
||||
if (!initializedSSL)
|
||||
{
|
||||
SSL_library_init();
|
||||
SSL_load_error_strings();
|
||||
OpenSSL_add_all_algorithms();
|
||||
initializedSSL = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -49,10 +53,13 @@ namespace cpp3ds
|
||||
Socket::Socket(Type type, bool secure) :
|
||||
m_type (type),
|
||||
m_socket (priv::SocketImpl::invalidSocket()),
|
||||
m_isBlocking(true),
|
||||
m_isSecure (secure)
|
||||
m_isBlocking(true)
|
||||
{
|
||||
m_secureData.socket.fd = priv::SocketImpl::invalidSocket();
|
||||
m_secureData.ssl = nullptr;
|
||||
m_secureData.sslMethod = nullptr;
|
||||
m_secureData.sslContext = nullptr;
|
||||
|
||||
setSecure(secure);
|
||||
}
|
||||
|
||||
|
||||
@@ -68,21 +75,8 @@ Socket::~Socket()
|
||||
void Socket::setBlocking(bool blocking)
|
||||
{
|
||||
// Apply if the socket is already created
|
||||
if (m_isSecure)
|
||||
{
|
||||
if (m_secureData.socket.fd != priv::SocketImpl::invalidSocket())
|
||||
{
|
||||
if (blocking)
|
||||
mbedtls_net_set_block(&m_secureData.socket);
|
||||
else
|
||||
mbedtls_net_set_nonblock(&m_secureData.socket);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if (m_socket != priv::SocketImpl::invalidSocket())
|
||||
priv::SocketImpl::setBlocking(m_socket, blocking);
|
||||
}
|
||||
if (m_socket != priv::SocketImpl::invalidSocket())
|
||||
priv::SocketImpl::setBlocking(m_socket, blocking);
|
||||
|
||||
m_isBlocking = blocking;
|
||||
}
|
||||
@@ -102,7 +96,7 @@ void Socket::setSecure(bool secure)
|
||||
m_isSecure = secure;
|
||||
|
||||
if ((secure && !oldSecure && m_socket != priv::SocketImpl::invalidSocket()) ||
|
||||
(!secure && oldSecure && m_secureData.socket.fd != priv::SocketImpl::invalidSocket()))
|
||||
(!secure && oldSecure && m_secureData.ssl))
|
||||
{
|
||||
close();
|
||||
create();
|
||||
@@ -120,7 +114,7 @@ bool Socket::isSecure() const
|
||||
////////////////////////////////////////////////////////////
|
||||
SocketHandle Socket::getHandle() const
|
||||
{
|
||||
return m_isSecure ? m_secureData.socket.fd : m_socket;
|
||||
return m_socket;
|
||||
}
|
||||
|
||||
|
||||
@@ -135,46 +129,20 @@ Socket::SecureData& Socket::getSecureData() const
|
||||
void Socket::create()
|
||||
{
|
||||
// Don't create the socket if it already exists
|
||||
if (m_isSecure)
|
||||
if (m_socket == priv::SocketImpl::invalidSocket())
|
||||
{
|
||||
if (m_secureData.socket.fd == priv::SocketImpl::invalidSocket())
|
||||
{
|
||||
int ret;
|
||||
mbedtls_net_init(&m_secureData.socket);
|
||||
mbedtls_ssl_init(&m_secureData.ssl);
|
||||
mbedtls_ssl_config_init(&m_secureData.conf);
|
||||
mbedtls_x509_crt_init(&m_secureData.cacert);
|
||||
mbedtls_ctr_drbg_init(&m_secureData.ctr_drbg);
|
||||
mbedtls_entropy_init( &m_secureData.entropy );
|
||||
|
||||
srand(time(nullptr));
|
||||
if (ret = mbedtls_ctr_drbg_seed(&m_secureData.ctr_drbg, entropy_func, &m_secureData.entropy, NULL, 0 ))
|
||||
err() << "mbedtls_ctr_drbg_seed failed: " << ret << std::endl;
|
||||
if ((ret = mbedtls_x509_crt_parse(&m_secureData.cacert, (const unsigned char *) mbedtls_test_cas_pem, mbedtls_test_cas_pem_len)) < 0)
|
||||
err() << "mbedtls_x509_crt_parse failed: -0x" << std::hex << ret << std::endl;
|
||||
|
||||
// use mbedtls_ssl_conf_endpoint to set socket as server
|
||||
if ((ret = mbedtls_ssl_config_defaults(&m_secureData.conf, MBEDTLS_SSL_IS_CLIENT, MBEDTLS_SSL_TRANSPORT_STREAM, MBEDTLS_SSL_PRESET_DEFAULT)) != 0)
|
||||
err() << "mbedtls_ssl_config_defaults failed: " << ret << std::endl;
|
||||
mbedtls_ssl_conf_authmode(&m_secureData.conf, MBEDTLS_SSL_VERIFY_OPTIONAL);
|
||||
mbedtls_ssl_conf_ca_chain(&m_secureData.conf, &m_secureData.cacert, NULL);
|
||||
mbedtls_ssl_conf_rng(&m_secureData.conf, mbedtls_ctr_drbg_random, &m_secureData.ctr_drbg);
|
||||
if ((ret = mbedtls_ssl_setup(&m_secureData.ssl, &m_secureData.conf)) != 0)
|
||||
err() << "mbedtls_ssl_setup failed: " << ret << std::endl;
|
||||
if ((ret = mbedtls_ssl_set_hostname(&m_secureData.ssl, "mbed TLS Server")) != 0)
|
||||
err() << "mbedtls_ssl_set_hostname failed: " << ret << std::endl;
|
||||
mbedtls_ssl_set_bio(&m_secureData.ssl, &m_secureData.socket, mbedtls_net_send, mbedtls_net_recv, NULL);
|
||||
|
||||
setBlocking(m_isBlocking);
|
||||
}
|
||||
SocketHandle handle = socket(PF_INET, m_type == Tcp ? SOCK_STREAM : SOCK_DGRAM, IPPROTO_IP);
|
||||
create(handle);
|
||||
}
|
||||
else
|
||||
if (m_isSecure && !m_secureData.ssl)
|
||||
{
|
||||
if (m_socket == priv::SocketImpl::invalidSocket())
|
||||
{
|
||||
SocketHandle handle = socket(PF_INET, m_type == Tcp ? SOCK_STREAM : SOCK_DGRAM, IPPROTO_IP);
|
||||
create(handle);
|
||||
}
|
||||
// TODO: handle errors
|
||||
initSSL();
|
||||
m_secureData.sslMethod = TLSv1_client_method();
|
||||
m_secureData.sslContext = SSL_CTX_new(m_secureData.sslMethod);
|
||||
m_secureData.ssl = SSL_new(m_secureData.sslContext);
|
||||
SSL_set_verify(m_secureData.ssl, SSL_VERIFY_NONE, nullptr);
|
||||
SSL_set_fd(m_secureData.ssl, m_socket);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -235,16 +203,16 @@ void Socket::close()
|
||||
}
|
||||
|
||||
// Close the secure socket
|
||||
if (m_secureData.socket.fd != priv::SocketImpl::invalidSocket())
|
||||
if (m_secureData.ssl)
|
||||
{
|
||||
mbedtls_net_free(&m_secureData.socket);
|
||||
mbedtls_ssl_free(&m_secureData.ssl);
|
||||
mbedtls_ssl_config_free(&m_secureData.conf);
|
||||
mbedtls_x509_crt_free(&m_secureData.cacert);
|
||||
mbedtls_ctr_drbg_free(&m_secureData.ctr_drbg);
|
||||
mbedtls_entropy_free(&m_secureData.entropy);
|
||||
m_secureData.socket.fd = priv::SocketImpl::invalidSocket();
|
||||
SSL_shutdown(m_secureData.ssl);
|
||||
SSL_free (m_secureData.ssl);
|
||||
SSL_CTX_free (m_secureData.sslContext);
|
||||
}
|
||||
|
||||
m_secureData.ssl = nullptr;
|
||||
m_secureData.sslMethod = nullptr;
|
||||
m_secureData.sslContext = nullptr;
|
||||
}
|
||||
|
||||
} // namespace cpp3ds
|
||||
|
||||
Reference in New Issue
Block a user