From 75968b9ecae4b6335913a0f774efd8970a0f8284 Mon Sep 17 00:00:00 2001 From: Rachel Guo <35738743+YUNQIUGUO@users.noreply.github.com> Date: Wed, 21 Feb 2024 12:51:37 -0800 Subject: [PATCH] Cherry-pick for 1.17.1 patch release (#19477) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### Description As title. ### Motivation and Context --------- Co-authored-by: petermcaughan Co-authored-by: Peter McAughan Co-authored-by: Adrian Lizarraga Co-authored-by: Patrice Vignola Co-authored-by: ivberg Co-authored-by: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Co-authored-by: Baiju Meswani Co-authored-by: Preetha Veeramalai Co-authored-by: fxmarty <9808326+fxmarty@users.noreply.github.com> Co-authored-by: Sheil Kumar Co-authored-by: Sheil Kumar Co-authored-by: Prathik Rao Co-authored-by: Shubham Bhokare <32080845+shubhambhokare1@users.noreply.github.com> Co-authored-by: kunal-vaishnavi <115581922+kunal-vaishnavi@users.noreply.github.com> Co-authored-by: Jian Chen Co-authored-by: Xavier Dupré Co-authored-by: satyajandhyala --- cmake/adjust_global_compile_flags.cmake | 9 +- docs/ContribOperators.md | 32 +- js/web/package.json | 10 + .../cpu/transformers/beam_search_impl_gpt.h | 2 +- .../cpu/transformers/beam_search_impl_t5.h | 2 +- .../transformers/beam_search_impl_whisper.h | 6 +- .../transformers/beam_search_parameters.cc | 8 +- .../cpu/transformers/generation_shared.h | 9 +- .../cpu/transformers/logits_processor.h | 81 +++-- .../transformers/generation_device_helper.cc | 12 +- .../core/framework/execution_providers.h | 55 +++- .../core/graph/contrib_ops/contrib_defs.cc | 40 +-- onnxruntime/core/graph/graph.cc | 37 ++- onnxruntime/core/graph/graph_viewer.cc | 18 +- .../core/optimizer/noop_elimination.cc | 73 +++-- .../ort_optimizer_api_impl.cc | 2 +- .../core/platform/windows/telemetry.cc | 24 +- onnxruntime/core/platform/windows/telemetry.h | 15 +- .../src/GraphDescBuilder.cpp | 15 +- .../src/MLOperatorAuthorImpl.cpp | 12 +- .../src/MLOperatorAuthorImpl.h | 1 + .../openvino/backends/basic_backend.cc | 11 +- .../core/providers/openvino/ov_interface.cc | 4 +- .../core/providers/openvino/ov_interface.h | 2 +- .../qnn/builder/opbuilder/split_op_builder.cc | 40 ++- onnxruntime/core/session/inference_session.cc | 72 ++++- onnxruntime/core/session/inference_session.h | 14 +- .../core/session/provider_bridge_ort.cc | 18 +- .../core/session/provider_registration.cc | 4 + onnxruntime/core/util/thread_utils.cc | 10 + .../quantization/matmul_4bits_quantizer.py | 9 +- .../tools/quantization/onnx_quantizer.py | 35 ++- .../tools/quantization/qdq_quantizer.py | 2 + .../python/tools/symbolic_shape_infer.py | 28 ++ .../transformers/fusion_bart_attention.py | 239 ++++++++++++++- .../transformers/models/whisper/README.md | 46 ++- .../transformers/models/whisper/benchmark.py | 22 +- .../models/whisper/benchmark_all.py | 6 + .../models/whisper/convert_to_onnx.py | 288 ++++++++++-------- .../models/whisper/requirements-cpu.txt | 2 + .../models/whisper/requirements-cuda.txt | 4 + .../models/whisper/requirements.txt | 11 + .../models/whisper/whisper_chain.py | 281 ++++++++++------- .../models/whisper/whisper_decoder.py | 16 +- .../models/whisper/whisper_encoder.py | 17 +- .../whisper/whisper_encoder_decoder_init.py | 31 +- .../models/whisper/whisper_helper.py | 136 ++++++--- .../models/whisper/whisper_openai_helper.py | 76 +++++ .../python/tools/transformers/onnx_model.py | 48 +++ .../tools/transformers/onnx_model_bart.py | 2 +- .../tools/transformers/quantize_helper.py | 3 +- .../transformers/torch_onnx_export_helper.py | 3 +- .../test/framework/allocation_planner_test.cc | 21 +- onnxruntime/test/framework/bfc_arena_test.cc | 2 + .../test/framework/execution_frame_test.cc | 55 +++- .../logging/HowToValidateEtwSinkOutput.md | 6 +- .../test/providers/qnn/split_op_test.cc | 41 ++- ...untime_test_python_symbolic_shape_infer.py | 202 ++++++++++++ .../test/python/quantization/test_qdq.py | 7 + .../test_quantizer_shape_inference.py | 92 ++++++ .../test/python/quantization/test_subgraph.py | 64 ++++ .../python/transformers/test_generation.py | 25 +- .../test_whisper_timestamp_processor.py | 4 +- .../ortmodule/_custom_gradient_registry.py | 7 +- .../ortmodule/_custom_op_symbolic_registry.py | 13 + .../cpu/torch_interop_utils/setup.py | 2 +- .../python/orttraining_test_ortmodule_api.py | 28 ++ .../nodejs/templates/test_linux.yml | 3 +- .../nodejs/templates/test_macos.yml | 2 +- .../nodejs/templates/test_win.yml | 2 +- ...orttraining-py-packaging-pipeline-cuda.yml | 2 +- .../azure-pipelines/py-packaging-pipeline.yml | 6 + .../azure-pipelines/templates/c-api-cpu.yml | 57 ++-- .../templates/py-packaging-stage.yml | 11 + .../templates/py-win-x64-qnn.yml | 177 +++++++++++ winml/lib/Api/HardwareCoreEnumerator.cpp | 27 +- 76 files changed, 2220 insertions(+), 579 deletions(-) create mode 100644 onnxruntime/python/tools/transformers/models/whisper/requirements-cpu.txt create mode 100644 onnxruntime/python/tools/transformers/models/whisper/requirements-cuda.txt create mode 100644 onnxruntime/python/tools/transformers/models/whisper/requirements.txt create mode 100644 onnxruntime/python/tools/transformers/models/whisper/whisper_openai_helper.py create mode 100644 onnxruntime/test/python/quantization/test_quantizer_shape_inference.py create mode 100644 onnxruntime/test/python/quantization/test_subgraph.py create mode 100644 tools/ci_build/github/azure-pipelines/templates/py-win-x64-qnn.yml diff --git a/cmake/adjust_global_compile_flags.cmake b/cmake/adjust_global_compile_flags.cmake index 30d8cbf78fb1a..8b4be045c8674 100644 --- a/cmake/adjust_global_compile_flags.cmake +++ b/cmake/adjust_global_compile_flags.cmake @@ -92,8 +92,13 @@ if (onnxruntime_MINIMAL_BUILD) endif() endif() -# enable stream for all the non-minimal build -if (NOT onnxruntime_MINIMAL_BUILD) +# Enable stream for all the non-minimal build, except for DML. There's currently a bug +# in the allocation planner when reusing buffers and more than one streams are used that +# make it possible (although rarely) to reach a reference count of 0 for a buffer that is +# still being used. Since DML doesn't benefit from multiple streams, disabling it is the +# safest option for now. +# https://github.com/microsoft/onnxruntime/issues/19480 +if (NOT onnxruntime_MINIMAL_BUILD AND NOT onnxruntime_USE_DML) add_compile_definitions(ORT_ENABLE_STREAM) endif() diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index fd26b09b09531..f91d66c22ea24 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -461,7 +461,7 @@ This version of the operator has been available since version 1 of the 'com.micr
repetition_penalty (optional) : T
The parameter for repetition penalty. Default value 1.0 means no penalty. Accepts value > 0.0. Shape is (1)
vocab_mask (optional) : M
-
Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vacab_size)
+
Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vocab_size)
prefix_vocab_mask (optional) : M
Mask of vocabulary for first step. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (batch_size, vocab_size)
attention_mask (optional) : I
@@ -2252,7 +2252,7 @@ This version of the operator has been available since version 1 of the 'com.micr
repetition_penalty (optional) : T
The parameter for repetition penalty. Default value 1.0 means no penalty. Accepts value > 0.0. Shape is (1)
vocab_mask (optional) : I
-
Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vacab_size)
+
Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vocab_size)
prefix_vocab_mask (optional) : I
Mask of vocabulary for first step. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (batch_size, vocab_size)
attention_mask (optional) : I
@@ -5154,7 +5154,7 @@ This version of the operator has been available since version 1 of the 'com.micr
repetition_penalty (optional) : T
The parameter for repetition penalty. Default value 1.0 means no penalty. Accepts value > 0.0. Shape is (1)
vocab_mask (optional) : I
-
Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vacab_size)
+
Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vocab_size)
prefix_vocab_mask (optional) : I
Mask of vocabulary for first step. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (batch_size, vocab_size)
attention_mask (optional) : I
@@ -5743,12 +5743,14 @@ This version of the operator has been available since version 1 of the 'com.micr #### Attributes
+
beginning_timestamp_token_id : int
+
The id of the first timestamp
decoder : graph (required)
Decoder subgraph to execute in a loop.
decoder_output_cross_qk : int
If nozero, decoder subgraph contains output Q*K from cross attentions. Default 0.
decoder_start_token_id : int
-
The id of the token that indicates decoding starts.
+
The id of the token that indicates decoding starts (i.e. the start of transcription token id)
early_stopping : int
early stop or not
encoder : graph
@@ -5761,10 +5763,18 @@ This version of the operator has been available since version 1 of the 'com.micr
Must be 2 for whisper
no_repeat_ngram_size : int
no repeat ngrams size
-
no_speech_token : int
+
no_speech_token_id : int
The token in whisper model that marks all sequence empty. With this model, whisper could output no_speech_prob after. Default -1.
+
no_timestamps_token_id : int
+
The id of the token that indicates no timestamps
pad_token_id : int (required)
The id of the padding token
+
start_of_lm_token_id : int
+
The id of the token that indicates LM starts
+
transcribe_token_id : int
+
The id of the transcribe task
+
translate_token_id : int
+
The id of the translate task
vocab_size : int
Size of the vocabulary. If not provided, it will be inferred from the decoder subgraph's output shape
@@ -5783,11 +5793,11 @@ This version of the operator has been available since version 1 of the 'com.micr
num_return_sequences : I
The number of returned sequences in the batch. Shape is (1)
length_penalty (optional) : T
-
Exponential penalty to the length. Default value 1.0 means no penalty.Value > 1.0 encourages longer sequences, while values < 1.0 produces shorter sequences.Shape is (1,)
+
Exponential penalty to the length. Default value 1.0 means no penalty. Value > 1.0 encourages longer sequences, while values < 1.0 produces shorter sequences. Shape is (1,)
repetition_penalty (optional) : T
The parameter for repetition penalty. Default value 1.0 means no penalty. Accepts value > 0.0. Shape is (1)
vocab_mask (optional) : M
-
Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vacab_size)
+
Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vocab_size)
prefix_vocab_mask (optional) : M
Mask of vocabulary for first step. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (batch_size, vocab_size)
attention_mask (optional) : I
@@ -5797,7 +5807,7 @@ This version of the operator has been available since version 1 of the 'com.micr
logits_processor (optional) : I
Specific logits processor for different types of beamsearch models. Default value 0 means no specific logit processor. Accepts value >= 0. Shape is (1)
cross_qk_layer_head (optional) : I
-
Only keep this list of (layer, head) of QK in the final cross_qk output when use_cross_qk is set. Default collect allits shape is (number of (layer, head) to keep, 2), i.e., [[layer_id1, head_id1], [layer_id2, head_id2]......]
+
Only keep this list of (layer, head) of QK in the final cross_qk output when use_cross_qk is set. Default collect all its shape is (number of (layer, head) to keep, 2), i.e., [[layer_id1, head_id1], [layer_id2, head_id2]......]
extra_decoding_ids (optional) : I
Part of the decoder_input_ids that we need cross qk for it. it is of shape (batch_size, extra_decoding_ids_len).In such case, we should remove this from the tail of the decoder_input_ids, and put it here. ids < 0 in it (for multiple batch) are treated as stop of the extra_decoding_ids for corresponding batch.
@@ -5810,11 +5820,11 @@ This version of the operator has been available since version 1 of the 'com.micr
sequences_scores (optional) : T
Final beam score of the generated sequences. Shape is (batch_size, num_return_sequences)
scores (optional) : T
-
Processed beam scores for each vocabulary token at each generation step.Beam scores consisting of log softmax scores for each vocabulary token and sum of log softmax of previously generated tokens in this beam.Shape is (max_length - sequence_length, batch_size, num_beams, vocab_size)
+
Processed beam scores for each vocabulary token at each generation step. Beam scores consisting of log softmax scores for each vocabulary token and sum of log softmax of previously generated tokens in this beam. Shape is (max_length - sequence_length, batch_size, num_beams, vocab_size)
cross_qk (optional) : V
-
Output the accumulated stacked Q*K in cross attentions. Let H = number of Head of cross attention, F = the frames or kv-seq-len of the cross attention input, T = real decoded token length, L = number of layers,B = batch size, R = num_return_sequences. It then should return tensor of shape [B, R, L*H, T, F].If cross_qk_layer_head is given, shape is [B, R, cross_qk_layer_head.shape[0], T, F]
+
Output the accumulated stacked Q*K in cross attentions. Let H = number of Head of cross attention, F = the frames or kv-seq-len of the cross attention input, T = real decoded token length, L = number of layers, B = batch size, R = num_return_sequences. It then should return tensor of shape [B, R, L*H, T, F]. If cross_qk_layer_head is given, shape is [B, R, cross_qk_layer_head.shape[0], T, F]
non_speech_probs (optional) : T
-
For whisper model, output the probabilities from logits after encoder and context decoding for the no_speech_token.Currently we treat the last token's logits is what we need, in future extra graph logic may be add to the encoder/context-decoder subgraph.The prob is save before logits may be updated by extra-decoding-ids. The shape of non_speech_probs is [B]
+
For whisper model, output the probabilities from logits after encoder and context decoding for the no_speech_token_id. The shape of non_speech_probs is [B]
#### Type Constraints diff --git a/js/web/package.json b/js/web/package.json index 047de382943e6..d306390fac594 100644 --- a/js/web/package.json +++ b/js/web/package.json @@ -69,11 +69,14 @@ "exports": { ".": { "node": "./dist/ort.node.min.js", + "types": "./types.d.ts", "default": { "import": "./dist/esm/ort.min.js", "require": "./dist/cjs/ort.min.js", + "types": "./types.d.ts", "default": { "development": "./dist/ort.js", + "types": "./types.d.ts", "default": "./dist/ort.min.js" } } @@ -81,34 +84,41 @@ "./experimental": { "import": "./dist/esm/ort.all.min.js", "require": "./dist/cjs/ort.all.min.js", + "types": "./types.d.ts", "default": { "development": "./dist/ort.all.js", + "types": "./types.d.ts", "default": "./dist/ort.all.min.js" } }, "./wasm": { "import": "./dist/esm/ort.wasm.min.js", "require": "./dist/cjs/ort.wasm.min.js", + "types": "./types.d.ts", "default": "./dist/ort.wasm.min.js" }, "./wasm-core": { "import": "./dist/esm/ort.wasm-core.min.js", "require": "./dist/cjs/ort.wasm-core.min.js", + "types": "./types.d.ts", "default": "./dist/ort.wasm-core.min.js" }, "./webgl": { "import": "./dist/esm/ort.webgl.min.js", "require": "./dist/cjs/ort.webgl.min.js", + "types": "./types.d.ts", "default": "./dist/ort.webgl.min.js" }, "./webgpu": { "import": "./dist/esm/ort.webgpu.min.js", "require": "./dist/cjs/ort.webgpu.min.js", + "types": "./types.d.ts", "default": "./dist/ort.webgpu.min.js" }, "./training": { "import": "./dist/esm/ort.training.wasm.min.js", "require": "./dist/cjs/ort.training.wasm.min.js", + "types": "./types.d.ts", "default": "./dist/ort.training.wasm.min.js" } }, diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_gpt.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_gpt.h index 56d950ca2f41e..d65ff9c5fb4f8 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_gpt.h +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_gpt.h @@ -258,7 +258,7 @@ Status BeamSearchGpt::Execute(const FeedsFetchesManager* init_run_feeds_fetch cpu_state.sequences.InitDevice(beam_state.sequences_device); ORT_RETURN_IF_ERROR(this->device_copy_int32_func_(beam_state.sequences_device.subspan(0, beam_state.sequences_device.size() / 2), cpu_state.sequences_space.subspan(0, cpu_state.sequences_space.size() / 2), - nullptr, + this->ort_stream_, DeviceCopyDirection::hostToDevice)); } diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h index 94547887d3a90..3dbdd7b0fcd70 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h @@ -214,7 +214,7 @@ Status BeamSearchT5::Execute(const FeedsFetchesManager& encoder_feeds_fetches cpu_state.sequences.InitDevice(beam_state.sequences_device); ORT_RETURN_IF_ERROR(this->device_copy_int32_func_(beam_state.sequences_device.subspan(0, beam_state.sequences_device.size() / 2), cpu_state.sequences_space.subspan(0, cpu_state.sequences_space.size() / 2), - nullptr, + this->ort_stream_, DeviceCopyDirection::hostToDevice)); } diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_whisper.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_whisper.h index 91b93a125ad7a..97dc513d4b54f 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_whisper.h +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_whisper.h @@ -134,8 +134,8 @@ Status BeamSearchWhisper::Execute(const FeedsFetchesManager& encoder_feeds_fe TensorShape no_speech_probs_shape{parameters->batch_size}; Tensor* no_speech_probs = this->context_.Output(parameters->no_speech_probs_output_id, no_speech_probs_shape); if (no_speech_probs && no_speech_probs->MutableData()) { - ORT_ENFORCE(parameters->no_speech_token >= 0 && parameters->no_speech_token < parameters->vocab_size, - "no_speech_token id out of range, it is ", parameters->no_speech_token, + ORT_ENFORCE(parameters->no_speech_token_id >= 0 && parameters->no_speech_token_id < parameters->vocab_size, + "no_speech_token_id is out of range, it is ", parameters->no_speech_token_id, ", vocab_size is ", parameters->vocab_size); this->parameters_->no_speech_probs = (void*)no_speech_probs->MutableData(); } @@ -226,7 +226,7 @@ Status BeamSearchWhisper::Execute(const FeedsFetchesManager& encoder_feeds_fe cpu_state.sequences.InitDevice(beam_state.sequences_device); ORT_RETURN_IF_ERROR(this->device_copy_int32_func_(beam_state.sequences_device.subspan(0, beam_state.sequences_device.size() / 2), cpu_state.sequences_space.subspan(0, cpu_state.sequences_space.size() / 2), - nullptr, + this->ort_stream_, DeviceCopyDirection::hostToDevice)); } diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc b/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc index 3962486d5b5eb..8a466dd9d9c18 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc @@ -141,7 +141,13 @@ void WhisperBeamSearchParameters::ParseFromAttributes(const OpKernelInfo& info) model_type = static_cast(info.GetAttrOrDefault("model_type", IGenerationParameters::kModelTypeWhisper)); ORT_ENFORCE(model_type == IGenerationParameters::kModelTypeWhisper); - no_speech_token = static_cast(info.GetAttrOrDefault("no_speech_token", -1LL)); + // Token ids are defined below in the order that they appear in the tokenizer + translate_token_id = static_cast(info.GetAttrOrDefault("translate_token_id", -1LL)); + transcribe_token_id = static_cast(info.GetAttrOrDefault("transcribe_token_id", -1LL)); + start_of_lm_token_id = static_cast(info.GetAttrOrDefault("start_of_lm_token_id", -1LL)); + no_speech_token_id = static_cast(info.GetAttrOrDefault("no_speech_token_id", -1LL)); + no_timestamps_token_id = static_cast(info.GetAttrOrDefault("no_timestamps_token_id", -1LL)); + beginning_timestamp_token_id = static_cast(info.GetAttrOrDefault("beginning_timestamp_token_id", -1LL)); cross_qk_layer_head_input_id = 12; extra_decoding_ids_input_id = 13; cross_qk_output_id = 3; diff --git a/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h b/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h index f6faf2e325f8f..34510902cf309 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h +++ b/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h @@ -180,7 +180,14 @@ struct IGenerationParameters { // Parameters for whisper model bool decoder_output_cross_qk = false; gsl::span extra_decoding_ids; - int32_t no_speech_token = -1; + + // Token ids are defined below in the order that they appear in the tokenizer + int32_t translate_token_id = -1; + int32_t transcribe_token_id = -1; + int32_t start_of_lm_token_id = -1; + int32_t no_speech_token_id = -1; + int32_t no_timestamps_token_id = -1; + int32_t beginning_timestamp_token_id = -1; void* no_speech_probs = nullptr; int cross_qk_layer_head_input_id = -1; diff --git a/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h index 4688ff272cee9..3c213ee944119 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h +++ b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h @@ -10,6 +10,7 @@ #include "contrib_ops/cpu/transformers/greedy_search_parameters.h" #include "contrib_ops/cpu/transformers/sampling_parameters.h" #include "contrib_ops/cpu/transformers/generation_shared.h" +#include namespace onnxruntime { namespace contrib { @@ -34,6 +35,14 @@ struct NextTokenScores { } }; +#ifdef DEBUG_GENERATION +template +void DumpScores(const char* name, const NextTokenScores& next_token_scores) { + std::cout << name << std::endl; + ORT_UNUSED_PARAMETER(next_token_scores); +} +#endif + // Interface for all scorers for beam search or beam sample. template class ILogitsProcessor { @@ -150,19 +159,25 @@ class PresencePenaltyLogitsProcessor : public ILogitsProcessor { template class TimestampLogitsProcessor : public ILogitsProcessor { public: - TimestampLogitsProcessor(int eos_token_id, int max_initial_timestamp_index) - : eos_token_id_(eos_token_id), max_initial_timestamp_index_(max_initial_timestamp_index) {} + TimestampLogitsProcessor(int end_of_text_token_id, // <|endoftext|> + int start_of_transcript_token_id, // <|startoftranscript|> + int translate_token_id, // <|translate|> + int transcribe_token_id, // <|transcribe|> + int start_of_lm_token_id, // <|startoflm|> + int no_timestamps_token_id, // <|notimestamps|> + int beginning_timestamp_token_id, // <|0.00|> + int max_initial_timestamp_index) + : end_of_text_token_id_(end_of_text_token_id), + start_of_transcript_token_id_(start_of_transcript_token_id), + translate_token_id_(translate_token_id), + transcribe_token_id_(transcribe_token_id), + start_of_lm_token_id_(start_of_lm_token_id), + no_timestamps_token_id_(no_timestamps_token_id), + beginning_timestamp_token_id_(beginning_timestamp_token_id), + max_initial_timestamp_index_(max_initial_timestamp_index) {} void Process(const ISequences* sequences, NextTokenScores& next_token_scores) override { - // TODO: translate_token_id_ and transcribe_token_id_ need to support both multilingual and English-only models. - const int beg_token_id_ = eos_token_id_ + 107; - const int not_token_id_ = eos_token_id_ + 106; - const int solm_token_id_ = eos_token_id_ + 105; - const int sot_token_id_ = eos_token_id_ + 1; - constexpr int translate_token_id_ = 50358; - constexpr int transcribe_token_id_ = 50359; - const int batch_beam_size = next_token_scores.batch_beam_size; const int vocab_size = next_token_scores.vocab_size; for (int i = 0; i < batch_beam_size; i++) { @@ -174,7 +189,7 @@ class TimestampLogitsProcessor : public ILogitsProcessor { size_t sample_begin = 0; for (size_t j = 0; j < seq_length; j++) { sample_begin++; - if (sequence[j] >= beg_token_id_) { + if (sequence[j] >= beginning_timestamp_token_id_) { break; } } @@ -182,30 +197,30 @@ class TimestampLogitsProcessor : public ILogitsProcessor { // Suppress tokens for (int j = 0; j < vocab_size; j++) { // Suppress notimestamps and solm tokens - if (j == not_token_id_ || j == solm_token_id_) { + if (j == no_timestamps_token_id_ || j == start_of_lm_token_id_) { beam_token_scores[j] = std::numeric_limits::lowest(); } // Suppress sot, translate and transcribe tokens if (seq_length > sample_begin) { - if (j == sot_token_id_ || j == translate_token_id_ || j == transcribe_token_id_) { + if (j == start_of_transcript_token_id_ || j == translate_token_id_ || j == transcribe_token_id_) { beam_token_scores[j] = std::numeric_limits::lowest(); } } } // Timestamps should be in pair except the first one - const bool last_was_timestamp = seq_length > 0 && sequence.back() >= beg_token_id_; - const bool penultimate_was_timestamp = seq_length <= sample_begin || sequence[seq_length - 2] >= beg_token_id_; + const bool last_was_timestamp = seq_length > 0 && sequence.back() >= beginning_timestamp_token_id_; + const bool penultimate_was_timestamp = seq_length <= sample_begin || sequence[seq_length - 2] >= beginning_timestamp_token_id_; if (last_was_timestamp) { if (penultimate_was_timestamp) { // If timestamps show up in pair, or it's the first timestamp, no more timestamp is generated - for (int j = beg_token_id_; j < vocab_size; j++) { + for (int j = beginning_timestamp_token_id_; j < vocab_size; j++) { beam_token_scores[j] = std::numeric_limits::lowest(); } } else { // If timestamp doesn't show up in pair, generate timestamp - for (int j = 0; j < eos_token_id_; j++) { + for (int j = 0; j < end_of_text_token_id_; j++) { beam_token_scores[j] = std::numeric_limits::lowest(); } } @@ -214,7 +229,7 @@ class TimestampLogitsProcessor : public ILogitsProcessor { // Find timestamp tokens std::vector timestamps; for (const auto& word_id : sequence) { - if (word_id >= beg_token_id_) { + if (word_id >= beginning_timestamp_token_id_) { timestamps.push_back(word_id); } } @@ -231,13 +246,13 @@ class TimestampLogitsProcessor : public ILogitsProcessor { timestamp_last = timestamps.back() + 1; } - for (int j = beg_token_id_; j < timestamp_last; j++) { + for (int j = beginning_timestamp_token_id_; j < timestamp_last; j++) { beam_token_scores[j] = std::numeric_limits::lowest(); } } if (seq_length == sample_begin) { - const int last_allowed = beg_token_id_ + max_initial_timestamp_index_; + const int last_allowed = beginning_timestamp_token_id_ + max_initial_timestamp_index_; for (int j = last_allowed + 1; j < vocab_size; j++) { beam_token_scores[j] = std::numeric_limits::lowest(); } @@ -247,8 +262,8 @@ class TimestampLogitsProcessor : public ILogitsProcessor { float timestamp_logprob = std::numeric_limits::lowest(); { float logsumexp = 0.0f; - const float logprob_max = *std::max_element(beam_token_scores.begin() + beg_token_id_, beam_token_scores.end()); - for (int j = beg_token_id_; j < vocab_size; ++j) { + const float logprob_max = *std::max_element(beam_token_scores.begin() + beginning_timestamp_token_id_, beam_token_scores.end()); + for (int j = beginning_timestamp_token_id_; j < vocab_size; ++j) { if (beam_token_scores[j] > std::numeric_limits::lowest()) { logsumexp += expf(beam_token_scores[j] - logprob_max); } @@ -258,9 +273,9 @@ class TimestampLogitsProcessor : public ILogitsProcessor { } } - const float max_text_token_logprob = *std::max_element(beam_token_scores.begin(), beam_token_scores.begin() + beg_token_id_); + const float max_text_token_logprob = *std::max_element(beam_token_scores.begin(), beam_token_scores.begin() + beginning_timestamp_token_id_); if (timestamp_logprob > max_text_token_logprob) { - for (int j = 0; j < beg_token_id_; ++j) { + for (int j = 0; j < beginning_timestamp_token_id_; ++j) { beam_token_scores[j] = std::numeric_limits::lowest(); } } @@ -272,7 +287,13 @@ class TimestampLogitsProcessor : public ILogitsProcessor { } private: - int eos_token_id_; + int end_of_text_token_id_; + int start_of_transcript_token_id_; + int translate_token_id_; + int transcribe_token_id_; + int start_of_lm_token_id_; + int no_timestamps_token_id_; + int beginning_timestamp_token_id_; int max_initial_timestamp_index_; }; @@ -334,7 +355,15 @@ class LogitsProcessorList : public ILogitsProcessorList { // Add timestamp processor for whisper model if (parameters.model_type == IGenerationParameters::kModelTypeWhisper && parameters.logits_processor == IGenerationParameters::kLogitsProcessorTypeWhisper) { constexpr int max_initial_timestamp_index = 50; - timestamp_processor_ = std::make_unique>(parameters.eos_token_id, max_initial_timestamp_index); + // Token ids are passed below in the order that they appear in the tokenizer + timestamp_processor_ = std::make_unique>(parameters.eos_token_id, + parameters.decoder_start_token_id, + parameters.translate_token_id, + parameters.transcribe_token_id, + parameters.start_of_lm_token_id, + parameters.no_timestamps_token_id, + parameters.beginning_timestamp_token_id, + max_initial_timestamp_index); processor_list_.push_back(timestamp_processor_.get()); } diff --git a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc index 380d561bbb23c..b8f8d7691a9b6 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc +++ b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc @@ -424,7 +424,7 @@ Status ProcessLogits(const OrtValue& logits, // const bool is_whisper_model = (parameters->model_type == onnxruntime::contrib::transformers::IGenerationParameters::kModelTypeWhisper); if (step == 1 && is_whisper_model && parameters->no_speech_probs) { cuda::LaunchSaveNoSpeechProbs( - (T*)parameters->no_speech_probs, Y_data, batch_size, num_beams, vocab_size, parameters->no_speech_token, cuda_stream); + (T*)parameters->no_speech_probs, Y_data, batch_size, num_beams, vocab_size, parameters->no_speech_token_id, cuda_stream); } // NOTE: currently we treat extra decoding ids are same @@ -469,7 +469,15 @@ Status ProcessLogits(const OrtValue& logits, // cudaMemcpyDeviceToHost, cuda_stream)); constexpr int max_initial_timestamp_index = 50; - onnxruntime::contrib::transformers::TimestampLogitsProcessor time_logit_processor(parameters->eos_token_id, max_initial_timestamp_index); + // Token ids are passed below in the order that they appear in the tokenizer + onnxruntime::contrib::transformers::TimestampLogitsProcessor time_logit_processor(parameters->eos_token_id, + parameters->decoder_start_token_id, + parameters->translate_token_id, + parameters->transcribe_token_id, + parameters->start_of_lm_token_id, + parameters->no_timestamps_token_id, + parameters->beginning_timestamp_token_id, + max_initial_timestamp_index); onnxruntime::contrib::transformers::NextTokenScores next_token_scores_timestamp({cpu_next_token_scores_span, batch_beam_size, vocab_size}); CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(cuda_stream)); diff --git a/onnxruntime/core/framework/execution_providers.h b/onnxruntime/core/framework/execution_providers.h index 61147e4367876..dc45cad692b6e 100644 --- a/onnxruntime/core/framework/execution_providers.h +++ b/onnxruntime/core/framework/execution_providers.h @@ -3,7 +3,6 @@ #pragma once -// #include #include #include #include @@ -14,7 +13,9 @@ #include "core/common/logging/logging.h" #ifdef _WIN32 #include +#include #include "core/platform/tracing.h" +#include "core/platform/windows/telemetry.h" #endif namespace onnxruntime { @@ -44,6 +45,49 @@ class ExecutionProviders { exec_provider_options_[provider_id] = providerOptions; #ifdef _WIN32 + LogProviderOptions(provider_id, providerOptions, false); + + // Register callback for ETW capture state (rundown) + WindowsTelemetry::RegisterInternalCallback( + [this]( + LPCGUID SourceId, + ULONG IsEnabled, + UCHAR Level, + ULONGLONG MatchAnyKeyword, + ULONGLONG MatchAllKeyword, + PEVENT_FILTER_DESCRIPTOR FilterData, + PVOID CallbackContext) { + (void)SourceId; + (void)Level; + (void)MatchAnyKeyword; + (void)MatchAllKeyword; + (void)FilterData; + (void)CallbackContext; + + // Check if this callback is for capturing state + if ((IsEnabled == EVENT_CONTROL_CODE_CAPTURE_STATE) && + ((MatchAnyKeyword & static_cast(onnxruntime::logging::ORTTraceLoggingKeyword::Session)) != 0)) { + for (size_t i = 0; i < exec_providers_.size(); ++i) { + const auto& provider_id = exec_provider_ids_[i]; + + auto it = exec_provider_options_.find(provider_id); + if (it != exec_provider_options_.end()) { + const auto& options = it->second; + + LogProviderOptions(provider_id, options, true); + } + } + } + }); +#endif + + exec_provider_ids_.push_back(provider_id); + exec_providers_.push_back(p_exec_provider); + return Status::OK(); + } + +#ifdef _WIN32 + void LogProviderOptions(const std::string& provider_id, const ProviderOptions& providerOptions, bool captureState) { for (const auto& config_pair : providerOptions) { TraceLoggingWrite( telemetry_provider_handle, @@ -52,14 +96,11 @@ class ExecutionProviders { TraceLoggingLevel(WINEVENT_LEVEL_INFO), TraceLoggingString(provider_id.c_str(), "ProviderId"), TraceLoggingString(config_pair.first.c_str(), "Key"), - TraceLoggingString(config_pair.second.c_str(), "Value")); + TraceLoggingString(config_pair.second.c_str(), "Value"), + TraceLoggingBool(captureState, "isCaptureState")); } -#endif - - exec_provider_ids_.push_back(provider_id); - exec_providers_.push_back(p_exec_provider); - return Status::OK(); } +#endif const IExecutionProvider* Get(const onnxruntime::Node& node) const { return Get(node.GetExecutionProviderType()); diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index 982e8fd834b76..0c90ad768a92f 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -1163,7 +1163,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA(BeamSearch, 1, "Shape is (1,)", "T", OpSchema::Optional) .Input(6, "repetition_penalty", "The parameter for repetition penalty. Default value 1.0 means no penalty. Accepts value > 0.0. Shape is (1)", "T", OpSchema::Optional) - .Input(7, "vocab_mask", "Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vacab_size)", "M", OpSchema::Optional) + .Input(7, "vocab_mask", "Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vocab_size)", "M", OpSchema::Optional) .Input(8, "prefix_vocab_mask", "Mask of vocabulary for first step. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (batch_size, vocab_size)", "M", OpSchema::Optional) .Input(9, "attention_mask", "Custom attention mask. Shape is (batch_size, sequence_length)", "I", OpSchema::Optional) .Input(10, "decoder_input_ids", "The forced input id sequence for the decoder subgraph. Shape is (batch_size, initial_sequence_length)", "I", OpSchema::Optional) @@ -1188,7 +1188,15 @@ ONNX_MS_OPERATOR_SET_SCHEMA(WhisperBeamSearch, 1, .SetDoc("Beam Search for whisper model, especiall with cross_qk features etc.") .Attr("eos_token_id", "The id of the end-of-sequence token", AttributeProto::INT) .Attr("pad_token_id", "The id of the padding token", AttributeProto::INT) - .Attr("decoder_start_token_id", "The id of the token that indicates decoding starts.", AttributeProto::INT, static_cast(-1)) + .Attr("decoder_start_token_id", "The id of the token that indicates decoding starts (i.e. the start of transcription token id)", AttributeProto::INT, static_cast(-1)) + .Attr("translate_token_id", "The id of the translate task", AttributeProto::INT, OPTIONAL_VALUE) + .Attr("transcribe_token_id", "The id of the transcribe task", AttributeProto::INT, OPTIONAL_VALUE) + .Attr("start_of_lm_token_id", "The id of the token that indicates LM starts", AttributeProto::INT, OPTIONAL_VALUE) + .Attr("no_speech_token_id", + "The token in whisper model that marks all sequence empty. With this model, whisper could output no_speech_prob after. Default -1.", + AttributeProto::INT, OPTIONAL_VALUE) + .Attr("no_timestamps_token_id", "The id of the token that indicates no timestamps", AttributeProto::INT, OPTIONAL_VALUE) + .Attr("beginning_timestamp_token_id", "The id of the first timestamp", AttributeProto::INT, OPTIONAL_VALUE) .Attr("no_repeat_ngram_size", "no repeat ngrams size", AttributeProto::INT, static_cast(0)) .Attr("early_stopping", "early stop or not", AttributeProto::INT, static_cast(0)) .Attr("model_type", "Must be 2 for whisper", AttributeProto::INT, static_cast(2)) @@ -1203,27 +1211,24 @@ ONNX_MS_OPERATOR_SET_SCHEMA(WhisperBeamSearch, 1, "If not provided, it will be inferred from the decoder subgraph's output shape", AttributeProto::INT, static_cast(-1)) .Attr("decoder_output_cross_qk", "If nozero, decoder subgraph contains output Q*K from cross attentions. Default 0.", AttributeProto::INT, OPTIONAL_VALUE) - .Attr("no_speech_token", - "The token in whisper model that marks all sequence empty. With this model, whisper could output no_speech_prob after. Default -1.", - AttributeProto::INT, OPTIONAL_VALUE) .Input(0, "input_ids", "The sequence used as a prompt for the generation in the encoder subgraph. Shape is (batch_size, sequence_length)", "F") .Input(1, "max_length", "The maximum length of the sequence to be generated. Shape is (1)", "I") .Input(2, "min_length", "The minimum length below which the score of eos_token_id is set to -Inf. Shape is (1)", "I", OpSchema::Optional) .Input(3, "num_beams", "Number of beams for beam search. 1 means no beam search. Shape is (1)", "I") .Input(4, "num_return_sequences", "The number of returned sequences in the batch. Shape is (1)", "I") .Input(5, "length_penalty", - "Exponential penalty to the length. Default value 1.0 means no penalty." - "Value > 1.0 encourages longer sequences, while values < 1.0 produces shorter sequences." + "Exponential penalty to the length. Default value 1.0 means no penalty. " + "Value > 1.0 encourages longer sequences, while values < 1.0 produces shorter sequences. " "Shape is (1,)", "T", OpSchema::Optional) .Input(6, "repetition_penalty", "The parameter for repetition penalty. Default value 1.0 means no penalty. Accepts value > 0.0. Shape is (1)", "T", OpSchema::Optional) - .Input(7, "vocab_mask", "Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vacab_size)", "M", OpSchema::Optional) + .Input(7, "vocab_mask", "Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vocab_size)", "M", OpSchema::Optional) .Input(8, "prefix_vocab_mask", "Mask of vocabulary for first step. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (batch_size, vocab_size)", "M", OpSchema::Optional) .Input(9, "attention_mask", "Custom attention mask. Shape is (batch_size, sequence_length)", "I", OpSchema::Optional) .Input(10, "decoder_input_ids", "The forced input id sequence for the decoder subgraph. Shape is (batch_size, initial_sequence_length)", "I", OpSchema::Optional) .Input(11, "logits_processor", "Specific logits processor for different types of beamsearch models. Default value 0 means no specific logit processor. Accepts value >= 0. Shape is (1)", "I", OpSchema::Optional) .Input(12, "cross_qk_layer_head", - "Only keep this list of (layer, head) of QK in the final cross_qk output when use_cross_qk is set. Default collect all" + "Only keep this list of (layer, head) of QK in the final cross_qk output when use_cross_qk is set. Default collect all " "its shape is (number of (layer, head) to keep, 2), i.e., [[layer_id1, head_id1], [layer_id2, head_id2]......]", "I", OpSchema::Optional) .Input(13, "extra_decoding_ids", @@ -1234,20 +1239,19 @@ ONNX_MS_OPERATOR_SET_SCHEMA(WhisperBeamSearch, 1, .Output(0, "sequences", "Word IDs of generated sequences. Shape is (batch_size, num_return_sequences, max_sequence_length)", "I") .Output(1, "sequences_scores", "Final beam score of the generated sequences. Shape is (batch_size, num_return_sequences)", "T", OpSchema::Optional) .Output(2, "scores", - "Processed beam scores for each vocabulary token at each generation step." - "Beam scores consisting of log softmax scores for each vocabulary token and sum of log softmax of previously generated tokens in this beam." + "Processed beam scores for each vocabulary token at each generation step. " + "Beam scores consisting of log softmax scores for each vocabulary token and sum of log softmax of previously generated tokens in this beam. " "Shape is (max_length - sequence_length, batch_size, num_beams, vocab_size)", "T", OpSchema::Optional) .Output(3, "cross_qk", "Output the accumulated stacked Q*K in cross attentions. Let H = number of Head of cross attention, " - "F = the frames or kv-seq-len of the cross attention input, T = real decoded token length, L = number of layers," - "B = batch size, R = num_return_sequences. It then should return tensor of shape [B, R, L*H, T, F]." + "F = the frames or kv-seq-len of the cross attention input, T = real decoded token length, L = number of layers, " + "B = batch size, R = num_return_sequences. It then should return tensor of shape [B, R, L*H, T, F]. " "If cross_qk_layer_head is given, shape is [B, R, cross_qk_layer_head.shape[0], T, F]", "V", OpSchema::Optional) .Output(4, "non_speech_probs", - "For whisper model, output the probabilities from logits after encoder and context decoding for the no_speech_token." - "Currently we treat the last token's logits is what we need, in future extra graph logic may be add to the encoder/context-decoder subgraph." - "The prob is save before logits may be updated by extra-decoding-ids. The shape of non_speech_probs is [B]", + "For whisper model, output the probabilities from logits after encoder and context decoding for the no_speech_token_id. " + "The shape of non_speech_probs is [B]", "T", OpSchema::Optional) .TypeConstraint("T", {"tensor(float)", "tensor(float16)"}, "Constrain to float tensors.") .TypeConstraint("F", {"tensor(float)", "tensor(int32)", "tensor(float16)"}, "Constrain input type to float or int tensors.") @@ -1321,7 +1325,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA(GreedySearch, 1, .Input(1, "max_length", "The maximum length of the sequence to be generated. Shape is (1)", "I") .Input(2, "min_length", "The minimum length below which the score of eos_token_id is set to -Inf. Shape is (1)", "I", OpSchema::Optional) .Input(3, "repetition_penalty", "The parameter for repetition penalty. Default value 1.0 means no penalty. Accepts value > 0.0. Shape is (1)", "T", OpSchema::Optional) - .Input(4, "vocab_mask", "Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vacab_size)", "I", OpSchema::Optional) + .Input(4, "vocab_mask", "Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vocab_size)", "I", OpSchema::Optional) .Input(5, "prefix_vocab_mask", "Mask of vocabulary for first step. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (batch_size, vocab_size)", "I", OpSchema::Optional) .Input(6, "attention_mask", "Custom attention mask. Shape is (batch_size, sequence_length)", "I", OpSchema::Optional) .Output(0, "sequences", "Word IDs of generated sequences. Shape is (batch_size, max_sequence_length)", "I") @@ -1362,7 +1366,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA(Sampling, 1, .Input(1, "max_length", "The maximum length of the sequence to be generated. Shape is (1)", "I") .Input(2, "min_length", "The minimum length below which the score of eos_token_id is set to -Inf. Shape is (1)", "I", OpSchema::Optional) .Input(3, "repetition_penalty", "The parameter for repetition penalty. Default value 1.0 means no penalty. Accepts value > 0.0. Shape is (1)", "T", OpSchema::Optional) - .Input(4, "vocab_mask", "Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vacab_size)", "I", OpSchema::Optional) + .Input(4, "vocab_mask", "Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vocab_size)", "I", OpSchema::Optional) .Input(5, "prefix_vocab_mask", "Mask of vocabulary for first step. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (batch_size, vocab_size)", "I", OpSchema::Optional) .Input(6, "attention_mask", "Custom attention mask. Shape is (batch_size, sequence_length)", "I", OpSchema::Optional) .Input(7, "presence_mask", "Presence penalty mask. Shape is (batch_size, vocab_size)", "I", OpSchema::Optional) diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc index f71b7ecebcf1a..eff1510ef8dc6 100644 --- a/onnxruntime/core/graph/graph.cc +++ b/onnxruntime/core/graph/graph.cc @@ -1818,16 +1818,36 @@ void Graph::ReverseDFSFrom(gsl::span from, } } +template +struct VisitorPriorityQueue { + using ComparatorType = std::function; + std::list list_; + const ComparatorType comparator_ = nullptr; + VisitorPriorityQueue(const ComparatorType& comp) : comparator_(comp) {} + + void push(T node) { + list_.insert( + std::upper_bound(list_.begin(), list_.end(), node, comparator_), + node); + } + bool empty() { return list_.empty(); } + T top() { return list_.back(); } + void pop() { list_.pop_back(); } +}; + #if !defined(ORT_MINIMAL_BUILD) void Graph::KahnsTopologicalSort(const std::function& enter, const std::function& comp) const { - std::unordered_map in_degree; - std::priority_queue, decltype(comp)> to_visit(comp); - std::vector topo_order; + InlinedVector in_degree(MaxNodeIndex(), 0); + InlinedVector topo_order; + VisitorPriorityQueue to_visit(comp); + + auto number_of_nodes = NumberOfNodes(); + topo_order.reserve(number_of_nodes); for (auto& node : Nodes()) { size_t input_edge_count = node.GetInputEdgesCount(); - in_degree.insert({node.Index(), input_edge_count}); + in_degree[node.Index()] = input_edge_count; if (input_edge_count == 0) { to_visit.push(&node); } @@ -1844,16 +1864,17 @@ void Graph::KahnsTopologicalSort(const std::function& enter, } for (auto node_it = current->OutputNodesBegin(); node_it != current->OutputNodesEnd(); ++node_it) { - in_degree[node_it->Index()]--; + auto& node_in_degree = in_degree[node_it->Index()]; + node_in_degree--; - if (in_degree[node_it->Index()] == 0) { + if (node_in_degree == 0) { to_visit.push(&*node_it); } } topo_order.push_back(current->Index()); } - if (NumberOfNodes() != static_cast(topo_order.size())) { + if (number_of_nodes != static_cast(topo_order.size())) { ORT_THROW("Some nodes are not included in the topological sort, graph have a cycle."); } } @@ -2842,7 +2863,7 @@ void Graph::AddInitializedTensor(const TensorProto& tensor) { const gsl::not_null tensor_added{graph_proto_->add_initializer()}; *(tensor_added) = tensor; - name_to_initial_tensor_[tensor.name()] = tensor_added; + name_to_initial_tensor_.emplace(tensor.name(), tensor_added); SetGraphResolveNeeded(); if (!is_loaded_from_model_file_ && GetNodeArg(tensor.name()) == nullptr) { // make sure there is a NodeArg for the initializer as SetGraphInputsOutputs may add it to the graph inputs. diff --git a/onnxruntime/core/graph/graph_viewer.cc b/onnxruntime/core/graph/graph_viewer.cc index cf78040ea5ac6..d2b73e4c2c130 100644 --- a/onnxruntime/core/graph/graph_viewer.cc +++ b/onnxruntime/core/graph/graph_viewer.cc @@ -14,8 +14,8 @@ bool NodeCompare::operator()(const Node* n1, const Node* n2) const { struct PriorityNodeCompare { inline bool IsHighPri(const Node* n) const { // local statics so we can compare std::strings in the checks - static const std::string shape_op("Shape"); - static const std::string size_op("Size"); + static constexpr std::string_view shape_op("Shape"); + static constexpr std::string_view size_op("Size"); const auto& op_type = n->OpType(); return op_type == shape_op || op_type == size_op; @@ -26,15 +26,20 @@ struct PriorityNodeCompare { // If return true, n2 will be output first bool operator()(const Node* n1, const Node* n2) const { // nodes in global high priority list will be output first - if (IsHighPri(n1) != IsHighPri(n2)) { - return IsHighPri(n2); + const bool isN1HighPri = IsHighPri(n1); + const bool isN2HighPri = IsHighPri(n2); + if (isN1HighPri != isN2HighPri) { + return isN2HighPri; } // nodes with lower priority value will be output first - if (n1->Priority() != n2->Priority()) { - return n1->Priority() > n2->Priority(); + const auto n1_priority = n1->Priority(); + const auto n2_priority = n2->Priority(); + if (n1_priority != n2_priority) { + return n1_priority > n2_priority; } +#ifdef ENABLE_TRAINING // nodes of forward pass will be output first auto n1_attrs = n1->GetAttributes(); auto n2_attrs = n2->GetAttributes(); @@ -45,6 +50,7 @@ struct PriorityNodeCompare { if (n1_is_forward != n2_is_forward) { return n2_is_forward > n1_is_forward; } +#endif // otherwise, nodes with lower index will be output first return n1->Index() > n2->Index(); diff --git a/onnxruntime/core/optimizer/noop_elimination.cc b/onnxruntime/core/optimizer/noop_elimination.cc index b3c2991d54b28..bba39b698a27a 100644 --- a/onnxruntime/core/optimizer/noop_elimination.cc +++ b/onnxruntime/core/optimizer/noop_elimination.cc @@ -42,49 +42,62 @@ bool NoopElimination::SatisfyCondition(const Graph& graph, const Node& node, con // if initializer_rank is bigger, the output is expected to be initializer_rank per broadcasting rule, // but it won't happen if the case is accepted, thus reject it - auto initializer_rank = initializer->dims().size(); + const auto& dims = initializer->dims(); + auto initializer_rank = dims.size(); const auto* other_input_shape = node.InputDefs()[input0_is_initializer ? 1 : 0]->Shape(); if (other_input_shape == nullptr || initializer_rank > other_input_shape->dim_size()) { return false; } - int32_t data_type = initializer->data_type(); - Initializer add_init(*initializer, graph.ModelPath()); - if (add_init.size() > 1) { + int64_t tensor_size = 1; + for (auto i : dims) { + tensor_size *= i; + } + + if (tensor_size > 1) { return false; } + // handle edge case where the total size of the initializer is 0 - if (add_init.size() == 0) { + if (tensor_size == 0) { return true; } - float value = 0.0f; - switch (data_type) { - case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: - value = *add_init.data(); - break; - case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: - value = math::halfToFloat(add_init.data()->val); - break; - case ONNX_NAMESPACE::TensorProto_DataType_DOUBLE: - value = static_cast(*add_init.data()); - break; - case ONNX_NAMESPACE::TensorProto_DataType_INT32: - value = static_cast(*add_init.data()); - break; - case ONNX_NAMESPACE::TensorProto_DataType_INT64: - value = static_cast(*add_init.data()); - break; - default: + if (op_type == "Add" || + op_type == "Sub" || + op_type == "Mul" || + op_type == "Div") { + int32_t data_type = initializer->data_type(); + Initializer add_init(*initializer, graph.ModelPath()); + + float value = 0.0f; + switch (data_type) { + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: + value = *add_init.data(); + break; + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: + value = math::halfToFloat(add_init.data()->val); + break; + case ONNX_NAMESPACE::TensorProto_DataType_DOUBLE: + value = static_cast(*add_init.data()); + break; + case ONNX_NAMESPACE::TensorProto_DataType_INT32: + value = static_cast(*add_init.data()); + break; + case ONNX_NAMESPACE::TensorProto_DataType_INT64: + value = static_cast(*add_init.data()); + break; + default: + return false; + } + + if (value != 0.0f && (op_type == "Add" || op_type == "Sub")) { return false; - } + } - if ((op_type == "Add" || op_type == "Sub") && value != 0.0f) { - return false; - } - - if ((op_type == "Mul" || op_type == "Div") && value != 1.0f) { - return false; + if (value != 1.0f && (op_type == "Mul" || op_type == "Div")) { + return false; + } } // reject node output is graph output for now diff --git a/onnxruntime/core/optimizer/transpose_optimization/ort_optimizer_api_impl.cc b/onnxruntime/core/optimizer/transpose_optimization/ort_optimizer_api_impl.cc index d9f08ffe1171e..c532f56b3d3d9 100644 --- a/onnxruntime/core/optimizer/transpose_optimization/ort_optimizer_api_impl.cc +++ b/onnxruntime/core/optimizer/transpose_optimization/ort_optimizer_api_impl.cc @@ -115,7 +115,7 @@ class ApiGraph final : public api::GraphRef { const auto& graph_outputs = graph_.GetOutputs(); graph_outputs_.reserve(graph_outputs.size()); for (const auto* output : graph_outputs) { - graph_outputs_.insert(output->Name()); + graph_outputs_.emplace(output->Name()); } } diff --git a/onnxruntime/core/platform/windows/telemetry.cc b/onnxruntime/core/platform/windows/telemetry.cc index a9849873fd060..654281d526e4d 100644 --- a/onnxruntime/core/platform/windows/telemetry.cc +++ b/onnxruntime/core/platform/windows/telemetry.cc @@ -2,6 +2,7 @@ // Licensed under the MIT License. #include "core/platform/windows/telemetry.h" +#include "core/platform/ort_mutex.h" #include "core/common/logging/logging.h" #include "onnxruntime_config.h" @@ -63,6 +64,8 @@ bool WindowsTelemetry::enabled_ = true; uint32_t WindowsTelemetry::projection_ = 0; UCHAR WindowsTelemetry::level_ = 0; UINT64 WindowsTelemetry::keyword_ = 0; +std::vector WindowsTelemetry::callbacks_; +OrtMutex WindowsTelemetry::callbacks_mutex_; WindowsTelemetry::WindowsTelemetry() { std::lock_guard lock(mutex_); @@ -104,6 +107,11 @@ UINT64 WindowsTelemetry::Keyword() const { // return etw_status_; // } +void WindowsTelemetry::RegisterInternalCallback(const EtwInternalCallback& callback) { + std::lock_guard lock(callbacks_mutex_); + callbacks_.push_back(callback); +} + void NTAPI WindowsTelemetry::ORT_TL_EtwEnableCallback( _In_ LPCGUID SourceId, _In_ ULONG IsEnabled, @@ -112,15 +120,21 @@ void NTAPI WindowsTelemetry::ORT_TL_EtwEnableCallback( _In_ ULONGLONG MatchAllKeyword, _In_opt_ PEVENT_FILTER_DESCRIPTOR FilterData, _In_opt_ PVOID CallbackContext) { - (void)SourceId; - (void)MatchAllKeyword; - (void)FilterData; - (void)CallbackContext; - std::lock_guard lock(provider_change_mutex_); enabled_ = (IsEnabled != 0); level_ = Level; keyword_ = MatchAnyKeyword; + + InvokeCallbacks(SourceId, IsEnabled, Level, MatchAnyKeyword, MatchAllKeyword, FilterData, CallbackContext); +} + +void WindowsTelemetry::InvokeCallbacks(LPCGUID SourceId, ULONG IsEnabled, UCHAR Level, ULONGLONG MatchAnyKeyword, + ULONGLONG MatchAllKeyword, PEVENT_FILTER_DESCRIPTOR FilterData, + PVOID CallbackContext) { + std::lock_guard lock(callbacks_mutex_); + for (const auto& callback : callbacks_) { + callback(SourceId, IsEnabled, Level, MatchAnyKeyword, MatchAllKeyword, FilterData, CallbackContext); + } } void WindowsTelemetry::EnableTelemetryEvents() const { diff --git a/onnxruntime/core/platform/windows/telemetry.h b/onnxruntime/core/platform/windows/telemetry.h index c3798943d491d..cdb186e9ed703 100644 --- a/onnxruntime/core/platform/windows/telemetry.h +++ b/onnxruntime/core/platform/windows/telemetry.h @@ -2,12 +2,14 @@ // Licensed under the MIT License. #pragma once +#include +#include + #include "core/platform/telemetry.h" #include #include #include "core/platform/ort_mutex.h" #include "core/platform/windows/TraceLoggingConfig.h" -#include namespace onnxruntime { @@ -58,16 +60,27 @@ class WindowsTelemetry : public Telemetry { void LogExecutionProviderEvent(LUID* adapterLuid) const override; + using EtwInternalCallback = std::function; + + static void RegisterInternalCallback(const EtwInternalCallback& callback); + private: static OrtMutex mutex_; static uint32_t global_register_count_; static bool enabled_; static uint32_t projection_; + static std::vector callbacks_; + static OrtMutex callbacks_mutex_; static OrtMutex provider_change_mutex_; static UCHAR level_; static ULONGLONG keyword_; + static void InvokeCallbacks(LPCGUID SourceId, ULONG IsEnabled, UCHAR Level, ULONGLONG MatchAnyKeyword, + ULONGLONG MatchAllKeyword, PEVENT_FILTER_DESCRIPTOR FilterData, PVOID CallbackContext); + static void NTAPI ORT_TL_EtwEnableCallback( _In_ LPCGUID SourceId, _In_ ULONG IsEnabled, diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp index c6a15e76f4736..2456b396de3f6 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp @@ -344,20 +344,25 @@ namespace Dml::GraphDescBuilder dmlFusedNodeInputIndex < isConstGpuGraphInputCount && isConstGpuGraphInput[dmlFusedNodeInputIndex]) { - // This is a highly inefficient approach to generating constant nodes. It duplicates constant data - // across the graph input as well as every consumer's unique constant node. However it is currently + // This is a highly inefficient approach to generating constant nodes. It duplicates constant data + // across the graph input as well as every consumer's unique constant node. However it is currently // only used for small inputs. uint32_t c_maxConstNodeDataSize = 8; - ComPtr constantInput = constantCpuGraphInputGetter(arg->Name()); auto& operatorGraphInputNode = graphNodeCreateInfo.nodesAsOperatorDesc[operatorGraphInputEdge.ToNodeIndex]; std::vector toNodeInputTensorDescs = operatorGraphInputNode->GetInputTensors(); DmlBufferTensorDesc* tensorDesc = toNodeInputTensorDescs[operatorGraphInputEdge.ToNodeInputIndex]; + ComPtr constantInput; - if (constantInput && tensorDesc->totalTensorSizeInBytes < c_maxConstNodeDataSize) + if (tensorDesc->totalTensorSizeInBytes < c_maxConstNodeDataSize) { - // The tensor description's size should be no larger than the constant input unless it was rounded to + constantInput = constantCpuGraphInputGetter(arg->Name()); + } + + if (constantInput) + { + // The tensor description's size should be no larger than the constant input unless it was rounded to // the required alignment. assert(((constantInput->GetTensorByteSize() + 3) & ~3) >= tensorDesc->totalTensorSizeInBytes); size_t minimumConstantSize = std::min(constantInput->GetTensorByteSize(), gsl::narrow_cast(tensorDesc->totalTensorSizeInBytes)); diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.cpp index dbd06abf82f72..d524780de71b8 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.cpp @@ -1123,7 +1123,7 @@ namespace Windows::AI::MachineLearning::Adapter } ORT_CATCH_RETURN } - + template HRESULT STDMETHODCALLTYPE OpNodeInfoWrapper::GetConstantInputTensor(uint32_t inputIndex, IMLOperatorTensor** tensor) const noexcept { @@ -1168,7 +1168,7 @@ namespace Windows::AI::MachineLearning::Adapter m_requiredConstantCpuInputs.begin(), m_requiredConstantCpuInputs.end(), inputIndex) != m_requiredConstantCpuInputs.end(); - + // This shouldn't happen since kernel creation is deferred and repeated when required constant inputs are not present. ORT_THROW_HR_IF(E_UNEXPECTED, inputRequiredAsConstant); } @@ -1562,7 +1562,13 @@ namespace Windows::AI::MachineLearning::Adapter OnnxTensorWrapper::OnnxTensorWrapper(onnx::TensorProto* impl, const onnxruntime::Path& modelPath) : m_impl(impl) { // The tensor may be stored as raw data or in typed fields. - if (impl->has_raw_data()) + if (impl->data_location() == onnx::TensorProto_DataLocation_EXTERNAL) + { + THROW_IF_NOT_OK(onnxruntime::utils::UnpackInitializerData(*impl, modelPath, m_unpackedExternalTensor)); + m_dataPtr = reinterpret_cast(m_unpackedExternalTensor.data()); + m_tensorByteSize = m_unpackedExternalTensor.size(); + } + else if (impl->has_raw_data()) { m_dataPtr = reinterpret_cast(impl->mutable_raw_data()->data()); m_tensorByteSize = impl->raw_data().size(); diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.h index 6530d89d895e7..59e253e88457a 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.h @@ -309,6 +309,7 @@ class OnnxTensorWrapper : public WRL::Base, public Closable private: size_t m_tensorByteSize = 0; std::unique_ptr m_unpackedTensor; + std::vector m_unpackedExternalTensor; std::byte* m_dataPtr = nullptr; // Lifetime is managed by the caller and guaranteed to outlive this class diff --git a/onnxruntime/core/providers/openvino/backends/basic_backend.cc b/onnxruntime/core/providers/openvino/backends/basic_backend.cc index e6c093d584031..0779940983aea 100644 --- a/onnxruntime/core/providers/openvino/backends/basic_backend.cc +++ b/onnxruntime/core/providers/openvino/backends/basic_backend.cc @@ -70,10 +70,13 @@ BasicBackend::BasicBackend(const ONNX_NAMESPACE::ModelProto& model_proto, LOGS_DEFAULT(INFO) << log_tag << "Loaded model to the plugin"; } #else - if (global_context_.disable_dynamic_shapes && dev_prec != "CPU_FP16") { - const std::string model = model_proto.SerializeAsString(); - exe_network_ = global_context_.ie_core.LoadNetwork( - model, hw_target, device_config, subgraph_context_.subgraph_name); + if (!subgraph_context_.has_dynamic_input_shape && + global_context_.onnx_model_path_name != "" && + dev_prec != "CPU_FP16") { + exe_network_ = global_context_.ie_core.LoadNetwork(global_context_.onnx_model_path_name, + hw_target, + device_config, + subgraph_context_.subgraph_name); LOGS_DEFAULT(INFO) << log_tag << "Loaded model to the plugin"; } else { ie_cnn_network_ = CreateOVModel(model_proto, global_context_, subgraph_context_, const_outputs_map_); diff --git a/onnxruntime/core/providers/openvino/ov_interface.cc b/onnxruntime/core/providers/openvino/ov_interface.cc index 931173fd7ef47..ea481791111fc 100644 --- a/onnxruntime/core/providers/openvino/ov_interface.cc +++ b/onnxruntime/core/providers/openvino/ov_interface.cc @@ -87,13 +87,13 @@ OVExeNetwork OVCore::LoadNetwork(std::shared_ptr& ie_cnn_network, } } -OVExeNetwork OVCore::LoadNetwork(const std::string& model, +OVExeNetwork OVCore::LoadNetwork(const std::string onnx_model_path, std::string& hw_target, ov::AnyMap& device_config, std::string name) { ov::CompiledModel obj; try { - obj = oe.compile_model(model, ov::Tensor(), hw_target, device_config); + obj = oe.compile_model(onnx_model_path, hw_target, device_config); OVExeNetwork exe(obj); return exe; } catch (const Exception& e) { diff --git a/onnxruntime/core/providers/openvino/ov_interface.h b/onnxruntime/core/providers/openvino/ov_interface.h index 3db19463809cf..cf4d867d4df55 100644 --- a/onnxruntime/core/providers/openvino/ov_interface.h +++ b/onnxruntime/core/providers/openvino/ov_interface.h @@ -45,7 +45,7 @@ class OVCore { std::string& hw_target, ov::AnyMap& device_config, std::string name); - OVExeNetwork LoadNetwork(const std::string& model_stream, + OVExeNetwork LoadNetwork(const std::string model_path, std::string& hw_target, ov::AnyMap& device_config, std::string name); diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/split_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/split_op_builder.cc index f4b0d1ff59175..9849a05db329c 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/split_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/split_op_builder.cc @@ -55,6 +55,19 @@ Status SplitOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper, return Status::OK(); } +// Converts an ONNX list of split lengths to a QNN list of split indices. +// Note that the first split index at 0 is implicit (QNN SDK >= 2.19 will raise a validation error if included). +static void ConvertSplitLengthsToSplitIndices(gsl::span split_lengths, + std::vector& split_indices) { + uint32_t split_it = 0; + for (size_t i = 0; i < split_lengths.size(); ++i) { + if (i > 0) { // Do not include the 0th split index. + split_indices.push_back(split_it); + } + split_it += SafeInt(split_lengths[i]); + } +} + Status SplitOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit, std::vector&& input_names, @@ -79,22 +92,15 @@ Status SplitOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wr const int64_t* tensor_data = reinterpret_cast(unpacked_tensor.data()); size_t tensor_byte_size = unpacked_tensor.size(); size_t size = tensor_byte_size / sizeof(int64_t); - split_index.push_back(0); // QNN need the start index of each range and starts from 0 - std::transform(tensor_data, tensor_data + size, std::back_inserter(split_index), - [](int64_t item) { return SafeInt(item); }); - split_index.pop_back(); + ConvertSplitLengthsToSplitIndices({tensor_data, size}, split_index); } else { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "QNN doesn't support dynamic split"); } } else { NodeAttrHelper node_helper(node_unit); if (node_helper.HasAttr("split")) { - auto split = node_helper.Get("split", std::vector{0}); - uint32_t split_it = 0; - for (size_t i = 0; i < split.size(); ++i) { - split_index.push_back(split_it); - split_it += split[i]; - } + auto split_lengths = node_helper.Get("split", std::vector{0}); + ConvertSplitLengthsToSplitIndices(split_lengths, split_index); } } @@ -105,11 +111,19 @@ Status SplitOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wr "Cannot get shape"); ORT_ENFORCE(static_cast(input_shape.size()) > axis_value, "axis not valid!"); ORT_RETURN_IF_NOT(input_shape.at(axis_value) > 0, "Shape value not valid!"); - auto num_outputs = node_unit.Outputs().size(); - auto step = SafeInt(input_shape.at(axis_value) / num_outputs); + + // ONNX spec states that if not evenly divisible by `num_outputs`, the last chunk is smaller. + // Therefore, we have to use ceil() when computing shape[axis] / num_outputs. + // See: core/providers/cpu/tensor/split.cc::PrepareForCompute() + const float num_outputs = static_cast(node_unit.Outputs().size()); + const float split_dim_size = static_cast(input_shape[axis_value]); + const uint32_t step = SafeInt(std::ceil(split_dim_size / num_outputs)); uint32_t split_it = 0; + for (size_t i = 0; i < num_outputs; ++i) { - split_index.push_back(split_it); + if (i > 0) { // 0th split index is implicit (QNN >= 2.19 raises validation error if included) + split_index.push_back(split_it); + } split_it += step; } } diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index e8853c8824738..c8fc812fe1238 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -46,10 +46,11 @@ #include "core/optimizer/transformer_memcpy.h" #include "core/optimizer/transpose_optimization/ort_optimizer_utils.h" #include "core/platform/Barrier.h" -#include "core/platform/ort_mutex.h" #include "core/platform/threadpool.h" #ifdef _WIN32 #include "core/platform/tracing.h" +#include +#include "core/platform/windows/telemetry.h" #endif #include "core/providers/cpu/controlflow/utils.h" #include "core/providers/cpu/cpu_execution_provider.h" @@ -239,6 +240,10 @@ Status GetMinimalBuildOptimizationHandling( } // namespace std::atomic InferenceSession::global_session_id_{1}; +std::map InferenceSession::active_sessions_; +#ifdef _WIN32 +OrtMutex InferenceSession::active_sessions_mutex_; // Protects access to active_sessions_ +#endif static Status FinalizeSessionOptions(const SessionOptions& user_provided_session_options, const ONNX_NAMESPACE::ModelProto& model_proto, @@ -349,17 +354,47 @@ void InferenceSession::SetLoggingManager(const SessionOptions& session_options, void InferenceSession::ConstructorCommon(const SessionOptions& session_options, const Environment& session_env) { auto status = FinalizeSessionOptions(session_options, model_proto_, is_model_proto_parsed_, session_options_); - // a monotonically increasing session id for use in telemetry - session_id_ = global_session_id_.fetch_add(1); ORT_ENFORCE(status.IsOK(), "Could not finalize session options while constructing the inference session. Error Message: ", status.ErrorMessage()); + // a monotonically increasing session id for use in telemetry + session_id_ = global_session_id_.fetch_add(1); + +#ifdef _WIN32 + std::lock_guard lock(active_sessions_mutex_); + active_sessions_[global_session_id_++] = this; + + // Register callback for ETW capture state (rundown) + WindowsTelemetry::RegisterInternalCallback( + [this]( + LPCGUID SourceId, + ULONG IsEnabled, + UCHAR Level, + ULONGLONG MatchAnyKeyword, + ULONGLONG MatchAllKeyword, + PEVENT_FILTER_DESCRIPTOR FilterData, + PVOID CallbackContext) { + (void)SourceId; + (void)Level; + (void)MatchAnyKeyword; + (void)MatchAllKeyword; + (void)FilterData; + (void)CallbackContext; + + // Check if this callback is for capturing state + if ((IsEnabled == EVENT_CONTROL_CODE_CAPTURE_STATE) && + ((MatchAnyKeyword & static_cast(onnxruntime::logging::ORTTraceLoggingKeyword::Session)) != 0)) { + LogAllSessions(); + } + }); +#endif + SetLoggingManager(session_options, session_env); // The call to InitLogger depends on the final state of session_options_. Hence it should be invoked // after the invocation of FinalizeSessionOptions. InitLogger(logging_manager_); // this sets session_logger_ so that it can be used for logging after this point. - TraceSessionOptions(session_options); + TraceSessionOptions(session_options, false); #if !defined(ORT_MINIMAL_BUILD) // Update the number of steps for the graph transformer manager using the "finalized" session options @@ -473,7 +508,9 @@ void InferenceSession::ConstructorCommon(const SessionOptions& session_options, telemetry_ = {}; } -void InferenceSession::TraceSessionOptions(const SessionOptions& session_options) { +void InferenceSession::TraceSessionOptions(const SessionOptions& session_options, bool captureState) { + (void)captureState; // Otherwise Linux build error + LOGS(*session_logger_, INFO) << session_options; #ifdef _WIN32 @@ -496,7 +533,8 @@ void InferenceSession::TraceSessionOptions(const SessionOptions& session_options TraceLoggingUInt8(static_cast(session_options.graph_optimization_level), "graph_optimization_level"), TraceLoggingBoolean(session_options.use_per_session_threads, "use_per_session_threads"), TraceLoggingBoolean(session_options.thread_pool_allow_spinning, "thread_pool_allow_spinning"), - TraceLoggingBoolean(session_options.use_deterministic_compute, "use_deterministic_compute")); + TraceLoggingBoolean(session_options.use_deterministic_compute, "use_deterministic_compute"), + TraceLoggingBoolean(captureState, "isCaptureState")); TraceLoggingWrite( telemetry_provider_handle, @@ -509,7 +547,8 @@ void InferenceSession::TraceSessionOptions(const SessionOptions& session_options TraceLoggingInt32(session_options.intra_op_param.dynamic_block_base_, "dynamic_block_base_"), TraceLoggingUInt32(session_options.intra_op_param.stack_size, "stack_size"), TraceLoggingString(!session_options.intra_op_param.affinity_str.empty() ? session_options.intra_op_param.affinity_str.c_str() : "", "affinity_str"), - TraceLoggingBoolean(session_options.intra_op_param.set_denormal_as_zero, "set_denormal_as_zero")); + TraceLoggingBoolean(session_options.intra_op_param.set_denormal_as_zero, "set_denormal_as_zero"), + TraceLoggingBoolean(captureState, "isCaptureState")); for (const auto& config_pair : session_options.config_options.configurations) { TraceLoggingWrite( @@ -518,7 +557,8 @@ void InferenceSession::TraceSessionOptions(const SessionOptions& session_options TraceLoggingKeyword(static_cast(onnxruntime::logging::ORTTraceLoggingKeyword::Session)), TraceLoggingLevel(WINEVENT_LEVEL_INFO), TraceLoggingString(config_pair.first.c_str(), "Key"), - TraceLoggingString(config_pair.second.c_str(), "Value")); + TraceLoggingString(config_pair.second.c_str(), "Value"), + TraceLoggingBoolean(captureState, "isCaptureState")); } #endif } @@ -614,6 +654,12 @@ InferenceSession::~InferenceSession() { } } + // Unregister the session +#ifdef _WIN32 + std::lock_guard lock(active_sessions_mutex_); +#endif + active_sessions_.erase(global_session_id_); + #ifdef ONNXRUNTIME_ENABLE_INSTRUMENT if (session_activity_started_) TraceLoggingWriteStop(session_activity, "OrtInferenceSessionActivity"); @@ -3043,4 +3089,14 @@ IOBinding* SessionIOBinding::Get() { return binding_.get(); } +#ifdef _WIN32 +void InferenceSession::LogAllSessions() { + std::lock_guard lock(active_sessions_mutex_); + for (const auto& session_pair : active_sessions_) { + InferenceSession* session = session_pair.second; + TraceSessionOptions(session->session_options_, true); + } +} +#endif + } // namespace onnxruntime diff --git a/onnxruntime/core/session/inference_session.h b/onnxruntime/core/session/inference_session.h index 96db49aabdaf6..f8211bfd2dd4e 100644 --- a/onnxruntime/core/session/inference_session.h +++ b/onnxruntime/core/session/inference_session.h @@ -3,6 +3,7 @@ #pragma once +#include #include #include #include @@ -21,11 +22,12 @@ #include "core/framework/session_state.h" #include "core/framework/tuning_results.h" #include "core/framework/framework_provider_common.h" +#include "core/framework/session_options.h" #include "core/graph/basic_types.h" #include "core/optimizer/graph_transformer_level.h" #include "core/optimizer/graph_transformer_mgr.h" #include "core/optimizer/insert_cast_transformer.h" -#include "core/framework/session_options.h" +#include "core/platform/ort_mutex.h" #ifdef ENABLE_LANGUAGE_INTEROP_OPS #include "core/language_interop_ops/language_interop_ops.h" #endif @@ -119,6 +121,10 @@ class InferenceSession { }; using InputOutputDefMetaMap = InlinedHashMap; + static std::map active_sessions_; +#ifdef _WIN32 + static OrtMutex active_sessions_mutex_; // Protects access to active_sessions_ +#endif public: #if !defined(ORT_MINIMAL_BUILD) @@ -642,7 +648,7 @@ class InferenceSession { void InitLogger(logging::LoggingManager* logging_manager); - void TraceSessionOptions(const SessionOptions& session_options); + void TraceSessionOptions(const SessionOptions& session_options, bool captureState); [[nodiscard]] common::Status CheckShapes(const std::string& input_name, const TensorShape& input_shape, const TensorShape& expected_shape, const char* input_output_moniker) const; @@ -679,6 +685,10 @@ class InferenceSession { */ void ShrinkMemoryArenas(gsl::span arenas_to_shrink); +#ifdef _WIN32 + void LogAllSessions(); +#endif + #if !defined(ORT_MINIMAL_BUILD) virtual common::Status AddPredefinedTransformers( GraphTransformerManager& transformer_manager, diff --git a/onnxruntime/core/session/provider_bridge_ort.cc b/onnxruntime/core/session/provider_bridge_ort.cc index 29c2c6b0cce16..c8579fd88ac41 100644 --- a/onnxruntime/core/session/provider_bridge_ort.cc +++ b/onnxruntime/core/session/provider_bridge_ort.cc @@ -1476,7 +1476,11 @@ ProviderOptions OrtOpenVINOProviderOptionsToOrtOpenVINOProviderOptionsV2(const O if (legacy_ov_options->device_type != nullptr) ov_options_converted_map["device_type"] = legacy_ov_options->device_type; - ov_options_converted_map["enable_npu_fast_compile"] = legacy_ov_options->enable_npu_fast_compile; + if (legacy_ov_options->enable_npu_fast_compile) { + ov_options_converted_map["enable_npu_fast_compile"] = "false"; + } else { + ov_options_converted_map["enable_npu_fast_compile"] = "true"; + } if (legacy_ov_options->device_id != nullptr) ov_options_converted_map["device_id"] = legacy_ov_options->device_id; @@ -1495,14 +1499,12 @@ ProviderOptions OrtOpenVINOProviderOptionsToOrtOpenVINOProviderOptionsV2(const O ov_options_converted_map["enable_opencl_throttling"] = legacy_ov_options->enable_opencl_throttling; - if (legacy_ov_options->enable_dynamic_shapes != '\0') { - std::string enable_dynamic_shapes = reinterpret_cast(legacy_ov_options->enable_dynamic_shapes); - if (enable_dynamic_shapes == "true" || enable_dynamic_shapes == "True") { - ov_options_converted_map["disable_dynamic_shapes"] = "false"; - } else if (enable_dynamic_shapes == "false" || enable_dynamic_shapes == "False") { - ov_options_converted_map["disable_dynamic_shapes"] = "true"; - } + if (legacy_ov_options->enable_dynamic_shapes) { + ov_options_converted_map["disable_dynamic_shapes"] = "false"; + } else { + ov_options_converted_map["disable_dynamic_shapes"] = "true"; } + // Add new provider option below ov_options_converted_map["num_streams"] = "1"; return ov_options_converted_map; diff --git a/onnxruntime/core/session/provider_registration.cc b/onnxruntime/core/session/provider_registration.cc index 86b3d01c640a3..ac059bfd00668 100644 --- a/onnxruntime/core/session/provider_registration.cc +++ b/onnxruntime/core/session/provider_registration.cc @@ -90,6 +90,10 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider, (std::string(provider_name) + " execution provider is not supported in this build. ").c_str()); }; + for (const auto& config_pair : provider_options) { + ORT_THROW_IF_ERROR(options->value.config_options.AddConfigEntry((std::string(provider_name) + ":" + config_pair.first).c_str(), config_pair.second.c_str())); + } + if (strcmp(provider_name, "DML") == 0) { #if defined(USE_DML) options->provider_factories.push_back(DMLProviderFactoryCreator::CreateFromProviderOptions(provider_options)); diff --git a/onnxruntime/core/util/thread_utils.cc b/onnxruntime/core/util/thread_utils.cc index 48f58add8237b..a5a165e150cf1 100644 --- a/onnxruntime/core/util/thread_utils.cc +++ b/onnxruntime/core/util/thread_utils.cc @@ -7,6 +7,7 @@ #ifdef _WIN32 #include +#include #endif #include #include "core/session/ort_apis.h" @@ -98,7 +99,16 @@ CreateThreadPoolHelper(Env* env, OrtThreadPoolParams options) { } options.thread_pool_size = static_cast(default_affinities.size()); if (options.auto_set_affinity) { +#ifdef _WIN32 + // Only set thread affinity on Server with auto affinity. + // On client best to let OS scheduler handle. + // On big (P-Core) / little (E-Core) CPU designs affinity overrides QoS and has high power usage + if (IsWindowsServer()) { + to.affinities = std::move(default_affinities); + } +#else to.affinities = std::move(default_affinities); +#endif } } if (options.thread_pool_size <= 1) { diff --git a/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py b/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py index 3e9f9a6544a71..eb7bbec997d59 100644 --- a/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py +++ b/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py @@ -349,6 +349,10 @@ def process(self): self.int4_quant_algo() +def ort_convert_str_to_bool(value): + return value.lower() in ("true", "1") + + def parse_args(): parser = argparse.ArgumentParser( description="""Blockwise int4 quantization for MatMul 2D weight matrices. @@ -366,7 +370,10 @@ def parse_args(): "--symmetric", required=False, default=True, - type=bool, + const=True, + nargs="?", + type=ort_convert_str_to_bool, + choices=[True, False], help="Indicate whether to quantize the model symmetrically", ) parser.add_argument( diff --git a/onnxruntime/python/tools/quantization/onnx_quantizer.py b/onnxruntime/python/tools/quantization/onnx_quantizer.py index 898a5f70ac45e..e7e562085c670 100644 --- a/onnxruntime/python/tools/quantization/onnx_quantizer.py +++ b/onnxruntime/python/tools/quantization/onnx_quantizer.py @@ -389,7 +389,7 @@ def add_new_nodes(self, nodes): def quantize_model(self): if self.has_QDQ_nodes(): logging.warning( - "Please check if the model is already quantized." + "Please check if the model is already quantized. " "Note you don't need to quantize a QAT model. OnnxRuntime support to run QAT model directly." ) @@ -446,6 +446,23 @@ def is_valid_quantize_weight(self, weight_name): return False return self.parent.is_valid_quantize_weight(weight_name) + def _get_default_tensor_type(self, tensor_name): + if "DefaultTensorType" in self.extra_options: + logging.info( + "get_tensor_type returns DefaultTensorType for tensor name %r, use %d", + tensor_name, + self.extra_options["DefaultTensorType"], + ) + return self.extra_options["DefaultTensorType"] + raise RuntimeError( + f"Unable to find data type for weight_name={tensor_name!r}. " + f"shape_inference failed to return a type probably this node is " + f"from a different domain or using an input produced by such an operator. " + f"This may happen if you quantize a model already quantized. " + f"You may use extra_options `DefaultTensorType` to indicate " + f"the default weight type, usually `onnx.TensorProto.FLOAT`." + ) + def get_tensor_type(self, tensor_name, mandatory=False): weight = find_by_name(tensor_name, self.model.initializer()) if weight is not None: @@ -454,11 +471,11 @@ def get_tensor_type(self, tensor_name, mandatory=False): vi = self.value_infos[tensor_name] if vi.type.HasField("tensor_type"): if mandatory and vi.type.tensor_type.elem_type == 0: - raise RuntimeError(f"Unable to find data type for weight_name={tensor_name!r}") + return self._get_default_tensor_type(tensor_name) return vi.type.tensor_type.elem_type if (not self.enable_subgraph_quantization) or (self.parent is None): if mandatory: - raise RuntimeError(f"Unable to find data type for weight_name={tensor_name!r}") + return self._get_default_tensor_type(tensor_name) return None otype = self.parent.is_valid_quantize_weight(tensor_name) if otype is not None: @@ -468,7 +485,7 @@ def get_tensor_type(self, tensor_name, mandatory=False): if res is not None: return res if mandatory: - raise RuntimeError(f"Unable to find data type for weight_name={tensor_name!r}") + return self._get_default_tensor_type(tensor_name) return None def is_float_tensor(self, tensor_name): @@ -1336,9 +1353,15 @@ def _dequantize_value(self, value_name): if (value_name in self.quantized_value_map) and (value_name not in self.generated_value_names): quantized_value = self.quantized_value_map[value_name] # Add DequantizeLinear Node for this input + scale_init = find_by_name(quantized_value.scale_name, self.model.initializer()) - # axis is not specified so scale_init must be a scalar. - assert onnx.numpy_helper.to_array(scale_init).size == 1 + + # In case we are working with subgraphs, the graph `producer_name` is set to `"onnx-quantizer"` in the `quantize_subgraph` method. In this case, the scale initializer may be on the top level graph, so the check below can not be done. + if self.model.model.producer_name != "onnx-quantizer" or ( + self.model.model.producer_name == "onnx-quantizer" and scale_init is not None + ): + # axis is not specified so scale_init must be a scalar. + assert onnx.numpy_helper.to_array(scale_init).size == 1 dqlinear_name = value_name + "_DequantizeLinear" dqlinear_node = self.model.find_node_by_name(dqlinear_name, self.new_nodes, self.model.graph()) diff --git a/onnxruntime/python/tools/quantization/qdq_quantizer.py b/onnxruntime/python/tools/quantization/qdq_quantizer.py index b0153aed766ad..123cfe913d6e2 100644 --- a/onnxruntime/python/tools/quantization/qdq_quantizer.py +++ b/onnxruntime/python/tools/quantization/qdq_quantizer.py @@ -270,6 +270,8 @@ def quantize_model(self): self.model.model.producer_name = __producer__ self.model.model.producer_version = __version__ + if self.qdq_op_domain == ms_domain: + self.model.set_opset_import(ms_domain, 1) return self.model.model diff --git a/onnxruntime/python/tools/symbolic_shape_infer.py b/onnxruntime/python/tools/symbolic_shape_infer.py index ef4c4ae906243..251d41a24ccc7 100755 --- a/onnxruntime/python/tools/symbolic_shape_infer.py +++ b/onnxruntime/python/tools/symbolic_shape_infer.py @@ -197,6 +197,7 @@ def __init__(self, int_max, auto_merge, guess_output_rank, verbose, prefix=""): "BiasGelu": self._infer_BiasGelu, "BiasSplitGelu": self._infer_BiasSplitGelu, "DecoderMaskedMultiHeadAttention": self._infer_DecoderMaskedMultiHeadAttention, + "DequantizeLinear": self._infer_DequantizeLinear, "EmbedLayerNormalization": self._infer_EmbedLayerNormalization, "FastGelu": self._infer_FastGelu, "GatedRelativePositionBias": self._infer_GatedRelativePositionBias, @@ -212,6 +213,7 @@ def __init__(self, int_max, auto_merge, guess_output_rank, verbose, prefix=""): "PackedAttention": self._infer_PackedAttention, "PackedMultiHeadAttention": self._infer_PackedMultiHeadAttention, "PythonOp": self._infer_PythonOp, + "QuantizeLinear": self._infer_QuantizeLinear, "QuickGelu": self._infer_FastGelu, "RelativePositionBias": self._infer_RelativePositionBias, "RemovePadding": self._infer_RemovePadding, @@ -238,6 +240,7 @@ def __init__(self, int_max, auto_merge, guess_output_rank, verbose, prefix=""): "upsample_nearest1d": self._infer_aten_upsample, "upsample_nearest2d": self._infer_aten_upsample, "upsample_nearest3d": self._infer_aten_upsample, + "upsample_bicubic2d": self._infer_aten_upsample, } self.run_ = True self.suggested_merge_ = {} @@ -457,6 +460,8 @@ def _onnx_infer_single_node(self, node): "GemmFastGelu", "LayerNormalization", "LongformerAttention", + "DequantizeLinear", + "QuantizeLinear", "RelativePositionBias", "RemovePadding", "RestorePadding", @@ -979,6 +984,29 @@ def _infer_NhwcConv(self, node): # noqa: N802 ) ) + def _infer_DequantizeLinear(self, node): # noqa: N802 + # Get the output data type from the scale input (index 1, required). + output_dtype = self.known_vi_[node.input[1]].type.tensor_type.elem_type + + # Get the output shape from the first input. + output_shape = self._get_shape(node, 0) + + vi = self.known_vi_[node.output[0]] + vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, output_shape)) + + def _infer_QuantizeLinear(self, node): # noqa: N802 + # Get the output data type from the zero-point input (index 2, optional). + # Otherwise, default to uint8 + output_dtype = onnx.TensorProto.UINT8 + if len(node.input) > 2 and node.input[2]: + output_dtype = self.known_vi_[node.input[2]].type.tensor_type.elem_type + + # Get the output shape from the first input. + output_shape = self._get_shape(node, 0) + + vi = self.known_vi_[node.output[0]] + vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, output_shape)) + def _infer_Einsum(self, node): # noqa: N802 # ref:https://github.com/onnx/onnx/blob/623dfaa0151b2e4ce49779c3ec31cbd78c592b80/onnx/defs/math/defs.cc#L3275 equation = get_attribute(node, "equation") diff --git a/onnxruntime/python/tools/transformers/fusion_bart_attention.py b/onnxruntime/python/tools/transformers/fusion_bart_attention.py index 71801401e9d06..ebecc1db24792 100644 --- a/onnxruntime/python/tools/transformers/fusion_bart_attention.py +++ b/onnxruntime/python/tools/transformers/fusion_bart_attention.py @@ -74,13 +74,74 @@ def check_runtime_shape_path( return True + def check_runtime_shape_path_openai( + self, + reshape_qkv_2, + matmul_qkv, + add_qk, + matmul_qk, + add_q, + ): + reshape_qkv_2_path = self.model.match_parent_path( + reshape_qkv_2, ["Concat", "Slice", "Gather", "Shape"], [1, 0, 0, 0] + ) + if reshape_qkv_2_path is None: + return False + else: + if reshape_qkv_2_path[-1].input[0] != matmul_qkv.output[0]: + return False + + matmul_qk_path_1 = self.model.match_parent_path( + matmul_qk, ["Mul", "Pow", "Cast", "Div", "Gather", "Shape"], [0, 1, 0, 0, 0, 0] + ) + matmul_qk_path_2 = self.model.match_parent_path( + matmul_qk, ["Mul", "Pow", "Cast", "Div", "Gather", "Shape"], [1, 1, 0, 0, 0, 0] + ) + if matmul_qk_path_1 is None or matmul_qk_path_2 is None: + return False + + mul_1 = matmul_qk_path_1[0] + mul_2 = matmul_qk_path_2[0] + if mul_1.input[1] != mul_2.input[1]: + return False + if matmul_qk_path_1[-1].input[0] != add_q.output[0] and matmul_qk_path_2[-1].input[0] != add_q.output[0]: + return False + + # For decoder attentions only + if add_qk is not None: + add_qk_path = self.model.match_parent_path(add_qk, ["Slice"], [1]) + if add_qk_path is None: + return False + slice_q_path_1 = self.model.match_parent_path( + add_qk_path[0], ["Slice", "Unsqueeze", "Gather", "Shape"], [0, 2, 0, 0] + ) + slice_q_path_2 = self.model.match_parent_path(add_qk_path[0], ["Unsqueeze", "Gather", "Shape"], [2, 0, 0]) + if slice_q_path_1 is None and slice_q_path_2 is None: + return False + _, unsqueeze_1, _, _ = slice_q_path_1 + unsqueeze_2, _, _ = slice_q_path_2 + if unsqueeze_1.input[0] != unsqueeze_2.input[0]: + return False + if slice_q_path_1[-1].input[0] != add_q.output[0] and slice_q_path_2[-1].input[0] != add_q.output[0]: + return False + + return True + def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): + # Track if fusion is occurring for OpenAI implementation of Whisper + model_impl_openai = False + # SkipLayerNormalization has two inputs, and one of them is the root input for attention. qkv_nodes = self.model.match_parent_path( normalize_node, ["Add", "MatMul", "Reshape", "Transpose", "Reshape", "MatMul"], [1, 1, 0, 0, 0, 0], ) + qkv_nodes_openai = self.model.match_parent_path( + normalize_node, + ["Add", "MatMul", "Reshape", "Transpose", "MatMul"], + [1, 1, 0, 0, 0], + ) if qkv_nodes is not None: ( add_out, @@ -90,6 +151,17 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): reshape_qkv_1, matmul_qkv, ) = qkv_nodes + elif qkv_nodes_openai is not None: + qkv_nodes = qkv_nodes_openai + ( + add_out, + matmul_out, + reshape_qkv_2, + transpose_qkv, + matmul_qkv, + ) = qkv_nodes + # Set model implementation to openai + model_impl_openai = True else: return @@ -137,6 +209,11 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): ["Reshape", "Transpose", "Reshape", "Add", "MatMul"], [1, 0, 0, 0, None], ) + v_nodes_openai = self.model.match_parent_path( + matmul_qkv, + ["Transpose", "Reshape", "Add", "MatMul"], + [1, 0, 0, None], + ) v_nodes_with_past_self_attn = self.model.match_parent_path( # Decoder attention with past value concatenated before MatMul matmul_qkv, @@ -149,12 +226,52 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): ["Reshape"], [1], ) + v_nodes_with_past_cross_attn_openai = self.model.match_parent_path( + matmul_qkv, + ["Transpose", "Reshape", "Reshape", "Transpose"], + [1, 0, 0, 0], + ) past_v, present_v = "", "" reshape_v_2, add_v = None, None if v_nodes is not None: (reshape_v_2, transpose_v, reshape_v_1, add_v, matmul_v) = v_nodes # For initial pass through encoder-decoder_with_past to get starting past values (beam search) present_v = transpose_v.output[0] + elif v_nodes_openai is not None: + v_nodes = v_nodes_openai + (transpose_v, reshape_v_1, add_v, matmul_v) = v_nodes + # For initial pass through encoder-decoder_with_past to get starting past values (beam search) + + # Find the child path to access the correct present_v values + # Openai impl provides present/past v values in 3D format + # whereas ort MultiHeadAttention expects v values in 4D, hence the + # additional Reshape and Transpose nodes are added + # For encoder attention types + # Add -> Reshape -> Transpose -> Present_V + reshape_path = self.model.match_child_path( + add_v, + ["Reshape", "Transpose"], + exclude=[reshape_v_1], + ) + # For decoder attention types + # add_v_node Reshape <- Transpose <-Past_V + # \ / + # \ / + # -> Concat <- + # | + # |--> Reshape -> Transpose -> Present_V + concat_path = self.model.match_child_path(add_v, ["Concat", "Reshape", "Transpose"]) + if reshape_path is not None: + (_, transpose_add_v) = reshape_path + if transpose_add_v.output[0] in graph_output_names: + present_v = transpose_add_v.output[0] + if concat_path is not None: + (concat_v, _, transpose_concat_v) = concat_path + if transpose_concat_v.output[0] in graph_output_names: + present_v = transpose_concat_v.output[0] + concat_nodes = self.model.match_parent_path(concat_v, ["Reshape", "Transpose"], [0, 0]) + _, transpose_concat_v_in = concat_nodes + past_v = transpose_concat_v_in.input[0] elif v_nodes_with_past_self_attn is not None: (reshape_v_2, concat_v, transpose_v, reshape_v_1, add_v, matmul_v) = v_nodes_with_past_self_attn v_nodes = v_nodes_with_past_self_attn @@ -171,6 +288,18 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): filter(lambda node: node.op_type == "Identity", self.model.input_name_to_nodes()[past_v]) ) present_v = identity_node_v[0].output[0] if len(identity_node_v) == 1 else "" + elif ( + v_nodes_with_past_cross_attn_openai is not None + and v_nodes_with_past_cross_attn_openai[-1].input[0] in graph_input_names + ): + v_nodes = v_nodes_with_past_cross_attn_openai + past_v = v_nodes[-1].input[0] + present_v = v_nodes[-1].output[0] + if present_v not in graph_output_names: + identity_node_v = list( + filter(lambda node: node.op_type == "Identity", self.model.input_name_to_nodes()[past_v]) + ) + present_v = identity_node_v[0].output[0] if len(identity_node_v) == 1 else "" else: logger.debug("fuse_attention: failed to match v path") return @@ -181,12 +310,17 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): qk_nodes_2 = self.model.match_parent_path( matmul_qkv, ["Softmax", "Reshape", "Add", "Reshape", "MatMul"], [0, 0, 0, 0, 0] ) + qk_nodes_2_openai = self.model.match_parent_path(matmul_qkv, ["Softmax", "Add", "MatMul"], [0, 0, 0]) + add_qk = None if qk_nodes_1 is not None: _, matmul_qk = qk_nodes_1 qk_nodes = qk_nodes_1 elif qk_nodes_2 is not None: _, _, add_qk, _, matmul_qk = qk_nodes_2 qk_nodes = qk_nodes_2 + elif qk_nodes_2_openai is not None: + _, add_qk, matmul_qk = qk_nodes_2_openai + qk_nodes = qk_nodes_2_openai else: return @@ -195,8 +329,17 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): ["Reshape", "Transpose", "Reshape", "Mul", "Add", "MatMul"], [0, 0, 0, 0, 0, 1], ) + q_nodes_openai = self.model.match_parent_path( + matmul_qk, + ["Mul", "Transpose", "Reshape", "Add", "MatMul"], + [0, 0, 0, 0, 1], + ) + reshape_q_2 = None if q_nodes is not None: reshape_q_2, transpose_q, reshape_q_1, mul_q, add_q, matmul_q = q_nodes + elif q_nodes_openai is not None: + q_nodes = q_nodes_openai + mul_q, transpose_q, reshape_q_1, add_q, matmul_q = q_nodes else: return @@ -205,6 +348,11 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): ["Transpose", "Reshape", "Transpose", "Reshape", "Add", "MatMul"], [1, 0, 0, 0, 0, 1], ) + k_nodes_with_bias_openai = self.model.match_parent_path( + matmul_qk, + ["Mul", "Transpose", "Reshape", "MatMul"], + [1, 0, 0, 0], + ) k_nodes_no_bias = self.model.match_parent_path( matmul_qk, ["Transpose", "Reshape", "Transpose", "Reshape", "MatMul"], @@ -222,11 +370,52 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): ["Transpose", "Reshape"], [1, 0], ) + k_nodes_no_bias_with_past_cross_attn_openai = self.model.match_parent_path( + # Decoder attention with past key directly used in MatMul + matmul_qk, + ["Mul", "Transpose", "Reshape", "Reshape", "Transpose"], + [1, 0, 0, 0, 0], + ) past_k, present_k = "", "" reshape_k_2, reshape_k_1, matmul_k = None, None, None if k_nodes_with_bias is not None: _, reshape_k_2, transpose_k_1, reshape_k_1, add_k, matmul_k = k_nodes_with_bias k_nodes = k_nodes_with_bias + elif k_nodes_with_bias_openai is not None: + mul_k, transpose_k_1, reshape_k_1, matmul_k = k_nodes_with_bias_openai + k_nodes = k_nodes_with_bias_openai + present_k = matmul_k.output[0] + + # Find the child path to access the correct present_k values + # Openai impl provides present/past k values in 3D format + # whereas ort MultiHeadAttention expects k values in 4D, hence the + # additional Reshape and Transpose nodes are added + # For encoder attention types + # Matmul -> Reshape -> Transpose -> Present_K + reshape_path = self.model.match_child_path( + matmul_k, + ["Reshape", "Transpose"], + exclude=[reshape_k_1], + ) + # For decoder attention types + # matmul_k_node Reshape <- Transpose <- Past_K + # \ / + # \ / + # -> Concat <- + # | + # |--> Reshape -> Transpose -> Present_K + concat_path = self.model.match_child_path(matmul_k, ["Concat", "Reshape", "Transpose"]) + if reshape_path is not None: + (_, transpose_matmul_k) = reshape_path + if transpose_matmul_k.output[0] in graph_output_names: + present_k = transpose_matmul_k.output[0] + if concat_path is not None: + (concat_k, _, transpose_concat_k) = concat_path + if transpose_concat_k.output[0] in graph_output_names: + present_k = transpose_concat_k.output[0] + concat_nodes = self.model.match_parent_path(concat_k, ["Reshape", "Transpose"], [0, 0]) + _, transpose_concat_k_in = concat_nodes + past_k = transpose_concat_k_in.input[0] elif k_nodes_no_bias is not None: _, reshape_k_2, transpose_k_1, reshape_k_1, matmul_k = k_nodes_no_bias k_nodes = k_nodes_no_bias @@ -249,12 +438,24 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): filter(lambda node: node.op_type == "Identity", self.model.input_name_to_nodes()[past_k]) ) present_k = identity_node_k[0].output[0] if len(identity_node_k) == 1 else "" + elif ( + k_nodes_no_bias_with_past_cross_attn_openai is not None + and k_nodes_no_bias_with_past_cross_attn_openai[-1].input[0] in graph_input_names + ): + k_nodes = k_nodes_no_bias_with_past_cross_attn_openai + past_k = k_nodes[-1].input[0] + present_k = k_nodes[-1].output[0] + if present_k not in graph_output_names: + identity_node_k = list( + filter(lambda node: node.op_type == "Identity", self.model.input_name_to_nodes()[past_k]) + ) + present_k = identity_node_k[0].output[0] if len(identity_node_k) == 1 else "" else: return past_k = past_k if past_k in graph_input_names else "" present_k = present_k if present_k in graph_output_names else "" - if k_nodes in (k_nodes_no_bias, k_nodes_no_bias_with_past_self_attn): + if k_nodes in (k_nodes_with_bias_openai, k_nodes_no_bias, k_nodes_no_bias_with_past_self_attn): # Create empty Add node for attention graph bias_dim = self.model.get_initializer(add_v.input[0]).dims[0] empty_bias_name = "empty_bias" @@ -270,13 +471,29 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): add_name = self.model.create_node_name("Add") add_k = helper.make_node("Add", [empty_bias_name, matmul_k.output[0]], [reshape_k_1.name], add_name) - if not past_k and not self.check_runtime_shape_path( - reshape_qkv_2, - reshape_qkv_1, - reshape_q_2, - reshape_k_2, - reshape_v_2, - root_input, + if ( + model_impl_openai + and not past_k + and not self.check_runtime_shape_path_openai( + reshape_qkv_2, + matmul_qkv, + add_qk, + matmul_qk, + add_q, + ) + ): + return + elif ( + not model_impl_openai + and not past_k + and not self.check_runtime_shape_path( + reshape_qkv_2, + reshape_qkv_1, + reshape_q_2, + reshape_k_2, + reshape_v_2, + root_input, + ) ): return @@ -301,8 +518,10 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): # 4) Decoder cross attention with two_root_inputs=True and qk_nodes=qk_nodes_1 # 5) Decoder cross attention with past with three_root_inputs=True and qk_nodes=qk_nodes_1 encoder_attention = one_root_input and qk_nodes == qk_nodes_1 - decoder_attention = one_root_input and qk_nodes == qk_nodes_2 - decoder_attention_with_past = encoder_attention and past_k and past_v + decoder_attention = one_root_input and qk_nodes in (qk_nodes_2, qk_nodes_2_openai) + decoder_attention_with_past = ( + (encoder_attention if not model_impl_openai else decoder_attention) and past_k and past_v + ) decoder_cross_attention = two_root_inputs and qk_nodes == qk_nodes_1 decoder_cross_attention_with_past = three_root_inputs and qk_nodes == qk_nodes_1 diff --git a/onnxruntime/python/tools/transformers/models/whisper/README.md b/onnxruntime/python/tools/transformers/models/whisper/README.md index 02100266200f8..7a678f2734ade 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/README.md +++ b/onnxruntime/python/tools/transformers/models/whisper/README.md @@ -1,5 +1,22 @@ # Whisper +## Prerequisites + +Please note the package versions needed for using Whisper in the `requirements.txt` file that fits your scenario. +- `requirements-cpu.txt` + - For running Whisper on CPU +- `requirements-cuda.txt` + - For running Whisper on CUDA + - Note that `torch` with CUDA enabled is not installed automatically. This is because `torch` should be installed with the CUDA version used on your machine. Please visit [the PyTorch website](https://pytorch.org/get-started/locally/) to download the `torch` version that is used with the CUDA version installed on your machine and satisfies the requirement listed in the file. +- `requirements.txt` + - Package versions needed in each of the above files + +In addition to the above packages, you will need to install `ffmpeg` on your machine. Visit the [FFmpeg website](https://ffmpeg.org/) for details. You can also install it natively using package managers. + +- Linux: `sudo apt-get install ffmpeg` +- MacOS: `sudo brew install ffmpeg` +- Windows: Download from website + ## Exporting Whisper with Beam Search There are several ways to export Whisper with beam search (using Whisper tiny as an example). @@ -10,10 +27,10 @@ There are several ways to export Whisper with beam search (using Whisper tiny as # From source $ git clone https://github.com/microsoft/onnxruntime $ cd onnxruntime/onnxruntime/python/tools/transformers/ -$ python3 -m models.whisper.convert_to_onnx -m openai/whisper-tiny --output whispertiny --use_external_data_format +$ python3 -m models.whisper.convert_to_onnx -m openai/whisper-large-v3 --output whisperlargev3 --use_external_data_format # From wheel -$ python3 -m onnxruntime.transformers.models.whisper.convert_to_onnx -m openai/whisper-tiny --output whispertiny --use_external_data_format +$ python3 -m onnxruntime.transformers.models.whisper.convert_to_onnx -m openai/whisper-large-v3 --output whisperlargev3 --use_external_data_format ``` ### Option 2: end-to-end model from [Olive](https://github.com/microsoft/Olive/tree/main/examples/whisper) @@ -39,40 +56,49 @@ model.save_pretrained(model_name.split("/")[-1] + "-onnx") Here are some additional examples for exporting Whisper with beam search. +To see all available options +``` +# From source: +$ python3 -m models.whisper.convert_to_onnx --help + +# From wheel: +$ python3 -m onnxruntime.transformers.models.whisper.convert_to_onnx --help +``` + Export with Forced Decoder Input Ids ``` # From source: -$ python3 -m models.whisper.convert_to_onnx -m openai/whisper-tiny --output whispertiny --use_external_data_format --use_forced_decoder_ids +$ python3 -m models.whisper.convert_to_onnx -m openai/whisper-large-v3 --output whisperlargev3 --use_external_data_format --use_forced_decoder_ids # From wheel: -$ python3 -m onnxruntime.transformers.models.whisper.convert_to_onnx -m openai/whisper-tiny --output whispertiny --use_external_data_format --use_forced_decoder_ids +$ python3 -m onnxruntime.transformers.models.whisper.convert_to_onnx -m openai/whisper-large-v3 --output whisperlargev3 --use_external_data_format --use_forced_decoder_ids ``` Export + Optimize for FP32 ``` # From source: -$ python3 -m models.whisper.convert_to_onnx -m openai/whisper-tiny --output whispertiny --use_external_data_format --optimize_onnx --precision fp32 +$ python3 -m models.whisper.convert_to_onnx -m openai/whisper-large-v3 --output whisperlargev3 --use_external_data_format --optimize_onnx --precision fp32 # From wheel: -$ python3 -m onnxruntime.transformers.models.whisper.convert_to_onnx -m openai/whisper-tiny --output whispertiny --use_external_data_format --optimize_onnx --precision fp32 +$ python3 -m onnxruntime.transformers.models.whisper.convert_to_onnx -m openai/whisper-large-v3 --output whisperlargev3 --use_external_data_format --optimize_onnx --precision fp32 ``` Export + Optimize for FP16 and GPU ``` # From source: -$ python3 -m models.whisper.convert_to_onnx -m openai/whisper-tiny --output whispertiny --use_external_data_format --optimize_onnx --precision fp16 --use_gpu --provider cuda --disable_auto_mixed_precision +$ python3 -m models.whisper.convert_to_onnx -m openai/whisper-large-v3 --output whisperlargev3 --use_external_data_format --optimize_onnx --precision fp16 --use_gpu --provider cuda --disable_auto_mixed_precision # From wheel: -$ python3 -m onnxruntime.transformers.models.whisper.convert_to_onnx -m openai/whisper-tiny --output whispertiny --use_external_data_format --optimize_onnx --precision fp16 --use_gpu --provider cuda --disable_auto_mixed_precision +$ python3 -m onnxruntime.transformers.models.whisper.convert_to_onnx -m openai/whisper-large-v3 --output whisperlargev3 --use_external_data_format --optimize_onnx --precision fp16 --use_gpu --provider cuda --disable_auto_mixed_precision ``` Export + Quantize for INT8 ``` # From source: -$ python3 -m models.whisper.convert_to_onnx -m openai/whisper-tiny --output whispertiny --use_external_data_format --precision int8 --quantize_embedding_layer +$ python3 -m models.whisper.convert_to_onnx -m openai/whisper-large-v3 --output whisperlargev3 --use_external_data_format --precision int8 --quantize_embedding_layer # From wheel: -$ python3 -m onnxruntime.transformers.models.whisper.convert_to_onnx -m openai/whisper-tiny --output whispertiny --use_external_data_format --precision int8 --quantize_embedding_layer +$ python3 -m onnxruntime.transformers.models.whisper.convert_to_onnx -m openai/whisper-large-v3 --output whisperlargev3 --use_external_data_format --precision int8 --quantize_embedding_layer ``` ## Benchmark Whisper diff --git a/onnxruntime/python/tools/transformers/models/whisper/benchmark.py b/onnxruntime/python/tools/transformers/models/whisper/benchmark.py index 759ae6d14f184..e57385aa6db8f 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/benchmark.py +++ b/onnxruntime/python/tools/transformers/models/whisper/benchmark.py @@ -1,3 +1,9 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + import argparse import ast import datetime @@ -54,6 +60,8 @@ def load_via_numpy(): inputs["decoder_input_ids"] = np.array([args.decoder_input_ids], dtype=np.int32) if args.has_logits_processor: inputs["logits_processor"] = np.array([args.logits_processor], dtype=np.int32) + if args.has_temperature: + inputs["temperature"] = np.array([args.temperature], dtype=np.float32) # Measure time taken to load audio file logger.info(f"Load audio: {args.audio_path}") @@ -163,6 +171,7 @@ def get_model(args: argparse.Namespace): def time_fn(args, fn, inputs): warmup_inputs = inputs[0] if type(inputs) is tuple else inputs benchmark_inputs = inputs[1] if type(inputs) is tuple else inputs + torch_device = torch.device(args.target_device) # Warm up warmup_range = ( @@ -180,7 +189,7 @@ def time_fn(args, fn, inputs): # Benchmark if args.device != "cpu": - torch.cuda.synchronize() + torch.cuda.synchronize(torch_device) start_time = time.time() bench_range = ( @@ -192,7 +201,7 @@ def time_fn(args, fn, inputs): fn(benchmark_inputs) if args.device != "cpu": - torch.cuda.synchronize() + torch.cuda.synchronize(torch_device) end_time = time.time() # Newline print after trange in order to print metrics on new lines without progress bar on same line @@ -500,7 +509,13 @@ def parse_args(): "--logits-processor", type=int, default=1, - help="Type of logits processor to use. See `BeamSearch` in https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/graph/contrib_ops/contrib_defs.cc for details.", + help="Whether to use timestamps logits processor or not (0 for false, 1 for true).", + ) + parser.add_argument( + "--temperature", + type=float, + default=1.0, + help="Temperature value for generation.", ) # Args for accessing detailed info @@ -581,6 +596,7 @@ def main(): args.has_audio_stream = "audio_stream" in ort_model_inputs setattr(args, "has_decoder_input_ids", "decoder_input_ids" in ort_model_inputs) # noqa: B010 setattr(args, "has_logits_processor", "logits_processor" in ort_model_inputs) # noqa: B010 + setattr(args, "has_temperature", "temperature" in ort_model_inputs) # noqa: B010 if args.decoder_input_ids == []: args.decoder_input_ids = [config.decoder_start_token_id] diff --git a/onnxruntime/python/tools/transformers/models/whisper/benchmark_all.py b/onnxruntime/python/tools/transformers/models/whisper/benchmark_all.py index d205a2d340721..814b0dd1ef6ac 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/benchmark_all.py +++ b/onnxruntime/python/tools/transformers/models/whisper/benchmark_all.py @@ -1,3 +1,9 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + import argparse import datetime import json diff --git a/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py b/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py index e15a12c07bed7..35211aab272e4 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py +++ b/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py @@ -28,17 +28,34 @@ def parse_arguments(argv=None): parser = argparse.ArgumentParser() - pretrained_models = PRETRAINED_WHISPER_MODELS - parser.add_argument( + conversion_args = parser.add_argument_group("Conversion Process Args") + optional_inputs = parser.add_argument_group("Optional Inputs (for WhisperBeamSearch op)") + optional_outputs = parser.add_argument_group("Optional Outputs (for WhisperBeamSearch op)") + quant_args = parser.add_argument_group("INT8 Quantization Args") + + ################################# + # Conversion options for Whisper + ################################# + + conversion_args.add_argument( "-m", "--model_name_or_path", required=False, default=PRETRAINED_WHISPER_MODELS[0], type=str, - help="Model path, or pretrained model name in the list: " + ", ".join(pretrained_models), + help="Model path, or pretrained model name in the list: " + ", ".join(PRETRAINED_WHISPER_MODELS), + ) + + conversion_args.add_argument( + "--model_impl", + required=False, + default="hf", + choices=["hf", "openai"], + type=str, + help="Select implementation for export of encoder and decoder subgraphs", ) - parser.add_argument( + conversion_args.add_argument( "--cache_dir", required=False, type=str, @@ -46,7 +63,7 @@ def parse_arguments(argv=None): help="Directory to cache pre-trained models", ) - parser.add_argument( + conversion_args.add_argument( "--output", required=False, type=str, @@ -54,19 +71,24 @@ def parse_arguments(argv=None): help="Output directory", ) - parser.add_argument( + conversion_args.add_argument( "-o", "--optimize_onnx", required=False, action="store_true", help="Use optimizer.py to optimize onnx model", ) - parser.set_defaults(optimize_onnx=False) + conversion_args.set_defaults(optimize_onnx=False) - parser.add_argument("--use_gpu", required=False, action="store_true", help="use GPU for inference") - parser.set_defaults(use_gpu=False) + conversion_args.add_argument( + "--use_gpu", + required=False, + action="store_true", + help="Use GPU for model inference", + ) + conversion_args.set_defaults(use_gpu=False) - parser.add_argument( + conversion_args.add_argument( "-p", "--precision", required=False, @@ -76,221 +98,226 @@ def parse_arguments(argv=None): help="Precision of model to run. fp32 for full precision, fp16 for half precision, int8 for quantization", ) - parser.add_argument("--verbose", required=False, action="store_true") - parser.set_defaults(verbose=False) - - parser.add_argument("-e", "--use_external_data_format", required=False, action="store_true") - parser.set_defaults(use_external_data_format=False) - - parser.add_argument( - "-s", - "--use_decoder_start_token", + conversion_args.add_argument( + "--use_int64_inputs", required=False, action="store_true", - help="Use config.decoder_start_token_id. Otherwise, add an extra graph input to \ - the encoder-decoder-init subgraph for decoder_input_ids.", + help="Use int64 instead of int32 for input_ids and attention_mask.", ) - parser.set_defaults(use_decoder_start_token=False) + conversion_args.set_defaults(use_int64_inputs=False) - parser.add_argument( - "-f", - "--use_forced_decoder_ids", + conversion_args.add_argument( + "--disable_auto_mixed_precision", required=False, action="store_true", - help="Use decoder_input_ids as an extra graph input to the beam search op", + help="Use pure fp16 instead of mixed precision", ) - parser.set_defaults(use_forced_decoder_ids=False) + conversion_args.set_defaults(disable_auto_mixed_precision=False) - parser.add_argument( - "-l", - "--use_logits_processor", + conversion_args.add_argument( + "-r", + "--provider", required=False, - action="store_true", - help="Use logits_processor as an extra graph input to enable specific logits processing", + type=str, + default="cpu", + choices=list(PROVIDERS.keys()), + help="Provider to benchmark. Default is CPUExecutionProvider.", ) - parser.set_defaults(use_specific_logits_processor=False) - parser.add_argument( - "-v", - "--use_vocab_mask", + conversion_args.add_argument( + "--verbose", required=False, action="store_true", - help="Use vocab_mask as an extra graph input to enable specific logits processing", + help="Enable verbose logging", ) - parser.set_defaults(use_vocab_mask=False) + conversion_args.set_defaults(verbose=False) - parser.add_argument( - "-u", - "--use_prefix_vocab_mask", + conversion_args.add_argument( + "-e", + "--use_external_data_format", required=False, action="store_true", - help="Use prefix_vocab_mask as an extra graph input to enable specific logits processing", + help="Save weights in external file. Necessary for 'small', 'medium', and 'large' models. Optional for 'tiny' and 'base' models.", ) - parser.set_defaults(use_prefix_vocab_mask=False) + conversion_args.set_defaults(use_external_data_format=False) - parser.add_argument( + conversion_args.add_argument( "-w", "--overwrite", required=False, action="store_true", - help="overwrite existing ONNX model", + help="Overwrite existing ONNX model", ) - parser.set_defaults(overwrite=False) + conversion_args.set_defaults(overwrite=False) - parser.add_argument( - "--disable_auto_mixed_precision", + conversion_args.add_argument( + "--separate_encoder_and_decoder_init", required=False, action="store_true", - help="use pure fp16 instead of mixed precision", + help="Do not merge encoder and decoder init to initialize past KV caches. Output 3 instead of 2 ONNX models.", ) - parser.set_defaults(disable_auto_mixed_precision=False) + conversion_args.set_defaults(separate_encoder_and_decoder_init=False) - parser.add_argument( - "--separate_encoder_and_decoder_init", + conversion_args.add_argument( + "--no_beam_search_op", required=False, action="store_true", - help="Do not merge encode and decoder init. Output 3 instead of 2 onnx models.", + help="Do not produce model with WhisperBeamSearch op, which chains encdecinit and decoder models into one op.", ) - parser.set_defaults(separate_encoder_and_decoder_init=False) + conversion_args.set_defaults(no_beam_search_op=False) - parser.add_argument( - "--use_int64_inputs", + conversion_args.add_argument( + "--state_dict_path", + type=str, + default="", + help="Filepath to load pre-trained model with custom state dictionary (e.g. pytorch_model.bin)", + ) + + ############################################################# + # Optional inputs for Whisper + # (listed below in the order that WhisperBeamSearch expects) + ############################################################# + + optional_inputs.add_argument( + "-v", + "--use_vocab_mask", required=False, action="store_true", - help="Use int64 instead of int32 for input_ids, position_ids and attention_mask.", + help="Use vocab_mask as an extra graph input to enable specific logits processing", ) - parser.set_defaults(use_int64_inputs=False) + optional_inputs.set_defaults(use_vocab_mask=False) - parser.add_argument( - "--chain_model", + optional_inputs.add_argument( + "-u", + "--use_prefix_vocab_mask", required=False, action="store_true", - help="Produce beam search model with chained encdecinit and decoder.", + help="Use prefix_vocab_mask as an extra graph input to enable specific logits processing", ) - parser.set_defaults(chain_model=True) + optional_inputs.set_defaults(use_prefix_vocab_mask=False) - parser.add_argument( - "--use_whisper_beamsearch", + optional_inputs.add_argument( + "-f", + "--use_forced_decoder_ids", required=False, action="store_true", - help="When chain_model, using WhisperBeamSearch operator rather than BeamSearch operator. \ - It will be set to true when collect_cross_qk, extra_decoding_ids or output_no_speech_probs is set.", + help="Use decoder_input_ids as an extra graph input to the beam search op", ) - parser.set_defaults(use_whisper_beamsearch=False) + optional_inputs.set_defaults(use_forced_decoder_ids=False) - parser.add_argument( - "--extra_decoding_ids", + optional_inputs.add_argument( + "-l", + "--use_logits_processor", required=False, action="store_true", - help="Need extra starting decoding ids for some feature like cross qk. Default if false.", + help="Use logits_processor as an extra graph input to enable specific logits processing", ) - parser.set_defaults(extra_decoding_ids=False) + optional_inputs.set_defaults(use_specific_logits_processor=False) - parser.add_argument( + optional_inputs.add_argument( "--collect_cross_qk", required=False, action="store_true", help="Beam search model collect stacked cross QK.", ) - parser.set_defaults(collect_cross_qk=False) + optional_inputs.set_defaults(collect_cross_qk=False) - parser.add_argument( - "--output_cross_qk", + optional_inputs.add_argument( + "--extra_decoding_ids", required=False, action="store_true", - help="Beam search model output collected qk as output. Also hint collect_cross_qk", + help="Need extra starting decoding ids for some feature like cross qk. Default if false.", + ) + optional_inputs.set_defaults(extra_decoding_ids=False) + + optional_inputs.add_argument( + "-t", + "--use_temperature", + required=False, + action="store_true", + help="Use temperature as an extra graph input for the WhisperBeamSearch op", ) - parser.set_defaults(output_cross_qk=False) + optional_inputs.set_defaults(use_temperature=False) - parser.add_argument( - "--no_speech_token_id", - default=50362, + optional_inputs.add_argument( + "--no_repeat_ngram_size", type=int, - help="specify no_speech_token_id. Default is 50362. if >= 0, will be add into beam search attr. \ - Note that default value maybe different between the multilingual and English-only models.", + default=0, + help="default to 0", ) - parser.add_argument( - "--output_no_speech_probs", + ############################################################# + # Optional outputs for Whisper + # (listed below in the order that WhisperBeamSearch expects) + ############################################################# + + optional_outputs.add_argument( + "--output_sequence_scores", required=False, action="store_true", - help="Beam search model output no speech probs which is computed from the encoder/context-decoder graph.", + help="Beam search model output scores for each generated sequence.", ) - parser.set_defaults(output_no_speech_probs=False) + optional_outputs.set_defaults(output_sequence_scores=False) - parser.add_argument( + optional_outputs.add_argument( "--output_scores", required=False, action="store_true", help="Beam search model output scores over vocab per generated token.", ) - parser.set_defaults(output_scores=False) + optional_outputs.set_defaults(output_scores=False) - parser.add_argument( - "--output_sequence_scores", + optional_outputs.add_argument( + "--output_cross_qk", required=False, action="store_true", - help="Beam search model output scores for each generated sequence.", + help="Beam search model output collected qk as output. Also hint collect_cross_qk", ) - parser.set_defaults(output_sequence_scores=False) + optional_outputs.set_defaults(output_cross_qk=False) - parser.add_argument( + optional_outputs.add_argument( "--cross_qk_onnx_model", required=False, type=str, default=None, - help="the model which consume cross_qk.", + help="The model which consumes cross_qk outputs.", ) - parser.add_argument( - "--beam_output_model", - type=str, - default="whisper_beamsearch.onnx", - help="default name is whisper_beamsearch.onnx.", + optional_outputs.add_argument( + "--output_no_speech_probs", + required=False, + action="store_true", + help="Beam search model output no speech probs which is computed from the encoder/context-decoder graph.", ) + optional_outputs.set_defaults(output_no_speech_probs=False) + + ################################### + # Quantization options for Whisper + ################################### - parser.add_argument( + quant_args.add_argument( "--quantize_embedding_layer", required=False, action="store_true", help="Quantize MatMul, GEMM, and Gather.", ) - parser.set_defaults(quantize_embedding_layer=False) + quant_args.set_defaults(quantize_embedding_layer=False) - parser.add_argument( + quant_args.add_argument( "--quantize_per_channel", required=False, action="store_true", help="Quantize weights per each channel.", ) - parser.set_defaults(quantize_per_channel=False) + quant_args.set_defaults(quantize_per_channel=False) - parser.add_argument( + quant_args.add_argument( "--quantize_reduce_range", required=False, action="store_true", help="Quantize weights with 7 bits.", ) - parser.set_defaults(quantize_reduce_range=False) - - parser.add_argument("--no_repeat_ngram_size", type=int, default=0, help="default to 0") - - parser.add_argument( - "--state_dict_path", - type=str, - default="", - help="filepath to load pre-trained model with custom state dictionary (e.g. pytorch_model.bin)", - ) - - parser.add_argument( - "-r", - "--provider", - required=False, - type=str, - default="cpu", - choices=list(PROVIDERS.keys()), - help="Provider to benchmark. Default is CPUExecutionProvider.", - ) + quant_args.set_defaults(quantize_reduce_range=False) args = parser.parse_args(argv) args.collect_cross_qk = args.collect_cross_qk or args.output_cross_qk @@ -300,6 +327,7 @@ def parse_arguments(argv=None): def export_onnx_models( model_name_or_path, + model_impl, cache_dir, output_dir, use_gpu, @@ -307,7 +335,7 @@ def export_onnx_models( optimize_onnx, precision, verbose, - use_decoder_start_token: bool = False, + use_forced_decoder_ids: bool = False, merge_encoder_and_decoder_init: bool = True, overwrite: bool = False, disable_auto_mixed_precision: bool = False, @@ -321,7 +349,7 @@ def export_onnx_models( device = torch.device("cuda:0" if use_gpu else "cpu") models = WhisperHelper.load_model( - model_name_or_path, cache_dir, device, merge_encoder_and_decoder_init, state_dict_path + model_name_or_path, model_impl, cache_dir, device, merge_encoder_and_decoder_init, state_dict_path ) config = models["decoder"].config @@ -352,7 +380,6 @@ def export_onnx_models( onnx_path, verbose, use_external_data_format, - use_decoder_input_ids=not use_decoder_start_token, use_int32_inputs=use_int32_inputs, ) else: @@ -396,7 +423,7 @@ def export_onnx_models( extra_options={"MatMulConstBOnly": True}, ) else: - logger.info(f"Skip optimizing: existed ONNX model {onnx_path}") + logger.info(f"Skip optimizing: existing ONNX model {onnx_path}") else: output_path = onnx_path @@ -431,6 +458,7 @@ def main(argv=None): output_paths = export_onnx_models( args.model_name_or_path, + args.model_impl, cache_dir, output_dir, args.use_gpu, @@ -438,7 +466,7 @@ def main(argv=None): args.optimize_onnx, args.precision, args.verbose, - args.use_decoder_start_token, + args.use_forced_decoder_ids, not args.separate_encoder_and_decoder_init, args.overwrite, args.disable_auto_mixed_precision, @@ -451,7 +479,7 @@ def main(argv=None): ) max_diff = 0 - if args.chain_model: + if not args.no_beam_search_op: logger.info("Chaining model ... :") args.beam_model_output_dir = WhisperHelper.get_onnx_path( output_dir, diff --git a/onnxruntime/python/tools/transformers/models/whisper/requirements-cpu.txt b/onnxruntime/python/tools/transformers/models/whisper/requirements-cpu.txt new file mode 100644 index 0000000000000..db2cd95324328 --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/whisper/requirements-cpu.txt @@ -0,0 +1,2 @@ +-r requirements.txt +onnxruntime>=1.17.1 \ No newline at end of file diff --git a/onnxruntime/python/tools/transformers/models/whisper/requirements-cuda.txt b/onnxruntime/python/tools/transformers/models/whisper/requirements-cuda.txt new file mode 100644 index 0000000000000..9bd215de9bc09 --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/whisper/requirements-cuda.txt @@ -0,0 +1,4 @@ +-r requirements.txt +# Please manually install torch>=1.13.0 with CUDA enabled for the CUDA version installed in your system. +# Instructions can be found here: https://pytorch.org/get-started/locally/ +onnxruntime-gpu>=1.17.1 diff --git a/onnxruntime/python/tools/transformers/models/whisper/requirements.txt b/onnxruntime/python/tools/transformers/models/whisper/requirements.txt new file mode 100644 index 0000000000000..c307a3665f8a0 --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/whisper/requirements.txt @@ -0,0 +1,11 @@ +torch>=1.13.0 +transformers>=4.24.0 +openai-whisper +ffmpeg-python +datasets +soundfile +librosa +optimum +onnxruntime-extensions>=0.9.0 +protobuf==3.20.2 +numpy==1.23.3 \ No newline at end of file diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_chain.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_chain.py index 33958e55f8c38..0b128f122e0f4 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_chain.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_chain.py @@ -1,3 +1,9 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + import logging import os @@ -9,7 +15,7 @@ update_decoder_subgraph_share_buffer_and_use_decoder_masked_mha, ) from onnx import TensorProto, helper -from transformers import WhisperConfig +from transformers import WhisperConfig, WhisperTokenizer logger = logging.getLogger(__name__) @@ -23,11 +29,22 @@ def verify_inputs(beam_inputs, graph_inputs): assert graph_input.name in beam_input +def clean_list(arr, remove_all_strings=True): + if remove_all_strings: + # Remove all empty strings in list + return list(filter(lambda elm: elm != "", arr)) + + # Remove empty strings at end of list + while len(arr) > 0: + if arr[-1] == "": + arr.pop() + else: + break + return arr + + def chain_model(args): - # Load encoder/decoder and insert necessary (but unused) graph inputs expected by BeamSearch op or WhisperBeamSearch op - args.use_whisper_beamsearch = ( - args.use_whisper_beamsearch or args.collect_cross_qk or args.output_no_speech_probs or args.extra_decoding_ids - ) + # Load encoder/decoder and insert necessary (but unused) graph inputs expected by WhisperBeamSearch op encoder_model = onnx.load_model(args.encoder_path, load_external_data=True) encoder_model.graph.name = "encoderdecoderinit subgraph" @@ -35,7 +52,10 @@ def chain_model(args): decoder_model.graph.name = "decoder subgraph" config = WhisperConfig.from_pretrained(args.model_name_or_path) + tokenizer = WhisperTokenizer.from_pretrained(args.model_name_or_path) + # Create inputs/outputs for WhisperBeamSearch op + temperature_name = "temperature_fp16" if args.precision == Precision.FLOAT16 else "temperature" beam_inputs = [ "input_features_fp16" if args.precision == Precision.FLOAT16 else "input_features", "max_length", @@ -44,37 +64,27 @@ def chain_model(args): "num_return_sequences", "length_penalty_fp16" if args.precision == Precision.FLOAT16 else "length_penalty", "repetition_penalty_fp16" if args.precision == Precision.FLOAT16 else "repetition_penalty", - "vocab_mask" if args.use_prefix_vocab_mask else "", + "vocab_mask" if args.use_vocab_mask else "", "prefix_vocab_mask" if args.use_prefix_vocab_mask else "", "", # attention mask "decoder_input_ids" if args.use_forced_decoder_ids else "", "logits_processor" if args.use_logits_processor else "", + "cross_qk_layer_head" if args.collect_cross_qk else "", + "extra_decoding_ids" if args.extra_decoding_ids else "", + temperature_name if args.use_temperature else "", ] - beam_outputs = ["sequences"] - if args.output_sequence_scores: - beam_outputs.append("sequence_scores") - if args.output_scores: - beam_outputs.append("scores") - - if args.use_whisper_beamsearch: - assert len(beam_inputs) == 12 - beam_inputs.extend( - [ - "cross_qk_layer_head" if args.collect_cross_qk else "", - "extra_decoding_ids" if args.extra_decoding_ids else "", - ] - ) - if args.collect_cross_qk: - while len(beam_outputs) < 3: - beam_outputs.extend([""]) - beam_outputs.extend(["cross_qk"]) - if args.output_no_speech_probs: - while len(beam_outputs) < 4: - beam_outputs.extend([""]) - beam_outputs.extend(["no_speech_probs_beam"]) - - input_features_cast_node, len_pen_cast_node, rep_pen_cast_node = None, None, None + sequence_scores_name = "sequence_scores_fp16" if args.precision == Precision.FLOAT16 else "sequence_scores" + scores_name = "scores_fp16" if args.precision == Precision.FLOAT16 else "scores" + beam_outputs = [ + "sequences", + sequence_scores_name if args.output_sequence_scores else "", + scores_name if args.output_scores else "", + "cross_qk" if args.collect_cross_qk else "", + "no_speech_probs_beam" if args.output_no_speech_probs else "", + ] + + graph_nodes = [] if args.precision == Precision.FLOAT16: input_features_cast_node = helper.make_node( "Cast", @@ -97,26 +107,70 @@ def chain_model(args): name="CastRepetitionPenaltyToFp16", to=TensorProto.FLOAT16, ) - - operator_type = "WhisperBeamSearch" if args.use_whisper_beamsearch else "BeamSearch" - node = helper.make_node(operator_type, inputs=beam_inputs, outputs=beam_outputs, name="BeamSearch_zcode") - node.domain = "com.microsoft" - node.attribute.extend( - [ - helper.make_attribute("eos_token_id", config.eos_token_id), - helper.make_attribute("pad_token_id", config.pad_token_id), - helper.make_attribute("decoder_start_token_id", config.decoder_start_token_id), - helper.make_attribute("no_repeat_ngram_size", args.no_repeat_ngram_size), - helper.make_attribute("early_stopping", True), - helper.make_attribute("model_type", 2), - ] + graph_nodes.extend([input_features_cast_node, len_pen_cast_node, rep_pen_cast_node]) + + if args.use_temperature: + temp_cast_node = helper.make_node( + "Cast", + inputs=["temperature"], + outputs=["temperature_fp16"], + name="temperature_to_fp16", + to=TensorProto.FLOAT16, + ) + graph_nodes.append(temp_cast_node) + + if args.output_sequence_scores: + output_sequence_scores_cast_node = helper.make_node( + "Cast", + inputs=["sequence_scores_fp16"], + outputs=["sequence_scores"], + name="CastOutputSequenceScoresToFp32", + to=TensorProto.FLOAT, + ) + graph_nodes.append(output_sequence_scores_cast_node) + + if args.output_scores: + output_scores_cast_node = helper.make_node( + "Cast", + inputs=["scores_fp16"], + outputs=["scores"], + name="CastScoresToFp32", + to=TensorProto.FLOAT, + ) + graph_nodes.append(output_scores_cast_node) + + # Create WhisperBeamSearch op + beam_search_attrs = [ + helper.make_attribute("eos_token_id", config.eos_token_id), + helper.make_attribute("pad_token_id", config.pad_token_id), + helper.make_attribute( + "decoder_start_token_id", config.decoder_start_token_id + ), # same as tokenizer.convert_tokens_to_ids(['<|startoftranscript|>'])[0] + helper.make_attribute("translate_token_id", tokenizer.convert_tokens_to_ids(["<|translate|>"])[0]), + helper.make_attribute("transcribe_token_id", tokenizer.convert_tokens_to_ids(["<|transcribe|>"])[0]), + helper.make_attribute("start_of_lm_token_id", tokenizer.convert_tokens_to_ids(["<|startoflm|>"])[0]), + ( + helper.make_attribute("no_speech_token_id", tokenizer.convert_tokens_to_ids(["<|nospeech|>"])[0]) + if args.output_no_speech_probs + else "" + ), + helper.make_attribute("no_timestamps_token_id", tokenizer.convert_tokens_to_ids(["<|notimestamps|>"])[0]), + helper.make_attribute("beginning_timestamp_token_id", tokenizer.convert_tokens_to_ids(["<|0.00|>"])[0]), + helper.make_attribute("no_repeat_ngram_size", args.no_repeat_ngram_size), + helper.make_attribute("early_stopping", True), + helper.make_attribute("model_type", 2), + helper.make_attribute("decoder_output_cross_qk", 1) if args.collect_cross_qk else "", + ] + node = helper.make_node( + "WhisperBeamSearch", + inputs=clean_list(beam_inputs, remove_all_strings=False), + outputs=clean_list(beam_outputs, remove_all_strings=False), + name="BeamSearch", + domain="com.microsoft", ) - if args.use_whisper_beamsearch: - if args.collect_cross_qk: - node.attribute.extend([helper.make_attribute("decoder_output_cross_qk", 1)]) - if args.no_speech_token_id >= 0: - node.attribute.extend([helper.make_attribute("no_speech_token", args.no_speech_token_id)]) + node.attribute.extend(clean_list(beam_search_attrs, remove_all_strings=True)) + # Graph inputs input_features = helper.make_tensor_value_info( "input_features", TensorProto.FLOAT, ["batch_size", "feature_size", "sequence_length"] ) @@ -126,73 +180,63 @@ def chain_model(args): num_return_sequences = helper.make_tensor_value_info("num_return_sequences", TensorProto.INT32, [1]) length_penalty = helper.make_tensor_value_info("length_penalty", TensorProto.FLOAT, [1]) repetition_penalty = helper.make_tensor_value_info("repetition_penalty", TensorProto.FLOAT, [1]) + vocab_mask = helper.make_tensor_value_info("vocab_mask", TensorProto.INT32, [config.vocab_size]) + prefix_vocab_mask = helper.make_tensor_value_info( + "prefix_vocab_mask", TensorProto.INT32, ["batch_size", config.vocab_size] + ) + decoder_input_ids = helper.make_tensor_value_info( + "decoder_input_ids", TensorProto.INT32, ["batch_size", "initial_sequence_length"] + ) + logits_processor = helper.make_tensor_value_info("logits_processor", TensorProto.INT32, [1]) + cross_qk_layer_head = helper.make_tensor_value_info("cross_qk_layer_head", TensorProto.INT32, ["num_layer_head", 2]) + extra_decoding_ids = helper.make_tensor_value_info( + "extra_decoding_ids", TensorProto.INT32, ["batch_size", "extra_decoding_ids_len"] + ) + temperature = helper.make_tensor_value_info("temperature", TensorProto.FLOAT, [1]) - graph_inputs = [ - input_features, - max_length, - min_length, - num_beams, - num_return_sequences, - length_penalty, - repetition_penalty, - ] - if args.use_vocab_mask: - vocab_mask = helper.make_tensor_value_info("vocab_mask", TensorProto.INT32, [config.vocab_size]) - graph_inputs.append(vocab_mask) - - if args.use_prefix_vocab_mask: - prefix_vocab_mask = helper.make_tensor_value_info( - "prefix_vocab_mask", TensorProto.INT32, ["batch_size", config.vocab_size] - ) - graph_inputs.append(prefix_vocab_mask) - - if args.use_forced_decoder_ids: - decoder_input_ids = helper.make_tensor_value_info( - "decoder_input_ids", TensorProto.INT32, ["batch_size", "initial_sequence_length"] - ) - graph_inputs.append(decoder_input_ids) - - if args.use_logits_processor: - logits_processor = helper.make_tensor_value_info("logits_processor", TensorProto.INT32, [1]) - graph_inputs.append(logits_processor) - - if args.collect_cross_qk: - cross_qk_layer_head = helper.make_tensor_value_info( - "cross_qk_layer_head", TensorProto.INT32, ["num_layer_head", 2] - ) - graph_inputs.append(cross_qk_layer_head) - - if args.extra_decoding_ids: - extra_decoding_ids = helper.make_tensor_value_info( - "extra_decoding_ids", TensorProto.INT32, ["batch_size", "extra_decoding_ids_len"] - ) - graph_inputs.append(extra_decoding_ids) + graph_inputs = clean_list( + [ + input_features, + max_length, + min_length, + num_beams, + num_return_sequences, + length_penalty, + repetition_penalty, + vocab_mask if args.use_vocab_mask else "", + prefix_vocab_mask if args.use_prefix_vocab_mask else "", + decoder_input_ids if args.use_forced_decoder_ids else "", + logits_processor if args.use_logits_processor else "", + cross_qk_layer_head if args.collect_cross_qk else "", + extra_decoding_ids if args.extra_decoding_ids else "", + temperature if args.use_temperature else "", + ] + ) - # graph outputs + # Graph outputs sequences = helper.make_tensor_value_info( "sequences", TensorProto.INT32, ["batch_size", "num_return_sequences", "max_length"] ) - graph_outputs = [sequences] - if args.output_cross_qk or (not args.cross_qk_onnx_model and args.collect_cross_qk): - cross_qk = helper.make_tensor_value_info( - "cross_qk", - TensorProto.FLOAT, - ["batch_size", "num_return_sequences", "num_layer_head_cross_qk", "max_length", "frames"], - ) - graph_outputs.extend([cross_qk]) - - if args.output_no_speech_probs: - no_speech_probs = helper.make_tensor_value_info("no_speech_probs", TensorProto.FLOAT, ["batch_size"]) - graph_outputs.extend([no_speech_probs]) - - if args.output_sequence_scores: - sequence_scores = helper.make_tensor_value_info("sequence_scores", TensorProto.FLOAT, ["batch_size"]) - graph_outputs.extend([sequence_scores]) + sequence_scores = helper.make_tensor_value_info("sequence_scores", TensorProto.FLOAT, ["batch_size"]) + scores = helper.make_tensor_value_info("scores", TensorProto.FLOAT, ["batch_size"]) + cross_qk = helper.make_tensor_value_info( + "cross_qk", + TensorProto.FLOAT, + ["batch_size", "num_return_sequences", "num_layer_head_cross_qk", "max_length", "frames"], + ) + no_speech_probs = helper.make_tensor_value_info("no_speech_probs", TensorProto.FLOAT, ["batch_size"]) - if args.output_scores: - scores = helper.make_tensor_value_info("scores", TensorProto.FLOAT, ["batch_size"]) - graph_outputs.extend([scores]) + graph_outputs = clean_list( + [ + sequences, + sequence_scores if args.output_sequence_scores else "", + scores if args.output_scores else "", + cross_qk if args.output_cross_qk or (not args.cross_qk_onnx_model and args.collect_cross_qk) else "", + no_speech_probs if args.output_no_speech_probs else "", + ] + ) + # Replace MultiHeadAttention with DecoderMaskedMultiHeadAttention for CUDA EP inference if hasattr(args, "use_gpu") and args.use_gpu: if update_decoder_subgraph_share_buffer_and_use_decoder_masked_mha(decoder_model.graph): logger.info("Updated whisper decoder subgraph to use DecoderMaskedMultiHeadAttention successfully!") @@ -213,11 +257,7 @@ def chain_model(args): opset_import = [helper.make_opsetid(domain="com.microsoft", version=1), helper.make_opsetid(domain="", version=17)] - graph_nodes = ( - [input_features_cast_node, len_pen_cast_node, rep_pen_cast_node, node] - if args.precision == Precision.FLOAT16 - else [node] - ) + graph_nodes.append(node) if args.output_no_speech_probs: prob_cast_node = helper.make_node( "Cast", @@ -226,9 +266,16 @@ def chain_model(args): name="no_speech_probs_cast_to_fp32", to=TensorProto.FLOAT, ) - graph_nodes.extend([prob_cast_node]) - - beam_graph = helper.make_graph(graph_nodes, "beam-search-test", graph_inputs, graph_outputs, initializers) + graph_nodes.append(prob_cast_node) + + # Make graph with WhisperBeamSearch op + beam_graph = helper.make_graph( + graph_nodes, + name="WhisperBeamSearch Graph", + inputs=graph_inputs, + outputs=graph_outputs, + initializer=initializers, + ) beam_graph_input_names = [gi.name for gi in graph_inputs] beam_graph_output_names = [go.name for go in graph_outputs] @@ -262,10 +309,12 @@ def chain_model(args): ir_version=decoder_model.ir_version, ) + # Save WhisperBeamSearch graph and external data if os.path.isfile(args.beam_model_output_dir): logger.info(f"Overwriting {args.beam_model_output_dir} and {args.beam_model_output_dir + '.data'}") os.remove(args.beam_model_output_dir) os.remove(args.beam_model_output_dir + ".data") + onnx.save( beam_model, args.beam_model_output_dir, diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_decoder.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_decoder.py index eca5ce3de15d3..93fd64c9eb7d3 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_decoder.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_decoder.py @@ -18,6 +18,7 @@ from onnx_model import OnnxModel from torch_onnx_export_helper import torch_onnx_export from transformers import WhisperConfig, file_utils +from whisper_openai_helper import WhisperDecoderInitOpenai from onnxruntime import InferenceSession @@ -67,10 +68,13 @@ def forward( class WhisperDecoder(torch.nn.Module): """A Whisper decoder with past key values""" - def __init__(self, decoder, config): + def __init__(self, decoder, config, model_impl: str = "hf", model: torch.nn.Module = None): super().__init__() self.decoder = decoder self.config = config + self.model_impl = model_impl + if model is not None: + self.whisper_decoder_openai_init = WhisperDecoderInitOpenai(model, decoder) def forward(self, decoder_input_ids, *past): encoder_outputs = file_utils.ModelOutput() @@ -78,6 +82,14 @@ def forward(self, decoder_input_ids, *past): encoder_outputs["last_hidden_state"] = dummy_encoder_hidden_states encoder_outputs["hidden_states"] = dummy_encoder_hidden_states encoder_outputs["attentions"] = None + + if self.model_impl == "openai": + dummy_encoder_hidden_states.unsqueeze(0) + dec_out, present = self.whisper_decoder_openai_init( + decoder_input_ids, dummy_encoder_hidden_states, past=past + ) + return dec_out, present + if len(past) == 0: past_key_values = None else: @@ -213,7 +225,7 @@ def export_onnx( decoder.config, batch_size=2, encode_sequence_length=3000, - past_decode_sequence_length=5 if isinstance(decoder, WhisperDecoder) else 0, + past_decode_sequence_length=6 if isinstance(decoder, WhisperDecoder) else 0, device=device, use_int32_inputs=use_int32_inputs, ) diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_encoder.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_encoder.py index 826d6e42c0775..93281848a5c9c 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_encoder.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_encoder.py @@ -25,12 +25,15 @@ class WhisperEncoder(torch.nn.Module): """Whisper encoder outputs only the last hidden state""" - def __init__(self, encoder, config: WhisperConfig): + def __init__(self, encoder, config: WhisperConfig, model_impl: str = "hf"): super().__init__() self.encoder = encoder self.config = config + self.model_impl = model_impl def forward(self, input_features): + if self.model_impl == "openai": + return self.encoder(input_features) return self.encoder.model.encoder(input_features)[0] @@ -40,7 +43,11 @@ def __init__(self, input_features): @staticmethod def create_dummy( - batch_size: int, sequence_length: int, feature_size: int, device: torch.device, use_int32_inputs: bool + batch_size: int, + sequence_length: int, + feature_size: int, + device: torch.device, + use_int32_inputs: bool = False, ): """Create dummy inputs for Whisper encoder. @@ -61,9 +68,9 @@ def create_dummy( return WhisperEncoderInputs(input_features) def to_list(self) -> List: - if self.input_features is None: + if self.input_ids is None: return [] - return [self.input_features] + return [self.input_ids] class WhisperEncoderHelper: @@ -74,6 +81,7 @@ def export_onnx( onnx_model_path: str, verbose: bool = True, use_external_data_format: bool = False, + use_int32_inputs: bool = False, ): """Export encoder to ONNX @@ -90,6 +98,7 @@ def export_onnx( sequence_length=3000, feature_size=config.num_mel_bins, device=device, + use_int32_inputs=use_int32_inputs, ) Path(onnx_model_path).parent.mkdir(parents=True, exist_ok=True) diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_encoder_decoder_init.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_encoder_decoder_init.py index a145178dbf37e..832f692e9980d 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_encoder_decoder_init.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_encoder_decoder_init.py @@ -4,6 +4,7 @@ # license information. # -------------------------------------------------------------------------- +import copy import logging import os import tempfile @@ -19,6 +20,7 @@ from transformers import WhisperConfig from whisper_decoder import WhisperDecoderInit from whisper_encoder import WhisperEncoder, WhisperEncoderInputs +from whisper_openai_helper import WhisperDecoderInitOpenai from onnxruntime import InferenceSession @@ -34,11 +36,16 @@ def __init__( decoder: torch.nn.Module, config: WhisperConfig, decoder_start_token_id: Optional[int] = None, + model_impl: str = "hf", + model: torch.nn.Module = None, ): super().__init__() self.config = config - self.whisper_encoder = WhisperEncoder(encoder, config) + self.whisper_encoder = WhisperEncoder(encoder, config, model_impl=model_impl) self.whisper_decoder_init = WhisperDecoderInit(decoder, config, decoder_start_token_id) + if model is not None: + self.whisper_decoder_openai_init = WhisperDecoderInitOpenai(model, decoder) + self.model_impl = model_impl def forward( self, @@ -47,9 +54,14 @@ def forward( ): encoder_hidden_states: torch.FloatTensor = self.whisper_encoder(encoder_input_ids) # Decoder out: (logits, past_key_values, encoder_hidden_state) - decinit_out = self.whisper_decoder_init(decoder_input_ids, encoder_hidden_states) - present_self, present_cross = PastKeyValuesHelper.group_by_self_and_cross(decinit_out[1]) - present = present_self + present_cross + if self.model_impl == "openai": + encoder_hidden_states.unsqueeze(0) + decinit_out, present = self.whisper_decoder_openai_init(decoder_input_ids, encoder_hidden_states) + return decinit_out, encoder_hidden_states, present + else: + decinit_out = self.whisper_decoder_init(decoder_input_ids, encoder_hidden_states) + present_self, present_cross = PastKeyValuesHelper.group_by_self_and_cross(decinit_out[1]) + present = present_self + present_cross return decinit_out[0], encoder_hidden_states, present @@ -63,7 +75,7 @@ def create_dummy( config: WhisperConfig, batch_size: int, encode_sequence_length: int, - use_decoder_input_ids: int, + use_decoder_input_ids: bool, device: torch.device, use_int32_inputs: bool = False, ): # -> WhisperEncoderDecoderInitInputs: @@ -72,7 +84,6 @@ def create_dummy( sequence_length=3000, feature_size=config.num_mel_bins, device=device, - use_int32_inputs=use_int32_inputs, ) decoder_input_ids = None if use_decoder_input_ids: @@ -114,13 +125,15 @@ def export_onnx( model.config, batch_size=2, encode_sequence_length=3000, - use_decoder_input_ids=use_decoder_input_ids, + use_decoder_input_ids=True, device=device, use_int32_inputs=use_int32_inputs, ) input_list = inputs.to_list() - out = model(inputs.encoder_input_ids, inputs.decoder_input_ids) + # TODO : Investigate whether copy of model if needed + cloned_model = copy.deepcopy(model).to(device) + out = cloned_model(inputs.encoder_input_ids, inputs.decoder_input_ids) present = out[2] present_names = PastKeyValuesHelper.get_input_names(present, encoder=True) @@ -146,7 +159,7 @@ def export_onnx( hidden_size = str(model.config.d_model) head_size = str(model.config.d_model // model.config.encoder_attention_heads) dynamic_axes = { - "encoder_input_ids": {0: "batch_size", 1: "encode_sequence_length"}, + "encoder_input_ids": {0: "batch_size", 1: "feature_size"}, "encoder_hidden_states": { 0: "batch_size", 1: "encode_sequence_length", diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py index a4bef1f06b4fe..a1d0d7fb3deeb 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py @@ -6,12 +6,14 @@ import logging import os -import sys from pathlib import Path from typing import Dict, Tuple, Union import numpy as np import torch +from float16 import float_to_float16_max_diff +from onnx_model import OnnxModel +from optimizer import optimize_model from packaging import version from transformers import WhisperConfig, WhisperForConditionalGeneration, WhisperProcessor from transformers import __version__ as transformers_version @@ -21,24 +23,20 @@ from onnxruntime import InferenceSession -sys.path.append(os.path.join(os.path.dirname(__file__), "..", "..")) -from float16 import float_to_float16_max_diff # noqa: E402 -from onnx_model import OnnxModel # noqa: E402 -from optimizer import optimize_model # noqa: E402 - logger = logging.getLogger(__name__) PRETRAINED_WHISPER_MODELS = [ "whisper-tiny", "whisper-tiny.en", + "whisper-base", + "whisper-base.en", "whisper-small", "whisper-small.en", "whisper-medium", "whisper-medium.en", - "whisper-base", - "whisper-base.en", "whisper-large", "whisper-large-v2", + "whisper-large-v3", ] @@ -72,9 +70,49 @@ def get_onnx_path( directory = os.path.join(output_dir, model_name) if new_folder else output_dir return os.path.join(directory, model_name + ".onnx") + @staticmethod + def load_model_openai( + model_name_or_path: str, + cache_dir: str, + device: torch.device, + ) -> torch.nn.Module: + """Load model given a pretrained name or path, then build models for ONNX conversion. + + Args: + model_name_or_path (str): pretrained model name or path + cache_dir (str): cache directory + device (torch.device): device to run the model + merge_encoder_and_decoder_init (bool, optional): Whether merge encoder and decoder initialization into one ONNX model. Defaults to True. + Returns: + Dict[str, torch.nn.Module]: mapping from name to modules for ONNX conversion. + """ + from whisper import _ALIGNMENT_HEADS, _MODELS, _download + from whisper.model import ModelDimensions, Whisper + + in_memory = False + + model_name = model_name_or_path.split("/")[-1][8:] + checkpoint_file, alignment_heads = None, None + if model_name in _MODELS: + checkpoint_file = _download(_MODELS[model_name], cache_dir, in_memory) + alignment_heads = _ALIGNMENT_HEADS[model_name] + + with open(checkpoint_file, "rb") as fp: + checkpoint = torch.load(fp, map_location=device) + del checkpoint_file + + dims = ModelDimensions(**checkpoint["dims"]) + model = Whisper(dims) + model.load_state_dict(checkpoint["model_state_dict"]) + + if alignment_heads is not None: + model.set_alignment_heads(alignment_heads) + return model.to(device) + @staticmethod def load_model( model_name_or_path: str, + model_impl: str, cache_dir: str, device: torch.device, merge_encoder_and_decoder_init: bool = True, @@ -94,18 +132,29 @@ def load_model( if version.parse(transformers_version) >= version.parse("4.36.0"): extra_kwargs["attn_implementation"] = "eager" model = WhisperForConditionalGeneration.from_pretrained(model_name_or_path, cache_dir=cache_dir, **extra_kwargs) + + if model_impl == "openai": + openai_model = WhisperHelper.load_model_openai(model_name_or_path, cache_dir, device) + model_encoder, model_decoder = openai_model.encoder, openai_model.decoder + passed_model = openai_model + else: + model_encoder, model_decoder = model, model + passed_model = None + if state_dict_path: model.load_state_dict(torch.load(state_dict_path), strict=False) - decoder = WhisperDecoder(model, model.config) + decoder = WhisperDecoder(model_decoder, model.config, model_impl=model_impl, model=passed_model) decoder.eval().to(device) if merge_encoder_and_decoder_init: encoder_decoder_init = WhisperEncoderDecoderInit( - model, - model, + model_encoder, + model_decoder, model.config, decoder_start_token_id=None, + model_impl=model_impl, + model=passed_model, ) return {"encoder_decoder_init": encoder_decoder_init, "decoder": decoder} else: @@ -295,7 +344,12 @@ def verify_onnx( ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") input_features = processor([ds[0]["audio"]["array"]], return_tensors="pt").input_features - batch_size, max_length, min_length, num_beams, num_return_sequences = 1, 26, 0, 5, 1 + start_id = [config.decoder_start_token_id] # ex: [50258] + prompt_ids = processor.get_decoder_prompt_ids(language="english", task="transcribe") + prompt_ids = list(map(lambda token: token[1], prompt_ids)) # ex: [50259, 50358, 50363] + forced_decoder_ids = start_id + prompt_ids # ex: [50258, 50259, 50358, 50363] + + batch_size, max_length, min_length, num_beams, num_return_sequences = 1, 30, 0, 1, 1 length_penalty, repetition_penalty = 1.0, 1.0 inputs = { "input_features": input_features.to(device), @@ -332,43 +386,51 @@ def verify_onnx( elif name == "prefix_vocab_mask": inputs[name] = np.ones((batch_size, config.vocab_size), dtype=ort_to_np[dtype]) elif name == "decoder_input_ids": - raw_input_ids = ( - [[config.decoder_start_token_id]] - if use_extra_decoding_ids - else [[config.decoder_start_token_id, 50259, 50359, 50363]] - ) + raw_input_ids = [start_id] if use_extra_decoding_ids else [forced_decoder_ids] inputs[name] = np.array(raw_input_ids, dtype=ort_to_np[dtype]) elif name == "logits_processor": inputs[name] = np.array([1], dtype=ort_to_np[dtype]) elif name == "cross_qk_layer_head": inputs[name] = np.array([[0, 0]], dtype=ort_to_np[dtype]) elif name == "extra_decoding_ids": - inputs[name] = np.repeat(np.array([[50259, 50359, 50363]], dtype=ort_to_np[dtype]), batch_size, 0) + inputs[name] = np.repeat(np.array([prompt_ids], dtype=ort_to_np[dtype]), batch_size, 0) + elif name == "temperature": + inputs[name] = np.array([1.0], dtype=ort_to_np[dtype]) else: inputs[name] = np.array([inputs[name]], dtype=ort_to_np[dtype]) ort_outputs = ort_session.run(None, inputs)[0][0] - if pt_outputs.shape != ort_outputs.shape: - logger.warning("PyTorch and ONNX Runtime outputs do not have the same shape") + expected_transcription_no_comma = ( + " Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel." + ) + expected_transcription_with_comma = ( + " Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel." + ) + expected_transcription_with_quote_and_comma = ( + ' "Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.' + ) + expected_transcription_options = { + expected_transcription_no_comma, + expected_transcription_with_comma, + expected_transcription_with_quote_and_comma, + } + pt_transcription = processor.batch_decode(pt_outputs, skip_special_tokens=True)[0] + ort_transcription = processor.batch_decode(ort_outputs, skip_special_tokens=True)[0] - diff = pt_outputs - ort_outputs - max_diff = max(diff.min(), diff.max(), key=abs) + parity = ( + pt_transcription in expected_transcription_options and ort_transcription in expected_transcription_options + ) + max_diff = 0 - if max_diff > 0: - # For ONNX Runtime INT8 model - pt_expected_transcription = ( - " Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel." - ) - pt_transcription = processor.batch_decode(pt_outputs, skip_special_tokens=True) - ort_expected_transcription = ( - " Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel." - ) - ort_transcription = processor.batch_decode(ort_outputs, skip_special_tokens=True) + if not parity: + if pt_outputs.shape != ort_outputs.shape: + diff = pt_outputs - ort_outputs[:, : len(pt_outputs[0])] + else: + diff = pt_outputs - ort_outputs + max_diff = max(diff.min(), diff.max(), key=abs) - parity = ( - pt_expected_transcription == pt_transcription[0] and ort_expected_transcription == ort_transcription[0] - ) - if parity: - max_diff = 0 + if max_diff != 0: + logger.warning(f"PyTorch outputs: {pt_transcription}") + logger.warning(f"ONNX Runtime outputs: {ort_transcription}") return max_diff diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_openai_helper.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_openai_helper.py new file mode 100644 index 0000000000000..941f61cf7cc29 --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_openai_helper.py @@ -0,0 +1,76 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import logging + +import torch + +logger = logging.getLogger(__name__) + + +class WhisperDecoderInitOpenai(torch.nn.Module): + """WhisperDecoderInit for Openai.""" + + def __init__( + self, + model: torch.nn.Module, + decoder: torch.nn.Module, + ): + super().__init__() + self.whisper_model = model + self.whisper_decoder = decoder + self.kv_cache = {} + + @torch.no_grad() + def forward( + self, + tokens, + audio_features, + past=None, + ): + # Create a kv_cache for past_values + past_kv_cache = dict() + if past is not None: + # Convert past values from 4D to 3D + past = [torch.transpose(val, 1, 2) for val in past] + past = [val.reshape(val.shape[:2] + (-1,)) for val in past] + half_idx = len(past) // 2 + for idx, block in enumerate(self.whisper_decoder.blocks): + past_kv_cache[block.attn.key] = past[2 * idx] + past_kv_cache[block.attn.value] = past[2 * idx + 1] + past_kv_cache[block.cross_attn.key] = past[2 * idx + half_idx] + past_kv_cache[block.cross_attn.value] = past[2 * idx + half_idx + 1] + + if not self.kv_cache: + self.kv_cache, _ = self.whisper_model.install_kv_cache_hooks() + + logits = self.whisper_decoder(tokens, audio_features, kv_cache=past_kv_cache) + + # Add concat node for past values + if past is not None: + for block in self.whisper_decoder.blocks: + self.kv_cache[block.attn.key] = torch.cat( + [past_kv_cache[block.attn.key], self.kv_cache[block.attn.key]], dim=1 + ).detach() + self.kv_cache[block.attn.value] = torch.cat( + [past_kv_cache[block.attn.value], self.kv_cache[block.attn.value]], dim=1 + ).detach() + + present_self, present_cross = [], [] + # Group self and cross values + for block in self.whisper_decoder.blocks: + present_self.append(self.kv_cache[block.attn.key]) + present_self.append(self.kv_cache[block.attn.value]) + if past is None: + present_cross.append(self.kv_cache[block.cross_attn.key]) + present_cross.append(self.kv_cache[block.cross_attn.value]) + + present_self = present_self + present_cross + # Add reshape and transpose ops to convert from 3D to 4D + present_self = [ + present_val.reshape(present_val.shape[:2] + (-1, 64)).transpose(1, 2) for present_val in present_self + ] + return logits, present_self diff --git a/onnxruntime/python/tools/transformers/onnx_model.py b/onnxruntime/python/tools/transformers/onnx_model.py index 9d1066b6e372b..7ae146ccc0b5d 100644 --- a/onnxruntime/python/tools/transformers/onnx_model.py +++ b/onnxruntime/python/tools/transformers/onnx_model.py @@ -426,6 +426,54 @@ def find_first_child_by_type(self, node, child_type, input_name_to_nodes=None, r return None + def match_child_path( + self, + node, + child_op_types, + child_output_index=None, + return_indice=None, + exclude=[], # noqa: B006 + ): + """ + Find a sequence of input edges based on constraints on parent op_type and index. + When input_index is None, we will find the first parent node based on constraints, + and return_indice will be appended the corresponding input index. + + Args: + node (str): current node name. + child_op_types (str): constraint of child node op_type of each input edge. + child_output_index (list): constraint of input index of each input edge. None means no constraint. + return_indice (list): a list to append the input index + When there is no constraint on input index of an edge. + + Returns: + children: a list of matched children node. + """ + if child_output_index is not None: + assert len(child_output_index) == len(child_op_types) + + current_node = node + matched_children = [] + for i, op_type in enumerate(child_op_types): + matched_child = None + node_children = self.get_children(current_node) + for child_i, child in enumerate(node_children): + if child.op_type == op_type and child not in exclude: + if child_output_index is not None and child_output_index[i] != child_i: + logger.debug( + f"Failed to match index={i} child_output_index={child_output_index[i]} op_type={op_type}", + stack_info=True, + ) + return None + matched_child = child + if matched_child is None: + logger.debug(f"Failed to match child op_type={op_type}", stack_info=True) + return None + + matched_children.append(matched_child) + current_node = matched_child + return matched_children + def find_first_parent_by_type(self, node, parent_type, output_name_to_node=None, recursive=True): if output_name_to_node is None: output_name_to_node = self.output_name_to_node() diff --git a/onnxruntime/python/tools/transformers/onnx_model_bart.py b/onnxruntime/python/tools/transformers/onnx_model_bart.py index 2a48722d17a19..61a786d7af60b 100644 --- a/onnxruntime/python/tools/transformers/onnx_model_bart.py +++ b/onnxruntime/python/tools/transformers/onnx_model_bart.py @@ -121,7 +121,7 @@ def fuse(self, reshape_node, input_name_to_nodes, output_name_to_node): class BartOnnxModel(BertOnnxModel): - def __init__(self, model, num_heads, hidden_size): + def __init__(self, model, num_heads, hidden_size, model_impl="hf"): super().__init__(model, num_heads, hidden_size) self.attention_mask = AttentionMask(self) self.attention_fusion = FusionBartAttention(self, self.hidden_size, self.num_heads, self.attention_mask) diff --git a/onnxruntime/python/tools/transformers/quantize_helper.py b/onnxruntime/python/tools/transformers/quantize_helper.py index a449e881ad361..6a25196dbc24c 100644 --- a/onnxruntime/python/tools/transformers/quantize_helper.py +++ b/onnxruntime/python/tools/transformers/quantize_helper.py @@ -7,7 +7,7 @@ import logging import os -import onnx # noqa: F401 +import onnx import torch from transformers.modeling_utils import Conv1D @@ -69,6 +69,7 @@ def quantize_onnx_model(onnx_model_path, quantized_model_path, use_external_data onnx_model_path, quantized_model_path, use_external_data_format=use_external_data_format, + extra_options={"DefaultTensorType": onnx.TensorProto.FLOAT}, ) logger.info(f"quantized model saved to:{quantized_model_path}") # TODO: inlcude external data in total model size. diff --git a/onnxruntime/python/tools/transformers/torch_onnx_export_helper.py b/onnxruntime/python/tools/transformers/torch_onnx_export_helper.py index f3e67930adbff..66f24c47f6cdb 100644 --- a/onnxruntime/python/tools/transformers/torch_onnx_export_helper.py +++ b/onnxruntime/python/tools/transformers/torch_onnx_export_helper.py @@ -4,6 +4,7 @@ # -------------------------------------------------------------------------- import torch +from torch._C._onnx import OperatorExportTypes TrainingMode = torch.onnx.TrainingMode from packaging.version import Version # noqa: E402 @@ -18,7 +19,7 @@ def torch_onnx_export( training=TrainingMode.EVAL, input_names=None, output_names=None, - operator_export_type=None, + operator_export_type=OperatorExportTypes.ONNX, opset_version=None, _retain_param_name=None, do_constant_folding=True, diff --git a/onnxruntime/test/framework/allocation_planner_test.cc b/onnxruntime/test/framework/allocation_planner_test.cc index b174ee4138be3..d7b1de5c930c5 100644 --- a/onnxruntime/test/framework/allocation_planner_test.cc +++ b/onnxruntime/test/framework/allocation_planner_test.cc @@ -327,10 +327,23 @@ class PlannerTest : public ::testing::Test { if (invoke_createPlan_explicityly) { onnxruntime::GraphViewer graph_viewer{graph_}; - status = SequentialPlanner::CreatePlan(nullptr, graph_viewer, outer_scope_node_args, execution_providers_, - kernel_create_info_map, {}, {}, state_->GetOrtValueNameIdxMap(), test_context, - MockStreamHandleRegsitry(), /* {{kCpuExecutionProvider, 1}}, {},*/ - ORT_TSTR(""), DefaultLoggingManager().DefaultLogger(), plan_); + status = SequentialPlanner::CreatePlan( + nullptr, + graph_viewer, + outer_scope_node_args, + execution_providers_, + kernel_create_info_map, + {}, + {}, + state_->GetOrtValueNameIdxMap(), + test_context, +#ifdef ORT_ENABLE_STREAM + MockStreamHandleRegsitry(), +#endif + /* {{kCpuExecutionProvider, 1}}, {},*/ + ORT_TSTR(""), + DefaultLoggingManager().DefaultLogger(), + plan_); EXPECT_TRUE(status.IsOK()) << status.ErrorMessage(); // AllocationPlanTestUtility::BasicIntegrityCheck(*plan_, name_to_arg_.size()); diff --git a/onnxruntime/test/framework/bfc_arena_test.cc b/onnxruntime/test/framework/bfc_arena_test.cc index 0d3e4449da939..e9f734057da1c 100644 --- a/onnxruntime/test/framework/bfc_arena_test.cc +++ b/onnxruntime/test/framework/bfc_arena_test.cc @@ -337,6 +337,7 @@ struct StreamMock : public Stream { Status CleanUpOnRunEnd() override { return Status::OK(); } }; +#ifdef ORT_ENABLE_STREAM TEST(StreamAwareArenaTest, TwoStreamAllocation) { StreamAwareArena a(std::unique_ptr(new CPUAllocator()), 1 << 30, false); CheckStats(&a, 0, 0, 0, 0); @@ -413,6 +414,7 @@ TEST(StreamAwareArenaTest, TestSecureTheChunk) { EXPECT_TRUE(waitFunctionInvoked) << "wait function should be invoked"; a.Free(p2); } +#endif TEST(BFCArenaTest, TestExtendStrategy) { int64_t extend_delta_bytes = 0; diff --git a/onnxruntime/test/framework/execution_frame_test.cc b/onnxruntime/test/framework/execution_frame_test.cc index ec572ce9deed8..60752d7456d97 100644 --- a/onnxruntime/test/framework/execution_frame_test.cc +++ b/onnxruntime/test/framework/execution_frame_test.cc @@ -75,7 +75,16 @@ TEST_F(ExecutionFrameTest, TensorAllocationTest) { ASSERT_STATUS_OK(state.FinalizeSessionState(ORT_TSTR(""), kernel_registry_manager)); vector outputs; - ExecutionFrame frame({}, {}, {}, outputs, {}, {}, state); + ExecutionFrame frame( + {}, + {}, + {}, + outputs, + {}, +#ifdef ORT_ENABLE_STREAM + {}, +#endif + state); int start_index = frame.GetNodeOffset(node->Index()); ASSERT_EQ(start_index, 0); @@ -150,7 +159,16 @@ TEST_F(ExecutionFrameTest, OutputShapeValidationTest) { ASSERT_STATUS_OK(state.FinalizeSessionState(ORT_TSTR(""), kernel_registry_manager)); vector outputs; - ExecutionFrame frame({}, {}, {}, outputs, {}, {}, state); + ExecutionFrame frame( + {}, + {}, + {}, + outputs, + {}, +#ifdef ORT_ENABLE_STREAM + {}, +#endif + state); int start_index = frame.GetNodeOffset(node->Index()); ASSERT_EQ(start_index, 0); @@ -216,7 +234,16 @@ TEST_F(ExecutionFrameTest, FeedInDataTest) { ASSERT_TRUE(mlvalue_name_idx_map.GetIdx("Y", y_idx).IsOK()); vector outputs; - ExecutionFrame frame(AsSpan({x_idx}), AsSpan({value}), AsSpan({y_idx}), outputs, {}, {}, state); + ExecutionFrame frame( + AsSpan({x_idx}), + AsSpan({value}), + AsSpan({y_idx}), + outputs, + {}, +#ifdef ORT_ENABLE_STREAM + {}, +#endif + state); OrtValue* p_ml_value = frame.GetMutableNodeInputOrOutputMLValue(0); Tensor* p_tensor_arg_0 = p_ml_value ? p_ml_value->GetMutable() : nullptr; @@ -299,7 +326,16 @@ TEST_F(ExecutionFrameTest, MemPatternTest) { std::vector(6, 1.0f), &v3); std::vector outputs; - ExecutionFrame frame(AsSpan({x1_idx, x2_idx, x3_idx}), AsSpan({v1, v2, v3}), AsSpan({t3_idx}), outputs, {}, {}, state); + ExecutionFrame frame( + AsSpan({x1_idx, x2_idx, x3_idx}), + AsSpan({v1, v2, v3}), + AsSpan({t3_idx}), + outputs, + {}, +#ifdef ORT_ENABLE_STREAM + {}, +#endif + state); OrtValue& mlvalue3 = *frame.GetMutableNodeInputOrOutputMLValue(3); OrtValue& mlvalue4 = *frame.GetMutableNodeInputOrOutputMLValue(4); @@ -388,7 +424,16 @@ TEST_F(ExecutionFrameTest, MemPatternWithExternalOutputsTest) { CreateMLValue(cpu_allocator, std::vector{2, 2}, std::vector(4, 1.0f), &t_value); vector outputs; - ExecutionFrame frame(AsSpan({x_idx}), AsSpan({x_value}), AsSpan({y_idx}), outputs, {}, {}, state); + ExecutionFrame frame( + AsSpan({x_idx}), + AsSpan({x_value}), + AsSpan({y_idx}), + outputs, + {}, +#ifdef ORT_ENABLE_STREAM + {}, +#endif + state); ASSERT_FALSE(frame.GetMutableNodeInputOrOutputMLValue(t_idx)->IsTensor()); ASSERT_STATUS_OK(frame.SetOutputMLValue(t_idx, t_value)); diff --git a/onnxruntime/test/platform/windows/logging/HowToValidateEtwSinkOutput.md b/onnxruntime/test/platform/windows/logging/HowToValidateEtwSinkOutput.md index 59fe946b929f2..309b474c016c9 100644 --- a/onnxruntime/test/platform/windows/logging/HowToValidateEtwSinkOutput.md +++ b/onnxruntime/test/platform/windows/logging/HowToValidateEtwSinkOutput.md @@ -3,13 +3,13 @@ The ETW Sink (ONNXRuntimeTraceLoggingProvider) allows ONNX semi-structured printf style logs to be output via ETW. ETW makes it easy and useful to only enable and listen for events with great performance, and when you need them instead of only at compile time. -Therefore ONNX will preserve any existing loggers and log severity [provided at compile time](docs/FAQ.md?plain=1#L7). +Therefore ONNX will preserve any existing loggers and log severity [provided at compile time](/docs/FAQ.md?plain=1#L7). However, when the provider is enabled a new ETW logger sink will also be added and the severity separately controlled via ETW dynamically. - Provider GUID: 929DD115-1ECB-4CB5-B060-EBD4983C421D -- Keyword: Logs (0x2) keyword per [logging.h](include\onnxruntime\core\common\logging\logging.h) -- Level: 1-5 ([CRITICAL through VERBOSE](https://learn.microsoft.com/en-us/windows/win32/api/evntprov/ns-evntprov-event_descriptor)) [mapping](onnxruntime\core\platform\windows\logging\etw_sink.cc) to [ONNX severity](include\onnxruntime\core\common\logging\severity.h) in an intuitive manner +- Keyword: Logs (0x2) keyword per [logging.h](/include/onnxruntime/core/common/logging/logging.h) +- Level: 1-5 ([CRITICAL through VERBOSE](https://learn.microsoft.com/en-us/windows/win32/api/evntprov/ns-evntprov-event_descriptor)) [mapping](/onnxruntime/core/platform/windows/logging/etw_sink.cc) to [ONNX severity](/include/onnxruntime/core/common/logging/severity.h) in an intuitive manner Notes: - The ETW provider must be enabled prior to session creation, as that as when internal logging setup is complete diff --git a/onnxruntime/test/providers/qnn/split_op_test.cc b/onnxruntime/test/providers/qnn/split_op_test.cc index 57e4b211777bb..6dc721edb421e 100644 --- a/onnxruntime/test/providers/qnn/split_op_test.cc +++ b/onnxruntime/test/providers/qnn/split_op_test.cc @@ -302,19 +302,46 @@ TEST_F(QnnHTPBackendTests, Split_Int32_Opset13) { // Test 8-bit QDQ Split opset 18 on HTP backend: equal split of axis 0 via 'num_outputs' attribute // and 'split' input. TEST_F(QnnHTPBackendTests, Split_Equal_Axis0_Opset18) { + // Split 6 into 3 outputs of lengths [2, 2, 2] + TestInputDef input_def({6, 2}, false, + {0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.f, 8.f, 9.0f, 10.0f, 11.0f}); + // Use 'split' input (initializer). - RunQDQSplitOpTestOnHTP(TestInputDef({4, 2}, false, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.f, 8.f}), - {2, 2}, // split - 0, // axis - -1, // num_outputs - 18, // opset + RunQDQSplitOpTestOnHTP(input_def, + {2, 2, 2}, // split + 0, // axis + -1, // num_outputs + 18, // opset ExpectedEPNodeAssignment::All); // Use 'num_outputs' attribute. - RunQDQSplitOpTestOnHTP(TestInputDef({4, 2}, false, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.f, 8.f}), + RunQDQSplitOpTestOnHTP(input_def, + {}, // split (use num_outputs instead) + 0, // axis + 3, // num_outputs + 18, // opset + ExpectedEPNodeAssignment::All); +} + +// Test 8-bit QDQ Split opset 18 on HTP backend. Use an uneven split (last chunk should be smaller). +TEST_F(QnnHTPBackendTests, Split_NonEqual_Axis0_Opset18) { + // Split 7 into 3 outputs of lengths [3, 3, 1] + TestInputDef input_def({7, 2}, false, + {0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.f, 8.f, 9.0f, 10.0f, 11.0f, 12.0f, 13.0f}); + + // Use a `split` input with uneven split lengths. + RunQDQSplitOpTestOnHTP(input_def, + {3, 3, 1}, // split + 0, // axis + -1, // num_outputs + 18, // opset + ExpectedEPNodeAssignment::All); + + // Use a `num_outputs` attribute that does not evenly divide into shape[axis]. + RunQDQSplitOpTestOnHTP(input_def, {}, // split (use num_outputs instead) 0, // axis - 2, // num_outputs + 3, // num_outputs 18, // opset ExpectedEPNodeAssignment::All); } diff --git a/onnxruntime/test/python/onnxruntime_test_python_symbolic_shape_infer.py b/onnxruntime/test/python/onnxruntime_test_python_symbolic_shape_infer.py index 67db411ddc246..eca1430448e8e 100644 --- a/onnxruntime/test/python/onnxruntime_test_python_symbolic_shape_infer.py +++ b/onnxruntime/test/python/onnxruntime_test_python_symbolic_shape_infer.py @@ -392,6 +392,208 @@ def test_div_precision(self): self.assertEqual(len(output_dims), 1) self.assertEqual(output_dims[0].dim_value, 512) + def test_quantize_linear(self): + """ + Test ONNX QuantizeLinear op. + Check that the output shape is propagated from the first input and that the output data + type comes from the zero-point input. + """ + initializers = [ + helper.make_tensor( + "scale", + TensorProto.FLOAT, + [], + [1.0], + ), + helper.make_tensor( + "zero_point", + TensorProto.INT8, + [], + [16], + ), + ] + + nodes = [ + helper.make_node( + "QuantizeLinear", + inputs=[ + "input_f32", + "scale", + "zero_point", + ], + outputs=["output_s8"], + ), + ] + + inputs = [ + helper.make_tensor_value_info("input_f32", TensorProto.FLOAT, ["b", 2, 3, 4]), + ] + + outputs = [ + helper.make_tensor_value_info("output_s8", TensorProto.UNDEFINED, None), + ] + + graph = helper.make_graph(nodes, "QuantizeLinear_Test", inputs, outputs, initializers) + model = helper.make_model(graph) + + inferred = SymbolicShapeInference.infer_shapes(model, auto_merge=True) + + expected_shapes = [ + helper.make_tensor_value_info("output_s8", TensorProto.INT8, ["b", 2, 3, 4]), + ] + self._check_shapes(graph, inferred.graph, expected_shapes) + + def test_quantize_linear_ms_domain(self): + """ + Test QuantizeLinear op ('com.microsoft' domain). + Check that the output shape is propagated from the first input and that the output data + type comes from the zero-point input. + """ + initializers = [ + helper.make_tensor( + "scale", + TensorProto.FLOAT, + [], + [1.0], + ), + helper.make_tensor( + "zero_point", + TensorProto.UINT16, + [], + [16], + ), + ] + + nodes = [ + helper.make_node( + "QuantizeLinear", + inputs=[ + "input_f32", + "scale", + "zero_point", + ], + outputs=["output_u16"], + domain="com.microsoft", + ), + ] + + inputs = [ + helper.make_tensor_value_info("input_f32", TensorProto.FLOAT, ["b", 2, 3, 4]), + ] + + outputs = [ + helper.make_tensor_value_info("output_u16", TensorProto.UNDEFINED, None), + ] + + graph = helper.make_graph(nodes, "QuantizeLinear_MSDomain_Test", inputs, outputs, initializers) + model = helper.make_model(graph) + + inferred = SymbolicShapeInference.infer_shapes(model, auto_merge=True) + + expected_shapes = [ + helper.make_tensor_value_info("output_u16", TensorProto.UINT16, ["b", 2, 3, 4]), + ] + self._check_shapes(graph, inferred.graph, expected_shapes) + + def test_quantize_linear_no_zp_input(self): + """ + Test QuantizeLinear op ('com.microsoft' domain). + Check that the output shape is propagated from the first input. + The zero-point input is missing, so the output data type should default to uint8. + """ + initializers = [ + helper.make_tensor( + "scale", + TensorProto.FLOAT, + [], + [1.0], + ), + ] + + nodes = [ + helper.make_node( + "QuantizeLinear", + inputs=[ + "input_f32", + "scale", + ], + outputs=["output_u8"], + domain="com.microsoft", + ), + ] + + inputs = [ + helper.make_tensor_value_info("input_f32", TensorProto.FLOAT, ["b", 2, 3, 4]), + ] + + outputs = [ + helper.make_tensor_value_info("output_u8", TensorProto.UNDEFINED, None), + ] + + graph = helper.make_graph(nodes, "QuantizeLinear_NoZP_Test", inputs, outputs, initializers) + model = helper.make_model(graph) + + inferred = SymbolicShapeInference.infer_shapes(model, auto_merge=True) + + # Check that the output shape is propagated from the first input and that the + # output data type comes from the zero-point input. + expected_shapes = [ + helper.make_tensor_value_info("output_u8", TensorProto.UINT8, ["b", 2, 3, 4]), + ] + self._check_shapes(graph, inferred.graph, expected_shapes) + + def test_dequantize_linear_ms_domain(self): + """ + Test DequantizeLinear operator ('com.microsoft' domain). + Check that the output shape is propagated from the first input and that the output data + type comes from the scale input. + """ + initializers = [ + helper.make_tensor( + "scale", + TensorProto.FLOAT, + [], + [1.0], + ), + helper.make_tensor( + "zero_point", + TensorProto.UINT16, + [], + [16], + ), + ] + + nodes = [ + helper.make_node( + "DequantizeLinear", + inputs=[ + "input_u16", + "scale", + "zero_point", + ], + outputs=["output_f32"], + domain="com.microsoft", + ), + ] + + inputs = [ + helper.make_tensor_value_info("input_u16", TensorProto.UINT16, ["b", 2, 3, 4]), + ] + + outputs = [ + helper.make_tensor_value_info("output_f32", TensorProto.UNDEFINED, None), + ] + + graph = helper.make_graph(nodes, "DequantizeLinear_MSDomain_Test", inputs, outputs, initializers) + model = helper.make_model(graph) + + inferred = SymbolicShapeInference.infer_shapes(model, auto_merge=True) + + expected_shapes = [ + helper.make_tensor_value_info("output_f32", TensorProto.FLOAT, ["b", 2, 3, 4]), + ] + self._check_shapes(graph, inferred.graph, expected_shapes) + class TestSymbolicShapeInferenceForSlice(unittest.TestCase): def check_slice_of_concat(self, input_dims, start, end, step, expected_output_dim): diff --git a/onnxruntime/test/python/quantization/test_qdq.py b/onnxruntime/test/python/quantization/test_qdq.py index 4de797400836f..223f405e8947a 100644 --- a/onnxruntime/test/python/quantization/test_qdq.py +++ b/onnxruntime/test/python/quantization/test_qdq.py @@ -601,6 +601,13 @@ def verify_qdq(self, per_channel, activation_type, weight_type, extra_options=No ) check_model_correctness(self, model_fp32_path, model_qdq_path, data_reader.get_next()) + # If the model uses Q/DQ ops with "com.microsoft" domain (e.g., for int16 support), + # then ensure the model has the appropriate opset import. + if extra_options and extra_options.get("UseQDQContribOps", False): + qdq_model = onnx.load_model(model_qdq_path) + ms_opset = next((opset for opset in qdq_model.opset_import if opset.domain == "com.microsoft"), None) + self.assertIsNot(ms_opset, None) + def verify_qop(self, per_channel, is_quant_type_int8): np.random.seed(1) model_fp32_path = str(Path(self._tmp_model_dir.name) / f"conv_relu_fp32.{per_channel}.onnx") diff --git a/onnxruntime/test/python/quantization/test_quantizer_shape_inference.py b/onnxruntime/test/python/quantization/test_quantizer_shape_inference.py new file mode 100644 index 0000000000000..2b5d1f36070e5 --- /dev/null +++ b/onnxruntime/test/python/quantization/test_quantizer_shape_inference.py @@ -0,0 +1,92 @@ +#!/usr/bin/env python +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +import unittest + +import numpy as np +import onnx +import onnx.helper as oh +import onnx.numpy_helper as onh + +from onnxruntime.quantization.onnx_quantizer import ONNXQuantizer +from onnxruntime.quantization.quant_utils import QuantizationMode, QuantType + + +class TestQuantizerShapeInference(unittest.TestCase): + def test_com_microsoft(self): + model = oh.make_model( + oh.make_graph( + [ + oh.make_node("MatMul", ["X", "W1"], ["T1"]), + oh.make_node("FusedMatMul", ["T1", "W2"], ["T2"], domain="com.microsoft"), + oh.make_node("MatMul", ["T2", "W3"], ["T3"]), + oh.make_node("MatMul", ["T3", "W4"], ["Y"]), + ], + "name", + [oh.make_tensor_value_info("X", onnx.TensorProto.FLOAT, [1, 4])], + [oh.make_tensor_value_info("Y", onnx.TensorProto.FLOAT, [1, 4])], + [ + onh.from_array(np.random.randn(4, 4).astype(np.float32), "W1"), + onh.from_array(np.random.randn(4, 4).astype(np.float32), "W2"), + onh.from_array(np.random.randn(4, 4).astype(np.float32), "W3"), + onh.from_array(np.random.randn(4, 4).astype(np.float32), "W4"), + ], + ), + opset_imports=[oh.make_opsetid("", 18), oh.make_opsetid("com.microsoft", 1)], + ) + model_shaped = onnx.shape_inference.infer_shapes(model) + shaped_results = set(t.name for t in model_shaped.graph.value_info) + # every result after T1 depends on T2 coming from a node com.microsoft, + # shape_inference cannot go beyond this point + self.assertEqual(shaped_results, {"T1"}) + + # first try: checks it raises an exception + quantizer = ONNXQuantizer( + model, + False, # per_channel + False, # reduce_range + QuantizationMode.IntegerOps, # mode + False, # static + QuantType.QInt8, # weight_type, + QuantType.QUInt8, # dynamic activation only supports uint8 + None, + [], # nodes_to_quantize, + [], # nodes_to_exclude + ["MatMul"], # op_types_to_quantize, + {"MatMulConstBOnly": True}, # extra_options, + # {'DefaultTensorType': 1, } + ) + + with self.assertRaises(RuntimeError) as e: + quantizer.quantize_model() + self.assertIn("Unable to find data type for weight_name=", str(e)) + + # second try: checks it works + quantizer = ONNXQuantizer( + model, + False, # per_channel + False, # reduce_range + QuantizationMode.IntegerOps, # mode + False, # static + QuantType.QInt8, # weight_type, + QuantType.QUInt8, # dynamic activation only supports uint8 + None, + [], # nodes_to_quantize, + [], # nodes_to_exclude + ["MatMul"], # op_types_to_quantize, + { + "MatMulConstBOnly": True, + "DefaultTensorType": 1, + }, + ) + + model = quantizer.quantize_model() + ops = {n.op_type for n in model.graph.node} + self.assertEqual(ops, {"Cast", "FusedMatMul", "MatMulInteger", "DynamicQuantizeLinear", "Mul"}) + + +if __name__ == "__main__": + unittest.main(verbosity=2) diff --git a/onnxruntime/test/python/quantization/test_subgraph.py b/onnxruntime/test/python/quantization/test_subgraph.py new file mode 100644 index 0000000000000..c425bf956f976 --- /dev/null +++ b/onnxruntime/test/python/quantization/test_subgraph.py @@ -0,0 +1,64 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import os +import tempfile +import unittest +import urllib.request + +import onnx + +from onnxruntime.quantization import quantize_dynamic + + +class TestDynamicQuantizationSubgraph(unittest.TestCase): + def test_dynamic_quantization_subgraph(self): + with tempfile.TemporaryDirectory() as tmpdir: + onnx_path = os.path.join(tmpdir, "decoder_model_merged.onnx") + quantized_onnx_path = os.path.join(tmpdir, "decoder_model_merged_quantized.onnx") + urllib.request.urlretrieve( + "https://huggingface.co/fxmarty/t5-tiny-onnx-testing/resolve/main/decoder_model_merged.onnx", onnx_path + ) + + quantize_dynamic( + model_input=onnx_path, + model_output=quantized_onnx_path, + per_channel=True, + op_types_to_quantize=[ + "Conv", + "MatMul", + "Attention", + "LSTM", + "Gather", + "Transpose", + "EmbedLayerNormalization", + ], + extra_options={"EnableSubgraph": True}, + ) + model = onnx.load(quantized_onnx_path) + + # The initializer `shared.weight_merged_0` is attached to the top-level graph, and used in a Gather node in each subgraphs. + # We expect the quantized Gather (after which a DequantizeLinear is attached) initializer to also be attached to the top-level graph. + found_gather_quantized = False + for initializer in model.graph.initializer: + if initializer.name == "shared.weight_merged_0_quantized": + found_gather_quantized = True + break + self.assertTrue(found_gather_quantized) + + found_gather_scale = False + for initializer in model.graph.initializer: + if initializer.name == "shared.weight_merged_0_scale": + found_gather_scale = True + break + self.assertTrue(found_gather_scale) + + # No initializers related to the Gather node should be attached to the subgraphs. + for node in model.graph.node: + for attr in node.attribute: + if attr.type == onnx.AttributeProto.GRAPH: + for initializer in attr.g.initializer: + self.assertTrue("shared.weight" not in initializer.name) diff --git a/onnxruntime/test/python/transformers/test_generation.py b/onnxruntime/test/python/transformers/test_generation.py index c9db1fbc02931..33ec1bd7728fe 100644 --- a/onnxruntime/test/python/transformers/test_generation.py +++ b/onnxruntime/test/python/transformers/test_generation.py @@ -361,7 +361,8 @@ def run_configs(self, optional_arguments): # INT8 CPU arguments = self.base_arguments + self.int8_cpu_arguments + optional_arguments - self.run_export(arguments) + if "--model_impl" not in arguments: + self.run_export(arguments) @pytest.mark.slow def test_required_args(self): @@ -380,18 +381,24 @@ def test_logits_processor(self): @pytest.mark.slow def test_cross_qk_overall(self): - decoder_input_ids = [ - "--chain_model", - "--collect_cross_qk", - "--output_cross_qk", - "--use_forced_decoder_ids", - "--extra_decoding_ids", - "--output_no_speech_probs", + cross_qk_input_args = [ "--use_vocab_mask", "--use_prefix_vocab_mask", + "--use_forced_decoder_ids", "--use_logits_processor", + "--collect_cross_qk", + "--extra_decoding_ids", ] - self.run_configs(decoder_input_ids) + cross_qk_output_args = [ + "--output_cross_qk", + "--output_no_speech_probs", + ] + self.run_configs(cross_qk_input_args + cross_qk_output_args) + + @pytest.mark.slow + def test_openai_impl_whisper(self): + optional_args = ["--model_impl", "openai"] + self.run_configs(optional_args) if __name__ == "__main__": diff --git a/onnxruntime/test/python/transformers/test_whisper_timestamp_processor.py b/onnxruntime/test/python/transformers/test_whisper_timestamp_processor.py index 77ce09d7e793b..7892000ae45a0 100644 --- a/onnxruntime/test/python/transformers/test_whisper_timestamp_processor.py +++ b/onnxruntime/test/python/transformers/test_whisper_timestamp_processor.py @@ -50,7 +50,7 @@ def run_timestamp(self, provider: str): ort_out = sess.run(None, ort_inputs) ort_out_tensor = torch.from_numpy(ort_out[0]) ort_transcription = processor.batch_decode( - ort_out_tensor[0][0].view(1, -1), skip_special_tokens=True, output_offsets=True + ort_out_tensor[0][0].view(1, -1), skip_special_tokens=True, output_offsets=True, decode_with_timestamps=True ) print(ort_transcription) expected_transcription = [ @@ -58,7 +58,7 @@ def run_timestamp(self, provider: str): "text": "<|0.00|> Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.<|5.44|>", "offsets": [ { - "text": "<|0.00|> Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.<|5.44|>", + "text": " Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.", "timestamp": (0.0, 5.44), } ], diff --git a/orttraining/orttraining/python/training/ortmodule/_custom_gradient_registry.py b/orttraining/orttraining/python/training/ortmodule/_custom_gradient_registry.py index 77317242727b4..4883075112dcb 100644 --- a/orttraining/orttraining/python/training/ortmodule/_custom_gradient_registry.py +++ b/orttraining/orttraining/python/training/ortmodule/_custom_gradient_registry.py @@ -241,7 +241,7 @@ def native_group_norm_gradient(): # are available for all versions, though they are not that convienent to use. def _upsample_gradient(backward_fn, dims): scales = ["" for _ in range(dims)] - if "bilinear" in backward_fn: + if "bicubic" in backward_fn: scales = ["I(2)", *scales] return [ ("Shape", ["I(0)"], ["Shape_X"]), @@ -271,3 +271,8 @@ def upsample_nearest2d_gradient(): @register_gradient("org.pytorch.aten", "ATen", "upsample_nearest3d", "vec") def upsample_nearest3d_gradient(): return _upsample_gradient("upsample_nearest3d_backward", 3) + + +@register_gradient("org.pytorch.aten", "ATen", "upsample_bicubic2d", "vec") +def upsample_bicubic2d_gradient(): + return _upsample_gradient("upsample_bicubic2d_backward", 2) diff --git a/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py b/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py index 99e8851b6a697..9288027f0188c 100644 --- a/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py +++ b/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py @@ -808,3 +808,16 @@ def upsample_nearest2d(g, input, output_size, scale_factors): @register_symbolic("upsample_nearest3d") def upsample_nearest3d(g, input, output_size, scale_factors): return _upsample_nearest(g, input, output_size, scale_factors, "upsample_nearest3d") + + +@register_symbolic("upsample_bicubic2d") +def upsample_bicubic2d(g, input, output_size, align_corners, scale_factors): + return g.op( + "org.pytorch.aten::ATen", + input, + output_size, + align_corners, + scale_factors, + operator_s="upsample_bicubic2d", + overload_name_s="vec", + ) diff --git a/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/setup.py b/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/setup.py index fa72f3b134917..898c242bb3c32 100644 --- a/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/setup.py +++ b/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/setup.py @@ -23,7 +23,7 @@ cur_file_dir, ] -extra_compile_args = {"cxx": ["-O3"]} +extra_compile_args = {"cxx": ["-O3", "-std=c++17"]} setup( name="torch_interop_utils", ext_modules=[ diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py index 938d33cc9a714..6a6832e06330a 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py @@ -1805,6 +1805,34 @@ def run_step(model, input): _test_helpers.assert_values_are_close(ort_input.grad, pt_input.grad) +def test_aten_upsample_bicubic(): + class _NeuralNetUpsampleBicubic(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input): + return torch.nn.functional.interpolate(input, size=(8, 12), mode="bicubic") + + device = "cuda" + pt_model = _NeuralNetUpsampleBicubic().to(device) + ort_model = ORTModule(copy.deepcopy(pt_model)) + + def run_step(model, input): + prediction = model(input) + prediction.sum().backward() + return prediction + + # reset manual seed to reset the generator + torch.manual_seed(2333) + pt_input = torch.randn([2, 4, 6, 8], dtype=torch.float, device=device, requires_grad=True) + ort_input = copy.deepcopy(pt_input) + pt_prediction = run_step(pt_model, pt_input) + ort_prediction = run_step(ort_model, ort_input) + + _test_helpers.assert_values_are_close(ort_prediction, pt_prediction) + _test_helpers.assert_values_are_close(ort_input.grad, pt_input.grad) + + def test_gradient_correctness_cast_chain(): class NeuralNetCast(torch.nn.Module): def __init__(self, D): diff --git a/tools/ci_build/github/azure-pipelines/nodejs/templates/test_linux.yml b/tools/ci_build/github/azure-pipelines/nodejs/templates/test_linux.yml index 864d1002a90fc..7b03c0e82f4bb 100644 --- a/tools/ci_build/github/azure-pipelines/nodejs/templates/test_linux.yml +++ b/tools/ci_build/github/azure-pipelines/nodejs/templates/test_linux.yml @@ -4,7 +4,7 @@ parameters: stages: - stage: Nodejs_Test_${{ parameters.StageSuffix }} dependsOn: - - Nodejs_Packaging_CPU + - Nodejs_Packaging condition: succeeded() jobs: - job: @@ -18,4 +18,3 @@ stages: value: '$(Build.BinariesDirectory)' steps: - template: test.yml - diff --git a/tools/ci_build/github/azure-pipelines/nodejs/templates/test_macos.yml b/tools/ci_build/github/azure-pipelines/nodejs/templates/test_macos.yml index 871d7894e5315..dc52e9a22f05b 100644 --- a/tools/ci_build/github/azure-pipelines/nodejs/templates/test_macos.yml +++ b/tools/ci_build/github/azure-pipelines/nodejs/templates/test_macos.yml @@ -3,7 +3,7 @@ parameters: stages: - stage: Nodejs_Test_MacOS_${{ parameters.StageSuffix }} dependsOn: - - Nodejs_Packaging_CPU + - Nodejs_Packaging condition: succeeded() jobs: - job: diff --git a/tools/ci_build/github/azure-pipelines/nodejs/templates/test_win.yml b/tools/ci_build/github/azure-pipelines/nodejs/templates/test_win.yml index c823ac788f925..9b3c61b2d3d85 100644 --- a/tools/ci_build/github/azure-pipelines/nodejs/templates/test_win.yml +++ b/tools/ci_build/github/azure-pipelines/nodejs/templates/test_win.yml @@ -4,7 +4,7 @@ parameters: stages: - stage: Nodejs_Test_${{ parameters.StageSuffix }} dependsOn: - - Nodejs_Packaging_CPU + - Nodejs_Packaging condition: succeeded() jobs: - job: diff --git a/tools/ci_build/github/azure-pipelines/orttraining-py-packaging-pipeline-cuda.yml b/tools/ci_build/github/azure-pipelines/orttraining-py-packaging-pipeline-cuda.yml index f244851f8cc37..d9ab85ee80ce3 100644 --- a/tools/ci_build/github/azure-pipelines/orttraining-py-packaging-pipeline-cuda.yml +++ b/tools/ci_build/github/azure-pipelines/orttraining-py-packaging-pipeline-cuda.yml @@ -15,7 +15,7 @@ stages: torch_version: '2.0.0' opset_version: '15' cuda_version: '11.8' - cmake_cuda_architectures: 60;61;70;75;80;86;90 + cmake_cuda_architectures: 60;61;70;75;80;86 docker_file: Dockerfile.manylinux2_28_training_cuda11_8 agent_pool: Onnxruntime-Linux-GPU upload_wheel: 'yes' diff --git a/tools/ci_build/github/azure-pipelines/py-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/py-packaging-pipeline.yml index 5349b1ca67ab1..6b0ae085fa4db 100644 --- a/tools/ci_build/github/azure-pipelines/py-packaging-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/py-packaging-pipeline.yml @@ -34,6 +34,11 @@ parameters: type: boolean default: true +- name: enable_windows_x64_qnn + displayName: 'Whether Windows x86_64 package with QNN EP is built.' + type: boolean + default: true + - name: build_py_parameters displayName: 'Specify extra build parameters' type: string @@ -70,5 +75,6 @@ stages: enable_mac_cpu: ${{ parameters.enable_mac_cpu }} enable_linux_arm: ${{ parameters.enable_linux_arm }} enable_windows_arm64_qnn: ${{ parameters.enable_windows_arm64_qnn }} + enable_windows_x64_qnn: ${{ parameters.enable_windows_x64_qnn }} build_py_parameters: ${{ parameters.build_py_parameters }} cmake_build_type: ${{ parameters.cmake_build_type }} diff --git a/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml b/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml index 3325e261715cf..de98a64ad90ac 100644 --- a/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml +++ b/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml @@ -501,12 +501,13 @@ stages: displayName: 'Clean Agent Directories' condition: always() -- stage: Nodejs_Packaging_CPU +- stage: Nodejs_Packaging dependsOn: + - Windows_CI_GPU_DML_Dev + - Windows_CI_GPU_DML_Dev_arm64 - Linux_C_API_Packaging_CPU + - Linux_C_API_Packaging_GPU_TensorRT_x64 - MacOS_C_API_Package_Publish - - Windows_Packaging_CPU_x64_${{ parameters.BuildVariant }} - - Windows_Packaging_CPU_arm64_${{ parameters.BuildVariant }} condition: succeeded() jobs: - job: @@ -533,18 +534,6 @@ stages: workingDirectory: '$(Build.SourcesDirectory)' displayName: 'Testing: force EOL to lf on windows for /js/**' - - task: DownloadPipelineArtifact@0 - displayName: 'Download Pipeline Artifact - NuGet (Win x64)' - inputs: - artifactName: 'onnxruntime-win-x64' - targetPath: '$(Build.BinariesDirectory)/nuget-artifact' - - - task: DownloadPipelineArtifact@0 - displayName: 'Download Pipeline Artifact - NuGet (Win ARM64)' - inputs: - artifactName: 'onnxruntime-win-arm64' - targetPath: '$(Build.BinariesDirectory)/nuget-artifact' - - task: DownloadPipelineArtifact@0 displayName: 'Download Pipeline Artifact - NuGet (OSX)' inputs: @@ -554,7 +543,7 @@ stages: - task: DownloadPipelineArtifact@0 displayName: 'Download Pipeline Artifact - NuGet (Linux x64)' inputs: - artifactName: 'onnxruntime-linux-x64' + artifactName: 'onnxruntime-linux-x64-tensorrt' targetPath: '$(Build.BinariesDirectory)/nuget-artifact' - task: DownloadPipelineArtifact@0 @@ -566,13 +555,13 @@ stages: - task: DownloadPipelineArtifact@0 displayName: 'Download Pipeline Artifact - Nodejs (Win x64)' inputs: - artifactName: 'drop-onnxruntime-nodejs-win-x64' + artifactName: 'drop-onnxruntime-nodejs-win-x64-dml' targetPath: '$(Build.BinariesDirectory)/nodejs-artifacts/win32/x64/' - task: DownloadPipelineArtifact@0 displayName: 'Download Pipeline Artifact - Nodejs (Win ARM64)' inputs: - artifactName: 'drop-onnxruntime-nodejs-win-arm64' + artifactName: 'drop-onnxruntime-nodejs-win-arm64-dml' targetPath: '$(Build.BinariesDirectory)/nodejs-artifacts/win32/arm64/' - task: DownloadPipelineArtifact@0 @@ -590,7 +579,7 @@ stages: - task: DownloadPipelineArtifact@0 displayName: 'Download Pipeline Artifact - Nodejs (Linux x64)' inputs: - artifactName: 'drop-onnxruntime-nodejs-linux-x64' + artifactName: 'drop-onnxruntime-nodejs-linux-x64-tensorrt' targetPath: '$(Build.BinariesDirectory)/nodejs-artifacts/linux/x64/' - task: DownloadPipelineArtifact@0 @@ -631,38 +620,32 @@ stages: # Node.js binding win32/x64 - task: CopyFiles@2 - displayName: 'Copy nuget binaries to: $(Build.SourcesDirectory)\js\node\bin\napi-v3\win32\x64\' - inputs: - SourceFolder: '$(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo\nuget-artifacts\onnxruntime-win-x64\lib' - Contents: '*.dll' - TargetFolder: '$(Build.SourcesDirectory)\js\node\bin\napi-v3\win32\x64' - - task: CopyFiles@2 - displayName: 'Copy nodejs binaries to: $(Build.SourcesDirectory)\js\node\bin\napi-v3\win32\x64\' + displayName: 'Copy binaries to: $(Build.SourcesDirectory)\js\node\bin\napi-v3\win32\x64\' inputs: SourceFolder: '$(Build.BinariesDirectory)\nodejs-artifacts\win32\x64' - Contents: '*.node' + Contents: | + *.dll + *.node TargetFolder: '$(Build.SourcesDirectory)\js\node\bin\napi-v3\win32\x64' # Node.js binding win32/arm64 - task: CopyFiles@2 - displayName: 'Copy nuget binaries to: $(Build.SourcesDirectory)\js\node\bin\napi-v3\win32\arm64\' - inputs: - SourceFolder: '$(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo\nuget-artifacts\onnxruntime-win-arm64\lib' - Contents: '*.dll' - TargetFolder: '$(Build.SourcesDirectory)\js\node\bin\napi-v3\win32\arm64' - - task: CopyFiles@2 - displayName: 'Copy nodejs binaries to: $(Build.SourcesDirectory)\js\node\bin\napi-v3\win32\arm64\' + displayName: 'Copy binaries to: $(Build.SourcesDirectory)\js\node\bin\napi-v3\win32\arm64\' inputs: SourceFolder: '$(Build.BinariesDirectory)\nodejs-artifacts\win32\arm64' - Contents: '*.node' + Contents: | + *.dll + *.node TargetFolder: '$(Build.SourcesDirectory)\js\node\bin\napi-v3\win32\arm64' # Node.js binding linux/x64 - task: CopyFiles@2 displayName: 'Copy nuget binaries to: $(Build.SourcesDirectory)\js\node\bin\napi-v3\linux\x64\' inputs: - SourceFolder: '$(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo\nuget-artifacts\onnxruntime-linux-x64\lib' - Contents: 'libonnxruntime.so.*' + SourceFolder: '$(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo\nuget-artifacts\onnxruntime-linux-x64-tensorrt\lib' + Contents: | + libonnxruntime.so.* + libonnxruntime_providers_*.so TargetFolder: '$(Build.SourcesDirectory)\js\node\bin\napi-v3\linux\x64' - task: CopyFiles@2 displayName: 'Copy nodejs binaries to: $(Build.SourcesDirectory)\js\node\bin\napi-v3\linux\x64\' diff --git a/tools/ci_build/github/azure-pipelines/templates/py-packaging-stage.yml b/tools/ci_build/github/azure-pipelines/templates/py-packaging-stage.yml index 28870a9eea7e0..6bfb1862d528a 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-packaging-stage.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-packaging-stage.yml @@ -40,6 +40,11 @@ parameters: type: boolean default: true +- name: enable_windows_x64_qnn + displayName: 'Whether Windows x86_64 package with QNN EP is built.' + type: boolean + default: true + # TODO: Now the Windows jobs use a different cmake build type. Consider to merge it. - name: cmake_build_type type: string @@ -459,3 +464,9 @@ stages: QNN_SDK: 'qnn-v2.18.0.240101_win' PYTHON_VERSION: '3.11' NUMPY_VERSION: '1.25.2' + + - ${{ if eq(parameters.enable_windows_x64_qnn, true) }}: + - template: py-win-x64-qnn.yml + parameters: + MACHINE_POOL: 'Onnxruntime-QNNEP-Windows-2022-CPU' + QNN_SDK: 'qnn-v2.18.0.240101_win' diff --git a/tools/ci_build/github/azure-pipelines/templates/py-win-x64-qnn.yml b/tools/ci_build/github/azure-pipelines/templates/py-win-x64-qnn.yml new file mode 100644 index 0000000000000..30f21e933ee36 --- /dev/null +++ b/tools/ci_build/github/azure-pipelines/templates/py-win-x64-qnn.yml @@ -0,0 +1,177 @@ +parameters: + +- name: MACHINE_POOL + type: string + default: 'Onnxruntime-QNNEP-Windows-2022-CPU' + +- name: QNN_SDK + displayName: QNN Windows SDK path + type: string + default: qnn-v2.18.0.240101_win + +- name: ENV_SETUP_SCRIPT + type: string + default: '' + +- name: BUILD_PY_PARAMETERS + displayName: > + Extra parameters to pass to build.py. Don't put newlines in here. + type: string + default: '' + +jobs: +- job: Win_py_x64_qnn_Wheels + timeoutInMinutes: 210 + workspace: + clean: all + pool: + name: ${{ parameters.MACHINE_POOL }} + strategy: + matrix: + Python38_x64: + PythonVersion: '3.8' + Python39_x64: + PythonVersion: '3.9' + Python310_x64: + PythonVersion: '3.10' + Python311_x64: + PythonVersion: '3.11' + Python312_x64: + PythonVersion: '3.12' + variables: + GRADLE_OPTS: '-Dorg.gradle.daemon=false' + VSGenerator: 'Visual Studio 17 2022' + QNN_SDK_ROOTDIR: 'C:\data\qnnsdk\${{parameters.QNN_SDK}}' + steps: + - checkout: self + clean: true + submodules: recursive + + - template: telemetry-steps.yml + + - script: | + DIR C:\data\qnnsdk + displayName: Check available QNN SDKs + + - task: UsePythonVersion@0 + inputs: + versionSpec: $(PythonVersion) + addToPath: true + architecture: 'x64' + + - task: onebranch.pipeline.tsaoptions@1 + displayName: 'OneBranch TSAOptions' + inputs: + tsaConfigFilePath: '$(Build.SourcesDirectory)\.config\tsaoptions.json' + appendSourceBranchName: false + + - task: PythonScript@0 + inputs: + scriptSource: inline + script: | + import sys + np_version = 'numpy==1.21.6' if sys.version_info < (3, 11) else 'numpy==1.24.2' + import subprocess + subprocess.call(['pip', 'install', '-q', 'setuptools', 'wheel', np_version]) + workingDirectory: '$(Build.BinariesDirectory)' + displayName: 'Install python modules' + + - template: download-deps.yml + + - task: PythonScript@0 + displayName: 'Update deps.txt' + inputs: + scriptPath: $(Build.SourcesDirectory)/tools/ci_build/replace_urls_in_deps.py + arguments: --new_dir $(Build.BinariesDirectory)/deps + workingDirectory: $(Build.BinariesDirectory) + + - task: PowerShell@2 + displayName: 'Install ONNX' + inputs: + filePath: '$(Build.SourcesDirectory)/tools/ci_build/github/windows/install_third_party_deps.ps1' + workingDirectory: '$(Build.BinariesDirectory)' + arguments: -cpu_arch x64 -install_prefix $(Build.BinariesDirectory)\RelWithDebInfo\installed -build_config RelWithDebInfo + + - template: set-nightly-build-option-variable-step.yml + + - task: PythonScript@0 + displayName: 'Generate cmake config' + inputs: + scriptPath: '$(Build.SourcesDirectory)\tools\ci_build\build.py' + arguments: > + --config RelWithDebInfo + --build_dir $(Build.BinariesDirectory) + --skip_submodule_sync + --cmake_generator "$(VSGenerator)" + --use_qnn + --qnn_home $(QNN_SDK_ROOTDIR) + --enable_pybind + --parallel --update + $(TelemetryOption) ${{ parameters.BUILD_PY_PARAMETERS }} + workingDirectory: '$(Build.BinariesDirectory)' + + - task: VSBuild@1 + displayName: 'Build' + inputs: + solution: '$(Build.BinariesDirectory)\RelWithDebInfo\onnxruntime.sln' + platform: 'x64' + configuration: RelWithDebInfo + msbuildArchitecture: 'x64' + maximumCpuCount: true + logProjectEvents: true + workingFolder: '$(Build.BinariesDirectory)\RelWithDebInfo' + createLogFile: true + + # Esrp signing + - template: win-esrp-dll.yml + parameters: + FolderPath: '$(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo\onnxruntime\capi' + DisplayName: 'ESRP - Sign Native dlls' + DoEsrp: true + Pattern: '*.pyd,*.dll' + + - task: PythonScript@0 + displayName: 'Build wheel' + inputs: + scriptPath: '$(Build.SourcesDirectory)\setup.py' + arguments: 'bdist_wheel ${{ parameters.BUILD_PY_PARAMETERS }} $(NightlyBuildOption) --wheel_name_suffix=qnn' + workingDirectory: '$(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo' + + - task: CopyFiles@2 + displayName: 'Copy Python Wheel to: $(Build.ArtifactStagingDirectory)' + inputs: + SourceFolder: '$(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo\dist' + Contents: '*.whl' + TargetFolder: '$(Build.ArtifactStagingDirectory)' + + - task: PublishBuildArtifacts@1 + displayName: 'Publish Artifact: ONNXRuntime python wheel' + inputs: + ArtifactName: onnxruntime_qnn + + - script: | + 7z x *.whl + workingDirectory: '$(Build.ArtifactStagingDirectory)' + displayName: 'unzip the package' + + - task: CredScan@3 + displayName: 'Run CredScan' + inputs: + debugMode: false + continueOnError: true + + - task: BinSkim@4 + displayName: 'Run BinSkim' + inputs: + AnalyzeTargetGlob: '+:file|$(Build.ArtifactStagingDirectory)\**\*.dll' + + - task: TSAUpload@2 + displayName: 'TSA upload' + condition: and (succeeded(), eq(variables['Build.SourceBranch'], 'refs/heads/main')) + inputs: + GdnPublishTsaOnboard: false + GdnPublishTsaConfigFile: '$(Build.sourcesDirectory)\.gdn\.gdntsa' + + - template: component-governance-component-detection-steps.yml + parameters: + condition: 'succeeded' diff --git a/winml/lib/Api/HardwareCoreEnumerator.cpp b/winml/lib/Api/HardwareCoreEnumerator.cpp index a89ac561f8860..b6b44690f4f6c 100644 --- a/winml/lib/Api/HardwareCoreEnumerator.cpp +++ b/winml/lib/Api/HardwareCoreEnumerator.cpp @@ -14,7 +14,7 @@ struct LogicalProcessorInformation { struct CoreCounter { uint32_t PhysicalCores = 0; - uint32_t SocDieCores = 0; + uint32_t Num2CacheCores = 0; }; static LogicalProcessorInformation GetLogicalProcessorInfos(LOGICAL_PROCESSOR_RELATIONSHIP relationship) { @@ -75,7 +75,7 @@ static CoreCounter GetNumberOPhysicalAndEngineeringCores() { read += currentProcessorInfo->Size; } - cores.SocDieCores = CountSetBits(dwLevel2GroupMask & ~dwLevel3GroupMask); + cores.Num2CacheCores = CountSetBits(dwLevel2GroupMask & ~dwLevel3GroupMask); return cores; } @@ -83,8 +83,27 @@ uint32_t HardwareCoreEnumerator::DefaultIntraOpNumThreads() { // # of physical cores = # of P cores + # of E Cores + # of Soc Cores. // # of logical cores = # of P cores x 2 (if hyper threading is enabled) + # of E cores + # of Soc Cores. auto cores = GetNumberOPhysicalAndEngineeringCores(); - // We want to use the number of physical cores, but exclude soc cores - return cores.PhysicalCores - cores.SocDieCores; + +#if !defined(_M_ARM64) && !defined(__aarch64__) + const int kVendorID_Intel[3] = {0x756e6547, 0x6c65746e, 0x49656e69}; // "GenuntelineI" + int regs_leaf0[4]; + int regs_leaf7[4]; + __cpuid(regs_leaf0, 0); + __cpuid(regs_leaf7, 0x7); + + auto isIntel = (kVendorID_Intel[0] == regs_leaf0[1]) && (kVendorID_Intel[1] == regs_leaf0[2]) && + (kVendorID_Intel[2] == regs_leaf0[3]); + + auto isHybrid = (regs_leaf7[3] & (1 << 15)); + + if (isIntel && isHybrid) { + // We want to use the number of physical cores, but exclude soc cores + // On Intel Hybrid processors, numSocCores == cores.Num2CacheCores + return cores.PhysicalCores - cores.Num2CacheCores; + } +#endif + + return cores.PhysicalCores; } } // namespace WINMLP