diff options
author | Joshua Duong <joshuaduong@google.com> | 2020-02-14 17:42:40 +0000 |
---|---|---|
committer | Gerrit Code Review <noreply-gerritcodereview@google.com> | 2020-02-14 17:42:40 +0000 |
commit | fd8b4ea996ea94e56a43594e2abeeb82b4bfb01e (patch) | |
tree | 9d7c0329a17a3b76ad106a7f409a13eec85a5fff /adb/tls/tls_connection.cpp | |
parent | c7b92ea120a963e85e5aa0257aab050ad64c124e (diff) | |
parent | 4293e322f27cdef47f9cf629990d361974e9e1b7 (diff) |
Merge changes from topic "adbwifi-syscore-base"
* changes:
[adbwifi] Add tls_connection library.
Move adb RSA utilities into its own library.
[adbwifi] Add adb protos.
Diffstat (limited to 'adb/tls/tls_connection.cpp')
-rw-r--r-- | adb/tls/tls_connection.cpp | 387 |
1 files changed, 387 insertions, 0 deletions
diff --git a/adb/tls/tls_connection.cpp b/adb/tls/tls_connection.cpp new file mode 100644 index 000000000..7df6ef410 --- /dev/null +++ b/adb/tls/tls_connection.cpp @@ -0,0 +1,387 @@ +/* + * Copyright (C) 2019 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "adb/tls/tls_connection.h" + +#include <algorithm> +#include <vector> + +#include <android-base/logging.h> +#include <android-base/strings.h> +#include <openssl/err.h> +#include <openssl/ssl.h> + +using android::base::borrowed_fd; + +namespace adb { +namespace tls { + +namespace { + +static constexpr char kExportedKeyLabel[] = "adb-label"; + +class TlsConnectionImpl : public TlsConnection { + public: + explicit TlsConnectionImpl(Role role, std::string_view cert, std::string_view priv_key, + borrowed_fd fd); + ~TlsConnectionImpl() override; + + bool AddTrustedCertificate(std::string_view cert) override; + void SetCertVerifyCallback(CertVerifyCb cb) override; + void SetCertificateCallback(SetCertCb cb) override; + void SetClientCAList(STACK_OF(X509_NAME) * ca_list) override; + std::vector<uint8_t> ExportKeyingMaterial(size_t length) override; + void EnableClientPostHandshakeCheck(bool enable) override; + TlsError DoHandshake() override; + std::vector<uint8_t> ReadFully(size_t size) override; + bool ReadFully(void* buf, size_t size) override; + bool WriteFully(std::string_view data) override; + + static bssl::UniquePtr<EVP_PKEY> EvpPkeyFromPEM(std::string_view pem); + static bssl::UniquePtr<CRYPTO_BUFFER> BufferFromPEM(std::string_view pem); + + private: + static int SSLSetCertVerifyCb(X509_STORE_CTX* ctx, void* opaque); + static int SSLSetCertCb(SSL* ssl, void* opaque); + + static bssl::UniquePtr<X509> X509FromBuffer(bssl::UniquePtr<CRYPTO_BUFFER> buffer); + static const char* SSLErrorString(); + void Invalidate(); + TlsError GetFailureReason(int err); + + Role role_; + bssl::UniquePtr<EVP_PKEY> priv_key_; + bssl::UniquePtr<CRYPTO_BUFFER> cert_; + + bssl::UniquePtr<STACK_OF(X509_NAME)> ca_list_; + bssl::UniquePtr<SSL_CTX> ssl_ctx_; + bssl::UniquePtr<SSL> ssl_; + std::vector<bssl::UniquePtr<X509>> known_certificates_; + bool client_verify_post_handshake_ = false; + + CertVerifyCb cert_verify_cb_; + SetCertCb set_cert_cb_; + borrowed_fd fd_; +}; // TlsConnectionImpl + +TlsConnectionImpl::TlsConnectionImpl(Role role, std::string_view cert, std::string_view priv_key, + borrowed_fd fd) + : role_(role), fd_(fd) { + CHECK(!cert.empty() && !priv_key.empty()); + LOG(INFO) << "Initializing adbwifi TlsConnection"; + cert_ = BufferFromPEM(cert); + priv_key_ = EvpPkeyFromPEM(priv_key); +} + +TlsConnectionImpl::~TlsConnectionImpl() { + // shutdown the SSL connection + if (ssl_ != nullptr) { + SSL_shutdown(ssl_.get()); + } +} + +// static +const char* TlsConnectionImpl::SSLErrorString() { + auto sslerr = ERR_peek_last_error(); + return ERR_reason_error_string(sslerr); +} + +// static +bssl::UniquePtr<EVP_PKEY> TlsConnectionImpl::EvpPkeyFromPEM(std::string_view pem) { + bssl::UniquePtr<BIO> bio(BIO_new_mem_buf(pem.data(), pem.size())); + return bssl::UniquePtr<EVP_PKEY>(PEM_read_bio_PrivateKey(bio.get(), nullptr, nullptr, nullptr)); +} + +// static +bssl::UniquePtr<CRYPTO_BUFFER> TlsConnectionImpl::BufferFromPEM(std::string_view pem) { + bssl::UniquePtr<BIO> bio(BIO_new_mem_buf(pem.data(), pem.size())); + char* name = nullptr; + char* header = nullptr; + uint8_t* data = nullptr; + long data_len = 0; + + if (!PEM_read_bio(bio.get(), &name, &header, &data, &data_len)) { + LOG(ERROR) << "Failed to read certificate"; + return nullptr; + } + OPENSSL_free(name); + OPENSSL_free(header); + + auto ret = bssl::UniquePtr<CRYPTO_BUFFER>(CRYPTO_BUFFER_new(data, data_len, nullptr)); + OPENSSL_free(data); + return ret; +} + +// static +bssl::UniquePtr<X509> TlsConnectionImpl::X509FromBuffer(bssl::UniquePtr<CRYPTO_BUFFER> buffer) { + if (!buffer) { + return nullptr; + } + return bssl::UniquePtr<X509>(X509_parse_from_buffer(buffer.get())); +} + +// static +int TlsConnectionImpl::SSLSetCertVerifyCb(X509_STORE_CTX* ctx, void* opaque) { + auto* p = reinterpret_cast<TlsConnectionImpl*>(opaque); + return p->cert_verify_cb_(ctx); +} + +// static +int TlsConnectionImpl::SSLSetCertCb(SSL* ssl, void* opaque) { + auto* p = reinterpret_cast<TlsConnectionImpl*>(opaque); + return p->set_cert_cb_(ssl); +} + +bool TlsConnectionImpl::AddTrustedCertificate(std::string_view cert) { + // Create X509 buffer from the certificate string + auto buf = X509FromBuffer(BufferFromPEM(cert)); + if (buf == nullptr) { + LOG(ERROR) << "Failed to create a X509 buffer for the certificate."; + return false; + } + known_certificates_.push_back(std::move(buf)); + return true; +} + +void TlsConnectionImpl::SetCertVerifyCallback(CertVerifyCb cb) { + cert_verify_cb_ = cb; +} + +void TlsConnectionImpl::SetCertificateCallback(SetCertCb cb) { + set_cert_cb_ = cb; +} + +void TlsConnectionImpl::SetClientCAList(STACK_OF(X509_NAME) * ca_list) { + CHECK(role_ == Role::Server); + ca_list_.reset(ca_list != nullptr ? SSL_dup_CA_list(ca_list) : nullptr); +} + +std::vector<uint8_t> TlsConnectionImpl::ExportKeyingMaterial(size_t length) { + if (ssl_.get() == nullptr) { + return {}; + } + + std::vector<uint8_t> out(length); + if (SSL_export_keying_material(ssl_.get(), out.data(), out.size(), kExportedKeyLabel, + sizeof(kExportedKeyLabel), nullptr, 0, false) == 0) { + return {}; + } + return out; +} + +void TlsConnectionImpl::EnableClientPostHandshakeCheck(bool enable) { + client_verify_post_handshake_ = enable; +} + +TlsConnection::TlsError TlsConnectionImpl::GetFailureReason(int err) { + switch (ERR_GET_REASON(err)) { + case SSL_R_SSLV3_ALERT_BAD_CERTIFICATE: + case SSL_R_SSLV3_ALERT_UNSUPPORTED_CERTIFICATE: + case SSL_R_SSLV3_ALERT_CERTIFICATE_REVOKED: + case SSL_R_SSLV3_ALERT_CERTIFICATE_EXPIRED: + case SSL_R_SSLV3_ALERT_CERTIFICATE_UNKNOWN: + case SSL_R_TLSV1_ALERT_ACCESS_DENIED: + case SSL_R_TLSV1_ALERT_UNKNOWN_CA: + case SSL_R_TLSV1_CERTIFICATE_REQUIRED: + return TlsError::PeerRejectedCertificate; + case SSL_R_CERTIFICATE_VERIFY_FAILED: + return TlsError::CertificateRejected; + default: + return TlsError::UnknownFailure; + } +} + +TlsConnection::TlsError TlsConnectionImpl::DoHandshake() { + int err = -1; + LOG(INFO) << "Starting adbwifi tls handshake"; + ssl_ctx_.reset(SSL_CTX_new(TLS_method())); + // TODO: Remove set_max_proto_version() once external/boringssl is updated + // past + // https://boringssl.googlesource.com/boringssl/+/58d56f4c59969a23e5f52014e2651c76fea2f877 + if (ssl_ctx_.get() == nullptr || + !SSL_CTX_set_min_proto_version(ssl_ctx_.get(), TLS1_3_VERSION) || + !SSL_CTX_set_max_proto_version(ssl_ctx_.get(), TLS1_3_VERSION)) { + LOG(ERROR) << "Failed to create SSL context"; + return TlsError::UnknownFailure; + } + + // Register user-supplied known certificates + for (auto const& cert : known_certificates_) { + if (X509_STORE_add_cert(SSL_CTX_get_cert_store(ssl_ctx_.get()), cert.get()) == 0) { + LOG(ERROR) << "Unable to add certificates into the X509_STORE"; + return TlsError::UnknownFailure; + } + } + + // Custom certificate verification + if (cert_verify_cb_) { + SSL_CTX_set_cert_verify_callback(ssl_ctx_.get(), SSLSetCertVerifyCb, this); + } + + // set select certificate callback, if any. + if (set_cert_cb_) { + SSL_CTX_set_cert_cb(ssl_ctx_.get(), SSLSetCertCb, this); + } + + // Server-allowed client CA list + if (ca_list_ != nullptr) { + bssl::UniquePtr<STACK_OF(X509_NAME)> names(SSL_dup_CA_list(ca_list_.get())); + SSL_CTX_set_client_CA_list(ssl_ctx_.get(), names.release()); + } + + // Register our certificate and private key. + std::vector<CRYPTO_BUFFER*> cert_chain = { + cert_.get(), + }; + if (!SSL_CTX_set_chain_and_key(ssl_ctx_.get(), cert_chain.data(), cert_chain.size(), + priv_key_.get(), nullptr)) { + LOG(ERROR) << "Unable to register the certificate chain file and private key [" + << SSLErrorString() << "]"; + Invalidate(); + return TlsError::UnknownFailure; + } + + SSL_CTX_set_verify(ssl_ctx_.get(), SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT, nullptr); + + // Okay! Let's try to do the handshake! + ssl_.reset(SSL_new(ssl_ctx_.get())); + if (!SSL_set_fd(ssl_.get(), fd_.get())) { + LOG(ERROR) << "SSL_set_fd failed. [" << SSLErrorString() << "]"; + return TlsError::UnknownFailure; + } + switch (role_) { + case Role::Server: + err = SSL_accept(ssl_.get()); + break; + case Role::Client: + err = SSL_connect(ssl_.get()); + break; + } + if (err != 1) { + LOG(ERROR) << "Handshake failed in SSL_accept/SSL_connect [" << SSLErrorString() << "]"; + auto sslerr = ERR_get_error(); + Invalidate(); + return GetFailureReason(sslerr); + } + + if (client_verify_post_handshake_ && role_ == Role::Client) { + uint8_t check; + // Try to peek one byte for any failures. This assumes on success that + // the server actually sends something. + err = SSL_peek(ssl_.get(), &check, 1); + if (err <= 0) { + LOG(ERROR) << "Post-handshake SSL_peek failed [" << SSLErrorString() << "]"; + auto sslerr = ERR_get_error(); + Invalidate(); + return GetFailureReason(sslerr); + } + } + + LOG(INFO) << "Handshake succeeded."; + return TlsError::Success; +} + +void TlsConnectionImpl::Invalidate() { + ssl_.reset(); + ssl_ctx_.reset(); +} + +std::vector<uint8_t> TlsConnectionImpl::ReadFully(size_t size) { + std::vector<uint8_t> buf(size); + if (!ReadFully(buf.data(), buf.size())) { + return {}; + } + + return buf; +} + +bool TlsConnectionImpl::ReadFully(void* buf, size_t size) { + CHECK_GT(size, 0U); + if (!ssl_) { + LOG(ERROR) << "Tried to read on a null SSL connection"; + return false; + } + + size_t offset = 0; + uint8_t* p8 = reinterpret_cast<uint8_t*>(buf); + while (size > 0) { + int bytes_read = + SSL_read(ssl_.get(), p8 + offset, std::min(static_cast<size_t>(INT_MAX), size)); + if (bytes_read <= 0) { + LOG(WARNING) << "SSL_read failed [" << SSLErrorString() << "]"; + return false; + } + size -= bytes_read; + offset += bytes_read; + } + return true; +} + +bool TlsConnectionImpl::WriteFully(std::string_view data) { + CHECK(!data.empty()); + if (!ssl_) { + LOG(ERROR) << "Tried to read on a null SSL connection"; + return false; + } + + while (!data.empty()) { + int bytes_out = SSL_write(ssl_.get(), data.data(), + std::min(static_cast<size_t>(INT_MAX), data.size())); + if (bytes_out <= 0) { + LOG(WARNING) << "SSL_write failed [" << SSLErrorString() << "]"; + return false; + } + data = data.substr(bytes_out); + } + return true; +} +} // namespace + +// static +std::unique_ptr<TlsConnection> TlsConnection::Create(TlsConnection::Role role, + std::string_view cert, + std::string_view priv_key, borrowed_fd fd) { + CHECK(!cert.empty()); + CHECK(!priv_key.empty()); + + return std::make_unique<TlsConnectionImpl>(role, cert, priv_key, fd); +} + +// static +bool TlsConnection::SetCertAndKey(SSL* ssl, std::string_view cert, std::string_view priv_key) { + CHECK(ssl); + // Note: declaring these in local scope is okay because + // SSL_set_chain_and_key will increase the refcount (bssl::UpRef). + auto x509_cert = TlsConnectionImpl::BufferFromPEM(cert); + auto evp_pkey = TlsConnectionImpl::EvpPkeyFromPEM(priv_key); + if (x509_cert == nullptr || evp_pkey == nullptr) { + return false; + } + + std::vector<CRYPTO_BUFFER*> cert_chain = { + x509_cert.get(), + }; + if (!SSL_set_chain_and_key(ssl, cert_chain.data(), cert_chain.size(), evp_pkey.get(), + nullptr)) { + LOG(ERROR) << "SSL_set_chain_and_key failed"; + return false; + } + + return true; +} + +} // namespace tls +} // namespace adb |