diff --git a/bindings/core/src/core.cpp b/bindings/core/src/core.cpp index 8cb3c6b78..1396a790c 100644 --- a/bindings/core/src/core.cpp +++ b/bindings/core/src/core.cpp @@ -12,6 +12,11 @@ void *constructInitialConfiguration(const KOREPattern *); namespace kllvm::bindings { +std::string return_sort_for_label(std::string const &label) { + auto tag = getTagForSymbolName(label.c_str()); + return getReturnSortForTag(tag); +} + std::shared_ptr make_injection( std::shared_ptr term, std::shared_ptr from, std::shared_ptr to) { diff --git a/include/kllvm/ast/AST.h b/include/kllvm/ast/AST.h index a63538c7a..4f3afc20f 100644 --- a/include/kllvm/ast/AST.h +++ b/include/kllvm/ast/AST.h @@ -36,6 +36,18 @@ using sptr = std::shared_ptr; std::string decodeKore(std::string); +/* + * Helper function to avoid repeated call-site uses of ostringstream when we + * just want the string representation of a node, rather than to print it to a + * stream. + */ +template +std::string ast_to_string(T &&node) { + auto os = std::ostringstream{}; + std::forward(node).print(os); + return os.str(); +} + // KORESort class KORESort : public std::enable_shared_from_this { public: diff --git a/include/kllvm/bindings/core/core.h b/include/kllvm/bindings/core/core.h index 17299bf85..4a6513e8f 100644 --- a/include/kllvm/bindings/core/core.h +++ b/include/kllvm/bindings/core/core.h @@ -12,6 +12,8 @@ namespace kllvm::bindings { +std::string return_sort_for_label(std::string const &label); + std::shared_ptr make_injection( std::shared_ptr term, std::shared_ptr from, std::shared_ptr to); diff --git a/include/runtime/header.h b/include/runtime/header.h index da875e6fd..a189051a5 100644 --- a/include/runtime/header.h +++ b/include/runtime/header.h @@ -350,6 +350,7 @@ uint32_t getInjectionForSortOfTag(uint32_t tag); bool hook_STRING_eq(SortString, SortString); const char *getSymbolNameForTag(uint32_t tag); +const char *getReturnSortForTag(uint32_t tag); const char *topSort(void); typedef struct { diff --git a/lib/codegen/EmitConfigParser.cpp b/lib/codegen/EmitConfigParser.cpp index f90754ee7..73f3a96e3 100644 --- a/lib/codegen/EmitConfigParser.cpp +++ b/lib/codegen/EmitConfigParser.cpp @@ -1308,6 +1308,43 @@ static void emitSortTable(KOREDefinition *definition, llvm::Module *module) { } } +static void +emitReturnSortTable(KOREDefinition *definition, llvm::Module *module) { + auto &ctx = module->getContext(); + + auto const &syms = definition->getSymbols(); + + auto element_type = llvm::Type::getInt8PtrTy(ctx); + auto table_type = llvm::ArrayType::get(element_type, syms.size()); + + auto table = module->getOrInsertGlobal("return_sort_table", table_type); + auto values = std::vector{}; + + for (auto [tag, symbol] : syms) { + auto sort = symbol->getSort(); + auto sort_str = ast_to_string(*sort); + + auto char_type = llvm::Type::getInt8Ty(ctx); + auto str_type = llvm::ArrayType::get(char_type, sort_str.size() + 1); + + auto sort_name + = module->getOrInsertGlobal("sort_name_" + sort_str, str_type); + + auto i64_type = llvm::Type::getInt64Ty(ctx); + auto zero = llvm::ConstantInt::get(i64_type, 0); + + auto pointer = llvm::ConstantExpr::getInBoundsGetElementPtr( + str_type, sort_name, std::vector{zero}); + + values.push_back(pointer); + } + + auto global = llvm::dyn_cast(table); + if (!global->hasInitializer()) { + global->setInitializer(llvm::ConstantArray::get(table_type, values)); + } +} + void emitConfigParserFunctions( KOREDefinition *definition, llvm::Module *module) { emitGetTagForSymbolName(definition, module); @@ -1329,6 +1366,7 @@ void emitConfigParserFunctions( emitInjTags(definition, module); emitSortTable(definition, module); + emitReturnSortTable(definition, module); } } // namespace kllvm diff --git a/runtime/util/util.cpp b/runtime/util/util.cpp index 41535ccae..8b81e3124 100644 --- a/runtime/util/util.cpp +++ b/runtime/util/util.cpp @@ -2,6 +2,12 @@ extern "C" { +extern char *return_sort_table; + +const char *getReturnSortForTag(uint32_t tag) { + return (&return_sort_table)[tag]; +} + block *dot_k() { return leaf_block(getTagForSymbolName("dotk{}")); }