diff --git a/CMakeLists.txt b/CMakeLists.txt index dca2576ed..3a7e50526 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -83,6 +83,8 @@ endif() set(GLOBAL_INCLUDE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/src ${CMAKE_CURRENT_SOURCE_DIR}/thirdparty + # for proto headers + ${PROJECT_BINARY_DIR}/src ) include_directories( diff --git a/maint/script/apply_code_style.sh b/maint/script/apply_code_style.sh index 9b1717788..ab8e04516 100755 --- a/maint/script/apply_code_style.sh +++ b/maint/script/apply_code_style.sh @@ -14,4 +14,4 @@ pushd "${PWD}/../../" > /dev/null | xargs "${CLANG_FORMAT}" -i -style=file 2>&1 \ | grep -v "Is a directory" echo "Done." -popd > /dev/null \ No newline at end of file +popd > /dev/null diff --git a/src/nnfusion/common/CMakeLists.txt b/src/nnfusion/common/CMakeLists.txt index 07cd4c754..dca1d2115 100644 --- a/src/nnfusion/common/CMakeLists.txt +++ b/src/nnfusion/common/CMakeLists.txt @@ -1,7 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. - +add_subdirectory(serialize) set(SRC languageunit.cpp type_info.cpp diff --git a/src/nnfusion/common/serialize/CMakeLists.txt b/src/nnfusion/common/serialize/CMakeLists.txt new file mode 100644 index 000000000..7c18a9a37 --- /dev/null +++ b/src/nnfusion/common/serialize/CMakeLists.txt @@ -0,0 +1,19 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +FILE(TO_NATIVE_PATH ${CMAKE_CURRENT_SOURCE_DIR} PROTOSRC_PATH) +FOREACH(item pbtypes attr_value tensor_shape node_def graph_def) + EXECUTE_PROCESS(COMMAND ${PROTOBUF_PROTOC_EXECUTABLE} --proto_path=${PROTOSRC_PATH} --cpp_out=${CMAKE_CURRENT_BINARY_DIR} ${item}.proto) + FILE(TO_NATIVE_PATH ${item}.pb.h proto_header) + FILE(TO_NATIVE_PATH ${item}.pb.cc proto_source) + list(APPEND SRC ${proto_header} ${proto_source}) +ENDFOREACH(item) + +include_directories(${CMAKE_CURRENT_BINARY_DIR}) +add_library(nnfusion_serialize STATIC ${SRC}) +target_include_directories(nnfusion_serialize SYSTEM PUBLIC + ${GLOBAL_INCLUDE_PATH} +) + +target_compile_options(nnfusion_serialize PRIVATE "-fPIC") +target_link_libraries(nnfusion_serialize ${Protobuf_LIBRARIES}) diff --git a/src/nnfusion/common/serialize/attr_value.proto b/src/nnfusion/common/serialize/attr_value.proto new file mode 100644 index 000000000..8d7bc0bd9 --- /dev/null +++ b/src/nnfusion/common/serialize/attr_value.proto @@ -0,0 +1,58 @@ +syntax = "proto3"; + +package nnfusion.serialize; + +// import "tensor.proto"; +import "tensor_shape.proto"; +import "pbtypes.proto"; + +option cc_enable_arenas = true; + +// Protocol buffer representing the value for an attr used to configure an Op. +// Comment indicates the corresponding attr type. Only the field matching the +// attr type may be filled. +message AttrValue { + message ListValue { + repeated bytes s = 2; // "list(string)" + repeated int64 i = 3 [packed = true]; // "list(int)" + repeated float f = 4 [packed = true]; // "list(float)" + repeated bool b = 5 [packed = true]; // "list(bool)" + repeated PBType type = 6 [packed = true]; // "list(type)" + repeated TensorShapeProto shape = 7; // "list(shape)" + // repeated TensorProto tensor = 8; // "list(tensor)" + // repeated NameAttrList func = 9; // "list(attr)" + } + + oneof value { + bytes s = 2; // "string" + int64 i = 3; // "int" + float f = 4; // "float" + bool b = 5; // "bool" + PBType type = 6; // "type" + TensorShapeProto shape = 7; // "shape" + // TensorProto tensor = 8; // "tensor" + ListValue list = 1; // any "list(...)" + + // "func" represents a function. func.name is a function's name or + // a primitive op's name. func.attr.first is the name of an attr + // defined for that function. func.attr.second is the value for + // that attr in the instantiation. + // NameAttrList func = 10; + + // This is a placeholder only used in nodes defined inside a + // function. It indicates the attr value will be supplied when + // the function is instantiated. For example, let us suppose a + // node "N" in function "FN". "N" has an attr "A" with value + // placeholder = "foo". When FN is instantiated with attr "foo" + // set to "bar", the instantiated node N's attr A will have been + // given the value "bar". + // string placeholder = 9; + } +} + +// A list of attr names and their values. The whole list is attached +// with a string name. E.g., MatMul[T=float]. +// message NameAttrList { +// string name = 1; +// map attr = 2; +// } diff --git a/src/nnfusion/common/serialize/graph_def.proto b/src/nnfusion/common/serialize/graph_def.proto new file mode 100644 index 000000000..03b611acb --- /dev/null +++ b/src/nnfusion/common/serialize/graph_def.proto @@ -0,0 +1,13 @@ +syntax = "proto3"; + +package nnfusion.serialize; + +import "node_def.proto"; + +option cc_enable_arenas = true; + +// Represents the graph of operations +message GraphDef { + repeated NodeDef node = 1; + int32 version = 2; +} \ No newline at end of file diff --git a/src/nnfusion/common/serialize/node_def.proto b/src/nnfusion/common/serialize/node_def.proto new file mode 100644 index 000000000..640d8a3d2 --- /dev/null +++ b/src/nnfusion/common/serialize/node_def.proto @@ -0,0 +1,62 @@ +syntax = "proto3"; + +package nnfusion.serialize; + +import "attr_value.proto"; + +option cc_enable_arenas = true; + +message NodeDef { + // The name given to this operator. Used for naming inputs, + // logging, visualization, etc. Unique within a single GraphDef. + // Must match the regexp "[A-Za-z0-9.][A-Za-z0-9_>./]*". + string name = 1; + + // The operation name. There may be custom parameters in attrs. + // Op names starting with an underscore are reserved for internal use. + string op = 2; + + // Each input is "node:src_output" with "node" being a string name and + // "src_output" indicating which output tensor to use from "node". If + // "src_output" is 0 the ":0" suffix can be omitted. Regular inputs + // may optionally be followed by control inputs that have the format + // "^node". + repeated string input = 3; + + // A (possibly partial) specification for the device on which this + // node should be placed. + // The expected syntax for this string is as follows: + // + // DEVICE_SPEC ::= PARTIAL_SPEC + // + // PARTIAL_SPEC ::= ("/" CONSTRAINT) * + // CONSTRAINT ::= ("job:" JOB_NAME) + // | ("replica:" [1-9][0-9]*) + // | ("task:" [1-9][0-9]*) + // | ("device:" [A-Za-z]* ":" ([1-9][0-9]* | "*") ) + // + // Valid values for this string include: + // * "/job:worker/replica:0/task:1/device:GPU:3" (full specification) + // * "/job:worker/device:GPU:3" (partial specification) + // * "" (no specification) + // + // If the constraints do not resolve to a single device (or if this + // field is empty or not present), the runtime will attempt to + // choose a device automatically. + // string device = 4; + + // Operation-specific graph-construction-time configuration. + // Note that this should include all attrs defined in the + // corresponding OpDef, including those with a value matching + // the default -- this allows the default to change and makes + // NodeDefs easier to interpret on their own. However, if + // an attr with a default is not specified in this list, the + // default will be used. + // The "names" (keys) must match the regexp "[a-z][a-z0-9_]+" (and + // one of the names from the corresponding OpDef's attr field). + // The values must have a type matching the corresponding OpDef + // attr's type field. + // TODO(josh11b): Add some examples here showing best practices. + map attr = 5; + +} \ No newline at end of file diff --git a/src/nnfusion/common/serialize/pbtypes.proto b/src/nnfusion/common/serialize/pbtypes.proto new file mode 100644 index 000000000..9b07fe269 --- /dev/null +++ b/src/nnfusion/common/serialize/pbtypes.proto @@ -0,0 +1,40 @@ +syntax = "proto3"; + +//package nnfusion; +package nnfusion.serialize; + +option cc_enable_arenas = true; + +// (== suppress_warning documentation-presence ==) +// LINT.IfChange +enum PBType { + // Not a legal value for Type. Used to indicate a Type field + // has not been set. + DT_INVALID = 0; + + // Data types that all computation devices are expected to be + // capable to support. + DT_BOOL = 1; + DT_CHAR = 2; + DT_FLOAT = 3; + DT_DOUBLE = 4; + DT_INT8 = 5; + DT_INT16 = 6; + DT_INT32 = 7; + DT_INT64 = 8; + DT_UINT8 = 9; + DT_UINT16 = 10; + DT_UINT32 = 11; + DT_UINT64 = 12; +} + +// For identifying the underlying type of a variant. For variants, the types +// listed here are a subset of the types in the variant type registry, +// corresponding to commonly used variants which must occasionally be +// special-cased. +// enum SpecializedType { +// // Invalid/unknown specialized type. +// ST_INVALID = 0; +// // "tensorflow::TensorList" in the variant type registry. +// ST_TENSOR_LIST = 1; +// } \ No newline at end of file diff --git a/src/nnfusion/common/serialize/tensor_shape.proto b/src/nnfusion/common/serialize/tensor_shape.proto new file mode 100644 index 000000000..943f7f708 --- /dev/null +++ b/src/nnfusion/common/serialize/tensor_shape.proto @@ -0,0 +1,42 @@ +// Protocol buffer representing the shape of tensors. + +syntax = "proto3"; +option cc_enable_arenas = true; + +package nnfusion.serialize; + +// Dimensions of a tensor. +message TensorShapeProto { + // One dimension of the tensor. + message Dim { + // Size of the tensor in that dimension. + // This value must be >= -1, but values of -1 are reserved for "unknown" + // shapes (values of -1 mean "unknown" dimension). Certain wrappers + // that work with TensorShapeProto may fail at runtime when deserializing + // a TensorShapeProto containing a dim value of -1. + int64 size = 1; + + // Optional name of the tensor dimension. + string name = 2; + }; + + // Dimensions of the tensor, such as {"input", 30}, {"output", 40} + // for a 30 x 40 2D tensor. If an entry has size -1, this + // corresponds to a dimension of unknown size. The names are + // optional. + // + // The order of entries in "dim" matters: It indicates the layout of the + // values in the tensor in-memory representation. + // + // The first entry in "dim" is the outermost dimension used to layout the + // values, the last entry is the innermost dimension. This matches the + // in-memory layout of RowMajor Eigen tensors. + // + // If "dim.size()" > 0, "unknown_rank" must be false. + repeated Dim dim = 2; + + // If true, the number of dimensions in the shape is unknown. + // + // If true, "dim.size()" must be 0. + bool unknown_rank = 3; +}; diff --git a/src/nnfusion/common/type/element_type.cpp b/src/nnfusion/common/type/element_type.cpp index a2fc8aaa9..7ffac3933 100644 --- a/src/nnfusion/common/type/element_type.cpp +++ b/src/nnfusion/common/type/element_type.cpp @@ -90,6 +90,39 @@ bool element::Type::nnfusion_element_type_to_dtype_string(const element::Type& n return true; } +bool element::Type::nnfusion_element_type_to_pbtype(const element::Type& ng_et, + nnfusion::serialize::PBType& dtype) +{ + if (ng_et == element::boolean) + dtype = nnfusion::serialize::PBType::DT_BOOL; + else if (ng_et == element::character) + dtype = nnfusion::serialize::PBType::DT_CHAR; + else if (ng_et == element::f32) + dtype = nnfusion::serialize::PBType::DT_FLOAT; + else if (ng_et == element::f64) + dtype = nnfusion::serialize::PBType::DT_DOUBLE; + else if (ng_et == element::i8) + dtype = nnfusion::serialize::PBType::DT_INT8; + else if (ng_et == element::i16) + dtype = nnfusion::serialize::PBType::DT_INT16; + else if (ng_et == element::i32) + dtype = nnfusion::serialize::PBType::DT_INT32; + else if (ng_et == element::i64) + dtype = nnfusion::serialize::PBType::DT_INT64; + else if (ng_et == element::u8) + dtype = nnfusion::serialize::PBType::DT_UINT8; + else if (ng_et == element::u16) + dtype = nnfusion::serialize::PBType::DT_UINT16; + else if (ng_et == element::u32) + dtype = nnfusion::serialize::PBType::DT_UINT32; + else if (ng_et == element::u64) + dtype = nnfusion::serialize::PBType::DT_UINT64; + else + return false; + + return true; +} + element::Type::Type( size_t bitwidth, bool is_real, bool is_signed, bool is_quantized, const std::string& cname) : m_bitwidth{bitwidth} diff --git a/src/nnfusion/common/type/element_type.hpp b/src/nnfusion/common/type/element_type.hpp index 984b0d827..19f1ecd99 100644 --- a/src/nnfusion/common/type/element_type.hpp +++ b/src/nnfusion/common/type/element_type.hpp @@ -26,6 +26,7 @@ #include #include "half/include/half.hpp" +#include "nnfusion/common/serialize/pbtypes.pb.h" #include "nnfusion/common/type/bfloat16.hpp" #include "nnfusion/util/errors.hpp" @@ -80,6 +81,8 @@ namespace nnfusion bool operator<(const Type& other) const; friend std::ostream& operator<<(std::ostream&, const Type&); static std::vector get_known_types(); + static bool nnfusion_element_type_to_pbtype(const Type& ng_et, + nnfusion::serialize::PBType& dtype); static bool nnfusion_element_type_to_dtype_string(const Type& ng_et, std::string& dtype); diff --git a/src/nnfusion/core/graph/CMakeLists.txt b/src/nnfusion/core/graph/CMakeLists.txt index 49e99cc44..41fea91b3 100644 --- a/src/nnfusion/core/graph/CMakeLists.txt +++ b/src/nnfusion/core/graph/CMakeLists.txt @@ -15,4 +15,5 @@ add_library(nnfusion_graph STATIC ${SRC}) target_include_directories(nnfusion_graph SYSTEM PUBLIC ${GLOBAL_INCLUDE_PATH} ) +target_link_libraries(nnfusion_graph nnfusion_serialize) target_compile_options(nnfusion_graph PRIVATE "-fPIC") \ No newline at end of file diff --git a/src/nnfusion/core/graph/graph.cpp b/src/nnfusion/core/graph/graph.cpp index 774b4d780..da319a0a8 100644 --- a/src/nnfusion/core/graph/graph.cpp +++ b/src/nnfusion/core/graph/graph.cpp @@ -5,6 +5,10 @@ #include "graph.hpp" #include "graph_util.hpp" +#include "nnfusion/common/serialize/attr_value.pb.h" +#include "nnfusion/common/serialize/graph_def.pb.h" +#include "nnfusion/common/serialize/pbtypes.pb.h" +#include "nnfusion/common/serialize/tensor_shape.pb.h" #include "nnfusion/util/util.hpp" using namespace nnfusion::graph; @@ -409,6 +413,93 @@ void Graph::set_temporary_pool_size(size_t size) m_temporary_pool_size = size; } +bool Graph::serialize_to_file(const std::string& file_path) +{ + nnfusion::serialize::GraphDef graphdef; + auto nnfusion_nodes = get_ordered_ops(true); + for (auto& nnfusion_node : nnfusion_nodes) + { + NNFUSION_CHECK( + !nnfusion_node->hasAttributes() || + (nnfusion_node->attributeNames().size() == 1 && nnfusion_node->hasAttribute("Alias"))) + << nnfusion_node->get_name() << " has " << nnfusion_node->attributeNames().size() + << " tags including \"Alias\" which cannot be serialized now."; + nnfusion::serialize::NodeDef* node = graphdef.add_node(); + // name + node->set_name(nnfusion_node->get_name()); + // op + node->set_op(nnfusion_node->get_op_type()); + // input + for (auto nnfusion_edge : nnfusion_node->get_in_edges()) + { + if (nnfusion_edge->get_src_output() == kControlSlot) + { + node->add_input("^" + nnfusion_edge->get_src()->get_name()); + } + else + { + node->add_input(nnfusion_edge->get_src()->get_name() + ":" + + std::to_string(nnfusion_edge->get_src_output())); + } + } + // TODO(gbxu): support all nnfusion ops + if (nnfusion_node->get_op_type() == "AllReduce") + { + // tensor_name + nnfusion::serialize::AttrValue tensor_name; + tensor_name.set_s(nnfusion_node->get_name()); + (*node->mutable_attr())["tensor_name"] = tensor_name; + } + // data type + if (nnfusion_node->get_output_size() == 1) + { + nnfusion::serialize::AttrValue data_type; + nnfusion::serialize::PBType dt; + nnfusion::element::Type::nnfusion_element_type_to_pbtype( + nnfusion_node->get_element_type(), dt); + data_type.set_type(dt); + (*node->mutable_attr())["T"] = data_type; + } +#if 0 + // Plan_gen can't parse this now. So just skip it for now. + else + { + nnfusion::serialize::AttrValue_ListValue* _data_types_list = + new nnfusion::serialize::AttrValue_ListValue(); + for (auto nnfusion_output : nnfusion_node->get_outputs()) + { + nnfusion::serialize::PBType dt; + nnfusion::element::Type::nnfusion_element_type_to_pbtype( + nnfusion_output->get_element_type(), dt); + _data_types_list->add_type(dt); + } + nnfusion::serialize::AttrValue _data_types; + _data_types.set_allocated_list(_data_types_list); + (*node->mutable_attr())["T"] = _data_types; + } +#endif + // _output_shapes + nnfusion::serialize::AttrValue_ListValue* _output_shapes_list = + new nnfusion::serialize::AttrValue_ListValue(); + for (auto nnfusion_output : nnfusion_node->get_outputs()) + { + auto shape = _output_shapes_list->add_shape(); + for (auto nnfusion_dim : nnfusion_output->get_shape()) + { + auto dim = shape->add_dim(); + dim->set_size(nnfusion_dim); + } + } + nnfusion::serialize::AttrValue _output_shapes; + _output_shapes.set_allocated_list(_output_shapes_list); + (*node->mutable_attr())["_output_shapes"] = _output_shapes; + } + graphdef.set_version(1); + std::fstream fs(file_path, std::ios::out | std::ios::trunc | std::ios::binary); + graphdef.SerializeToOstream(&fs); + return true; +} + size_t Graph::get_memory_io() { size_t total_io = 0; diff --git a/src/nnfusion/core/graph/graph.hpp b/src/nnfusion/core/graph/graph.hpp index 777373b2e..0873a36c7 100644 --- a/src/nnfusion/core/graph/graph.hpp +++ b/src/nnfusion/core/graph/graph.hpp @@ -3,6 +3,7 @@ #pragma once +#include #include #include #include "gedge.hpp" @@ -125,6 +126,7 @@ namespace nnfusion void set_temporary_pool_size(size_t); size_t get_memory_io(); + bool serialize_to_file(const std::string& file_path); private: // Map from node ids to allocated nodes. nodes_[id] may be nullptr if diff --git a/src/nnfusion/core/kernels/cpu/cpu_langunit.cpp b/src/nnfusion/core/kernels/cpu/cpu_langunit.cpp index c118a755b..7d1455047 100644 --- a/src/nnfusion/core/kernels/cpu/cpu_langunit.cpp +++ b/src/nnfusion/core/kernels/cpu/cpu_langunit.cpp @@ -27,3 +27,5 @@ LU_DEFINE(declaration::worker_thread_pool, "concurrency::NumaAwareThreadPool *worker_thread_pool;\n") LU_DEFINE(declaration::schedule_thread_pool, "concurrency::NumaAwareThreadPool *schedule_thread_pool;\n") +LU_DEFINE(declaration::superscaler_schedule_thread, + "concurrency::NumaAwareThreadPool *superscaler_schedule_thread;\n") \ No newline at end of file diff --git a/src/nnfusion/core/kernels/cpu/cpu_langunit.hpp b/src/nnfusion/core/kernels/cpu/cpu_langunit.hpp index f84872b07..4b5c9e670 100644 --- a/src/nnfusion/core/kernels/cpu/cpu_langunit.hpp +++ b/src/nnfusion/core/kernels/cpu/cpu_langunit.hpp @@ -31,6 +31,7 @@ namespace nnfusion LU_DECLARE(eigen_global_thread_pool_device); LU_DECLARE(worker_thread_pool); LU_DECLARE(schedule_thread_pool); + LU_DECLARE(superscaler_schedule_thread); } } // namespace kernels } // namespace nnfusion diff --git a/src/nnfusion/core/kernels/cuda_gpu/cuda_emitter.hpp b/src/nnfusion/core/kernels/cuda_gpu/cuda_emitter.hpp index 0867393f8..9e6c9f696 100644 --- a/src/nnfusion/core/kernels/cuda_gpu/cuda_emitter.hpp +++ b/src/nnfusion/core/kernels/cuda_gpu/cuda_emitter.hpp @@ -208,6 +208,8 @@ namespace nnfusion GENERIC_OP_LOGGING(); if (!FLAGS_fantares_codegen_server.empty()) { + NNFUSION_LOG(INFO) << "Translate for " << ctx->gnode->get_op_type(); + auto ir = nnfusion::op::get_translation(ctx->gnode); #if 0 std::unordered_set wl = { diff --git a/src/nnfusion/core/kernels/cuda_gpu/cuda_langunit.cpp b/src/nnfusion/core/kernels/cuda_gpu/cuda_langunit.cpp index ea429d047..0a0d4c14a 100644 --- a/src/nnfusion/core/kernels/cuda_gpu/cuda_langunit.cpp +++ b/src/nnfusion/core/kernels/cuda_gpu/cuda_langunit.cpp @@ -11,7 +11,7 @@ using namespace nnfusion::kernels; LU_DEFINE(header::cuda, "#include \n#include \n"); LU_DEFINE(header::cublas, "#include \n"); LU_DEFINE(header::cudnn, "#include \n"); -LU_DEFINE(header::super_scaler, "#include \"super_scaler.h\"\n"); +LU_DEFINE(header::superscaler, "#include \"superscaler.h\"\n"); LU_DEFINE(header::cupti, "#include \n"); LU_DEFINE(header::cuda_prof_api, "#include \n"); LU_DEFINE(header::cuda_fp16, "#include \n"); @@ -947,4 +947,4 @@ void HostApplyLayerNorm( float(epsilon), gamma, beta); } -)"); \ No newline at end of file +)"); diff --git a/src/nnfusion/core/kernels/cuda_gpu/cuda_langunit.hpp b/src/nnfusion/core/kernels/cuda_gpu/cuda_langunit.hpp index ae68e3e72..47c173c63 100644 --- a/src/nnfusion/core/kernels/cuda_gpu/cuda_langunit.hpp +++ b/src/nnfusion/core/kernels/cuda_gpu/cuda_langunit.hpp @@ -13,7 +13,7 @@ namespace nnfusion LU_DECLARE(cuda); LU_DECLARE(cublas); LU_DECLARE(cudnn); - LU_DECLARE(super_scaler); + LU_DECLARE(superscaler); LU_DECLARE(cupti); LU_DECLARE(cuda_prof_api); LU_DECLARE(cuda_fp16); diff --git a/src/nnfusion/core/kernels/cuda_gpu/kernels/allreduce.cpp b/src/nnfusion/core/kernels/cuda_gpu/kernels/allreduce.cpp index 383c9b600..f97840481 100644 --- a/src/nnfusion/core/kernels/cuda_gpu/kernels/allreduce.cpp +++ b/src/nnfusion/core/kernels/cuda_gpu/kernels/allreduce.cpp @@ -18,18 +18,64 @@ namespace nnfusion class SuperScalerAllReduce : public KernelEmitter { public: + string tensor_name; SuperScalerAllReduce(shared_ptr ctx) : KernelEmitter(ctx, "SuperScaler") { + tensor_name = ctx->gnode->get_name(); } LanguageUnit_p emit_function_body() override { LanguageUnit_p _lu(new LanguageUnit(get_function_name())); auto& lu = *_lu; - auto data_size = m_context->inputs.front()->size(false); + auto input0_size = m_context->inputs.front()->size(false); + auto input0_allocated_bytes = m_context->inputs.front()->size(true); + auto code = nnfusion::op::create_code_from_template( + R"(sc_allreduce("@tensorname@", input0, @input0_size@, stream); +if(input0==output0) return; +CUDA_SAFE_CALL(cudaMemcpyAsync(output0, input0, @input0_allocated_bytes@, cudaMemcpyDefault, stream)); +)", + {{"input0_size", input0_size}, + {"input0_allocated_bytes", input0_allocated_bytes}, + {"tensorname", tensor_name}}); // allreduce and applygradient use the same stream. - lu << "super_scaler_all_reduce(input0, output0, " << data_size << ", &stream);"; + lu << code; + return _lu; + } + + LanguageUnit_p emit_function_signature() + { + LanguageUnit_p _lu(new LanguageUnit(this->m_kernel_name + "_sig")); + auto& lu = *_lu; + + vector params; + for (size_t i = 0; i < m_context->inputs.size(); i++) + { + stringstream ss; + ss << m_context->inputs[i]->get_element_type().c_type_string() << "* "; + ss << "input" << i; + params.push_back(ss.str()); + } + + for (size_t i = 0; i < m_context->outputs.size(); i++) + { + stringstream ss; + ss << m_context->outputs[i]->get_element_type().c_type_string() << "* "; + ss << "output" << i; + params.push_back(ss.str()); + } + + for (size_t i = 0; i < m_context->tensors.size(); i++) + { + stringstream ss; + ss << m_context->tensors[i]->get_element_type().c_type_string() << "* "; + ss << m_context->tensors[i]->get_name(); + params.push_back(ss.str()); + } + + lu << "void " + << "(cudaStream_t stream, " << join(params, ", ") << ")"; return _lu; } @@ -37,15 +83,13 @@ namespace nnfusion { LanguageUnit_p _lu(new LanguageUnit(get_function_name() + "_dep")); _lu->require(header::cuda); - _lu->require(header::super_scaler); // This require nccl, mpi - // _lu->require(declaration::allreduce_stream); - // _lu->require(declaration::applygradient_stream); + _lu->require(header::superscaler); return _lu; } }; - } - } -} + } // namespace cuda + } // namespace kernels +} // namespace nnfusion using namespace nnfusion; using namespace nnfusion::kernels; diff --git a/src/nnfusion/core/operators/generic_op/generic_op.hpp b/src/nnfusion/core/operators/generic_op/generic_op.hpp index c93bdf0b5..278dd51e3 100644 --- a/src/nnfusion/core/operators/generic_op/generic_op.hpp +++ b/src/nnfusion/core/operators/generic_op/generic_op.hpp @@ -30,6 +30,7 @@ namespace nnfusion using infersharedmemory_func_t = void (*)(std::shared_ptr gnode); using translate_func_t = std::string (*)(std::shared_ptr gnode); using translate_func_t_v2 = std::string (*)(std::shared_ptr gnode); + std::string get_annotation(std::string translation); // OpConfig(): f_infershape(infershape::copy_shape_from_inputs) { } diff --git a/src/nnfusion/core/operators/generic_op/op_registration.cpp b/src/nnfusion/core/operators/generic_op/op_registration.cpp index 97d1b74db..c98b0f4de 100644 --- a/src/nnfusion/core/operators/generic_op/op_registration.cpp +++ b/src/nnfusion/core/operators/generic_op/op_registration.cpp @@ -101,5 +101,27 @@ namespace nnfusion return options; } + + // + std::string get_annotation(std::string translation) + // + { + // + std::string options; + // + const char annotation[] = "## @annotation: "; + // + int pos = translation.find(annotation); + // + if (pos >= 0) + // + { + // + pos += sizeof(annotation) - 1; + // + options = translation.substr(pos); + // + } + // + + // + if (options.size() > 0) + // + { + // + if (options[0] != '|') + // + options = "|" + options; + // + if (options.back() != '|') + // + options += "|"; + // + } + // + + // + return options; + // + } } // namespace op } // namespace nnfusion diff --git a/src/nnfusion/engine/device/cuda.cpp b/src/nnfusion/engine/device/cuda.cpp index 2a9aac64c..6cfd2c0dd 100644 --- a/src/nnfusion/engine/device/cuda.cpp +++ b/src/nnfusion/engine/device/cuda.cpp @@ -14,6 +14,7 @@ #include "nnfusion/engine/pass/graph/gemm_fusion_pass.hpp" #include "nnfusion/engine/pass/graph/gnode_device_dispatcher.hpp" #include "nnfusion/engine/pass/graph/gradient_weight_mapping_pass.hpp" +#include "nnfusion/engine/pass/graph/graph_serialization_pass.hpp" #include "nnfusion/engine/pass/graph/kernel_fusion_pass.hpp" #include "nnfusion/engine/pass/graph/kernel_profiling_pass.hpp" #include "nnfusion/engine/pass/graph/kernel_selection.hpp" @@ -22,6 +23,7 @@ #include "nnfusion/engine/pass/graph/op_inplace_pass.hpp" #include "nnfusion/engine/pass/graph/pattern_substitution.hpp" #include "nnfusion/engine/pass/graph/runtime_const_folding_pass.hpp" +#include "nnfusion/engine/pass/graph/superscaler_dataparallelism_pass.hpp" #include "nnfusion/engine/pass/graph/vector_dot_transpose_pass.hpp" #include "nnfusion/engine/pass/extract_graph_signature.hpp" @@ -49,6 +51,10 @@ CudaEngine::CudaEngine() g_passes->push_back(make_shared()); g_passes->push_back(make_shared()); g_passes->push_back(make_shared()); + //superscaler pass + g_passes->push_back(make_shared()); + g_passes->push_back(make_shared()); + g_passes->push_back(make_shared()); g_passes->push_back(make_shared()); diff --git a/src/nnfusion/engine/pass/codegen/codegen_langunit.cpp b/src/nnfusion/engine/pass/codegen/codegen_langunit.cpp index 2490097d8..e3e1f3584 100644 --- a/src/nnfusion/engine/pass/codegen/codegen_langunit.cpp +++ b/src/nnfusion/engine/pass/codegen/codegen_langunit.cpp @@ -41,25 +41,22 @@ find_package(Threads REQUIRED) target_link_libraries(${TARGET_NAME} Threads::Threads) )"); -LU_DEFINE(nnfusion::codegen::cmake::super_scaler, +LU_DEFINE(nnfusion::codegen::cmake::superscaler_cuda, R"( -find_package(MPI) -include_directories(${MPI_INCLUDE_PATH}) -find_library(SUPER_SCALER_LIBRARIES libsuper_scaler.so ${CMAKE_CURRENT_SOURCE_DIR}) -target_link_libraries(${TARGET_NAME} - ${MPI_LIBRARIES} - ${SUPER_SCALER_LIBRARIES} - nccl) +if (NOT TARGET superscaler) +set(TARGET_GPU_PLATFORM "CUDA" CACHE STRING "Choose your GPU platform: CUDA or ROCm") +include(superscaler/superscaler.cmake) +endif() +target_link_libraries(${TARGET_NAME} superscaler) )"); -LU_DEFINE(nnfusion::codegen::cmake::rocm_super_scaler, +LU_DEFINE(nnfusion::codegen::cmake::superscaler_rocm, R"( -find_package(MPI) -include_directories(${MPI_INCLUDE_PATH}) -find_library(ssrocm libsuper_scaler_rocm.so ${CMAKE_CURRENT_SOURCE_DIR}) -target_link_libraries(${TARGET_NAME} - ${MPI_LIBRARIES} - ${ssrocm} +if (NOT TARGET superscaler) +set(TARGET_GPU_PLATFORM "ROCm" CACHE STRING "Choose your GPU platform: CUDA or ROCm") +include(superscaler/superscaler.cmake) +endif() +target_link_libraries(${TARGET_NAME} superscaler) )"); LU_DEFINE(nnfusion::codegen::cmake::cuda_lib, diff --git a/src/nnfusion/engine/pass/codegen/codegen_langunit.hpp b/src/nnfusion/engine/pass/codegen/codegen_langunit.hpp index 8d50e1535..90f0f679b 100644 --- a/src/nnfusion/engine/pass/codegen/codegen_langunit.hpp +++ b/src/nnfusion/engine/pass/codegen/codegen_langunit.hpp @@ -15,8 +15,8 @@ namespace nnfusion LU_DECLARE(mlas); LU_DECLARE(threadpool); LU_DECLARE(threads); - LU_DECLARE(super_scaler); - LU_DECLARE(rocm_super_scaler); + LU_DECLARE(superscaler_cuda); + LU_DECLARE(superscaler_rocm); LU_DECLARE(cuda_lib); LU_DECLARE(rocm_lib); } // namespace cmake diff --git a/src/nnfusion/engine/pass/codegen/cuda_codegen_pass.cpp b/src/nnfusion/engine/pass/codegen/cuda_codegen_pass.cpp index bdfa00c28..42c759663 100644 --- a/src/nnfusion/engine/pass/codegen/cuda_codegen_pass.cpp +++ b/src/nnfusion/engine/pass/codegen/cuda_codegen_pass.cpp @@ -13,6 +13,8 @@ #include "nnfusion/core/kernels/kernel_emitter.hpp" #include "nnfusion/core/kernels/kernel_registration.hpp" +#include + using namespace nnfusion; using namespace nnfusion::graph; using namespace nnfusion::kernels; @@ -52,7 +54,7 @@ void CudaCodegenPass::set_global_member(std::shared_ptr ctx, } } - superscaler_enable = global_required.count("header::super_scaler") > 0; + superscaler_enable = global_required.count("header::superscaler") > 0; return; } @@ -65,13 +67,6 @@ void CudaCodegenPass::initialize(std::shared_ptr ctx, projgen->lup_codegen->pwd = m_codegen_folder; projgen->lup_codegen->write_to = "nnfusion_rt.cu"; auto& copy_templates = projgen->lup_codegen->copy_templates; - if (superscaler_enable) - { - copy_templates.emplace_back("super_scaler/super_scaler.h", "./super_scaler.h"); - NNFUSION_LOG(NNFUSION_WARNING) << "libsuper_scaler.so should be copied from " - "(build)/src/tools/nnfusion/templates/super_scaler/"; - copy_templates.emplace_back("super_scaler/libsuper_scaler.so", "./libsuper_scaler.so"); - } copy_templates.emplace_back("image_tests/image_test.cpp", "./image_tests/image_test.cpp"); copy_templates.emplace_back("image_tests/CMakeLists_cuda.txt", "./image_tests/CMakeLists.txt"); @@ -97,11 +92,34 @@ void CudaCodegenPass::initialize(std::shared_ptr ctx, copy_folder.push_back(threadpool_path); } + if (superscaler_enable) + { + std::string superscaler_path = std::string(path) + std::string("/superscaler"); + copy_folder.push_back(superscaler_path); + } + // setup main_block auto& lu_init_begin = *(projgen->lup_init->begin); { - lu_init_begin << "\nextern \"C\" void cuda_init()\n{\n"; - lu_init_begin << "CUDA_SAFE_CALL(cudaDeviceReset());\n"; + if (superscaler_enable) + { + lu_init_begin << "\nextern \"C\" void cuda_init(const char* resource_dir)\n{\n"; + lu_init_begin << "CUDA_SAFE_CALL(cudaDeviceReset());\n"; + lu_init_begin << + R"(int device_id; +int host_id; +sc_init(resource_dir); +sc_get_host_id(&host_id); +sc_get_device_id(&device_id); +printf("[host_id: %d device_id: %d] is running\n", host_id, device_id); +CUDA_SAFE_CALL(cudaSetDevice(device_id)); +)"; + } + else + { + lu_init_begin << "\nextern \"C\" void cuda_init()\n{\n"; + lu_init_begin << "CUDA_SAFE_CALL(cudaDeviceReset());\n"; + } } auto& lu_init_end = *(projgen->lup_init->end); @@ -138,10 +156,20 @@ void CudaCodegenPass::initialize(std::shared_ptr ctx, auto& lu_exit_begin = *(projgen->lup_exit->begin); { lu_exit_begin << "\nextern \"C\" void cuda_free()\n{\n"; + if (superscaler_enable) + { + lu_exit_begin << + R"(int device_id; +sc_get_device_id(&device_id); +CUDA_SAFE_CALL(cudaSetDevice(device_id)); +)"; + } } auto& lu_exit_end = *(projgen->lup_exit->end); { + if (superscaler_enable) + lu_exit_end << "sc_finalize();\n"; lu_exit_end << "}\n\n"; } @@ -276,7 +304,10 @@ bool CudaCodegenPass::collect_funcs(std::shared_ptr ctx, std::string std_thread_func_call = std::string("auto ") + std_thread_func_name + std::string(" = std::bind") + thread_call_str; lu_new_caller << std_thread_func_call; - std::string t_threadpool_call = std::string("schedule_thread_pool->Schedule("); + std::string t_threadpool_call = + (superscaler_enable && thread_name != "dev0_thread") + ? std::string("superscaler_schedule_thread->Schedule(") + : std::string("schedule_thread_pool->Schedule("); t_threadpool_call += (std_thread_func_name + std::string(");\n")); lu_new_caller << t_threadpool_call; } @@ -287,6 +318,13 @@ bool CudaCodegenPass::collect_funcs(std::shared_ptr ctx, { lu_begin << "extern \"C\" void " << thread_name << "("; lu_begin << thread_call_paras << ")\n{\n"; + if (superscaler_enable) + { + lu_begin << R"(int device_id; +sc_get_device_id(&device_id); +CUDA_SAFE_CALL(cudaSetDevice(device_id)); +)"; + } } LanguageUnit_p end = @@ -383,12 +421,6 @@ std::vector>> pairs.end(), [](std::pair>& a, std::pair>& b) { - if (a.first.find("default_") != string::npos) - return false; - - if (b.first.find("default") != string::npos) - return false; - int pos_a = a.first.find("async_"); int pos_b = b.first.find("async_"); if (pos_a >= 0 && pos_b >= 0) @@ -400,8 +432,10 @@ std::vector>> d2 = d2.substr(0, d2.find(delimiter)); return std::stoi(d1) < std::stoi(d2); } - - return a.first > b.first; + else + { + return a.first > b.first; + } }); } @@ -618,6 +652,8 @@ nnfusion::LanguageUnit_p CudaCodegenPass::func_call_codegen(nnfusion::ir::Instru bool CudaCodegenPass::collect_stream(std::shared_ptr ctx, std::shared_ptr tu) { + std::regex r(R"(CUDA_SAFE_CALL\(cudaSetDevice\(\d)"); + //stream NNFUSION_CHECK_NOT_NULLPTR(device_async_manager); if (device_async_manager && device_async_manager->num_stream() > 0) @@ -626,9 +662,23 @@ bool CudaCodegenPass::collect_stream(std::shared_ptr ctx, auto stream_init = device_async_manager->emit_stream_init(); auto stream_destroy = device_async_manager->emit_stream_destroy(); - stream_init->require(stream_decl); - add_init_and_exit_pair(stream_init, stream_destroy); + string stream_init_code_old = stream_init->get_code(); + string stream_destroy_code_old = stream_destroy->get_code(); + string stream_init_code = + (superscaler_enable ? std::regex_replace(stream_init_code_old, r, "// $0") + : stream_init_code_old); + string stream_destroy_code = + (superscaler_enable ? std::regex_replace(stream_destroy_code_old, r, "// $0") + : stream_destroy_code_old); + LanguageUnit_p stream_init_lu( + new LanguageUnit(stream_init->get_symbol(), stream_init_code)); + LanguageUnit_p stream_destroy_lu( + new LanguageUnit(stream_destroy->get_symbol(), stream_destroy_code)); + + stream_init_lu->require(stream_decl); + add_init_and_exit_pair(stream_init_lu, stream_destroy_lu); } + //event if (device_async_manager && device_async_manager->num_event() > 0) { @@ -636,8 +686,69 @@ bool CudaCodegenPass::collect_stream(std::shared_ptr ctx, auto event_init = device_async_manager->emit_event_init(); auto event_destroy = device_async_manager->emit_event_destroy(); - event_init->require(event_decl); - add_init_and_exit_pair(event_init, event_destroy); + string event_init_code_old = event_init->get_code(); + string event_destroy_code_old = event_destroy->get_code(); + string event_init_code = + (superscaler_enable ? std::regex_replace(event_init_code_old, r, "// $0") + : event_init_code_old); + string event_destroy_code = + (superscaler_enable ? std::regex_replace(event_destroy_code_old, r, "// $0") + : event_destroy_code_old); + + LanguageUnit_p event_init_lu(new LanguageUnit(event_init->get_symbol(), event_init_code)); + LanguageUnit_p event_destroy_lu( + new LanguageUnit(event_destroy->get_symbol(), event_destroy_code)); + + event_init_lu->require(event_decl); + add_init_and_exit_pair(event_init_lu, event_destroy_lu); + } + + return true; +} + +bool CudaCodegenPass::collect_mem(std::shared_ptr ctx, + std::shared_ptr tu) +{ + if (!tu) + return false; + auto mem_pair = create_init_and_exit_pair("MEM_ALLOC", + "MEM_FREE"); + auto lup_mem_alloc = mem_pair.first; + auto lup_mem_free = mem_pair.second; + auto& allocator_list = tu->memory_allocator_factory->get_allocator_list(); + + size_t total_alloc = 0; + for (const auto& allocator : allocator_list) + { + total_alloc += allocator.second->max_allocated(); + } + LanguageUnit_p total = std::make_shared( + "total_memory", "// total memory:" + to_string(total_alloc) + "\n"); + lup_mem_alloc->unit_vec.push_back(total); + + std::regex r(R"(CUDA_SAFE_CALL\(cudaSetDevice\(\d)"); + + for (const auto& allocator : allocator_list) + { + auto init = allocator.second->emit_memory_init(); + auto alloc = allocator.second->emit_memory_alloc(); + auto free = allocator.second->emit_memory_free(); + + string alloc_code_old = alloc->get_code(); + string free_code_old = free->get_code(); + + string alloc_code = + (superscaler_enable ? std::regex_replace(alloc_code_old, r, "// $0") : alloc_code_old); + string free_code = + (superscaler_enable ? std::regex_replace(free_code_old, r, "// $0") : free_code_old); + + LanguageUnit_p alloc_lu(new LanguageUnit(alloc->get_symbol(), alloc_code)); + LanguageUnit_p free_lu(new LanguageUnit(free->get_symbol(), free_code)); + + lup_mem_alloc->unit_vec.push_back(alloc_lu); + lup_mem_alloc->require(init); + lup_mem_free->unit_vec.push_back(free_lu); + lup_mem_free->require(init); } return true; @@ -687,6 +798,8 @@ bool CudaCodegenPass::modify_codegen() { projgen->lup_codegen->require(header::threadpool); projgen->lup_codegen->require(declaration::schedule_thread_pool); + if (superscaler_enable) + projgen->lup_codegen->require(declaration::superscaler_schedule_thread); projgen->lup_codegen->require(header::barrier); // auto thread_decl = host_async_manager->emit_stream_decl(); // projgen->lup_codegen->require(thread_decl); @@ -723,6 +836,25 @@ bool CudaCodegenPass::modify_codegen() { lu_schedule_thread_pool_del << "delete schedule_thread_pool;\n"; } + + if (superscaler_enable) + { + auto superscaler_schedule_thread_pair = + create_init_and_exit_pair( + "init_superscaler_schedule_thread", "del_superscaler_schedule_thread"); + auto lup_superscaler_schedule_thread_pair_init = superscaler_schedule_thread_pair.first; + auto lup_superscaler_schedule_thread_pair_del = superscaler_schedule_thread_pair.second; + + auto& lu_superscaler_schedule_thread_init = *lup_superscaler_schedule_thread_pair_init; + { + lu_superscaler_schedule_thread_init + << "superscaler_schedule_thread = new concurrency::NumaAwareThreadPool(1,1);\n"; + } + auto& lu_superscaler_schedule_thread_del = *lup_superscaler_schedule_thread_pair_del; + { + lu_superscaler_schedule_thread_del << "delete superscaler_schedule_thread;\n"; + } + } } if (host_async_manager && host_async_manager->num_event() > 0) @@ -813,7 +945,10 @@ void CudaCodegenPass::create_header_file(std::shared_ptr ctx lu_header << params; lu_header << ");\n"; - lu_header << "extern \"C\" void cuda_init();\n"; + if (superscaler_enable) + lu_header << "extern \"C\" void cuda_init(const char*);\n"; + else + lu_header << "extern \"C\" void cuda_init();\n"; lu_header << "extern \"C\" void cuda_free();\n"; @@ -859,10 +994,17 @@ void CudaCodegenPass::create_main_file(std::shared_ptr ctx, LanguageUnit d2hcopy("d2hcopy"); LanguageUnit fillval("fillval"); - lu_main << "int main(void)"; + lu_main << "int main(int argc, char *argv[])"; lu_main.block_begin(); { - lu_main << "\ncuda_init();\n\n"; + if (superscaler_enable) + { + lu_main << "\nif(!argv[1]) {throw std::runtime_error(\"superscaler resource dir is not " + "given!\"); }\n\n"; + lu_main << "\ncuda_init(argv[1]);\n\n"; + } + else + lu_main << "\ncuda_init();\n\n"; for (size_t i = 0; i < tu->arg.size(); i++) { @@ -1066,6 +1208,7 @@ set(CMAKE_CXX_FLAGS_RELEASE "-O2") find_package(CUDA) set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} ${CUDA_ARCH}") +# set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -ftemplate-depth=4096 -gencode arch=compute_60,code=sm_60 -gencode arch=compute_61,code=sm_61 -gencode arch=compute_70,code=sm_70 -gencode arch=compute_75,code=sm_75") set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -O2") set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -cudart shared") )"; @@ -1098,16 +1241,24 @@ set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -cudart shared") lu << nnfusion::codegen::cmake::threads->get_code(); } - if (global_required.count("header::super_scaler") > 0) + if (superscaler_enable) { - // add super_scaler - lu << nnfusion::codegen::cmake::super_scaler->get_code(); + // add superscaler + lu << nnfusion::codegen::cmake::superscaler_cuda->get_code(); } } lu << R"( cuda_add_executable(main_test main_test.cpp) target_link_libraries(main_test ${TARGET_NAME}) + +if(EXISTS "${CMAKE_BINARY_DIR}/Constant") +else() +add_custom_command( + TARGET ${TARGET_NAME} + POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy_directory ${CMAKE_SOURCE_DIR}/Constant ${CMAKE_BINARY_DIR}/Constant +) +endif() )"; return; } diff --git a/src/nnfusion/engine/pass/codegen/cuda_codegen_pass.hpp b/src/nnfusion/engine/pass/codegen/cuda_codegen_pass.hpp index 2a3399e8b..47496cd26 100644 --- a/src/nnfusion/engine/pass/codegen/cuda_codegen_pass.hpp +++ b/src/nnfusion/engine/pass/codegen/cuda_codegen_pass.hpp @@ -39,6 +39,8 @@ namespace nnfusion std::shared_ptr tu); virtual bool collect_stream(std::shared_ptr ctx, std::shared_ptr tu) override; + virtual bool collect_mem(std::shared_ptr ctx, + std::shared_ptr tu) override; virtual bool collect_funcs(std::shared_ptr ctx, std::shared_ptr tu) override; virtual std::vector>> @@ -71,4 +73,4 @@ namespace nnfusion bool superscaler_enable = false; }; } -} \ No newline at end of file +} diff --git a/src/nnfusion/engine/pass/codegen/rocm_codegen_pass.cpp b/src/nnfusion/engine/pass/codegen/rocm_codegen_pass.cpp index 48becc5cc..dbcb9e8b1 100644 --- a/src/nnfusion/engine/pass/codegen/rocm_codegen_pass.cpp +++ b/src/nnfusion/engine/pass/codegen/rocm_codegen_pass.cpp @@ -30,6 +30,8 @@ void RocmCodegenPass::initialize(std::shared_ptr ctx, projgen->lup_codegen->write_to = "nnfusion_rt.cu"; auto& copy_templates = projgen->lup_codegen->copy_templates; copy_templates.emplace_back("rocm_adapter/rocm_adapter.h", "./rocm_adapter.h"); + // copy_templates.emplace_back("rocm_adapter/fastgen_for_sliced_kernels.sh", + // "./fastgen_for_sliced_kernels.sh"); // NNFUSION_CHECK(0 == system("chmod a+x fastgen_for_sliced_kernels.sh")); copy_templates.emplace_back("image_tests/image_test.cpp", "./image_tests/image_test.cpp"); copy_templates.emplace_back("image_tests/CMakeLists_rocm.txt", "./image_tests/CMakeLists.txt"); @@ -50,11 +52,8 @@ void RocmCodegenPass::initialize(std::shared_ptr ctx, if (superscaler_enable) { - copy_templates.emplace_back("super_scaler/super_scaler.h", "./super_scaler.h"); - NNFUSION_LOG(NNFUSION_WARNING) << "libsuper_scaler_rocm.so should be copied from " - "(build)/src/tools/nnfusion/templates/super_scaler/"; - copy_templates.emplace_back("super_scaler/libsuper_scaler_rocm.so", - "./libsuper_scaler_rocm.so"); + std::string superscaler_path = std::string(path) + std::string("/superscaler"); + copy_folder.push_back(superscaler_path); } copy_templates.emplace_back("image_tests/image_test.cpp", "./image_tests/image_test.cpp"); copy_templates.emplace_back("image_tests/CMakeLists_rocm.txt", "./image_tests/CMakeLists.txt"); @@ -163,10 +162,10 @@ set(CMAKE_CXX_FLAGS "-O2 -Wno-ignored-attributes -Wno-duplicate-decl-specifier") // add rocm_lib lu << nnfusion::codegen::cmake::rocm_lib->get_code(); - if (global_required.count("header::super_scaler") > 0) + if (superscaler_enable) { - // add super_scaler - lu << nnfusion::codegen::cmake::rocm_super_scaler->get_code(); + // add superscaler + lu << nnfusion::codegen::cmake::superscaler_rocm->get_code(); } } @@ -174,6 +173,13 @@ set(CMAKE_CXX_FLAGS "-O2 -Wno-ignored-attributes -Wno-duplicate-decl-specifier") add_executable(main_test main_test.cpp) target_link_libraries(main_test ${TARGET_NAME}) +# if(EXISTS "${CMAKE_BINARY_DIR}/Constant") +# else() +# add_custom_command( +# TARGET ${TARGET_NAME} +# POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy_directory ${CMAKE_SOURCE_DIR}/Constant ${CMAKE_BINARY_DIR}/Constant +# ) +# endif() )"; return; } @@ -284,7 +290,8 @@ bool RocmCodegenPass::after_projgen() // // fast compile script for dynamic shared lib // nnfusion::codegen::copy_file_from_templates("rocm_adapter/fastgen_for_sliced_kernels.sh", // "./fastgen_for_sliced_kernels.sh"); + // NNFUSION_CHECK(0 == system("chmod a+x fastgen_for_sliced_kernels.sh")); } NNFUSION_CHECK(chdir(cd) == 0); return true; -} \ No newline at end of file +} diff --git a/src/nnfusion/engine/pass/graph/CMakeLists.txt b/src/nnfusion/engine/pass/graph/CMakeLists.txt index 73874901a..5f9cf5d23 100644 --- a/src/nnfusion/engine/pass/graph/CMakeLists.txt +++ b/src/nnfusion/engine/pass/graph/CMakeLists.txt @@ -19,9 +19,12 @@ set(SRC assign_async_info_pass.cpp kernel_profiling_pass.cpp runtime_const_folding_pass.cpp + common_subexpression_elimination_pass.cpp pattern_substitution.cpp batchnorm_inference_folding_pass.cpp + autodiff_pass.cpp dot_transpose_pass.cpp + superscaler_dataparallelism_pass.cpp common_subexpression_elimination_pass.cpp autodiff_pass.cpp ) @@ -33,5 +36,5 @@ target_include_directories(nnfusion_engine_pass_graph SYSTEM PUBLIC target_compile_options(nnfusion_engine_pass_graph PRIVATE "-fPIC") target_link_libraries(nnfusion_engine_pass_graph PRIVATE nnfusion_common nnfusion_engine_pass_graph_blockfusion -gflags ${CURL_LIBRARIES} nnfusion_engine_pass_graph_autodiff +gflags ${CURL_LIBRARIES} ${WholeArchiveFlag} nnfusion_engine_pass_graph_autodiff ${NoWholeArchiveFlag} ) diff --git a/src/nnfusion/engine/pass/graph/graph_pass.cpp b/src/nnfusion/engine/pass/graph/graph_pass.cpp index bb363650d..875ada67c 100644 --- a/src/nnfusion/engine/pass/graph/graph_pass.cpp +++ b/src/nnfusion/engine/pass/graph/graph_pass.cpp @@ -29,4 +29,4 @@ DEFINE_string(fantares_codegen_server, "", "Antares codegen server address and port, format: :"); -DECLARE_string(fdefault_device); \ No newline at end of file +DECLARE_string(fdefault_device); diff --git a/src/nnfusion/engine/pass/graph/graph_serialization_pass.hpp b/src/nnfusion/engine/pass/graph/graph_serialization_pass.hpp new file mode 100644 index 000000000..e4ed78b6e --- /dev/null +++ b/src/nnfusion/engine/pass/graph/graph_serialization_pass.hpp @@ -0,0 +1,29 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +#pragma once + +#include +#include "graph_pass_base.hpp" +using namespace nnfusion::graph; +DEFINE_string(fnnfusion_graph_path, "./nnfusion_graph.pb", "path to save nnfusion graph file."); +DEFINE_bool(fenable_export_graph, false, "enable exporting graph."); +namespace nnfusion +{ + namespace pass + { + namespace graph + { + class GraphSerializationPass : public GraphPassBase + { + public: + bool run_on_graph(std::shared_ptr& graph) + { + if (FLAGS_fenable_export_graph) + graph->serialize_to_file(FLAGS_fnnfusion_graph_path); + return true; + } + }; + } + } +} diff --git a/src/nnfusion/engine/pass/graph/op_inplace_pass.cpp b/src/nnfusion/engine/pass/graph/op_inplace_pass.cpp index 10e0a17de..e74b74323 100644 --- a/src/nnfusion/engine/pass/graph/op_inplace_pass.cpp +++ b/src/nnfusion/engine/pass/graph/op_inplace_pass.cpp @@ -71,6 +71,12 @@ bool OpInplacePass::run_on_graph(std::shared_ptr& graph) } } + else if (node->get_op_type() == "AllReduce") + { + auto op = std::dynamic_pointer_cast(node->get_op_ptr()); + AddInplace(op, 0, 0, false); + } + else if (nnfusion::op::get_annotation(nnfusion::op::get_translation(node)) .find("|memcpy|") != string::npos) { diff --git a/src/nnfusion/engine/pass/graph/superscaler_dataparallelism_pass.cpp b/src/nnfusion/engine/pass/graph/superscaler_dataparallelism_pass.cpp new file mode 100644 index 000000000..8602ae88f --- /dev/null +++ b/src/nnfusion/engine/pass/graph/superscaler_dataparallelism_pass.cpp @@ -0,0 +1,301 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +#include "superscaler_dataparallelism_pass.hpp" +#include +#include "nnfusion/core/graph/gnode.hpp" +#include "nnfusion/core/graph/graph.hpp" +#include "nnfusion/core/operators/generic_op/generic_op.hpp" +#include "nnfusion/core/operators/op_define/allreduce.hpp" +#include "nnfusion/core/operators/op_define/constant.hpp" +#include "nnfusion/core/operators/op_define/result.hpp" +using namespace nnfusion::graph; +using namespace nnfusion::op; +using namespace nnfusion::pass::graph; +using namespace std; +DEFINE_bool(fadd_sc_allreduce, false, "Add Allreduce operater after ApplyGradient operator."); +DEFINE_bool(fadd_sc_allreduce_fusion, + false, + "Add fused sc Allreduce operater after ApplyGradient operator."); +DEFINE_int32(sc_allreduce_fusion_num, -1, "set the number of adjacent allreduce_op to fuse."); +DEFINE_int32(sc_allreduce_fusion_size, + -1, + "set the floats of data to fuse: 67108864 is recommended."); +DEFINE_int32(sc_allreduce_fusion_time, -1, "set the timeout to fuse: 1000 millisecond by default."); + +#define SC_ALLREDUCE_DEBUG +int SuperScalerDataParallelismPass::get_gradient_from_apply(std::shared_ptr apply_node) +{ + // TODO(gbxu): adapt for more apply op. A general way: provide API of quering "grad" from op def. + if (apply_node->get_op_type() == "ApplyGradientDescent" || + apply_node->get_op_type() == "ApplyGradient") + { + int weight_index = (apply_node->get_in_edge(0)->get_src()->is_variable() || + (apply_node->get_in_edge(0)->get_src()->is_parameter() && + std::dynamic_pointer_cast( + apply_node->get_in_edge(0)->get_src()->get_op_ptr()) + ->require_grad())) + ? 0 + : 1; + return (weight_index + 1) % 2; + } + else + { + return -1; + } +} + +std::vector> SuperScalerDataParallelismPass::group_gradient_apply( + std::map, std::shared_ptr>> hash_to_gradient_apply) +{ + std::vector> gradient_key_subgroups; + int sc_allreduce_fusion_num = FLAGS_sc_allreduce_fusion_num; + int sc_allreduce_fusion_size = FLAGS_sc_allreduce_fusion_size; + int sc_allreduce_fusion_time = FLAGS_sc_allreduce_fusion_time; + NNFUSION_LOG(INFO) << "[sc dp pass] sc_allreduce_fusion_num:" << sc_allreduce_fusion_num; + NNFUSION_LOG(INFO) << "[sc dp pass] sc_allreduce_fusion_size:" << sc_allreduce_fusion_size; + NNFUSION_LOG(INFO) << "[sc dp pass] sc_allreduce_fusion_time:" << sc_allreduce_fusion_time; + if (sc_allreduce_fusion_num <= 0 && sc_allreduce_fusion_size <= 0 && + sc_allreduce_fusion_time <= 0) + { + sc_allreduce_fusion_num = hash_to_gradient_apply.size(); // concat all allreduce into one + NNFUSION_LOG(INFO) << "[sc dp pass] reset sc_allreduce_fusion_num:" + << sc_allreduce_fusion_num; + } + std::vector subgroup; + std::vector fused_sizes; + if (sc_allreduce_fusion_num > 0) + { + int curr_fuse_size = 0; + for (int i = 0; i < hash_to_gradient_apply.size(); i++) + { + int curr = shape_size(hash_to_gradient_apply[i].first->get_shape()); + curr_fuse_size += curr; + // allreduce nodes are adjacent and sorted from back to front when backward by default + subgroup.push_back(i); + if (subgroup.size() >= sc_allreduce_fusion_num) + { + gradient_key_subgroups.push_back(subgroup); + fused_sizes.push_back(curr_fuse_size); + subgroup.clear(); + curr_fuse_size = 0; + } + } + if (subgroup.size() != 0) // fuse the remaining allreduce nodes + { + gradient_key_subgroups.push_back(subgroup); + fused_sizes.push_back(curr_fuse_size); + subgroup.clear(); + curr_fuse_size = 0; + } + } + else + { + // timeout and buffer_size + NNFUSION_CHECK(sc_allreduce_fusion_time == -1) + << "now sc_allreduce_fusion_time is not supported."; + int curr_fuse_size = 0; + for (int i = 0; i < hash_to_gradient_apply.size(); i++) + { + // TODO: timeout mechanism + int curr = shape_size(hash_to_gradient_apply[i].first->get_shape()); + if (curr_fuse_size + curr > sc_allreduce_fusion_size) + { + gradient_key_subgroups.push_back(subgroup); + fused_sizes.push_back(curr_fuse_size); + subgroup.clear(); + curr_fuse_size = 0; + } + subgroup.push_back(i); + curr_fuse_size += curr; + } + if (subgroup.size() != 0) // fuse the remaining allreduce nodes + { + gradient_key_subgroups.push_back(subgroup); + fused_sizes.push_back(curr_fuse_size); + subgroup.clear(); + curr_fuse_size = 0; + } + } + return gradient_key_subgroups; +} + +std::shared_ptr SuperScalerDataParallelismPass::concat_into_one( + std::shared_ptr& graph, + std::vector subgroup, + std::map, std::shared_ptr>> hash_to_gradient_apply) +{ + // gradient->reshape->concat + GNodeIndexVector concat_inputs; + for (int i : subgroup) + { + auto gradient_apply = hash_to_gradient_apply[i]; + auto gradient_node = gradient_apply.first; + int n = 0; + nnfusion::AxisVector order = nnfusion::AxisVector(gradient_node->get_shape().size()); + std::generate(order.begin(), order.end(), [&n]() { return n++; }); + auto apply_node = gradient_apply.second; + auto reshape_op = std::make_shared( + order, + nnfusion::Shape(1, shape_size(gradient_node->get_shape()))); // AxisVector={0, 1..} + add_inplace(reshape_op, 0, 0, false); + auto reshape_node = graph->add_node_and_edge(reshape_op, {gradient_node}); // output_index=0 + concat_inputs.push_back(GNodeIndex(reshape_node, 0)); + } + auto concat_op = std::make_shared(0); + auto first_gradient_node = hash_to_gradient_apply[subgroup[0]].first; + concat_op->set_name(first_gradient_node->get_name() + "_fusion_concat_node"); + std::shared_ptr concat_node = graph->add_node_and_edge(concat_op, {concat_inputs}); + return concat_node; +} + +std::vector, int>> SuperScalerDataParallelismPass::split_from_one( + std::shared_ptr& graph, + std::map, std::shared_ptr>> hash_to_gradient_apply, + std::shared_ptr allreduce_node, + std::vector subgroup) +{ + std::vector, int>> allreduced_gradients_index; + size_t cursor = 0; + std::vector lower{0}; + std::vector upper{0}; + size_t allreduced_tensor_size = shape_size(allreduce_node->get_shape()); + for (int i : subgroup) + { + auto gradient_apply = hash_to_gradient_apply[i]; + auto gradient_node = gradient_apply.first; + // allreduce->slice + nnfusion::Shape gradient_shape = + gradient_node->get_shape(); // default get_output_shape(output_index=0) + cursor += shape_size(gradient_shape); + upper[0] = cursor; + NNFUSION_CHECK(cursor <= allreduced_tensor_size) << "slice range is out of buffer"; + auto slice_op = std::make_shared(lower, upper); + lower[0] = cursor; + slice_op->set_name(gradient_node->get_name() + "_fusion_slice_node"); + auto slice_node = graph->add_node_and_edge(slice_op, {allreduce_node}); + // allreduce->slice->reshape + auto reshape_op = std::make_shared(nnfusion::AxisVector{0}, + gradient_shape); // AxisVector={0, 1..} + add_inplace(reshape_op, 0, 0, false); + auto reshape_node = graph->add_node_and_edge(reshape_op, {slice_node}); // output_index=0 + allreduced_gradients_index.push_back( + std::pair, int>(reshape_node, i)); + } + return allreduced_gradients_index; +} + +bool SuperScalerDataParallelismPass::add_allreduce( + std::shared_ptr& graph, + std::map, std::shared_ptr>> hash_to_gradient_apply) +{ + for (int i = 0; i < hash_to_gradient_apply.size(); i++) + { + auto gradient_node = hash_to_gradient_apply[i].first; + auto apply_node = hash_to_gradient_apply[i].second; + int gradient_index = get_gradient_from_apply(apply_node); + graph->remove_edge(apply_node->get_in_edge(gradient_index)); + // Weight(weight_node) ----| + // | + // V + // (gradient) ApplyGradient-> Result + auto allreduce_op = std::make_shared(); + std::shared_ptr allreduce_node = + graph->add_node_and_edge(allreduce_op, {gradient_node}); + NNFUSION_LOG(INFO) << "[sc dp pass] allreduce name:" << allreduce_node->get_name(); + // Weight(weight_node) ------------| + // | + // V + // (gradient) --> allreduce ApplyGradient-> Result + graph->add_edge(allreduce_node, 0, apply_node, gradient_index); + // Weight(weight_node) ------------| + // | + // V + // (gradient) --> allreduce --> ApplyGradient-> Result + } +} + +bool SuperScalerDataParallelismPass::add_fused_allreduce( + std::shared_ptr& graph, + std::map, std::shared_ptr>> hash_to_gradient_apply) +{ + std::vector> gradient_key_subgroups = + group_gradient_apply(hash_to_gradient_apply); + for (std::vector subgroup : gradient_key_subgroups) + { + auto concat_node = concat_into_one(graph, subgroup, hash_to_gradient_apply); + // Weight(weight_node) ----------------| + // | + // | + // (gradient)->reshape---| | + // V V + //(gradient)->reshape-> concat ->ApplyGradient-> Result + std::shared_ptr allreduce_node; + auto allreduce_op = std::make_shared(); + allreduce_node = graph->add_node_and_edge(allreduce_op, {concat_node}); + NNFUSION_LOG(INFO) << "[sc dp pass] allreduce name:" << allreduce_node->get_name(); + // Weight(weight_node) ----------------------------| + // | + // | + // (gradient)->reshape---| | + // V V + //(gradient)->reshape-> concat --> allreduce ->ApplyGradient-> Result + std::vector, int>> allreduced_gradients_key = + split_from_one(graph, hash_to_gradient_apply, allreduce_node, subgroup); + // Weight(weight_node) --------------------------------------------------------------------| + // | + // | + // (gradient)->reshape---| V + // | |->reshape->ApplyGradient-> Result + // V | + //(gradient)->reshape-> concat(concated_gradient) --> allreduce --> + //slice->reshape->ApplyGradient-> Result + + for (std::pair, int> reshape_key : allreduced_gradients_key) + { + std::shared_ptr apply_node = hash_to_gradient_apply[reshape_key.second].second; + int gradient_index = get_gradient_from_apply(apply_node); + graph->remove_edge(apply_node->get_in_edge(gradient_index)); + graph->add_edge(reshape_key.first, 0, apply_node, gradient_index); + } + } + return true; +} + +bool SuperScalerDataParallelismPass::run_on_graph(std::shared_ptr& graph) +{ + sc_allreduce_enable = FLAGS_fadd_sc_allreduce; + sc_allreduce_fusion_enable = FLAGS_fadd_sc_allreduce_fusion; + if (!sc_allreduce_enable) + return true; + std::map, std::shared_ptr>> hash_to_gradient_apply; + // group gradient and apply* op from n-th layer to 1st layer + for (int i = graph->get_outputs().size() - 1; i >= 0; i--) + { + auto result_node = graph->get_outputs()[i]; + // the apply node followed by result node. so check result_node's input node + int gradient_index = get_gradient_from_apply((result_node->get_in_edge(0)->get_src())); + if (gradient_index == -1) + continue; // skip nodes whose type are not Apply*. + NNFUSION_CHECK(result_node->get_in_edges().size() == 1) + << "result node has other input except apply op:"; + // Weight(weight_node) ----| + // | + // V + // (gradient) --------> Apply-> Result + auto apply_node = result_node->get_in_edge(0)->get_src(); + std::shared_ptr gradient_node = apply_node->get_in_edge(gradient_index)->get_src(); + NNFUSION_LOG(INFO) << "[sc dp pass] find gradient: " << gradient_node->get_name() + << "; id: " << hash_to_gradient_apply.size(); + hash_to_gradient_apply[hash_to_gradient_apply.size()] = + std::pair, std::shared_ptr>(gradient_node, apply_node); + } + if (sc_allreduce_fusion_enable) + { + return add_fused_allreduce(graph, hash_to_gradient_apply); + } + else + { + return add_allreduce(graph, hash_to_gradient_apply); + } +} diff --git a/src/nnfusion/engine/pass/graph/superscaler_dataparallelism_pass.hpp b/src/nnfusion/engine/pass/graph/superscaler_dataparallelism_pass.hpp new file mode 100644 index 000000000..9456284b2 --- /dev/null +++ b/src/nnfusion/engine/pass/graph/superscaler_dataparallelism_pass.hpp @@ -0,0 +1,67 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +#pragma once + +#include +#include "graph_pass_base.hpp" +using namespace nnfusion::graph; + +namespace nnfusion +{ + namespace pass + { + namespace graph + { + class SuperScalerDataParallelismPass : public GraphPassBase + { + public: + bool run_on_graph(std::shared_ptr& graph) override; + + private: + std::shared_ptr concat_into_one( + std::shared_ptr& graph, + std::vector subgroup, + std::map, std::shared_ptr>> + hash_to_gradient_apply); + std::vector, int>> split_from_one( + std::shared_ptr& graph, + std::map, std::shared_ptr>> + hash_to_gradient_apply, + std::shared_ptr allreduce_node, + std::vector subgroup); + bool add_allreduce( + std::shared_ptr& graph, + std::map, std::shared_ptr>> + hash_to_gradient_apply); + bool add_fused_allreduce( + std::shared_ptr& graph, + std::map, std::shared_ptr>> + hash_to_gradient_apply); + std::vector> group_gradient_apply( + std::map, std::shared_ptr>> + hash_to_gradient_apply); + int get_gradient_from_apply(std::shared_ptr apply_node); + bool sc_allreduce_enable; + bool sc_allreduce_fusion_enable; + template + void add_inplace(T op, size_t output, size_t input, bool destructive) + { + auto op_annotations = op->get_op_annotations(); + if (op_annotations) + { + // pass-through + op_annotations->add_in_place_oi_pair({output, input, destructive}); + } + else + { + op_annotations = std::make_shared(); + // pass-through + op_annotations->add_in_place_oi_pair({output, input, destructive}); + op->set_op_annotations(op_annotations); + } + } + }; + } + } +} diff --git a/src/tools/CMakeLists.txt b/src/tools/CMakeLists.txt index a05af420e..e941f682b 100644 --- a/src/tools/CMakeLists.txt +++ b/src/tools/CMakeLists.txt @@ -2,3 +2,4 @@ # Licensed under the MIT License. add_subdirectory(nnfusion) +add_subdirectory(serialize) diff --git a/src/tools/nnfusion/CMakeLists.txt b/src/tools/nnfusion/CMakeLists.txt index d4d6bcd63..4be4d441d 100644 --- a/src/tools/nnfusion/CMakeLists.txt +++ b/src/tools/nnfusion/CMakeLists.txt @@ -18,7 +18,7 @@ if (ONNX_FRONTEND) endif() if (TORCHSCRIPT_FRONTEND) - target_link_libraries(nnfusion torchscript_import_interface frontend_util) + target_link_libraries(nnfusion torchscript_import_interface torchscript_import frontend_util) endif() target_link_libraries(nnfusion nnfusion_backend nnfusion_engine_pass_graph nnfusion_operators gflags) @@ -30,6 +30,7 @@ add_custom_command( POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy_directory ${CMAKE_SOURCE_DIR}/thirdparty/threadpool ${CMAKE_BINARY_DIR}/src/tools/nnfusion/threadpool POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy_directory ${CMAKE_SOURCE_DIR}/thirdparty/mlas ${CMAKE_BINARY_DIR}/src/tools/nnfusion/mlas POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy_directory ${CMAKE_SOURCE_DIR}/thirdparty/mkl ${CMAKE_BINARY_DIR}/src/tools/nnfusion/mkl + POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy_directory ${CMAKE_SOURCE_DIR}/thirdparty/superscaler ${CMAKE_BINARY_DIR}/src/tools/nnfusion/superscaler POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy_directory ${CMAKE_SOURCE_DIR}/src/tools/nnfusion/templates ${CMAKE_BINARY_DIR}/test/templates ) diff --git a/src/tools/nnfusion/distributed_training/mnist/README.md b/src/tools/nnfusion/distributed_training/mnist/README.md new file mode 100644 index 000000000..7f0ece670 --- /dev/null +++ b/src/tools/nnfusion/distributed_training/mnist/README.md @@ -0,0 +1,37 @@ +## Distributed MNIST Training Example +### Install Dependencies +```sh +# install superscaler +$ python3 -m pip install tensorflow==1.15 +$ git clone https://github.com/microsoft/SuperScaler.git && cd SuperScaler && python3 -m pip install . && cd - && rm -fr SuperScaler +# install other dependencies +$ python3 -m pip install torch torchvision mpi4py +``` + +### Prepare data +Prepare your own trainable frozen model or you can acquire from [mnist_mlp.onnx](https://nnfusion.blob.core.windows.net/models/onnx/mnist_mlp.onnx) and GPU cluster's specification file which describes the underlying topology of your model training environment. Since the GPU cluster's specification is used by [SuperScaler](https://github.com/microsoft/SuperScaler.git), you can learn all the details from there. Or an example [resource_pool.yaml](https://github.com/microsoft/SuperScaler#appendix-a-sample-resource_poolyaml) is provided. +```sh +$ cd ./src/tools/nnfusion/distributed_training/mnist +$ wget +$ wget +``` + +### Compile +```sh +# this will compile the frozen model for 2 GPU workers doing data parallel training on the same host +$ bash ../../../superscaler/nnfusion_dp_single_host.sh mnist_mlp.onnx "-f onnx -p \"batch:3\" -fautodiff -ftraining_mode -fextern_result_memory=True" localhost:2 resource_pool.yaml +``` + +### Build +```sh +$ cd build && cmake . && make -j +``` + +### Train +```sh +$ bash ./train.sh +# in case you are using older version MPI: +$ mpirun -np 2 -x PATH -x LD_LIBRARY_PATH bash -c 'python3 nnf_py/train.py $OMPI_COMM_WORLD_LOCAL_RANK/plan.json' +``` + + diff --git a/src/tools/nnfusion/distributed_training/mnist/nnf_py/__init__.py b/src/tools/nnfusion/distributed_training/mnist/nnf_py/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/tools/nnfusion/distributed_training/mnist/nnf_py/dataprc.py b/src/tools/nnfusion/distributed_training/mnist/nnf_py/dataprc.py new file mode 100644 index 000000000..fc9961ee1 --- /dev/null +++ b/src/tools/nnfusion/distributed_training/mnist/nnf_py/dataprc.py @@ -0,0 +1,55 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import torch +from torchvision import datasets, transforms +from mpi4py import MPI +comm = MPI.COMM_WORLD +rank = comm.Get_rank() + + +def get_dataloader(device_id, world_size): + + batch_size = 3 + kwargs = {'batch_size': batch_size} + + kwargs.update({'num_workers': 1, + 'pin_memory': True, + }, + ) + + transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.1307,), (0.3081,)) + ]) + + if(rank == 0): + datasets.MNIST('/tmp', + train=True, + download=True, + ) + datasets.MNIST('/tmp', + train=False, + download=True, + ) + comm.barrier() + dataset1 = datasets.MNIST('/tmp', + train=True, + download=False, + transform=transform) + dataset2 = datasets.MNIST('/tmp', + train=False, + download=False, + transform=transform) + + train_sampler = torch.utils.data.distributed.DistributedSampler( + dataset1, + num_replicas=world_size, + rank=device_id, + shuffle=True + ) + train_dataloader = torch.utils.data.DataLoader( + dataset1, sampler=train_sampler, **kwargs) + test_dataloader = torch.utils.data.DataLoader(dataset2, **kwargs) + + return train_dataloader, test_dataloader diff --git a/src/tools/nnfusion/distributed_training/mnist/nnf_py/dtypes.py b/src/tools/nnfusion/distributed_training/mnist/nnf_py/dtypes.py new file mode 100644 index 000000000..013858336 --- /dev/null +++ b/src/tools/nnfusion/distributed_training/mnist/nnf_py/dtypes.py @@ -0,0 +1,52 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import ctypes +import torch +c_float = ctypes.c_float +c_float_p = ctypes.POINTER(ctypes.c_float) +c_float_p_p = ctypes.POINTER(ctypes.POINTER(ctypes.c_float)) +c_int = ctypes.c_int +c_int_p = ctypes.POINTER(ctypes.c_int) +c_int_p_p = ctypes.POINTER(ctypes.POINTER(ctypes.c_int)) +c_int64 = ctypes.c_int64 +c_int64_p = ctypes.POINTER(ctypes.c_int64) + + +def tensor_ptr(tensor): + tensor_addr = tensor.storage().data_ptr() + tensor_ptr = None + if tensor.dtype is torch.float32: + tensor_ptr = ctypes.cast(tensor_addr, c_float_p) + else: + if tensor.dtype is torch.int32: + tensor_ptr = ctypes.cast(tensor_addr, c_int_p) + else: + if tensor.dtype is torch.int64: + tensor_ptr = ctypes.cast(tensor_addr, c_int64_p) + else: + raise Exception("Dtype is not suppported: %s" % (tensor.dtype)) + return tensor_ptr + + +def deduce_signatrue(tensors): + sig = [] + for p in tensors: + if p.dtype is torch.float32: + sig.append(c_float_p) + else: + if p.dtype is torch.int32: + sig.append(c_int_p) + else: + if p.dtype is torch.int64: + sig.append(c_int64_p) + else: + raise Exception("Dtype is not suppported: %s" % (p.dtype)) + return tuple(sig) + + +def get_data_addr(tensors): + addr = [] + for p in tensors: + addr.append(tensor_ptr(p)) + return tuple(addr) diff --git a/src/tools/nnfusion/distributed_training/mnist/nnf_py/nnf.py b/src/tools/nnfusion/distributed_training/mnist/nnf_py/nnf.py new file mode 100644 index 000000000..131662225 --- /dev/null +++ b/src/tools/nnfusion/distributed_training/mnist/nnf_py/nnf.py @@ -0,0 +1,72 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import os +import logging +from ctypes import cdll, c_char_p, c_int, byref +import dtypes + + +class Runtime: + def __init__(self): + # detect existed library of nnfusion runtime + libnnf_rt = "none" + if "LIB_NNF_RT" not in os.environ.keys(): + logging.info( + "libnnfusion_rt is not specified \ + by system enviroment variable: LIB_NNF_RT") + default_path = os.path.join( + os.path.dirname(os.path.abspath(__file__)), '..') + for file in os.listdir(default_path): + if file.startswith("libnnf") and file.endswith("rt.so"): + libnnf_rt = os.path.join(default_path, file) + logging.info("libnnfusion_rt library detected") + self.default_path = default_path + else: + libnnf_rt = os.environ["LIB_NNF_RT"] + + if not os.path.exists(libnnf_rt): + raise Exception("libnnfusion_rt: %s is not existed!" % (libnnf_rt)) + + try: + libnnf = cdll.LoadLibrary(libnnf_rt) + except Exception: + raise Exception("libnnfusion_rt: %s is not loaded!" % (libnnf_rt)) + + # member of session + self.libnnf_path = libnnf_rt + self.libnnf = libnnf + + # call for init session + def init(self, plan_file_path): + if "cpu" in self.libnnf_path: + self.libnnf.cpu_init() + else: + self.libnnf.cuda_init(c_char_p(plan_file_path.encode('utf-8'))) + + def device_id(self): + device_id = c_int() + self.libnnf.sc_get_device_id(byref(device_id)) + return device_id.value + + def world_size(self): + world_size = c_int() + self.libnnf.sc_get_world_size(byref(world_size)) + return world_size.value + + def feed(self, tensors=[], signature=(), params=()): + if tensors is not []: + self.libnnf.argtypes = dtypes.deduce_signatrue(tensors) + self.libnnf.kernel_entry(*(dtypes.get_data_addr(tensors))) + else: + self.libnnf.argtypes = signature + self.libnnf.kernel_entry(*params) + + def free(self): + if "cpu" in self.libnnf_path: + self.libnnf.cpu_free() + else: + self.libnnf.cuda_free() + + del self.libnnf + del self.libnnf_path diff --git a/src/tools/nnfusion/distributed_training/mnist/nnf_py/train.py b/src/tools/nnfusion/distributed_training/mnist/nnf_py/train.py new file mode 100644 index 000000000..fc30b4874 --- /dev/null +++ b/src/tools/nnfusion/distributed_training/mnist/nnf_py/train.py @@ -0,0 +1,167 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import torch +import json +import nnf +import dataprc +import math +import sys + + +json_path = './para_info.json' +loss_name = "loss" + + +class NNFTrainer: + def __init__(self, plan_file_path): + + self.rt = nnf.Runtime() + self.rt.init(plan_file_path) + self.device_id = self.rt.device_id() + self.world_size = self.rt.world_size() + self.cuda_device_id = 'cuda:' + str(self.rt.device_id()) + print("i am running @ local device: ", self.cuda_device_id) + self.cuda_device = torch.device(self.cuda_device_id) + torch.cuda.set_device(self.cuda_device) + + self.parambyid = dict() + # lastid = 0 + paramlist = list() + # siglist = list() + self.loss_id = 0 + self.inputs_name_id = dict() + # self.weight_ids = set() + self.weight_name_id = dict() # {name: [w_id, b_id]} + + # Should load info from var_info.json + with open(json_path) as json_file: + varinfo = json.load(json_file) + for key in varinfo["input"]: + id = int(varinfo["input"][key]["id"].split( + "inputs[")[1].split("]")[0]) + shape = varinfo["input"][key]["shape"] + dtype = varinfo["input"][key]["id"][2:].split("*")[0] + self.parambyid[id] = (shape, dtype) + self.inputs_name_id[key] = id + + for key in varinfo["weight"]: + id = int(varinfo["weight"][key]["id"].split( + "inputs[")[1].split("]")[0]) + shape = varinfo["weight"][key]["shape"] + dtype = varinfo["weight"][key]["id"][2:].split("*")[0] + self.parambyid[id] = (shape, dtype) + # self.weight_ids.add(id) + if key.endswith(".weight"): + name = key[: key.rindex(".weight")] + if name not in self.weight_name_id: + self.weight_name_id[name] = [id, None] + else: + self.weight_name_id[name][0] = id + elif key.endswith(".bias"): + name = key[: key.rindex(".bias")] + if name not in self.weight_name_id: + self.weight_name_id[name] = [None, id] + else: + self.weight_name_id[name][1] = id + + start = len(varinfo["input"]) + len(varinfo["weight"]) + for key in varinfo["output"]: + id = int(varinfo["output"][key]["id"].split( + "outputs[")[1].split("]")[0]) + start + shape = varinfo["output"][key]["shape"] + dtype = varinfo["output"][key]["id"][2:].split("*")[0] + self.parambyid[id] = (shape, dtype) + if key == loss_name: + self.loss_id = id + + for key in sorted(self.parambyid): + # inititalizer + (shape, dtype) = self.parambyid[key] + if dtype == "float": + dtype = torch.float + else: + if dtype == "int64_t": + dtype = torch.int64 + else: + raise Exception("Dtype is not suppported: %s" % (dtype)) + param = torch.ones(shape, dtype=dtype, device=self.cuda_device) + # if key in self.weight_ids: + # torch.nn.init.uniform_(param) + paramlist.append(param) + + self.paramlist = paramlist + + for v in self.weight_name_id.values(): + w_id, b_id = v + torch.nn.init.kaiming_uniform_( + self.paramlist[w_id], a=math.sqrt(5)) + if b_id is not None: + fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out( + self.paramlist[w_id]) + bound = 1 / math.sqrt(fan_in) + torch.nn.init.uniform_(self.paramlist[b_id], -bound, bound) + + def interation(self, inputs=dict()): + # replacing param with data tensor + # torch.set_printoptions(threshold=501408) + for i in inputs: + self.paramlist[i] = inputs[i] + self.rt.feed(tensors=self.paramlist) + + def finish(self): + self.rt.free() + + def save(self): + torch.save({"general_io_state": self.paramlist}, "torch_checkpoint") + + +class NaiveBertDataLoader: + def __init__(self, device_id, world_size): + self.batch_size = 3 + self.data, _ = dataprc.get_dataloader(device_id, world_size) + + def interation(self, batch, trainer): + # data, target + batch = {"data": batch[0], "target": batch[1]} + for k in batch: + batch[k] = batch[k].to(trainer.cuda_device) + + name_map = {"data": "data", "target": "target"} + + inputs = dict() + for key, val in trainer.inputs_name_id.items(): + if key in name_map and name_map[key] in batch: + inputs[val] = batch[name_map[key]] + else: + (shape, dtype) = trainer.parambyid[val] + if dtype == "float": + dtype = torch.float + else: + if dtype == "int64_t": + dtype = torch.int64 + else: + raise Exception( + "Dtype is not suppported: %s" % (dtype)) + inputs[val] = torch.ones( + shape, dtype=dtype, device=trainer.cuda_device) + + trainer.interation(inputs=inputs) + loss = trainer.paramlist[trainer.loss_id] + print("total loss: ", loss) + + return loss + + +if __name__ == "__main__": + if len(sys.argv) != 2: + raise Exception('no plan is given') + trainer = NNFTrainer(sys.argv[1]) + data = NaiveBertDataLoader(trainer.device_id, trainer.world_size) + i = 1 + for batch in data.data: + print(i) + i += 1 + data.interation(batch, trainer) + trainer.save() + trainer.finish() diff --git a/src/tools/serialize/CMakeLists.txt b/src/tools/serialize/CMakeLists.txt new file mode 100644 index 000000000..2324e2cdf --- /dev/null +++ b/src/tools/serialize/CMakeLists.txt @@ -0,0 +1,8 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +FILE(TO_NATIVE_PATH ${PROJECT_SOURCE_DIR}/src/nnfusion/common/serialize PROTOSRC_PATH) +FOREACH(item pbtypes attr_value tensor_shape node_def graph_def) + EXECUTE_PROCESS(COMMAND ${PROTOBUF_PROTOC_EXECUTABLE} --proto_path=${PROTOSRC_PATH} --python_out=${CMAKE_CURRENT_BINARY_DIR} ${item}.proto) +ENDFOREACH(item) +EXECUTE_PROCESS(COMMAND cp ${CMAKE_CURRENT_SOURCE_DIR}/nnfusion_serialize_tool.py ${CMAKE_CURRENT_BINARY_DIR}) \ No newline at end of file diff --git a/src/tools/serialize/nnfusion_serialize_tool.py b/src/tools/serialize/nnfusion_serialize_tool.py new file mode 100644 index 000000000..323d35a8d --- /dev/null +++ b/src/tools/serialize/nnfusion_serialize_tool.py @@ -0,0 +1,27 @@ +"""A python tool to parse *.pb file of nnfusion serialization. +TODO(gbxu): merge this into nnfusion adapter. +""" +import graph_def_pb2 +import sys + +def import_nnfusion(): + """Read the existing nnfusion graph. + """ + graph_def = graph_def_pb2.GraphDef() + try: + with open(sys.argv[1], "rb") as f: + graph_def.ParseFromString(f.read()) + except IOError: + print(sys.argv[1] + ": File not found. Creating a new file.") + with open(sys.argv[2], "w+") as f: + f.write(str(graph_def)) + return graph_def + +def export_nnfusion(graph_def): + """Write the new nnfusion graph back to disk. + """ + with open(sys.argv[1], "wb") as f: + f.write(graph_def.SerializeToString()) + +if __name__ == "__main__": + import_nnfusion() diff --git a/src/tools/superscaler/nnfusion_dp_single_host.sh b/src/tools/superscaler/nnfusion_dp_single_host.sh new file mode 100644 index 000000000..4aed838e7 --- /dev/null +++ b/src/tools/superscaler/nnfusion_dp_single_host.sh @@ -0,0 +1,74 @@ +# Copyright (c) Microsoft Corporation - All rights reserved +# Licensed under the MIT License + +#!/bin/bash +# saner programming env: these switches turn some bugs into errors +set -o errexit -o pipefail -o noclobber -o nounset + +#TODO: test whether superscaler is installed +#TODO: add cmdline arg parser +#example +#./nnfusion_dp_single_host_train.sh data/mnist_mlp.onnx "-f onnx -p \"batch:3\" -fautodiff -ftraining_mode -fextern_result_memory=True" 10.0.0.25:2 data/resource_pool.yaml + +#user specified inputs +MODEL_FILE_PATH=$(realpath $1) +MODEL_COMPILING_OPTIONS="$2" +#here, we only support one host mutilple worker +DEPLOYMENT_CONFIG=$3 +CLUSTER_SPEC_FILE=$(realpath $4) + +OLD_PWD=$(pwd) +echo $PWD +NUM_PROCS_SAME_HOST=${DEPLOYMENT_CONFIG##*:} +NNFUSION_EXE='./src/tools/nnfusion/nnfusion' +DUMPED_GRAPH_FILENAME='graph' +SUPERSCALER_COMPILING_OPTIONS="-fadd_sc_allreduce=true -fenable_export_graph=true -fnnfusion_graph_path=$DUMPED_GRAPH_FILENAME.pb" +COMPILE_CMD="$NNFUSION_EXE $MODEL_FILE_PATH $MODEL_COMPILING_OPTIONS $SUPERSCALER_COMPILING_OPTIONS" + +echo "-- Creating build dir for compiling NNFusion model" +rm -fr build && mkdir build && cd build && cp -r ../nnf_py . +pushd ../../../../../.. > /dev/null +echo "-- Building NNFusion" +rm -fr build && mkdir build && cd build && cmake .. > /dev/null +make -j > /dev/null || make +echo "-- Compiling model $MODEL_FILE_PATH" +eval $COMPILE_CMD > /dev/null +echo "-- Dumping NNFusion graph" +python3 src/tools/serialize/nnfusion_serialize_tool.py $DUMPED_GRAPH_FILENAME.pb $DUMPED_GRAPH_FILENAME.pbtxt > /dev/null 2>&1 +#copy generated resources +cp -r nnfusion_rt/cuda_codegen/* $OLD_PWD/build +cp -r $DUMPED_GRAPH_FILENAME.pbtxt $OLD_PWD/build +popd > /dev/null + +#modify cmakelist +sed -i 's/.*cuda_add_library.*/set(CMAKE_POSITION_INDEPENDENT_CODE ON)\ncuda_add_library(${TARGET_NAME} SHARED ${SRC})/' CMakeLists.txt + +echo "-- Generating superscaler runnig plan" +python3 -c "from superscaler.nnfusion import generate_data_parallelism_plan; \ + generate_data_parallelism_plan(\"$DUMPED_GRAPH_FILENAME.pbtxt\", \ + ${NUM_PROCS_SAME_HOST}, \ + \"${CLUSTER_SPEC_FILE}\", \ + \"${PWD}\", \ + communication_DSL='ring')" > /dev/null + + +MPIRUN=$(which mpirun) +LAUNCH_CMD="$MPIRUN \ +--tag-output \ +--output-filename .result\ +" +for i in $(seq 0 $(($NUM_PROCS_SAME_HOST-1))) +do + if [ "$i" == "0" ];then + LAUNCH_CMD="$LAUNCH_CMD -np 1 -host $DEPLOYMENT_CONFIG python nnf_py/train.py "$i/plan.json" " + else + LAUNCH_CMD="$LAUNCH_CMD : -np 1 -host $DEPLOYMENT_CONFIG python nnf_py/train.py "$i/plan.json" " + fi +done + +echo "-- Writing launch command into train.sh " +echo $LAUNCH_CMD > train.sh && chmod +x train.sh +echo "-- Compile successfully" + + + diff --git a/thirdparty/superscaler/superscaler.cmake b/thirdparty/superscaler/superscaler.cmake new file mode 100644 index 000000000..c7b4dafd7 --- /dev/null +++ b/thirdparty/superscaler/superscaler.cmake @@ -0,0 +1,33 @@ +execute_process( + COMMAND + python3 -c + "import superscaler as _; print(_.__path__[0])" + OUTPUT_VARIABLE SUPERSCALER_INSTALLATION_PATH + OUTPUT_STRIP_TRAILING_WHITESPACE) +message(STATUS ${SUPERSCALER_INSTALLATION_PATH}) + +find_path( + SUPERSCALER_INCLUDE_DIR + NAMES superscaler.h + HINTS ${CMAKE_CURRENT_SOURCE_DIR}/superscaler) + +find_library( + SUPERSCALER_LIBRARY + NAMES superscaler_pywrap + HINTS ${SUPERSCALER_INSTALLATION_PATH} + PATH_SUFFIXES lib) + +include(FindPackageHandleStandardArgs) +# handle the QUIETLY and REQUIRED arguments and set VILLASNODE_FOUND to TRUE if +# all listed variables are TRUE +find_package_handle_standard_args(superscaler DEFAULT_MSG SUPERSCALER_LIBRARY) + +mark_as_advanced(SUPERSCALER_INCLUDE_DIR SUPERSCALER_LIBRARY) + +set(SUPERSCALER_LIBRARIES ${SUPERSCALER_LIBRARY}) +set(SUPERSCALER_INCLUDE_DIRS ${SUPERSCALER_INCLUDE_DIR}) + +add_library(superscaler INTERFACE) +target_include_directories(superscaler SYSTEM + INTERFACE ${SUPERSCALER_INCLUDE_DIRS}) +target_link_libraries(superscaler INTERFACE ${SUPERSCALER_LIBRARIES}) diff --git a/thirdparty/superscaler/superscaler.h b/thirdparty/superscaler/superscaler.h new file mode 100644 index 000000000..475d618cf --- /dev/null +++ b/thirdparty/superscaler/superscaler.h @@ -0,0 +1,35 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +#pragma once + +#include + +#ifdef __cplusplus +extern "C" { +#endif +//c interface to init a superscaler session from plan file +void sc_init(const char* plan_path); + +//c interface to destroy a superscaler session +void sc_finalize(); + +//c interface to check the number of participants in this session +void sc_get_world_size(int*); + +//c interface to get current process uniq process id of all the participants in this session +void sc_get_host_id(int*); + +//c interface to get current process uniq device id of all the participants in this session +void sc_get_device_id(int*); + +// in-place allreduce, which means data's contents will change after allreduce +// tensor_name is used to index plans for this tensor +// if stream provided, superscaler will use it to do allreduce task +void sc_allreduce(const char* tensor_name, float* data, size_t size, void* stream); +void sc_send(const char* tensor_name, unsigned char* input, size_t size, void* stream); +void sc_recv(const char* tensor_name, unsigned char** output, size_t* size, void* stream); + +#ifdef __cplusplus +} +#endif