summaryrefslogtreecommitdiff
path: root/neuralnetworks/utils/adapter/hidl/src/PreparedModel.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'neuralnetworks/utils/adapter/hidl/src/PreparedModel.cpp')
-rw-r--r--neuralnetworks/utils/adapter/hidl/src/PreparedModel.cpp89
1 files changed, 32 insertions, 57 deletions
diff --git a/neuralnetworks/utils/adapter/hidl/src/PreparedModel.cpp b/neuralnetworks/utils/adapter/hidl/src/PreparedModel.cpp
index c6055a6747..3570a74a6d 100644
--- a/neuralnetworks/utils/adapter/hidl/src/PreparedModel.cpp
+++ b/neuralnetworks/utils/adapter/hidl/src/PreparedModel.cpp
@@ -55,15 +55,6 @@ auto convertInput(const Type& object) -> decltype(nn::convert(std::declval<Type>
return result;
}
-nn::GeneralResult<nn::Version> validateRequestForModel(const nn::Request& request,
- const nn::Model& model) {
- nn::GeneralResult<nn::Version> version = nn::validateRequestForModel(request, model);
- if (!version.ok()) {
- version.error().code = nn::ErrorStatus::INVALID_ARGUMENT;
- }
- return version;
-}
-
class FencedExecutionCallback final : public V1_3::IFencedExecutionCallback {
public:
explicit FencedExecutionCallback(const nn::ExecuteFencedInfoCallback& callback)
@@ -144,58 +135,48 @@ void notify(CallbackType* callback, ExecutionResult result) {
}
nn::GeneralResult<void> execute(const nn::SharedPreparedModel& preparedModel,
- const Executor& executor, const V1_0::Request& request,
+ const V1_0::Request& request,
const sp<V1_0::IExecutionCallback>& callback) {
if (callback.get() == nullptr) {
return NN_ERROR(nn::ErrorStatus::INVALID_ARGUMENT) << "Invalid callback";
}
- auto nnRequest = NN_TRY(convertInput(request));
+ const auto nnRequest = NN_TRY(convertInput(request));
- const std::any resource = preparedModel->getUnderlyingResource();
- if (const auto* model = std::any_cast<const nn::Model*>(&resource)) {
- CHECK(*model != nullptr);
- NN_TRY(adapter::validateRequestForModel(nnRequest, **model));
- }
+ auto result = preparedModel->execute(nnRequest, nn::MeasureTiming::NO, {}, {}, {}, {});
- Task task = [preparedModel, nnRequest = std::move(nnRequest), callback] {
- auto result = preparedModel->execute(nnRequest, nn::MeasureTiming::NO, {}, {}, {}, {});
- notify(callback.get(), std::move(result));
- };
- executor(std::move(task), {});
+ if (!result.ok() && result.error().code == nn::ErrorStatus::INVALID_ARGUMENT) {
+ const auto& [message, code, outputShapes] = result.error();
+ return nn::error(code) << message;
+ }
+ notify(callback.get(), std::move(result));
return {};
}
nn::GeneralResult<void> execute_1_2(const nn::SharedPreparedModel& preparedModel,
- const Executor& executor, const V1_0::Request& request,
- V1_2::MeasureTiming measure,
+ const V1_0::Request& request, V1_2::MeasureTiming measure,
const sp<V1_2::IExecutionCallback>& callback) {
if (callback.get() == nullptr) {
return NN_ERROR(nn::ErrorStatus::INVALID_ARGUMENT) << "Invalid callback";
}
- auto nnRequest = NN_TRY(convertInput(request));
+ const auto nnRequest = NN_TRY(convertInput(request));
const auto nnMeasure = NN_TRY(convertInput(measure));
- const std::any resource = preparedModel->getUnderlyingResource();
- if (const auto* model = std::any_cast<const nn::Model*>(&resource)) {
- CHECK(*model != nullptr);
- NN_TRY(adapter::validateRequestForModel(nnRequest, **model));
- }
+ auto result = preparedModel->execute(nnRequest, nnMeasure, {}, {}, {}, {});
- Task task = [preparedModel, nnRequest = std::move(nnRequest), nnMeasure, callback] {
- auto result = preparedModel->execute(nnRequest, nnMeasure, {}, {}, {}, {});
- notify(callback.get(), std::move(result));
- };
- executor(std::move(task), {});
+ if (!result.ok() && result.error().code == nn::ErrorStatus::INVALID_ARGUMENT) {
+ const auto& [message, code, outputShapes] = result.error();
+ return nn::error(code) << message;
+ }
+ notify(callback.get(), std::move(result));
return {};
}
nn::GeneralResult<void> execute_1_3(const nn::SharedPreparedModel& preparedModel,
- const Executor& executor, const V1_3::Request& request,
- V1_2::MeasureTiming measure,
+ const V1_3::Request& request, V1_2::MeasureTiming measure,
const V1_3::OptionalTimePoint& deadline,
const V1_3::OptionalTimeoutDuration& loopTimeoutDuration,
const sp<V1_3::IExecutionCallback>& callback) {
@@ -203,25 +184,20 @@ nn::GeneralResult<void> execute_1_3(const nn::SharedPreparedModel& preparedModel
return NN_ERROR(nn::ErrorStatus::INVALID_ARGUMENT) << "Invalid callback";
}
- auto nnRequest = NN_TRY(convertInput(request));
+ const auto nnRequest = NN_TRY(convertInput(request));
const auto nnMeasure = NN_TRY(convertInput(measure));
const auto nnDeadline = NN_TRY(convertInput(deadline));
const auto nnLoopTimeoutDuration = NN_TRY(convertInput(loopTimeoutDuration));
- const std::any resource = preparedModel->getUnderlyingResource();
- if (const auto* model = std::any_cast<const nn::Model*>(&resource)) {
- CHECK(*model != nullptr);
- NN_TRY(adapter::validateRequestForModel(nnRequest, **model));
- }
+ auto result =
+ preparedModel->execute(nnRequest, nnMeasure, nnDeadline, nnLoopTimeoutDuration, {}, {});
- Task task = [preparedModel, nnRequest = std::move(nnRequest), nnMeasure, nnDeadline,
- nnLoopTimeoutDuration, callback] {
- auto result = preparedModel->execute(nnRequest, nnMeasure, nnDeadline,
- nnLoopTimeoutDuration, {}, {});
- notify(callback.get(), std::move(result));
- };
- executor(std::move(task), nnDeadline);
+ if (!result.ok() && result.error().code == nn::ErrorStatus::INVALID_ARGUMENT) {
+ const auto& [message, code, outputShapes] = result.error();
+ return nn::error(code) << message;
+ }
+ notify(callback.get(), std::move(result));
return {};
}
@@ -304,10 +280,9 @@ nn::GeneralResult<std::pair<hidl_handle, sp<V1_3::IFencedExecutionCallback>>> ex
} // namespace
-PreparedModel::PreparedModel(nn::SharedPreparedModel preparedModel, Executor executor)
- : kPreparedModel(std::move(preparedModel)), kExecutor(std::move(executor)) {
+PreparedModel::PreparedModel(nn::SharedPreparedModel preparedModel)
+ : kPreparedModel(std::move(preparedModel)) {
CHECK(kPreparedModel != nullptr);
- CHECK(kExecutor != nullptr);
}
nn::SharedPreparedModel PreparedModel::getUnderlyingPreparedModel() const {
@@ -316,7 +291,7 @@ nn::SharedPreparedModel PreparedModel::getUnderlyingPreparedModel() const {
Return<V1_0::ErrorStatus> PreparedModel::execute(const V1_0::Request& request,
const sp<V1_0::IExecutionCallback>& callback) {
- auto result = adapter::execute(kPreparedModel, kExecutor, request, callback);
+ auto result = adapter::execute(kPreparedModel, request, callback);
if (!result.has_value()) {
auto [message, code] = std::move(result).error();
LOG(ERROR) << "adapter::PreparedModel::execute failed with " << code << ": " << message;
@@ -329,7 +304,7 @@ Return<V1_0::ErrorStatus> PreparedModel::execute(const V1_0::Request& request,
Return<V1_0::ErrorStatus> PreparedModel::execute_1_2(const V1_0::Request& request,
V1_2::MeasureTiming measure,
const sp<V1_2::IExecutionCallback>& callback) {
- auto result = adapter::execute_1_2(kPreparedModel, kExecutor, request, measure, callback);
+ auto result = adapter::execute_1_2(kPreparedModel, request, measure, callback);
if (!result.has_value()) {
auto [message, code] = std::move(result).error();
LOG(ERROR) << "adapter::PreparedModel::execute_1_2 failed with " << code << ": " << message;
@@ -344,7 +319,7 @@ Return<V1_3::ErrorStatus> PreparedModel::execute_1_3(
const V1_3::OptionalTimePoint& deadline,
const V1_3::OptionalTimeoutDuration& loopTimeoutDuration,
const sp<V1_3::IExecutionCallback>& callback) {
- auto result = adapter::execute_1_3(kPreparedModel, kExecutor, request, measure, deadline,
+ auto result = adapter::execute_1_3(kPreparedModel, request, measure, deadline,
loopTimeoutDuration, callback);
if (!result.has_value()) {
auto [message, code] = std::move(result).error();
@@ -405,8 +380,8 @@ Return<void> PreparedModel::configureExecutionBurst(
cb(V1_2::utils::convert(code).value(), nullptr);
return Void();
}
- auto burstContext = std::move(result).value();
- cb(V1_0::ErrorStatus::NONE, std::move(burstContext));
+ const auto burstContext = std::move(result).value();
+ cb(V1_0::ErrorStatus::NONE, burstContext);
return Void();
}