diff options
Diffstat (limited to 'adb/pairing_connection/pairing_connection.cpp')
-rw-r--r-- | adb/pairing_connection/pairing_connection.cpp | 491 |
1 files changed, 491 insertions, 0 deletions
diff --git a/adb/pairing_connection/pairing_connection.cpp b/adb/pairing_connection/pairing_connection.cpp new file mode 100644 index 000000000..a26a6b4d2 --- /dev/null +++ b/adb/pairing_connection/pairing_connection.cpp @@ -0,0 +1,491 @@ +/* + * Copyright (C) 2020 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "adb/pairing/pairing_connection.h" + +#include <stddef.h> +#include <stdint.h> + +#include <functional> +#include <memory> +#include <string_view> +#include <thread> +#include <vector> + +#include <adb/pairing/pairing_auth.h> +#include <adb/tls/tls_connection.h> +#include <android-base/endian.h> +#include <android-base/logging.h> +#include <android-base/macros.h> +#include <android-base/unique_fd.h> + +#include "pairing.pb.h" + +using namespace adb; +using android::base::unique_fd; +using TlsError = tls::TlsConnection::TlsError; + +const uint8_t kCurrentKeyHeaderVersion = 1; +const uint8_t kMinSupportedKeyHeaderVersion = 1; +const uint8_t kMaxSupportedKeyHeaderVersion = 1; +const uint32_t kMaxPayloadSize = kMaxPeerInfoSize * 2; + +struct PairingPacketHeader { + uint8_t version; // PairingPacket version + uint8_t type; // the type of packet (PairingPacket.Type) + uint32_t payload; // Size of the payload in bytes +} __attribute__((packed)); + +struct PairingAuthDeleter { + void operator()(PairingAuthCtx* p) { pairing_auth_destroy(p); } +}; // PairingAuthDeleter +using PairingAuthPtr = std::unique_ptr<PairingAuthCtx, PairingAuthDeleter>; + +// PairingConnectionCtx encapsulates the protocol to authenticate two peers with +// each other. This class will open the tcp sockets and handle the pairing +// process. On completion, both sides will have each other's public key +// (certificate) if successful, otherwise, the pairing failed. The tcp port +// number is hardcoded (see pairing_connection.cpp). +// +// Each PairingConnectionCtx instance represents a different device trying to +// pair. So for the device, we can have multiple PairingConnectionCtxs while the +// host may have only one (unless host has a PairingServer). +// +// See pairing_connection_test.cpp for example usage. +// +struct PairingConnectionCtx { + public: + using Data = std::vector<uint8_t>; + using ResultCallback = pairing_result_cb; + enum class Role { + Client, + Server, + }; + + explicit PairingConnectionCtx(Role role, const Data& pswd, const PeerInfo& peer_info, + const Data& certificate, const Data& priv_key); + virtual ~PairingConnectionCtx(); + + // Starts the pairing connection on a separate thread. + // Upon completion, if the pairing was successful, + // |cb| will be called with the peer information and certificate. + // Otherwise, |cb| will be called with empty data. |fd| should already + // be opened. PairingConnectionCtx will take ownership of the |fd|. + // + // Pairing is successful if both server/client uses the same non-empty + // |pswd|, and they are able to exchange the information. |pswd| and + // |certificate| must be non-empty. Start() can only be called once in the + // lifetime of this object. + // + // Returns true if the thread was successfully started, false otherwise. + bool Start(int fd, ResultCallback cb, void* opaque); + + private: + // Setup the tls connection. + bool SetupTlsConnection(); + + /************ PairingPacketHeader methods ****************/ + // Tries to write out the header and payload. + bool WriteHeader(const PairingPacketHeader* header, std::string_view payload); + // Tries to parse incoming data into the |header|. Returns true if header + // is valid and header version is supported. |header| is filled on success. + // |header| may contain garbage if unsuccessful. + bool ReadHeader(PairingPacketHeader* header); + // Creates a PairingPacketHeader. + void CreateHeader(PairingPacketHeader* header, adb::proto::PairingPacket::Type type, + uint32_t payload_size); + // Checks if actual matches expected. + bool CheckHeaderType(adb::proto::PairingPacket::Type expected, uint8_t actual); + + /*********** State related methods **************/ + // Handles the State::ExchangingMsgs state. + bool DoExchangeMsgs(); + // Handles the State::ExchangingPeerInfo state. + bool DoExchangePeerInfo(); + + // The background task to do the pairing. + void StartWorker(); + + // Calls |cb_| and sets the state to Stopped. + void NotifyResult(const PeerInfo* p); + + static PairingAuthPtr CreatePairingAuthPtr(Role role, const Data& pswd); + + enum class State { + Ready, + ExchangingMsgs, + ExchangingPeerInfo, + Stopped, + }; + + std::atomic<State> state_{State::Ready}; + Role role_; + Data pswd_; + PeerInfo peer_info_; + Data cert_; + Data priv_key_; + + // Peer's info + PeerInfo their_info_; + + ResultCallback cb_; + void* opaque_ = nullptr; + std::unique_ptr<tls::TlsConnection> tls_; + PairingAuthPtr auth_; + unique_fd fd_; + std::thread thread_; + static constexpr size_t kExportedKeySize = 64; +}; // PairingConnectionCtx + +PairingConnectionCtx::PairingConnectionCtx(Role role, const Data& pswd, const PeerInfo& peer_info, + const Data& cert, const Data& priv_key) + : role_(role), pswd_(pswd), peer_info_(peer_info), cert_(cert), priv_key_(priv_key) { + CHECK(!pswd_.empty() && !cert_.empty() && !priv_key_.empty()); +} + +PairingConnectionCtx::~PairingConnectionCtx() { + // Force close the fd and wait for the worker thread to finish. + fd_.reset(); + if (thread_.joinable()) { + thread_.join(); + } +} + +bool PairingConnectionCtx::SetupTlsConnection() { + tls_ = tls::TlsConnection::Create( + role_ == Role::Server ? tls::TlsConnection::Role::Server + : tls::TlsConnection::Role::Client, + std::string_view(reinterpret_cast<const char*>(cert_.data()), cert_.size()), + std::string_view(reinterpret_cast<const char*>(priv_key_.data()), priv_key_.size()), + fd_); + + if (tls_ == nullptr) { + LOG(ERROR) << "Unable to start TlsConnection. Unable to pair fd=" << fd_.get(); + return false; + } + + // Allow any peer certificate + tls_->SetCertVerifyCallback([](X509_STORE_CTX*) { return 1; }); + + // SSL doesn't seem to behave correctly with fdevents so just do a blocking + // read for the pairing data. + if (tls_->DoHandshake() != TlsError::Success) { + LOG(ERROR) << "Failed to handshake with the peer fd=" << fd_.get(); + return false; + } + + // To ensure the connection is not stolen while we do the PAKE, append the + // exported key material from the tls connection to the password. + std::vector<uint8_t> exportedKeyMaterial = tls_->ExportKeyingMaterial(kExportedKeySize); + if (exportedKeyMaterial.empty()) { + LOG(ERROR) << "Failed to export key material"; + return false; + } + pswd_.insert(pswd_.end(), std::make_move_iterator(exportedKeyMaterial.begin()), + std::make_move_iterator(exportedKeyMaterial.end())); + auth_ = CreatePairingAuthPtr(role_, pswd_); + + return true; +} + +bool PairingConnectionCtx::WriteHeader(const PairingPacketHeader* header, + std::string_view payload) { + PairingPacketHeader network_header = *header; + network_header.payload = htonl(network_header.payload); + if (!tls_->WriteFully(std::string_view(reinterpret_cast<const char*>(&network_header), + sizeof(PairingPacketHeader))) || + !tls_->WriteFully(payload)) { + LOG(ERROR) << "Failed to write out PairingPacketHeader"; + state_ = State::Stopped; + return false; + } + return true; +} + +bool PairingConnectionCtx::ReadHeader(PairingPacketHeader* header) { + auto data = tls_->ReadFully(sizeof(PairingPacketHeader)); + if (data.empty()) { + return false; + } + + uint8_t* p = data.data(); + // First byte is always PairingPacketHeader version + header->version = *p; + ++p; + if (header->version < kMinSupportedKeyHeaderVersion || + header->version > kMaxSupportedKeyHeaderVersion) { + LOG(ERROR) << "PairingPacketHeader version mismatch (us=" << kCurrentKeyHeaderVersion + << " them=" << header->version << ")"; + return false; + } + // Next byte is the PairingPacket::Type + if (!adb::proto::PairingPacket::Type_IsValid(*p)) { + LOG(ERROR) << "Unknown PairingPacket type=" << static_cast<uint32_t>(*p); + return false; + } + header->type = *p; + ++p; + // Last, the payload size + header->payload = ntohl(*(reinterpret_cast<uint32_t*>(p))); + if (header->payload == 0 || header->payload > kMaxPayloadSize) { + LOG(ERROR) << "header payload not within a safe payload size (size=" << header->payload + << ")"; + return false; + } + + return true; +} + +void PairingConnectionCtx::CreateHeader(PairingPacketHeader* header, + adb::proto::PairingPacket::Type type, + uint32_t payload_size) { + header->version = kCurrentKeyHeaderVersion; + uint8_t type8 = static_cast<uint8_t>(static_cast<int>(type)); + header->type = type8; + header->payload = payload_size; +} + +bool PairingConnectionCtx::CheckHeaderType(adb::proto::PairingPacket::Type expected_type, + uint8_t actual) { + uint8_t expected = *reinterpret_cast<uint8_t*>(&expected_type); + if (actual != expected) { + LOG(ERROR) << "Unexpected header type (expected=" << static_cast<uint32_t>(expected) + << " actual=" << static_cast<uint32_t>(actual) << ")"; + return false; + } + return true; +} + +void PairingConnectionCtx::NotifyResult(const PeerInfo* p) { + cb_(p, fd_.get(), opaque_); + state_ = State::Stopped; +} + +bool PairingConnectionCtx::Start(int fd, ResultCallback cb, void* opaque) { + if (fd < 0) { + return false; + } + + State expected = State::Ready; + if (!state_.compare_exchange_strong(expected, State::ExchangingMsgs)) { + return false; + } + + fd_.reset(fd); + cb_ = cb; + opaque_ = opaque; + + thread_ = std::thread([this] { StartWorker(); }); + return true; +} + +bool PairingConnectionCtx::DoExchangeMsgs() { + uint32_t payload = pairing_auth_msg_size(auth_.get()); + std::vector<uint8_t> msg(payload); + pairing_auth_get_spake2_msg(auth_.get(), msg.data()); + + PairingPacketHeader header; + CreateHeader(&header, adb::proto::PairingPacket::SPAKE2_MSG, payload); + + // Write our SPAKE2 msg + if (!WriteHeader(&header, + std::string_view(reinterpret_cast<const char*>(msg.data()), msg.size()))) { + LOG(ERROR) << "Failed to write SPAKE2 msg."; + return false; + } + + // Read the peer's SPAKE2 msg header + if (!ReadHeader(&header)) { + LOG(ERROR) << "Invalid PairingPacketHeader."; + return false; + } + if (!CheckHeaderType(adb::proto::PairingPacket::SPAKE2_MSG, header.type)) { + return false; + } + + // Read the SPAKE2 msg payload and initialize the cipher for + // encrypting the PeerInfo and certificate. + auto their_msg = tls_->ReadFully(header.payload); + if (their_msg.empty() || + !pairing_auth_init_cipher(auth_.get(), their_msg.data(), their_msg.size())) { + LOG(ERROR) << "Unable to initialize pairing cipher [their_msg.size=" << their_msg.size() + << "]"; + return false; + } + + return true; +} + +bool PairingConnectionCtx::DoExchangePeerInfo() { + // Encrypt PeerInfo + std::vector<uint8_t> buf; + uint8_t* p = reinterpret_cast<uint8_t*>(&peer_info_); + buf.assign(p, p + sizeof(peer_info_)); + std::vector<uint8_t> outbuf(pairing_auth_safe_encrypted_size(auth_.get(), buf.size())); + CHECK(!outbuf.empty()); + size_t outsize; + if (!pairing_auth_encrypt(auth_.get(), buf.data(), buf.size(), outbuf.data(), &outsize)) { + LOG(ERROR) << "Failed to encrypt peer info"; + return false; + } + outbuf.resize(outsize); + + // Write out the packet header + PairingPacketHeader out_header; + out_header.version = kCurrentKeyHeaderVersion; + out_header.type = static_cast<uint8_t>(static_cast<int>(adb::proto::PairingPacket::PEER_INFO)); + out_header.payload = htonl(outbuf.size()); + if (!tls_->WriteFully( + std::string_view(reinterpret_cast<const char*>(&out_header), sizeof(out_header)))) { + LOG(ERROR) << "Unable to write PairingPacketHeader"; + return false; + } + + // Write out the encrypted payload + if (!tls_->WriteFully( + std::string_view(reinterpret_cast<const char*>(outbuf.data()), outbuf.size()))) { + LOG(ERROR) << "Unable to write encrypted peer info"; + return false; + } + + // Read in the peer's packet header + PairingPacketHeader header; + if (!ReadHeader(&header)) { + LOG(ERROR) << "Invalid PairingPacketHeader."; + return false; + } + + if (!CheckHeaderType(adb::proto::PairingPacket::PEER_INFO, header.type)) { + return false; + } + + // Read in the encrypted peer certificate + buf = tls_->ReadFully(header.payload); + if (buf.empty()) { + return false; + } + + // Try to decrypt the certificate + outbuf.resize(pairing_auth_safe_decrypted_size(auth_.get(), buf.data(), buf.size())); + if (outbuf.empty()) { + LOG(ERROR) << "Unsupported payload while decrypting peer info."; + return false; + } + + if (!pairing_auth_decrypt(auth_.get(), buf.data(), buf.size(), outbuf.data(), &outsize)) { + LOG(ERROR) << "Failed to decrypt"; + return false; + } + outbuf.resize(outsize); + + // The decrypted message should contain the PeerInfo. + if (outbuf.size() != sizeof(PeerInfo)) { + LOG(ERROR) << "Got size=" << outbuf.size() << "PeerInfo.size=" << sizeof(PeerInfo); + return false; + } + + p = outbuf.data(); + ::memcpy(&their_info_, p, sizeof(PeerInfo)); + p += sizeof(PeerInfo); + + return true; +} + +void PairingConnectionCtx::StartWorker() { + // Setup the secure transport + if (!SetupTlsConnection()) { + NotifyResult(nullptr); + return; + } + + for (;;) { + switch (state_) { + case State::ExchangingMsgs: + if (!DoExchangeMsgs()) { + NotifyResult(nullptr); + return; + } + state_ = State::ExchangingPeerInfo; + break; + case State::ExchangingPeerInfo: + if (!DoExchangePeerInfo()) { + NotifyResult(nullptr); + return; + } + NotifyResult(&their_info_); + return; + case State::Ready: + case State::Stopped: + LOG(FATAL) << __func__ << ": Got invalid state"; + return; + } + } +} + +// static +PairingAuthPtr PairingConnectionCtx::CreatePairingAuthPtr(Role role, const Data& pswd) { + switch (role) { + case Role::Client: + return PairingAuthPtr(pairing_auth_client_new(pswd.data(), pswd.size())); + break; + case Role::Server: + return PairingAuthPtr(pairing_auth_server_new(pswd.data(), pswd.size())); + break; + } +} + +static PairingConnectionCtx* CreateConnection(PairingConnectionCtx::Role role, const uint8_t* pswd, + size_t pswd_len, const PeerInfo* peer_info, + const uint8_t* x509_cert_pem, size_t x509_size, + const uint8_t* priv_key_pem, size_t priv_size) { + CHECK(pswd); + CHECK_GT(pswd_len, 0U); + CHECK(x509_cert_pem); + CHECK_GT(x509_size, 0U); + CHECK(priv_key_pem); + CHECK_GT(priv_size, 0U); + CHECK(peer_info); + std::vector<uint8_t> vec_pswd(pswd, pswd + pswd_len); + std::vector<uint8_t> vec_x509_cert(x509_cert_pem, x509_cert_pem + x509_size); + std::vector<uint8_t> vec_priv_key(priv_key_pem, priv_key_pem + priv_size); + return new PairingConnectionCtx(role, vec_pswd, *peer_info, vec_x509_cert, vec_priv_key); +} + +PairingConnectionCtx* pairing_connection_client_new(const uint8_t* pswd, size_t pswd_len, + const PeerInfo* peer_info, + const uint8_t* x509_cert_pem, size_t x509_size, + const uint8_t* priv_key_pem, size_t priv_size) { + return CreateConnection(PairingConnectionCtx::Role::Client, pswd, pswd_len, peer_info, + x509_cert_pem, x509_size, priv_key_pem, priv_size); +} + +PairingConnectionCtx* pairing_connection_server_new(const uint8_t* pswd, size_t pswd_len, + const PeerInfo* peer_info, + const uint8_t* x509_cert_pem, size_t x509_size, + const uint8_t* priv_key_pem, size_t priv_size) { + return CreateConnection(PairingConnectionCtx::Role::Server, pswd, pswd_len, peer_info, + x509_cert_pem, x509_size, priv_key_pem, priv_size); +} + +void pairing_connection_destroy(PairingConnectionCtx* ctx) { + CHECK(ctx); + delete ctx; +} + +bool pairing_connection_start(PairingConnectionCtx* ctx, int fd, pairing_result_cb cb, + void* opaque) { + return ctx->Start(fd, cb, opaque); +} |