diff options
Diffstat (limited to 'neuralnetworks/utils/adapter/aidl/src/PreparedModel.cpp')
-rw-r--r-- | neuralnetworks/utils/adapter/aidl/src/PreparedModel.cpp | 77 |
1 files changed, 62 insertions, 15 deletions
diff --git a/neuralnetworks/utils/adapter/aidl/src/PreparedModel.cpp b/neuralnetworks/utils/adapter/aidl/src/PreparedModel.cpp index 5cab62c625..790558fde4 100644 --- a/neuralnetworks/utils/adapter/aidl/src/PreparedModel.cpp +++ b/neuralnetworks/utils/adapter/aidl/src/PreparedModel.cpp @@ -118,17 +118,20 @@ nn::GeneralResult<nn::OptionalTimePoint> makeOptionalTimePoint(int64_t durationN return durationNs < 0 ? nn::OptionalTimePoint{} : nn::TimePoint(makeDuration(durationNs)); } -nn::ExecutionResult<ExecutionResult> executeSynchronously(const nn::IPreparedModel& preparedModel, - const Request& request, - bool measureTiming, int64_t deadlineNs, - int64_t loopTimeoutDurationNs) { +nn::ExecutionResult<ExecutionResult> executeSynchronously( + const nn::IPreparedModel& preparedModel, const Request& request, bool measureTiming, + int64_t deadlineNs, int64_t loopTimeoutDurationNs, const std::vector<TokenValuePair>& hints, + const std::vector<ExtensionNameAndPrefix>& extensionNameToPrefix) { const auto nnRequest = NN_TRY(convertInput(request)); const auto nnMeasureTiming = measureTiming ? nn::MeasureTiming::YES : nn::MeasureTiming::NO; const auto nnDeadline = NN_TRY(makeOptionalTimePoint(deadlineNs)); const auto nnLoopTimeoutDuration = NN_TRY(makeOptionalDuration(loopTimeoutDurationNs)); + auto nnHints = NN_TRY(convertInput(hints)); + auto nnExtensionNameToPrefix = NN_TRY(convertInput(extensionNameToPrefix)); const auto result = - preparedModel.execute(nnRequest, nnMeasureTiming, nnDeadline, nnLoopTimeoutDuration); + preparedModel.execute(nnRequest, nnMeasureTiming, nnDeadline, nnLoopTimeoutDuration, + nnHints, nnExtensionNameToPrefix); if (!result.ok() && result.error().code == nn::ErrorStatus::OUTPUT_INSUFFICIENT_SIZE) { const auto& [message, code, outputShapes] = result.error(); @@ -147,16 +150,21 @@ nn::ExecutionResult<ExecutionResult> executeSynchronously(const nn::IPreparedMod nn::GeneralResult<FencedExecutionResult> executeFenced( const nn::IPreparedModel& preparedModel, const Request& request, const std::vector<ndk::ScopedFileDescriptor>& waitFor, bool measureTiming, - int64_t deadlineNs, int64_t loopTimeoutDurationNs, int64_t durationNs) { + int64_t deadlineNs, int64_t loopTimeoutDurationNs, int64_t durationNs, + const std::vector<TokenValuePair>& hints, + const std::vector<ExtensionNameAndPrefix>& extensionNameToPrefix) { const auto nnRequest = NN_TRY(convertInput(request)); const auto nnWaitFor = NN_TRY(convertSyncFences(waitFor)); const auto nnMeasureTiming = measureTiming ? nn::MeasureTiming::YES : nn::MeasureTiming::NO; const auto nnDeadline = NN_TRY(makeOptionalTimePoint(deadlineNs)); const auto nnLoopTimeoutDuration = NN_TRY(makeOptionalDuration(loopTimeoutDurationNs)); const auto nnDuration = NN_TRY(makeOptionalDuration(durationNs)); + auto nnHints = NN_TRY(convertInput(hints)); + auto nnExtensionNameToPrefix = NN_TRY(convertInput(extensionNameToPrefix)); auto [syncFence, executeFencedInfoCallback] = NN_TRY(preparedModel.executeFenced( - nnRequest, nnWaitFor, nnMeasureTiming, nnDeadline, nnLoopTimeoutDuration, nnDuration)); + nnRequest, nnWaitFor, nnMeasureTiming, nnDeadline, nnLoopTimeoutDuration, nnDuration, + nnHints, nnExtensionNameToPrefix)); ndk::ScopedFileDescriptor fileDescriptor; if (syncFence.hasFd()) { @@ -171,11 +179,16 @@ nn::GeneralResult<FencedExecutionResult> executeFenced( nn::GeneralResult<nn::SharedExecution> createReusableExecution( const nn::IPreparedModel& preparedModel, const Request& request, bool measureTiming, - int64_t loopTimeoutDurationNs) { + int64_t loopTimeoutDurationNs, const std::vector<TokenValuePair>& hints, + const std::vector<ExtensionNameAndPrefix>& extensionNameToPrefix) { const auto nnRequest = NN_TRY(convertInput(request)); const auto nnMeasureTiming = measureTiming ? nn::MeasureTiming::YES : nn::MeasureTiming::NO; const auto nnLoopTimeoutDuration = NN_TRY(makeOptionalDuration(loopTimeoutDurationNs)); - return preparedModel.createReusableExecution(nnRequest, nnMeasureTiming, nnLoopTimeoutDuration); + auto nnHints = NN_TRY(convertInput(hints)); + auto nnExtensionNameToPrefix = NN_TRY(convertInput(extensionNameToPrefix)); + + return preparedModel.createReusableExecution(nnRequest, nnMeasureTiming, nnLoopTimeoutDuration, + nnHints, nnExtensionNameToPrefix); } nn::ExecutionResult<ExecutionResult> executeSynchronously(const nn::IExecution& execution, @@ -231,7 +244,7 @@ ndk::ScopedAStatus PreparedModel::executeSynchronously(const Request& request, b int64_t loopTimeoutDurationNs, ExecutionResult* executionResult) { auto result = adapter::executeSynchronously(*kPreparedModel, request, measureTiming, deadlineNs, - loopTimeoutDurationNs); + loopTimeoutDurationNs, {}, {}); if (!result.has_value()) { const auto& [message, code, _] = result.error(); const auto aidlCode = utils::convert(code).value_or(ErrorStatus::GENERAL_FAILURE); @@ -247,7 +260,41 @@ ndk::ScopedAStatus PreparedModel::executeFenced( bool measureTiming, int64_t deadlineNs, int64_t loopTimeoutDurationNs, int64_t durationNs, FencedExecutionResult* executionResult) { auto result = adapter::executeFenced(*kPreparedModel, request, waitFor, measureTiming, - deadlineNs, loopTimeoutDurationNs, durationNs); + deadlineNs, loopTimeoutDurationNs, durationNs, {}, {}); + if (!result.has_value()) { + const auto& [message, code] = result.error(); + const auto aidlCode = utils::convert(code).value_or(ErrorStatus::GENERAL_FAILURE); + return ndk::ScopedAStatus::fromServiceSpecificErrorWithMessage( + static_cast<int32_t>(aidlCode), message.c_str()); + } + *executionResult = std::move(result).value(); + return ndk::ScopedAStatus::ok(); +} + +ndk::ScopedAStatus PreparedModel::executeSynchronouslyWithConfig(const Request& request, + const ExecutionConfig& config, + int64_t deadlineNs, + ExecutionResult* executionResult) { + auto result = adapter::executeSynchronously( + *kPreparedModel, request, config.measureTiming, deadlineNs, + config.loopTimeoutDurationNs, config.executionHints, config.extensionNameToPrefix); + if (!result.has_value()) { + const auto& [message, code, _] = result.error(); + const auto aidlCode = utils::convert(code).value_or(ErrorStatus::GENERAL_FAILURE); + return ndk::ScopedAStatus::fromServiceSpecificErrorWithMessage( + static_cast<int32_t>(aidlCode), message.c_str()); + } + *executionResult = std::move(result).value(); + return ndk::ScopedAStatus::ok(); +} + +ndk::ScopedAStatus PreparedModel::executeFencedWithConfig( + const Request& request, const std::vector<ndk::ScopedFileDescriptor>& waitFor, + const ExecutionConfig& config, int64_t deadlineNs, int64_t durationNs, + FencedExecutionResult* executionResult) { + auto result = adapter::executeFenced(*kPreparedModel, request, waitFor, config.measureTiming, + deadlineNs, config.loopTimeoutDurationNs, durationNs, + config.executionHints, config.extensionNameToPrefix); if (!result.has_value()) { const auto& [message, code] = result.error(); const auto aidlCode = utils::convert(code).value_or(ErrorStatus::GENERAL_FAILURE); @@ -275,11 +322,11 @@ nn::SharedPreparedModel PreparedModel::getUnderlyingPreparedModel() const { } ndk::ScopedAStatus PreparedModel::createReusableExecution(const Request& request, - bool measureTiming, - int64_t loopTimeoutDurationNs, + const ExecutionConfig& config, std::shared_ptr<IExecution>* execution) { - auto result = adapter::createReusableExecution(*kPreparedModel, request, measureTiming, - loopTimeoutDurationNs); + auto result = adapter::createReusableExecution( + *kPreparedModel, request, config.measureTiming, config.loopTimeoutDurationNs, + config.executionHints, config.extensionNameToPrefix); if (!result.has_value()) { const auto& [message, code] = result.error(); const auto aidlCode = utils::convert(code).value_or(ErrorStatus::GENERAL_FAILURE); |