Skip to content

Commit

Permalink
support TLS
Browse files Browse the repository at this point in the history
  • Loading branch information
superwhd committed Mar 16, 2023
1 parent b4f670a commit e12a0c1
Show file tree
Hide file tree
Showing 6 changed files with 335 additions and 33 deletions.
5 changes: 5 additions & 0 deletions etc/cmake/options.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,11 @@ if(OTBR_DNS_DSO)
target_compile_definitions(otbr-config INTERFACE OTBR_ENABLE_DNS_DSO=1)
endif()

option(OTBR_DNS_DSO_TLS "Enable DSO over TLS" OFF)
if(OTBR_DNS_DSO_TLS)
target_compile_definitions(otbr-config INTERFACE OTBR_ENABLE_DNS_DSO_TLS=1)
endif()

option(OTBR_SRP_REPLICATION "Enable SRP Replication" OFF)
if(OTBR_SRP_REPLICATION)
target_compile_definitions(otbr-config INTERFACE OTBR_ENABLE_SRP_REPLICATION=1)
Expand Down
230 changes: 226 additions & 4 deletions src/dso/dso_transport.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@
#include <netdb.h>
#include <sys/types.h>

#include "common/code_utils.hpp"
#include "mbedtls/certs.h"
#include "mbedtls/net_sockets.h"
#include "openthread/openthread-system.h"
#include "openthread/platform/dso_transport.h"
Expand Down Expand Up @@ -99,6 +101,11 @@ void DsoAgent::Init(otInstance *aInstance, const std::string &aInfraNetIfName)

mInstance = aInstance;
mInfraNetIfName = aInfraNetIfName;

#if OTBR_ENABLE_DNS_DSO_TLS
mTlsClientConfig.Init(/* aIsClient */ true);
mTlsServerConfig.Init(/* aIsClient */ false);
#endif
}

DsoAgent::DsoConnection *DsoAgent::FindConnection(otPlatDsoConnection *aConnection)
Expand Down Expand Up @@ -216,6 +223,49 @@ void DsoAgent::SetHandlers(AcceptHandler aAcceptHandler,
mReceiveHandler = std::move(aReceiveHandler);
}

DsoAgent::DsoConnection::DsoConnection(DsoAgent *aAgent, otPlatDsoConnection *aConnection)
: mAgent(aAgent)
, mConnection(aConnection)
, mState(State::kDisabled)
{
}

DsoAgent::DsoConnection::DsoConnection(DsoAgent *aAgent, otPlatDsoConnection *aConnection, mbedtls_net_context aCtx)
: mAgent(aAgent)
, mConnection(aConnection)
{
assert(mState == State::kDisabled);

mCtx = aCtx;

#if OTBR_ENABLE_DNS_DSO_TLS
mbedtls_ssl_init(&mTlsCtx);
int ret = mbedtls_ssl_setup(&mTlsCtx, &mAgent->mTlsServerConfig.GetConfig());
if (ret < 0)
{
otbrLogWarning("Failed to setup TLS: %x", ret);
MarkStateAs(State::kDisabled);
}
else
{
MarkStateAs(State::kTlsHandshaking);
}
mbedtls_ssl_set_bio(&mTlsCtx, &mCtx, mbedtls_net_send, mbedtls_net_recv, nullptr);
#else
MarkStateAs(State::kConnected);
mAgent->HandleConnected(this);
#endif
}

DsoAgent::DsoConnection::~DsoConnection(void)
{
Disconnect(OT_PLAT_DSO_DISCONNECT_MODE_FORCIBLY_ABORT);
#if OTBR_ENABLE_DNS_DSO_TLS
mbedtls_ssl_free(&mTlsCtx);
#endif
mbedtls_net_free(&mCtx);
}

const char *DsoAgent::DsoConnection::StateToString(State aState)
{
const char *ret = "";
Expand All @@ -231,6 +281,11 @@ const char *DsoAgent::DsoConnection::StateToString(State aState)
case State::kConnected:
ret = "Connected";
break;
#if OTBR_ENABLE_DNS_DSO_TLS
case State::kTlsHandshaking:
ret = "TLS Handshaking";
break;
#endif
}

return ret;
Expand Down Expand Up @@ -266,9 +321,22 @@ void DsoAgent::DsoConnection::Connect(const otSockAddr *aPeerSockAddr)

if (!ret)
{
otbrLogInfo("Connected [%s]:%s", addrBuf, portString.c_str());
otbrLogInfo("TCP connected [%s]:%s", addrBuf, portString.c_str());

#if OTBR_ENABLE_DNS_DSO_TLS
mbedtls_ssl_init(&mTlsCtx);
ret = mbedtls_ssl_setup(&mTlsCtx, &mAgent->mTlsClientConfig.GetConfig());
if (ret < 0)
{
otbrLogWarning("!!! failed to setup TLS: %d", ret);
exit(1);
}
mbedtls_ssl_set_bio(&mTlsCtx, &mCtx, mbedtls_net_send, mbedtls_net_recv, nullptr);
MarkStateAs(State::kTlsHandshaking);
#else
MarkStateAs(State::kConnected);
mAgent->HandleConnected(this);
#endif
}
else if (errno == EAGAIN || errno == EINPROGRESS)
{
Expand Down Expand Up @@ -351,7 +419,11 @@ void DsoAgent::DsoConnection::HandleReceive(void)
{
size_t wantReadLen = mNeedBytes ? mNeedBytes : (sizeof(uint16_t) - mReceiveLengthBuffer.size());

#if OTBR_ENABLE_DNS_DSO_TLS
readLen = mbedtls_ssl_read(&mTlsCtx, buf, std::min(sizeof(buf), wantReadLen));
#else
readLen = mbedtls_net_recv(&mCtx, buf, std::min(sizeof(buf), wantReadLen));
#endif

assert(readLen <= static_cast<int>(std::min(sizeof(buf), wantReadLen)));

Expand Down Expand Up @@ -417,6 +489,9 @@ void DsoAgent::DsoConnection::HandleMbedTlsError(int aError)
case MBEDTLS_ERR_SSL_WANT_READ:
case MBEDTLS_ERR_SSL_WANT_WRITE:
break;
case MBEDTLS_ERR_SSL_ASYNC_IN_PROGRESS:
case MBEDTLS_ERR_SSL_CRYPTO_IN_PROGRESS:
break;
default:
MarkStateAs(State::kDisabled);
mAgent->HandleDisconnected(this, OT_PLAT_DSO_DISCONNECT_MODE_FORCIBLY_ABORT);
Expand All @@ -436,8 +511,20 @@ void DsoAgent::DsoConnection::UpdateStateBySocketState(void)
{
if (!optVal)
{
#if OTBR_ENABLE_DNS_DSO_TLS
mbedtls_ssl_init(&mTlsCtx);
int ret = mbedtls_ssl_setup(&mTlsCtx, &mAgent->mTlsClientConfig.GetConfig());
if (ret < 0)
{
otbrLogWarning("!!! failed to setup TLS: %d", ret);
exit(1);
}
mbedtls_ssl_set_bio(&mTlsCtx, &mCtx, mbedtls_net_send, mbedtls_net_recv, nullptr);
MarkStateAs(State::kTlsHandshaking);
#else
MarkStateAs(State::kConnected);
mAgent->HandleConnected(this);
#endif
}
else
{
Expand All @@ -464,12 +551,80 @@ void DsoAgent::DsoConnection::MarkStateAs(State aState)
return;
}

#if OTBR_ENABLE_DNS_DSO_TLS
void UpdateFdSetForHandshaking(mbedtls_ssl_context *aCtx, int aFd, fd_set &aReadSet, fd_set &aWriteSet, int &aMaxFd)
{
bool shouldRead = false, shouldWrite = false;

switch (aCtx->state)
{
case MBEDTLS_SSL_HELLO_REQUEST:
case MBEDTLS_SSL_CLIENT_HELLO:
case MBEDTLS_SSL_CLIENT_CERTIFICATE:
case MBEDTLS_SSL_CLIENT_KEY_EXCHANGE:
case MBEDTLS_SSL_CERTIFICATE_VERIFY:
case MBEDTLS_SSL_CLIENT_CHANGE_CIPHER_SPEC:
case MBEDTLS_SSL_CLIENT_FINISHED:
shouldWrite = true;
break;
case MBEDTLS_SSL_SERVER_HELLO:
case MBEDTLS_SSL_SERVER_CERTIFICATE:
case MBEDTLS_SSL_SERVER_KEY_EXCHANGE:
case MBEDTLS_SSL_CERTIFICATE_REQUEST:
case MBEDTLS_SSL_SERVER_HELLO_DONE:
case MBEDTLS_SSL_SERVER_CHANGE_CIPHER_SPEC:
case MBEDTLS_SSL_SERVER_FINISHED:
case MBEDTLS_SSL_FLUSH_BUFFERS:
case MBEDTLS_SSL_HANDSHAKE_WRAPUP:
shouldRead = true;
break;
}
if (aCtx->conf->endpoint == MBEDTLS_SSL_IS_SERVER)
{
std::swap(shouldRead, shouldWrite);
}
if (shouldRead)
{
FD_SET(aFd, &aReadSet);
aMaxFd = std::max(aMaxFd, aFd);
}
if (shouldWrite)
{
FD_SET(aFd, &aWriteSet);
aMaxFd = std::max(aMaxFd, aFd);
}
}
void DsoAgent::DsoConnection::TlsHandshake(void)
{
assert(mState == State::kTlsHandshaking);

int ret = mbedtls_ssl_handshake(&mTlsCtx);

HandleMbedTlsError(ret);
if (!ret)
{
MarkStateAs(State::kConnected);
mAgent->HandleConnected(this);
}
if (mState == State::kDisabled)
{
otbrLogWarning("Failed to handshake: %x", ret);
}
}
#endif

void DsoAgent::DsoConnection::FlushSendBuffer(void)
{
int writeLen = mbedtls_net_send(&mCtx, mSendMessageBuffer.data(), mSendMessageBuffer.size());
int writeLen;

VerifyOrExit(mState == State::kConnected);

#if OTBR_ENABLE_DNS_DSO_TLS
writeLen = mbedtls_ssl_write(&mTlsCtx, mSendMessageBuffer.data(), mSendMessageBuffer.size());
#else
writeLen = mbedtls_net_send(&mCtx, mSendMessageBuffer.data(), mSendMessageBuffer.size());
#endif

if (writeLen < 0)
{
otbrLogWarning("Failed to send DSO message: %d", writeLen);
Expand All @@ -484,6 +639,61 @@ void DsoAgent::DsoConnection::FlushSendBuffer(void)
return;
}

#if OTBR_ENABLE_DNS_DSO_TLS
DsoAgent::TlsConfig::~TlsConfig(void)
{
mbedtls_ssl_config_free(&mConfig);
mbedtls_pk_free(&mPKey);
mbedtls_x509_crt_free(&mSrvCert);
}

void DsoAgent::TlsConfig::Init(bool aIsClient)
{
otPlatCryptoRandomInit();
mbedtls_ssl_config_init(&mConfig);
mbedtls_x509_crt_init(&mSrvCert);
mbedtls_pk_init(&mPKey);
mbedtls_ssl_conf_rng(
&mConfig,
[](void *, unsigned char *aBuffer, size_t aLength) -> int {
OTBR_UNUSED_VARIABLE(otPlatCryptoRandomGet(aBuffer, aLength));
return 0;
},
nullptr);

if (aIsClient)
{
VerifyOrDie(!mbedtls_ssl_config_defaults(&mConfig, MBEDTLS_SSL_IS_CLIENT, MBEDTLS_SSL_TRANSPORT_STREAM,
MBEDTLS_SSL_PRESET_DEFAULT),
"mbedtls_ssl_config_defaults failed");
VerifyOrDie(!mbedtls_x509_crt_parse(&mSrvCert, reinterpret_cast<const unsigned char *>(kCasPem), kCasPemLength),
"mbedtls_x509_crt_parse failed");
mbedtls_ssl_conf_authmode(&mConfig, MBEDTLS_SSL_VERIFY_OPTIONAL);
mbedtls_ssl_conf_ca_chain(&mConfig, &mSrvCert, nullptr);
}
else
{
VerifyOrDie(!mbedtls_ssl_config_defaults(&mConfig, MBEDTLS_SSL_IS_SERVER, MBEDTLS_SSL_TRANSPORT_STREAM,
MBEDTLS_SSL_PRESET_DEFAULT),
"mbedtls_ssl_config_defaults failed");
VerifyOrDie(!mbedtls_x509_crt_parse(&mSrvCert, reinterpret_cast<const unsigned char *>(kSrvPem), kSrvPemLength),
"mbedtls_x509_crt_parse failed");
VerifyOrDie(!mbedtls_x509_crt_parse(&mSrvCert, reinterpret_cast<const unsigned char *>(kCasPem), kCasPemLength),
"mbedtls_x509_crt_parse failed");
VerifyOrDie(
!mbedtls_pk_parse_key(&mPKey, reinterpret_cast<const unsigned char *>(kSrvKey), kSrvKeyLength, nullptr, 0),
"mbedtls_pk_parse_key failed");

mbedtls_ssl_conf_ca_chain(&mConfig, mSrvCert.next, nullptr);

VerifyOrDie(!mbedtls_ssl_conf_own_cert(&mConfig, &mSrvCert, &mPKey), "mbedtls_ssl_conf_own_cert failed");
}

mbedtls_ssl_conf_min_version(&mConfig, MBEDTLS_SSL_MAJOR_VERSION_3, MBEDTLS_SSL_MINOR_VERSION_3);
mbedtls_ssl_conf_max_version(&mConfig, MBEDTLS_SSL_MAJOR_VERSION_3, MBEDTLS_SSL_MINOR_VERSION_3);
}
#endif

void DsoAgent::Update(MainloopContext &aMainloop)
{
if (mListeningEnabled)
Expand All @@ -510,6 +720,12 @@ void DsoAgent::Update(MainloopContext &aMainloop)
}
aMainloop.mMaxFd = std::max(aMainloop.mMaxFd, pair.second->GetFd());
break;
#if OTBR_ENABLE_DNS_DSO_TLS
case DsoConnection::State::kTlsHandshaking:
UpdateFdSetForHandshaking(&pair.second->mTlsCtx, pair.second->GetFd(), aMainloop.mReadFdSet,
aMainloop.mWriteFdSet, aMainloop.mMaxFd);
break;
#endif
default:
break;
}
Expand Down Expand Up @@ -580,8 +796,7 @@ void DsoAgent::ProcessIncomingConnection(mbedtls_net_context aCtx, uint8_t *aAdd
conn = HandleAccept(mInstance, &sockAddr);

VerifyOrExit(conn != nullptr, otbrLogInfo("Failed to accept connection"));

HandleConnected(FindOrCreateConnection(conn, aCtx));
FindOrCreateConnection(conn, aCtx);
successful = true;

exit:
Expand Down Expand Up @@ -623,6 +838,13 @@ void DsoAgent::ProcessConnections(const MainloopContext &aMainloop)
conn->FlushSendBuffer();
}
break;
#if OTBR_ENABLE_DNS_DSO_TLS
case DsoConnection::State::kTlsHandshaking:
if (FD_ISSET(conn->GetFd(), &aMainloop.mReadFdSet) || FD_ISSET(conn->GetFd(), &aMainloop.mWriteFdSet))
{
conn->TlsHandshake();
}
#endif
default:
break;
}
Expand Down
Loading

0 comments on commit e12a0c1

Please sign in to comment.