diff --git a/include/cpp3ds/Network/Socket.hpp b/include/cpp3ds/Network/Socket.hpp index 771edfa..02ce33e 100644 --- a/include/cpp3ds/Network/Socket.hpp +++ b/include/cpp3ds/Network/Socket.hpp @@ -32,9 +32,9 @@ #include #include #ifdef EMULATION -#include -#include -#include +#define OPENSSL_NO_BIO +#include +#include #else #include <3ds.h> #endif @@ -76,12 +76,9 @@ public: struct SecureData { #ifdef EMULATION - mbedtls_net_context socket; - mbedtls_entropy_context entropy; - mbedtls_ctr_drbg_context ctr_drbg; - mbedtls_ssl_context ssl; - mbedtls_ssl_config conf; - mbedtls_x509_crt cacert; + SSL_CTX *sslContext; + const SSL_METHOD *sslMethod; + ::SSL *ssl; #else sslcContext sslContext; u32 rootCertChain; diff --git a/src/cpp3ds/Network/TcpSocket.cpp b/src/cpp3ds/Network/TcpSocket.cpp index 70f135d..a9c9d19 100644 --- a/src/cpp3ds/Network/TcpSocket.cpp +++ b/src/cpp3ds/Network/TcpSocket.cpp @@ -34,9 +34,7 @@ #include #include #include -#ifdef EMULATION -#include -#else +#ifndef EMULATION #include <3ds.h> #endif @@ -143,110 +141,99 @@ Socket::Status TcpSocket::connect(const IpAddress& remoteAddress, unsigned short // Create the internal socket if it doesn't exist create(); -#ifdef EMULATION - if (isSecure()) - { - char port[5]; - sprintf(port, "%d", remotePort); - if (mbedtls_net_connect(&getSecureData().socket, remoteAddress.toString().c_str(), port, MBEDTLS_NET_PROTO_TCP) != 0) - return priv::SocketImpl::getErrorStatus(); + // Create the remote address + sockaddr_in address = priv::SocketImpl::createAddress(remoteAddress.toInteger(), remotePort); - int ret; - while((ret = mbedtls_ssl_handshake(&getSecureData().ssl)) != 0) - if(ret != MBEDTLS_ERR_SSL_WANT_READ && ret != MBEDTLS_ERR_SSL_WANT_WRITE) - return Error; + if (timeout <= Time::Zero) + { + // ----- We're not using a timeout: just try to connect ----- + + // Connect the socket + if (::connect(getHandle(), reinterpret_cast(&address), sizeof(address)) == -1) + return priv::SocketImpl::getErrorStatus(); +#ifndef EMULATION + if (isSecure()) + sslc_init(getSecureData(), getHandle(), remoteAddress.toString().c_str()); +#else + if (isSecure()) + SSL_connect(getSecureData().ssl); +#endif + // Connection succeeded + return Done; } else -#endif { - // Create the remote address - sockaddr_in address = priv::SocketImpl::createAddress(remoteAddress.toInteger(), remotePort); + // ----- We're using a timeout: we'll need a few tricks to make it work ----- - if (timeout <= Time::Zero) + // Save the previous blocking state + bool blocking = isBlocking(); + + // Switch to non-blocking to enable our connection timeout + if (blocking) + setBlocking(false); + + // Try to connect to the remote address + if (::connect(getHandle(), reinterpret_cast(&address), sizeof(address)) >= 0) { - // ----- We're not using a timeout: just try to connect ----- - - // Connect the socket - if (::connect(getHandle(), reinterpret_cast(&address), sizeof(address)) == -1) - return priv::SocketImpl::getErrorStatus(); + // We got instantly connected! (it may no happen a lot...) + setBlocking(blocking); #ifndef EMULATION - if (isSecure()) + if (isSecure()) sslc_init(getSecureData(), getHandle(), remoteAddress.toString().c_str()); +#else + if (isSecure()) + SSL_connect(getSecureData().ssl); #endif - // Connection succeeded return Done; } - else + + // Get the error status + Status status = priv::SocketImpl::getErrorStatus(); + + // If we were in non-blocking mode, return immediately + if (!blocking) + return status; + + // Otherwise, wait until something happens to our socket (success, timeout or error) + if (status == Socket::NotReady) { - // ----- We're using a timeout: we'll need a few tricks to make it work ----- + // Setup the selector + fd_set selector; + FD_ZERO(&selector); + FD_SET(getHandle(), &selector); - // Save the previous blocking state - bool blocking = isBlocking(); + // Setup the timeout + timeval time; + time.tv_sec = static_cast(timeout.asMicroseconds() / 1000000); + time.tv_usec = static_cast(timeout.asMicroseconds() % 1000000); - // Switch to non-blocking to enable our connection timeout - if (blocking) - setBlocking(false); - - // Try to connect to the remote address - if (::connect(getHandle(), reinterpret_cast(&address), sizeof(address)) >= 0) + // Wait for something to write on our socket (which means that the connection request has returned) + if (select(static_cast(getHandle() + 1), NULL, &selector, NULL, &time) > 0) { - // We got instantly connected! (it may no happen a lot...) - setBlocking(blocking); -#ifndef EMULATION - if (isSecure()) - sslc_init(getSecureData(), getHandle(), remoteAddress.toString().c_str()); -#endif - return Done; - } - - // Get the error status - Status status = priv::SocketImpl::getErrorStatus(); - - // If we were in non-blocking mode, return immediately - if (!blocking) - return status; - - // Otherwise, wait until something happens to our socket (success, timeout or error) - if (status == Socket::NotReady) - { - // Setup the selector - fd_set selector; - FD_ZERO(&selector); - FD_SET(getHandle(), &selector); - - // Setup the timeout - timeval time; - time.tv_sec = static_cast(timeout.asMicroseconds() / 1000000); - time.tv_usec = static_cast(timeout.asMicroseconds() % 1000000); - - // Wait for something to write on our socket (which means that the connection request has returned) - if (select(static_cast(getHandle() + 1), NULL, &selector, NULL, &time) > 0) + // At this point the connection may have been either accepted or refused. + // To know whether it's a success or a failure, we must check the address of the connected peer + if (getRemoteAddress() != cpp3ds::IpAddress::None) { - // At this point the connection may have been either accepted or refused. - // To know whether it's a success or a failure, we must check the address of the connected peer - if (getRemoteAddress() != cpp3ds::IpAddress::None) - { - // Connection accepted - status = Done; - } - else - { - // Connection refused - status = priv::SocketImpl::getErrorStatus(); - } + // Connection accepted + status = Done; } else { - // Failed to connect before timeout is over + // Connection refused status = priv::SocketImpl::getErrorStatus(); } } - - // Switch back to blocking mode - setBlocking(true); - - return status; + else + { + // Failed to connect before timeout is over + status = priv::SocketImpl::getErrorStatus(); + } } + + // Switch back to blocking mode + setBlocking(true); + + return status; } } @@ -291,7 +278,7 @@ Socket::Status TcpSocket::send(const void* data, std::size_t size, std::size_t& // Send a chunk of data if (isSecure()) #ifdef EMULATION - result = mbedtls_ssl_write(&getSecureData().ssl, static_cast(data) + sent, size - sent); + result = SSL_write(getSecureData().ssl, static_cast(data) + sent, size - sent); #else result = sslcWrite(&getSecureData().sslContext, static_cast(data) + sent, size - sent); #endif @@ -331,7 +318,7 @@ Socket::Status TcpSocket::receive(void* data, std::size_t size, std::size_t& rec int sizeReceived; if (isSecure()) #ifdef EMULATION - sizeReceived = mbedtls_ssl_read(&getSecureData().ssl, static_cast(data), size); + sizeReceived = SSL_read(getSecureData().ssl, static_cast(data), size); #else sizeReceived = sslcRead(&getSecureData().sslContext, data, size, false); #endif diff --git a/src/emu3ds/CMakeLists.txt b/src/emu3ds/CMakeLists.txt index 4451013..753d0ce 100644 --- a/src/emu3ds/CMakeLists.txt +++ b/src/emu3ds/CMakeLists.txt @@ -135,8 +135,9 @@ unset(FREETYPE_INCLUDE_DIRS CACHE) find_package(JPEG REQUIRED) find_package(Freetype REQUIRED) +find_package(OpenSSL REQUIRED) -include_directories(${FREETYPE_INCLUDE_DIRS} ${JPEG_INCLUDE_DIR}) +include_directories(${FREETYPE_INCLUDE_DIRS} ${JPEG_INCLUDE_DIR} ${OPENSSL_INCLUDE_DIR}) add_library(cpp3ds-emu STATIC ${SRC} diff --git a/src/emu3ds/Network/Socket.cpp b/src/emu3ds/Network/Socket.cpp index a654101..36a1606 100644 --- a/src/emu3ds/Network/Socket.cpp +++ b/src/emu3ds/Network/Socket.cpp @@ -28,17 +28,21 @@ #include #include #include -#include -#include namespace { - static int entropy_func(void *data, unsigned char *output, size_t len) + bool initializedSSL = false; + + void initSSL() { - for (int i = 0; i < len; ++i) - output[i] = rand() % 256; - return 0; + if (!initializedSSL) + { + SSL_library_init(); + SSL_load_error_strings(); + OpenSSL_add_all_algorithms(); + initializedSSL = true; + } } } @@ -49,10 +53,13 @@ namespace cpp3ds Socket::Socket(Type type, bool secure) : m_type (type), m_socket (priv::SocketImpl::invalidSocket()), -m_isBlocking(true), -m_isSecure (secure) +m_isBlocking(true) { - m_secureData.socket.fd = priv::SocketImpl::invalidSocket(); + m_secureData.ssl = nullptr; + m_secureData.sslMethod = nullptr; + m_secureData.sslContext = nullptr; + + setSecure(secure); } @@ -68,21 +75,8 @@ Socket::~Socket() void Socket::setBlocking(bool blocking) { // Apply if the socket is already created - if (m_isSecure) - { - if (m_secureData.socket.fd != priv::SocketImpl::invalidSocket()) - { - if (blocking) - mbedtls_net_set_block(&m_secureData.socket); - else - mbedtls_net_set_nonblock(&m_secureData.socket); - } - } - else - { - if (m_socket != priv::SocketImpl::invalidSocket()) - priv::SocketImpl::setBlocking(m_socket, blocking); - } + if (m_socket != priv::SocketImpl::invalidSocket()) + priv::SocketImpl::setBlocking(m_socket, blocking); m_isBlocking = blocking; } @@ -102,7 +96,7 @@ void Socket::setSecure(bool secure) m_isSecure = secure; if ((secure && !oldSecure && m_socket != priv::SocketImpl::invalidSocket()) || - (!secure && oldSecure && m_secureData.socket.fd != priv::SocketImpl::invalidSocket())) + (!secure && oldSecure && m_secureData.ssl)) { close(); create(); @@ -120,7 +114,7 @@ bool Socket::isSecure() const //////////////////////////////////////////////////////////// SocketHandle Socket::getHandle() const { - return m_isSecure ? m_secureData.socket.fd : m_socket; + return m_socket; } @@ -135,46 +129,20 @@ Socket::SecureData& Socket::getSecureData() const void Socket::create() { // Don't create the socket if it already exists - if (m_isSecure) + if (m_socket == priv::SocketImpl::invalidSocket()) { - if (m_secureData.socket.fd == priv::SocketImpl::invalidSocket()) - { - int ret; - mbedtls_net_init(&m_secureData.socket); - mbedtls_ssl_init(&m_secureData.ssl); - mbedtls_ssl_config_init(&m_secureData.conf); - mbedtls_x509_crt_init(&m_secureData.cacert); - mbedtls_ctr_drbg_init(&m_secureData.ctr_drbg); - mbedtls_entropy_init( &m_secureData.entropy ); - - srand(time(nullptr)); - if (ret = mbedtls_ctr_drbg_seed(&m_secureData.ctr_drbg, entropy_func, &m_secureData.entropy, NULL, 0 )) - err() << "mbedtls_ctr_drbg_seed failed: " << ret << std::endl; - if ((ret = mbedtls_x509_crt_parse(&m_secureData.cacert, (const unsigned char *) mbedtls_test_cas_pem, mbedtls_test_cas_pem_len)) < 0) - err() << "mbedtls_x509_crt_parse failed: -0x" << std::hex << ret << std::endl; - - // use mbedtls_ssl_conf_endpoint to set socket as server - if ((ret = mbedtls_ssl_config_defaults(&m_secureData.conf, MBEDTLS_SSL_IS_CLIENT, MBEDTLS_SSL_TRANSPORT_STREAM, MBEDTLS_SSL_PRESET_DEFAULT)) != 0) - err() << "mbedtls_ssl_config_defaults failed: " << ret << std::endl; - mbedtls_ssl_conf_authmode(&m_secureData.conf, MBEDTLS_SSL_VERIFY_OPTIONAL); - mbedtls_ssl_conf_ca_chain(&m_secureData.conf, &m_secureData.cacert, NULL); - mbedtls_ssl_conf_rng(&m_secureData.conf, mbedtls_ctr_drbg_random, &m_secureData.ctr_drbg); - if ((ret = mbedtls_ssl_setup(&m_secureData.ssl, &m_secureData.conf)) != 0) - err() << "mbedtls_ssl_setup failed: " << ret << std::endl; - if ((ret = mbedtls_ssl_set_hostname(&m_secureData.ssl, "mbed TLS Server")) != 0) - err() << "mbedtls_ssl_set_hostname failed: " << ret << std::endl; - mbedtls_ssl_set_bio(&m_secureData.ssl, &m_secureData.socket, mbedtls_net_send, mbedtls_net_recv, NULL); - - setBlocking(m_isBlocking); - } + SocketHandle handle = socket(PF_INET, m_type == Tcp ? SOCK_STREAM : SOCK_DGRAM, IPPROTO_IP); + create(handle); } - else + if (m_isSecure && !m_secureData.ssl) { - if (m_socket == priv::SocketImpl::invalidSocket()) - { - SocketHandle handle = socket(PF_INET, m_type == Tcp ? SOCK_STREAM : SOCK_DGRAM, IPPROTO_IP); - create(handle); - } + // TODO: handle errors + initSSL(); + m_secureData.sslMethod = TLSv1_client_method(); + m_secureData.sslContext = SSL_CTX_new(m_secureData.sslMethod); + m_secureData.ssl = SSL_new(m_secureData.sslContext); + SSL_set_verify(m_secureData.ssl, SSL_VERIFY_NONE, nullptr); + SSL_set_fd(m_secureData.ssl, m_socket); } } @@ -235,16 +203,16 @@ void Socket::close() } // Close the secure socket - if (m_secureData.socket.fd != priv::SocketImpl::invalidSocket()) + if (m_secureData.ssl) { - mbedtls_net_free(&m_secureData.socket); - mbedtls_ssl_free(&m_secureData.ssl); - mbedtls_ssl_config_free(&m_secureData.conf); - mbedtls_x509_crt_free(&m_secureData.cacert); - mbedtls_ctr_drbg_free(&m_secureData.ctr_drbg); - mbedtls_entropy_free(&m_secureData.entropy); - m_secureData.socket.fd = priv::SocketImpl::invalidSocket(); + SSL_shutdown(m_secureData.ssl); + SSL_free (m_secureData.ssl); + SSL_CTX_free (m_secureData.sslContext); } + + m_secureData.ssl = nullptr; + m_secureData.sslMethod = nullptr; + m_secureData.sslContext = nullptr; } } // namespace cpp3ds