summaryrefslogtreecommitdiff
path: root/neuralnetworks/utils/common/src/ResilientPreparedModel.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'neuralnetworks/utils/common/src/ResilientPreparedModel.cpp')
-rw-r--r--neuralnetworks/utils/common/src/ResilientPreparedModel.cpp85
1 files changed, 80 insertions, 5 deletions
diff --git a/neuralnetworks/utils/common/src/ResilientPreparedModel.cpp b/neuralnetworks/utils/common/src/ResilientPreparedModel.cpp
index b8acee16c9..5dd5f99f5f 100644
--- a/neuralnetworks/utils/common/src/ResilientPreparedModel.cpp
+++ b/neuralnetworks/utils/common/src/ResilientPreparedModel.cpp
@@ -16,19 +16,52 @@
#include "ResilientPreparedModel.h"
+#include "InvalidBurst.h"
+#include "ResilientBurst.h"
+
#include <android-base/logging.h>
#include <android-base/thread_annotations.h>
#include <nnapi/IPreparedModel.h>
#include <nnapi/Result.h>
+#include <nnapi/TypeUtils.h>
#include <nnapi/Types.h>
#include <functional>
#include <memory>
#include <mutex>
+#include <sstream>
#include <utility>
#include <vector>
namespace android::hardware::neuralnetworks::utils {
+namespace {
+
+template <typename FnType>
+auto protect(const ResilientPreparedModel& resilientPreparedModel, const FnType& fn)
+ -> decltype(fn(*resilientPreparedModel.getPreparedModel())) {
+ auto preparedModel = resilientPreparedModel.getPreparedModel();
+ auto result = fn(*preparedModel);
+
+ // Immediately return if prepared model is not dead.
+ if (result.has_value() || result.error().code != nn::ErrorStatus::DEAD_OBJECT) {
+ return result;
+ }
+
+ // Attempt recovery and return if it fails.
+ auto maybePreparedModel = resilientPreparedModel.recover(preparedModel.get());
+ if (!maybePreparedModel.has_value()) {
+ const auto& [message, code] = maybePreparedModel.error();
+ std::ostringstream oss;
+ oss << ", and failed to recover dead prepared model with error " << code << ": " << message;
+ result.error().message += oss.str();
+ return result;
+ }
+ preparedModel = std::move(maybePreparedModel).value();
+
+ return fn(*preparedModel);
+}
+
+} // namespace
nn::GeneralResult<std::shared_ptr<const ResilientPreparedModel>> ResilientPreparedModel::create(
Factory makePreparedModel) {
@@ -55,9 +88,16 @@ nn::SharedPreparedModel ResilientPreparedModel::getPreparedModel() const {
return mPreparedModel;
}
-nn::SharedPreparedModel ResilientPreparedModel::recover(
- const nn::IPreparedModel* /*failingPreparedModel*/, bool /*blocking*/) const {
+nn::GeneralResult<nn::SharedPreparedModel> ResilientPreparedModel::recover(
+ const nn::IPreparedModel* failingPreparedModel) const {
std::lock_guard guard(mMutex);
+
+ // Another caller updated the failing prepared model.
+ if (mPreparedModel.get() != failingPreparedModel) {
+ return mPreparedModel;
+ }
+
+ mPreparedModel = NN_TRY(kMakePreparedModel());
return mPreparedModel;
}
@@ -65,7 +105,11 @@ nn::ExecutionResult<std::pair<std::vector<nn::OutputShape>, nn::Timing>>
ResilientPreparedModel::execute(const nn::Request& request, nn::MeasureTiming measure,
const nn::OptionalTimePoint& deadline,
const nn::OptionalDuration& loopTimeoutDuration) const {
- return getPreparedModel()->execute(request, measure, deadline, loopTimeoutDuration);
+ const auto fn = [&request, measure, &deadline,
+ &loopTimeoutDuration](const nn::IPreparedModel& preparedModel) {
+ return preparedModel.execute(request, measure, deadline, loopTimeoutDuration);
+ };
+ return protect(*this, fn);
}
nn::GeneralResult<std::pair<nn::SyncFence, nn::ExecuteFencedInfoCallback>>
@@ -75,12 +119,43 @@ ResilientPreparedModel::executeFenced(const nn::Request& request,
const nn::OptionalTimePoint& deadline,
const nn::OptionalDuration& loopTimeoutDuration,
const nn::OptionalDuration& timeoutDurationAfterFence) const {
- return getPreparedModel()->executeFenced(request, waitFor, measure, deadline,
- loopTimeoutDuration, timeoutDurationAfterFence);
+ const auto fn = [&request, &waitFor, measure, &deadline, &loopTimeoutDuration,
+ &timeoutDurationAfterFence](const nn::IPreparedModel& preparedModel) {
+ return preparedModel.executeFenced(request, waitFor, measure, deadline, loopTimeoutDuration,
+ timeoutDurationAfterFence);
+ };
+ return protect(*this, fn);
+}
+
+nn::GeneralResult<nn::SharedBurst> ResilientPreparedModel::configureExecutionBurst() const {
+#if 0
+ auto self = shared_from_this();
+ ResilientBurst::Factory makeBurst =
+ [preparedModel = std::move(self)]() -> nn::GeneralResult<nn::SharedBurst> {
+ return preparedModel->configureExecutionBurst();
+ };
+ return ResilientBurst::create(std::move(makeBurst));
+#else
+ return configureExecutionBurstInternal();
+#endif
}
std::any ResilientPreparedModel::getUnderlyingResource() const {
return getPreparedModel()->getUnderlyingResource();
}
+bool ResilientPreparedModel::isValidInternal() const {
+ return true;
+}
+
+nn::GeneralResult<nn::SharedBurst> ResilientPreparedModel::configureExecutionBurstInternal() const {
+ if (!isValidInternal()) {
+ return std::make_shared<const InvalidBurst>();
+ }
+ const auto fn = [](const nn::IPreparedModel& preparedModel) {
+ return preparedModel.configureExecutionBurst();
+ };
+ return protect(*this, fn);
+}
+
} // namespace android::hardware::neuralnetworks::utils