diff options
-rw-r--r-- | common/constants.cc | 2 | ||||
-rw-r--r-- | common/constants.h | 2 | ||||
-rw-r--r-- | common/http_fetcher.h | 8 | ||||
-rw-r--r-- | common/http_fetcher_unittest.cc | 54 | ||||
-rw-r--r-- | common/libcurl_http_fetcher.cc | 41 | ||||
-rw-r--r-- | common/libcurl_http_fetcher.h | 7 | ||||
-rw-r--r-- | common/mock_http_fetcher.cc | 6 | ||||
-rw-r--r-- | common/mock_http_fetcher.h | 7 | ||||
-rw-r--r-- | common/multi_range_http_fetcher.h | 5 | ||||
-rw-r--r-- | test_http_server.cc | 23 | ||||
-rw-r--r-- | update_attempter_android.cc | 7 |
11 files changed, 135 insertions, 27 deletions
diff --git a/common/constants.cc b/common/constants.cc index b15c3f40..3b7aa6eb 100644 --- a/common/constants.cc +++ b/common/constants.cc @@ -91,5 +91,7 @@ const char kPayloadPropertyFileSize[] = "FILE_SIZE"; const char kPayloadPropertyFileHash[] = "FILE_HASH"; const char kPayloadPropertyMetadataSize[] = "METADATA_SIZE"; const char kPayloadPropertyMetadataHash[] = "METADATA_HASH"; +const char kPayloadPropertyAuthorization[] = "AUTHORIZATION"; +const char kPayloadPropertyUserAgent[] = "USER_AGENT"; } // namespace chromeos_update_engine diff --git a/common/constants.h b/common/constants.h index 62f61ce7..d0013297 100644 --- a/common/constants.h +++ b/common/constants.h @@ -93,6 +93,8 @@ extern const char kPayloadPropertyFileSize[]; extern const char kPayloadPropertyFileHash[]; extern const char kPayloadPropertyMetadataSize[]; extern const char kPayloadPropertyMetadataHash[]; +extern const char kPayloadPropertyAuthorization[]; +extern const char kPayloadPropertyUserAgent[]; // A download source is any combination of protocol and server (that's of // interest to us when looking at UMA metrics) using which we may download diff --git a/common/http_fetcher.h b/common/http_fetcher.h index 11e8e9f7..d2499eb1 100644 --- a/common/http_fetcher.h +++ b/common/http_fetcher.h @@ -44,7 +44,7 @@ class HttpFetcher { // |proxy_resolver| is the resolver that will be consulted for proxy // settings. It may be null, in which case direct connections will // be used. Does not take ownership of the resolver. - HttpFetcher(ProxyResolver* proxy_resolver) + explicit HttpFetcher(ProxyResolver* proxy_resolver) : post_data_set_(false), http_response_code_(0), delegate_(nullptr), @@ -95,6 +95,12 @@ class HttpFetcher { // TransferTerminated() will be called when the transfer is actually done. virtual void TerminateTransfer() = 0; + // Add or update a custom header to be sent with every request. If the same + // |header_name| is passed twice, the second |header_value| would override the + // previous value. + virtual void SetHeader(const std::string& header_name, + const std::string& header_value) = 0; + // If data is coming in too quickly, you can call Pause() to pause the // transfer. The delegate will not have ReceivedBytes() called while // an HttpFetcher is paused. diff --git a/common/http_fetcher_unittest.cc b/common/http_fetcher_unittest.cc index aaa538c0..bd723d77 100644 --- a/common/http_fetcher_unittest.cc +++ b/common/http_fetcher_unittest.cc @@ -377,12 +377,12 @@ TYPED_TEST_CASE(HttpFetcherTest, HttpFetcherTestTypes); namespace { class HttpFetcherTestDelegate : public HttpFetcherDelegate { public: - HttpFetcherTestDelegate() : - is_expect_error_(false), times_transfer_complete_called_(0), - times_transfer_terminated_called_(0), times_received_bytes_called_(0) {} + HttpFetcherTestDelegate() = default; void ReceivedBytes(HttpFetcher* /* fetcher */, - const void* /* bytes */, size_t /* length */) override { + const void* bytes, + size_t length) override { + data.append(reinterpret_cast<const char*>(bytes), length); // Update counters times_received_bytes_called_++; } @@ -404,12 +404,15 @@ class HttpFetcherTestDelegate : public HttpFetcherDelegate { } // Are we expecting an error response? (default: no) - bool is_expect_error_; + bool is_expect_error_{false}; // Counters for callback invocations. - int times_transfer_complete_called_; - int times_transfer_terminated_called_; - int times_received_bytes_called_; + int times_transfer_complete_called_{0}; + int times_transfer_terminated_called_{0}; + int times_received_bytes_called_{0}; + + // The received data bytes. + string data; }; @@ -480,6 +483,41 @@ TYPED_TEST(HttpFetcherTest, ErrorTest) { CHECK_EQ(delegate.times_transfer_terminated_called_, 0); } +TYPED_TEST(HttpFetcherTest, ExtraHeadersInRequestTest) { + if (this->test_.IsMock()) + return; + + HttpFetcherTestDelegate delegate; + unique_ptr<HttpFetcher> fetcher(this->test_.NewSmallFetcher()); + fetcher->set_delegate(&delegate); + fetcher->SetHeader("User-Agent", "MyTest"); + fetcher->SetHeader("user-agent", "Override that header"); + fetcher->SetHeader("Authorization", "Basic user:passwd"); + + // Invalid headers. + fetcher->SetHeader("X-Foo", "Invalid\nHeader\nIgnored"); + fetcher->SetHeader("X-Bar: ", "I do not know how to parse"); + + // Hide Accept header normally added by default. + fetcher->SetHeader("Accept", ""); + + PythonHttpServer server; + int port = server.GetPort(); + ASSERT_TRUE(server.started_); + + StartTransfer(fetcher.get(), LocalServerUrlForPath(port, "/echo-headers")); + this->loop_.Run(); + + EXPECT_NE(string::npos, + delegate.data.find("user-agent: Override that header\r\n")); + EXPECT_NE(string::npos, + delegate.data.find("Authorization: Basic user:passwd\r\n")); + + EXPECT_EQ(string::npos, delegate.data.find("\nAccept:")); + EXPECT_EQ(string::npos, delegate.data.find("X-Foo: Invalid")); + EXPECT_EQ(string::npos, delegate.data.find("X-Bar: I do not")); +} + namespace { class PausingHttpFetcherTestDelegate : public HttpFetcherDelegate { public: diff --git a/common/libcurl_http_fetcher.cc b/common/libcurl_http_fetcher.cc index 761b74e1..725bdd46 100644 --- a/common/libcurl_http_fetcher.cc +++ b/common/libcurl_http_fetcher.cc @@ -128,24 +128,32 @@ void LibcurlHttpFetcher::ResumeTransfer(const string& url) { CHECK_EQ(curl_easy_setopt(curl_handle_, CURLOPT_POSTFIELDSIZE, post_data_.size()), CURLE_OK); + } + // Setup extra HTTP headers. + if (curl_http_headers_) { + curl_slist_free_all(curl_http_headers_); + curl_http_headers_ = nullptr; + } + for (const auto& header : extra_headers_) { + // curl_slist_append() copies the string. + curl_http_headers_ = + curl_slist_append(curl_http_headers_, header.second.c_str()); + } + if (post_data_set_) { // Set the Content-Type HTTP header, if one was specifically set. - CHECK(!curl_http_headers_); if (post_content_type_ != kHttpContentTypeUnspecified) { - const string content_type_attr = - base::StringPrintf("Content-Type: %s", - GetHttpContentTypeString(post_content_type_)); - curl_http_headers_ = curl_slist_append(nullptr, - content_type_attr.c_str()); - CHECK(curl_http_headers_); - CHECK_EQ( - curl_easy_setopt(curl_handle_, CURLOPT_HTTPHEADER, - curl_http_headers_), - CURLE_OK); + const string content_type_attr = base::StringPrintf( + "Content-Type: %s", GetHttpContentTypeString(post_content_type_)); + curl_http_headers_ = + curl_slist_append(curl_http_headers_, content_type_attr.c_str()); } else { LOG(WARNING) << "no content type set, using libcurl default"; } } + CHECK_EQ( + curl_easy_setopt(curl_handle_, CURLOPT_HTTPHEADER, curl_http_headers_), + CURLE_OK); if (bytes_downloaded_ > 0 || download_length_) { // Resume from where we left off. @@ -311,6 +319,17 @@ void LibcurlHttpFetcher::TerminateTransfer() { } } +void LibcurlHttpFetcher::SetHeader(const string& header_name, + const string& header_value) { + string header_line = header_name + ": " + header_value; + // Avoid the space if no data on the right side of the semicolon. + if (header_value.empty()) + header_line = header_name + ":"; + TEST_AND_RETURN(header_line.find('\n') == string::npos); + TEST_AND_RETURN(header_name.find(':') == string::npos); + extra_headers_[base::ToLowerASCII(header_name)] = header_line; +} + void LibcurlHttpFetcher::CurlPerformOnce() { CHECK(transfer_in_progress_); int running_handles = 0; diff --git a/common/libcurl_http_fetcher.h b/common/libcurl_http_fetcher.h index 66dbb18c..5a642369 100644 --- a/common/libcurl_http_fetcher.h +++ b/common/libcurl_http_fetcher.h @@ -57,6 +57,10 @@ class LibcurlHttpFetcher : public HttpFetcher { // cannot be resumed. void TerminateTransfer() override; + // Pass the headers to libcurl. + void SetHeader(const std::string& header_name, + const std::string& header_value) override; + // Suspend the transfer by calling curl_easy_pause(CURLPAUSE_ALL). void Pause() override; @@ -181,6 +185,9 @@ class LibcurlHttpFetcher : public HttpFetcher { CURL* curl_handle_{nullptr}; struct curl_slist* curl_http_headers_{nullptr}; + // The extra headers that will be sent on each request. + std::map<std::string, std::string> extra_headers_; + // Lists of all read(0)/write(1) file descriptors that we're waiting on from // the message loop. libcurl may open/close descriptors and switch their // directions so maintain two separate lists so that watch conditions can be diff --git a/common/mock_http_fetcher.cc b/common/mock_http_fetcher.cc index f3fa70db..d0348f13 100644 --- a/common/mock_http_fetcher.cc +++ b/common/mock_http_fetcher.cc @@ -20,6 +20,7 @@ #include <base/bind.h> #include <base/logging.h> +#include <base/strings/string_util.h> #include <base/time/time.h> #include <gtest/gtest.h> @@ -117,6 +118,11 @@ void MockHttpFetcher::TerminateTransfer() { delegate_->TransferTerminated(this); } +void MockHttpFetcher::SetHeader(const std::string& header_name, + const std::string& header_value) { + extra_headers_[base::ToLowerASCII(header_name)] = header_value; +} + void MockHttpFetcher::Pause() { CHECK(!paused_); paused_ = true; diff --git a/common/mock_http_fetcher.h b/common/mock_http_fetcher.h index 90d34ddd..e56318e3 100644 --- a/common/mock_http_fetcher.h +++ b/common/mock_http_fetcher.h @@ -17,6 +17,7 @@ #ifndef UPDATE_ENGINE_COMMON_MOCK_HTTP_FETCHER_H_ #define UPDATE_ENGINE_COMMON_MOCK_HTTP_FETCHER_H_ +#include <map> #include <string> #include <vector> @@ -87,6 +88,9 @@ class MockHttpFetcher : public HttpFetcher { // The transfer cannot be resumed. void TerminateTransfer() override; + void SetHeader(const std::string& header_name, + const std::string& header_value) override; + // Suspend the mock transfer. void Pause() override; @@ -125,6 +129,9 @@ class MockHttpFetcher : public HttpFetcher { // The number of bytes we've sent so far size_t sent_size_; + // The extra headers set. + std::map<std::string, std::string> extra_headers_; + // The TaskId of the timeout callback. After each chunk of data sent, we // time out for 0s just to make sure that run loop services other clients. brillo::MessageLoop::TaskId timeout_id_; diff --git a/common/multi_range_http_fetcher.h b/common/multi_range_http_fetcher.h index 8158a229..8a91eadf 100644 --- a/common/multi_range_http_fetcher.h +++ b/common/multi_range_http_fetcher.h @@ -80,6 +80,11 @@ class MultiRangeHttpFetcher : public HttpFetcher, public HttpFetcherDelegate { // State change: Downloading -> Pending transfer ended void TerminateTransfer() override; + void SetHeader(const std::string& header_name, + const std::string& header_value) override { + base_fetcher_->SetHeader(header_name, header_value); + } + void Pause() override { base_fetcher_->Pause(); } void Unpause() override { base_fetcher_->Unpause(); } diff --git a/test_http_server.cc b/test_http_server.cc index 98e7a6da..2955e79f 100644 --- a/test_http_server.cc +++ b/test_http_server.cc @@ -72,13 +72,12 @@ enum { }; struct HttpRequest { - HttpRequest() - : start_offset(0), end_offset(0), return_code(kHttpResponseOk) {} + string raw_headers; string host; string url; - off_t start_offset; - off_t end_offset; // non-inclusive, zero indicates unspecified. - HttpResponseCode return_code; + off_t start_offset{0}; + off_t end_offset{0}; // non-inclusive, zero indicates unspecified. + HttpResponseCode return_code{kHttpResponseOk}; }; bool ParseRequest(int fd, HttpRequest* request) { @@ -96,6 +95,7 @@ bool ParseRequest(int fd, HttpRequest* request) { LOG(INFO) << "got headers:\n--8<------8<------8<------8<----\n" << headers << "\n--8<------8<------8<------8<----"; + request->raw_headers = headers; // Break header into lines. vector<string> lines; @@ -452,6 +452,13 @@ ssize_t HandleErrorIfOffset(int fd, const HttpRequest& request, } } +// Returns a valid response echoing in the body of the response all the headers +// sent by the client. +void HandleEchoHeaders(int fd, const HttpRequest& request) { + WriteHeaders(fd, 0, request.raw_headers.size(), kHttpResponseOk); + WriteString(fd, request.raw_headers); +} + void HandleHang(int fd) { LOG(INFO) << "Hanging until the other side of the connection is closed."; char c; @@ -512,8 +519,8 @@ void HandleConnection(int fd) { LOG(INFO) << "pid(" << getpid() << "): handling url " << url; if (url == "/quitquitquit") { HandleQuit(fd); - } else if (base::StartsWith(url, "/download/", - base::CompareCase::SENSITIVE)) { + } else if (base::StartsWith( + url, "/download/", base::CompareCase::SENSITIVE)) { const UrlTerms terms(url, 2); HandleGet(fd, request, terms.GetSizeT(1)); } else if (base::StartsWith(url, "/flaky/", base::CompareCase::SENSITIVE)) { @@ -528,6 +535,8 @@ void HandleConnection(int fd) { base::CompareCase::SENSITIVE)) { const UrlTerms terms(url, 3); HandleErrorIfOffset(fd, request, terms.GetSizeT(1), terms.GetInt(2)); + } else if (url == "/echo-headers") { + HandleEchoHeaders(fd, request); } else if (url == "/hang") { HandleHang(fd); } else { diff --git a/update_attempter_android.cc b/update_attempter_android.cc index cc478a41..8403dec1 100644 --- a/update_attempter_android.cc +++ b/update_attempter_android.cc @@ -171,6 +171,13 @@ bool UpdateAttempterAndroid::ApplyPayload( BuildUpdateActions(); SetupDownload(); + // Setup extra headers. + HttpFetcher* fetcher = download_action_->http_fetcher(); + if (!headers[kPayloadPropertyAuthorization].empty()) + fetcher->SetHeader("Authorization", headers[kPayloadPropertyAuthorization]); + if (!headers[kPayloadPropertyUserAgent].empty()) + fetcher->SetHeader("User-Agent", headers[kPayloadPropertyUserAgent]); + cpu_limiter_.StartLimiter(); SetStatusAndNotify(UpdateStatus::UPDATE_AVAILABLE); |