summaryrefslogtreecommitdiff
path: root/adb/tls/tls_connection.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'adb/tls/tls_connection.cpp')
-rw-r--r--adb/tls/tls_connection.cpp387
1 files changed, 387 insertions, 0 deletions
diff --git a/adb/tls/tls_connection.cpp b/adb/tls/tls_connection.cpp
new file mode 100644
index 000000000..7df6ef410
--- /dev/null
+++ b/adb/tls/tls_connection.cpp
@@ -0,0 +1,387 @@
+/*
+ * Copyright (C) 2019 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except
+ * in compliance with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "adb/tls/tls_connection.h"
+
+#include <algorithm>
+#include <vector>
+
+#include <android-base/logging.h>
+#include <android-base/strings.h>
+#include <openssl/err.h>
+#include <openssl/ssl.h>
+
+using android::base::borrowed_fd;
+
+namespace adb {
+namespace tls {
+
+namespace {
+
+static constexpr char kExportedKeyLabel[] = "adb-label";
+
+class TlsConnectionImpl : public TlsConnection {
+ public:
+ explicit TlsConnectionImpl(Role role, std::string_view cert, std::string_view priv_key,
+ borrowed_fd fd);
+ ~TlsConnectionImpl() override;
+
+ bool AddTrustedCertificate(std::string_view cert) override;
+ void SetCertVerifyCallback(CertVerifyCb cb) override;
+ void SetCertificateCallback(SetCertCb cb) override;
+ void SetClientCAList(STACK_OF(X509_NAME) * ca_list) override;
+ std::vector<uint8_t> ExportKeyingMaterial(size_t length) override;
+ void EnableClientPostHandshakeCheck(bool enable) override;
+ TlsError DoHandshake() override;
+ std::vector<uint8_t> ReadFully(size_t size) override;
+ bool ReadFully(void* buf, size_t size) override;
+ bool WriteFully(std::string_view data) override;
+
+ static bssl::UniquePtr<EVP_PKEY> EvpPkeyFromPEM(std::string_view pem);
+ static bssl::UniquePtr<CRYPTO_BUFFER> BufferFromPEM(std::string_view pem);
+
+ private:
+ static int SSLSetCertVerifyCb(X509_STORE_CTX* ctx, void* opaque);
+ static int SSLSetCertCb(SSL* ssl, void* opaque);
+
+ static bssl::UniquePtr<X509> X509FromBuffer(bssl::UniquePtr<CRYPTO_BUFFER> buffer);
+ static const char* SSLErrorString();
+ void Invalidate();
+ TlsError GetFailureReason(int err);
+
+ Role role_;
+ bssl::UniquePtr<EVP_PKEY> priv_key_;
+ bssl::UniquePtr<CRYPTO_BUFFER> cert_;
+
+ bssl::UniquePtr<STACK_OF(X509_NAME)> ca_list_;
+ bssl::UniquePtr<SSL_CTX> ssl_ctx_;
+ bssl::UniquePtr<SSL> ssl_;
+ std::vector<bssl::UniquePtr<X509>> known_certificates_;
+ bool client_verify_post_handshake_ = false;
+
+ CertVerifyCb cert_verify_cb_;
+ SetCertCb set_cert_cb_;
+ borrowed_fd fd_;
+}; // TlsConnectionImpl
+
+TlsConnectionImpl::TlsConnectionImpl(Role role, std::string_view cert, std::string_view priv_key,
+ borrowed_fd fd)
+ : role_(role), fd_(fd) {
+ CHECK(!cert.empty() && !priv_key.empty());
+ LOG(INFO) << "Initializing adbwifi TlsConnection";
+ cert_ = BufferFromPEM(cert);
+ priv_key_ = EvpPkeyFromPEM(priv_key);
+}
+
+TlsConnectionImpl::~TlsConnectionImpl() {
+ // shutdown the SSL connection
+ if (ssl_ != nullptr) {
+ SSL_shutdown(ssl_.get());
+ }
+}
+
+// static
+const char* TlsConnectionImpl::SSLErrorString() {
+ auto sslerr = ERR_peek_last_error();
+ return ERR_reason_error_string(sslerr);
+}
+
+// static
+bssl::UniquePtr<EVP_PKEY> TlsConnectionImpl::EvpPkeyFromPEM(std::string_view pem) {
+ bssl::UniquePtr<BIO> bio(BIO_new_mem_buf(pem.data(), pem.size()));
+ return bssl::UniquePtr<EVP_PKEY>(PEM_read_bio_PrivateKey(bio.get(), nullptr, nullptr, nullptr));
+}
+
+// static
+bssl::UniquePtr<CRYPTO_BUFFER> TlsConnectionImpl::BufferFromPEM(std::string_view pem) {
+ bssl::UniquePtr<BIO> bio(BIO_new_mem_buf(pem.data(), pem.size()));
+ char* name = nullptr;
+ char* header = nullptr;
+ uint8_t* data = nullptr;
+ long data_len = 0;
+
+ if (!PEM_read_bio(bio.get(), &name, &header, &data, &data_len)) {
+ LOG(ERROR) << "Failed to read certificate";
+ return nullptr;
+ }
+ OPENSSL_free(name);
+ OPENSSL_free(header);
+
+ auto ret = bssl::UniquePtr<CRYPTO_BUFFER>(CRYPTO_BUFFER_new(data, data_len, nullptr));
+ OPENSSL_free(data);
+ return ret;
+}
+
+// static
+bssl::UniquePtr<X509> TlsConnectionImpl::X509FromBuffer(bssl::UniquePtr<CRYPTO_BUFFER> buffer) {
+ if (!buffer) {
+ return nullptr;
+ }
+ return bssl::UniquePtr<X509>(X509_parse_from_buffer(buffer.get()));
+}
+
+// static
+int TlsConnectionImpl::SSLSetCertVerifyCb(X509_STORE_CTX* ctx, void* opaque) {
+ auto* p = reinterpret_cast<TlsConnectionImpl*>(opaque);
+ return p->cert_verify_cb_(ctx);
+}
+
+// static
+int TlsConnectionImpl::SSLSetCertCb(SSL* ssl, void* opaque) {
+ auto* p = reinterpret_cast<TlsConnectionImpl*>(opaque);
+ return p->set_cert_cb_(ssl);
+}
+
+bool TlsConnectionImpl::AddTrustedCertificate(std::string_view cert) {
+ // Create X509 buffer from the certificate string
+ auto buf = X509FromBuffer(BufferFromPEM(cert));
+ if (buf == nullptr) {
+ LOG(ERROR) << "Failed to create a X509 buffer for the certificate.";
+ return false;
+ }
+ known_certificates_.push_back(std::move(buf));
+ return true;
+}
+
+void TlsConnectionImpl::SetCertVerifyCallback(CertVerifyCb cb) {
+ cert_verify_cb_ = cb;
+}
+
+void TlsConnectionImpl::SetCertificateCallback(SetCertCb cb) {
+ set_cert_cb_ = cb;
+}
+
+void TlsConnectionImpl::SetClientCAList(STACK_OF(X509_NAME) * ca_list) {
+ CHECK(role_ == Role::Server);
+ ca_list_.reset(ca_list != nullptr ? SSL_dup_CA_list(ca_list) : nullptr);
+}
+
+std::vector<uint8_t> TlsConnectionImpl::ExportKeyingMaterial(size_t length) {
+ if (ssl_.get() == nullptr) {
+ return {};
+ }
+
+ std::vector<uint8_t> out(length);
+ if (SSL_export_keying_material(ssl_.get(), out.data(), out.size(), kExportedKeyLabel,
+ sizeof(kExportedKeyLabel), nullptr, 0, false) == 0) {
+ return {};
+ }
+ return out;
+}
+
+void TlsConnectionImpl::EnableClientPostHandshakeCheck(bool enable) {
+ client_verify_post_handshake_ = enable;
+}
+
+TlsConnection::TlsError TlsConnectionImpl::GetFailureReason(int err) {
+ switch (ERR_GET_REASON(err)) {
+ case SSL_R_SSLV3_ALERT_BAD_CERTIFICATE:
+ case SSL_R_SSLV3_ALERT_UNSUPPORTED_CERTIFICATE:
+ case SSL_R_SSLV3_ALERT_CERTIFICATE_REVOKED:
+ case SSL_R_SSLV3_ALERT_CERTIFICATE_EXPIRED:
+ case SSL_R_SSLV3_ALERT_CERTIFICATE_UNKNOWN:
+ case SSL_R_TLSV1_ALERT_ACCESS_DENIED:
+ case SSL_R_TLSV1_ALERT_UNKNOWN_CA:
+ case SSL_R_TLSV1_CERTIFICATE_REQUIRED:
+ return TlsError::PeerRejectedCertificate;
+ case SSL_R_CERTIFICATE_VERIFY_FAILED:
+ return TlsError::CertificateRejected;
+ default:
+ return TlsError::UnknownFailure;
+ }
+}
+
+TlsConnection::TlsError TlsConnectionImpl::DoHandshake() {
+ int err = -1;
+ LOG(INFO) << "Starting adbwifi tls handshake";
+ ssl_ctx_.reset(SSL_CTX_new(TLS_method()));
+ // TODO: Remove set_max_proto_version() once external/boringssl is updated
+ // past
+ // https://boringssl.googlesource.com/boringssl/+/58d56f4c59969a23e5f52014e2651c76fea2f877
+ if (ssl_ctx_.get() == nullptr ||
+ !SSL_CTX_set_min_proto_version(ssl_ctx_.get(), TLS1_3_VERSION) ||
+ !SSL_CTX_set_max_proto_version(ssl_ctx_.get(), TLS1_3_VERSION)) {
+ LOG(ERROR) << "Failed to create SSL context";
+ return TlsError::UnknownFailure;
+ }
+
+ // Register user-supplied known certificates
+ for (auto const& cert : known_certificates_) {
+ if (X509_STORE_add_cert(SSL_CTX_get_cert_store(ssl_ctx_.get()), cert.get()) == 0) {
+ LOG(ERROR) << "Unable to add certificates into the X509_STORE";
+ return TlsError::UnknownFailure;
+ }
+ }
+
+ // Custom certificate verification
+ if (cert_verify_cb_) {
+ SSL_CTX_set_cert_verify_callback(ssl_ctx_.get(), SSLSetCertVerifyCb, this);
+ }
+
+ // set select certificate callback, if any.
+ if (set_cert_cb_) {
+ SSL_CTX_set_cert_cb(ssl_ctx_.get(), SSLSetCertCb, this);
+ }
+
+ // Server-allowed client CA list
+ if (ca_list_ != nullptr) {
+ bssl::UniquePtr<STACK_OF(X509_NAME)> names(SSL_dup_CA_list(ca_list_.get()));
+ SSL_CTX_set_client_CA_list(ssl_ctx_.get(), names.release());
+ }
+
+ // Register our certificate and private key.
+ std::vector<CRYPTO_BUFFER*> cert_chain = {
+ cert_.get(),
+ };
+ if (!SSL_CTX_set_chain_and_key(ssl_ctx_.get(), cert_chain.data(), cert_chain.size(),
+ priv_key_.get(), nullptr)) {
+ LOG(ERROR) << "Unable to register the certificate chain file and private key ["
+ << SSLErrorString() << "]";
+ Invalidate();
+ return TlsError::UnknownFailure;
+ }
+
+ SSL_CTX_set_verify(ssl_ctx_.get(), SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT, nullptr);
+
+ // Okay! Let's try to do the handshake!
+ ssl_.reset(SSL_new(ssl_ctx_.get()));
+ if (!SSL_set_fd(ssl_.get(), fd_.get())) {
+ LOG(ERROR) << "SSL_set_fd failed. [" << SSLErrorString() << "]";
+ return TlsError::UnknownFailure;
+ }
+ switch (role_) {
+ case Role::Server:
+ err = SSL_accept(ssl_.get());
+ break;
+ case Role::Client:
+ err = SSL_connect(ssl_.get());
+ break;
+ }
+ if (err != 1) {
+ LOG(ERROR) << "Handshake failed in SSL_accept/SSL_connect [" << SSLErrorString() << "]";
+ auto sslerr = ERR_get_error();
+ Invalidate();
+ return GetFailureReason(sslerr);
+ }
+
+ if (client_verify_post_handshake_ && role_ == Role::Client) {
+ uint8_t check;
+ // Try to peek one byte for any failures. This assumes on success that
+ // the server actually sends something.
+ err = SSL_peek(ssl_.get(), &check, 1);
+ if (err <= 0) {
+ LOG(ERROR) << "Post-handshake SSL_peek failed [" << SSLErrorString() << "]";
+ auto sslerr = ERR_get_error();
+ Invalidate();
+ return GetFailureReason(sslerr);
+ }
+ }
+
+ LOG(INFO) << "Handshake succeeded.";
+ return TlsError::Success;
+}
+
+void TlsConnectionImpl::Invalidate() {
+ ssl_.reset();
+ ssl_ctx_.reset();
+}
+
+std::vector<uint8_t> TlsConnectionImpl::ReadFully(size_t size) {
+ std::vector<uint8_t> buf(size);
+ if (!ReadFully(buf.data(), buf.size())) {
+ return {};
+ }
+
+ return buf;
+}
+
+bool TlsConnectionImpl::ReadFully(void* buf, size_t size) {
+ CHECK_GT(size, 0U);
+ if (!ssl_) {
+ LOG(ERROR) << "Tried to read on a null SSL connection";
+ return false;
+ }
+
+ size_t offset = 0;
+ uint8_t* p8 = reinterpret_cast<uint8_t*>(buf);
+ while (size > 0) {
+ int bytes_read =
+ SSL_read(ssl_.get(), p8 + offset, std::min(static_cast<size_t>(INT_MAX), size));
+ if (bytes_read <= 0) {
+ LOG(WARNING) << "SSL_read failed [" << SSLErrorString() << "]";
+ return false;
+ }
+ size -= bytes_read;
+ offset += bytes_read;
+ }
+ return true;
+}
+
+bool TlsConnectionImpl::WriteFully(std::string_view data) {
+ CHECK(!data.empty());
+ if (!ssl_) {
+ LOG(ERROR) << "Tried to read on a null SSL connection";
+ return false;
+ }
+
+ while (!data.empty()) {
+ int bytes_out = SSL_write(ssl_.get(), data.data(),
+ std::min(static_cast<size_t>(INT_MAX), data.size()));
+ if (bytes_out <= 0) {
+ LOG(WARNING) << "SSL_write failed [" << SSLErrorString() << "]";
+ return false;
+ }
+ data = data.substr(bytes_out);
+ }
+ return true;
+}
+} // namespace
+
+// static
+std::unique_ptr<TlsConnection> TlsConnection::Create(TlsConnection::Role role,
+ std::string_view cert,
+ std::string_view priv_key, borrowed_fd fd) {
+ CHECK(!cert.empty());
+ CHECK(!priv_key.empty());
+
+ return std::make_unique<TlsConnectionImpl>(role, cert, priv_key, fd);
+}
+
+// static
+bool TlsConnection::SetCertAndKey(SSL* ssl, std::string_view cert, std::string_view priv_key) {
+ CHECK(ssl);
+ // Note: declaring these in local scope is okay because
+ // SSL_set_chain_and_key will increase the refcount (bssl::UpRef).
+ auto x509_cert = TlsConnectionImpl::BufferFromPEM(cert);
+ auto evp_pkey = TlsConnectionImpl::EvpPkeyFromPEM(priv_key);
+ if (x509_cert == nullptr || evp_pkey == nullptr) {
+ return false;
+ }
+
+ std::vector<CRYPTO_BUFFER*> cert_chain = {
+ x509_cert.get(),
+ };
+ if (!SSL_set_chain_and_key(ssl, cert_chain.data(), cert_chain.size(), evp_pkey.get(),
+ nullptr)) {
+ LOG(ERROR) << "SSL_set_chain_and_key failed";
+ return false;
+ }
+
+ return true;
+}
+
+} // namespace tls
+} // namespace adb