summaryrefslogtreecommitdiff
path: root/neuralnetworks/utils/common/src/ResilientBuffer.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'neuralnetworks/utils/common/src/ResilientBuffer.cpp')
-rw-r--r--neuralnetworks/utils/common/src/ResilientBuffer.cpp48
1 files changed, 44 insertions, 4 deletions
diff --git a/neuralnetworks/utils/common/src/ResilientBuffer.cpp b/neuralnetworks/utils/common/src/ResilientBuffer.cpp
index cf5496ac39..47abbe268f 100644
--- a/neuralnetworks/utils/common/src/ResilientBuffer.cpp
+++ b/neuralnetworks/utils/common/src/ResilientBuffer.cpp
@@ -20,6 +20,7 @@
#include <android-base/thread_annotations.h>
#include <nnapi/IBuffer.h>
#include <nnapi/Result.h>
+#include <nnapi/TypeUtils.h>
#include <nnapi/Types.h>
#include <functional>
@@ -29,6 +30,34 @@
#include <vector>
namespace android::hardware::neuralnetworks::utils {
+namespace {
+
+template <typename FnType>
+auto protect(const ResilientBuffer& resilientBuffer, const FnType& fn)
+ -> decltype(fn(*resilientBuffer.getBuffer())) {
+ auto buffer = resilientBuffer.getBuffer();
+ auto result = fn(*buffer);
+
+ // Immediately return if device is not dead.
+ if (result.has_value() || result.error().code != nn::ErrorStatus::DEAD_OBJECT) {
+ return result;
+ }
+
+ // Attempt recovery and return if it fails.
+ auto maybeBuffer = resilientBuffer.recover(buffer.get());
+ if (!maybeBuffer.has_value()) {
+ const auto& [resultErrorMessage, resultErrorCode] = result.error();
+ const auto& [recoveryErrorMessage, recoveryErrorCode] = maybeBuffer.error();
+ return nn::error(resultErrorCode)
+ << resultErrorMessage << ", and failed to recover dead buffer with error "
+ << recoveryErrorCode << ": " << recoveryErrorMessage;
+ }
+ buffer = std::move(maybeBuffer).value();
+
+ return fn(*buffer);
+}
+
+} // namespace
nn::GeneralResult<std::shared_ptr<const ResilientBuffer>> ResilientBuffer::create(
Factory makeBuffer) {
@@ -53,9 +82,16 @@ nn::SharedBuffer ResilientBuffer::getBuffer() const {
std::lock_guard guard(mMutex);
return mBuffer;
}
-nn::SharedBuffer ResilientBuffer::recover(const nn::IBuffer* /*failingBuffer*/,
- bool /*blocking*/) const {
+nn::GeneralResult<nn::SharedBuffer> ResilientBuffer::recover(
+ const nn::IBuffer* failingBuffer) const {
std::lock_guard guard(mMutex);
+
+ // Another caller updated the failing prepared model.
+ if (mBuffer.get() != failingBuffer) {
+ return mBuffer;
+ }
+
+ mBuffer = NN_TRY(kMakeBuffer());
return mBuffer;
}
@@ -64,12 +100,16 @@ nn::Request::MemoryDomainToken ResilientBuffer::getToken() const {
}
nn::GeneralResult<void> ResilientBuffer::copyTo(const nn::Memory& dst) const {
- return getBuffer()->copyTo(dst);
+ const auto fn = [&dst](const nn::IBuffer& buffer) { return buffer.copyTo(dst); };
+ return protect(*this, fn);
}
nn::GeneralResult<void> ResilientBuffer::copyFrom(const nn::Memory& src,
const nn::Dimensions& dimensions) const {
- return getBuffer()->copyFrom(src, dimensions);
+ const auto fn = [&src, &dimensions](const nn::IBuffer& buffer) {
+ return buffer.copyFrom(src, dimensions);
+ };
+ return protect(*this, fn);
}
} // namespace android::hardware::neuralnetworks::utils