Skip to content

Commit

Permalink
Merge pull request #419 from neilcook/luastate
Browse files Browse the repository at this point in the history
Refactor lustate handling to be much more efficient
  • Loading branch information
neilcook authored Apr 3, 2024
2 parents acbdfd9 + b2b1146 commit 4141a4a
Show file tree
Hide file tree
Showing 2 changed files with 129 additions and 67 deletions.
85 changes: 56 additions & 29 deletions trackalert/trackalert-luastate.hh
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
*/

#pragma once

#include "ext/luawrapper/include/LuaContext.hpp"
#include "misc.hh"
#include <mutex>
Expand All @@ -42,7 +43,9 @@ typedef std::function<void()> background_t;
extern background_t g_background;
typedef std::unordered_map<std::string, background_t> bg_func_map_t;

vector<std::function<void(void)>> setupLua(bool client, bool allow_report, LuaContext& c_lua, report_t& report_func, bg_func_map_t* bg_func_map, CustomFuncMap& custom_func_map, const std::string& config);
vector<std::function<void(void)>>
setupLua(bool client, bool allow_report, LuaContext& c_lua, report_t& report_func, bg_func_map_t* bg_func_map,
CustomFuncMap& custom_func_map, const std::string& config);

struct LuaThreadContext {
LuaContext lua_context;
Expand All @@ -54,40 +57,46 @@ struct LuaThreadContext {

#define NUM_LUA_STATES 6

class LuaMultiThread
{
class LuaMultiThread {
public:
LuaMultiThread() : num_states(NUM_LUA_STATES),
state_index(0)
LuaMultiThread() : num_states(NUM_LUA_STATES)
{
LuaMultiThread{num_states};
}

LuaMultiThread(unsigned int nstates) : num_states(nstates),
state_index(0)
LuaMultiThread(unsigned int nstates) : num_states(nstates)
{
for (unsigned int i=0; i<num_states; i++) {
lua_cv.push_back(std::make_shared<LuaThreadContext>());
}
for (unsigned int i = 0; i < num_states; i++) {
lua_pool.push_back(std::make_shared<LuaThreadContext>());
}
lua_read_only = lua_pool; // Make a copy for use by the control thread
}

LuaMultiThread(const LuaMultiThread&) = delete;

LuaMultiThread& operator=(const LuaMultiThread&) = delete;

// these are used to setup the function pointers
std::vector<std::shared_ptr<LuaThreadContext>>::iterator begin() { return lua_cv.begin(); }
std::vector<std::shared_ptr<LuaThreadContext>>::iterator end() { return lua_cv.end(); }
std::vector<std::shared_ptr<LuaThreadContext>>::iterator begin()
{ return lua_read_only.begin(); }

void report(const LoginTuple& lt) {
auto lt_context = getLuaState();
std::vector<std::shared_ptr<LuaThreadContext>>::iterator end()
{ return lua_read_only.end(); }

void report(const LoginTuple& lt)
{
auto pool_member = getPoolMember();
auto lt_context = pool_member.getLuaContext();
// lock the lua state mutex
std::lock_guard<std::mutex> lock(lt_context->lua_mutex);
// call the report function
lt_context->report_func(lt);
}

void background(const std::string& func_name) {
auto lt_context = getLuaState();
void background(const std::string& func_name)
{
auto pool_member = getPoolMember();
auto lt_context = pool_member.getLuaContext();
// lock the lua state mutex
std::lock_guard<std::mutex> lock(lt_context->lua_mutex);
// call the background function
Expand All @@ -96,31 +105,49 @@ public:
fn->second();
}

CustomFuncReturn custom_func(const std::string& command, const CustomFuncArgs& cfa) {
auto lt_context = getLuaState();
CustomFuncReturn custom_func(const std::string& command, const CustomFuncArgs& cfa)
{
auto pool_member = getPoolMember();
auto lt_context = pool_member.getLuaContext();
// lock the lua state mutex
std::lock_guard<std::mutex> lock(lt_context->lua_mutex);
// call the custom function
for (const auto& i : lt_context->custom_func_map) {
for (const auto& i: lt_context->custom_func_map) {
if (command.compare(i.first) == 0) {
return i.second.c_func(cfa);
return i.second.c_func(cfa);
}
}
return CustomFuncReturn(false, KeyValVector{});
}

protected:
std::shared_ptr<LuaThreadContext> getLuaState()
{
class SharedPoolMember {
public:
SharedPoolMember(std::shared_ptr<LuaThreadContext> ptr, LuaMultiThread* pool) : d_pool_item(ptr), d_pool(pool) {}
~SharedPoolMember() { if (d_pool != nullptr) { d_pool->returnPoolMember(d_pool_item); } }
SharedPoolMember(const SharedPoolMember&) = delete;
SharedPoolMember& operator=(const SharedPoolMember&) = delete;
std::shared_ptr<LuaThreadContext> getLuaContext() { return d_pool_item; }
private:
std::shared_ptr<LuaThreadContext> d_pool_item;
LuaMultiThread* d_pool;
};

SharedPoolMember getPoolMember() {
std::lock_guard<std::mutex> lock(mutx);
if (state_index >= num_states)
state_index = 0;
return lua_cv[state_index++];
auto member = lua_pool.back();
lua_pool.pop_back();
return SharedPoolMember(member, this);
}
void returnPoolMember(std::shared_ptr<LuaThreadContext> my_ptr) {
std::lock_guard<std::mutex> lock(mutx);
lua_pool.push_back(my_ptr);
}

private:
std::vector<std::shared_ptr<LuaThreadContext>> lua_cv;
std::vector<std::shared_ptr<LuaThreadContext>> lua_pool;
std::vector<std::shared_ptr<LuaThreadContext>> lua_read_only;
unsigned int num_states;
unsigned int state_index;
std::mutex mutx;
};

Expand Down
111 changes: 73 additions & 38 deletions wforce/luastate.hh
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
*/

#pragma once

#include "ext/luawrapper/include/LuaContext.hpp"
#include "misc.hh"
#include <mutex>
Expand All @@ -41,7 +42,7 @@ typedef std::function<std::string(const std::string&)> canonicalize_t;

struct CustomFuncMapObject {
custom_func_t c_func;
bool c_reportSink;
bool c_reportSink;
};

typedef std::map<std::string, CustomFuncMapObject> CustomFuncMap;
Expand All @@ -50,7 +51,10 @@ extern CustomFuncMap g_custom_func_map;
typedef std::map<std::string, custom_get_func_t> CustomGetFuncMap;
extern CustomGetFuncMap g_custom_get_func_map;

vector<std::function<void(void)>> setupLua(bool client, bool allow_report, LuaContext& c_lua, allow_t& allow_func, report_t& report_func, reset_t& reset_func, canonicalize_t& canon_func, CustomFuncMap& custom_func_map, CustomGetFuncMap& custom_get_func_map, const std::string& config);
vector<std::function<void(void)>>
setupLua(bool client, bool allow_report, LuaContext& c_lua, allow_t& allow_func, report_t& report_func,
reset_t& reset_func, canonicalize_t& canon_func, CustomFuncMap& custom_func_map,
CustomGetFuncMap& custom_get_func_map, const std::string& config);

struct LuaThreadContext {
LuaContext lua_context;
Expand All @@ -65,101 +69,132 @@ struct LuaThreadContext {

#define NUM_LUA_STATES 6

class LuaMultiThread
{
class LuaMultiThread {
public:
LuaMultiThread() : num_states(NUM_LUA_STATES),
state_index(0)

LuaMultiThread() : num_states(NUM_LUA_STATES)
{
LuaMultiThread{num_states};
}

LuaMultiThread(unsigned int nstates) : num_states(nstates),
state_index(0)
LuaMultiThread(unsigned int nstates) : num_states(nstates)
{
for (unsigned int i=0; i<num_states; i++) {
lua_cv.push_back(std::make_shared<LuaThreadContext>());
}
for (unsigned int i = 0; i < num_states; i++) {
lua_pool.push_back(std::make_shared<LuaThreadContext>());
}
lua_read_only = lua_pool; // Make a copy for use by the control thread
}

LuaMultiThread(const LuaMultiThread&) = delete;

LuaMultiThread& operator=(const LuaMultiThread&) = delete;

// these are used to setup the allow and report function pointers
std::vector<std::shared_ptr<LuaThreadContext>>::iterator begin() { return lua_cv.begin(); }
std::vector<std::shared_ptr<LuaThreadContext>>::iterator end() { return lua_cv.end(); }
std::vector<std::shared_ptr<LuaThreadContext>>::iterator begin()
{ return lua_read_only.begin(); }

std::vector<std::shared_ptr<LuaThreadContext>>::iterator end()
{ return lua_read_only.end(); }

bool reset(const std::string& type, const std::string& login_value, const ComboAddress& ca_value) {
auto lt_context = getLuaState();
bool reset(const std::string& type, const std::string& login_value, const ComboAddress& ca_value)
{
auto pool_member = getPoolMember();
auto lt_context = pool_member.getLuaContext();
// lock the lua state mutex
std::lock_guard<std::mutex> lock(lt_context->lua_mutex);
// call the reset function
return lt_context->reset_func(type, login_value, ca_value);
}

AllowReturn allow(const LoginTuple& lt) {
auto lt_context = getLuaState();
AllowReturn allow(const LoginTuple& lt)
{
auto pool_member = getPoolMember();
auto lt_context = pool_member.getLuaContext();
// lock the lua state mutex
std::lock_guard<std::mutex> lock(lt_context->lua_mutex);
// call the allow function
return lt_context->allow_func(lt);
}

void report(const LoginTuple& lt) {
auto lt_context = getLuaState();
void report(const LoginTuple& lt)
{
auto pool_member = getPoolMember();
auto lt_context = pool_member.getLuaContext();
// lock the lua state mutex
std::lock_guard<std::mutex> lock(lt_context->lua_mutex);
// call the report function
lt_context->report_func(lt);
}

std::string canonicalize(const std::string& login) {
auto lt_context = getLuaState();
std::string canonicalize(const std::string& login)
{
auto pool_member = getPoolMember();
auto lt_context = pool_member.getLuaContext();
// lock the lua state mutex
std::lock_guard<std::mutex> lock(lt_context->lua_mutex);
// call the canonicalize function
return lt_context->canon_func(login);
}

CustomFuncReturn custom_func(const std::string& command, const CustomFuncArgs& cfa, bool& reportSinkReturn) {
auto lt_context = getLuaState();
CustomFuncReturn custom_func(const std::string& command, const CustomFuncArgs& cfa, bool& reportSinkReturn)
{
auto pool_member = getPoolMember();
auto lt_context = pool_member.getLuaContext();
// lock the lua state mutex
std::lock_guard<std::mutex> lock(lt_context->lua_mutex);
// call the custom function
for (const auto& i : lt_context->custom_func_map) {
for (const auto& i: lt_context->custom_func_map) {
if (command.compare(i.first) == 0) {
reportSinkReturn = i.second.c_reportSink;
return i.second.c_func(cfa);
reportSinkReturn = i.second.c_reportSink;
return i.second.c_func(cfa);
}
}
return CustomFuncReturn(false, KeyValVector{});
}

std::string custom_get_func(const std::string& command) {
auto lt_context = getLuaState();
std::string custom_get_func(const std::string& command)
{
auto pool_member = getPoolMember();
auto lt_context = pool_member.getLuaContext();
// lock the lua state mutex
std::lock_guard<std::mutex> lock(lt_context->lua_mutex);
// call the custom function
for (const auto& i : lt_context->custom_get_func_map) {
for (const auto& i: lt_context->custom_get_func_map) {
if (command.compare(i.first) == 0) {
return i.second();
return i.second();
}
}
return string();
}

protected:
std::shared_ptr<LuaThreadContext> getLuaState()
{

class SharedPoolMember {
public:
SharedPoolMember(std::shared_ptr<LuaThreadContext> ptr, LuaMultiThread* pool) : d_pool_item(ptr), d_pool(pool) {}
~SharedPoolMember() { if (d_pool != nullptr) { d_pool->returnPoolMember(d_pool_item); } }
SharedPoolMember(const SharedPoolMember&) = delete;
SharedPoolMember& operator=(const SharedPoolMember&) = delete;
std::shared_ptr<LuaThreadContext> getLuaContext() { return d_pool_item; }
private:
std::shared_ptr<LuaThreadContext> d_pool_item;
LuaMultiThread* d_pool;
};
SharedPoolMember getPoolMember() {
std::lock_guard<std::mutex> lock(mutx);
auto member = lua_pool.back();
lua_pool.pop_back();
return SharedPoolMember(member, this);
}
void returnPoolMember(std::shared_ptr<LuaThreadContext> my_ptr) {
std::lock_guard<std::mutex> lock(mutx);
if (state_index >= num_states)
state_index = 0;
return lua_cv[state_index++];
lua_pool.push_back(my_ptr);
}

private:
std::vector<std::shared_ptr<LuaThreadContext>> lua_cv;
std::vector<std::shared_ptr<LuaThreadContext>> lua_pool;
std::vector<std::shared_ptr<LuaThreadContext>> lua_read_only;
unsigned int num_states;
unsigned int state_index;
std::mutex mutx;
};

Expand Down

0 comments on commit 4141a4a

Please sign in to comment.