diff options
Diffstat (limited to 'neuralnetworks/utils/adapter/aidl/src/Device.cpp')
-rw-r--r-- | neuralnetworks/utils/adapter/aidl/src/Device.cpp | 47 |
1 files changed, 35 insertions, 12 deletions
diff --git a/neuralnetworks/utils/adapter/aidl/src/Device.cpp b/neuralnetworks/utils/adapter/aidl/src/Device.cpp index 763be7f3fa..84aaddbe9d 100644 --- a/neuralnetworks/utils/adapter/aidl/src/Device.cpp +++ b/neuralnetworks/utils/adapter/aidl/src/Device.cpp @@ -148,13 +148,14 @@ void notify(IPreparedModelCallback* callback, PrepareModelResult result) { } } -nn::GeneralResult<void> prepareModel(const nn::SharedDevice& device, const Executor& executor, - const Model& model, ExecutionPreference preference, - Priority priority, int64_t deadlineNs, - const std::vector<ndk::ScopedFileDescriptor>& modelCache, - const std::vector<ndk::ScopedFileDescriptor>& dataCache, - const std::vector<uint8_t>& token, - const std::shared_ptr<IPreparedModelCallback>& callback) { +nn::GeneralResult<void> prepareModel( + const nn::SharedDevice& device, const Executor& executor, const Model& model, + ExecutionPreference preference, Priority priority, int64_t deadlineNs, + const std::vector<ndk::ScopedFileDescriptor>& modelCache, + const std::vector<ndk::ScopedFileDescriptor>& dataCache, const std::vector<uint8_t>& token, + const std::vector<TokenValuePair>& hints, + const std::vector<ExtensionNameAndPrefix>& extensionNameToPrefix, + const std::shared_ptr<IPreparedModelCallback>& callback) { if (callback.get() == nullptr) { return NN_ERROR(nn::ErrorStatus::INVALID_ARGUMENT) << "Invalid callback"; } @@ -166,12 +167,16 @@ nn::GeneralResult<void> prepareModel(const nn::SharedDevice& device, const Execu auto nnModelCache = NN_TRY(convertInput(modelCache)); auto nnDataCache = NN_TRY(convertInput(dataCache)); const auto nnToken = NN_TRY(convertCacheToken(token)); + auto nnHints = NN_TRY(convertInput(hints)); + auto nnExtensionNameToPrefix = NN_TRY(convertInput(extensionNameToPrefix)); Task task = [device, nnModel = std::move(nnModel), nnPreference, nnPriority, nnDeadline, nnModelCache = std::move(nnModelCache), nnDataCache = std::move(nnDataCache), - nnToken, callback] { - auto result = device->prepareModel(nnModel, nnPreference, nnPriority, nnDeadline, - nnModelCache, nnDataCache, nnToken); + nnToken, nnHints = std::move(nnHints), + nnExtensionNameToPrefix = std::move(nnExtensionNameToPrefix), callback] { + auto result = + device->prepareModel(nnModel, nnPreference, nnPriority, nnDeadline, nnModelCache, + nnDataCache, nnToken, nnHints, nnExtensionNameToPrefix); notify(callback.get(), std::move(result)); }; executor(std::move(task), nnDeadline); @@ -273,8 +278,9 @@ ndk::ScopedAStatus Device::prepareModel(const Model& model, ExecutionPreference const std::vector<ndk::ScopedFileDescriptor>& dataCache, const std::vector<uint8_t>& token, const std::shared_ptr<IPreparedModelCallback>& callback) { - const auto result = adapter::prepareModel(kDevice, kExecutor, model, preference, priority, - deadlineNs, modelCache, dataCache, token, callback); + const auto result = + adapter::prepareModel(kDevice, kExecutor, model, preference, priority, deadlineNs, + modelCache, dataCache, token, {}, {}, callback); if (!result.has_value()) { const auto& [message, code] = result.error(); const auto aidlCode = utils::convert(code).value_or(ErrorStatus::GENERAL_FAILURE); @@ -301,4 +307,21 @@ ndk::ScopedAStatus Device::prepareModelFromCache( return ndk::ScopedAStatus::ok(); } +ndk::ScopedAStatus Device::prepareModelWithConfig( + const Model& model, const PrepareModelConfig& config, + const std::shared_ptr<IPreparedModelCallback>& callback) { + const auto result = adapter::prepareModel( + kDevice, kExecutor, model, config.preference, config.priority, config.deadlineNs, + config.modelCache, config.dataCache, config.cacheToken, config.compilationHints, + config.extensionNameToPrefix, callback); + if (!result.has_value()) { + const auto& [message, code] = result.error(); + const auto aidlCode = utils::convert(code).value_or(ErrorStatus::GENERAL_FAILURE); + callback->notify(aidlCode, nullptr); + return ndk::ScopedAStatus::fromServiceSpecificErrorWithMessage( + static_cast<int32_t>(aidlCode), message.c_str()); + } + return ndk::ScopedAStatus::ok(); +} + } // namespace aidl::android::hardware::neuralnetworks::adapter |