Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into snnn/cleanup1
Browse files Browse the repository at this point in the history
  • Loading branch information
snnn committed Nov 14, 2024
2 parents f669f69 + 12dfe28 commit 2d0f091
Show file tree
Hide file tree
Showing 10 changed files with 182 additions and 138 deletions.
1 change: 1 addition & 0 deletions docs/OperatorKernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -1086,6 +1086,7 @@ Do not modify directly.*
|GreaterOrEqual|*in* A:**T**<br> *in* B:**T**<br> *out* C:**T1**|16+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **T1** = tensor(bool)|
|||12+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **T1** = tensor(bool)|
|GridSample|*in* X:**T1**<br> *in* grid:**T2**<br> *out* Y:**T1**|16+|**T1** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **T2** = tensor(float), tensor(float16)|
|GroupNorm||21+|**M** = tensor(float), tensor(float16)<br/> **T** = tensor(float), tensor(float16)|
|HardSigmoid|*in* X:**T**<br> *out* Y:**T**|6+|**T** = tensor(float), tensor(float16)|
|Hardmax|*in* input:**T**<br> *out* output:**T**|13+|**T** = tensor(float), tensor(float16)|
|||11+|**T** = tensor(float), tensor(float16)|
Expand Down
4 changes: 2 additions & 2 deletions js/web/lib/wasm/jsep/webnn/tensor-manager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ class TensorIdTracker {
}

// eslint-disable-next-line no-bitwise
const usage = MLTensorUsage.READ | MLTensorUsage.WRITE;
const usage = typeof MLTensorUsage == 'undefined' ? undefined : MLTensorUsage.READ | MLTensorUsage.WRITE;
this.wrapper = await this.tensorManager.getCachedTensor(dataType, shape, usage, true, true);

if (copyOld && this.activeUpload) {
Expand Down Expand Up @@ -349,7 +349,7 @@ class TensorManagerImpl implements TensorManager {
public async getCachedTensor(
dataType: MLOperandDataType,
shape: readonly number[],
usage: MLTensorUsageFlags,
usage: MLTensorUsageFlags | undefined,
writable: boolean,
readable: boolean,
): Promise<TensorWrapper> {
Expand Down
3 changes: 2 additions & 1 deletion js/web/lib/wasm/jsep/webnn/webnn.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -400,7 +400,8 @@ declare const MLTensorUsage: {
};

interface MLTensorDescriptor extends MLOperandDescriptor {
usage: MLTensorUsageFlags;
/** @deprecated Use readable/writeable instead of usage */
usage: MLTensorUsageFlags | undefined;
importableToWebGPU?: boolean;
readable?: boolean;
writable?: boolean;
Expand Down
4 changes: 2 additions & 2 deletions js/web/test/test-runner.ts
Original file line number Diff line number Diff line change
Expand Up @@ -661,7 +661,7 @@ async function createMLTensorForOutput(mlContext: MLContext, type: ort.Tensor.Ty
shape: dims as number[],
// Assign both shape and dimensions while transitioning to new API.
dimensions: dims as number[],
usage: MLTensorUsage.READ,
usage: typeof MLTensorUsage == 'undefined' ? undefined : MLTensorUsage.READ,
readable: true,
});

Expand All @@ -686,7 +686,7 @@ async function createMLTensorForInput(mlContext: MLContext, cpuTensor: ort.Tenso
shape: cpuTensor.dims as number[],
// Assign both shape and dimensions while transitioning to new API.
dimensions: cpuTensor.dims as number[],
usage: MLTensorUsage.WRITE,
usage: typeof MLTensorUsage == 'undefined' ? undefined : MLTensorUsage.WRITE,
writable: true,
});
mlContext.writeTensor(mlTensor, cpuTensor.data);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ void ConvQDQRules(SelectorActionRegistry& qdq_selector_action_registry, bool is_

#if !defined(ORT_MINIMAL_BUILD)
// TODO: Enable 16-bit types in selector when QLinearConv supports 16-bit.
std::vector<const char*> providers = {kCpuExecutionProvider, kDmlExecutionProvider};
std::vector<const char*> providers = {kCpuExecutionProvider, kDmlExecutionProvider, kAclExecutionProvider};
std::unique_ptr<NodeSelector> selector = std::make_unique<QDQ::ConvSelector>(is_int8_allowed,
false,
false,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1172,6 +1172,7 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation
{REG_INFO_MS( 1, BiasAdd, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO_MS( 1, QuickGelu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO_MS( 1, GroupNorm, typeNameListGroupNorm, supportedTypeListGroupNorm, DmlGraphSupport::Supported)},
{REG_INFO( 21, GroupNorm, typeNameListGroupNorm, supportedTypeListGroupNorm, DmlGraphSupport::Supported)},
{REG_INFO_MS( 1, MatMulNBits, typeNameListTwo, supportedTypeListMatMulNBits, DmlGraphSupport::Supported, requiredConstantCpuInputs(), std::nullopt, QueryMatMulNBits)},

// Operators that need to alias an input with an output
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1648,6 +1648,7 @@ using ShapeInferenceHelper_BatchNormalization15 = BatchNormalizationHelper;
using ShapeInferenceHelper_LRN = GetOutputShapeAsInputShapeHelper;
using ShapeInferenceHelper_MeanVarianceNormalization = GetOutputShapeAsInputShapeHelper;
using ShapeInferenceHelper_GroupNorm = GetOutputShapeAsInputShapeHelper;
using ShapeInferenceHelper_GroupNorm21 = GetOutputShapeAsInputShapeHelper;
using ShapeInferenceHelper_LayerNormalization = GetOutputShapeAsInputShapeHelper;
using ShapeInferenceHelper_LayerNormalization17 = GetOutputShapeAsInputShapeHelper;
using ShapeInferenceHelper_SkipLayerNormalization = SkipLayerNormHelper;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -454,6 +454,7 @@ namespace OperatorHelper
static const int sc_sinceVer_Transpose = 21;
static const int sc_sinceVer_Identity = 21;
static const int sc_sinceVer_QLinearMatMul = 21;
static const int sc_sinceVer_GroupNorm = 21;
}

namespace MsftOperatorSet1
Expand Down
Loading

0 comments on commit 2d0f091

Please sign in to comment.