diff --git a/cpp/.gitignore b/cpp/.gitignore new file mode 100644 index 000000000..7b99f6aaa --- /dev/null +++ b/cpp/.gitignore @@ -0,0 +1,10 @@ +assets/* +!assets/.gitkeep + +include/baselines3_models/* +!include/baselines3_models/predictor.h +!include/baselines3_models/preprocessing.h + +src/baselines3_models/* +!src/baselines3_models/predictor.cpp +!src/baselines3_models/preprocessing.cpp diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt new file mode 100644 index 000000000..f10361c27 --- /dev/null +++ b/cpp/CMakeLists.txt @@ -0,0 +1,61 @@ +cmake_minimum_required(VERSION 3.16.3) +project(baselines3_models) +include(cmake/CMakeRC.cmake) + +cmrc_add_resource_library(baselines3_model_resources +ALIAS baselines3_model::rc +NAMESPACE baselines3_model + +#static +#!static +) + +# Install PyTorch C++ first, see: https://pytorch.org/cppdocs/installing.html +# Don't forget to add it to your CMAKE_PREFIX_PATH +find_package(Torch REQUIRED) +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}") + +#Enable C++17 +set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++17 -Wall -Wextra -fPIC") + +option(BASELINES3_BIN "Building bin" OFF) +option(BASELINES3_PYBIND "Build python bindings (requires pybind11)" OFF) + +set(ALL_SOURCES +src/baselines3_models/predictor.cpp +src/baselines3_models/preprocessing.cpp + +#sources +#!sources +) + +set(LIBRARIES +"${TORCH_LIBRARIES}" baselines3_model::rc +) + +add_library(baselines3_models SHARED ${ALL_SOURCES}) +target_link_libraries(baselines3_models ${LIBRARIES}) +target_include_directories(baselines3_models PUBLIC + $ +) + +if (BASELINES3_BIN) + add_executable(predict ${CMAKE_CURRENT_SOURCE_DIR}/src/predict.cpp) + target_link_libraries(predict baselines3_models) +endif() + +if (BASELINES3_PYBIND) + set (Python_EXECUTABLE "/usr/bin/python3.8") + # apt-get install python3-dev + find_package(Python COMPONENTS Interpreter Development) + # apt-get install python3-pybind11 + find_package(pybind11 REQUIRED) + + pybind11_add_module(baselines3_py ${ALL_SOURCES}) + target_link_libraries(baselines3_py PRIVATE ${LIBRARIES}) + target_compile_definitions(baselines3_py PUBLIC -DEXPORT_PYBIND) + + target_include_directories(baselines3_py PUBLIC + $ + ) +endif() \ No newline at end of file diff --git a/cpp/assets/.gitkeep b/cpp/assets/.gitkeep new file mode 100644 index 000000000..e69de29bb diff --git a/cpp/cmake/CMakeRC.cmake b/cpp/cmake/CMakeRC.cmake new file mode 100644 index 000000000..1a034b5f1 --- /dev/null +++ b/cpp/cmake/CMakeRC.cmake @@ -0,0 +1,644 @@ +# This block is executed when generating an intermediate resource file, not when +# running in CMake configure mode +if(_CMRC_GENERATE_MODE) + # Read in the digits + file(READ "${INPUT_FILE}" bytes HEX) + # Format each pair into a character literal. Heuristics seem to favor doing + # the conversion in groups of five for fastest conversion + string(REGEX REPLACE "(..)(..)(..)(..)(..)" "'\\\\x\\1','\\\\x\\2','\\\\x\\3','\\\\x\\4','\\\\x\\5'," chars "${bytes}") + # Since we did this in groups, we have some leftovers to clean up + string(LENGTH "${bytes}" n_bytes2) + math(EXPR n_bytes "${n_bytes2} / 2") + math(EXPR remainder "${n_bytes} % 5") # <-- '5' is the grouping count from above + set(cleanup_re "$") + set(cleanup_sub ) + while(remainder) + set(cleanup_re "(..)${cleanup_re}") + set(cleanup_sub "'\\\\x\\${remainder}',${cleanup_sub}") + math(EXPR remainder "${remainder} - 1") + endwhile() + if(NOT cleanup_re STREQUAL "$") + string(REGEX REPLACE "${cleanup_re}" "${cleanup_sub}" chars "${chars}") + endif() + string(CONFIGURE [[ + namespace { const char file_array[] = { @chars@ 0 }; } + namespace cmrc { namespace @NAMESPACE@ { namespace res_chars { + extern const char* const @SYMBOL@_begin = file_array; + extern const char* const @SYMBOL@_end = file_array + @n_bytes@; + }}} + ]] code) + file(WRITE "${OUTPUT_FILE}" "${code}") + # Exit from the script. Nothing else needs to be processed + return() +endif() + +set(_version 2.0.0) + +cmake_minimum_required(VERSION 3.3) +include(CMakeParseArguments) + +if(COMMAND cmrc_add_resource_library) + if(NOT DEFINED _CMRC_VERSION OR NOT (_version STREQUAL _CMRC_VERSION)) + message(WARNING "More than one CMakeRC version has been included in this project.") + endif() + # CMakeRC has already been included! Don't do anything + return() +endif() + +set(_CMRC_VERSION "${_version}" CACHE INTERNAL "CMakeRC version. Used for checking for conflicts") + +set(_CMRC_SCRIPT "${CMAKE_CURRENT_LIST_FILE}" CACHE INTERNAL "Path to CMakeRC script") + +function(_cmrc_normalize_path var) + set(path "${${var}}") + file(TO_CMAKE_PATH "${path}" path) + while(path MATCHES "//") + string(REPLACE "//" "/" path "${path}") + endwhile() + string(REGEX REPLACE "/+$" "" path "${path}") + set("${var}" "${path}" PARENT_SCOPE) +endfunction() + +get_filename_component(_inc_dir "${CMAKE_BINARY_DIR}/_cmrc/include" ABSOLUTE) +set(CMRC_INCLUDE_DIR "${_inc_dir}" CACHE INTERNAL "Directory for CMakeRC include files") +# Let's generate the primary include file +file(MAKE_DIRECTORY "${CMRC_INCLUDE_DIR}/cmrc") +set(hpp_content [==[ +#ifndef CMRC_CMRC_HPP_INCLUDED +#define CMRC_CMRC_HPP_INCLUDED + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#if !(defined(__EXCEPTIONS) || defined(__cpp_exceptions) || defined(_CPPUNWIND) || defined(CMRC_NO_EXCEPTIONS)) +#define CMRC_NO_EXCEPTIONS 1 +#endif + +namespace cmrc { namespace detail { struct dummy; } } + +#define CMRC_DECLARE(libid) \ + namespace cmrc { namespace detail { \ + struct dummy; \ + static_assert(std::is_same::value, "CMRC_DECLARE() must only appear at the global namespace"); \ + } } \ + namespace cmrc { namespace libid { \ + cmrc::embedded_filesystem get_filesystem(); \ + } } static_assert(true, "") + +namespace cmrc { + +class file { + const char* _begin = nullptr; + const char* _end = nullptr; + +public: + using iterator = const char*; + using const_iterator = iterator; + iterator begin() const noexcept { return _begin; } + iterator cbegin() const noexcept { return _begin; } + iterator end() const noexcept { return _end; } + iterator cend() const noexcept { return _end; } + std::size_t size() const { return static_cast(std::distance(begin(), end())); } + + file() = default; + file(iterator beg, iterator end) noexcept : _begin(beg), _end(end) {} +}; + +class directory_entry; + +namespace detail { + +class directory; +class file_data; + +class file_or_directory { + union _data_t { + class file_data* file_data; + class directory* directory; + } _data; + bool _is_file = true; + +public: + explicit file_or_directory(file_data& f) { + _data.file_data = &f; + } + explicit file_or_directory(directory& d) { + _data.directory = &d; + _is_file = false; + } + bool is_file() const noexcept { + return _is_file; + } + bool is_directory() const noexcept { + return !is_file(); + } + const directory& as_directory() const noexcept { + assert(!is_file()); + return *_data.directory; + } + const file_data& as_file() const noexcept { + assert(is_file()); + return *_data.file_data; + } +}; + +class file_data { +public: + const char* begin_ptr; + const char* end_ptr; + file_data(const file_data&) = delete; + file_data(const char* b, const char* e) : begin_ptr(b), end_ptr(e) {} +}; + +inline std::pair split_path(const std::string& path) { + auto first_sep = path.find("/"); + if (first_sep == path.npos) { + return std::make_pair(path, ""); + } else { + return std::make_pair(path.substr(0, first_sep), path.substr(first_sep + 1)); + } +} + +struct created_subdirectory { + class directory& directory; + class file_or_directory& index_entry; +}; + +class directory { + std::list _files; + std::list _dirs; + std::map _index; + + using base_iterator = std::map::const_iterator; + +public: + + directory() = default; + directory(const directory&) = delete; + + created_subdirectory add_subdir(std::string name) & { + _dirs.emplace_back(); + auto& back = _dirs.back(); + auto& fod = _index.emplace(name, file_or_directory{back}).first->second; + return created_subdirectory{back, fod}; + } + + file_or_directory* add_file(std::string name, const char* begin, const char* end) & { + assert(_index.find(name) == _index.end()); + _files.emplace_back(begin, end); + return &_index.emplace(name, file_or_directory{_files.back()}).first->second; + } + + const file_or_directory* get(const std::string& path) const { + auto pair = split_path(path); + auto child = _index.find(pair.first); + if (child == _index.end()) { + return nullptr; + } + auto& entry = child->second; + if (pair.second.empty()) { + // We're at the end of the path + return &entry; + } + + if (entry.is_file()) { + // We can't traverse into a file. Stop. + return nullptr; + } + // Keep going down + return entry.as_directory().get(pair.second); + } + + class iterator { + base_iterator _base_iter; + base_iterator _end_iter; + public: + using value_type = directory_entry; + using difference_type = std::ptrdiff_t; + using pointer = const value_type*; + using reference = const value_type&; + using iterator_category = std::input_iterator_tag; + + iterator() = default; + explicit iterator(base_iterator iter, base_iterator end) : _base_iter(iter), _end_iter(end) {} + + iterator begin() const noexcept { + return *this; + } + + iterator end() const noexcept { + return iterator(_end_iter, _end_iter); + } + + inline value_type operator*() const noexcept; + + bool operator==(const iterator& rhs) const noexcept { + return _base_iter == rhs._base_iter; + } + + bool operator!=(const iterator& rhs) const noexcept { + return !(*this == rhs); + } + + iterator operator++() noexcept { + auto cp = *this; + ++_base_iter; + return cp; + } + + iterator& operator++(int) noexcept { + ++_base_iter; + return *this; + } + }; + + using const_iterator = iterator; + + iterator begin() const noexcept { + return iterator(_index.begin(), _index.end()); + } + + iterator end() const noexcept { + return iterator(); + } +}; + +inline std::string normalize_path(std::string path) { + while (path.find("/") == 0) { + path.erase(path.begin()); + } + while (!path.empty() && (path.rfind("/") == path.size() - 1)) { + path.pop_back(); + } + auto off = path.npos; + while ((off = path.find("//")) != path.npos) { + path.erase(path.begin() + static_cast(off)); + } + return path; +} + +using index_type = std::map; + +} // detail + +class directory_entry { + std::string _fname; + const detail::file_or_directory* _item; + +public: + directory_entry() = delete; + explicit directory_entry(std::string filename, const detail::file_or_directory& item) + : _fname(filename) + , _item(&item) + {} + + const std::string& filename() const & { + return _fname; + } + std::string filename() const && { + return std::move(_fname); + } + + bool is_file() const { + return _item->is_file(); + } + + bool is_directory() const { + return _item->is_directory(); + } +}; + +directory_entry detail::directory::iterator::operator*() const noexcept { + assert(begin() != end()); + return directory_entry(_base_iter->first, _base_iter->second); +} + +using directory_iterator = detail::directory::iterator; + +class embedded_filesystem { + // Never-null: + const cmrc::detail::index_type* _index; + const detail::file_or_directory* _get(std::string path) const { + path = detail::normalize_path(path); + auto found = _index->find(path); + if (found == _index->end()) { + return nullptr; + } else { + return found->second; + } + } + +public: + explicit embedded_filesystem(const detail::index_type& index) + : _index(&index) + {} + + file open(const std::string& path) const { + auto entry_ptr = _get(path); + if (!entry_ptr || !entry_ptr->is_file()) { +#ifdef CMRC_NO_EXCEPTIONS + fprintf(stderr, "Error no such file or directory: %s\n", path.c_str()); + abort(); +#else + throw std::system_error(make_error_code(std::errc::no_such_file_or_directory), path); +#endif + } + auto& dat = entry_ptr->as_file(); + return file{dat.begin_ptr, dat.end_ptr}; + } + + bool is_file(const std::string& path) const noexcept { + auto entry_ptr = _get(path); + return entry_ptr && entry_ptr->is_file(); + } + + bool is_directory(const std::string& path) const noexcept { + auto entry_ptr = _get(path); + return entry_ptr && entry_ptr->is_directory(); + } + + bool exists(const std::string& path) const noexcept { + return !!_get(path); + } + + directory_iterator iterate_directory(const std::string& path) const { + auto entry_ptr = _get(path); + if (!entry_ptr) { +#ifdef CMRC_NO_EXCEPTIONS + fprintf(stderr, "Error no such file or directory: %s\n", path.c_str()); + abort(); +#else + throw std::system_error(make_error_code(std::errc::no_such_file_or_directory), path); +#endif + } + if (!entry_ptr->is_directory()) { +#ifdef CMRC_NO_EXCEPTIONS + fprintf(stderr, "Error not a directory: %s\n", path.c_str()); + abort(); +#else + throw std::system_error(make_error_code(std::errc::not_a_directory), path); +#endif + } + return entry_ptr->as_directory().begin(); + } +}; + +} + +#endif // CMRC_CMRC_HPP_INCLUDED +]==]) + +set(cmrc_hpp "${CMRC_INCLUDE_DIR}/cmrc/cmrc.hpp" CACHE INTERNAL "") +set(_generate 1) +if(EXISTS "${cmrc_hpp}") + file(READ "${cmrc_hpp}" _current) + if(_current STREQUAL hpp_content) + set(_generate 0) + endif() +endif() +file(GENERATE OUTPUT "${cmrc_hpp}" CONTENT "${hpp_content}" CONDITION ${_generate}) + +add_library(cmrc-base INTERFACE) +target_include_directories(cmrc-base INTERFACE $) +# Signal a basic C++11 feature to require C++11. +target_compile_features(cmrc-base INTERFACE cxx_nullptr) +set_property(TARGET cmrc-base PROPERTY INTERFACE_CXX_EXTENSIONS OFF) +add_library(cmrc::base ALIAS cmrc-base) + +function(cmrc_add_resource_library name) + set(args ALIAS NAMESPACE TYPE) + cmake_parse_arguments(ARG "" "${args}" "" "${ARGN}") + # Generate the identifier for the resource library's namespace + set(ns_re "[a-zA-Z_][a-zA-Z0-9_]*") + if(NOT DEFINED ARG_NAMESPACE) + # Check that the library name is also a valid namespace + if(NOT name MATCHES "${ns_re}") + message(SEND_ERROR "Library name is not a valid namespace. Specify the NAMESPACE argument") + endif() + set(ARG_NAMESPACE "${name}") + else() + if(NOT ARG_NAMESPACE MATCHES "${ns_re}") + message(SEND_ERROR "NAMESPACE for ${name} is not a valid C++ namespace identifier (${ARG_NAMESPACE})") + endif() + endif() + set(libname "${name}") + # Check that type is either "STATIC" or "OBJECT", or default to "STATIC" if + # not set + if(NOT DEFINED ARG_TYPE) + set(ARG_TYPE STATIC) + elseif(NOT "${ARG_TYPE}" MATCHES "^(STATIC|OBJECT)$") + message(SEND_ERROR "${ARG_TYPE} is not a valid TYPE (STATIC and OBJECT are acceptable)") + set(ARG_TYPE STATIC) + endif() + # Generate a library with the compiled in character arrays. + string(CONFIGURE [=[ + #include + #include + #include + + namespace cmrc { + namespace @ARG_NAMESPACE@ { + + namespace res_chars { + // These are the files which are available in this resource library + $, + > + } + + namespace { + + const cmrc::detail::index_type& + get_root_index() { + static cmrc::detail::directory root_directory_; + static cmrc::detail::file_or_directory root_directory_fod{root_directory_}; + static cmrc::detail::index_type root_index; + root_index.emplace("", &root_directory_fod); + struct dir_inl { + class cmrc::detail::directory& directory; + }; + dir_inl root_directory_dir{root_directory_}; + (void)root_directory_dir; + $, + > + $, + > + return root_index; + } + + } + + cmrc::embedded_filesystem get_filesystem() { + static auto& index = get_root_index(); + return cmrc::embedded_filesystem{index}; + } + + } // @ARG_NAMESPACE@ + } // cmrc + ]=] cpp_content @ONLY) + get_filename_component(libdir "${CMAKE_CURRENT_BINARY_DIR}/__cmrc_${name}" ABSOLUTE) + get_filename_component(lib_tmp_cpp "${libdir}/lib_.cpp" ABSOLUTE) + string(REPLACE "\n " "\n" cpp_content "${cpp_content}") + file(GENERATE OUTPUT "${lib_tmp_cpp}" CONTENT "${cpp_content}") + get_filename_component(libcpp "${libdir}/lib.cpp" ABSOLUTE) + add_custom_command(OUTPUT "${libcpp}" + DEPENDS "${lib_tmp_cpp}" "${cmrc_hpp}" + COMMAND ${CMAKE_COMMAND} -E copy_if_different "${lib_tmp_cpp}" "${libcpp}" + COMMENT "Generating ${name} resource loader" + ) + # Generate the actual static library. Each source file is just a single file + # with a character array compiled in containing the contents of the + # corresponding resource file. + add_library(${name} ${ARG_TYPE} ${libcpp}) + set_property(TARGET ${name} PROPERTY CMRC_LIBDIR "${libdir}") + set_property(TARGET ${name} PROPERTY CMRC_NAMESPACE "${ARG_NAMESPACE}") + target_link_libraries(${name} PUBLIC cmrc::base) + set_property(TARGET ${name} PROPERTY CMRC_IS_RESOURCE_LIBRARY TRUE) + if(ARG_ALIAS) + add_library("${ARG_ALIAS}" ALIAS ${name}) + endif() + cmrc_add_resources(${name} ${ARG_UNPARSED_ARGUMENTS}) +endfunction() + +function(_cmrc_register_dirs name dirpath) + if(dirpath STREQUAL "") + return() + endif() + # Skip this dir if we have already registered it + get_target_property(registered "${name}" _CMRC_REGISTERED_DIRS) + if(dirpath IN_LIST registered) + return() + endif() + # Register the parent directory first + get_filename_component(parent "${dirpath}" DIRECTORY) + if(NOT parent STREQUAL "") + _cmrc_register_dirs("${name}" "${parent}") + endif() + # Now generate the registration + set_property(TARGET "${name}" APPEND PROPERTY _CMRC_REGISTERED_DIRS "${dirpath}") + _cm_encode_fpath(sym "${dirpath}") + if(parent STREQUAL "") + set(parent_sym root_directory) + else() + _cm_encode_fpath(parent_sym "${parent}") + endif() + get_filename_component(leaf "${dirpath}" NAME) + set_property( + TARGET "${name}" + APPEND PROPERTY CMRC_MAKE_DIRS + "static auto ${sym}_dir = ${parent_sym}_dir.directory.add_subdir(\"${leaf}\")\;" + "root_index.emplace(\"${dirpath}\", &${sym}_dir.index_entry)\;" + ) +endfunction() + +function(cmrc_add_resources name) + get_target_property(is_reslib ${name} CMRC_IS_RESOURCE_LIBRARY) + if(NOT TARGET ${name} OR NOT is_reslib) + message(SEND_ERROR "cmrc_add_resources called on target '${name}' which is not an existing resource library") + return() + endif() + + set(options) + set(args WHENCE PREFIX) + set(list_args) + cmake_parse_arguments(ARG "${options}" "${args}" "${list_args}" "${ARGN}") + + if(NOT ARG_WHENCE) + set(ARG_WHENCE ${CMAKE_CURRENT_SOURCE_DIR}) + endif() + _cmrc_normalize_path(ARG_WHENCE) + get_filename_component(ARG_WHENCE "${ARG_WHENCE}" ABSOLUTE) + + # Generate the identifier for the resource library's namespace + get_target_property(lib_ns "${name}" CMRC_NAMESPACE) + + get_target_property(libdir ${name} CMRC_LIBDIR) + get_target_property(target_dir ${name} SOURCE_DIR) + file(RELATIVE_PATH reldir "${target_dir}" "${CMAKE_CURRENT_SOURCE_DIR}") + if(reldir MATCHES "^\\.\\.") + message(SEND_ERROR "Cannot call cmrc_add_resources in a parent directory from the resource library target") + return() + endif() + + foreach(input IN LISTS ARG_UNPARSED_ARGUMENTS) + _cmrc_normalize_path(input) + get_filename_component(abs_in "${input}" ABSOLUTE) + # Generate a filename based on the input filename that we can put in + # the intermediate directory. + file(RELATIVE_PATH relpath "${ARG_WHENCE}" "${abs_in}") + if(relpath MATCHES "^\\.\\.") + # For now we just error on files that exist outside of the soure dir. + message(SEND_ERROR "Cannot add file '${input}': File must be in a subdirectory of ${ARG_WHENCE}") + continue() + endif() + if(DEFINED ARG_PREFIX) + _cmrc_normalize_path(ARG_PREFIX) + endif() + if(ARG_PREFIX AND NOT ARG_PREFIX MATCHES "/$") + set(ARG_PREFIX "${ARG_PREFIX}/") + endif() + get_filename_component(dirpath "${ARG_PREFIX}${relpath}" DIRECTORY) + _cmrc_register_dirs("${name}" "${dirpath}") + get_filename_component(abs_out "${libdir}/intermediate/${relpath}.cpp" ABSOLUTE) + # Generate a symbol name relpath the file's character array + _cm_encode_fpath(sym "${relpath}") + # Get the symbol name for the parent directory + if(dirpath STREQUAL "") + set(parent_sym root_directory) + else() + _cm_encode_fpath(parent_sym "${dirpath}") + endif() + # Generate the rule for the intermediate source file + _cmrc_generate_intermediate_cpp(${lib_ns} ${sym} "${abs_out}" "${abs_in}") + target_sources(${name} PRIVATE "${abs_out}") + set_property(TARGET ${name} APPEND PROPERTY CMRC_EXTERN_DECLS + "// Pointers to ${input}" + "extern const char* const ${sym}_begin\;" + "extern const char* const ${sym}_end\;" + ) + get_filename_component(leaf "${relpath}" NAME) + set_property( + TARGET ${name} + APPEND PROPERTY CMRC_MAKE_FILES + "root_index.emplace(" + " \"${ARG_PREFIX}${relpath}\"," + " ${parent_sym}_dir.directory.add_file(" + " \"${leaf}\"," + " res_chars::${sym}_begin," + " res_chars::${sym}_end" + " )" + ")\;" + ) + endforeach() +endfunction() + +function(_cmrc_generate_intermediate_cpp lib_ns symbol outfile infile) + add_custom_command( + # This is the file we will generate + OUTPUT "${outfile}" + # These are the primary files that affect the output + DEPENDS "${infile}" "${_CMRC_SCRIPT}" + COMMAND + "${CMAKE_COMMAND}" + -D_CMRC_GENERATE_MODE=TRUE + -DNAMESPACE=${lib_ns} + -DSYMBOL=${symbol} + "-DINPUT_FILE=${infile}" + "-DOUTPUT_FILE=${outfile}" + -P "${_CMRC_SCRIPT}" + COMMENT "Generating intermediate file for ${infile}" + ) +endfunction() + +function(_cm_encode_fpath var fpath) + string(MAKE_C_IDENTIFIER "${fpath}" ident) + string(MD5 hash "${fpath}") + string(SUBSTRING "${hash}" 0 4 hash) + set(${var} f_${hash}_${ident} PARENT_SCOPE) +endfunction() diff --git a/cpp/include/baselines3_models/predictor.h b/cpp/include/baselines3_models/predictor.h new file mode 100644 index 000000000..77d023492 --- /dev/null +++ b/cpp/include/baselines3_models/predictor.h @@ -0,0 +1,34 @@ +#pragma once + +#include +#include + +namespace baselines3_models { +class Predictor { +public: + enum PolicyType { + // If we have an actor network and a value network + ACTOR_VALUE, + ACTOR_VALUE_DISCRETE, + // The network is a Q-Network: outputs Q(s,a) for all a for a given s + QNET_ALL + }; + + Predictor(std::string actor_filename, std::string q_filename, std::string v_filename); + + torch::Tensor predict(torch::Tensor &observation, bool unscale_action = true); + double value(torch::Tensor &observation); + + std::vector predict_vector(std::vector obs); + + virtual torch::Tensor preprocess_observation(torch::Tensor &observation); + virtual torch::Tensor process_action(torch::Tensor &action); + virtual std::vector enumerate_actions(); + +protected: + torch::jit::script::Module model_actor; + torch::jit::script::Module model_q; + torch::jit::script::Module model_v; + PolicyType policy_type; +}; +} // namespace baselines3_models \ No newline at end of file diff --git a/cpp/include/baselines3_models/preprocessing.h b/cpp/include/baselines3_models/preprocessing.h new file mode 100644 index 000000000..4dfcc098a --- /dev/null +++ b/cpp/include/baselines3_models/preprocessing.h @@ -0,0 +1,10 @@ +#pragma once + +#include +#include + +namespace baselines3_models { + +torch::Tensor multi_one_hot(torch::Tensor &input, torch::Tensor &classes); + +} \ No newline at end of file diff --git a/cpp/model_template.cpp b/cpp/model_template.cpp new file mode 100644 index 000000000..276429ae4 --- /dev/null +++ b/cpp/model_template.cpp @@ -0,0 +1,40 @@ +/*** + * This file was AUTOGENERATED by Stable Baselines3 Zoo + * https://github.com/DLR-RM/rl-baselines3-zoo + */ +#include "baselines3_models/FILE_NAME.h" +#include "baselines3_models/preprocessing.h" +#ifdef EXPORT_PYBIND +#include +#include +#endif + +namespace baselines3_models { + +CLASS_NAME::CLASS_NAME() : Predictor("MODEL_ACTOR", "MODEL_Q", "MODEL_V") { + policy_type = POLICY_TYPE; +} + +torch::Tensor CLASS_NAME::preprocess_observation(torch::Tensor &observation) { + torch::Tensor result; + PREPROCESS_OBSERVATION + return result; +} + +torch::Tensor CLASS_NAME::process_action(torch::Tensor &action) { + torch::Tensor result; + PROCESS_ACTION + return result; +} + +#ifdef EXPORT_PYBIND +namespace py = pybind11; + +PYBIND11_MODULE(baselines3_py, m) { + py::class_(m, "CLASS_NAME") + .def(py::init()) + .def("predict", &CLASS_NAME::predict_vector); +} +#endif + +} // namespace baselines3_models \ No newline at end of file diff --git a/cpp/model_template.h b/cpp/model_template.h new file mode 100644 index 000000000..d4bec8d82 --- /dev/null +++ b/cpp/model_template.h @@ -0,0 +1,27 @@ +/*** + * This file was AUTOGENERATED by Stable Baselines3 Zoo + * https://github.com/DLR-RM/rl-baselines3-zoo + */ +#pragma once + +#include "baselines3_models/predictor.h" +#include "torch/script.h" + +namespace baselines3_models { +class CLASS_NAME : public Predictor { +public: + CLASS_NAME(); + + /** + * Observation space is: + * OBSERVATION_SPACE + */ + torch::Tensor preprocess_observation(torch::Tensor &observation) override; + + /** + * Action space is: + * ACTION_SPACE + */ + torch::Tensor process_action(torch::Tensor &action) override; +}; +} // namespace baselines3_models \ No newline at end of file diff --git a/cpp/src/baselines3_models/predictor.cpp b/cpp/src/baselines3_models/predictor.cpp new file mode 100644 index 000000000..429ccedba --- /dev/null +++ b/cpp/src/baselines3_models/predictor.cpp @@ -0,0 +1,97 @@ +#include "baselines3_models/predictor.h" +#include "cmrc/cmrc.hpp" + +CMRC_DECLARE(baselines3_model); + +namespace baselines3_models { + +static void _load_model(std::string filename, + torch::jit::script::Module &model) { + if (filename != "") { + auto fs = cmrc::baselines3_model::get_filesystem(); + auto f = fs.open(filename); + std::string data(f.begin(), f.end()); + std::istringstream stream(data); + model = torch::jit::load(stream); + } +} + +Predictor::Predictor(std::string actor_filename, std::string q_filename, + std::string v_filename) { + + _load_model(actor_filename, model_actor); + _load_model(q_filename, model_q); + _load_model(v_filename, model_v); +} + +torch::Tensor Predictor::predict(torch::Tensor &observation, + bool unscale_action) { + c10::InferenceMode guard; + torch::Tensor processed_observation = preprocess_observation(observation); + at::Tensor action; + std::vector inputs; + inputs.push_back(processed_observation.unsqueeze(0)); + + if (policy_type == ACTOR_VALUE || policy_type == ACTOR_VALUE_DISCRETE) { + action = model_actor.forward(inputs).toTensor(); + if (unscale_action) { + action = process_action(action); + } + if (policy_type == ACTOR_VALUE_DISCRETE) { + action = torch::argmax(action); + } + } else if (policy_type == QNET_ALL) { + auto q_values = model_q.forward(inputs).toTensor(); + action = torch::argmax(q_values); + } else { + throw std::runtime_error("Unknown policy type"); + } + + return action; +} + +double Predictor::value(torch::Tensor &observation) { + double value = 0.0; + + torch::Tensor processed_observation = preprocess_observation(observation); + at::Tensor action; + std::vector inputs; + + if (policy_type == ACTOR_VALUE || policy_type == ACTOR_VALUE_DISCRETE) { + inputs.push_back(processed_observation.unsqueeze(0)); + auto v = model_v.forward(inputs).toTensor(); + value = v.data_ptr()[0]; + } else if (policy_type == QNET_ALL) { + inputs.push_back(processed_observation.unsqueeze(0)); + auto q = model_q.forward(inputs).toTensor(); + value = torch::max(q).data_ptr()[0]; + } else { + throw std::runtime_error("Unknown policy type"); + } + + return value; +} + +std::vector Predictor::predict_vector(std::vector obs) { + torch::Tensor observation = torch::from_blob(obs.data(), obs.size()); + torch::Tensor action = predict(observation); + action = action.contiguous().to(torch::kFloat32); + std::vector result(action.data_ptr(), + action.data_ptr() + action.numel()); + return result; +} + +torch::Tensor Predictor::preprocess_observation(torch::Tensor &observation) { + return observation; +} + +torch::Tensor Predictor::process_action(torch::Tensor &action) { + return action; +} + +std::vector Predictor::enumerate_actions() { + std::vector result; + return result; +} + +} // namespace baselines3_models \ No newline at end of file diff --git a/cpp/src/baselines3_models/preprocessing.cpp b/cpp/src/baselines3_models/preprocessing.cpp new file mode 100644 index 000000000..e0707e450 --- /dev/null +++ b/cpp/src/baselines3_models/preprocessing.cpp @@ -0,0 +1,24 @@ +#include "baselines3_models/preprocessing.h" + +using namespace torch::indexing; + +namespace baselines3_models { + +torch::Tensor multi_one_hot(torch::Tensor &input, torch::Tensor &classes) { + int entries = torch::sum(classes).item(); + + torch::Tensor result = + torch::zeros({1, entries}, torch::TensorOptions().dtype(torch::kLong)); + + int offset = 0; + for (int k = 0; k < classes.sizes()[0]; k++) { + int n = classes[k].item(); + + result.index({0, Slice(offset, offset + n)}) = torch::one_hot(input[k], n); + offset += n; + } + + return result; +} + +} // namespace baselines3_models \ No newline at end of file diff --git a/cpp/src/predict.cpp b/cpp/src/predict.cpp new file mode 100644 index 000000000..f854b79e9 --- /dev/null +++ b/cpp/src/predict.cpp @@ -0,0 +1,16 @@ +// This file is just a demonstration, you can adapt to test your model +// First, include your model: +#include "baselines3_models/cartpole_v1.h" + +using namespace baselines3_models; + +int main(int argc, const char *argv[]) { + // Create an instance of it: + CartPole_v1 cartpole; + + // Build an observation: + torch::Tensor observation = torch::tensor({0., 0., 0., 0.}); + + // You can now check the prediction: + std::cout << cartpole.predict(observation) << std::endl; +} \ No newline at end of file diff --git a/enjoy.py b/enjoy.py index d94fa03fa..1dfe171e5 100644 --- a/enjoy.py +++ b/enjoy.py @@ -11,7 +11,9 @@ import utils.import_envs # noqa: F401 pylint: disable=unused-import from utils import ALGOS, create_test_env, get_saved_hyperparams +from utils.cpp_exporter import CppExporter from utils.callbacks import tqdm + from utils.exp_manager import ExperimentManager from utils.load_from_hub import download_from_hub from utils.utils import StoreDict, get_model_path @@ -32,6 +34,7 @@ def main(): # noqa: C901 ) parser.add_argument("--deterministic", action="store_true", default=False, help="Use deterministic actions") parser.add_argument("--device", help="PyTorch device to be use (ex: cpu, cuda...)", default="auto", type=str) + parser.add_argument("--export-cpp", help="Export to C++ code", default="", type=str) parser.add_argument( "--load-best", action="store_true", default=False, help="Load best model instead of last model if available" ) @@ -82,6 +85,10 @@ def main(): # noqa: C901 env_name: EnvironmentName = args.env algo = args.algo folder = args.folder + device = args.device + + if args.export_cpp: + device = "cpu" try: _, model_path, log_path = get_model_path( @@ -188,10 +195,16 @@ def main(): # noqa: C901 "clip_range": lambda _: 0.0, } - model = ALGOS[algo].load(model_path, env=env, custom_objects=custom_objects, device=args.device, **kwargs) + model = ALGOS[algo].load(model_path, env=env, custom_objects=custom_objects, device=device, **kwargs) obs = env.reset() + if args.export_cpp: + print("Exporting to C++...") + exporter = CppExporter(model, args.export_cpp, args.env) + exporter.export() + exit() + # Deterministic by default except for atari games stochastic = args.stochastic or is_atari and not args.deterministic deterministic = not stochastic diff --git a/utils/cpp_exporter.py b/utils/cpp_exporter.py new file mode 100644 index 000000000..3a1cfa093 --- /dev/null +++ b/utils/cpp_exporter.py @@ -0,0 +1,296 @@ +import os +import re +import shutil +from pathlib import Path +from typing import List + +import torch as th +from gym import spaces +from stable_baselines3.common.base_class import BaseAlgorithm +from stable_baselines3.common.policies import ActorCriticPolicy +from stable_baselines3.common.preprocessing import is_image_space, preprocess_obs +from stable_baselines3.dqn.policies import DQNPolicy +from stable_baselines3.sac.policies import SACPolicy +from stable_baselines3.td3.policies import TD3Policy + + +class CppExporter(object): + def __init__(self, model: BaseAlgorithm, directory: str, name: str): + """ + C++ module exporter + + :param model: The algorithm that should be exported + :param directory: Output directory + :param name: The module name + """ + self.model = model + self.directory = directory + self.name = name.replace("-", "_") + + # Templates directory is found relatively to this script (../cpp) + self.template_directory = str(Path(__file__).parent.parent / Path("cpp")) + self.vars = {} + + # Name of the asset (.pt file) + self.asset_fnames = [] + self.cpp_fname = None + + def generate_directory(self) -> None: + """ + Generates the target directory if it doesn't exists + """ + + def ignore(directory: str, files: List[str]) -> List[str]: + if directory == self.template_directory: + return [".gitignore", "model_template.h", "model_template.cpp"] + + return [] + + if not os.path.isdir(self.directory): + shutil.copytree(self.template_directory, self.directory, ignore=ignore) + + def generate_observation_preprocessing(self): + """ + Generates observation preprocessing code for this model + + :raises NotImplementedError: If the observation space is not supported + """ + observation_space = self.model.env.observation_space + policy = self.model.policy + preprocess_observation = "" + self.vars["OBSERVATION_SPACE"] = repr(observation_space) + + if isinstance(observation_space, spaces.Box): + if is_image_space(observation_space) and policy.normalize_images: + # Normalizing image pixels + preprocess_observation += "result = observation / 255.;\n" + else: + # Keeping observation as it is + preprocess_observation += "result = observation;\n" + elif isinstance(observation_space, spaces.Discrete): + # Applying one hot representation + preprocess_observation += f"result = torch::one_hot(observation, {observation_space.n});\n" + elif isinstance(observation_space, spaces.MultiDiscrete): + # Applying multiple one hot representation (using C++ function) + classes = ",".join(map(str, observation_space.nvec)) + preprocess_observation += f"torch::Tensor classes = torch::tensor({classes});\n" + preprocess_observation += "result = multi_one_hot(observation, classes);\n" + else: + raise NotImplementedError(f"C++ exporting does not support observation {observation_space}") + + self.vars["PREPROCESS_OBSERVATION"] = preprocess_observation + + def generate_action_processing(self): + """ + Generates the action post-processing + + :raises NotImplementedError: If the action space is not supported + """ + action_space = self.model.env.action_space + process_action = "" + self.vars["ACTION_SPACE"] = repr(action_space) + + if isinstance(action_space, spaces.Box): + # Handling action clipping + low_values = ",".join([f"(float){x}" for x in action_space.low]) + process_action += "torch::Tensor action_low = torch::tensor({%s});\n" % low_values + high_values = ",".join([f"(float){x}" for x in action_space.high]) + process_action += "torch::Tensor action_high = torch::tensor({%s});\n" % high_values + + if self.model.policy.squash_output: + # Unscaling the action assuming it lies in [-1, 1], since squash networks use Tanh as + # final activation functions + process_action += "result = action_low + (0.5 * (action + 1.0) * (action_high - action_low));\n" + else: + # Clipping not squashed action + process_action += "result = torch::clip(action, action_low, action_high);\n" + elif isinstance(action_space, spaces.Discrete) or isinstance(action_space, spaces.MultiDiscrete): + # Keeping input as it is + process_action += "result = action;\n" + else: + raise NotImplementedError(f"C++ exporting does not support processing action {action_space}") + + self.vars["PROCESS_ACTION"] = process_action + + def export_code(self): + """ + Export the C++ code + """ + self.vars["CLASS_NAME"] = self.name + fname = self.name.lower() + + self.vars["FILE_NAME"] = fname + self.cpp_fname = os.path.join("src", "baselines3_models", f"{fname}.cpp") + include_fname = os.path.join("include", "baselines3_models", f"{fname}.h") + target_header = os.path.join(self.directory, include_fname) + target_cpp = os.path.join(self.directory, self.cpp_fname) + + self.generate_observation_preprocessing() + self.generate_action_processing() + + self.render("model_template.h", target_header) + self.render("model_template.cpp", target_cpp) + + def render(self, template: str, target: str): + """ + Renders some template, replacing self.vars variables by their values + + :param str template: The template name + :param str target: The target file + """ + with open(self.template_directory + "/" + template, "r") as template_f: + with open(target, "w") as target_f: + data = template_f.read() + for var in self.vars: + data = data.replace(var, self.vars[var]) + target_f.write(data) + + print("Generated " + target) + + def update_cmake(self): + """ + Updates the target's CMakeLists.txt, adding files in static and sources section + + :raises ValueError: If a section can't be found in the CMakeLists + """ + cmake_contents = open(self.directory + "/CMakeLists.txt", "r").read() + + def add_to_section(section_name: str, fname: str, contents: str): + pattern = f"#{section_name}(.+)#!{section_name}" + flags = re.MULTILINE + re.DOTALL + + match = re.search(pattern, cmake_contents, flags=flags) + + if match is None: + raise ValueError(f"Couldn't find {section_name} section in CMakeLists.txt") + + files = match[1].strip() + if files: + files = list(map(str.strip, files.split("\n"))) + else: + files = [] + + if fname not in files: + print(f"Adding {fname} to CMake {section_name}") + files.append(fname) + + new_section = f"#{section_name}\n" + ("\n".join(files)) + "\n" + f"#!{section_name}" + + return re.sub(pattern, new_section, contents, flags=flags) + + for asset in self.asset_fnames: + cmake_contents = add_to_section("static", asset, cmake_contents) + cmake_contents = add_to_section("sources", self.cpp_fname, cmake_contents) + + with open(self.directory + "/CMakeLists.txt", "w") as f: + f.write(cmake_contents) + + def export_model(self): + """ + Export the Algorithm's model using Pytorch's JIT script tracer + + :raises NotImplementedError: If the policy is not supported + """ + policy = self.model.policy + obs = th.Tensor(self.model.env.reset()) + + def get_fname(suffix: str): + asset_fname = os.path.join("assets", f"{self.name.lower()}_{suffix}.pt") + fname = os.path.join(self.directory, asset_fname) + return asset_fname, fname + + traced = { + "actor": None, + "q": None, + "v": None, + } + + obs = preprocess_obs(obs, self.model.env.observation_space) + + if isinstance(policy, TD3Policy): + # Actor extract features and apply mu + actor_model = th.nn.Sequential(policy.actor.features_extractor, policy.actor.mu) + traced["actor"] = th.jit.trace(actor_model, obs) + + # Value function is a combination of actor and Q + class TD3PolicyValue(th.nn.Module): + def __init__(self, policy: TD3Policy, actor_model: th.nn.Module): + super(TD3PolicyValue, self).__init__() + + self.actor = actor_model + self.critic = policy.critic + + def forward(self, obs): + action = self.actor_model(obs) + critic_features = self.critic.features_extractor(obs) + return self.critic.q_networks[0](th.cat([critic_features, action], dim=1)) + + # Note(antonin): unused variable action + action = policy.actor.mu(policy.actor.extract_features(obs)) + v_model = TD3PolicyValue(policy, actor_model) + traced["v"] = th.jit.trace(v_model, obs) + self.vars["POLICY_TYPE"] = "ACTOR_VALUE" + elif isinstance(policy, SACPolicy): + # Feature extractor, latent pi and mu + if self.model.use_sde: + # XXX: Check for bijector ? + actor_model = th.nn.Sequential(policy.actor.features_extractor, policy.actor.latent_pi, policy.actor.mu) + else: + actor_model = th.nn.Sequential( + policy.actor.features_extractor, policy.actor.latent_pi, policy.actor.mu, th.nn.Tanh() + ) + traced["actor"] = th.jit.trace(actor_model, obs) + + class SACPolicyValue(th.nn.Module): + def __init__(self, policy: SACPolicy, actor_model: th.nn.Module): + super(SACPolicyValue, self).__init__() + + self.actor_model = actor_model + self.critic = policy.critic + + def forward(self, obs): + action = self.actor_model(obs) + critic_features = self.critic.features_extractor(obs) + return self.critic.q_networks[0](th.cat([critic_features, action], dim=1)) + + v_model = SACPolicyValue(policy, actor_model) + traced["v"] = th.jit.trace(v_model, obs) + self.vars["POLICY_TYPE"] = "ACTOR_VALUE" + elif isinstance(policy, ActorCriticPolicy): + # Actor is feature extractor, mpl and action net + actor_model = th.nn.Sequential(policy.features_extractor, policy.mlp_extractor.policy_net, policy.action_net) + traced["actor"] = th.jit.trace(actor_model, obs) + + # The value network is computed directly in ActorCriticPolicy (and not the Q network) + value_model = th.nn.Sequential(policy.features_extractor, policy.mlp_extractor.value_net, policy.value_net) + traced["v"] = th.jit.trace(value_model, obs) + + if isinstance(self.model.env.action_space, spaces.Discrete): + self.vars["POLICY_TYPE"] = "ACTOR_VALUE_DISCRETE" + else: + self.vars["POLICY_TYPE"] = "ACTOR_VALUE" + elif isinstance(policy, DQNPolicy): + # For DQN, we only use one Q network that outputs Q(s,a) for all possible actions, it is then + # both used for action prediction using argmax and for value prediction + q_model = th.nn.Sequential(policy.q_net.features_extractor, policy.q_net.q_net) + traced["q"] = th.jit.trace(q_model, obs) + self.vars["POLICY_TYPE"] = "QNET_ALL" + else: + raise NotImplementedError(f"C++ exporting does not support policy {policy}") + + for entry in traced.keys(): + var = f"MODEL_{entry.upper()}" + if traced[entry] is None: + self.vars[var] = "" + else: + asset_fname, fname = get_fname(entry) + traced[entry].save(fname) + print(f"Generated {fname}") + self.asset_fnames.append(asset_fname) + self.vars[var] = asset_fname + + def export(self) -> None: + self.generate_directory() + self.export_model() + self.export_code() + self.update_cmake()