diff options
Diffstat (limited to 'fs_mgr/libsnapshot/cow_decompress.cpp')
-rw-r--r-- | fs_mgr/libsnapshot/cow_decompress.cpp | 264 |
1 files changed, 264 insertions, 0 deletions
diff --git a/fs_mgr/libsnapshot/cow_decompress.cpp b/fs_mgr/libsnapshot/cow_decompress.cpp new file mode 100644 index 000000000..faceafe17 --- /dev/null +++ b/fs_mgr/libsnapshot/cow_decompress.cpp @@ -0,0 +1,264 @@ +// +// Copyright (C) 2020 The Android Open Source Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#include "cow_decompress.h" + +#include <utility> + +#include <android-base/logging.h> +#include <brotli/decode.h> +#include <zlib.h> + +namespace android { +namespace snapshot { + +class NoDecompressor final : public IDecompressor { + public: + bool Decompress(size_t) override; +}; + +bool NoDecompressor::Decompress(size_t) { + size_t stream_remaining = stream_->Size(); + while (stream_remaining) { + size_t buffer_size = stream_remaining; + uint8_t* buffer = reinterpret_cast<uint8_t*>(sink_->GetBuffer(buffer_size, &buffer_size)); + if (!buffer) { + LOG(ERROR) << "Could not acquire buffer from sink"; + return false; + } + + // Read until we can fill the buffer. + uint8_t* buffer_pos = buffer; + size_t bytes_to_read = std::min(buffer_size, stream_remaining); + while (bytes_to_read) { + size_t read; + if (!stream_->Read(buffer_pos, bytes_to_read, &read)) { + return false; + } + if (!read) { + LOG(ERROR) << "Stream ended prematurely"; + return false; + } + if (!sink_->ReturnData(buffer_pos, read)) { + LOG(ERROR) << "Could not return buffer to sink"; + return false; + } + buffer_pos += read; + bytes_to_read -= read; + stream_remaining -= read; + } + } + return true; +} + +std::unique_ptr<IDecompressor> IDecompressor::Uncompressed() { + return std::unique_ptr<IDecompressor>(new NoDecompressor()); +} + +// Read chunks of the COW and incrementally stream them to the decoder. +class StreamDecompressor : public IDecompressor { + public: + bool Decompress(size_t output_bytes) override; + + virtual bool Init() = 0; + virtual bool DecompressInput(const uint8_t* data, size_t length) = 0; + virtual bool Done() = 0; + + protected: + bool GetFreshBuffer(); + + size_t output_bytes_; + size_t stream_remaining_; + uint8_t* output_buffer_ = nullptr; + size_t output_buffer_remaining_ = 0; +}; + +static constexpr size_t kChunkSize = 4096; + +bool StreamDecompressor::Decompress(size_t output_bytes) { + if (!Init()) { + return false; + } + + stream_remaining_ = stream_->Size(); + output_bytes_ = output_bytes; + + uint8_t chunk[kChunkSize]; + while (stream_remaining_) { + size_t read = std::min(stream_remaining_, sizeof(chunk)); + if (!stream_->Read(chunk, read, &read)) { + return false; + } + if (!read) { + LOG(ERROR) << "Stream ended prematurely"; + return false; + } + if (!DecompressInput(chunk, read)) { + return false; + } + + stream_remaining_ -= read; + + if (stream_remaining_ && Done()) { + LOG(ERROR) << "Decompressor terminated early"; + return false; + } + } + if (!Done()) { + LOG(ERROR) << "Decompressor expected more bytes"; + return false; + } + return true; +} + +bool StreamDecompressor::GetFreshBuffer() { + size_t request_size = std::min(output_bytes_, kChunkSize); + output_buffer_ = + reinterpret_cast<uint8_t*>(sink_->GetBuffer(request_size, &output_buffer_remaining_)); + if (!output_buffer_) { + LOG(ERROR) << "Could not acquire buffer from sink"; + return false; + } + return true; +} + +class GzDecompressor final : public StreamDecompressor { + public: + ~GzDecompressor(); + + bool Init() override; + bool DecompressInput(const uint8_t* data, size_t length) override; + bool Done() override { return ended_; } + + private: + z_stream z_ = {}; + bool ended_ = false; +}; + +bool GzDecompressor::Init() { + if (int rv = inflateInit(&z_); rv != Z_OK) { + LOG(ERROR) << "inflateInit returned error code " << rv; + return false; + } + return true; +} + +GzDecompressor::~GzDecompressor() { + inflateEnd(&z_); +} + +bool GzDecompressor::DecompressInput(const uint8_t* data, size_t length) { + z_.next_in = reinterpret_cast<Bytef*>(const_cast<uint8_t*>(data)); + z_.avail_in = length; + + while (z_.avail_in) { + // If no more output buffer, grab a new buffer. + if (z_.avail_out == 0) { + if (!GetFreshBuffer()) { + return false; + } + z_.next_out = reinterpret_cast<Bytef*>(output_buffer_); + z_.avail_out = output_buffer_remaining_; + } + + // Remember the position of the output buffer so we can call ReturnData. + auto avail_out = z_.avail_out; + + // Decompress. + int rv = inflate(&z_, Z_NO_FLUSH); + if (rv != Z_OK && rv != Z_STREAM_END) { + LOG(ERROR) << "inflate returned error code " << rv; + return false; + } + + size_t returned = avail_out - z_.avail_out; + if (!sink_->ReturnData(output_buffer_, returned)) { + LOG(ERROR) << "Could not return buffer to sink"; + return false; + } + output_buffer_ += returned; + output_buffer_remaining_ -= returned; + + if (rv == Z_STREAM_END) { + if (z_.avail_in) { + LOG(ERROR) << "Gz stream ended prematurely"; + return false; + } + ended_ = true; + return true; + } + } + return true; +} + +std::unique_ptr<IDecompressor> IDecompressor::Gz() { + return std::unique_ptr<IDecompressor>(new GzDecompressor()); +} + +class BrotliDecompressor final : public StreamDecompressor { + public: + ~BrotliDecompressor(); + + bool Init() override; + bool DecompressInput(const uint8_t* data, size_t length) override; + bool Done() override { return BrotliDecoderIsFinished(decoder_); } + + private: + BrotliDecoderState* decoder_ = nullptr; +}; + +bool BrotliDecompressor::Init() { + decoder_ = BrotliDecoderCreateInstance(nullptr, nullptr, nullptr); + return true; +} + +BrotliDecompressor::~BrotliDecompressor() { + if (decoder_) { + BrotliDecoderDestroyInstance(decoder_); + } +} + +bool BrotliDecompressor::DecompressInput(const uint8_t* data, size_t length) { + size_t available_in = length; + const uint8_t* next_in = data; + + bool needs_more_output = false; + while (available_in || needs_more_output) { + if (!output_buffer_remaining_ && !GetFreshBuffer()) { + return false; + } + + auto output_buffer = output_buffer_; + auto r = BrotliDecoderDecompressStream(decoder_, &available_in, &next_in, + &output_buffer_remaining_, &output_buffer_, nullptr); + if (r == BROTLI_DECODER_RESULT_ERROR) { + LOG(ERROR) << "brotli decode failed"; + return false; + } + if (!sink_->ReturnData(output_buffer, output_buffer_ - output_buffer)) { + return false; + } + needs_more_output = (r == BROTLI_DECODER_RESULT_NEEDS_MORE_OUTPUT); + } + return true; +} + +std::unique_ptr<IDecompressor> IDecompressor::Brotli() { + return std::unique_ptr<IDecompressor>(new BrotliDecompressor()); +} + +} // namespace snapshot +} // namespace android |