/* * Copyright (C) 2021 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "Device.h" #include "Adapter.h" #include "Buffer.h" #include "PreparedModel.h" #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include namespace aidl::android::hardware::neuralnetworks::adapter { namespace { template auto convertInput(const Type& object) -> decltype(nn::convert(std::declval())) { auto result = nn::convert(object); if (!result.has_value()) { result.error().code = nn::ErrorStatus::INVALID_ARGUMENT; } return result; } nn::Duration makeDuration(int64_t durationNs) { return nn::Duration(std::chrono::nanoseconds(durationNs)); } nn::GeneralResult makeOptionalTimePoint(int64_t durationNs) { if (durationNs < -1) { return NN_ERROR(nn::ErrorStatus::INVALID_ARGUMENT) << "Invalid time point " << durationNs; } return durationNs < 0 ? nn::OptionalTimePoint{} : nn::TimePoint(makeDuration(durationNs)); } nn::GeneralResult convertCacheToken(const std::vector& token) { nn::CacheToken nnToken; if (token.size() != nnToken.size()) { return NN_ERROR(nn::ErrorStatus::INVALID_ARGUMENT) << "Invalid token"; } std::copy(token.begin(), token.end(), nnToken.begin()); return nnToken; } nn::GeneralResult downcast(const IPreparedModelParcel& preparedModel) { if (preparedModel.preparedModel == nullptr) { return NN_ERROR(nn::ErrorStatus::INVALID_ARGUMENT) << "preparedModel is nullptr"; } if (preparedModel.preparedModel->isRemote()) { return NN_ERROR(nn::ErrorStatus::INVALID_ARGUMENT) << "Cannot convert remote models"; } // This static_cast is safe because adapter::PreparedModel is the only class that implements // the IPreparedModel interface in the adapter service code. const auto* casted = static_cast(preparedModel.preparedModel.get()); return casted->getUnderlyingPreparedModel(); } nn::GeneralResult> downcastAll( const std::vector& preparedModels) { std::vector canonical; canonical.reserve(preparedModels.size()); for (const auto& preparedModel : preparedModels) { canonical.push_back(NN_TRY(downcast(preparedModel))); } return canonical; } nn::GeneralResult allocate(const nn::IDevice& device, const BufferDesc& desc, const std::vector& preparedModels, const std::vector& inputRoles, const std::vector& outputRoles) { auto nnDesc = NN_TRY(convertInput(desc)); auto nnPreparedModels = NN_TRY(downcastAll(preparedModels)); auto nnInputRoles = NN_TRY(convertInput(inputRoles)); auto nnOutputRoles = NN_TRY(convertInput(outputRoles)); auto buffer = NN_TRY(device.allocate(nnDesc, nnPreparedModels, nnInputRoles, nnOutputRoles)); CHECK(buffer != nullptr); const nn::Request::MemoryDomainToken token = buffer->getToken(); auto aidlBuffer = ndk::SharedRefBase::make(std::move(buffer)); return DeviceBuffer{.buffer = std::move(aidlBuffer), .token = static_cast(token)}; } nn::GeneralResult> getSupportedOperations(const nn::IDevice& device, const Model& model) { const auto nnModel = NN_TRY(convertInput(model)); return device.getSupportedOperations(nnModel); } using PrepareModelResult = nn::GeneralResult; std::shared_ptr adaptPreparedModel(nn::SharedPreparedModel preparedModel) { if (preparedModel == nullptr) { return nullptr; } return ndk::SharedRefBase::make(std::move(preparedModel)); } void notify(IPreparedModelCallback* callback, PrepareModelResult result) { if (!result.has_value()) { const auto& [message, status] = result.error(); LOG(ERROR) << message; const auto aidlCode = utils::convert(status).value_or(ErrorStatus::GENERAL_FAILURE); callback->notify(aidlCode, nullptr); } else { auto preparedModel = std::move(result).value(); auto aidlPreparedModel = adaptPreparedModel(std::move(preparedModel)); callback->notify(ErrorStatus::NONE, std::move(aidlPreparedModel)); } } nn::GeneralResult prepareModel(const nn::SharedDevice& device, const Executor& executor, const Model& model, ExecutionPreference preference, Priority priority, int64_t deadlineNs, const std::vector& modelCache, const std::vector& dataCache, const std::vector& token, const std::shared_ptr& callback) { if (callback.get() == nullptr) { return NN_ERROR(nn::ErrorStatus::INVALID_ARGUMENT) << "Invalid callback"; } auto nnModel = NN_TRY(convertInput(model)); const auto nnPreference = NN_TRY(convertInput(preference)); const auto nnPriority = NN_TRY(convertInput(priority)); const auto nnDeadline = NN_TRY(makeOptionalTimePoint(deadlineNs)); auto nnModelCache = NN_TRY(convertInput(modelCache)); auto nnDataCache = NN_TRY(convertInput(dataCache)); const auto nnToken = NN_TRY(convertCacheToken(token)); Task task = [device, nnModel = std::move(nnModel), nnPreference, nnPriority, nnDeadline, nnModelCache = std::move(nnModelCache), nnDataCache = std::move(nnDataCache), nnToken, callback] { auto result = device->prepareModel(nnModel, nnPreference, nnPriority, nnDeadline, nnModelCache, nnDataCache, nnToken); notify(callback.get(), std::move(result)); }; executor(std::move(task), nnDeadline); return {}; } nn::GeneralResult prepareModelFromCache( const nn::SharedDevice& device, const Executor& executor, int64_t deadlineNs, const std::vector& modelCache, const std::vector& dataCache, const std::vector& token, const std::shared_ptr& callback) { if (callback.get() == nullptr) { return NN_ERROR(nn::ErrorStatus::INVALID_ARGUMENT) << "Invalid callback"; } const auto nnDeadline = NN_TRY(makeOptionalTimePoint(deadlineNs)); auto nnModelCache = NN_TRY(convertInput(modelCache)); auto nnDataCache = NN_TRY(convertInput(dataCache)); const auto nnToken = NN_TRY(convertCacheToken(token)); auto task = [device, nnDeadline, nnModelCache = std::move(nnModelCache), nnDataCache = std::move(nnDataCache), nnToken, callback] { auto result = device->prepareModelFromCache(nnDeadline, nnModelCache, nnDataCache, nnToken); notify(callback.get(), std::move(result)); }; executor(std::move(task), nnDeadline); return {}; } } // namespace Device::Device(::android::nn::SharedDevice device, Executor executor) : kDevice(std::move(device)), kExecutor(std::move(executor)) { CHECK(kDevice != nullptr); CHECK(kExecutor != nullptr); } ndk::ScopedAStatus Device::allocate(const BufferDesc& desc, const std::vector& preparedModels, const std::vector& inputRoles, const std::vector& outputRoles, DeviceBuffer* buffer) { auto result = adapter::allocate(*kDevice, desc, preparedModels, inputRoles, outputRoles); if (!result.has_value()) { const auto& [message, code] = result.error(); const auto aidlCode = utils::convert(code).value_or(ErrorStatus::GENERAL_FAILURE); return ndk::ScopedAStatus::fromServiceSpecificErrorWithMessage( static_cast(aidlCode), message.c_str()); } *buffer = std::move(result).value(); return ndk::ScopedAStatus::ok(); } ndk::ScopedAStatus Device::getCapabilities(Capabilities* capabilities) { *capabilities = utils::convert(kDevice->getCapabilities()).value(); return ndk::ScopedAStatus::ok(); } ndk::ScopedAStatus Device::getNumberOfCacheFilesNeeded(NumberOfCacheFiles* numberOfCacheFiles) { const auto [numModelCache, numDataCache] = kDevice->getNumberOfCacheFilesNeeded(); *numberOfCacheFiles = NumberOfCacheFiles{.numModelCache = static_cast(numModelCache), .numDataCache = static_cast(numDataCache)}; return ndk::ScopedAStatus::ok(); } ndk::ScopedAStatus Device::getSupportedExtensions(std::vector* extensions) { *extensions = utils::convert(kDevice->getSupportedExtensions()).value(); return ndk::ScopedAStatus::ok(); } ndk::ScopedAStatus Device::getSupportedOperations(const Model& model, std::vector* supported) { auto result = adapter::getSupportedOperations(*kDevice, model); if (!result.has_value()) { const auto& [message, code] = result.error(); const auto aidlCode = utils::convert(code).value_or(ErrorStatus::GENERAL_FAILURE); return ndk::ScopedAStatus::fromServiceSpecificErrorWithMessage( static_cast(aidlCode), message.c_str()); } *supported = std::move(result).value(); return ndk::ScopedAStatus::ok(); } ndk::ScopedAStatus Device::getType(DeviceType* deviceType) { *deviceType = utils::convert(kDevice->getType()).value(); return ndk::ScopedAStatus::ok(); } ndk::ScopedAStatus Device::getVersionString(std::string* version) { *version = kDevice->getVersionString(); return ndk::ScopedAStatus::ok(); } ndk::ScopedAStatus Device::prepareModel(const Model& model, ExecutionPreference preference, Priority priority, int64_t deadlineNs, const std::vector& modelCache, const std::vector& dataCache, const std::vector& token, const std::shared_ptr& callback) { const auto result = adapter::prepareModel(kDevice, kExecutor, model, preference, priority, deadlineNs, modelCache, dataCache, token, callback); if (!result.has_value()) { const auto& [message, code] = result.error(); const auto aidlCode = utils::convert(code).value_or(ErrorStatus::GENERAL_FAILURE); callback->notify(aidlCode, nullptr); return ndk::ScopedAStatus::fromServiceSpecificErrorWithMessage( static_cast(aidlCode), message.c_str()); } return ndk::ScopedAStatus::ok(); } ndk::ScopedAStatus Device::prepareModelFromCache( int64_t deadlineNs, const std::vector& modelCache, const std::vector& dataCache, const std::vector& token, const std::shared_ptr& callback) { const auto result = adapter::prepareModelFromCache(kDevice, kExecutor, deadlineNs, modelCache, dataCache, token, callback); if (!result.has_value()) { const auto& [message, code] = result.error(); const auto aidlCode = utils::convert(code).value_or(ErrorStatus::GENERAL_FAILURE); callback->notify(aidlCode, nullptr); return ndk::ScopedAStatus::fromServiceSpecificErrorWithMessage( static_cast(aidlCode), message.c_str()); } return ndk::ScopedAStatus::ok(); } } // namespace aidl::android::hardware::neuralnetworks::adapter