diff options
Diffstat (limited to 'neuralnetworks/utils/common/src/ResilientBuffer.cpp')
-rw-r--r-- | neuralnetworks/utils/common/src/ResilientBuffer.cpp | 48 |
1 files changed, 44 insertions, 4 deletions
diff --git a/neuralnetworks/utils/common/src/ResilientBuffer.cpp b/neuralnetworks/utils/common/src/ResilientBuffer.cpp index cf5496ac39..47abbe268f 100644 --- a/neuralnetworks/utils/common/src/ResilientBuffer.cpp +++ b/neuralnetworks/utils/common/src/ResilientBuffer.cpp @@ -20,6 +20,7 @@ #include <android-base/thread_annotations.h> #include <nnapi/IBuffer.h> #include <nnapi/Result.h> +#include <nnapi/TypeUtils.h> #include <nnapi/Types.h> #include <functional> @@ -29,6 +30,34 @@ #include <vector> namespace android::hardware::neuralnetworks::utils { +namespace { + +template <typename FnType> +auto protect(const ResilientBuffer& resilientBuffer, const FnType& fn) + -> decltype(fn(*resilientBuffer.getBuffer())) { + auto buffer = resilientBuffer.getBuffer(); + auto result = fn(*buffer); + + // Immediately return if device is not dead. + if (result.has_value() || result.error().code != nn::ErrorStatus::DEAD_OBJECT) { + return result; + } + + // Attempt recovery and return if it fails. + auto maybeBuffer = resilientBuffer.recover(buffer.get()); + if (!maybeBuffer.has_value()) { + const auto& [resultErrorMessage, resultErrorCode] = result.error(); + const auto& [recoveryErrorMessage, recoveryErrorCode] = maybeBuffer.error(); + return nn::error(resultErrorCode) + << resultErrorMessage << ", and failed to recover dead buffer with error " + << recoveryErrorCode << ": " << recoveryErrorMessage; + } + buffer = std::move(maybeBuffer).value(); + + return fn(*buffer); +} + +} // namespace nn::GeneralResult<std::shared_ptr<const ResilientBuffer>> ResilientBuffer::create( Factory makeBuffer) { @@ -53,9 +82,16 @@ nn::SharedBuffer ResilientBuffer::getBuffer() const { std::lock_guard guard(mMutex); return mBuffer; } -nn::SharedBuffer ResilientBuffer::recover(const nn::IBuffer* /*failingBuffer*/, - bool /*blocking*/) const { +nn::GeneralResult<nn::SharedBuffer> ResilientBuffer::recover( + const nn::IBuffer* failingBuffer) const { std::lock_guard guard(mMutex); + + // Another caller updated the failing prepared model. + if (mBuffer.get() != failingBuffer) { + return mBuffer; + } + + mBuffer = NN_TRY(kMakeBuffer()); return mBuffer; } @@ -64,12 +100,16 @@ nn::Request::MemoryDomainToken ResilientBuffer::getToken() const { } nn::GeneralResult<void> ResilientBuffer::copyTo(const nn::Memory& dst) const { - return getBuffer()->copyTo(dst); + const auto fn = [&dst](const nn::IBuffer& buffer) { return buffer.copyTo(dst); }; + return protect(*this, fn); } nn::GeneralResult<void> ResilientBuffer::copyFrom(const nn::Memory& src, const nn::Dimensions& dimensions) const { - return getBuffer()->copyFrom(src, dimensions); + const auto fn = [&src, &dimensions](const nn::IBuffer& buffer) { + return buffer.copyFrom(src, dimensions); + }; + return protect(*this, fn); } } // namespace android::hardware::neuralnetworks::utils |