summaryrefslogtreecommitdiff
path: root/neuralnetworks/aidl/vts/functional/ValidateModel.cpp
diff options
context:
space:
mode:
authorLev Proleev <levp@google.com>2020-12-15 19:25:32 +0000
committerLev Proleev <levp@google.com>2021-02-09 15:28:25 +0000
commitc185e88ccf62ee47fd13d52f447b4db471701fce (patch)
treec813d49a122c21091aded691671046aaea0ad14e /neuralnetworks/aidl/vts/functional/ValidateModel.cpp
parent6b6dfcd439f431cf15c6cf82bc170b61f2871670 (diff)
Implement VTS tests for NNAPI AIDL interface
The tests are copied from HIDL 1.0-3 VTS tests and updated to use AIDL. Bug: 172922059 Test: VtsHalNeuralnetworksTargetTest Change-Id: Ife08409e9b46420685a1ccb0b3256286c973dbf5 Merged-In: Ife08409e9b46420685a1ccb0b3256286c973dbf5 (cherry picked from commit b38bb4f12a1ceb33ebd0dd798650a74a8ef9d20e)
Diffstat (limited to 'neuralnetworks/aidl/vts/functional/ValidateModel.cpp')
-rw-r--r--neuralnetworks/aidl/vts/functional/ValidateModel.cpp1338
1 files changed, 1338 insertions, 0 deletions
diff --git a/neuralnetworks/aidl/vts/functional/ValidateModel.cpp b/neuralnetworks/aidl/vts/functional/ValidateModel.cpp
new file mode 100644
index 0000000000..b84d981abd
--- /dev/null
+++ b/neuralnetworks/aidl/vts/functional/ValidateModel.cpp
@@ -0,0 +1,1338 @@
+/*
+ * Copyright (C) 2021 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#define LOG_TAG "neuralnetworks_aidl_hal_test"
+
+#include <aidl/android/hardware/common/NativeHandle.h>
+#include <android/binder_auto_utils.h>
+#include <android/binder_enums.h>
+#include <android/binder_interface_utils.h>
+#include <nnapi/TypeUtils.h>
+#include <nnapi/hal/aidl/Conversions.h>
+#include <nnapi/hal/aidl/Utils.h>
+
+#include <optional>
+#include <type_traits>
+#include <utility>
+
+#include "Callbacks.h"
+#include "GeneratedTestHarness.h"
+#include "Utils.h"
+#include "VtsHalNeuralnetworks.h"
+
+namespace aidl::android::hardware::neuralnetworks::vts::functional {
+
+using common::NativeHandle;
+using implementation::PreparedModelCallback;
+
+using PrepareModelMutation = std::function<void(Model*, ExecutionPreference*, Priority*)>;
+
+///////////////////////// UTILITY FUNCTIONS /////////////////////////
+
+static void validateGetSupportedOperations(const std::shared_ptr<IDevice>& device,
+ const std::string& message, const Model& model) {
+ SCOPED_TRACE(message + " [getSupportedOperations]");
+
+ std::vector<bool> supported;
+ const auto retStatus = device->getSupportedOperations(model, &supported);
+
+ ASSERT_FALSE(retStatus.isOk());
+ ASSERT_EQ(retStatus.getExceptionCode(), EX_SERVICE_SPECIFIC);
+ ASSERT_EQ(static_cast<ErrorStatus>(retStatus.getServiceSpecificError()),
+ ErrorStatus::INVALID_ARGUMENT);
+}
+
+static void validatePrepareModel(const std::shared_ptr<IDevice>& device, const std::string& message,
+ const Model& model, ExecutionPreference preference,
+ Priority priority) {
+ SCOPED_TRACE(message + " [prepareModel]");
+
+ std::shared_ptr<PreparedModelCallback> preparedModelCallback =
+ ndk::SharedRefBase::make<PreparedModelCallback>();
+ const auto prepareLaunchStatus =
+ device->prepareModel(model, preference, priority, kNoDeadline, {}, {}, kEmptyCacheToken,
+ preparedModelCallback);
+ ASSERT_FALSE(prepareLaunchStatus.isOk());
+ ASSERT_EQ(prepareLaunchStatus.getExceptionCode(), EX_SERVICE_SPECIFIC);
+ ASSERT_EQ(static_cast<ErrorStatus>(prepareLaunchStatus.getServiceSpecificError()),
+ ErrorStatus::INVALID_ARGUMENT);
+
+ preparedModelCallback->wait();
+ ErrorStatus prepareReturnStatus = preparedModelCallback->getStatus();
+ ASSERT_EQ(ErrorStatus::INVALID_ARGUMENT, prepareReturnStatus);
+ std::shared_ptr<IPreparedModel> preparedModel = preparedModelCallback->getPreparedModel();
+ ASSERT_EQ(nullptr, preparedModel.get());
+}
+
+static bool validExecutionPreference(ExecutionPreference preference) {
+ return preference == ExecutionPreference::LOW_POWER ||
+ preference == ExecutionPreference::FAST_SINGLE_ANSWER ||
+ preference == ExecutionPreference::SUSTAINED_SPEED;
+}
+
+static bool validExecutionPriority(Priority priority) {
+ return priority == Priority::LOW || priority == Priority::MEDIUM || priority == Priority::HIGH;
+}
+
+// Primary validation function. This function will take a valid model, apply a
+// mutation to invalidate the model, the execution preference, or the priority,
+// then pass these to supportedOperations and/or prepareModel if that method is
+// called with an invalid argument.
+static void validate(const std::shared_ptr<IDevice>& device, const std::string& message,
+ const Model& originalModel, const PrepareModelMutation& mutate) {
+ Model model = utils::clone(originalModel).value();
+ ExecutionPreference preference = ExecutionPreference::FAST_SINGLE_ANSWER;
+ Priority priority = kDefaultPriority;
+ mutate(&model, &preference, &priority);
+
+ if (validExecutionPreference(preference) && validExecutionPriority(priority)) {
+ validateGetSupportedOperations(device, message, model);
+ }
+
+ validatePrepareModel(device, message, model, preference, priority);
+}
+
+static uint32_t addOperand(Model* model) {
+ model->main.operands.push_back({
+ .type = OperandType::INT32,
+ .dimensions = {},
+ .scale = 0.0f,
+ .zeroPoint = 0,
+ .lifetime = OperandLifeTime::SUBGRAPH_INPUT,
+ .location = {.poolIndex = 0, .offset = 0, .length = 0},
+ });
+ return model->main.operands.size() - 1;
+}
+
+static uint32_t addOperand(Model* model, OperandLifeTime lifetime) {
+ uint32_t index = addOperand(model);
+ model->main.operands[index].lifetime = lifetime;
+ return index;
+}
+
+// If we introduce a CONSTANT_COPY for an operand of size operandSize,
+// how much will this increase the size of the model? This assumes
+// that we can (re)use all of model.operandValues for the operand
+// value.
+static size_t constantCopyExtraSize(const Model& model, size_t operandSize) {
+ const size_t operandValuesSize = model.operandValues.size();
+ return (operandValuesSize < operandSize) ? (operandSize - operandValuesSize) : 0;
+}
+
+// Highly specialized utility routine for converting an operand to
+// CONSTANT_COPY lifetime.
+//
+// Expects that:
+// - operand has a known size
+// - operand->lifetime has already been set to CONSTANT_COPY
+// - operand->location has been zeroed out
+//
+// Does the following:
+// - initializes operand->location to point to the beginning of model->operandValues
+// - resizes model->operandValues (if necessary) to be large enough for the operand
+// value, padding it with zeroes on the end
+//
+// Potential problem:
+// By changing the operand to CONSTANT_COPY lifetime, this function is effectively initializing the
+// operand with unspecified (but deterministic) data. This means that the model may be invalidated
+// in two ways: not only is the lifetime of CONSTANT_COPY invalid, but the operand's value in the
+// graph may also be invalid (e.g., if the operand is used as an activation code and has an invalid
+// value). For now, this should be fine because it just means we're not testing what we think we're
+// testing in certain cases; but we can handwave this and assume we're probabilistically likely to
+// exercise the validation code over the span of the entire test set and operand space.
+//
+// Aborts if the specified operand type is an extension type or OEM type.
+static void becomeConstantCopy(Model* model, Operand* operand) {
+ // sizeOfData will abort if the specified type is an extension type or OEM type.
+ const size_t sizeOfOperand = sizeOfData(*operand);
+ EXPECT_NE(sizeOfOperand, size_t(0));
+ operand->location.poolIndex = 0;
+ operand->location.offset = 0;
+ operand->location.length = sizeOfOperand;
+ if (model->operandValues.size() < sizeOfOperand) {
+ model->operandValues.resize(sizeOfOperand);
+ }
+}
+
+// The sizeForBinder() functions estimate the size of the
+// representation of a value when sent to binder. It's probably a bit
+// of an under-estimate, because we don't know the size of the
+// metadata in the binder format (e.g., representation of the size of
+// a vector); but at least it adds up "big" things like vector
+// contents. However, it doesn't treat inter-field or end-of-struct
+// padding in a methodical way -- there's no attempt to be consistent
+// in whether or not padding in the native (C++) representation
+// contributes to the estimated size for the binder representation;
+// and there's no attempt to understand what padding (if any) is
+// needed in the binder representation.
+//
+// This assumes that non-metadata uses a fixed length encoding (e.g.,
+// a uint32_t is always encoded in sizeof(uint32_t) bytes, rather than
+// using an encoding whose length is related to the magnitude of the
+// encoded value).
+
+template <typename Type>
+static size_t sizeForBinder(const Type& val) {
+ static_assert(std::is_trivially_copyable_v<std::remove_reference_t<Type>>,
+ "expected a trivially copyable type");
+ return sizeof(val);
+}
+
+template <typename Type>
+static size_t sizeForBinder(const std::vector<Type>& vec) {
+ return std::accumulate(vec.begin(), vec.end(), 0,
+ [](size_t acc, const Type& x) { return acc + sizeForBinder(x); });
+}
+
+template <>
+size_t sizeForBinder(const SymmPerChannelQuantParams& symmPerChannelQuantParams) {
+ size_t size = 0;
+
+ size += sizeForBinder(symmPerChannelQuantParams.scales);
+ size += sizeForBinder(symmPerChannelQuantParams.channelDim);
+
+ return size;
+}
+
+template <>
+size_t sizeForBinder(const std::optional<OperandExtraParams>& optionalExtraParams) {
+ if (!optionalExtraParams.has_value()) {
+ return 0;
+ }
+ const auto& extraParams = optionalExtraParams.value();
+ using Tag = OperandExtraParams::Tag;
+ switch (extraParams.getTag()) {
+ case Tag::channelQuant:
+ return sizeForBinder(extraParams.get<Tag::channelQuant>());
+ case Tag::extension:
+ return sizeForBinder(extraParams.get<Tag::extension>());
+ }
+ LOG(FATAL) << "Unrecognized extraParams tag: " << static_cast<int>(extraParams.getTag());
+ return 0;
+}
+
+template <>
+size_t sizeForBinder(const Operand& operand) {
+ size_t size = 0;
+
+ size += sizeForBinder(operand.type);
+ size += sizeForBinder(operand.dimensions);
+ size += sizeForBinder(operand.scale);
+ size += sizeForBinder(operand.zeroPoint);
+ size += sizeForBinder(operand.lifetime);
+ size += sizeForBinder(operand.location);
+ size += sizeForBinder(operand.extraParams);
+
+ return size;
+}
+
+template <>
+size_t sizeForBinder(const Operation& operation) {
+ size_t size = 0;
+
+ size += sizeForBinder(operation.type);
+ size += sizeForBinder(operation.inputs);
+ size += sizeForBinder(operation.outputs);
+
+ return size;
+}
+
+template <>
+size_t sizeForBinder(const std::string& name) {
+ return name.size();
+}
+
+template <>
+size_t sizeForBinder(const Memory& memory) {
+ // This is just a guess.
+
+ size_t size = 0;
+ const NativeHandle& handle = memory.handle;
+ size += sizeof(decltype(handle.fds)::value_type) * handle.fds.size();
+ size += sizeof(decltype(handle.ints)::value_type) * handle.ints.size();
+ size += sizeForBinder(memory.name);
+ size += sizeof(memory);
+
+ return size;
+}
+
+template <>
+size_t sizeForBinder(const Subgraph& subgraph) {
+ size_t size = 0;
+
+ size += sizeForBinder(subgraph.operands);
+ size += sizeForBinder(subgraph.operations);
+ size += sizeForBinder(subgraph.inputIndexes);
+ size += sizeForBinder(subgraph.outputIndexes);
+
+ return size;
+}
+
+template <>
+size_t sizeForBinder(const ExtensionNameAndPrefix& extensionNameToPrefix) {
+ size_t size = 0;
+
+ size += sizeForBinder(extensionNameToPrefix.name);
+ size += sizeForBinder(extensionNameToPrefix.prefix);
+
+ return size;
+}
+
+template <>
+size_t sizeForBinder(const Model& model) {
+ size_t size = 0;
+
+ size += sizeForBinder(model.main);
+ size += sizeForBinder(model.referenced);
+ size += sizeForBinder(model.operandValues);
+ size += sizeForBinder(model.pools);
+ size += sizeForBinder(model.relaxComputationFloat32toFloat16);
+ size += sizeForBinder(model.extensionNameToPrefix);
+
+ return size;
+}
+
+// https://developer.android.com/reference/android/os/TransactionTooLargeException.html
+//
+// "The Binder transaction buffer has a limited fixed size,
+// currently 1Mb, which is shared by all transactions in progress
+// for the process."
+//
+// Will our representation fit under this limit? There are two complications:
+// - Our representation size is just approximate (see sizeForBinder()).
+// - This object may not be the only occupant of the Binder transaction buffer.
+// So we'll be very conservative: We want the representation size to be no
+// larger than half the transaction buffer size.
+//
+// If our representation grows large enough that it still fits within
+// the transaction buffer but combined with other transactions may
+// exceed the buffer size, then we may see intermittent HAL transport
+// errors.
+static bool exceedsBinderSizeLimit(size_t representationSize) {
+ // Instead of using this fixed buffer size, we might instead be able to use
+ // ProcessState::self()->getMmapSize(). However, this has a potential
+ // problem: The binder/mmap size of the current process does not necessarily
+ // indicate the binder/mmap size of the service (i.e., the other process).
+ // The only way it would be a good indication is if both the current process
+ // and the service use the default size.
+ static const size_t kHalfBufferSize = 1024 * 1024 / 2;
+
+ return representationSize > kHalfBufferSize;
+}
+
+///////////////////////// VALIDATE EXECUTION ORDER ////////////////////////////
+
+static void mutateExecutionOrderTest(const std::shared_ptr<IDevice>& device, const Model& model,
+ const std::vector<uint32_t>& numberOfConsumers) {
+ for (size_t operation = 0; operation < model.main.operations.size(); ++operation) {
+ const Operation& operationObj = model.main.operations[operation];
+ for (uint32_t input : operationObj.inputs) {
+ if (model.main.operands[input].lifetime == OperandLifeTime::TEMPORARY_VARIABLE ||
+ model.main.operands[input].lifetime == OperandLifeTime::SUBGRAPH_OUTPUT) {
+ // This operation reads an operand written by some
+ // other operation. Move this operation to the
+ // beginning of the sequence, ensuring that it reads
+ // the operand before that operand is written, thereby
+ // violating execution order rules.
+ const std::string message = "mutateExecutionOrderTest: operation " +
+ std::to_string(operation) + " is a reader";
+ validate(device, message, model,
+ [operation](Model* model, ExecutionPreference*, Priority*) {
+ auto& operations = model->main.operations;
+ std::rotate(operations.begin(), operations.begin() + operation,
+ operations.begin() + operation + 1);
+ });
+ break; // only need to do this once per operation
+ }
+ }
+ for (uint32_t output : operationObj.outputs) {
+ if (numberOfConsumers[output] > 0) {
+ // This operation writes an operand read by some other
+ // operation. Move this operation to the end of the
+ // sequence, ensuring that it writes the operand after
+ // that operand is read, thereby violating execution
+ // order rules.
+ const std::string message = "mutateExecutionOrderTest: operation " +
+ std::to_string(operation) + " is a writer";
+ validate(device, message, model,
+ [operation](Model* model, ExecutionPreference*, Priority*) {
+ auto& operations = model->main.operations;
+ std::rotate(operations.begin() + operation,
+ operations.begin() + operation + 1, operations.end());
+ });
+ break; // only need to do this once per operation
+ }
+ }
+ }
+}
+
+///////////////////////// VALIDATE MODEL OPERAND TYPE /////////////////////////
+
+static const int32_t invalidOperandTypes[] = {
+ -1,
+ static_cast<int32_t>(*(ndk::enum_range<OperandType>().end() - 1)) + 1,
+};
+
+static void mutateOperandTypeTest(const std::shared_ptr<IDevice>& device, const Model& model) {
+ for (size_t operand = 0; operand < model.main.operands.size(); ++operand) {
+ for (int32_t invalidOperandType : invalidOperandTypes) {
+ const std::string message = "mutateOperandTypeTest: operand " +
+ std::to_string(operand) + " set to value " +
+ std::to_string(invalidOperandType);
+ validate(device, message, model,
+ [operand, invalidOperandType](Model* model, ExecutionPreference*, Priority*) {
+ model->main.operands[operand].type =
+ static_cast<OperandType>(invalidOperandType);
+ });
+ }
+ }
+}
+
+///////////////////////// VALIDATE OPERAND RANK /////////////////////////
+
+static uint32_t getInvalidRank(OperandType type) {
+ switch (type) {
+ case OperandType::FLOAT16:
+ case OperandType::FLOAT32:
+ case OperandType::INT32:
+ case OperandType::UINT32:
+ case OperandType::BOOL:
+ return 1;
+ case OperandType::TENSOR_BOOL8:
+ case OperandType::TENSOR_FLOAT16:
+ case OperandType::TENSOR_FLOAT32:
+ case OperandType::TENSOR_INT32:
+ case OperandType::TENSOR_QUANT8_ASYMM:
+ case OperandType::TENSOR_QUANT8_SYMM:
+ case OperandType::TENSOR_QUANT16_ASYMM:
+ case OperandType::TENSOR_QUANT16_SYMM:
+ case OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL:
+ return 0;
+ default:
+ return 0;
+ }
+}
+
+static void mutateOperandRankTest(const std::shared_ptr<IDevice>& device, const Model& model) {
+ for (size_t operand = 0; operand < model.main.operands.size(); ++operand) {
+ const uint32_t invalidRank = getInvalidRank(model.main.operands[operand].type);
+ if (invalidRank == 0) {
+ continue;
+ }
+ const std::string message = "mutateOperandRankTest: operand " + std::to_string(operand) +
+ " has rank of " + std::to_string(invalidRank);
+ validate(device, message, model,
+ [operand, invalidRank](Model* model, ExecutionPreference*, Priority*) {
+ model->main.operands[operand].dimensions =
+ std::vector<int32_t>(invalidRank, 0);
+ });
+ }
+}
+
+///////////////////////// VALIDATE OPERAND SCALE /////////////////////////
+
+static float getInvalidScale(OperandType type) {
+ switch (type) {
+ case OperandType::FLOAT16:
+ case OperandType::FLOAT32:
+ case OperandType::INT32:
+ case OperandType::UINT32:
+ case OperandType::BOOL:
+ case OperandType::TENSOR_BOOL8:
+ case OperandType::TENSOR_FLOAT16:
+ case OperandType::TENSOR_FLOAT32:
+ case OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL:
+ case OperandType::SUBGRAPH:
+ return 1.0f;
+ case OperandType::TENSOR_INT32:
+ return -1.0f;
+ case OperandType::TENSOR_QUANT8_SYMM:
+ case OperandType::TENSOR_QUANT8_ASYMM:
+ case OperandType::TENSOR_QUANT16_ASYMM:
+ case OperandType::TENSOR_QUANT16_SYMM:
+ return 0.0f;
+ default:
+ return 0.0f;
+ }
+}
+
+static void mutateOperandScaleTest(const std::shared_ptr<IDevice>& device, const Model& model) {
+ for (size_t operand = 0; operand < model.main.operands.size(); ++operand) {
+ const float invalidScale = getInvalidScale(model.main.operands[operand].type);
+ const std::string message = "mutateOperandScaleTest: operand " + std::to_string(operand) +
+ " has scale of " + std::to_string(invalidScale);
+ validate(device, message, model,
+ [operand, invalidScale](Model* model, ExecutionPreference*, Priority*) {
+ model->main.operands[operand].scale = invalidScale;
+ });
+ }
+}
+
+///////////////////////// VALIDATE OPERAND ZERO POINT /////////////////////////
+
+static std::vector<int32_t> getInvalidZeroPoints(OperandType type) {
+ switch (type) {
+ case OperandType::FLOAT16:
+ case OperandType::FLOAT32:
+ case OperandType::INT32:
+ case OperandType::UINT32:
+ case OperandType::BOOL:
+ case OperandType::TENSOR_BOOL8:
+ case OperandType::TENSOR_FLOAT16:
+ case OperandType::TENSOR_FLOAT32:
+ case OperandType::TENSOR_INT32:
+ case OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL:
+ case OperandType::SUBGRAPH:
+ return {1};
+ case OperandType::TENSOR_QUANT8_ASYMM:
+ return {-1, 256};
+ case OperandType::TENSOR_QUANT8_SYMM:
+ return {-129, -1, 1, 128};
+ case OperandType::TENSOR_QUANT16_ASYMM:
+ return {-1, 65536};
+ case OperandType::TENSOR_QUANT16_SYMM:
+ return {-32769, -1, 1, 32768};
+ default:
+ return {};
+ }
+}
+
+static void mutateOperandZeroPointTest(const std::shared_ptr<IDevice>& device, const Model& model) {
+ for (size_t operand = 0; operand < model.main.operands.size(); ++operand) {
+ const std::vector<int32_t> invalidZeroPoints =
+ getInvalidZeroPoints(model.main.operands[operand].type);
+ for (int32_t invalidZeroPoint : invalidZeroPoints) {
+ const std::string message = "mutateOperandZeroPointTest: operand " +
+ std::to_string(operand) + " has zero point of " +
+ std::to_string(invalidZeroPoint);
+ validate(device, message, model,
+ [operand, invalidZeroPoint](Model* model, ExecutionPreference*, Priority*) {
+ model->main.operands[operand].zeroPoint = invalidZeroPoint;
+ });
+ }
+ }
+}
+
+///////////////////////// VALIDATE OPERAND LIFETIME /////////////////////////////////////////////
+
+static std::vector<OperandLifeTime> getInvalidLifeTimes(const Model& model, size_t modelSize,
+ const Operand& operand) {
+ // TODO: Support OperandLifeTime::CONSTANT_REFERENCE as an invalid lifetime
+ // TODO: Support OperandLifeTime::NO_VALUE as an invalid lifetime
+
+ // Ways to get an invalid lifetime:
+ // - change whether a lifetime means an operand should have a writer
+ std::vector<OperandLifeTime> ret;
+ switch (operand.lifetime) {
+ case OperandLifeTime::SUBGRAPH_OUTPUT:
+ case OperandLifeTime::TEMPORARY_VARIABLE:
+ ret = {
+ OperandLifeTime::SUBGRAPH_INPUT,
+ OperandLifeTime::CONSTANT_COPY,
+ };
+ break;
+ case OperandLifeTime::CONSTANT_COPY:
+ case OperandLifeTime::CONSTANT_POOL:
+ case OperandLifeTime::SUBGRAPH_INPUT:
+ ret = {
+ OperandLifeTime::TEMPORARY_VARIABLE,
+ OperandLifeTime::SUBGRAPH_OUTPUT,
+ };
+ break;
+ case OperandLifeTime::NO_VALUE:
+ // Not enough information to know whether
+ // TEMPORARY_VARIABLE or CONSTANT_COPY would be invalid --
+ // is this operand written (then CONSTANT_COPY would be
+ // invalid) or not (then TEMPORARY_VARIABLE would be
+ // invalid)?
+ break;
+ case OperandLifeTime::SUBGRAPH:
+ break;
+ default:
+ ADD_FAILURE();
+ break;
+ }
+
+ const size_t operandSize = sizeOfData(operand); // will be zero if shape is unknown
+ if (!operandSize ||
+ exceedsBinderSizeLimit(modelSize + constantCopyExtraSize(model, operandSize))) {
+ // Unknown size or too-large size
+ ret.erase(std::remove(ret.begin(), ret.end(), OperandLifeTime::CONSTANT_COPY), ret.end());
+ }
+
+ return ret;
+}
+
+static void mutateOperandLifeTimeTest(const std::shared_ptr<IDevice>& device, const Model& model) {
+ const size_t modelSize = sizeForBinder(model);
+ for (size_t operand = 0; operand < model.main.operands.size(); ++operand) {
+ const std::vector<OperandLifeTime> invalidLifeTimes =
+ getInvalidLifeTimes(model, modelSize, model.main.operands[operand]);
+ for (OperandLifeTime invalidLifeTime : invalidLifeTimes) {
+ const std::string message = "mutateOperandLifetimeTest: operand " +
+ std::to_string(operand) + " has lifetime " +
+ toString(invalidLifeTime) + " instead of lifetime " +
+ toString(model.main.operands[operand].lifetime);
+ validate(device, message, model,
+ [operand, invalidLifeTime](Model* model, ExecutionPreference*, Priority*) {
+ static const DataLocation kZeroDataLocation = {};
+ Operand& operandObj = model->main.operands[operand];
+ switch (operandObj.lifetime) {
+ case OperandLifeTime::SUBGRAPH_INPUT: {
+ auto& inputs = model->main.inputIndexes;
+ inputs.erase(std::remove(inputs.begin(), inputs.end(), operand),
+ inputs.end());
+ break;
+ }
+ case OperandLifeTime::SUBGRAPH_OUTPUT: {
+ auto& outputs = model->main.outputIndexes;
+ outputs.erase(std::remove(outputs.begin(), outputs.end(), operand),
+ outputs.end());
+ break;
+ }
+ default:
+ break;
+ }
+ operandObj.lifetime = invalidLifeTime;
+ operandObj.location = kZeroDataLocation;
+ switch (invalidLifeTime) {
+ case OperandLifeTime::CONSTANT_COPY: {
+ becomeConstantCopy(model, &operandObj);
+ break;
+ }
+ case OperandLifeTime::SUBGRAPH_INPUT:
+ model->main.inputIndexes.push_back(operand);
+ break;
+ case OperandLifeTime::SUBGRAPH_OUTPUT:
+ model->main.outputIndexes.push_back(operand);
+ break;
+ default:
+ break;
+ }
+ });
+ }
+ }
+}
+
+///////////////////////// VALIDATE OPERAND INPUT-or-OUTPUT //////////////////////////////////////
+
+static std::optional<OperandLifeTime> getInputOutputLifeTime(const Model& model, size_t modelSize,
+ const Operand& operand) {
+ // Ways to get an invalid lifetime (with respect to model inputIndexes and outputIndexes):
+ // - change whether a lifetime means an operand is a model input, a model output, or neither
+ // - preserve whether or not a lifetime means an operand should have a writer
+ switch (operand.lifetime) {
+ case OperandLifeTime::CONSTANT_COPY:
+ case OperandLifeTime::CONSTANT_POOL:
+ return OperandLifeTime::SUBGRAPH_INPUT;
+ case OperandLifeTime::SUBGRAPH_INPUT: {
+ const size_t operandSize = sizeOfData(operand); // will be zero if shape is unknown
+ if (!operandSize ||
+ exceedsBinderSizeLimit(modelSize + constantCopyExtraSize(model, operandSize))) {
+ // Unknown size or too-large size
+ break;
+ }
+ return OperandLifeTime::CONSTANT_COPY;
+ }
+ case OperandLifeTime::SUBGRAPH_OUTPUT:
+ return OperandLifeTime::TEMPORARY_VARIABLE;
+ case OperandLifeTime::TEMPORARY_VARIABLE:
+ return OperandLifeTime::SUBGRAPH_OUTPUT;
+ case OperandLifeTime::NO_VALUE:
+ // Not enough information to know whether
+ // TEMPORARY_VARIABLE or CONSTANT_COPY would be an
+ // appropriate choice -- is this operand written (then
+ // TEMPORARY_VARIABLE would be appropriate) or not (then
+ // CONSTANT_COPY would be appropriate)?
+ break;
+ case OperandLifeTime::SUBGRAPH:
+ break;
+ default:
+ ADD_FAILURE();
+ break;
+ }
+
+ return std::nullopt;
+}
+
+static void mutateOperandInputOutputTest(const std::shared_ptr<IDevice>& device,
+ const Model& model) {
+ const size_t modelSize = sizeForBinder(model);
+ for (size_t operand = 0; operand < model.main.operands.size(); ++operand) {
+ const std::optional<OperandLifeTime> changedLifeTime =
+ getInputOutputLifeTime(model, modelSize, model.main.operands[operand]);
+ if (changedLifeTime) {
+ const std::string message = "mutateOperandInputOutputTest: operand " +
+ std::to_string(operand) + " has lifetime " +
+ toString(*changedLifeTime) + " instead of lifetime " +
+ toString(model.main.operands[operand].lifetime);
+ validate(device, message, model,
+ [operand, changedLifeTime](Model* model, ExecutionPreference*, Priority*) {
+ static const DataLocation kZeroDataLocation = {};
+ Operand& operandObj = model->main.operands[operand];
+ operandObj.lifetime = *changedLifeTime;
+ operandObj.location = kZeroDataLocation;
+ if (*changedLifeTime == OperandLifeTime::CONSTANT_COPY) {
+ becomeConstantCopy(model, &operandObj);
+ }
+ });
+ }
+ }
+}
+
+///////////////////////// VALIDATE OPERAND NUMBER OF WRITERS ////////////////////////////////////
+
+static void mutateOperandAddWriterTest(const std::shared_ptr<IDevice>& device, const Model& model) {
+ for (size_t operation = 0; operation < model.main.operations.size(); ++operation) {
+ for (size_t badOutputNum = 0;
+ badOutputNum < model.main.operations[operation].outputs.size(); ++badOutputNum) {
+ const uint32_t outputOperandIndex =
+ model.main.operations[operation].outputs[badOutputNum];
+ const std::string message = "mutateOperandAddWriterTest: operation " +
+ std::to_string(operation) + " writes to " +
+ std::to_string(outputOperandIndex);
+ // We'll insert a copy of the operation, all of whose
+ // OTHER output operands are newly-created -- i.e.,
+ // there'll only be a duplicate write of ONE of that
+ // operation's output operands.
+ validate(device, message, model,
+ [operation, badOutputNum](Model* model, ExecutionPreference*, Priority*) {
+ Operation newOperation = model->main.operations[operation];
+ for (size_t outputNum = 0; outputNum < newOperation.outputs.size();
+ ++outputNum) {
+ if (outputNum == badOutputNum) continue;
+
+ Operand operandValue =
+ model->main.operands[newOperation.outputs[outputNum]];
+ if (operandValue.lifetime == OperandLifeTime::SUBGRAPH_OUTPUT) {
+ operandValue.lifetime = OperandLifeTime::TEMPORARY_VARIABLE;
+ } else {
+ ASSERT_EQ(operandValue.lifetime,
+ OperandLifeTime::TEMPORARY_VARIABLE);
+ }
+ newOperation.outputs[outputNum] = model->main.operands.size();
+ model->main.operands.push_back(operandValue);
+ }
+ // Where do we insert the extra writer (a new
+ // operation)? It has to be later than all the
+ // writers of its inputs. The easiest thing to do
+ // is to insert it at the end of the operation
+ // sequence.
+ model->main.operations.push_back(newOperation);
+ });
+ }
+ }
+}
+
+///////////////////////// VALIDATE EXTRA ??? /////////////////////////
+
+// TODO: Operand::location
+
+///////////////////////// VALIDATE OPERATION OPERAND TYPE /////////////////////////
+
+static void mutateOperand(Operand* operand, OperandType type) {
+ Operand newOperand = *operand;
+ newOperand.type = type;
+ switch (type) {
+ case OperandType::FLOAT16:
+ case OperandType::FLOAT32:
+ case OperandType::INT32:
+ case OperandType::UINT32:
+ case OperandType::BOOL:
+ newOperand.dimensions = {};
+ newOperand.scale = 0.0f;
+ newOperand.zeroPoint = 0;
+ break;
+ case OperandType::TENSOR_BOOL8:
+ case OperandType::TENSOR_FLOAT16:
+ case OperandType::TENSOR_FLOAT32:
+ newOperand.dimensions = operand->dimensions.size() > 0 ? operand->dimensions
+ : std::vector<int32_t>({1});
+ newOperand.scale = 0.0f;
+ newOperand.zeroPoint = 0;
+ break;
+ case OperandType::TENSOR_INT32:
+ newOperand.dimensions = operand->dimensions.size() > 0 ? operand->dimensions
+ : std::vector<int32_t>({1});
+ newOperand.zeroPoint = 0;
+ break;
+ case OperandType::TENSOR_QUANT8_ASYMM:
+ case OperandType::TENSOR_QUANT8_SYMM:
+ case OperandType::TENSOR_QUANT16_ASYMM:
+ case OperandType::TENSOR_QUANT16_SYMM:
+ newOperand.dimensions = operand->dimensions.size() > 0 ? operand->dimensions
+ : std::vector<int32_t>({1});
+ newOperand.scale = operand->scale != 0.0f ? operand->scale : 1.0f;
+ break;
+ case OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL: {
+ newOperand.dimensions = operand->dimensions.size() > 0 ? operand->dimensions
+ : std::vector<int32_t>({1});
+ newOperand.scale = 0.0f;
+ newOperand.zeroPoint = 0;
+
+ SymmPerChannelQuantParams channelQuant;
+ channelQuant.channelDim = 0;
+ channelQuant.scales = std::vector<float>(
+ operand->dimensions.size() > 0 ? static_cast<size_t>(operand->dimensions[0])
+ : 0);
+ for (size_t i = 0; i < channelQuant.scales.size(); ++i) {
+ channelQuant.scales[i] = 1.0f;
+ }
+ newOperand.extraParams->set<OperandExtraParams::Tag::channelQuant>(
+ std::move(channelQuant));
+ } break;
+ default:
+ break;
+ }
+ *operand = newOperand;
+}
+
+static bool mutateOperationOperandTypeSkip(size_t operand, OperandType type, const Model& model) {
+ if (type == model.main.operands[operand].type) {
+ return true;
+ }
+ for (const Operation& operation : model.main.operations) {
+ // Skip mutateOperationOperandTypeTest for the following operations.
+ // - LSH_PROJECTION's second argument is allowed to have any type.
+ // - ARGMIN and ARGMAX's first argument can be any of
+ // TENSOR_(FLOAT16|FLOAT32|INT32|QUANT8_ASYMM).
+ // - CAST's argument can be any of TENSOR_(FLOAT16|FLOAT32|INT32|QUANT8_ASYMM).
+ // - RANDOM_MULTINOMIAL's argument can be either TENSOR_FLOAT16 or TENSOR_FLOAT32.
+ // - DEQUANTIZE input can be any of
+ // TENSOR_(QUANT8_ASYMM|QUANT8_ASYMM_SIGNED|QUANT8_SYMM|QUANT8_SYMM_PER_CHANNEL),
+ // output can be of either TENSOR_FLOAT16 or TENSOR_FLOAT32.
+ // - QUANTIZE input can be either TENSOR_FLOAT16 or TENSOR_FLOAT32
+ // - CONV_2D filter type (arg 1) can be QUANT8_ASYMM or QUANT8_SYMM_PER_CHANNEL
+ // - DEPTHWISE_CONV_2D filter type (arg 1) can be QUANT8_ASYMM or QUANT8_SYMM_PER_CHANNEL
+ // - GROUPED_CONV_2D filter type (arg 1) can be QUANT8_ASYMM or QUANT8_SYMM_PER_CHANNEL
+ // - TRANSPOSE_CONV_2D filter type (arg 1) can be QUANT8_ASYMM or QUANT8_SYMM_PER_CHANNEL
+ // - AXIS_ALIGNED_BBOX_TRANSFORM bounding boxes (arg 1) can be of
+ // TENSOR_QUANT8_ASYMM or TENSOR_QUANT8_ASYMM_SIGNED.
+ // - RANK's input can have any TENSOR_* type.
+ switch (operation.type) {
+ case OperationType::LSH_PROJECTION: {
+ if (operand == operation.inputs[1]) {
+ return true;
+ }
+ } break;
+ case OperationType::CAST:
+ case OperationType::ARGMAX:
+ case OperationType::ARGMIN: {
+ if (type == OperandType::TENSOR_FLOAT16 || type == OperandType::TENSOR_FLOAT32 ||
+ type == OperandType::TENSOR_INT32 || type == OperandType::TENSOR_QUANT8_ASYMM ||
+ type == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
+ return true;
+ }
+ } break;
+ case OperationType::QUANTIZE: {
+ if (operand == operation.inputs[0] &&
+ (type == OperandType::TENSOR_FLOAT16 || type == OperandType::TENSOR_FLOAT32)) {
+ return true;
+ }
+ if (operand == operation.outputs[0] &&
+ (type == OperandType::TENSOR_QUANT8_ASYMM ||
+ type == OperandType::TENSOR_QUANT8_ASYMM_SIGNED)) {
+ return true;
+ }
+ } break;
+ case OperationType::RANDOM_MULTINOMIAL: {
+ if (operand == operation.inputs[0] &&
+ (type == OperandType::TENSOR_FLOAT16 || type == OperandType::TENSOR_FLOAT32)) {
+ return true;
+ }
+ } break;
+ case OperationType::DEQUANTIZE: {
+ if (operand == operation.inputs[0] &&
+ (type == OperandType::TENSOR_QUANT8_ASYMM ||
+ type == OperandType::TENSOR_QUANT8_ASYMM_SIGNED ||
+ type == OperandType::TENSOR_QUANT8_SYMM ||
+ type == OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL)) {
+ return true;
+ }
+ if (operand == operation.outputs[0] &&
+ (type == OperandType::TENSOR_FLOAT16 || type == OperandType::TENSOR_FLOAT32)) {
+ return true;
+ }
+ } break;
+ case OperationType::TRANSPOSE_CONV_2D:
+ case OperationType::GROUPED_CONV_2D:
+ case OperationType::DEPTHWISE_CONV_2D:
+ case OperationType::CONV_2D: {
+ if (operand == operation.inputs[1] &&
+ (type == OperandType::TENSOR_QUANT8_ASYMM ||
+ type == OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL)) {
+ return true;
+ }
+ } break;
+ case OperationType::AXIS_ALIGNED_BBOX_TRANSFORM: {
+ if (operand == operation.inputs[1] &&
+ (type == OperandType::TENSOR_QUANT8_ASYMM ||
+ type == OperandType::TENSOR_QUANT8_ASYMM_SIGNED)) {
+ return true;
+ }
+ } break;
+ case OperationType::RANK: {
+ if (operand == operation.inputs[0] &&
+ (type == OperandType::TENSOR_FLOAT16 || type == OperandType::TENSOR_FLOAT32 ||
+ type == OperandType::TENSOR_INT32 ||
+ type == OperandType::TENSOR_QUANT8_ASYMM ||
+ type == OperandType::TENSOR_QUANT16_SYMM ||
+ type == OperandType::TENSOR_BOOL8 ||
+ type == OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL ||
+ type == OperandType::TENSOR_QUANT16_ASYMM ||
+ type == OperandType::TENSOR_QUANT8_SYMM ||
+ type == OperandType::TENSOR_QUANT8_ASYMM_SIGNED)) {
+ return true;
+ }
+ } break;
+ default:
+ break;
+ }
+ }
+ return false;
+}
+
+static void mutateOperationOperandTypeTest(const std::shared_ptr<IDevice>& device,
+ const Model& model) {
+ for (size_t operand = 0; operand < model.main.operands.size(); ++operand) {
+ for (OperandType invalidOperandType : ndk::enum_range<OperandType>()) {
+ if (mutateOperationOperandTypeSkip(operand, invalidOperandType, model)) {
+ continue;
+ }
+ const std::string message = "mutateOperationOperandTypeTest: operand " +
+ std::to_string(operand) + " set to type " +
+ toString(invalidOperandType);
+ validate(device, message, model,
+ [operand, invalidOperandType](Model* model, ExecutionPreference*, Priority*) {
+ mutateOperand(&model->main.operands[operand], invalidOperandType);
+ });
+ }
+ }
+}
+
+///////////////////////// VALIDATE MODEL OPERATION TYPE /////////////////////////
+
+static const int32_t invalidOperationTypes[] = {
+ -1,
+ static_cast<int32_t>(*(ndk::enum_range<OperationType>().end() - 1)) + 1,
+};
+
+static void mutateOperationTypeTest(const std::shared_ptr<IDevice>& device, const Model& model) {
+ for (size_t operation = 0; operation < model.main.operations.size(); ++operation) {
+ for (int32_t invalidOperationType : invalidOperationTypes) {
+ const std::string message = "mutateOperationTypeTest: operation " +
+ std::to_string(operation) + " set to value " +
+ std::to_string(invalidOperationType);
+ validate(device, message, model,
+ [operation, invalidOperationType](Model* model, ExecutionPreference*,
+ Priority*) {
+ model->main.operations[operation].type =
+ static_cast<OperationType>(invalidOperationType);
+ });
+ }
+ }
+}
+
+///////////////////////// VALIDATE MODEL OPERATION INPUT OPERAND INDEX /////////////////////////
+
+static void mutateOperationInputOperandIndexTest(const std::shared_ptr<IDevice>& device,
+ const Model& model) {
+ for (size_t operation = 0; operation < model.main.operations.size(); ++operation) {
+ const uint32_t invalidOperand = model.main.operands.size();
+ for (size_t input = 0; input < model.main.operations[operation].inputs.size(); ++input) {
+ const std::string message = "mutateOperationInputOperandIndexTest: operation " +
+ std::to_string(operation) + " input " +
+ std::to_string(input);
+ validate(device, message, model,
+ [operation, input, invalidOperand](Model* model, ExecutionPreference*,
+ Priority*) {
+ model->main.operations[operation].inputs[input] = invalidOperand;
+ });
+ }
+ }
+}
+
+///////////////////////// VALIDATE MODEL OPERATION OUTPUT OPERAND INDEX /////////////////////////
+
+static void mutateOperationOutputOperandIndexTest(const std::shared_ptr<IDevice>& device,
+ const Model& model) {
+ for (size_t operation = 0; operation < model.main.operations.size(); ++operation) {
+ const uint32_t invalidOperand = model.main.operands.size();
+ for (size_t output = 0; output < model.main.operations[operation].outputs.size();
+ ++output) {
+ const std::string message = "mutateOperationOutputOperandIndexTest: operation " +
+ std::to_string(operation) + " output " +
+ std::to_string(output);
+ validate(device, message, model,
+ [operation, output, invalidOperand](Model* model, ExecutionPreference*,
+ Priority*) {
+ model->main.operations[operation].outputs[output] = invalidOperand;
+ });
+ }
+ }
+}
+
+///////////////////////// VALIDATE MODEL OPERANDS WRITTEN ///////////////////////////////////////
+
+static void mutateOperationRemoveWriteTest(const std::shared_ptr<IDevice>& device,
+ const Model& model,
+ const std::vector<uint32_t>& numberOfConsumers) {
+ for (size_t operation = 0; operation < model.main.operations.size(); ++operation) {
+ for (size_t outputNum = 0; outputNum < model.main.operations[operation].outputs.size();
+ ++outputNum) {
+ const uint32_t outputOperandIndex = model.main.operations[operation].outputs[outputNum];
+ if (numberOfConsumers[outputOperandIndex] > 0) {
+ const std::string message = "mutateOperationRemoveWriteTest: operation " +
+ std::to_string(operation) + " writes to " +
+ std::to_string(outputOperandIndex);
+ validate(device, message, model,
+ [operation, outputNum](Model* model, ExecutionPreference*, Priority*) {
+ int32_t& outputOperandIndex =
+ model->main.operations[operation].outputs[outputNum];
+ Operand operandValue = model->main.operands[outputOperandIndex];
+ if (operandValue.lifetime == OperandLifeTime::SUBGRAPH_OUTPUT) {
+ operandValue.lifetime = OperandLifeTime::TEMPORARY_VARIABLE;
+ } else {
+ ASSERT_EQ(operandValue.lifetime,
+ OperandLifeTime::TEMPORARY_VARIABLE);
+ }
+ outputOperandIndex = model->main.operands.size();
+ model->main.operands.push_back(operandValue);
+ });
+ }
+ }
+ }
+}
+
+///////////////////////// REMOVE OPERAND FROM EVERYTHING /////////////////////////
+
+static void removeValueAndDecrementGreaterValues(std::vector<int32_t>* vec, uint32_t value) {
+ if (vec) {
+ // remove elements matching "value"
+ vec->erase(std::remove(vec->begin(), vec->end(), value), vec->end());
+
+ // decrement elements exceeding "value"
+ std::transform(vec->begin(), vec->end(), vec->begin(),
+ [value](uint32_t v) { return v > value ? v-- : v; });
+ }
+}
+
+static void removeOperand(Model* model, uint32_t index) {
+ model->main.operands.erase(model->main.operands.begin() + index);
+ for (Operation& operation : model->main.operations) {
+ removeValueAndDecrementGreaterValues(&operation.inputs, index);
+ removeValueAndDecrementGreaterValues(&operation.outputs, index);
+ }
+ removeValueAndDecrementGreaterValues(&model->main.inputIndexes, index);
+ removeValueAndDecrementGreaterValues(&model->main.outputIndexes, index);
+}
+
+static bool removeOperandSkip(size_t operandIndex, const Model& model,
+ const std::vector<uint32_t>& numberOfConsumers) {
+ if (numberOfConsumers[operandIndex] == 0) {
+ // Removing an unused operand has no effect.
+ return true;
+ }
+ for (const Operation& operation : model.main.operations) {
+ // Skip removeOperandTest for the following operations.
+ // - SPLIT's outputs are not checked during prepareModel.
+ if (operation.type == OperationType::SPLIT) {
+ for (const size_t index : operation.outputs) {
+ if (index == operandIndex) {
+ return true;
+ }
+ }
+ }
+ // BIDIRECTIONAL_SEQUENCE_LSTM and BIDIRECTIONAL_SEQUENCE_RNN can have
+ // either one, two, three or four outputs depending on their
+ // mergeOutputs parameter and if state outputs are provided.
+ // UNIDIRECTIONAL_SEQUENCE_LSTM and UNIDIRECTIONAL_SEQUENCE_RNN can have
+ // either one or three outputs depending on whether state outputs are
+ // provided.
+ if (operation.type == OperationType::UNIDIRECTIONAL_SEQUENCE_LSTM ||
+ operation.type == OperationType::UNIDIRECTIONAL_SEQUENCE_RNN ||
+ operation.type == OperationType::BIDIRECTIONAL_SEQUENCE_LSTM ||
+ operation.type == OperationType::BIDIRECTIONAL_SEQUENCE_RNN) {
+ for (const size_t index : operation.outputs) {
+ if (index == operandIndex) {
+ return true;
+ }
+ }
+ }
+ }
+ return false;
+}
+
+static void removeOperandTest(const std::shared_ptr<IDevice>& device, const Model& model,
+ const std::vector<uint32_t>& numberOfConsumers) {
+ for (size_t operand = 0; operand < model.main.operands.size(); ++operand) {
+ if (removeOperandSkip(operand, model, numberOfConsumers)) {
+ continue;
+ }
+ const std::string message = "removeOperandTest: operand " + std::to_string(operand);
+ validate(device, message, model, [operand](Model* model, ExecutionPreference*, Priority*) {
+ removeOperand(model, operand);
+ });
+ }
+}
+
+///////////////////////// REMOVE OPERATION /////////////////////////
+
+static void removeOperation(Model* model, uint32_t index) {
+ auto& operations = model->main.operations;
+ operations.erase(operations.begin() + index);
+}
+
+static void removeOperationTest(const std::shared_ptr<IDevice>& device, const Model& model) {
+ for (size_t operation = 0; operation < model.main.operations.size(); ++operation) {
+ const std::string message = "removeOperationTest: operation " + std::to_string(operation);
+ validate(device, message, model,
+ [operation](Model* model, ExecutionPreference*, Priority*) {
+ removeOperation(model, operation);
+ });
+ }
+}
+
+///////////////////////// REMOVE OPERATION INPUT /////////////////////////
+
+static bool removeOperationInputSkip(const Operation& op, size_t input) {
+ // Skip removeOperationInputTest for the following operations.
+ // - CONCATENATION has at least 2 inputs, with the last element being INT32.
+ // - CONV_2D, DEPTHWISE_CONV_2D, MAX_POOL_2D, AVERAGE_POOL_2D, L2_POOL_2D, RESIZE_BILINEAR,
+ // SPACE_TO_DEPTH, SPACE_TO_DEPTH, SPACE_TO_BATCH_ND, BATCH_TO_SPACE_ND can have an optional
+ // layout parameter.
+ // RESIZE_BILINEAR and RESIZE_NEAREST_NEIGHBOR can have optional
+ // align_corners and half_pixel_centers parameters.
+ // - L2_NORMALIZATION, LOCAL_RESPONSE_NORMALIZATION, SOFTMAX can have an optional axis
+ // parameter.
+ switch (op.type) {
+ case OperationType::CONCATENATION: {
+ if (op.inputs.size() > 2 && input != op.inputs.size() - 1) {
+ return true;
+ }
+ } break;
+ case OperationType::DEPTHWISE_CONV_2D: {
+ if ((op.inputs.size() == 12 && input == 11) || (op.inputs.size() == 9 && input == 8)) {
+ return true;
+ }
+ } break;
+ case OperationType::CONV_2D:
+ case OperationType::AVERAGE_POOL_2D:
+ case OperationType::MAX_POOL_2D:
+ case OperationType::L2_POOL_2D: {
+ if ((op.inputs.size() == 11 && input == 10) || (op.inputs.size() == 8 && input == 7)) {
+ return true;
+ }
+ } break;
+ case OperationType::RESIZE_BILINEAR: {
+ if (op.inputs.size() >= 4 && input >= 3) {
+ return true;
+ }
+ } break;
+ case OperationType::RESIZE_NEAREST_NEIGHBOR: {
+ if (op.inputs.size() >= 5 && input >= 3) {
+ return true;
+ }
+ } break;
+ case OperationType::SPACE_TO_DEPTH:
+ case OperationType::DEPTH_TO_SPACE:
+ case OperationType::BATCH_TO_SPACE_ND: {
+ if (op.inputs.size() == 3 && input == 2) {
+ return true;
+ }
+ } break;
+ case OperationType::SPACE_TO_BATCH_ND: {
+ if (op.inputs.size() == 4 && input == 3) {
+ return true;
+ }
+ } break;
+ case OperationType::L2_NORMALIZATION: {
+ if (op.inputs.size() == 2 && input == 1) {
+ return true;
+ }
+ } break;
+ case OperationType::LOCAL_RESPONSE_NORMALIZATION: {
+ if (op.inputs.size() == 6 && input == 5) {
+ return true;
+ }
+ } break;
+ case OperationType::SOFTMAX: {
+ if (op.inputs.size() == 3 && input == 2) {
+ return true;
+ }
+ } break;
+ default:
+ break;
+ }
+ return false;
+}
+
+static void removeOperationInputTest(const std::shared_ptr<IDevice>& device, const Model& model) {
+ for (size_t operation = 0; operation < model.main.operations.size(); ++operation) {
+ for (size_t input = 0; input < model.main.operations[operation].inputs.size(); ++input) {
+ const Operation& op = model.main.operations[operation];
+ if (removeOperationInputSkip(op, input)) {
+ continue;
+ }
+ const std::string message = "removeOperationInputTest: operation " +
+ std::to_string(operation) + ", input " +
+ std::to_string(input);
+ validate(device, message, model,
+ [operation, input](Model* model, ExecutionPreference*, Priority*) {
+ auto& inputs = model->main.operations[operation].inputs;
+ inputs.erase(inputs.begin() + input);
+ });
+ }
+ }
+}
+
+///////////////////////// REMOVE OPERATION OUTPUT /////////////////////////
+
+static void removeOperationOutputTest(const std::shared_ptr<IDevice>& device, const Model& model) {
+ for (size_t operation = 0; operation < model.main.operations.size(); ++operation) {
+ for (size_t output = 0; output < model.main.operations[operation].outputs.size();
+ ++output) {
+ const std::string message = "removeOperationOutputTest: operation " +
+ std::to_string(operation) + ", output " +
+ std::to_string(output);
+ validate(device, message, model,
+ [operation, output](Model* model, ExecutionPreference*, Priority*) {
+ auto& outputs = model->main.operations[operation].outputs;
+ outputs.erase(outputs.begin() + output);
+ });
+ }
+ }
+}
+
+///////////////////////// MODEL VALIDATION /////////////////////////
+
+// TODO: remove model input
+// TODO: remove model output
+// TODO: add unused operation
+
+///////////////////////// ADD OPERATION INPUT /////////////////////////
+
+static bool addOperationInputSkip(const Operation& op) {
+ // Skip addOperationInputTest for the following operations.
+ // - L2_NORMALIZATION, LOCAL_RESPONSE_NORMALIZATION, SOFTMAX can have an optional INT32 axis
+ // parameter.
+ if ((op.type == OperationType::L2_NORMALIZATION && op.inputs.size() == 1) ||
+ (op.type == OperationType::LOCAL_RESPONSE_NORMALIZATION && op.inputs.size() == 5) ||
+ (op.type == OperationType::SOFTMAX && op.inputs.size() == 2) ||
+ (op.type == OperationType::RESIZE_BILINEAR && op.inputs.size() < 6) ||
+ (op.type == OperationType::RESIZE_NEAREST_NEIGHBOR && op.inputs.size() < 6)) {
+ return true;
+ }
+ return false;
+}
+
+static void addOperationInputTest(const std::shared_ptr<IDevice>& device, const Model& model) {
+ for (size_t operation = 0; operation < model.main.operations.size(); ++operation) {
+ if (addOperationInputSkip(model.main.operations[operation])) {
+ continue;
+ }
+ const std::string message = "addOperationInputTest: operation " + std::to_string(operation);
+ validate(device, message, model,
+ [operation](Model* model, ExecutionPreference*, Priority*) {
+ uint32_t index = addOperand(model, OperandLifeTime::SUBGRAPH_INPUT);
+ model->main.operations[operation].inputs.push_back(index);
+ model->main.inputIndexes.push_back(index);
+ });
+ }
+}
+
+///////////////////////// ADD OPERATION OUTPUT /////////////////////////
+
+static void addOperationOutputTest(const std::shared_ptr<IDevice>& device, const Model& model) {
+ for (size_t operation = 0; operation < model.main.operations.size(); ++operation) {
+ const std::string message =
+ "addOperationOutputTest: operation " + std::to_string(operation);
+ validate(device, message, model,
+ [operation](Model* model, ExecutionPreference*, Priority*) {
+ uint32_t index = addOperand(model, OperandLifeTime::SUBGRAPH_OUTPUT);
+ model->main.operations[operation].outputs.push_back(index);
+ model->main.outputIndexes.push_back(index);
+ });
+ }
+}
+
+///////////////////////// VALIDATE EXECUTION PREFERENCE /////////////////////////
+
+static const int32_t invalidExecutionPreferences[] = {
+ static_cast<int32_t>(ExecutionPreference::LOW_POWER) - 1, // lower bound
+ static_cast<int32_t>(ExecutionPreference::SUSTAINED_SPEED) + 1, // upper bound
+};
+
+static void mutateExecutionPreferenceTest(const std::shared_ptr<IDevice>& device,
+ const Model& model) {
+ for (int32_t invalidPreference : invalidExecutionPreferences) {
+ const std::string message =
+ "mutateExecutionPreferenceTest: preference " + std::to_string(invalidPreference);
+ validate(device, message, model,
+ [invalidPreference](Model*, ExecutionPreference* preference, Priority*) {
+ *preference = static_cast<ExecutionPreference>(invalidPreference);
+ });
+ }
+}
+
+///////////////////////// VALIDATE PRIORITY /////////////////////////
+
+static const int32_t invalidPriorities[] = {
+ static_cast<int32_t>(Priority::LOW) - 1, // lower bound
+ static_cast<int32_t>(Priority::HIGH) + 1, // upper bound
+};
+
+static void mutateExecutionPriorityTest(const std::shared_ptr<IDevice>& device,
+ const Model& model) {
+ for (int32_t invalidPriority : invalidPriorities) {
+ const std::string message =
+ "mutatePriorityTest: priority " + std::to_string(invalidPriority);
+ validate(device, message, model,
+ [invalidPriority](Model*, ExecutionPreference*, Priority* priority) {
+ *priority = static_cast<Priority>(invalidPriority);
+ });
+ }
+}
+
+////////////////////////// ENTRY POINT //////////////////////////////
+
+void validateModel(const std::shared_ptr<IDevice>& device, const Model& model) {
+ const auto numberOfConsumers = nn::countNumberOfConsumers(
+ model.main.operands.size(), nn::convert(model.main.operations).value());
+ mutateExecutionOrderTest(device, model, numberOfConsumers);
+ mutateOperandTypeTest(device, model);
+ mutateOperandRankTest(device, model);
+ mutateOperandScaleTest(device, model);
+ mutateOperandZeroPointTest(device, model);
+ mutateOperandLifeTimeTest(device, model);
+ mutateOperandInputOutputTest(device, model);
+ mutateOperandAddWriterTest(device, model);
+ mutateOperationOperandTypeTest(device, model);
+ mutateOperationTypeTest(device, model);
+ mutateOperationInputOperandIndexTest(device, model);
+ mutateOperationOutputOperandIndexTest(device, model);
+ mutateOperationRemoveWriteTest(device, model, numberOfConsumers);
+ removeOperandTest(device, model, numberOfConsumers);
+ removeOperationTest(device, model);
+ removeOperationInputTest(device, model);
+ removeOperationOutputTest(device, model);
+ addOperationInputTest(device, model);
+ addOperationOutputTest(device, model);
+ mutateExecutionPreferenceTest(device, model);
+ mutateExecutionPriorityTest(device, model);
+}
+
+} // namespace aidl::android::hardware::neuralnetworks::vts::functional