diff options
Diffstat (limited to 'rs/java')
-rw-r--r-- | rs/java/android/renderscript/Allocation.java | 498 | ||||
-rw-r--r-- | rs/java/android/renderscript/AllocationAdapter.java | 235 | ||||
-rw-r--r-- | rs/java/android/renderscript/Element.java | 48 | ||||
-rw-r--r-- | rs/java/android/renderscript/FieldPacker.java | 191 | ||||
-rw-r--r-- | rs/java/android/renderscript/FileA3D.java | 3 | ||||
-rw-r--r-- | rs/java/android/renderscript/Mesh.java | 10 | ||||
-rw-r--r-- | rs/java/android/renderscript/Path.java | 87 | ||||
-rw-r--r-- | rs/java/android/renderscript/RenderScript.java | 487 | ||||
-rw-r--r-- | rs/java/android/renderscript/RenderScriptCacheDir.java | 40 | ||||
-rw-r--r-- | rs/java/android/renderscript/Script.java | 171 | ||||
-rw-r--r-- | rs/java/android/renderscript/ScriptC.java | 4 | ||||
-rw-r--r-- | rs/java/android/renderscript/ScriptGroup2.java | 449 | ||||
-rw-r--r-- | rs/java/android/renderscript/ScriptIntrinsicBLAS.java | 1510 | ||||
-rw-r--r-- | rs/java/android/renderscript/ScriptIntrinsicBlur.java | 2 | ||||
-rw-r--r-- | rs/java/android/renderscript/ScriptIntrinsicResize.java | 8 | ||||
-rw-r--r-- | rs/java/android/renderscript/Type.java | 62 |
16 files changed, 3434 insertions, 371 deletions
diff --git a/rs/java/android/renderscript/Allocation.java b/rs/java/android/renderscript/Allocation.java index 4e895669653e..4fa2c81fec60 100644 --- a/rs/java/android/renderscript/Allocation.java +++ b/rs/java/android/renderscript/Allocation.java @@ -58,15 +58,14 @@ public class Allocation extends BaseObj { Allocation mAdaptedAllocation; int mSize; - boolean mConstrainedLOD; - boolean mConstrainedFace; - boolean mConstrainedY; - boolean mConstrainedZ; boolean mReadAllowed = true; boolean mWriteAllowed = true; + boolean mAutoPadding = false; + int mSelectedX; int mSelectedY; int mSelectedZ; int mSelectedLOD; + int mSelectedArray[]; Type.CubemapFace mSelectedFace = Type.CubemapFace.POSITIVE_X; int mCurrentDimX; @@ -77,6 +76,8 @@ public class Allocation extends BaseObj { new HashMap<Long, Allocation>(); OnBufferAvailableListener mBufferNotifier; + private Surface mGetSurfaceSurface = null; + private Element.DataType validateObjectIsPrimitiveArray(Object d, boolean checkType) { final Class c = d.getClass(); if (!c.isArray()) { @@ -272,6 +273,17 @@ public class Allocation extends BaseObj { } /** + * @hide + * Enable/Disable AutoPadding for Vec3 elements. + * + * @param useAutoPadding True: enable AutoPadding; False: disable AutoPadding + * + */ + public void setAutoPadding(boolean useAutoPadding) { + mAutoPadding = useAutoPadding; + } + + /** * Get the size of the Allocation in bytes. * * @return size of the Allocation in bytes. @@ -787,6 +799,7 @@ public class Allocation extends BaseObj { copy1DRangeFromUnchecked(xoff, count, data); } + /** * This is only intended to be used by auto-generated code reflected from * the RenderScript script files. @@ -796,12 +809,33 @@ public class Allocation extends BaseObj { * @param fp */ public void setFromFieldPacker(int xoff, int component_number, FieldPacker fp) { + setFromFieldPacker(xoff, 0, 0, component_number, fp); + } + + /** + * @hide + * This is only intended to be used by auto-generated code reflected from + * the RenderScript script files. + * + * @param xoff + * @param yoff + * @param zoff + * @param component_number + * @param fp + */ + public void setFromFieldPacker(int xoff, int yoff, int zoff, int component_number, FieldPacker fp) { mRS.validate(); if (component_number >= mType.mElement.mElements.length) { throw new RSIllegalArgumentException("Component_number " + component_number + " out of range."); } if(xoff < 0) { - throw new RSIllegalArgumentException("Offset must be >= 0."); + throw new RSIllegalArgumentException("Offset x must be >= 0."); + } + if(yoff < 0) { + throw new RSIllegalArgumentException("Offset y must be >= 0."); + } + if(zoff < 0) { + throw new RSIllegalArgumentException("Offset z must be >= 0."); } final byte[] data = fp.getData(); @@ -814,11 +848,11 @@ public class Allocation extends BaseObj { " does not match component size " + eSize + "."); } - mRS.nAllocationElementData1D(getIDSafe(), xoff, mSelectedLOD, - component_number, data, data_length); + mRS.nAllocationElementData(getIDSafe(), xoff, yoff, zoff, mSelectedLOD, + component_number, data, data_length); } - private void data1DChecks(int off, int count, int len, int dataSize) { + private void data1DChecks(int off, int count, int len, int dataSize, boolean usePadding) { mRS.validate(); if(off < 0) { throw new RSIllegalArgumentException("Offset must be >= 0."); @@ -830,8 +864,14 @@ public class Allocation extends BaseObj { throw new RSIllegalArgumentException("Overflow, Available count " + mCurrentCount + ", got " + count + " at offset " + off + "."); } - if(len < dataSize) { - throw new RSIllegalArgumentException("Array too small for allocation type."); + if(usePadding) { + if(len < dataSize / 4 * 3) { + throw new RSIllegalArgumentException("Array too small for allocation type."); + } + } else { + if(len < dataSize) { + throw new RSIllegalArgumentException("Array too small for allocation type."); + } } } @@ -853,8 +893,14 @@ public class Allocation extends BaseObj { Element.DataType dt, int arrayLen) { Trace.traceBegin(RenderScript.TRACE_TAG, "copy1DRangeFromUnchecked"); final int dataSize = mType.mElement.getBytesSize() * count; - data1DChecks(off, count, arrayLen * dt.mSize, dataSize); - mRS.nAllocationData1D(getIDSafe(), off, mSelectedLOD, count, array, dataSize, dt); + // AutoPadding for Vec3 Element + boolean usePadding = false; + if (mAutoPadding && (mType.getElement().getVectorSize() == 3)) { + usePadding = true; + } + data1DChecks(off, count, arrayLen * dt.mSize, dataSize, usePadding); + mRS.nAllocationData1D(getIDSafe(), off, mSelectedLOD, count, array, dataSize, dt, + mType.mElement.mType.mSize, usePadding); Trace.traceEnd(RenderScript.TRACE_TAG); } @@ -1031,8 +1077,24 @@ public class Allocation extends BaseObj { Trace.traceBegin(RenderScript.TRACE_TAG, "copy2DRangeFromUnchecked"); mRS.validate(); validate2DRange(xoff, yoff, w, h); + final int dataSize = mType.mElement.getBytesSize() * w * h; + // AutoPadding for Vec3 Element + boolean usePadding = false; + int sizeBytes = arrayLen * dt.mSize; + if (mAutoPadding && (mType.getElement().getVectorSize() == 3)) { + if (dataSize / 4 * 3 > sizeBytes) { + throw new RSIllegalArgumentException("Array too small for allocation type."); + } + usePadding = true; + sizeBytes = dataSize; + } else { + if (dataSize > sizeBytes) { + throw new RSIllegalArgumentException("Array too small for allocation type."); + } + } mRS.nAllocationData2D(getIDSafe(), xoff, yoff, mSelectedLOD, mSelectedFace.mID, w, h, - array, arrayLen * dt.mSize, dt); + array, sizeBytes, dt, + mType.mElement.mType.mSize, usePadding); Trace.traceEnd(RenderScript.TRACE_TAG); } @@ -1193,8 +1255,24 @@ public class Allocation extends BaseObj { Trace.traceBegin(RenderScript.TRACE_TAG, "copy3DRangeFromUnchecked"); mRS.validate(); validate3DRange(xoff, yoff, zoff, w, h, d); + final int dataSize = mType.mElement.getBytesSize() * w * h * d; + // AutoPadding for Vec3 Element + boolean usePadding = false; + int sizeBytes = arrayLen * dt.mSize; + if (mAutoPadding && (mType.getElement().getVectorSize() == 3)) { + if (dataSize / 4 * 3 > sizeBytes) { + throw new RSIllegalArgumentException("Array too small for allocation type."); + } + usePadding = true; + sizeBytes = dataSize; + } else { + if (dataSize > sizeBytes) { + throw new RSIllegalArgumentException("Array too small for allocation type."); + } + } mRS.nAllocationData3D(getIDSafe(), xoff, yoff, zoff, mSelectedLOD, w, h, d, - array, arrayLen * dt.mSize, dt); + array, sizeBytes, dt, + mType.mElement.mType.mSize, usePadding); Trace.traceEnd(RenderScript.TRACE_TAG); } @@ -1209,7 +1287,7 @@ public class Allocation extends BaseObj { * @param w Width of the region to update * @param h Height of the region to update * @param d Depth of the region to update - * @param data to be placed into the allocation + * @param array to be placed into the allocation */ public void copy3DRangeFrom(int xoff, int yoff, int zoff, int w, int h, int d, Object array) { Trace.traceBegin(RenderScript.TRACE_TAG, "copy3DRangeFrom"); @@ -1262,12 +1340,23 @@ public class Allocation extends BaseObj { private void copyTo(Object array, Element.DataType dt, int arrayLen) { Trace.traceBegin(RenderScript.TRACE_TAG, "copyTo"); - if (dt.mSize * arrayLen < mSize) { - throw new RSIllegalArgumentException( - "Size of output array cannot be smaller than size of allocation."); - } mRS.validate(); - mRS.nAllocationRead(getID(mRS), array, dt); + boolean usePadding = false; + if (mAutoPadding && (mType.getElement().getVectorSize() == 3)) { + usePadding = true; + } + if (usePadding) { + if (dt.mSize * arrayLen < mSize / 4 * 3) { + throw new RSIllegalArgumentException( + "Size of output array cannot be smaller than size of allocation."); + } + } else { + if (dt.mSize * arrayLen < mSize) { + throw new RSIllegalArgumentException( + "Size of output array cannot be smaller than size of allocation."); + } + } + mRS.nAllocationRead(getID(mRS), array, dt, mType.mElement.mType.mSize, usePadding); Trace.traceEnd(RenderScript.TRACE_TAG); } @@ -1333,6 +1422,45 @@ public class Allocation extends BaseObj { } /** + * @hide + * This is only intended to be used by auto-generated code reflected from + * the RenderScript script files and should not be used by developers. + * + * @param xoff + * @param yoff + * @param zoff + * @param component_number + * @param array + */ + public void copyToFieldPacker(int xoff, int yoff, int zoff, int component_number, FieldPacker fp) { + mRS.validate(); + if (component_number >= mType.mElement.mElements.length) { + throw new RSIllegalArgumentException("Component_number " + component_number + " out of range."); + } + if(xoff < 0) { + throw new RSIllegalArgumentException("Offset x must be >= 0."); + } + if(yoff < 0) { + throw new RSIllegalArgumentException("Offset y must be >= 0."); + } + if(zoff < 0) { + throw new RSIllegalArgumentException("Offset z must be >= 0."); + } + + final byte[] data = fp.getData(); + int data_length = fp.getPos(); + int eSize = mType.mElement.mElements[component_number].getBytesSize(); + eSize *= mType.mElement.mArraySizes[component_number]; + + if (data_length != eSize) { + throw new RSIllegalArgumentException("Field packer sizelength " + data_length + + " does not match component size " + eSize + "."); + } + + mRS.nAllocationElementRead(getIDSafe(), xoff, yoff, zoff, mSelectedLOD, + component_number, data, data_length); + } + /** * Resize a 1D allocation. The contents of the allocation are preserved. * If new elements are allocated objects are created with null contents and * the new region is otherwise undefined. @@ -1364,6 +1492,318 @@ public class Allocation extends BaseObj { updateCacheInfo(mType); } + private void copy1DRangeToUnchecked(int off, int count, Object array, + Element.DataType dt, int arrayLen) { + Trace.traceBegin(RenderScript.TRACE_TAG, "copy1DRangeToUnchecked"); + final int dataSize = mType.mElement.getBytesSize() * count; + // AutoPadding for Vec3 Element + boolean usePadding = false; + if (mAutoPadding && (mType.getElement().getVectorSize() == 3)) { + usePadding = true; + } + data1DChecks(off, count, arrayLen * dt.mSize, dataSize, usePadding); + mRS.nAllocationRead1D(getIDSafe(), off, mSelectedLOD, count, array, dataSize, dt, + mType.mElement.mType.mSize, usePadding); + Trace.traceEnd(RenderScript.TRACE_TAG); + } + + /** + * @hide + * Copy part of this Allocation into an array. This method does not + * guarantee that the Allocation is compatible with the input buffer. + * + * @param off The offset of the first element to be copied. + * @param count The number of elements to be copied. + * @param array The dest data array + */ + public void copy1DRangeToUnchecked(int off, int count, Object array) { + copy1DRangeToUnchecked(off, count, array, + validateObjectIsPrimitiveArray(array, false), + java.lang.reflect.Array.getLength(array)); + } + + /** + * @hide + * Copy part of this Allocation into an array. This method does not + * guarantee that the Allocation is compatible with the input buffer. + * + * @param off The offset of the first element to be copied. + * @param count The number of elements to be copied. + * @param d the source data array + */ + public void copy1DRangeToUnchecked(int off, int count, int[] d) { + copy1DRangeToUnchecked(off, count, (Object)d, Element.DataType.SIGNED_32, d.length); + } + + /** + * @hide + * Copy part of this Allocation into an array. This method does not + * guarantee that the Allocation is compatible with the input buffer. + * + * @param off The offset of the first element to be copied. + * @param count The number of elements to be copied. + * @param d the source data array + */ + public void copy1DRangeToUnchecked(int off, int count, short[] d) { + copy1DRangeToUnchecked(off, count, (Object)d, Element.DataType.SIGNED_16, d.length); + } + + /** + * @hide + * Copy part of this Allocation into an array. This method does not + * guarantee that the Allocation is compatible with the input buffer. + * + * @param off The offset of the first element to be copied. + * @param count The number of elements to be copied. + * @param d the source data array + */ + public void copy1DRangeToUnchecked(int off, int count, byte[] d) { + copy1DRangeToUnchecked(off, count, (Object)d, Element.DataType.SIGNED_8, d.length); + } + + /** + * @hide + * Copy part of this Allocation into an array. This method does not + * guarantee that the Allocation is compatible with the input buffer. + * + * @param off The offset of the first element to be copied. + * @param count The number of elements to be copied. + * @param d the source data array + */ + public void copy1DRangeToUnchecked(int off, int count, float[] d) { + copy1DRangeToUnchecked(off, count, (Object)d, Element.DataType.FLOAT_32, d.length); + } + + + /** + * @hide + * Copy part of this Allocation into an array. This method does not + * and will generate exceptions if the Allocation type does not + * match the component type of the array passed in. + * + * @param off The offset of the first element to be copied. + * @param count The number of elements to be copied. + * @param array The source data array. + */ + public void copy1DRangeTo(int off, int count, Object array) { + copy1DRangeToUnchecked(off, count, array, + validateObjectIsPrimitiveArray(array, true), + java.lang.reflect.Array.getLength(array)); + } + + /** + * @hide + * Copy part of this Allocation into an array. This method does not + * and will generate exceptions if the Allocation type is not a 32 bit + * integer type. + * + * @param off The offset of the first element to be copied. + * @param count The number of elements to be copied. + * @param d the source data array + */ + public void copy1DRangeTo(int off, int count, int[] d) { + validateIsInt32(); + copy1DRangeToUnchecked(off, count, d, Element.DataType.SIGNED_32, d.length); + } + + /** + * @hide + * Copy part of this Allocation into an array. This method does not + * and will generate exceptions if the Allocation type is not a 16 bit + * integer type. + * + * @param off The offset of the first element to be copied. + * @param count The number of elements to be copied. + * @param d the source data array + */ + public void copy1DRangeTo(int off, int count, short[] d) { + validateIsInt16(); + copy1DRangeToUnchecked(off, count, d, Element.DataType.SIGNED_16, d.length); + } + + /** + * @hide + * Copy part of this Allocation into an array. This method does not + * and will generate exceptions if the Allocation type is not an 8 bit + * integer type. + * + * @param off The offset of the first element to be copied. + * @param count The number of elements to be copied. + * @param d the source data array + */ + public void copy1DRangeTo(int off, int count, byte[] d) { + validateIsInt8(); + copy1DRangeToUnchecked(off, count, d, Element.DataType.SIGNED_8, d.length); + } + + /** + * @hide + * Copy part of this Allocation into an array. This method does not + * and will generate exceptions if the Allocation type is not a 32 bit float + * type. + * + * @param off The offset of the first element to be copied. + * @param count The number of elements to be copied. + * @param d the source data array. + */ + public void copy1DRangeTo(int off, int count, float[] d) { + validateIsFloat32(); + copy1DRangeToUnchecked(off, count, d, Element.DataType.FLOAT_32, d.length); + } + + + void copy2DRangeToUnchecked(int xoff, int yoff, int w, int h, Object array, + Element.DataType dt, int arrayLen) { + Trace.traceBegin(RenderScript.TRACE_TAG, "copy2DRangeToUnchecked"); + mRS.validate(); + validate2DRange(xoff, yoff, w, h); + final int dataSize = mType.mElement.getBytesSize() * w * h; + // AutoPadding for Vec3 Element + boolean usePadding = false; + int sizeBytes = arrayLen * dt.mSize; + if (mAutoPadding && (mType.getElement().getVectorSize() == 3)) { + if (dataSize / 4 * 3 > sizeBytes) { + throw new RSIllegalArgumentException("Array too small for allocation type."); + } + usePadding = true; + sizeBytes = dataSize; + } else { + if (dataSize > sizeBytes) { + throw new RSIllegalArgumentException("Array too small for allocation type."); + } + } + mRS.nAllocationRead2D(getIDSafe(), xoff, yoff, mSelectedLOD, mSelectedFace.mID, w, h, + array, sizeBytes, dt, mType.mElement.mType.mSize, usePadding); + Trace.traceEnd(RenderScript.TRACE_TAG); + } + + /** + * @hide + * Copy from a rectangular region in this Allocation into an array. + * + * @param xoff X offset of the region to copy in this Allocation + * @param yoff Y offset of the region to copy in this Allocation + * @param w Width of the region to copy + * @param h Height of the region to copy + * @param array Dest Array to be copied into + */ + public void copy2DRangeTo(int xoff, int yoff, int w, int h, Object array) { + copy2DRangeToUnchecked(xoff, yoff, w, h, array, + validateObjectIsPrimitiveArray(array, true), + java.lang.reflect.Array.getLength(array)); + } + + /** + * @hide + * Copy from a rectangular region in this Allocation into an array. + * + * @param xoff X offset of the region to copy in this Allocation + * @param yoff Y offset of the region to copy in this Allocation + * @param w Width of the region to copy + * @param h Height of the region to copy + * @param data Dest Array to be copied into + */ + public void copy2DRangeTo(int xoff, int yoff, int w, int h, byte[] data) { + validateIsInt8(); + copy2DRangeToUnchecked(xoff, yoff, w, h, data, + Element.DataType.SIGNED_8, data.length); + } + + /** + * @hide + * Copy from a rectangular region in this Allocation into an array. + * + * @param xoff X offset of the region to copy in this Allocation + * @param yoff Y offset of the region to copy in this Allocation + * @param w Width of the region to copy + * @param h Height of the region to copy + * @param data Dest Array to be copied into + */ + public void copy2DRangeTo(int xoff, int yoff, int w, int h, short[] data) { + validateIsInt16(); + copy2DRangeToUnchecked(xoff, yoff, w, h, data, + Element.DataType.SIGNED_16, data.length); + } + + /** + * @hide + * Copy from a rectangular region in this Allocation into an array. + * + * @param xoff X offset of the region to copy in this Allocation + * @param yoff Y offset of the region to copy in this Allocation + * @param w Width of the region to copy + * @param h Height of the region to copy + * @param data Dest Array to be copied into + */ + public void copy2DRangeTo(int xoff, int yoff, int w, int h, int[] data) { + validateIsInt32(); + copy2DRangeToUnchecked(xoff, yoff, w, h, data, + Element.DataType.SIGNED_32, data.length); + } + + /** + * @hide + * Copy from a rectangular region in this Allocation into an array. + * + * @param xoff X offset of the region to copy in this Allocation + * @param yoff Y offset of the region to copy in this Allocation + * @param w Width of the region to copy + * @param h Height of the region to copy + * @param data Dest Array to be copied into + */ + public void copy2DRangeTo(int xoff, int yoff, int w, int h, float[] data) { + validateIsFloat32(); + copy2DRangeToUnchecked(xoff, yoff, w, h, data, + Element.DataType.FLOAT_32, data.length); + } + + + /** + * @hide + * + */ + private void copy3DRangeToUnchecked(int xoff, int yoff, int zoff, int w, int h, int d, + Object array, Element.DataType dt, int arrayLen) { + Trace.traceBegin(RenderScript.TRACE_TAG, "copy3DRangeToUnchecked"); + mRS.validate(); + validate3DRange(xoff, yoff, zoff, w, h, d); + final int dataSize = mType.mElement.getBytesSize() * w * h * d; + // AutoPadding for Vec3 Element + boolean usePadding = false; + int sizeBytes = arrayLen * dt.mSize; + if (mAutoPadding && (mType.getElement().getVectorSize() == 3)) { + if (dataSize / 4 * 3 > sizeBytes) { + throw new RSIllegalArgumentException("Array too small for allocation type."); + } + usePadding = true; + sizeBytes = dataSize; + } else { + if (dataSize > sizeBytes) { + throw new RSIllegalArgumentException("Array too small for allocation type."); + } + } + mRS.nAllocationRead3D(getIDSafe(), xoff, yoff, zoff, mSelectedLOD, w, h, d, + array, sizeBytes, dt, mType.mElement.mType.mSize, usePadding); + Trace.traceEnd(RenderScript.TRACE_TAG); + } + + /** + * @hide + * Copy from a rectangular region in this Allocation into an array. + * + * @param xoff X offset of the region to copy in this Allocation + * @param yoff Y offset of the region to copy in this Allocation + * @param zoff Z offset of the region to copy in this Allocation + * @param w Width of the region to copy + * @param h Height of the region to copy + * @param d Depth of the region to copy + * @param array Dest Array to be copied into + */ + public void copy3DRangeTo(int xoff, int yoff, int zoff, int w, int h, int d, Object array) { + copy3DRangeToUnchecked(xoff, yoff, zoff, w, h, d, array, + validateObjectIsPrimitiveArray(array, true), + java.lang.reflect.Array.getLength(array)); + } // creation @@ -1559,7 +1999,12 @@ public class Allocation extends BaseObj { if ((mUsage & USAGE_IO_INPUT) == 0) { throw new RSInvalidStateException("Allocation is not a surface texture."); } - return mRS.nAllocationGetSurface(getID(mRS)); + + if (mGetSurfaceSurface == null) { + mGetSurfaceSurface = mRS.nAllocationGetSurface(getID(mRS)); + } + + return mGetSurfaceSurface; } /** @@ -1882,4 +2327,15 @@ public class Allocation extends BaseObj { } } + /** + * For USAGE_IO_OUTPUT, destroy() implies setSurface(null). + * + */ + @Override + public void destroy() { + if((mUsage & USAGE_IO_OUTPUT) != 0) { + setSurface(null); + } + super.destroy(); + } } diff --git a/rs/java/android/renderscript/AllocationAdapter.java b/rs/java/android/renderscript/AllocationAdapter.java index 3522a52fba8a..183726fb2fc6 100644 --- a/rs/java/android/renderscript/AllocationAdapter.java +++ b/rs/java/android/renderscript/AllocationAdapter.java @@ -21,76 +21,20 @@ package android.renderscript; * **/ public class AllocationAdapter extends Allocation { - AllocationAdapter(long id, RenderScript rs, Allocation alloc) { + Type mWindow; + + AllocationAdapter(long id, RenderScript rs, Allocation alloc, Type t) { super(id, rs, alloc.mType, alloc.mUsage); mAdaptedAllocation = alloc; + mWindow = t; } + /* long getID(RenderScript rs) { throw new RSInvalidStateException( "This operation is not supported with adapters at this time."); } - - /** - * @hide - */ - public void subData(int xoff, FieldPacker fp) { - super.setFromFieldPacker(xoff, fp); - } - /** - * @hide - */ - public void subElementData(int xoff, int component_number, FieldPacker fp) { - super.setFromFieldPacker(xoff, component_number, fp); - } - /** - * @hide - */ - public void subData1D(int off, int count, int[] d) { - super.copy1DRangeFrom(off, count, d); - } - /** - * @hide - */ - public void subData1D(int off, int count, short[] d) { - super.copy1DRangeFrom(off, count, d); - } - /** - * @hide - */ - public void subData1D(int off, int count, byte[] d) { - super.copy1DRangeFrom(off, count, d); - } - /** - * @hide - */ - public void subData1D(int off, int count, float[] d) { - super.copy1DRangeFrom(off, count, d); - } - /** - * @hide - */ - public void subData2D(int xoff, int yoff, int w, int h, int[] d) { - super.copy2DRangeFrom(xoff, yoff, w, h, d); - } - /** - * @hide - */ - public void subData2D(int xoff, int yoff, int w, int h, float[] d) { - super.copy2DRangeFrom(xoff, yoff, w, h, d); - } - /** - * @hide - */ - public void readData(int[] d) { - super.copyTo(d); - } - /** - * @hide - */ - public void readData(float[] d) { - super.copyTo(d); - } + */ void initLOD(int lod) { if (lod < 0) { @@ -125,6 +69,28 @@ public class AllocationAdapter extends Allocation { mSelectedZ = 0; } + private void updateOffsets() { + int a1 = 0, a2 = 0, a3 = 0, a4 = 0; + + if (mSelectedArray != null) { + if (mSelectedArray.length > 0) { + a1 = mSelectedArray[0]; + } + if (mSelectedArray.length > 1) { + a2 = mSelectedArray[2]; + } + if (mSelectedArray.length > 2) { + a3 = mSelectedArray[2]; + } + if (mSelectedArray.length > 3) { + a4 = mSelectedArray[3]; + } + } + mRS.nAllocationAdapterOffset(getID(mRS), mSelectedX, mSelectedY, mSelectedZ, + mSelectedLOD, mSelectedFace.mID, a1, a2, a3, a4); + + } + /** * Set the active LOD. The LOD must be within the range for the * type being adapted. The base allocation must have mipmaps. @@ -138,11 +104,13 @@ public class AllocationAdapter extends Allocation { if (!mAdaptedAllocation.getType().hasMipmaps()) { throw new RSInvalidStateException("Cannot set LOD when the allocation type does not include mipmaps."); } - if (!mConstrainedLOD) { + if (mWindow.hasMipmaps()) { throw new RSInvalidStateException("Cannot set LOD when the adapter includes mipmaps."); } initLOD(lod); + mSelectedLOD = lod; + updateOffsets(); } /** @@ -155,14 +123,38 @@ public class AllocationAdapter extends Allocation { if (!mAdaptedAllocation.getType().hasFaces()) { throw new RSInvalidStateException("Cannot set Face when the allocation type does not include faces."); } - if (!mConstrainedFace) { - throw new RSInvalidStateException("Cannot set LOD when the adapter includes mipmaps."); + if (mWindow.hasFaces()) { + throw new RSInvalidStateException("Cannot set face when the adapter includes faces."); } if (cf == null) { throw new RSIllegalArgumentException("Cannot set null face."); } mSelectedFace = cf; + updateOffsets(); + } + + + /** + * @hide + * Set the active X. The x value must be within the range for + * the allocation being adapted. + * + * @param x The x to make active. + */ + public void setX(int x) { + if (mAdaptedAllocation.getType().getX() <= x) { + throw new RSInvalidStateException("Cannot set X greater than dimension of allocation."); + } + if (mWindow.getX() == mAdaptedAllocation.getType().getX()) { + throw new RSInvalidStateException("Cannot set X when the adapter includes X."); + } + if ((mWindow.getX() + x) >= mAdaptedAllocation.getType().getX()) { + throw new RSInvalidStateException("Cannot set (X + window) which would be larger than dimension of allocation."); + } + + mSelectedX = x; + updateOffsets(); } /** @@ -179,11 +171,15 @@ public class AllocationAdapter extends Allocation { if (mAdaptedAllocation.getType().getY() <= y) { throw new RSInvalidStateException("Cannot set Y greater than dimension of allocation."); } - if (!mConstrainedY) { + if (mWindow.getY() == mAdaptedAllocation.getType().getY()) { throw new RSInvalidStateException("Cannot set Y when the adapter includes Y."); } + if ((mWindow.getY() + y) >= mAdaptedAllocation.getType().getY()) { + throw new RSInvalidStateException("Cannot set (Y + window) which would be larger than dimension of allocation."); + } mSelectedY = y; + updateOffsets(); } /** @@ -200,35 +196,112 @@ public class AllocationAdapter extends Allocation { if (mAdaptedAllocation.getType().getZ() <= z) { throw new RSInvalidStateException("Cannot set Z greater than dimension of allocation."); } - if (!mConstrainedZ) { + if (mWindow.getZ() == mAdaptedAllocation.getType().getZ()) { throw new RSInvalidStateException("Cannot set Z when the adapter includes Z."); } + if ((mWindow.getZ() + z) >= mAdaptedAllocation.getType().getZ()) { + throw new RSInvalidStateException("Cannot set (Z + window) which would be larger than dimension of allocation."); + } mSelectedZ = z; + updateOffsets(); + } + + /** + * @hide + */ + public void setArray(int arrayNum, int arrayVal) { + if (mAdaptedAllocation.getType().getArray(arrayNum) == 0) { + throw new RSInvalidStateException("Cannot set arrayNum when the allocation type does not include arrayNum dim."); + } + if (mAdaptedAllocation.getType().getArray(arrayNum) <= arrayVal) { + throw new RSInvalidStateException("Cannot set arrayNum greater than dimension of allocation."); + } + if (mWindow.getArray(arrayNum) == mAdaptedAllocation.getType().getArray(arrayNum)) { + throw new RSInvalidStateException("Cannot set arrayNum when the adapter includes arrayNum."); + } + if ((mWindow.getArray(arrayNum) + arrayVal) >= mAdaptedAllocation.getType().getArray(arrayNum)) { + throw new RSInvalidStateException("Cannot set (arrayNum + window) which would be larger than dimension of allocation."); + } + + mSelectedArray[arrayNum] = arrayVal; + updateOffsets(); } static public AllocationAdapter create1D(RenderScript rs, Allocation a) { rs.validate(); - AllocationAdapter aa = new AllocationAdapter(0, rs, a); - aa.mConstrainedLOD = true; - aa.mConstrainedFace = true; - aa.mConstrainedY = true; - aa.mConstrainedZ = true; - aa.initLOD(0); - return aa; + Type t = Type.createX(rs, a.getElement(), a.getType().getX()); + return createTyped(rs, a, t); } + static public AllocationAdapter create2D(RenderScript rs, Allocation a) { rs.validate(); - AllocationAdapter aa = new AllocationAdapter(0, rs, a); - aa.mConstrainedLOD = true; - aa.mConstrainedFace = true; - aa.mConstrainedY = false; - aa.mConstrainedZ = true; - aa.initLOD(0); - return aa; + Type t = Type.createXY(rs, a.getElement(), a.getType().getX(), a.getType().getY()); + return createTyped(rs, a, t); } + /** + * @hide + * + * Create an arbitrary window into the base allocation + * The type describes the shape of the window. + * + * Any dimensions present in the type must be equal or smaller + * to the dimensions in the source allocation. A dimension + * present in the allocation that is not present in the type + * will be constrained away with the selectors + * + * If a dimension is present in the type and allcation one of + * two things will happen + * + * If the type is smaller than the allocation a window will be + * created, the selected value in the adapter for that dimension + * will act as the base address and the type will describe the + * size of the view starting at that point. + * + * If the type and allocation dimension are of the same size + * then setting the selector for the dimension will be an error. + */ + static public AllocationAdapter createTyped(RenderScript rs, Allocation a, Type t) { + rs.validate(); + + if (a.mAdaptedAllocation != null) { + throw new RSInvalidStateException("Adapters cannot be nested."); + } + + if (!a.getType().getElement().equals(t.getElement())) { + throw new RSInvalidStateException("Element must match Allocation type."); + } + + if (t.hasFaces() || t.hasMipmaps()) { + throw new RSInvalidStateException("Adapters do not support window types with Mipmaps or Faces."); + } + + Type at = a.getType(); + if ((t.getX() > at.getX()) || + (t.getY() > at.getY()) || + (t.getZ() > at.getZ()) || + (t.getArrayCount() > at.getArrayCount())) { + + throw new RSInvalidStateException("Type cannot have dimension larger than the source allocation."); + } + + if (t.getArrayCount() > 0) { + for (int i = 0; i < t.getArray(i); i++) { + if (t.getArray(i) > at.getArray(i)) { + throw new RSInvalidStateException("Type cannot have dimension larger than the source allocation."); + } + } + } + + // Create the object + long id = rs.nAllocationAdapterCreate(a.getID(rs), t.getID(rs)); + if (id == 0) { + throw new RSRuntimeException("AllocationAdapter creation failed."); + } + return new AllocationAdapter(id, rs, a, t); + } /** * Override the Allocation resize. Resizing adapters is not diff --git a/rs/java/android/renderscript/Element.java b/rs/java/android/renderscript/Element.java index c6b5b0d99ea2..60ff996d0505 100644 --- a/rs/java/android/renderscript/Element.java +++ b/rs/java/android/renderscript/Element.java @@ -114,11 +114,15 @@ public class Element extends BaseObj { * MATRIX the three matrix types contain FLOAT_32 elements and are treated * as 32 bits for alignment purposes. * - * RS_* objects. 32 bit opaque handles. + * RS_* objects: opaque handles with implementation dependent + * sizes. */ public enum DataType { NONE (0, 0), - //FLOAT_16 (1, 2), + /** + * @hide + */ + FLOAT_16 (1, 2), FLOAT_32 (2, 4), FLOAT_64 (3, 8), SIGNED_8 (4, 1), @@ -386,6 +390,16 @@ public class Element extends BaseObj { return rs.mElement_I64; } + /** + * @hide + */ + public static Element F16(RenderScript rs) { + if(rs.mElement_F16 == null) { + rs.mElement_F16 = createUser(rs, DataType.FLOAT_16); + } + return rs.mElement_F16; + } + public static Element F32(RenderScript rs) { if(rs.mElement_F32 == null) { rs.mElement_F32 = createUser(rs, DataType.FLOAT_32); @@ -520,6 +534,36 @@ public class Element extends BaseObj { return rs.mElement_RGBA_8888; } + /** + * @hide + */ + public static Element F16_2(RenderScript rs) { + if(rs.mElement_HALF_2 == null) { + rs.mElement_HALF_2 = createVector(rs, DataType.FLOAT_16, 2); + } + return rs.mElement_HALF_2; + } + + /** + * @hide + */ + public static Element F16_3(RenderScript rs) { + if(rs.mElement_FLOAT_3 == null) { + rs.mElement_FLOAT_3 = createVector(rs, DataType.FLOAT_16, 3); + } + return rs.mElement_HALF_3; + } + + /** + * @hide + */ + public static Element F16_4(RenderScript rs) { + if(rs.mElement_HALF_4 == null) { + rs.mElement_HALF_4 = createVector(rs, DataType.FLOAT_16, 4); + } + return rs.mElement_HALF_4; + } + public static Element F32_2(RenderScript rs) { if(rs.mElement_FLOAT_2 == null) { rs.mElement_FLOAT_2 = createVector(rs, DataType.FLOAT_32, 2); diff --git a/rs/java/android/renderscript/FieldPacker.java b/rs/java/android/renderscript/FieldPacker.java index 20b07e7592ac..de1c49730aaa 100644 --- a/rs/java/android/renderscript/FieldPacker.java +++ b/rs/java/android/renderscript/FieldPacker.java @@ -47,6 +47,15 @@ public class FieldPacker { // subAlign() can never work correctly for copied FieldPacker objects. } + static FieldPacker createFromArray(Object[] args) { + FieldPacker fp = new FieldPacker(RenderScript.sPointerSize * 8); + for (Object arg : args) { + fp.addSafely(arg); + } + fp.resize(fp.mPos); + return fp; + } + public void align(int v) { if ((v <= 0) || ((v & (v - 1)) != 0)) { throw new RSIllegalArgumentException("argument must be a non-negative non-zero power of 2: " + v); @@ -241,8 +250,7 @@ public class FieldPacker { addI64(0); addI64(0); addI64(0); - } - else { + } else { addI32((int)obj.getID(null)); } } else { @@ -619,11 +627,182 @@ public class FieldPacker { return mPos; } - private final byte mData[]; + private void add(Object obj) { + if (obj instanceof Boolean) { + addBoolean((Boolean)obj); + return; + } + + if (obj instanceof Byte) { + addI8((Byte)obj); + return; + } + + if (obj instanceof Short) { + addI16((Short)obj); + return; + } + + if (obj instanceof Integer) { + addI32((Integer)obj); + return; + } + + if (obj instanceof Long) { + addI64((Long)obj); + return; + } + + if (obj instanceof Float) { + addF32((Float)obj); + return; + } + + if (obj instanceof Double) { + addF64((Double)obj); + return; + } + + if (obj instanceof Byte2) { + addI8((Byte2)obj); + return; + } + + if (obj instanceof Byte3) { + addI8((Byte3)obj); + return; + } + + if (obj instanceof Byte4) { + addI8((Byte4)obj); + return; + } + + if (obj instanceof Short2) { + addI16((Short2)obj); + return; + } + + if (obj instanceof Short3) { + addI16((Short3)obj); + return; + } + + if (obj instanceof Short4) { + addI16((Short4)obj); + return; + } + + if (obj instanceof Int2) { + addI32((Int2)obj); + return; + } + + if (obj instanceof Int3) { + addI32((Int3)obj); + return; + } + + if (obj instanceof Int4) { + addI32((Int4)obj); + return; + } + + if (obj instanceof Long2) { + addI64((Long2)obj); + return; + } + + if (obj instanceof Long3) { + addI64((Long3)obj); + return; + } + + if (obj instanceof Long4) { + addI64((Long4)obj); + return; + } + + if (obj instanceof Float2) { + addF32((Float2)obj); + return; + } + + if (obj instanceof Float3) { + addF32((Float3)obj); + return; + } + + if (obj instanceof Float4) { + addF32((Float4)obj); + return; + } + + if (obj instanceof Double2) { + addF64((Double2)obj); + return; + } + + if (obj instanceof Double3) { + addF64((Double3)obj); + return; + } + + if (obj instanceof Double4) { + addF64((Double4)obj); + return; + } + + if (obj instanceof Matrix2f) { + addMatrix((Matrix2f)obj); + return; + } + + if (obj instanceof Matrix3f) { + addMatrix((Matrix3f)obj); + return; + } + + if (obj instanceof Matrix4f) { + addMatrix((Matrix4f)obj); + return; + } + + if (obj instanceof BaseObj) { + addObj((BaseObj)obj); + return; + } + } + + private boolean resize(int newSize) { + if (newSize == mLen) { + return false; + } + + byte[] newData = new byte[newSize]; + System.arraycopy(mData, 0, newData, 0, mPos); + mData = newData; + mLen = newSize; + return true; + } + + private void addSafely(Object obj) { + boolean retry; + final int oldPos = mPos; + do { + retry = false; + try { + add(obj); + } catch (ArrayIndexOutOfBoundsException e) { + mPos = oldPos; + resize(mLen * 2); + retry = true; + } + } while (retry); + } + + private byte mData[]; private int mPos; private int mLen; private BitSet mAlignment; - } - - diff --git a/rs/java/android/renderscript/FileA3D.java b/rs/java/android/renderscript/FileA3D.java index 41648101cc88..9d8f1624a051 100644 --- a/rs/java/android/renderscript/FileA3D.java +++ b/rs/java/android/renderscript/FileA3D.java @@ -145,6 +145,9 @@ public class FileA3D extends BaseObj { case MESH: entry.mLoadedObj = new Mesh(objectID, rs); break; + + default: + throw new RSRuntimeException("Unrecognized object type in file."); } entry.mLoadedObj.updateFromNative(); diff --git a/rs/java/android/renderscript/Mesh.java b/rs/java/android/renderscript/Mesh.java index 1a5dc9e7a6c7..13c8e1c91052 100644 --- a/rs/java/android/renderscript/Mesh.java +++ b/rs/java/android/renderscript/Mesh.java @@ -363,6 +363,9 @@ public class Mesh extends BaseObj { alloc = Allocation.createTyped(mRS, entry.t, mUsage); } else if(entry.e != null) { alloc = Allocation.createSized(mRS, entry.e, entry.size, mUsage); + } else { + // Should never happen because the builder will always set one + throw new IllegalStateException("Builder corrupt, no valid element in entry."); } vertexBuffers[ct] = alloc; vtx[ct] = alloc.getID(mRS); @@ -375,6 +378,9 @@ public class Mesh extends BaseObj { alloc = Allocation.createTyped(mRS, entry.t, mUsage); } else if(entry.e != null) { alloc = Allocation.createSized(mRS, entry.e, entry.size, mUsage); + } else { + // Should never happen because the builder will always set one + throw new IllegalStateException("Builder corrupt, no valid element in entry."); } long allocID = (alloc == null) ? 0 : alloc.getID(mRS); indexBuffers[ct] = alloc; @@ -811,9 +817,7 @@ public class Mesh extends BaseObj { sm.getVertexAllocation(0).copy1DRangeFromUnchecked(0, mMaxIndex, mVtxData); if(uploadToBufferObject) { - if (uploadToBufferObject) { - sm.getVertexAllocation(0).syncAll(Allocation.USAGE_SCRIPT); - } + sm.getVertexAllocation(0).syncAll(Allocation.USAGE_SCRIPT); } sm.getIndexSetAllocation(0).copy1DRangeFromUnchecked(0, mIndexCount, mIndexData); diff --git a/rs/java/android/renderscript/Path.java b/rs/java/android/renderscript/Path.java deleted file mode 100644 index f3502aacbd64..000000000000 --- a/rs/java/android/renderscript/Path.java +++ /dev/null @@ -1,87 +0,0 @@ -/* - * Copyright (C) 2008 The Android Open Source Project - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package android.renderscript; - -/** - * @hide - * - */ -public class Path extends BaseObj { - - public enum Primitive { - QUADRATIC_BEZIER(0), - CUBIC_BEZIER(1); - - int mID; - Primitive(int id) { - mID = id; - } - } - - Allocation mVertexBuffer; - Allocation mLoopBuffer; - Primitive mPrimitive; - float mQuality; - boolean mCoverageToAlpha; - - Path(long id, RenderScript rs, Primitive p, Allocation vtx, Allocation loop, float q) { - super(id, rs); - mVertexBuffer = vtx; - mLoopBuffer = loop; - mPrimitive = p; - mQuality = q; - } - - public Allocation getVertexAllocation() { - return mVertexBuffer; - } - - public Allocation getLoopAllocation() { - return mLoopBuffer; - } - - public Primitive getPrimitive() { - return mPrimitive; - } - - @Override - void updateFromNative() { - } - - - public static Path createStaticPath(RenderScript rs, Primitive p, float quality, Allocation vtx) { - long id = rs.nPathCreate(p.mID, false, vtx.getID(rs), 0, quality); - Path newPath = new Path(id, rs, p, null, null, quality); - return newPath; - } - - public static Path createStaticPath(RenderScript rs, Primitive p, float quality, Allocation vtx, Allocation loops) { - return null; - } - - public static Path createDynamicPath(RenderScript rs, Primitive p, float quality, Allocation vtx) { - return null; - } - - public static Path createDynamicPath(RenderScript rs, Primitive p, float quality, Allocation vtx, Allocation loops) { - return null; - } - - -} - - diff --git a/rs/java/android/renderscript/RenderScript.java b/rs/java/android/renderscript/RenderScript.java index e86c46153b08..6b1939c679b0 100644 --- a/rs/java/android/renderscript/RenderScript.java +++ b/rs/java/android/renderscript/RenderScript.java @@ -29,6 +29,7 @@ import android.util.Log; import android.view.Surface; import android.os.SystemProperties; import android.os.Trace; +import java.util.ArrayList; /** * This class provides access to a RenderScript context, which controls RenderScript @@ -49,6 +50,12 @@ public class RenderScript { @SuppressWarnings({"UnusedDeclaration", "deprecation"}) static final boolean LOG_ENABLED = false; + static private ArrayList<RenderScript> mProcessContextList = new ArrayList<RenderScript>(); + private boolean mIsProcessContext = false; + private int mContextFlags = 0; + private int mContextSdkVersion = 0; + + private Context mApplicationContext; /* @@ -123,8 +130,6 @@ public class RenderScript { native void nContextInitToClient(long con); native void nContextDeinitToClient(long con); - static File mCacheDir; - // this should be a monotonically increasing ID // used in conjunction with the API version of a device static final long sMinorID = 1; @@ -139,23 +144,6 @@ public class RenderScript { return sMinorID; } - /** - * Sets the directory to use as a persistent storage for the - * renderscript object file cache. - * - * @hide - * @param cacheDir A directory the current process can write to - */ - public static void setupDiskCache(File cacheDir) { - if (!sInitialized) { - Log.e(LOG_TAG, "RenderScript.setupDiskCache() called when disabled"); - return; - } - - // Defer creation of cache path to nScriptCCreate(). - mCacheDir = cacheDir; - } - /** * ContextType specifies the specific type of context to be created. * @@ -244,6 +232,11 @@ public class RenderScript { validate(); rsnContextSetPriority(mContext, p); } + native void rsnContextSetCacheDir(long con, String cacheDir); + synchronized void nContextSetCacheDir(String cacheDir) { + validate(); + rsnContextSetCacheDir(mContext, cacheDir); + } native void rsnContextDump(long con, int bits); synchronized void nContextDump(int bits) { validate(); @@ -302,6 +295,57 @@ public class RenderScript { rsnContextResume(mContext); } + native long rsnClosureCreate(long con, long kernelID, long returnValue, + long[] fieldIDs, long[] values, int[] sizes, long[] depClosures, + long[] depFieldIDs); + synchronized long nClosureCreate(long kernelID, long returnValue, + long[] fieldIDs, long[] values, int[] sizes, long[] depClosures, + long[] depFieldIDs) { + validate(); + return rsnClosureCreate(mContext, kernelID, returnValue, fieldIDs, values, + sizes, depClosures, depFieldIDs); + } + + native long rsnInvokeClosureCreate(long con, long invokeID, byte[] params, + long[] fieldIDs, long[] values, int[] sizes); + synchronized long nInvokeClosureCreate(long invokeID, byte[] params, + long[] fieldIDs, long[] values, int[] sizes) { + validate(); + return rsnInvokeClosureCreate(mContext, invokeID, params, fieldIDs, + values, sizes); + } + + native void rsnClosureSetArg(long con, long closureID, int index, + long value, int size); + synchronized void nClosureSetArg(long closureID, int index, long value, + int size) { + validate(); + rsnClosureSetArg(mContext, closureID, index, value, size); + } + + native void rsnClosureSetGlobal(long con, long closureID, long fieldID, + long value, int size); + // Does this have to be synchronized? + synchronized void nClosureSetGlobal(long closureID, long fieldID, + long value, int size) { + validate(); // TODO: is this necessary? + rsnClosureSetGlobal(mContext, closureID, fieldID, value, size); + } + + native long rsnScriptGroup2Create(long con, String name, String cachePath, + long[] closures); + synchronized long nScriptGroup2Create(String name, String cachePath, + long[] closures) { + validate(); + return rsnScriptGroup2Create(mContext, name, cachePath, closures); + } + + native void rsnScriptGroup2Execute(long con, long groupID); + synchronized void nScriptGroup2Execute(long groupID) { + validate(); + rsnScriptGroup2Execute(mContext, groupID); + } + native void rsnAssignName(long con, long obj, byte[] name); synchronized void nAssignName(long obj, byte[] name) { validate(); @@ -436,16 +480,18 @@ public class RenderScript { } - native void rsnAllocationData1D(long con, long id, int off, int mip, int count, Object d, int sizeBytes, int dt); - synchronized void nAllocationData1D(long id, int off, int mip, int count, Object d, int sizeBytes, Element.DataType dt) { + native void rsnAllocationData1D(long con, long id, int off, int mip, int count, Object d, int sizeBytes, int dt, + int mSize, boolean usePadding); + synchronized void nAllocationData1D(long id, int off, int mip, int count, Object d, int sizeBytes, Element.DataType dt, + int mSize, boolean usePadding) { validate(); - rsnAllocationData1D(mContext, id, off, mip, count, d, sizeBytes, dt.mID); + rsnAllocationData1D(mContext, id, off, mip, count, d, sizeBytes, dt.mID, mSize, usePadding); } - native void rsnAllocationElementData1D(long con,long id, int xoff, int mip, int compIdx, byte[] d, int sizeBytes); - synchronized void nAllocationElementData1D(long id, int xoff, int mip, int compIdx, byte[] d, int sizeBytes) { + native void rsnAllocationElementData(long con,long id, int xoff, int yoff, int zoff, int mip, int compIdx, byte[] d, int sizeBytes); + synchronized void nAllocationElementData(long id, int xoff, int yoff, int zoff, int mip, int compIdx, byte[] d, int sizeBytes) { validate(); - rsnAllocationElementData1D(mContext, id, xoff, mip, compIdx, d, sizeBytes); + rsnAllocationElementData(mContext, id, xoff, yoff, zoff, mip, compIdx, d, sizeBytes); } native void rsnAllocationData2D(long con, @@ -469,11 +515,13 @@ public class RenderScript { } native void rsnAllocationData2D(long con, long id, int xoff, int yoff, int mip, int face, - int w, int h, Object d, int sizeBytes, int dt); + int w, int h, Object d, int sizeBytes, int dt, + int mSize, boolean usePadding); synchronized void nAllocationData2D(long id, int xoff, int yoff, int mip, int face, - int w, int h, Object d, int sizeBytes, Element.DataType dt) { + int w, int h, Object d, int sizeBytes, Element.DataType dt, + int mSize, boolean usePadding) { validate(); - rsnAllocationData2D(mContext, id, xoff, yoff, mip, face, w, h, d, sizeBytes, dt.mID); + rsnAllocationData2D(mContext, id, xoff, yoff, mip, face, w, h, d, sizeBytes, dt.mID, mSize, usePadding); } native void rsnAllocationData2D(long con, long id, int xoff, int yoff, int mip, int face, Bitmap b); @@ -501,33 +549,56 @@ public class RenderScript { } native void rsnAllocationData3D(long con, long id, int xoff, int yoff, int zoff, int mip, - int w, int h, int depth, Object d, int sizeBytes, int dt); + int w, int h, int depth, Object d, int sizeBytes, int dt, + int mSize, boolean usePadding); synchronized void nAllocationData3D(long id, int xoff, int yoff, int zoff, int mip, - int w, int h, int depth, Object d, int sizeBytes, Element.DataType dt) { + int w, int h, int depth, Object d, int sizeBytes, Element.DataType dt, + int mSize, boolean usePadding) { validate(); - rsnAllocationData3D(mContext, id, xoff, yoff, zoff, mip, w, h, depth, d, sizeBytes, dt.mID); + rsnAllocationData3D(mContext, id, xoff, yoff, zoff, mip, w, h, depth, d, sizeBytes, + dt.mID, mSize, usePadding); } - native void rsnAllocationRead(long con, long id, Object d, int dt); - synchronized void nAllocationRead(long id, Object d, Element.DataType dt) { + native void rsnAllocationRead(long con, long id, Object d, int dt, int mSize, boolean usePadding); + synchronized void nAllocationRead(long id, Object d, Element.DataType dt, int mSize, boolean usePadding) { validate(); - rsnAllocationRead(mContext, id, d, dt.mID); + rsnAllocationRead(mContext, id, d, dt.mID, mSize, usePadding); } native void rsnAllocationRead1D(long con, long id, int off, int mip, int count, Object d, - int sizeBytes, int dt); + int sizeBytes, int dt, int mSize, boolean usePadding); synchronized void nAllocationRead1D(long id, int off, int mip, int count, Object d, - int sizeBytes, Element.DataType dt) { + int sizeBytes, Element.DataType dt, int mSize, boolean usePadding) { validate(); - rsnAllocationRead1D(mContext, id, off, mip, count, d, sizeBytes, dt.mID); + rsnAllocationRead1D(mContext, id, off, mip, count, d, sizeBytes, dt.mID, mSize, usePadding); + } + + native void rsnAllocationElementRead(long con,long id, int xoff, int yoff, int zoff, + int mip, int compIdx, byte[] d, int sizeBytes); + synchronized void nAllocationElementRead(long id, int xoff, int yoff, int zoff, + int mip, int compIdx, byte[] d, int sizeBytes) { + validate(); + rsnAllocationElementRead(mContext, id, xoff, yoff, zoff, mip, compIdx, d, sizeBytes); } native void rsnAllocationRead2D(long con, long id, int xoff, int yoff, int mip, int face, - int w, int h, Object d, int sizeBytes, int dt); + int w, int h, Object d, int sizeBytes, int dt, + int mSize, boolean usePadding); synchronized void nAllocationRead2D(long id, int xoff, int yoff, int mip, int face, - int w, int h, Object d, int sizeBytes, Element.DataType dt) { + int w, int h, Object d, int sizeBytes, Element.DataType dt, + int mSize, boolean usePadding) { validate(); - rsnAllocationRead2D(mContext, id, xoff, yoff, mip, face, w, h, d, sizeBytes, dt.mID); + rsnAllocationRead2D(mContext, id, xoff, yoff, mip, face, w, h, d, sizeBytes, dt.mID, mSize, usePadding); + } + + native void rsnAllocationRead3D(long con, long id, int xoff, int yoff, int zoff, int mip, + int w, int h, int depth, Object d, int sizeBytes, int dt, + int mSize, boolean usePadding); + synchronized void nAllocationRead3D(long id, int xoff, int yoff, int zoff, int mip, + int w, int h, int depth, Object d, int sizeBytes, Element.DataType dt, + int mSize, boolean usePadding) { + validate(); + rsnAllocationRead3D(mContext, id, xoff, yoff, zoff, mip, w, h, depth, d, sizeBytes, dt.mID, mSize, usePadding); } native long rsnAllocationGetType(long con, long id); @@ -542,6 +613,20 @@ public class RenderScript { rsnAllocationResize1D(mContext, id, dimX); } + native long rsnAllocationAdapterCreate(long con, long allocId, long typeId); + synchronized long nAllocationAdapterCreate(long allocId, long typeId) { + validate(); + return rsnAllocationAdapterCreate(mContext, allocId, typeId); + } + + native void rsnAllocationAdapterOffset(long con, long id, int x, int y, int z, + int mip, int face, int a1, int a2, int a3, int a4); + synchronized void nAllocationAdapterOffset(long id, int x, int y, int z, + int mip, int face, int a1, int a2, int a3, int a4) { + validate(); + rsnAllocationAdapterOffset(mContext, id, x, y, z, mip, face, a1, a2, a3, a4); + } + native long rsnFileA3DCreateFromAssetStream(long con, long assetStream); synchronized long nFileA3DCreateFromAssetStream(long assetStream) { validate(); @@ -605,52 +690,14 @@ public class RenderScript { validate(); rsnScriptInvoke(mContext, id, slot); } - native void rsnScriptForEach(long con, long id, int slot, long ain, long aout, byte[] params); - native void rsnScriptForEach(long con, long id, int slot, long ain, long aout); - native void rsnScriptForEachClipped(long con, long id, int slot, long ain, long aout, byte[] params, - int xstart, int xend, int ystart, int yend, int zstart, int zend); - native void rsnScriptForEachClipped(long con, long id, int slot, long ain, long aout, - int xstart, int xend, int ystart, int yend, int zstart, int zend); - synchronized void nScriptForEach(long id, int slot, long ain, long aout, byte[] params) { - validate(); - if (params == null) { - rsnScriptForEach(mContext, id, slot, ain, aout); - } else { - rsnScriptForEach(mContext, id, slot, ain, aout, params); - } - } - synchronized void nScriptForEachClipped(long id, int slot, long ain, long aout, byte[] params, - int xstart, int xend, int ystart, int yend, int zstart, int zend) { - validate(); - if (params == null) { - rsnScriptForEachClipped(mContext, id, slot, ain, aout, xstart, xend, ystart, yend, zstart, zend); - } else { - rsnScriptForEachClipped(mContext, id, slot, ain, aout, params, xstart, xend, ystart, yend, zstart, zend); - } - } - - /** - * Multi-input code. - * - */ + native void rsnScriptForEach(long con, long id, int slot, long[] ains, + long aout, byte[] params, int[] limits); - // @hide - native void rsnScriptForEachMultiClipped(long con, long id, int slot, long[] ains, long aout, byte[] params, - int xstart, int xend, int ystart, int yend, int zstart, int zend); - // @hide - native void rsnScriptForEachMultiClipped(long con, long id, int slot, long[] ains, long aout, - int xstart, int xend, int ystart, int yend, int zstart, int zend); - - // @hide - synchronized void nScriptForEachMultiClipped(long id, int slot, long[] ains, long aout, byte[] params, - int xstart, int xend, int ystart, int yend, int zstart, int zend) { - validate(); - if (params == null) { - rsnScriptForEachMultiClipped(mContext, id, slot, ains, aout, xstart, xend, ystart, yend, zstart, zend); - } else { - rsnScriptForEachMultiClipped(mContext, id, slot, ains, aout, params, xstart, xend, ystart, yend, zstart, zend); - } + synchronized void nScriptForEach(long id, int slot, long[] ains, long aout, + byte[] params, int[] limits) { + validate(); + rsnScriptForEach(mContext, id, slot, ains, aout, params, limits); } native void rsnScriptInvokeV(long con, long id, int slot, byte[] params); @@ -743,6 +790,12 @@ public class RenderScript { return rsnScriptKernelIDCreate(mContext, sid, slot, sig); } + native long rsnScriptInvokeIDCreate(long con, long sid, int slot); + synchronized long nScriptInvokeIDCreate(long sid, int slot) { + validate(); + return rsnScriptInvokeIDCreate(mContext, sid, slot); + } + native long rsnScriptFieldIDCreate(long con, long sid, int slot); synchronized long nScriptFieldIDCreate(long sid, int slot) { validate(); @@ -850,14 +903,70 @@ public class RenderScript { rsnMeshGetIndices(mContext, id, idxIds, primitives, vtxIdCount); } - native long rsnPathCreate(long con, int prim, boolean isStatic, long vtx, long loop, float q); - synchronized long nPathCreate(int prim, boolean isStatic, long vtx, long loop, float q) { + native void rsnScriptIntrinsicBLAS_Single(long con, long id, int func, int TransA, + int TransB, int Side, int Uplo, int Diag, int M, int N, int K, + float alpha, long A, long B, float beta, long C, int incX, int incY, + int KL, int KU); + synchronized void nScriptIntrinsicBLAS_Single(long id, int func, int TransA, + int TransB, int Side, int Uplo, int Diag, int M, int N, int K, + float alpha, long A, long B, float beta, long C, int incX, int incY, + int KL, int KU) { validate(); - return rsnPathCreate(mContext, prim, isStatic, vtx, loop, q); + rsnScriptIntrinsicBLAS_Single(mContext, id, func, TransA, TransB, Side, Uplo, Diag, M, N, K, alpha, A, B, beta, C, incX, incY, KL, KU); } + native void rsnScriptIntrinsicBLAS_Double(long con, long id, int func, int TransA, + int TransB, int Side, int Uplo, int Diag, int M, int N, int K, + double alpha, long A, long B, double beta, long C, int incX, int incY, + int KL, int KU); + synchronized void nScriptIntrinsicBLAS_Double(long id, int func, int TransA, + int TransB, int Side, int Uplo, int Diag, int M, int N, int K, + double alpha, long A, long B, double beta, long C, int incX, int incY, + int KL, int KU) { + validate(); + rsnScriptIntrinsicBLAS_Double(mContext, id, func, TransA, TransB, Side, Uplo, Diag, M, N, K, alpha, A, B, beta, C, incX, incY, KL, KU); + } + + native void rsnScriptIntrinsicBLAS_Complex(long con, long id, int func, int TransA, + int TransB, int Side, int Uplo, int Diag, int M, int N, int K, + float alphaX, float alphaY, long A, long B, float betaX, float betaY, long C, int incX, int incY, + int KL, int KU); + synchronized void nScriptIntrinsicBLAS_Complex(long id, int func, int TransA, + int TransB, int Side, int Uplo, int Diag, int M, int N, int K, + float alphaX, float alphaY, long A, long B, float betaX, float betaY, long C, int incX, int incY, + int KL, int KU) { + validate(); + rsnScriptIntrinsicBLAS_Complex(mContext, id, func, TransA, TransB, Side, Uplo, Diag, M, N, K, alphaX, alphaY, A, B, betaX, betaY, C, incX, incY, KL, KU); + } + + native void rsnScriptIntrinsicBLAS_Z(long con, long id, int func, int TransA, + int TransB, int Side, int Uplo, int Diag, int M, int N, int K, + double alphaX, double alphaY, long A, long B, double betaX, double betaY, long C, int incX, int incY, + int KL, int KU); + synchronized void nScriptIntrinsicBLAS_Z(long id, int func, int TransA, + int TransB, int Side, int Uplo, int Diag, int M, int N, int K, + double alphaX, double alphaY, long A, long B, double betaX, double betaY, long C, int incX, int incY, + int KL, int KU) { + validate(); + rsnScriptIntrinsicBLAS_Z(mContext, id, func, TransA, TransB, Side, Uplo, Diag, M, N, K, alphaX, alphaY, A, B, betaX, betaY, C, incX, incY, KL, KU); + } + + native void rsnScriptIntrinsicBLAS_BNNM(long con, long id, int M, int N, int K, + long A, int a_offset, long B, int b_offset, long C, int c_offset, + int c_mult_int); + synchronized void nScriptIntrinsicBLAS_BNNM(long id, int M, int N, int K, + long A, int a_offset, long B, int b_offset, long C, int c_offset, + int c_mult_int) { + validate(); + rsnScriptIntrinsicBLAS_BNNM(mContext, id, M, N, K, A, a_offset, B, b_offset, C, c_offset, c_mult_int); + } + + + long mDev; long mContext; + private boolean mDestroyed = false; + @SuppressWarnings({"FieldCanBeLocal"}) MessageThread mMessageThread; @@ -869,6 +978,7 @@ public class RenderScript { Element mElement_I32; Element mElement_U64; Element mElement_I64; + Element mElement_F16; Element mElement_F32; Element mElement_F64; Element mElement_BOOLEAN; @@ -892,6 +1002,10 @@ public class RenderScript { Element mElement_RGBA_4444; Element mElement_RGBA_8888; + Element mElement_HALF_2; + Element mElement_HALF_3; + Element mElement_HALF_4; + Element mElement_FLOAT_2; Element mElement_FLOAT_3; Element mElement_FLOAT_4; @@ -1040,8 +1154,10 @@ public class RenderScript { * their priority to LOW to avoid starving forground processes. */ public enum Priority { - LOW (Process.THREAD_PRIORITY_BACKGROUND + (5 * Process.THREAD_PRIORITY_LESS_FAVORABLE)), - NORMAL (Process.THREAD_PRIORITY_DISPLAY); + // These values used to represent official thread priority values + // now they are simply enums to be used by the runtime side + LOW (15), + NORMAL (-8); int mID; Priority(int id) { @@ -1203,20 +1319,13 @@ public class RenderScript { } /** - * @hide - */ - public static RenderScript create(Context ctx, int sdkVersion) { - return create(ctx, sdkVersion, ContextType.NORMAL, CREATE_FLAG_NONE); - } - - /** * Create a RenderScript context. * * @hide * @param ctx The context. * @return RenderScript */ - public static RenderScript create(Context ctx, int sdkVersion, ContextType ct, int flags) { + private static RenderScript internalCreate(Context ctx, int sdkVersion, ContextType ct, int flags) { if (!sInitialized) { Log.e(LOG_TAG, "RenderScript.create() called when disabled; someone is likely to crash"); return null; @@ -1231,16 +1340,28 @@ public class RenderScript { rs.mDev = rs.nDeviceCreate(); rs.mContext = rs.nContextCreate(rs.mDev, flags, sdkVersion, ct.mID); rs.mContextType = ct; + rs.mContextFlags = flags; + rs.mContextSdkVersion = sdkVersion; if (rs.mContext == 0) { throw new RSDriverException("Failed to create RS context."); } + + // set up cache directory for entire context + final String CACHE_PATH = "com.android.renderscript.cache"; + File f = new File(RenderScriptCacheDir.mCacheDir, CACHE_PATH); + String mCachePath = f.getAbsolutePath(); + f.mkdirs(); + rs.nContextSetCacheDir(mCachePath); + rs.mMessageThread = new MessageThread(rs); rs.mMessageThread.start(); return rs; } /** - * Create a RenderScript context. + * calls create(ctx, ContextType.NORMAL, CREATE_FLAG_NONE) + * + * See documentation for @create for details * * @param ctx The context. * @return RenderScript @@ -1250,21 +1371,33 @@ public class RenderScript { } /** - * Create a RenderScript context. + * calls create(ctx, ct, CREATE_FLAG_NONE) * + * See documentation for @create for details * * @param ctx The context. * @param ct The type of context to be created. * @return RenderScript */ public static RenderScript create(Context ctx, ContextType ct) { - int v = ctx.getApplicationInfo().targetSdkVersion; - return create(ctx, v, ct, CREATE_FLAG_NONE); + return create(ctx, ct, CREATE_FLAG_NONE); } - /** - * Create a RenderScript context. + + /** + * Gets or creates a RenderScript context of the specified type. + * + * The returned context will be cached for future reuse within + * the process. When an application is finished using + * RenderScript it should call releaseAllContexts() + * + * A process context is a context designed for easy creation and + * lifecycle management. Multiple calls to this function will + * return the same object provided they are called with the same + * options. This allows it to be used any time a RenderScript + * context is needed. * + * Prior to API 23 this always created a new context. * * @param ctx The context. * @param ct The type of context to be created. @@ -1277,6 +1410,100 @@ public class RenderScript { } /** + * calls create(ctx, sdkVersion, ContextType.NORMAL, CREATE_FLAG_NONE) + * + * Used by the RenderScriptThunker to maintain backward compatibility. + * + * @hide + * @param ctx The context. + * @param sdkVersion The target SDK Version. + * @return RenderScript + */ + public static RenderScript create(Context ctx, int sdkVersion) { + return create(ctx, sdkVersion, ContextType.NORMAL, CREATE_FLAG_NONE); + } + + /** + * Gets or creates a RenderScript context of the specified type. + * + * @hide + * @param ctx The context. + * @param ct The type of context to be created. + * @param sdkVersion The target SDK Version. + * @param flags The OR of the CREATE_FLAG_* options desired + * @return RenderScript + */ + public static RenderScript create(Context ctx, int sdkVersion, ContextType ct, int flags) { + if (sdkVersion < 23) { + return internalCreate(ctx, sdkVersion, ct, flags); + } + + synchronized (mProcessContextList) { + for (RenderScript prs : mProcessContextList) { + if ((prs.mContextType == ct) && + (prs.mContextFlags == flags) && + (prs.mContextSdkVersion == sdkVersion)) { + + return prs; + } + } + + RenderScript prs = internalCreate(ctx, sdkVersion, ct, flags); + prs.mIsProcessContext = true; + mProcessContextList.add(prs); + return prs; + } + } + + /** + * @hide + * + * Releases all the process contexts. This is the same as + * calling .destroy() on each unique context retreived with + * create(...). If no contexts have been created this + * function does nothing. + * + * Typically you call this when your application is losing focus + * and will not be using a context for some time. + * + * This has no effect on a context created with + * createMultiContext() + */ + public static void releaseAllContexts() { + ArrayList<RenderScript> oldList; + synchronized (mProcessContextList) { + oldList = mProcessContextList; + mProcessContextList = new ArrayList<RenderScript>(); + } + + for (RenderScript prs : oldList) { + prs.mIsProcessContext = false; + prs.destroy(); + } + oldList.clear(); + } + + + + /** + * Create a RenderScript context. + * + * This is an advanced function intended for applications which + * need to create more than one RenderScript context to be used + * at the same time. + * + * If you need a single context please use create() + * + * @hide + * @param ctx The context. + * @return RenderScript + */ + public static RenderScript createMultiContext(Context ctx, ContextType ct, int flags, int API_number) { + return internalCreate(ctx, API_number, ct, flags); + } + + + /** * Print the currently available debugging information about the state of * the RS context to the log. * @@ -1295,27 +1522,55 @@ public class RenderScript { nContextFinish(); } + private void helpDestroy() { + boolean shouldDestroy = false; + synchronized(this) { + if (!mDestroyed) { + shouldDestroy = true; + mDestroyed = true; + } + } + + if (shouldDestroy) { + nContextFinish(); + + nContextDeinitToClient(mContext); + mMessageThread.mRun = false; + try { + mMessageThread.join(); + } catch(InterruptedException e) { + } + + nContextDestroy(); + + nDeviceDestroy(mDev); + mDev = 0; + } + } + + protected void finalize() throws Throwable { + helpDestroy(); + super.finalize(); + } + + /** * Destroys this RenderScript context. Once this function is called, * using this context or any objects belonging to this context is * illegal. * + * API 23+, this function is a NOP if the context was created + * with create(). Please use releaseAllContexts() to clean up + * contexts created with the create function. + * */ public void destroy() { - validate(); - nContextFinish(); - - nContextDeinitToClient(mContext); - mMessageThread.mRun = false; - try { - mMessageThread.join(); - } catch(InterruptedException e) { + if (mIsProcessContext) { + // users cannot destroy a process context + return; } - - nContextDestroy(); - - nDeviceDestroy(mDev); - mDev = 0; + validate(); + helpDestroy(); } boolean isAlive() { diff --git a/rs/java/android/renderscript/RenderScriptCacheDir.java b/rs/java/android/renderscript/RenderScriptCacheDir.java new file mode 100644 index 000000000000..95a9d7575945 --- /dev/null +++ b/rs/java/android/renderscript/RenderScriptCacheDir.java @@ -0,0 +1,40 @@ +/* + * Copyright (C) 2008-2015 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package android.renderscript; + +import java.io.File; + +/** + * Used only for tracking the RenderScript cache directory. + * @hide + */ +public class RenderScriptCacheDir { + /** + * Sets the directory to use as a persistent storage for the + * renderscript object file cache. + * + * @hide + * @param cacheDir A directory the current process can write to + */ + public static void setupDiskCache(File cacheDir) { + // Defer creation of cache path to nScriptCCreate(). + mCacheDir = cacheDir; + } + + static File mCacheDir; + +} diff --git a/rs/java/android/renderscript/Script.java b/rs/java/android/renderscript/Script.java index c49ef948db97..65056ac5c22d 100644 --- a/rs/java/android/renderscript/Script.java +++ b/rs/java/android/renderscript/Script.java @@ -48,7 +48,8 @@ public class Script extends BaseObj { /** * Only to be used by generated reflected classes. */ - protected KernelID createKernelID(int slot, int sig, Element ein, Element eout) { + protected KernelID createKernelID(int slot, int sig, Element ein, + Element eout) { KernelID k = mKIDs.get(slot); if (k != null) { return k; @@ -65,6 +66,46 @@ public class Script extends BaseObj { } /** + * @hide Pending API review + * InvokeID is an identifier for an invoke function. It is used + * as an identifier for ScriptGroup creation. + * + * This class should not be directly created. Instead use the method in the + * reflected or intrinsic code "getInvokeID_funcname()". + * + */ + public static final class InvokeID extends BaseObj { + Script mScript; + int mSlot; + InvokeID(long id, RenderScript rs, Script s, int slot) { + super(id, rs); + mScript = s; + mSlot = slot; + } + } + + private final SparseArray<InvokeID> mIIDs = new SparseArray<InvokeID>(); + /** + * @hide Pending API review + * Only to be used by generated reflected classes. + */ + protected InvokeID createInvokeID(int slot) { + InvokeID i = mIIDs.get(slot); + if (i != null) { + return i; + } + + long id = mRS.nScriptInvokeIDCreate(getID(mRS), slot); + if (id == 0) { + throw new RSDriverException("Failed to create KernelID"); + } + + i = new InvokeID(id, mRS, this, slot); + mIIDs.put(slot, i); + return i; + } + + /** * FieldID is an identifier for a Script + exported field pair. It is used * as an identifier for ScriptGroup creation. * @@ -127,59 +168,56 @@ public class Script extends BaseObj { * Only intended for use by generated reflected code. * */ - protected void forEach(int slot, Allocation ain, Allocation aout, FieldPacker v) { - mRS.validate(); - mRS.validateObject(ain); - mRS.validateObject(aout); - if (ain == null && aout == null) { - throw new RSIllegalArgumentException( - "At least one of ain or aout is required to be non-null."); - } - long in_id = 0; - if (ain != null) { - in_id = ain.getID(mRS); - } - long out_id = 0; - if (aout != null) { - out_id = aout.getID(mRS); - } - byte[] params = null; - if (v != null) { - params = v.getData(); - } - mRS.nScriptForEach(getID(mRS), slot, in_id, out_id, params); + protected void forEach(int slot, Allocation ain, Allocation aout, + FieldPacker v) { + forEach(slot, ain, aout, v, null); } /** * Only intended for use by generated reflected code. * */ - protected void forEach(int slot, Allocation ain, Allocation aout, FieldPacker v, LaunchOptions sc) { + protected void forEach(int slot, Allocation ain, Allocation aout, + FieldPacker v, LaunchOptions sc) { + // TODO: Is this necessary if nScriptForEach calls validate as well? mRS.validate(); mRS.validateObject(ain); mRS.validateObject(aout); + if (ain == null && aout == null) { throw new RSIllegalArgumentException( "At least one of ain or aout is required to be non-null."); } - if (sc == null) { - forEach(slot, ain, aout, v); - return; - } - long in_id = 0; + long[] in_ids = null; if (ain != null) { - in_id = ain.getID(mRS); + in_ids = mInIdsBuffer; + in_ids[0] = ain.getID(mRS); } + long out_id = 0; if (aout != null) { out_id = aout.getID(mRS); } + byte[] params = null; if (v != null) { params = v.getData(); } - mRS.nScriptForEachClipped(getID(mRS), slot, in_id, out_id, params, sc.xstart, sc.xend, sc.ystart, sc.yend, sc.zstart, sc.zend); + + int[] limits = null; + if (sc != null) { + limits = new int[6]; + + limits[0] = sc.xstart; + limits[1] = sc.xend; + limits[2] = sc.ystart; + limits[3] = sc.yend; + limits[4] = sc.zstart; + limits[5] = sc.zend; + } + + mRS.nScriptForEach(getID(mRS), slot, in_ids, out_id, params, limits); } /** @@ -187,8 +225,9 @@ public class Script extends BaseObj { * * @hide */ - protected void forEach(int slot, Allocation[] ains, Allocation aout, FieldPacker v) { - forEach(slot, ains, aout, v, new LaunchOptions()); + protected void forEach(int slot, Allocation[] ains, Allocation aout, + FieldPacker v) { + forEach(slot, ains, aout, v, null); } /** @@ -196,42 +235,63 @@ public class Script extends BaseObj { * * @hide */ - protected void forEach(int slot, Allocation[] ains, Allocation aout, FieldPacker v, LaunchOptions sc) { + protected void forEach(int slot, Allocation[] ains, Allocation aout, + FieldPacker v, LaunchOptions sc) { + // TODO: Is this necessary if nScriptForEach calls validate as well? mRS.validate(); - - for (Allocation ain : ains) { - mRS.validateObject(ain); + if (ains != null) { + for (Allocation ain : ains) { + mRS.validateObject(ain); + } } - mRS.validateObject(aout); + if (ains == null && aout == null) { throw new RSIllegalArgumentException( "At least one of ain or aout is required to be non-null."); } - if (sc == null) { - forEach(slot, ains, aout, v); - return; - } - - long[] in_ids = new long[ains.length]; - for (int index = 0; index < ains.length; ++index) { - in_ids[index] = ains[index].getID(mRS); + long[] in_ids; + if (ains != null) { + in_ids = new long[ains.length]; + for (int index = 0; index < ains.length; ++index) { + in_ids[index] = ains[index].getID(mRS); + } + } else { + in_ids = null; } long out_id = 0; if (aout != null) { out_id = aout.getID(mRS); } + byte[] params = null; if (v != null) { params = v.getData(); } - mRS.nScriptForEachMultiClipped(getID(mRS), slot, in_ids, out_id, params, sc.xstart, sc.xend, sc.ystart, sc.yend, sc.zstart, sc.zend); + + int[] limits = null; + if (sc != null) { + limits = new int[6]; + + limits[0] = sc.xstart; + limits[1] = sc.xend; + limits[2] = sc.ystart; + limits[3] = sc.yend; + limits[4] = sc.zstart; + limits[5] = sc.zend; + } + + mRS.nScriptForEach(getID(mRS), slot, in_ids, out_id, params, limits); } + long[] mInIdsBuffer; + Script(long id, RenderScript rs) { super(id, rs); + + mInIdsBuffer = new long[1]; } @@ -243,11 +303,17 @@ public class Script extends BaseObj { mRS.validate(); mRS.validateObject(va); if (va != null) { - if (mRS.getApplicationContext().getApplicationInfo().targetSdkVersion >= 20) { + + android.content.Context context = mRS.getApplicationContext(); + + if (context.getApplicationInfo().targetSdkVersion >= 20) { final Type t = va.mType; - if (t.hasMipmaps() || t.hasFaces() || (t.getY() != 0) || (t.getZ() != 0)) { + if (t.hasMipmaps() || t.hasFaces() || (t.getY() != 0) || + (t.getZ() != 0)) { + throw new RSIllegalArgumentException( - "API 20+ only allows simple 1D allocations to be used with bind."); + "API 20+ only allows simple 1D allocations to be " + + "used with bind."); } } mRS.nScriptBindAllocation(getID(mRS), va.getID(mRS), slot); @@ -378,11 +444,14 @@ public class Script extends BaseObj { protected Allocation mAllocation; protected void init(RenderScript rs, int dimx) { - mAllocation = Allocation.createSized(rs, mElement, dimx, Allocation.USAGE_SCRIPT); + mAllocation = Allocation.createSized(rs, mElement, dimx, + Allocation.USAGE_SCRIPT); } protected void init(RenderScript rs, int dimx, int usages) { - mAllocation = Allocation.createSized(rs, mElement, dimx, Allocation.USAGE_SCRIPT | usages); + mAllocation = + Allocation.createSized(rs, mElement, dimx, + Allocation.USAGE_SCRIPT | usages); } protected FieldBase() { diff --git a/rs/java/android/renderscript/ScriptC.java b/rs/java/android/renderscript/ScriptC.java index 64d21e49dee5..bf706c131e85 100644 --- a/rs/java/android/renderscript/ScriptC.java +++ b/rs/java/android/renderscript/ScriptC.java @@ -124,7 +124,7 @@ public class ScriptC extends Script { // Create the RS cache path if we haven't done so already. if (mCachePath == null) { - File f = new File(rs.mCacheDir, CACHE_PATH); + File f = new File(RenderScriptCacheDir.mCacheDir, CACHE_PATH); mCachePath = f.getAbsolutePath(); f.mkdirs(); } @@ -135,7 +135,7 @@ public class ScriptC extends Script { private static synchronized long internalStringCreate(RenderScript rs, String resName, byte[] bitcode) { // Create the RS cache path if we haven't done so already. if (mCachePath == null) { - File f = new File(rs.mCacheDir, CACHE_PATH); + File f = new File(RenderScriptCacheDir.mCacheDir, CACHE_PATH); mCachePath = f.getAbsolutePath(); f.mkdirs(); } diff --git a/rs/java/android/renderscript/ScriptGroup2.java b/rs/java/android/renderscript/ScriptGroup2.java new file mode 100644 index 000000000000..417bbee73019 --- /dev/null +++ b/rs/java/android/renderscript/ScriptGroup2.java @@ -0,0 +1,449 @@ +/* + * Copyright (C) 2015 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package android.renderscript; + +import android.util.Log; +import android.util.Pair; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +/** + +****************************** +You have tried to change the API from what has been previously approved. + +To make these errors go away, you have two choices: +1) You can add "@hide" javadoc comments to the methods, etc. listed in the +errors above. + +2) You can update current.txt by executing the following command: +make update-api + +To submit the revised current.txt to the main Android repository, +you will need approval. +****************************** + +@hide Pending Android public API approval. +*/ +public class ScriptGroup2 extends BaseObj { + + public static class Closure extends BaseObj { + private Object[] mArgs; + private Allocation mReturnValue; + private Map<Script.FieldID, Object> mBindings; + + private Future mReturnFuture; + private Map<Script.FieldID, Future> mGlobalFuture; + + private FieldPacker mFP; + + private static final String TAG = "Closure"; + + public Closure(long id, RenderScript rs) { + super(id, rs); + } + + public Closure(RenderScript rs, Script.KernelID kernelID, Type returnType, + Object[] args, Map<Script.FieldID, Object> globals) { + super(0, rs); + + mArgs = args; + mReturnValue = Allocation.createTyped(rs, returnType); + mBindings = globals; + mGlobalFuture = new HashMap<Script.FieldID, Future>(); + + int numValues = args.length + globals.size(); + + long[] fieldIDs = new long[numValues]; + long[] values = new long[numValues]; + int[] sizes = new int[numValues]; + long[] depClosures = new long[numValues]; + long[] depFieldIDs = new long[numValues]; + + int i; + for (i = 0; i < args.length; i++) { + Object obj = args[i]; + fieldIDs[i] = 0; + if (obj instanceof UnboundValue) { + UnboundValue unbound = (UnboundValue)obj; + unbound.addReference(this, i); + } else { + retrieveValueAndDependenceInfo(rs, i, args[i], values, sizes, + depClosures, depFieldIDs); + } + } + + for (Map.Entry<Script.FieldID, Object> entry : globals.entrySet()) { + Object obj = entry.getValue(); + Script.FieldID fieldID = entry.getKey(); + fieldIDs[i] = fieldID.getID(rs); + if (obj instanceof UnboundValue) { + UnboundValue unbound = (UnboundValue)obj; + unbound.addReference(this, fieldID); + } else { + retrieveValueAndDependenceInfo(rs, i, obj, values, + sizes, depClosures, depFieldIDs); + } + i++; + } + + long id = rs.nClosureCreate(kernelID.getID(rs), mReturnValue.getID(rs), + fieldIDs, values, sizes, depClosures, depFieldIDs); + + setID(id); + } + + public Closure(RenderScript rs, Script.InvokeID invokeID, + Object[] args, Map<Script.FieldID, Object> globals) { + super(0, rs); + mFP = FieldPacker.createFromArray(args); + + mArgs = args; + mBindings = globals; + mGlobalFuture = new HashMap<Script.FieldID, Future>(); + + int numValues = globals.size(); + + long[] fieldIDs = new long[numValues]; + long[] values = new long[numValues]; + int[] sizes = new int[numValues]; + long[] depClosures = new long[numValues]; + long[] depFieldIDs = new long[numValues]; + + int i = 0; + for (Map.Entry<Script.FieldID, Object> entry : globals.entrySet()) { + Object obj = entry.getValue(); + Script.FieldID fieldID = entry.getKey(); + fieldIDs[i] = fieldID.getID(rs); + if (obj instanceof UnboundValue) { + UnboundValue unbound = (UnboundValue)obj; + unbound.addReference(this, fieldID); + } else { + // TODO(yangni): Verify obj not a future. + retrieveValueAndDependenceInfo(rs, i, obj, values, + sizes, depClosures, depFieldIDs); + } + i++; + } + + long id = rs.nInvokeClosureCreate(invokeID.getID(rs), mFP.getData(), fieldIDs, + values, sizes); + + setID(id); + } + + private static + void retrieveValueAndDependenceInfo(RenderScript rs, + int index, Object obj, + long[] values, int[] sizes, + long[] depClosures, + long[] depFieldIDs) { + + if (obj instanceof Future) { + Future f = (Future)obj; + obj = f.getValue(); + depClosures[index] = f.getClosure().getID(rs); + Script.FieldID fieldID = f.getFieldID(); + depFieldIDs[index] = fieldID != null ? fieldID.getID(rs) : 0; + if (obj == null) { + // Value is originally created by the owner closure + values[index] = 0; + sizes[index] = 0; + return; + } + } else { + depClosures[index] = 0; + depFieldIDs[index] = 0; + } + + ValueAndSize vs = new ValueAndSize(rs, obj); + values[index] = vs.value; + sizes[index] = vs.size; + } + + public Future getReturn() { + if (mReturnFuture == null) { + mReturnFuture = new Future(this, null, mReturnValue); + } + + return mReturnFuture; + } + + public Future getGlobal(Script.FieldID field) { + Future f = mGlobalFuture.get(field); + + if (f == null) { + // If the field is not bound to this closure, this will return a future + // without an associated value (reference). So this is not working for + // cross-module (cross-script) linking in this case where a field not + // explicitly bound. + f = new Future(this, field, mBindings.get(field)); + mGlobalFuture.put(field, f); + } + + return f; + } + + void setArg(int index, Object obj) { + mArgs[index] = obj; + ValueAndSize vs = new ValueAndSize(mRS, obj); + mRS.nClosureSetArg(getID(mRS), index, vs.value, vs.size); + } + + void setGlobal(Script.FieldID fieldID, Object obj) { + mBindings.put(fieldID, obj); + ValueAndSize vs = new ValueAndSize(mRS, obj); + mRS.nClosureSetGlobal(getID(mRS), fieldID.getID(mRS), vs.value, vs.size); + } + + private static final class ValueAndSize { + public ValueAndSize(RenderScript rs, Object obj) { + if (obj instanceof Allocation) { + value = ((Allocation)obj).getID(rs); + size = -1; + } else if (obj instanceof Boolean) { + value = ((Boolean)obj).booleanValue() ? 1 : 0; + size = 4; + } else if (obj instanceof Integer) { + value = ((Integer)obj).longValue(); + size = 4; + } else if (obj instanceof Long) { + value = ((Long)obj).longValue(); + size = 8; + } else if (obj instanceof Float) { + value = ((Float)obj).longValue(); + size = 4; + } else if (obj instanceof Double) { + value = ((Double)obj).longValue(); + size = 8; + } + } + public long value; + public int size; + } + } + + public static class Future { + Closure mClosure; + Script.FieldID mFieldID; + Object mValue; + + Future(Closure closure, Script.FieldID fieldID, Object value) { + mClosure = closure; + mFieldID = fieldID; + mValue = value; + } + + Closure getClosure() { return mClosure; } + Script.FieldID getFieldID() { return mFieldID; } + Object getValue() { return mValue; } + } + + public static class UnboundValue { + // Either mFieldID or mArgIndex should be set but not both. + List<Pair<Closure, Script.FieldID>> mFieldID; + // -1 means unset. Legal values are 0 .. n-1, where n is the number of + // arguments for the referencing closure. + List<Pair<Closure, Integer>> mArgIndex; + + UnboundValue() { + mFieldID = new ArrayList<Pair<Closure, Script.FieldID>>(); + mArgIndex = new ArrayList<Pair<Closure, Integer>>(); + } + + void addReference(Closure closure, int index) { + mArgIndex.add(Pair.create(closure, Integer.valueOf(index))); + } + + void addReference(Closure closure, Script.FieldID fieldID) { + mFieldID.add(Pair.create(closure, fieldID)); + } + + void set(Object value) { + for (Pair<Closure, Integer> p : mArgIndex) { + Closure closure = p.first; + int index = p.second.intValue(); + closure.setArg(index, value); + } + for (Pair<Closure, Script.FieldID> p : mFieldID) { + Closure closure = p.first; + Script.FieldID fieldID = p.second; + closure.setGlobal(fieldID, value); + } + } + } + + String mName; + List<Closure> mClosures; + List<UnboundValue> mInputs; + Future[] mOutputs; + + private static final String TAG = "ScriptGroup2"; + + public ScriptGroup2(long id, RenderScript rs) { + super(id, rs); + } + + ScriptGroup2(RenderScript rs, String name, List<Closure> closures, + List<UnboundValue> inputs, Future[] outputs) { + super(0, rs); + mName = name; + mClosures = closures; + mInputs = inputs; + mOutputs = outputs; + + long[] closureIDs = new long[closures.size()]; + for (int i = 0; i < closureIDs.length; i++) { + closureIDs[i] = closures.get(i).getID(rs); + } + long id = rs.nScriptGroup2Create(name, ScriptC.mCachePath, closureIDs); + setID(id); + } + + public Object[] execute(Object... inputs) { + if (inputs.length < mInputs.size()) { + Log.e(TAG, this.toString() + " receives " + inputs.length + " inputs, " + + "less than expected " + mInputs.size()); + return null; + } + + if (inputs.length > mInputs.size()) { + Log.i(TAG, this.toString() + " receives " + inputs.length + " inputs, " + + "more than expected " + mInputs.size()); + } + + for (int i = 0; i < mInputs.size(); i++) { + Object obj = inputs[i]; + if (obj instanceof Future || obj instanceof UnboundValue) { + Log.e(TAG, this.toString() + ": input " + i + + " is a future or unbound value"); + return null; + } + UnboundValue unbound = mInputs.get(i); + unbound.set(obj); + } + + mRS.nScriptGroup2Execute(getID(mRS)); + + Object[] outputObjs = new Object[mOutputs.length]; + int i = 0; + for (Future f : mOutputs) { + outputObjs[i++] = f.getValue(); + } + return outputObjs; + } + + /** + @hide Pending Android public API approval. + */ + public static final class Binding { + public Script.FieldID mField; + public Object mValue; + public Binding(Script.FieldID field, Object value) { + mField = field; + mValue = value; + } + } + + /** + @hide Pending Android public API approval. + */ + public static final class Builder { + RenderScript mRS; + List<Closure> mClosures; + List<UnboundValue> mInputs; + private static final String TAG = "ScriptGroup2.Builder"; + + public Builder(RenderScript rs) { + mRS = rs; + mClosures = new ArrayList<Closure>(); + mInputs = new ArrayList<UnboundValue>(); + } + + public Closure addKernel(Script.KernelID k, Type returnType, Object[] args, + Map<Script.FieldID, Object> globalBindings) { + Closure c = new Closure(mRS, k, returnType, args, globalBindings); + mClosures.add(c); + return c; + } + + public Closure addInvoke(Script.InvokeID invoke, Object[] args, + Map<Script.FieldID, Object> globalBindings) { + Closure c = new Closure(mRS, invoke, args, globalBindings); + mClosures.add(c); + return c; + } + + public UnboundValue addInput() { + UnboundValue unbound = new UnboundValue(); + mInputs.add(unbound); + return unbound; + } + + public Closure addKernel(Script.KernelID k, Type returnType, Object... argsAndBindings) { + ArrayList<Object> args = new ArrayList<Object>(); + Map<Script.FieldID, Object> bindingMap = new HashMap<Script.FieldID, Object>(); + if (!seperateArgsAndBindings(argsAndBindings, args, bindingMap)) { + return null; + } + return addKernel(k, returnType, args.toArray(), bindingMap); + } + + public Closure addInvoke(Script.InvokeID invoke, Object... argsAndBindings) { + ArrayList<Object> args = new ArrayList<Object>(); + Map<Script.FieldID, Object> bindingMap = new HashMap<Script.FieldID, Object>(); + if (!seperateArgsAndBindings(argsAndBindings, args, bindingMap)) { + return null; + } + return addInvoke(invoke, args.toArray(), bindingMap); + } + + public ScriptGroup2 create(String name, Future... outputs) { + if (name == null || name.isEmpty() || name.length() > 100 || + !name.equals(name.replaceAll("[^a-zA-Z0-9-]", "_"))) { + throw new RSIllegalArgumentException("invalid script group name"); + } + ScriptGroup2 ret = new ScriptGroup2(mRS, name, mClosures, mInputs, outputs); + return ret; + } + + private boolean seperateArgsAndBindings(Object[] argsAndBindings, + ArrayList<Object> args, + Map<Script.FieldID, Object> bindingMap) { + int i; + for (i = 0; i < argsAndBindings.length; i++) { + if (argsAndBindings[i] instanceof Binding) { + break; + } + args.add(argsAndBindings[i]); + } + + for (; i < argsAndBindings.length; i++) { + if (!(argsAndBindings[i] instanceof Binding)) { + return false; + } + Binding b = (Binding)argsAndBindings[i]; + bindingMap.put(b.mField, b.mValue); + } + + return true; + } + + } +} diff --git a/rs/java/android/renderscript/ScriptIntrinsicBLAS.java b/rs/java/android/renderscript/ScriptIntrinsicBLAS.java new file mode 100644 index 000000000000..16b703356b47 --- /dev/null +++ b/rs/java/android/renderscript/ScriptIntrinsicBLAS.java @@ -0,0 +1,1510 @@ +/* + * Copyright (C) 2015 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package android.renderscript; + +import android.annotation.IntDef; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; + +/** + * + * BLAS + * + * @hide + **/ +public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { + private Allocation mLUT; + + private ScriptIntrinsicBLAS(long id, RenderScript rs) { + super(id, rs); + } + + private static final int RsBlas_sdsdot = 1; + private static final int RsBlas_dsdot = 2; + private static final int RsBlas_sdot = 3; + private static final int RsBlas_ddot = 4; + private static final int RsBlas_cdotu_sub = 5; + private static final int RsBlas_cdotc_sub = 6; + private static final int RsBlas_zdotu_sub = 7; + private static final int RsBlas_zdotc_sub = 8; + private static final int RsBlas_snrm2 = 9; + private static final int RsBlas_sasum = 10; + private static final int RsBlas_dnrm2 = 11; + private static final int RsBlas_dasum = 12; + private static final int RsBlas_scnrm2 = 13; + private static final int RsBlas_scasum = 14; + private static final int RsBlas_dznrm2 = 15; + private static final int RsBlas_dzasum = 16; + private static final int RsBlas_isamax = 17; + private static final int RsBlas_idamax = 18; + private static final int RsBlas_icamax = 19; + private static final int RsBlas_izamax = 20; + private static final int RsBlas_sswap = 21; + private static final int RsBlas_scopy = 22; + private static final int RsBlas_saxpy = 23; + private static final int RsBlas_dswap = 24; + private static final int RsBlas_dcopy = 25; + private static final int RsBlas_daxpy = 26; + private static final int RsBlas_cswap = 27; + private static final int RsBlas_ccopy = 28; + private static final int RsBlas_caxpy = 29; + private static final int RsBlas_zswap = 30; + private static final int RsBlas_zcopy = 31; + private static final int RsBlas_zaxpy = 32; + private static final int RsBlas_srotg = 33; + private static final int RsBlas_srotmg = 34; + private static final int RsBlas_srot = 35; + private static final int RsBlas_srotm = 36; + private static final int RsBlas_drotg = 37; + private static final int RsBlas_drotmg = 38; + private static final int RsBlas_drot = 39; + private static final int RsBlas_drotm = 40; + private static final int RsBlas_sscal = 41; + private static final int RsBlas_dscal = 42; + private static final int RsBlas_cscal = 43; + private static final int RsBlas_zscal = 44; + private static final int RsBlas_csscal = 45; + private static final int RsBlas_zdscal = 46; + private static final int RsBlas_sgemv = 47; + private static final int RsBlas_sgbmv = 48; + private static final int RsBlas_strmv = 49; + private static final int RsBlas_stbmv = 50; + private static final int RsBlas_stpmv = 51; + private static final int RsBlas_strsv = 52; + private static final int RsBlas_stbsv = 53; + private static final int RsBlas_stpsv = 54; + private static final int RsBlas_dgemv = 55; + private static final int RsBlas_dgbmv = 56; + private static final int RsBlas_dtrmv = 57; + private static final int RsBlas_dtbmv = 58; + private static final int RsBlas_dtpmv = 59; + private static final int RsBlas_dtrsv = 60; + private static final int RsBlas_dtbsv = 61; + private static final int RsBlas_dtpsv = 62; + private static final int RsBlas_cgemv = 63; + private static final int RsBlas_cgbmv = 64; + private static final int RsBlas_ctrmv = 65; + private static final int RsBlas_ctbmv = 66; + private static final int RsBlas_ctpmv = 67; + private static final int RsBlas_ctrsv = 68; + private static final int RsBlas_ctbsv = 69; + private static final int RsBlas_ctpsv = 70; + private static final int RsBlas_zgemv = 71; + private static final int RsBlas_zgbmv = 72; + private static final int RsBlas_ztrmv = 73; + private static final int RsBlas_ztbmv = 74; + private static final int RsBlas_ztpmv = 75; + private static final int RsBlas_ztrsv = 76; + private static final int RsBlas_ztbsv = 77; + private static final int RsBlas_ztpsv = 78; + private static final int RsBlas_ssymv = 79; + private static final int RsBlas_ssbmv = 80; + private static final int RsBlas_sspmv = 81; + private static final int RsBlas_sger = 82; + private static final int RsBlas_ssyr = 83; + private static final int RsBlas_sspr = 84; + private static final int RsBlas_ssyr2 = 85; + private static final int RsBlas_sspr2 = 86; + private static final int RsBlas_dsymv = 87; + private static final int RsBlas_dsbmv = 88; + private static final int RsBlas_dspmv = 89; + private static final int RsBlas_dger = 90; + private static final int RsBlas_dsyr = 91; + private static final int RsBlas_dspr = 92; + private static final int RsBlas_dsyr2 = 93; + private static final int RsBlas_dspr2 = 94; + private static final int RsBlas_chemv = 95; + private static final int RsBlas_chbmv = 96; + private static final int RsBlas_chpmv = 97; + private static final int RsBlas_cgeru = 98; + private static final int RsBlas_cgerc = 99; + private static final int RsBlas_cher = 100; + private static final int RsBlas_chpr = 101; + private static final int RsBlas_cher2 = 102; + private static final int RsBlas_chpr2 = 103; + private static final int RsBlas_zhemv = 104; + private static final int RsBlas_zhbmv = 105; + private static final int RsBlas_zhpmv = 106; + private static final int RsBlas_zgeru = 107; + private static final int RsBlas_zgerc = 108; + private static final int RsBlas_zher = 109; + private static final int RsBlas_zhpr = 110; + private static final int RsBlas_zher2 = 111; + private static final int RsBlas_zhpr2 = 112; + private static final int RsBlas_sgemm = 113; + private static final int RsBlas_ssymm = 114; + private static final int RsBlas_ssyrk = 115; + private static final int RsBlas_ssyr2k = 116; + private static final int RsBlas_strmm = 117; + private static final int RsBlas_strsm = 118; + private static final int RsBlas_dgemm = 119; + private static final int RsBlas_dsymm = 120; + private static final int RsBlas_dsyrk = 121; + private static final int RsBlas_dsyr2k = 122; + private static final int RsBlas_dtrmm = 123; + private static final int RsBlas_dtrsm = 124; + private static final int RsBlas_cgemm = 125; + private static final int RsBlas_csymm = 126; + private static final int RsBlas_csyrk = 127; + private static final int RsBlas_csyr2k = 128; + private static final int RsBlas_ctrmm = 129; + private static final int RsBlas_ctrsm = 130; + private static final int RsBlas_zgemm = 131; + private static final int RsBlas_zsymm = 132; + private static final int RsBlas_zsyrk = 133; + private static final int RsBlas_zsyr2k = 134; + private static final int RsBlas_ztrmm = 135; + private static final int RsBlas_ztrsm = 136; + private static final int RsBlas_chemm = 137; + private static final int RsBlas_cherk = 138; + private static final int RsBlas_cher2k = 139; + private static final int RsBlas_zhemm = 140; + private static final int RsBlas_zherk = 141; + private static final int RsBlas_zher2k = 142; + + // BLAS extensions start here + private static final int RsBlas_bnnm = 1000; + + /** + */ + public static ScriptIntrinsicBLAS create(RenderScript rs) { + long id = rs.nScriptIntrinsicCreate(13, Element.U32(rs).getID(rs)); + return new ScriptIntrinsicBLAS(id, rs); + } + + @IntDef({NO_TRANSPOSE, TRANSPOSE, CONJ_TRANSPOSE}) + @Retention(RetentionPolicy.SOURCE) + public @interface Transpose {} + + @IntDef({UPPER, LOWER}) + @Retention(RetentionPolicy.SOURCE) + public @interface Uplo {} + + @IntDef({NON_UNIT, UNIT}) + @Retention(RetentionPolicy.SOURCE) + public @interface Diag {} + + @IntDef({LEFT, RIGHT}) + @Retention(RetentionPolicy.SOURCE) + public @interface Side {} + + public static final int NO_TRANSPOSE = 111; + public static final int TRANSPOSE = 112; + public static final int CONJ_TRANSPOSE = 113; + + public static final int UPPER = 121; + public static final int LOWER = 122; + + public static final int NON_UNIT = 131; + public static final int UNIT = 132; + + public static final int LEFT = 141; + public static final int RIGHT = 142; + + static void validateSide(@Side int Side) { + if (Side != LEFT && Side != RIGHT) { + throw new RSRuntimeException("Invalid side passed to BLAS"); + } + } + + static void validateTranspose(@Transpose int Trans) { + if (Trans != NO_TRANSPOSE && Trans != TRANSPOSE && + Trans != CONJ_TRANSPOSE) { + throw new RSRuntimeException("Invalid transpose passed to BLAS"); + } + } + + static void validateConjTranspose(@Transpose int Trans) { + if (Trans != NO_TRANSPOSE && + Trans != CONJ_TRANSPOSE) { + throw new RSRuntimeException("Invalid transpose passed to BLAS"); + } + } + + static void validateDiag(@Diag int Diag) { + if (Diag != NON_UNIT && Diag != UNIT) { + throw new RSRuntimeException("Invalid diag passed to BLAS"); + } + } + + static void validateUplo(@Uplo int Uplo) { + if (Uplo != LEFT && Uplo != RIGHT) { + throw new RSRuntimeException("Invalid uplo passed to BLAS"); + } + } + + + /** + * Level 2 BLAS + */ + + static void validateGEMV(Element e, int TransA, Allocation A, Allocation X, int incX, Allocation Y, int incY) { + validateTranspose(TransA); + int M = A.getType().getY(); + int N = A.getType().getX(); + if (!A.getType().getElement().isCompatible(e) || + !X.getType().getElement().isCompatible(e) || + !Y.getType().getElement().isCompatible(e)) { + throw new RSRuntimeException("Called BLAS with wrong Element type"); + } + if (X.getType().getY() > 1 || Y.getType().getY() > 1) { + throw new RSRuntimeException("BLAS vectors must have Y dimension of 0 or 1"); + } + + if (incX <= 0 || incY <= 0) { + throw new RSRuntimeException("Vector increments must be greater than 0"); + } + int expectedXDim = -1, expectedYDim = -1; + if (TransA == NO_TRANSPOSE) { + expectedXDim = 1 + (N - 1) * incX; + expectedYDim = 1 + (M - 1) * incY; + } else { + expectedXDim = 1 + (M - 1) * incX; + expectedYDim = 1 + (N - 1) * incY; + } + if (X.getType().getX() != expectedXDim || + Y.getType().getY() != expectedXDim) { + throw new RSRuntimeException("Incorrect vector dimensions for GEMV"); + } + } + void SGEMV(@Transpose int TransA, float alpha, Allocation A, Allocation X, int incX, float beta, Allocation Y, int incY) { + validateGEMV(Element.F32(mRS), TransA, A, X, incX, Y, incY); + int M = A.getType().getY(); + int N = A.getType().getX(); + mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_sgemv, TransA, 0, 0, 0, 0, M, N, 0, alpha, A.getID(mRS), X.getID(mRS), beta, Y.getID(mRS), incX, incY, 0, 0); + } + void DGEMV(@Transpose int TransA, double alpha, Allocation A, Allocation X, int incX, double beta, Allocation Y, int incY) { + validateGEMV(Element.F64(mRS), TransA, A, X, incX, Y, incY); + int M = A.getType().getY(); + int N = A.getType().getX(); + mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dgemv, TransA, 0, 0, 0, 0, M, N, 0, alpha, A.getID(mRS), X.getID(mRS), beta, Y.getID(mRS), incX, incY, 0, 0); + } + void CGEMV(@Transpose int TransA, Float2 alpha, Allocation A, Allocation X, int incX, Float2 beta, Allocation Y, int incY) { + validateGEMV(Element.F32_2(mRS), TransA, A, X, incX, Y, incY); + int M = A.getType().getY(); + int N = A.getType().getX(); + mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_cgemv, TransA, 0, 0, 0, 0, M, N, 0, alpha.x, alpha.y, A.getID(mRS), X.getID(mRS), beta.x, beta.y, Y.getID(mRS), incX, incY, 0, 0); + } + void ZGEMV(@Transpose int TransA, Double2 alpha, Allocation A, Allocation X, int incX, Double2 beta, Allocation Y, int incY) { + validateGEMV(Element.F64_2(mRS), TransA, A, X, incX, Y, incY); + int M = A.getType().getY(); + int N = A.getType().getX(); + mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zgemv, TransA, 0, 0, 0, 0, M, N, 0, alpha.x, alpha.y, A.getID(mRS), X.getID(mRS), beta.x, beta.y, Y.getID(mRS), incX, incY, 0, 0); + } + + void SGBMV(@Transpose int TransA, int KL, int KU, float alpha, Allocation A, Allocation X, int incX, float beta, Allocation Y, int incY) { + // GBMV has the same validation requirements as GEMV + KL and KU >= 0 + validateGEMV(Element.F32(mRS), TransA, A, X, incX, Y, incY); + if (KL < 0 || KU < 0) { + throw new RSRuntimeException("KL and KU must be greater than or equal to 0"); + } + int M = A.getType().getY(); + int N = A.getType().getX(); + mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_sgbmv, TransA, 0, 0, 0, 0, M, N, 0, alpha, A.getID(mRS), X.getID(mRS), beta, Y.getID(mRS), incX, incY, KL, KU); + } + void DGBMV(@Transpose int TransA, int KL, int KU, double alpha, Allocation A, Allocation X, int incX, double beta, Allocation Y, int incY) { + // GBMV has the same validation requirements as GEMV + KL and KU >= 0 + validateGEMV(Element.F64(mRS), TransA, A, X, incX, Y, incY); + if (KL < 0 || KU < 0) { + throw new RSRuntimeException("KL and KU must be greater than or equal to 0"); + } + int M = A.getType().getY(); + int N = A.getType().getX(); + mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dgbmv, TransA, 0, 0, 0, 0, M, N, 0, alpha, A.getID(mRS), X.getID(mRS), beta, Y.getID(mRS), incX, incY, KL, KU); + } + void CGBMV(@Transpose int TransA, int KL, int KU, Float2 alpha, Allocation A, Allocation X, int incX, Float2 beta, Allocation Y, int incY) { + // GBMV has the same validation requirements as GEMV + KL and KU >= 0 + validateGEMV(Element.F32_2(mRS), TransA, A, X, incX, Y, incY); + if (KL < 0 || KU < 0) { + throw new RSRuntimeException("KL and KU must be greater than or equal to 0"); + } + int M = A.getType().getY(); + int N = A.getType().getX(); + mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_cgbmv, TransA, 0, 0, 0, 0, M, N, 0, alpha.x, alpha.y, A.getID(mRS), X.getID(mRS), beta.x, beta.y, Y.getID(mRS), incX, incY, KL, KU); + } + void ZGBMV(@Transpose int TransA, int KL, int KU, Double2 alpha, Allocation A, Allocation X, int incX, Double2 beta, Allocation Y, int incY) { + // GBMV has the same validation requirements as GEMV + KL and KU >= 0 + validateGEMV(Element.F64_2(mRS), TransA, A, X, incX, Y, incY); + if (KL < 0 || KU < 0) { + throw new RSRuntimeException("KL and KU must be greater than or equal to 0"); + } + int M = A.getType().getY(); + int N = A.getType().getX(); + mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zgbmv, TransA, 0, 0, 0, 0, M, N, 0, alpha.x, alpha.y, A.getID(mRS), X.getID(mRS), beta.x, beta.y, Y.getID(mRS), incX, incY, KL, KU); + } + + static void validateTRMV(Element e, @Transpose int TransA, Allocation A, Allocation X, int incX) { + validateTranspose(TransA); + int N = A.getType().getY(); + if (A.getType().getX() != N) { + throw new RSRuntimeException("A must be a square matrix for TRMV"); + } + if (!A.getType().getElement().isCompatible(e) || + !X.getType().getElement().isCompatible(e)) { + throw new RSRuntimeException("Called BLAS with wrong Element type"); + } + if (X.getType().getY() > 1) { + throw new RSRuntimeException("BLAS vectors must have Y dimension of 0 or 1"); + } + + if (incX <= 0) { + throw new RSRuntimeException("Vector increments must be greater than 0"); + } + int expectedXDim = 1 + (N - 1) * incX; + if (X.getType().getX() != expectedXDim) { + throw new RSRuntimeException("Incorrect vector dimensions for TRMV"); + } + } + + static int validateTPMV(Element e, @Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation Ap, Allocation X, int incX) { + validateTranspose(TransA); + validateUplo(Uplo); + validateDiag(Diag); + if (!Ap.getType().getElement().isCompatible(e) || + !X.getType().getElement().isCompatible(e)) { + throw new RSRuntimeException("Called BLAS with wrong Element type"); + } + if (X.getType().getY() > 1) { + throw new RSRuntimeException("BLAS vectors must have Y dimension of 0 or 1"); + } + + if (Ap.getType().getY() > 1) { + throw new RSRuntimeException("Ap must have a Y dimension of 0 or 1"); + } + + int N = (int)Math.sqrt((double)Ap.getType().getX() * 2); + if (Ap.getType().getX() != ((N * (N+1)) / 2)) { + throw new RSRuntimeException("Invalid dimension for Ap"); + } + + int expectedXDim = 1 + (N - 1) * incX; + if (X.getType().getX() != expectedXDim) { + throw new RSRuntimeException("Incorrect vector dimensions for SYMV"); + } + + return N; + } + + void STRMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation A, Allocation X, int incX) { + validateTRMV(Element.F32(mRS), TransA, A, X, incX); + int N = A.getType().getY(); + mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_strmv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0); + } + void DTRMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation A, Allocation X, int incX) { + validateTRMV(Element.F64(mRS), TransA, A, X, incX); + int N = A.getType().getY(); + mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dtrmv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0); + } + void CTRMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation A, Allocation X, int incX) { + validateTRMV(Element.F32_2(mRS), TransA, A, X, incX); + int N = A.getType().getY(); + mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_ctrmv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0); + } + void ZTRMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation A, Allocation X, int incX) { + validateTRMV(Element.F64_2(mRS), TransA, A, X, incX); + int N = A.getType().getY(); + mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_ztrmv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0); + } + void STBMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, int K, Allocation A, Allocation X, int incX) { + // TBMV has the same requirements as TRMV + validateTRMV(Element.F32(mRS), TransA, A, X, incX); + int N = A.getType().getY(); + mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_stbmv, TransA, 0, 0, Uplo, Diag, 0, N, K, 0, A.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0); + } + void DTBMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, int K, Allocation A, Allocation X, int incX) { + // TBMV has the same requirements as TRMV + validateTRMV(Element.F64(mRS), TransA, A, X, incX); + int N = A.getType().getY(); + mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dtbmv, TransA, 0, 0, Uplo, Diag, 0, N, K, 0, A.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0); + } + void CTBMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, int K, Allocation A, Allocation X, int incX) { + // TBMV has the same requirements as TRMV + validateTRMV(Element.F32_2(mRS), TransA, A, X, incX); + int N = A.getType().getY(); + mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_ctbmv, TransA, 0, 0, Uplo, Diag, 0, N, K, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0); + } + void ZTBMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, int K, Allocation A, Allocation X, int incX) { + // TBMV has the same requirements as TRMV + validateTRMV(Element.F64_2(mRS), TransA, A, X, incX); + int N = A.getType().getY(); + mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_ztbmv, TransA, 0, 0, Uplo, Diag, 0, N, K, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0); + } + void STPMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation Ap, Allocation X, int incX) { + int N = validateTPMV(Element.F32(mRS), Uplo, TransA, Diag, Ap, X, incX); + mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_stpmv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, Ap.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0); + } + void DTPMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation Ap, Allocation X, int incX) { + int N = validateTPMV(Element.F64(mRS), Uplo, TransA, Diag, Ap, X, incX); + mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dtpmv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, Ap.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0); + } + void CTPMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation Ap, Allocation X, int incX) { + int N = validateTPMV(Element.F32_2(mRS), Uplo, TransA, Diag, Ap, X, incX); + mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_ctpmv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, Ap.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0); + } + void ZTPMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation Ap, Allocation X, int incX) { + int N = validateTPMV(Element.F64_2(mRS), Uplo, TransA, Diag, Ap, X, incX); + mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_ztpmv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, Ap.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0); + } + void STRSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation A, Allocation X, int incX) { + // TRSV is the same as TRMV + validateTRMV(Element.F32(mRS), TransA, A, X, incX); + int N = A.getType().getY(); + mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_strsv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0); + + } + void DTRSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation A, Allocation X, int incX) { + // TRSV is the same as TRMV + validateTRMV(Element.F64(mRS), TransA, A, X, incX); + int N = A.getType().getY(); + mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dtrsv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0); + + } + void CTRSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation A, Allocation X, int incX) { + // TRSV is the same as TRMV + validateTRMV(Element.F32_2(mRS), TransA, A, X, incX); + int N = A.getType().getY(); + mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_ctrsv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0); + + } + void ZTRSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation A, Allocation X, int incX) { + // TRSV is the same as TRMV + validateTRMV(Element.F64_2(mRS), TransA, A, X, incX); + int N = A.getType().getY(); + mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_ztrsv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0); + + } + void STBSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, int K, Allocation A, Allocation X, int incX) { + // TBSV is the same as TRMV + validateTRMV(Element.F32(mRS), TransA, A, X, incX); + int N = A.getType().getY(); + if (K < 0) { + throw new RSRuntimeException("Number of diagonals must be positive"); + } + mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_stbsv, TransA, 0, 0, Uplo, Diag, 0, N, K, 0, A.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0); + } + void DTBSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, int K, Allocation A, Allocation X, int incX) { + // TBSV is the same as TRMV + validateTRMV(Element.F64(mRS), TransA, A, X, incX); + int N = A.getType().getY(); + if (K < 0) { + throw new RSRuntimeException("Number of diagonals must be positive"); + } + mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dtbsv, TransA, 0, 0, Uplo, Diag, 0, N, K, 0, A.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0); + } + void CTBSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, int K, Allocation A, Allocation X, int incX) { + // TBSV is the same as TRMV + validateTRMV(Element.F32_2(mRS), TransA, A, X, incX); + int N = A.getType().getY(); + if (K < 0) { + throw new RSRuntimeException("Number of diagonals must be positive"); + } + mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_ctbsv, TransA, 0, 0, Uplo, Diag, 0, N, K, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0); + } + void ZTBSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, int K, Allocation A, Allocation X, int incX) { + // TBSV is the same as TRMV + validateTRMV(Element.F64_2(mRS), TransA, A, X, incX); + int N = A.getType().getY(); + if (K < 0) { + throw new RSRuntimeException("Number of diagonals must be positive"); + } + mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_ztbsv, TransA, 0, 0, Uplo, Diag, 0, N, K, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0); + } + void STPSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation Ap, Allocation X, int incX) { + // TPSV is same as TPMV + int N = validateTPMV(Element.F32(mRS), Uplo, TransA, Diag, Ap, X, incX); + mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_stpsv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, Ap.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0); + } + void DTPSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation Ap, Allocation X, int incX) { + // TPSV is same as TPMV + int N = validateTPMV(Element.F64(mRS), Uplo, TransA, Diag, Ap, X, incX); + mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dtpsv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, Ap.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0); + } + void CTPSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation Ap, Allocation X, int incX) { + // TPSV is same as TPMV + int N = validateTPMV(Element.F32_2(mRS), Uplo, TransA, Diag, Ap, X, incX); + mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_ctpsv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, Ap.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0); + } + void ZTPSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation Ap, Allocation X, int incX) { + // TPSV is same as TPMV + int N = validateTPMV(Element.F64_2(mRS), Uplo, TransA, Diag, Ap, X, incX); + mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_ztpsv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, Ap.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0); + } + + /** + * Level 2, S and D only + */ + static int validateSYMV(Element e, @Uplo int Uplo, Allocation A, Allocation X, Allocation Y, int incX, int incY) { + validateUplo(Uplo); + int N = A.getType().getY(); + if (A.getType().getX() != N) { + throw new RSRuntimeException("A must be a square matrix for SYMV"); + } + if (!A.getType().getElement().isCompatible(e) || + !X.getType().getElement().isCompatible(e) || + !Y.getType().getElement().isCompatible(e) ) { + throw new RSRuntimeException("Called BLAS with wrong Element type"); + } + if (X.getType().getY() > 1 || Y.getType().getY() > 1) { + throw new RSRuntimeException("BLAS vectors must have Y dimension of 0 or 1"); + } + + if (incX <= 0 || incY <= 0) { + throw new RSRuntimeException("Vector increments must be greater than 0"); + } + int expectedXDim = 1 + (N - 1) * incX; + if (X.getType().getX() != expectedXDim) { + throw new RSRuntimeException("Incorrect vector dimensions for SYMV"); + } + int expectedYDim = 1 + (N - 1) * incY; + if (Y.getType().getX() != expectedYDim) { + throw new RSRuntimeException("Incorrect vector dimensions for SYMV"); + } + return N; + } + static int validateSPMV(Element e, @Uplo int Uplo, Allocation Ap, Allocation X, int incX, Allocation Y, int incY) { + validateUplo(Uplo); + if (!Ap.getType().getElement().isCompatible(e) || + !X.getType().getElement().isCompatible(e) || + !Y.getType().getElement().isCompatible(e)) { + throw new RSRuntimeException("Called BLAS with wrong Element type"); + } + if (X.getType().getY() > 1 || Y.getType().getY() > 1) { + throw new RSRuntimeException("BLAS vectors must have Y dimension of 0 or 1"); + } + + if (Ap.getType().getY() > 1) { + throw new RSRuntimeException("Ap must have a Y dimension of 0 or 1"); + } + + int N = (int)Math.sqrt((double)Ap.getType().getX() * 2); + if (Ap.getType().getX() != ((N * (N+1)) / 2)) { + throw new RSRuntimeException("Invalid dimension for Ap"); + } + + int expectedXDim = 1 + (N - 1) * incX; + if (X.getType().getX() != expectedXDim) { + throw new RSRuntimeException("Incorrect vector dimensions for SPMV"); + } + int expectedYDim = 1 + (N - 1) * incY; + if (Y.getType().getX() != expectedYDim) { + throw new RSRuntimeException("Incorrect vector dimensions for SPMV"); + } + + return N; + } + static void validateGER(Element e, Allocation X, int incX, Allocation Y, int incY, Allocation A) { + if (!A.getType().getElement().isCompatible(e) || + !X.getType().getElement().isCompatible(e) || + !Y.getType().getElement().isCompatible(e) ) { + throw new RSRuntimeException("Called BLAS with wrong Element type"); + } + + if (X.getType().getY() > 1 || Y.getType().getY() > 1) { + throw new RSRuntimeException("BLAS vectors must have Y dimension of 0 or 1"); + } + + int M = A.getType().getY(); + int N = A.getType().getX(); + + if (N < 1 || M < 1) { + throw new RSRuntimeException("M and N must be 1 or greater for GER"); + } + + int expectedXDim = 1 + (N - 1) * incX; + if (X.getType().getX() != expectedXDim) { + throw new RSRuntimeException("Incorrect vector dimensions for GER"); + } + int expectedYDim = 1 + (N - 1) * incY; + if (Y.getType().getX() != expectedYDim) { + throw new RSRuntimeException("Incorrect vector dimensions for GER"); + } + + + } + static int validateSYR(Element e, @Uplo int Uplo, Allocation X, int incX, Allocation A) { + validateUplo(Uplo); + if (!A.getType().getElement().isCompatible(e) || + !X.getType().getElement().isCompatible(e)) { + throw new RSRuntimeException("Called BLAS with wrong Element type"); + } + + int N = A.getType().getX(); + + if (X.getType().getY() > 1) { + throw new RSRuntimeException("BLAS vectors must have Y dimension of 0 or 1"); + } + if (N != A.getType().getY()) { + throw new RSRuntimeException("A must be a symmetric matrix"); + } + + int expectedXDim = 1 + (N - 1) * incX; + if (X.getType().getX() != expectedXDim) { + throw new RSRuntimeException("Incorrect vector dimensions for SYR"); + } + return N; + } + static int validateSPR(Element e, @Uplo int Uplo, Allocation X, int incX, Allocation Ap) { + validateUplo(Uplo); + if (!Ap.getType().getElement().isCompatible(e) || + !X.getType().getElement().isCompatible(e)) { + throw new RSRuntimeException("Called BLAS with wrong Element type"); + } + if (X.getType().getY() > 1) { + throw new RSRuntimeException("BLAS vectors must have Y dimension of 0 or 1"); + } + + if (Ap.getType().getY() > 1) { + throw new RSRuntimeException("Ap must have a Y dimension of 0 or 1"); + } + + int N = (int)Math.sqrt((double)Ap.getType().getX() * 2); + if (Ap.getType().getX() != ((N * (N+1)) / 2)) { + throw new RSRuntimeException("Invalid dimension for Ap"); + } + + int expectedXDim = 1 + (N - 1) * incX; + if (X.getType().getX() != expectedXDim) { + throw new RSRuntimeException("Incorrect vector dimensions for SPMV"); + } + + return N; + } + + static int validateSYR2(Element e, @Uplo int Uplo, Allocation X, int incX, Allocation Y, int incY, Allocation A) { + validateUplo(Uplo); + if (!A.getType().getElement().isCompatible(e) || + !X.getType().getElement().isCompatible(e) || + !Y.getType().getElement().isCompatible(e)) { + throw new RSRuntimeException("Called BLAS with wrong Element type"); + } + + if (X.getType().getY() > 1 || Y.getType().getY() > 1) { + throw new RSRuntimeException("BLAS vectors must have Y dimension of 0 or 1"); + } + + int N = A.getType().getX(); + + if (N != A.getType().getY()) { + throw new RSRuntimeException("A must be a symmetric matrix"); + } + + int expectedXDim = 1 + (N - 1) * incX; + int expectedYDim = 1 + (N - 1) * incY; + if (X.getType().getX() != expectedXDim || Y.getType().getX() != expectedYDim) { + throw new RSRuntimeException("Incorrect vector dimensions for SYR"); + } + return N; + + } + static int validateSPR2(Element e, @Uplo int Uplo, Allocation X, int incX, Allocation Y, int incY, Allocation Ap) { + validateUplo(Uplo); + if (!Ap.getType().getElement().isCompatible(e) || + !X.getType().getElement().isCompatible(e) || + !Y.getType().getElement().isCompatible(e)) { + throw new RSRuntimeException("Called BLAS with wrong Element type"); + } + if (X.getType().getY() > 1 || Y.getType().getY() > 1) { + throw new RSRuntimeException("BLAS vectors must have Y dimension of 0 or 1"); + } + + if (Ap.getType().getY() > 1) { + throw new RSRuntimeException("Ap must have a Y dimension of 0 or 1"); + } + + int N = (int)Math.sqrt((double)Ap.getType().getX() * 2); + if (Ap.getType().getX() != ((N * (N+1)) / 2)) { + throw new RSRuntimeException("Invalid dimension for Ap"); + } + + int expectedXDim = 1 + (N - 1) * incX; + int expectedYDim = 1 + (N - 1) * incY; + if (X.getType().getX() != expectedXDim || Y.getType().getX() != expectedYDim) { + throw new RSRuntimeException("Incorrect vector dimensions for SPMV"); + } + + return N; + } + + void SSYMV(@Uplo int Uplo, float alpha, Allocation A, Allocation X, int incX, float beta, Allocation Y, int incY) { + int N = validateSYMV(Element.F32(mRS), Uplo, A, X, Y, incX, incY); + mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_ssymv, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, A.getID(mRS), X.getID(mRS), beta, Y.getID(mRS), incX, incY, 0, 0); + } + void SSBMV(@Uplo int Uplo, int K, float alpha, Allocation A, Allocation X, int incX, float beta, Allocation Y, int incY) { + // SBMV is the same as SYMV + int N = validateSYMV(Element.F32(mRS), Uplo, A, X, Y, incX, incY); + mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_ssbmv, 0, 0, 0, Uplo, 0, 0, N, K, alpha, A.getID(mRS), X.getID(mRS), beta, Y.getID(mRS), incX, incY, 0, 0); + } + void SSPMV(@Uplo int Uplo, float alpha, Allocation Ap, Allocation X, int incX, float beta, Allocation Y, int incY) { + int N = validateSPMV(Element.F32(mRS), Uplo, Ap, X, incX, Y, incY); + mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_sspmv, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, Ap.getID(mRS), X.getID(mRS), beta, Y.getID(mRS), incX, incY, 0, 0); + } + void SGER(float alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) { + int M = A.getType().getY(); + int N = A.getType().getX(); + mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_sger, 0, 0, 0, 0, 0, M, N, 0, alpha, X.getID(mRS), Y.getID(mRS), 0.f, A.getID(mRS), incX, incY, 0, 0); + } + void SSYR(@Uplo int Uplo, float alpha, Allocation X, int incX, Allocation A) { + int N = validateSYR(Element.F32(mRS), Uplo, X, incX, A); + mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_ssyr, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, X.getID(mRS), A.getID(mRS), 0.f, 0, incX, 0, 0, 0); + } + void SSPR(@Uplo int Uplo, float alpha, Allocation X, int incX, Allocation Ap) { + int N = validateSPR(Element.F32(mRS), Uplo, X, incX, Ap); + mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_sspr, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, X.getID(mRS), Ap.getID(mRS), 0.f, 0, incX, 0, 0, 0); + } + void SSYR2(@Uplo int Uplo, float alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) { + int N = validateSYR2(Element.F32(mRS), Uplo, X, incX, Y, incY, A); + mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_ssyr2, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, X.getID(mRS), Y.getID(mRS), 0, A.getID(mRS), incX, incY, 0, 0); + } + void SSPR2(@Uplo int Uplo, float alpha, Allocation X, int incX, Allocation Y, int incY, Allocation Ap) { + int N = validateSPR2(Element.F32(mRS), Uplo, X, incX, Y, incY, Ap); + mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_sspr2, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, X.getID(mRS), Y.getID(mRS), 0, Ap.getID(mRS), incX, incY, 0, 0); + } + void DSYMV(@Uplo int Uplo, double alpha, Allocation A, Allocation X, int incX, double beta, Allocation Y, int incY) { + int N = validateSYMV(Element.F64(mRS), Uplo, A, X, Y, incX, incY); + mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dsymv, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, A.getID(mRS), X.getID(mRS), beta, Y.getID(mRS), incX, incY, 0, 0); + } + void DSBMV(@Uplo int Uplo, int K, double alpha, Allocation A, Allocation X, int incX, double beta, Allocation Y, int incY) { + // SBMV is the same as SYMV + int N = validateSYMV(Element.F64(mRS), Uplo, A, X, Y, incX, incY); + mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dsbmv, 0, 0, 0, Uplo, 0, 0, N, K, alpha, A.getID(mRS), X.getID(mRS), beta, Y.getID(mRS), incX, incY, 0, 0); + } + void DSPMV(@Uplo int Uplo, double alpha, Allocation Ap, Allocation X, int incX, double beta, Allocation Y, int incY) { + int N = validateSPMV(Element.F64(mRS), Uplo, Ap, X, incX, Y, incY); + mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dspmv, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, Ap.getID(mRS), X.getID(mRS), beta, Y.getID(mRS), incX, incY, 0, 0); + } + void DGER(double alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) { + int M = A.getType().getY(); + int N = A.getType().getX(); + mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dger, 0, 0, 0, 0, 0, M, N, 0, alpha, X.getID(mRS), Y.getID(mRS), 0.f, A.getID(mRS), incX, incY, 0, 0); + } + void DSYR(@Uplo int Uplo, double alpha, Allocation X, int incX, Allocation A) { + int N = validateSYR(Element.F64(mRS), Uplo, X, incX, A); + mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dsyr, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, X.getID(mRS), A.getID(mRS), 0.f, 0, incX, 0, 0, 0); + } + void DSPR(@Uplo int Uplo, double alpha, Allocation X, int incX, Allocation Ap) { + int N = validateSPR(Element.F64(mRS), Uplo, X, incX, Ap); + mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dspr, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, X.getID(mRS), Ap.getID(mRS), 0.f, 0, incX, 0, 0, 0); + } + void DSYR2(@Uplo int Uplo, double alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) { + int N = validateSYR2(Element.F64(mRS), Uplo, X, incX, Y, incY, A); + mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dsyr2, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, X.getID(mRS), Y.getID(mRS), 0, A.getID(mRS), incX, incY, 0, 0); + } + void DSPR2(@Uplo int Uplo, double alpha, Allocation X, int incX, Allocation Y, int incY, Allocation Ap) { + int N = validateSPR2(Element.F64(mRS), Uplo, X, incX, Y, incY, Ap); + mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dspr2, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, X.getID(mRS), Y.getID(mRS), 0, Ap.getID(mRS), incX, incY, 0, 0); + } + + + /** + * Level 2, C and Z only + */ + + static void validateGERU(Element e, Allocation X, int incX, Allocation Y, int incY, Allocation A) { + if (!A.getType().getElement().isCompatible(e) || + !X.getType().getElement().isCompatible(e) || + !Y.getType().getElement().isCompatible(e)) { + throw new RSRuntimeException("Called BLAS with wrong Element type"); + } + if (X.getType().getY() > 1 || Y.getType().getY() > 1) { + throw new RSRuntimeException("BLAS vectors must have Y dimension of 0 or 1"); + } + + int M = A.getType().getY(); + int N = A.getType().getX(); + + int expectedXDim = 1 + (N - 1) * incX; + if (X.getType().getX() != expectedXDim) { + throw new RSRuntimeException("Incorrect vector dimensions for GERU"); + } + int expectedYDim = 1 + (N - 1) * incY; + if (Y.getType().getX() != expectedYDim) { + throw new RSRuntimeException("Incorrect vector dimensions for GERU"); + } + + } + + void CHEMV(@Uplo int Uplo, Float2 alpha, Allocation A, Allocation X, int incX, Float2 beta, Allocation Y, int incY) { + // HEMV is the same as SYR2 validation-wise + int N = validateSYR2(Element.F32_2(mRS), Uplo, X, incX, Y, incY, A); + mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_chemv, 0, 0, 0, Uplo, 0, 0, N, 0, alpha.x, alpha.y, A.getID(mRS), X.getID(mRS), beta.x, beta.y, Y.getID(mRS), incX, incY, 0, 0); + } + void CHBMV(@Uplo int Uplo, int K, Float2 alpha, Allocation A, Allocation X, int incX, Float2 beta, Allocation Y, int incY) { + // HBMV is the same as SYR2 validation-wise + int N = validateSYR2(Element.F32_2(mRS), Uplo, X, incX, Y, incY, A); + if (K < 0) { + throw new RSRuntimeException("K must be 0 or greater for HBMV"); + } + mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_chbmv, 0, 0, 0, Uplo, 0, 0, N, K, alpha.x, alpha.y, A.getID(mRS), X.getID(mRS), beta.x, beta.y, Y.getID(mRS), incX, incY, 0, 0); + } + void CHPMV(@Uplo int Uplo, Float2 alpha, Allocation Ap, Allocation X, int incX, Float2 beta, Allocation Y, int incY) { + // HPMV is the same as SPR2 + int N = validateSPR2(Element.F32_2(mRS), Uplo, X, incX, Y, incY, Ap); + mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_chpmv, 0, 0, 0, Uplo, 0, 0, N, 0, alpha.x, alpha.y, Ap.getID(mRS), X.getID(mRS), beta.x, beta.y, Y.getID(mRS), incX, incY, 0, 0); + } + void CGERU(Float2 alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) { + validateGERU(Element.F32_2(mRS), X, incX, Y, incY, A); + int M = A.getType().getY(); + int N = A.getType().getX(); + mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_cgeru, 0, 0, 0, 0, 0, M, N, 0, alpha.x, alpha.y, X.getID(mRS), Y.getID(mRS), 0, 0, A.getID(mRS), incX, incY, 0, 0); + } + void CGERC(Float2 alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) { + // same as GERU + validateGERU(Element.F32_2(mRS), X, incX, Y, incY, A); + int M = A.getType().getY(); + int N = A.getType().getX(); + mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_cgerc, 0, 0, 0, 0, 0, M, N, 0, alpha.x, alpha.y, X.getID(mRS), Y.getID(mRS), 0, 0, A.getID(mRS), incX, incY, 0, 0); + } + void CHER(@Uplo int Uplo, float alpha, Allocation X, int incX, Allocation A) { + // same as SYR + int N = validateSYR(Element.F32(mRS), Uplo, X, incX, A); + mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_cher, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, 0, X.getID(mRS), 0, 0, 0, A.getID(mRS), incX, 0, 0, 0); + } + void CHPR(@Uplo int Uplo, float alpha, Allocation X, int incX, Allocation Ap) { + // equivalent to SPR for validation + int N = validateSPR(Element.F32_2(mRS), Uplo, X, incX, Ap); + mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_chpr, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, 0, X.getID(mRS), 0, 0, 0, Ap.getID(mRS), incX, 0, 0, 0); + } + void CHER2(@Uplo int Uplo, Float2 alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) { + // same as SYR2 + int N = validateSYR2(Element.F32_2(mRS), Uplo, X, incX, Y, incY, A); + mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_cher2, 0, 0, 0, Uplo, 0, 0, N, 0, alpha.x, alpha.y, X.getID(mRS), Y.getID(mRS), 0, 0, A.getID(mRS), incX, incY, 0, 0); + } + void CHPR2(@Uplo int Uplo, Float2 alpha, Allocation X, int incX, Allocation Y, int incY, Allocation Ap) { + // same as SPR2 + int N = validateSPR2(Element.F32_2(mRS), Uplo, X, incX, Y, incY, Ap); + mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_chpr2, 0, 0, 0, Uplo, 0, 0, N, 0, alpha.x, alpha.y, X.getID(mRS), Y.getID(mRS), 0, 0, Ap.getID(mRS), incX, incY, 0, 0); + } + void ZHEMV(@Uplo int Uplo, Double2 alpha, Allocation A, Allocation X, int incX, Double2 beta, Allocation Y, int incY) { + // HEMV is the same as SYR2 validation-wise + int N = validateSYR2(Element.F64_2(mRS), Uplo, X, incX, Y, incY, A); + mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zhemv, 0, 0, 0, Uplo, 0, 0, N, 0, alpha.x, alpha.y, A.getID(mRS), X.getID(mRS), beta.x, beta.y, Y.getID(mRS), incX, incY, 0, 0); + } + void ZHBMV(@Uplo int Uplo, int K, Double2 alpha, Allocation A, Allocation X, int incX, Double2 beta, Allocation Y, int incY) { + // HBMV is the same as SYR2 validation-wise + int N = validateSYR2(Element.F64_2(mRS), Uplo, X, incX, Y, incY, A); + if (K < 0) { + throw new RSRuntimeException("K must be 0 or greater for HBMV"); + } + mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zhbmv, 0, 0, 0, Uplo, 0, 0, N, K, alpha.x, alpha.y, A.getID(mRS), X.getID(mRS), beta.x, beta.y, Y.getID(mRS), incX, incY, 0, 0); + } + void ZHPMV(@Uplo int Uplo, Double2 alpha, Allocation Ap, Allocation X, int incX, Double2 beta, Allocation Y, int incY) { + // HPMV is the same as SPR2 + int N = validateSPR2(Element.F64_2(mRS), Uplo, X, incX, Y, incY, Ap); + mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zhpmv, 0, 0, 0, Uplo, 0, 0, N, 0, alpha.x, alpha.y, Ap.getID(mRS), X.getID(mRS), beta.x, beta.y, Y.getID(mRS), incX, incY, 0, 0); + } + void ZGERU(Double2 alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) { + validateGERU(Element.F64_2(mRS), X, incX, Y, incY, A); + int M = A.getType().getY(); + int N = A.getType().getX(); + mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zgeru, 0, 0, 0, 0, 0, M, N, 0, alpha.x, alpha.y, X.getID(mRS), Y.getID(mRS), 0, 0, A.getID(mRS), incX, incY, 0, 0); + } + void ZGERC(Double2 alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) { + // same as GERU + validateGERU(Element.F64_2(mRS), X, incX, Y, incY, A); + int M = A.getType().getY(); + int N = A.getType().getX(); + mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zgerc, 0, 0, 0, 0, 0, M, N, 0, alpha.x, alpha.y, X.getID(mRS), Y.getID(mRS), 0, 0, A.getID(mRS), incX, incY, 0, 0); + } + void ZHER(@Uplo int Uplo, double alpha, Allocation X, int incX, Allocation A) { + // same as SYR + int N = validateSYR(Element.F64(mRS), Uplo, X, incX, A); + mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zher, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, 0, X.getID(mRS), 0, 0, 0, A.getID(mRS), incX, 0, 0, 0); + } + void ZHPR(@Uplo int Uplo, double alpha, Allocation X, int incX, Allocation Ap) { + // equivalent to SPR for validation + int N = validateSPR(Element.F64_2(mRS), Uplo, X, incX, Ap); + mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zhpr, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, 0, X.getID(mRS), 0, 0, 0, Ap.getID(mRS), incX, 0, 0, 0); + } + void ZHER2(@Uplo int Uplo, Double2 alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) { + // same as SYR2 + int N = validateSYR2(Element.F64_2(mRS), Uplo, X, incX, Y, incY, A); + mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zher2, 0, 0, 0, Uplo, 0, 0, N, 0, alpha.x, alpha.y, X.getID(mRS), Y.getID(mRS), 0, 0, A.getID(mRS), incX, incY, 0, 0); + } + void ZHPR2(@Uplo int Uplo, Double2 alpha, Allocation X, int incX, Allocation Y, int incY, Allocation Ap) { + // same as SPR2 + int N = validateSPR2(Element.F64_2(mRS), Uplo, X, incX, Y, incY, Ap); + mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zhpr2, 0, 0, 0, Uplo, 0, 0, N, 0, alpha.x, alpha.y, X.getID(mRS), Y.getID(mRS), 0, 0, Ap.getID(mRS), incX, incY, 0, 0); + } + + + /** + * Level 3 BLAS + */ + + static void validateL3(Element e, int TransA, int TransB, int Side, Allocation A, Allocation B, Allocation C) { + int aX = -1, aY = -1, bX = -1, bY = -1, cX = -1, cY = -1; + if ((A != null && !A.getType().getElement().isCompatible(e)) || + (B != null && !B.getType().getElement().isCompatible(e)) || + (C != null && !C.getType().getElement().isCompatible(e))) { + throw new RSRuntimeException("Called BLAS with wrong Element type"); + } + if (C != null) { + cX = C.getType().getY(); + cY = C.getType().getX(); + } + if (Side == RIGHT) { + if (B != null) { + bX = A.getType().getY(); + bY = A.getType().getX(); + } + if (A != null) { + aX = B.getType().getY(); + aY = B.getType().getX(); + } + } else { + if (A != null) { + if (TransA == TRANSPOSE) { + aY = A.getType().getY(); + aX = A.getType().getX(); + } else { + aX = A.getType().getY(); + aY = A.getType().getX(); + } + } + if (B != null) { + if (TransB == TRANSPOSE) { + bY = B.getType().getY(); + bX = B.getType().getX(); + } else { + bX = B.getType().getY(); + bY = B.getType().getX(); + } + } + } + if (A != null && B != null && C != null) { + if (aY != bX || aX != cX || bY != cY) { + throw new RSRuntimeException("Called BLAS with invalid dimensions"); + } + } else if (A != null && C != null) { + // A and C only + if (aX != cY || aY != cX) { + throw new RSRuntimeException("Called BLAS with invalid dimensions"); + } + } else if (A != null && B != null) { + // A and B only + } + + } + + public void SGEMM(@Transpose int TransA, @Transpose int TransB, float alpha, Allocation A, + Allocation B, float beta, Allocation C) { + validateTranspose(TransA); + validateTranspose(TransB); + validateL3(Element.F32(mRS), TransA, TransB, 0, A, B, C); + + int M = -1, N = -1, K = -1; + if (TransA == TRANSPOSE) { + M = A.getType().getX(); + K = A.getType().getY(); + } else { + M = A.getType().getY(); + K = A.getType().getX(); + } + if (TransB == TRANSPOSE) { + N = B.getType().getY(); + } else { + N = B.getType().getX(); + } + mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_sgemm, TransA, TransB, 0, 0, 0, M, N, K, alpha, A.getID(mRS), B.getID(mRS), + beta, C.getID(mRS), 0, 0, 0, 0); + } + public void DGEMM(@Transpose int TransA, @Transpose int TransB, double alpha, Allocation A, + Allocation B, double beta, Allocation C) { + validateTranspose(TransA); + validateTranspose(TransB); + validateL3(Element.F64(mRS), TransA, TransB, 0, A, B, C); + int M = -1, N = -1, K = -1; + if (TransA == TRANSPOSE) { + M = A.getType().getX(); + K = A.getType().getY(); + } else { + M = A.getType().getY(); + K = A.getType().getX(); + } + if (TransB == TRANSPOSE) { + N = B.getType().getY(); + } else { + N = B.getType().getX(); + } + mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dgemm, TransA, TransB, 0, 0, 0, M, N, K, alpha, A.getID(mRS), B.getID(mRS), + beta, C.getID(mRS), 0, 0, 0, 0); + } + public void CGEMM(@Transpose int TransA, @Transpose int TransB, Float2 alpha, Allocation A, + Allocation B, Float2 beta, Allocation C) { + validateTranspose(TransA); + validateTranspose(TransB); + validateL3(Element.F32_2(mRS), TransA, TransB, 0, A, B, C); + int M = -1, N = -1, K = -1; + if (TransA == TRANSPOSE) { + M = A.getType().getX(); + K = A.getType().getY(); + } else { + M = A.getType().getY(); + K = A.getType().getX(); + } + if (TransB == TRANSPOSE) { + N = B.getType().getY(); + } else { + N = B.getType().getX(); + } + mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_cgemm, TransA, TransB, 0, 0, 0, M, N, K, alpha.x, alpha.y, A.getID(mRS), B.getID(mRS), + beta.x, beta.y, C.getID(mRS), 0, 0, 0, 0); + } + + public void ZGEMM(@Transpose int TransA, @Transpose int TransB, Double2 alpha, Allocation A, + Allocation B, Double2 beta, Allocation C) { + validateTranspose(TransA); + validateTranspose(TransB); + validateL3(Element.F64_2(mRS), TransA, TransB, 0, A, B, C); + int M = -1, N = -1, K = -1; + if (TransA == TRANSPOSE) { + M = A.getType().getX(); + K = A.getType().getY(); + } else { + M = A.getType().getY(); + K = A.getType().getX(); + } + if (TransB == TRANSPOSE) { + N = B.getType().getY(); + } else { + N = B.getType().getX(); + } + mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zgemm, TransA, TransB, 0, 0, 0, M, N, K, alpha.x, alpha.y, A.getID(mRS), B.getID(mRS), + beta.x, beta.y, C.getID(mRS), 0, 0, 0, 0); + } + + public void SSYMM(@Side int Side, @Uplo int Uplo, float alpha, Allocation A, + Allocation B, float beta, Allocation C) { + validateSide(Side); + validateUplo(Uplo); + validateL3(Element.F32(mRS), 0, 0, Side, A, B, C); + mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_ssymm, 0, 0, Side, Uplo, 0, C.getType().getY(), C.getType().getX(), 0, alpha, A.getID(mRS), B.getID(mRS), + beta, C.getID(mRS), 0, 0, 0, 0); + } + public void DSYMM(@Side int Side, @Uplo int Uplo, double alpha, Allocation A, + Allocation B, double beta, Allocation C) { + validateSide(Side); + validateUplo(Uplo); + validateL3(Element.F64(mRS), 0, 0, Side, A, B, C); + mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dsymm, 0, 0, Side, Uplo, 0, C.getType().getY(), C.getType().getX(), 0, alpha, A.getID(mRS), B.getID(mRS), + beta, C.getID(mRS), 0, 0, 0, 0); + } + public void CSYMM(@Side int Side, @Uplo int Uplo, Float2 alpha, Allocation A, + Allocation B, Float2 beta, Allocation C) { + validateSide(Side); + validateUplo(Uplo); + validateL3(Element.F32_2(mRS), 0, 0, Side, A, B, C); + mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_csymm, 0, 0, Side, Uplo, 0, C.getType().getY(), C.getType().getX(), 0, alpha.x, alpha.y, A.getID(mRS), B.getID(mRS), + beta.x, beta.y, C.getID(mRS), 0, 0, 0, 0); + } + public void ZSYMM(@Side int Side, @Uplo int Uplo, Double2 alpha, Allocation A, + Allocation B, Double2 beta, Allocation C) { + validateSide(Side); + validateUplo(Uplo); + validateL3(Element.F64_2(mRS), 0, 0, Side, A, B, C); + mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zsymm, 0, 0, Side, Uplo, 0, C.getType().getY(), C.getType().getX(), 0, alpha.x, alpha.y, A.getID(mRS), B.getID(mRS), + beta.x, beta.y, C.getID(mRS), 0, 0, 0, 0); + } + + public void SSYRK(@Uplo int Uplo, @Transpose int Trans, float alpha, Allocation A, float beta, Allocation C) { + validateTranspose(Trans); + validateUplo(Uplo); + validateL3(Element.F32(mRS), Trans, 0, 0, A, null, C); + int K = -1; + if (Trans == TRANSPOSE) { + K = A.getType().getY(); + } else { + K = A.getType().getX(); + } + + mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_ssyrk, Trans, 0, 0, Uplo, 0, 0, C.getType().getX(), K, alpha, A.getID(mRS), 0, beta, C.getID(mRS), 0, 0, 0, 0); + } + + public void DSYRK(@Uplo int Uplo, @Transpose int Trans, double alpha, Allocation A, double beta, Allocation C) { + validateTranspose(Trans); + validateUplo(Uplo); + validateL3(Element.F64(mRS), Trans, 0, 0, A, null, C); + int K = -1; + if (Trans == TRANSPOSE) { + K = A.getType().getY(); + } else { + K = A.getType().getX(); + } + mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dsyrk, Trans, 0, 0, Uplo, 0, 0, C.getType().getX(), K, alpha, A.getID(mRS), 0, beta, C.getID(mRS), 0, 0, 0, 0); + } + public void CSYRK(@Uplo int Uplo, @Transpose int Trans, float alphaX, float alphaY, Allocation A, float betaX, float betaY, Allocation C) { + validateTranspose(Trans); + validateUplo(Uplo); + validateL3(Element.F32_2(mRS), Trans, 0, 0, A, null, C); + int K = -1; + if (Trans == TRANSPOSE) { + K = A.getType().getY(); + } else { + K = A.getType().getX(); + } + mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_csyrk, Trans, 0, 0, Uplo, 0, 0, C.getType().getX(), K, alphaX, alphaY, A.getID(mRS), 0, betaX, betaY, + C.getID(mRS), 0, 0, 0, 0); + } + public void ZSYRK(@Uplo int Uplo, @Transpose int Trans, double alphaX, double alphaY, Allocation A, double betaX, double betaY, Allocation C) { + validateTranspose(Trans); + validateUplo(Uplo); + validateL3(Element.F64_2(mRS), Trans, 0, 0, A, null, C); + int K = -1; + if (Trans == TRANSPOSE) { + K = A.getType().getY(); + } else { + K = A.getType().getX(); + } + mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zsyrk, Trans, 0, 0, Uplo, 0, 0, C.getType().getX(), K, alphaX, alphaY, A.getID(mRS), 0, betaX, betaY, + C.getID(mRS), 0, 0, 0, 0); + } + + static void validateSYR2K(Element e, @Transpose int Trans, Allocation A, Allocation B, Allocation C) { + validateTranspose(Trans); + if (!A.getType().getElement().isCompatible(e) || + !B.getType().getElement().isCompatible(e) || + !C.getType().getElement().isCompatible(e)) { + throw new RSRuntimeException("Called BLAS with wrong Element type"); + } + int Cdim = -1; + // A is n x k if no transpose, k x n if transpose + // C is n x n + if (Trans == TRANSPOSE) { + // check columns versus C + Cdim = A.getType().getX(); + } else { + // check rows versus C + Cdim = A.getType().getY(); + } + if (C.getType().getX() != Cdim && C.getType().getY() != Cdim) { + throw new RSRuntimeException("Invalid symmetric matrix in SYR2K"); + } + // A dims == B dims + if (A.getType().getX() != B.getType().getX() || A.getType().getY() != B.getType().getY()) { + throw new RSRuntimeException("Invalid A and B in SYR2K"); + } + } + public void SSYR2K(@Uplo int Uplo, @Transpose int Trans, float alpha, Allocation A, Allocation B, float beta, Allocation C) { + validateUplo(Uplo); + validateSYR2K(Element.F32(mRS), Trans, A, B, C); + int K = -1; + if (Trans == TRANSPOSE) { + K = A.getType().getY(); + } else { + K = A.getType().getX(); + } + mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_ssyr2k, Trans, 0, 0, Uplo, 0, 0, C.getType().getX(), K, alpha, A.getID(mRS), B.getID(mRS), beta, C.getID(mRS), 0, 0, 0, 0); + } + public void DSYR2K(@Uplo int Uplo, @Transpose int Trans, double alpha, Allocation A, Allocation B, double beta, Allocation C) { + validateUplo(Uplo); + validateSYR2K(Element.F64(mRS), Trans, A, B, C); + int K = -1; + if (Trans == TRANSPOSE) { + K = A.getType().getY(); + } else { + K = A.getType().getX(); + } + mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_ssyr2k, Trans, 0, 0, Uplo, 0, 0, C.getType().getX(), K, alpha, A.getID(mRS), B.getID(mRS), beta, C.getID(mRS), 0, 0, 0, 0); + } + public void CSYR2K(@Uplo int Uplo, @Transpose int Trans, Float2 alpha, Allocation A, Allocation B, Float2 beta, Allocation C) { + validateUplo(Uplo); + validateSYR2K(Element.F32_2(mRS), Trans, A, B, C); + int K = -1; + if (Trans == TRANSPOSE) { + K = A.getType().getY(); + } else { + K = A.getType().getX(); + } + mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_ssyr2k, Trans, 0, 0, Uplo, 0, 0, C.getType().getX(), K, alpha.x, alpha.y, A.getID(mRS), B.getID(mRS), beta.x, beta.y, C.getID(mRS), 0, 0, 0, 0); + } + public void ZSYR2K(@Uplo int Uplo, @Transpose int Trans, Double2 alpha, Allocation A, Allocation B, Double2 beta, Allocation C) { + validateUplo(Uplo); + validateSYR2K(Element.F64_2(mRS), Trans, A, B, C); + int K = -1; + if (Trans == TRANSPOSE) { + K = A.getType().getY(); + } else { + K = A.getType().getX(); + } + mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_ssyr2k, Trans, 0, 0, Uplo, 0, 0, C.getType().getX(), K, alpha.x, alpha.y, A.getID(mRS), B.getID(mRS), beta.x, beta.y, C.getID(mRS), 0, 0, 0, 0); + } + + static void validateTRMM(Element e, @Side int Side, @Transpose int TransA, Allocation A, Allocation B) { + validateSide(Side); + validateTranspose(TransA); + int aX = -1, aY = -1, bX = -1, bY = -1; + if (!A.getType().getElement().isCompatible(e) || + !B.getType().getElement().isCompatible(e)) { + throw new RSRuntimeException("Called BLAS with wrong Element type"); + } + if (TransA == TRANSPOSE) { + aY = A.getType().getY(); + aX = A.getType().getX(); + } else { + aY = A.getType().getX(); + aX = A.getType().getY(); + } + bX = B.getType().getY(); + bY = B.getType().getX(); + if (Side == LEFT) { + if (aX == 0 || aY != bX) { + throw new RSRuntimeException("Called TRMM with invalid matrices"); + } + } else { + if (bY != aX || aY == 0) { + throw new RSRuntimeException("Called TRMM with invalid matrices"); + } + } + } + public void STRMM(@Side int Side, @Uplo int Uplo, @Transpose int TransA, @Diag int Diag, float alpha, Allocation A, Allocation B) { + validateUplo(Uplo); + validateDiag(Diag); + validateTRMM(Element.F32(mRS), Side, TransA, A, B); + mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_strmm, TransA, 0, Side, Uplo, Diag, B.getType().getY(), B.getType().getX(), 0, + alpha, A.getID(mRS), B.getID(mRS), 0.f, 0, 0, 0, 0, 0); + } + public void DTRMM(@Side int Side, @Uplo int Uplo, @Transpose int TransA, @Diag int Diag, double alpha, Allocation A, Allocation B) { + validateUplo(Uplo); + validateDiag(Diag); + validateTRMM(Element.F64(mRS), Side, TransA, A, B); + mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_strmm, TransA, 0, Side, Uplo, Diag, B.getType().getY(), B.getType().getX(), 0, + alpha, A.getID(mRS), B.getID(mRS), 0.f, 0, 0, 0, 0, 0); + } + public void CTRMM(@Side int Side, @Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Float2 alpha, Allocation A, Allocation B) { + validateUplo(Uplo); + validateDiag(Diag); + validateTRMM(Element.F32_2(mRS), Side, TransA, A, B); + mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_strmm, TransA, 0, Side, Uplo, Diag, B.getType().getY(), B.getType().getX(), 0, + alpha.x, alpha.y, A.getID(mRS), B.getID(mRS), 0, 0, 0, 0, 0, 0, 0); + } + public void ZTRMM(@Side int Side, @Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Double2 alpha, Allocation A, Allocation B) { + validateUplo(Uplo); + validateDiag(Diag); + validateTRMM(Element.F64_2(mRS), Side, TransA, A, B); + mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_strmm, TransA, 0, Side, Uplo, Diag, B.getType().getY(), B.getType().getX(), 0, + alpha.x, alpha.y, A.getID(mRS), B.getID(mRS), 0, 0, 0, 0, 0, 0, 0); + } + + static void validateTRSM(Element e, @Side int Side, @Transpose int TransA, Allocation A, Allocation B) { + int adim = -1, bX = -1, bY = -1; + validateSide(Side); + validateTranspose(TransA); + if (!A.getType().getElement().isCompatible(e) || + !B.getType().getElement().isCompatible(e)) { + throw new RSRuntimeException("Called BLAS with wrong Element type"); + } + adim = A.getType().getX(); + if (adim != A.getType().getY()) { + // this may be unnecessary, the restriction could potentially be relaxed + // A needs to contain at least that symmetric matrix but could theoretically be larger + // for now we assume adapters are sufficient, will reevaluate in the future + throw new RSRuntimeException("Called TRSM with a non-symmetric matrix A"); + } + bX = B.getType().getY(); + bY = B.getType().getX(); + if (Side == LEFT) { + // A is M*M + if (adim != bY) { + throw new RSRuntimeException("Called TRSM with invalid matrix dimensions"); + } + } else { + // A is N*N + if (adim != bX) { + throw new RSRuntimeException("Called TRSM with invalid matrix dimensions"); + } + } + } + public void STRSM(@Side int Side, @Uplo int Uplo, @Transpose int TransA, @Diag int Diag, float alpha, Allocation A, Allocation B) { + validateUplo(Uplo); + validateDiag(Diag); + validateTRSM(Element.F32(mRS), Side, TransA, A, B); + mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_strsm, TransA, 0, Side, Uplo, Diag, B.getType().getY(), B.getType().getX(), 0, + alpha, A.getID(mRS), B.getID(mRS), 0, 0, 0, 0, 0, 0); + } + public void DTRSM(@Side int Side, @Uplo int Uplo, @Transpose int TransA, @Diag int Diag, double alpha, Allocation A, Allocation B) { + validateUplo(Uplo); + validateDiag(Diag); + validateTRSM(Element.F64(mRS), Side, TransA, A, B); + mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_strsm, TransA, 0, Side, Uplo, Diag, B.getType().getY(), B.getType().getX(), 0, + alpha, A.getID(mRS), B.getID(mRS), 0, 0, 0, 0, 0, 0); + } + public void CTRSM(@Side int Side, @Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Float2 alpha, Allocation A, Allocation B) { + validateUplo(Uplo); + validateDiag(Diag); + validateTRSM(Element.F32_2(mRS), Side, TransA, A, B); + mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_strsm, TransA, 0, Side, Uplo, Diag, B.getType().getY(), B.getType().getX(), 0, + alpha.x, alpha.y, A.getID(mRS), B.getID(mRS), 0, 0, 0, 0, 0, 0, 0); + } + public void ZTRSM(@Side int Side, @Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Double2 alpha, Allocation A, Allocation B) { + validateUplo(Uplo); + validateDiag(Diag); + validateTRSM(Element.F64_2(mRS), Side, TransA, A, B); + mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_strsm, TransA, 0, Side, Uplo, Diag, B.getType().getY(), B.getType().getX(), 0, + alpha.x, alpha.y, A.getID(mRS), B.getID(mRS), 0, 0, 0, 0, 0, 0, 0); + } + + static void validateHEMM(Element e, @Side int Side, Allocation A, Allocation B, Allocation C) { + validateSide(Side); + + if (!A.getType().getElement().isCompatible(e) || + !B.getType().getElement().isCompatible(e) || + !C.getType().getElement().isCompatible(e)) { + throw new RSRuntimeException("Called BLAS with wrong Element type"); + } + + // A must be square; can potentially be relaxed similar to TRSM + int adim = A.getType().getX(); + if (adim != A.getType().getY()) { + throw new RSRuntimeException("Called HEMM with non-square A"); + } + if ((Side == LEFT && adim != B.getType().getY()) || + (Side == RIGHT && adim != B.getType().getX())) { + throw new RSRuntimeException("Called HEMM with invalid B"); + } + if (B.getType().getX() != C.getType().getX() || + B.getType().getY() != C.getType().getY()) { + throw new RSRuntimeException("Called HEMM with mismatched B and C"); + } + } + public void CHEMM(@Side int Side, @Uplo int Uplo, float alpha, Allocation A, Allocation B, float beta, Allocation C) { + validateUplo(Uplo); + validateHEMM(Element.F32_2(mRS), Side, A, B, C); + mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_chemm, 0, 0, Side, Uplo, 0, C.getType().getY(), C.getType().getX(), 0, + alpha, 0, A.getID(mRS), B.getID(mRS), beta, 0, C.getID(mRS), 0, 0, 0, 0); + } + public void ZHEMM(@Side int Side, @Uplo int Uplo, double alpha, Allocation A, Allocation B, double beta, Allocation C) { + validateUplo(Uplo); + validateHEMM(Element.F32_2(mRS), Side, A, B, C); + mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zhemm, 0, 0, Side, Uplo, 0, C.getType().getY(), C.getType().getX(), 0, + alpha, 0, A.getID(mRS), B.getID(mRS), beta, 0, C.getID(mRS), 0, 0, 0, 0); + } + + static void validateHERK(Element e, @Transpose int Trans, Allocation A, Allocation C) { + if (!A.getType().getElement().isCompatible(e) || + !C.getType().getElement().isCompatible(e)) { + throw new RSRuntimeException("Called BLAS with wrong Element type"); + } + validateConjTranspose(Trans); + int cdim = C.getType().getX(); + if (cdim != C.getType().getY()) { + throw new RSRuntimeException("Called HERK with non-square C"); + } + if (Trans == NO_TRANSPOSE) { + if (cdim != A.getType().getX()) { + throw new RSRuntimeException("Called HERK with invalid A"); + } + } else { + if (cdim != A.getType().getY()) { + throw new RSRuntimeException("Called HERK with invalid A"); + } + } + } + public void CHERK(@Uplo int Uplo, @Transpose int Trans, float alpha, Allocation A, float beta, Allocation C) { + validateUplo(Uplo); + validateHERK(Element.F32_2(mRS), Trans, A, C); + int k = 0; + if (Trans == TRANSPOSE) { + k = A.getType().getY(); + } else { + k = A.getType().getX(); + } + mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_cherk, Trans, 0, 0, Uplo, 0, 0, C.getType().getX(), k, + alpha, 0, A.getID(mRS), 0, beta, 0, C.getID(mRS), 0, 0, 0, 0); + } + public void ZHERK(@Uplo int Uplo, @Transpose int Trans, double alpha, Allocation A, double beta, Allocation C) { + validateUplo(Uplo); + validateHERK(Element.F64_2(mRS), Trans, A, C); + int k = 0; + if (Trans == TRANSPOSE) { + k = A.getType().getY(); + } else { + k = A.getType().getX(); + } + mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zherk, Trans, 0, 0, Uplo, 0, 0, C.getType().getX(), k, + alpha, 0, A.getID(mRS), 0, beta, 0, C.getID(mRS), 0, 0, 0, 0); + } + + static void validateHER2K(Element e, @Transpose int Trans, Allocation A, Allocation B, Allocation C) { + if (!A.getType().getElement().isCompatible(e) || + !B.getType().getElement().isCompatible(e) || + !C.getType().getElement().isCompatible(e)) { + throw new RSRuntimeException("Called BLAS with wrong Element type"); + } + validateConjTranspose(Trans); + int cdim = C.getType().getX(); + if (cdim != C.getType().getY()) { + throw new RSRuntimeException("Called HER2K with non-square C"); + } + if (Trans == NO_TRANSPOSE) { + if (A.getType().getY() != cdim) { + throw new RSRuntimeException("Called HER2K with invalid matrices"); + } + } else { + if (A.getType().getX() != cdim) { + throw new RSRuntimeException("Called HER2K with invalid matrices"); + } + } + if (A.getType().getX() != B.getType().getX() || A.getType().getY() != B.getType().getY()) { + throw new RSRuntimeException("Called HER2K with invalid A and B matrices"); + } + } + public void CHER2K(@Uplo int Uplo, @Transpose int Trans, Float2 alpha, Allocation A, Allocation B, float beta, Allocation C) { + validateUplo(Uplo); + validateHER2K(Element.F32_2(mRS), Trans, A, B, C); + int k = 0; + if (Trans == NO_TRANSPOSE) { + k = A.getType().getX(); + } else { + k = A.getType().getY(); + } + mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_cher2k, Trans, 0, 0, Uplo, 0, 0, C.getType().getX(), k, alpha.x, alpha.y, + A.getID(mRS), B.getID(mRS), beta, 0, C.getID(mRS), 0, 0, 0, 0); + } + public void ZHER2K(@Uplo int Uplo, @Transpose int Trans, Double2 alpha, Allocation A, Allocation B, double beta, Allocation C) { + validateUplo(Uplo); + validateHER2K(Element.F64_2(mRS), Trans, A, B, C); + int k = 0; + if (Trans == NO_TRANSPOSE) { + k = A.getType().getX(); + } else { + k = A.getType().getY(); + } + mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zher2k, Trans, 0, 0, Uplo, 0, 0, C.getType().getX(), k, alpha.x, alpha.y, + A.getID(mRS), B.getID(mRS), beta, 0, C.getID(mRS), 0, 0, 0, 0); + } + + + /** + * + * 8-bit GEMM-like operation for neural networks + * + * @hide + **/ + public void BNNM(Allocation A, int a_offset, Allocation B, int b_offset, Allocation C, int c_offset, int c_mult) { + validateL3(Element.U8(mRS), NO_TRANSPOSE, TRANSPOSE, 0, A, B, C); + + int M = -1, N = -1, K = -1; + M = A.getType().getY(); + N = B.getType().getY(); + K = A.getType().getX(); + + + mRS.nScriptIntrinsicBLAS_BNNM(getID(mRS), M, N, K, A.getID(mRS), a_offset, B.getID(mRS), b_offset, C.getID(mRS), c_offset, c_mult); + + } + +} diff --git a/rs/java/android/renderscript/ScriptIntrinsicBlur.java b/rs/java/android/renderscript/ScriptIntrinsicBlur.java index 5c4edd3ae9d8..60e2b6d99b1e 100644 --- a/rs/java/android/renderscript/ScriptIntrinsicBlur.java +++ b/rs/java/android/renderscript/ScriptIntrinsicBlur.java @@ -34,7 +34,7 @@ public final class ScriptIntrinsicBlur extends ScriptIntrinsic { * Create an intrinsic for applying a blur to an allocation. The * default radius is 5.0. * - * Supported elements types are {@link Element#U8_4} + * Supported elements types are {@link Element#U8_4 Element#U8} * * @param rs The RenderScript context * @param e Element type for inputs and outputs diff --git a/rs/java/android/renderscript/ScriptIntrinsicResize.java b/rs/java/android/renderscript/ScriptIntrinsicResize.java index d6764ccab75e..cee4c33ee87d 100644 --- a/rs/java/android/renderscript/ScriptIntrinsicResize.java +++ b/rs/java/android/renderscript/ScriptIntrinsicResize.java @@ -29,6 +29,8 @@ public final class ScriptIntrinsicResize extends ScriptIntrinsic { /** * Supported elements types are {@link Element#U8}, {@link * Element#U8_2}, {@link Element#U8_3}, {@link Element#U8_4} + * {@link Element#F32}, {@link Element#F32_2}, {@link + * Element#F32_3}, {@link Element#F32_4} * * @param rs The RenderScript context * @@ -52,7 +54,11 @@ public final class ScriptIntrinsicResize extends ScriptIntrinsic { if (!e.isCompatible(Element.U8(mRS)) && !e.isCompatible(Element.U8_2(mRS)) && !e.isCompatible(Element.U8_3(mRS)) && - !e.isCompatible(Element.U8_4(mRS))) { + !e.isCompatible(Element.U8_4(mRS)) && + !e.isCompatible(Element.F32(mRS)) && + !e.isCompatible(Element.F32_2(mRS)) && + !e.isCompatible(Element.F32_3(mRS)) && + !e.isCompatible(Element.F32_4(mRS))) { throw new RSIllegalArgumentException("Unsuported element type."); } diff --git a/rs/java/android/renderscript/Type.java b/rs/java/android/renderscript/Type.java index 98aeaa95d3bc..a58e42cd9712 100644 --- a/rs/java/android/renderscript/Type.java +++ b/rs/java/android/renderscript/Type.java @@ -52,6 +52,9 @@ public class Type extends BaseObj { int mDimYuv; int mElementCount; Element mElement; + int mArrays[]; + + static final int mMaxArrays = 4; public enum CubemapFace { POSITIVE_X (0), @@ -146,6 +149,30 @@ public class Type extends BaseObj { return mElementCount; } + /** + * @hide + */ + public int getArray(int dim) { + if ((dim < 0) || (dim >= mMaxArrays)) { + throw new RSIllegalArgumentException("Array dimension out of range."); + } + + if (mArrays == null || dim >= mArrays.length) { + // Dimension in range but no array for that dimension allocated + return 0; + } + + return mArrays[dim]; + } + + /** + * @hide + */ + public int getArrayCount() { + if (mArrays != null) return mArrays.length; + return 0; + } + void calcElementCount() { boolean hasLod = hasMipmaps(); int x = getX(); @@ -180,6 +207,13 @@ public class Type extends BaseObj { count += x * y * z * faces; } + + if (mArrays != null) { + for (int ct = 0; ct < mArrays.length; ct++) { + count *= mArrays[ct]; + } + } + mElementCount = count; } @@ -296,6 +330,7 @@ public class Type extends BaseObj { boolean mDimMipmaps; boolean mDimFaces; int mYuv; + int[] mArray = new int[mMaxArrays]; Element mElement; @@ -341,6 +376,22 @@ public class Type extends BaseObj { return this; } + /** + * @hide + * + * @param dim + * @param value + * + * @return Builder + */ + public Builder setArray(int dim, int value) { + if(dim < 0 || dim >= mMaxArrays) { + throw new RSIllegalArgumentException("Array dimension out of range."); + } + mArray[dim] = value; + return this; + } + public Builder setMipmaps(boolean value) { mDimMipmaps = value; return this; @@ -405,6 +456,16 @@ public class Type extends BaseObj { } } + int[] arrays = null; + for (int ct = mMaxArrays - 1; ct >= 0; ct--) { + if (mArray[ct] != 0 && arrays == null) { + arrays = new int[ct]; + } + if ((mArray[ct] == 0) && (arrays != null)) { + throw new RSInvalidStateException("Array dimensions must be contigous from 0."); + } + } + long id = mRS.nTypeCreate(mElement.getID(mRS), mDimX, mDimY, mDimZ, mDimMipmaps, mDimFaces, mYuv); Type t = new Type(id, mRS); @@ -415,6 +476,7 @@ public class Type extends BaseObj { t.mDimMipmaps = mDimMipmaps; t.mDimFaces = mDimFaces; t.mDimYuv = mYuv; + t.mArrays = arrays; t.calcElementCount(); return t; |