summaryrefslogtreecommitdiffstats
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/Bindings/LuaState.cpp149
-rw-r--r--src/Bindings/LuaState.h89
2 files changed, 231 insertions, 7 deletions
diff --git a/src/Bindings/LuaState.cpp b/src/Bindings/LuaState.cpp
index 200878cf7..8e4f48275 100644
--- a/src/Bindings/LuaState.cpp
+++ b/src/Bindings/LuaState.cpp
@@ -20,6 +20,10 @@ extern "C"
#include "../Entities/Entity.h"
#include "../BlockEntities/BlockEntity.h"
+
+
+
+
// fwd: "SQLite/lsqlite3.c"
extern "C"
{
@@ -39,6 +43,10 @@ extern "C"
const cLuaState::cRet cLuaState::Return = {};
+/** Each Lua state stores a pointer to its creating cLuaState in Lua globals, under this name.
+This way any cLuaState can reference the main cLuaState's TrackedCallbacks, mutex etc. */
+static const char * g_CanonLuaStateGlobalName = "_CuberiteInternal_CanonLuaState";
+
@@ -114,6 +122,72 @@ cLuaStateTracker & cLuaStateTracker::Get(void)
////////////////////////////////////////////////////////////////////////////////
+// cLuaState::cCallback:
+
+bool cLuaState::cCallback::RefStack(cLuaState & a_LuaState, int a_StackPos)
+{
+ // Check if the stack contains a function:
+ if (!lua_isfunction(a_LuaState, a_StackPos))
+ {
+ return false;
+ }
+
+ // Clear any previous callback:
+ Clear();
+
+ // Add self to LuaState's callback-tracking:
+ a_LuaState.TrackCallback(*this);
+
+ // Store the new callback:
+ cCSLock Lock(m_CS);
+ m_Ref.RefStack(a_LuaState, a_StackPos);
+ return true;
+}
+
+
+
+
+
+void cLuaState::cCallback::Clear(void)
+{
+ // Free the callback reference:
+ lua_State * luaState = nullptr;
+ {
+ cCSLock Lock(m_CS);
+ if (!m_Ref.IsValid())
+ {
+ return;
+ }
+ luaState = m_Ref.GetLuaState();
+ m_Ref.UnRef();
+ }
+
+ // Remove from LuaState's callback-tracking:
+ cLuaState(luaState).UntrackCallback(*this);
+}
+
+
+
+
+
+void cLuaState::cCallback::Invalidate(void)
+{
+ cCSLock Lock(m_CS);
+ if (!m_Ref.IsValid())
+ {
+ LOGD("%s: Invalidating an already invalid callback at %p, this should not happen",
+ __FUNCTION__, reinterpret_cast<void *>(this)
+ );
+ return;
+ }
+ m_Ref.UnRef();
+}
+
+
+
+
+
+////////////////////////////////////////////////////////////////////////////////
// cLuaState:
cLuaState::cLuaState(const AString & a_SubsystemName) :
@@ -170,6 +244,10 @@ void cLuaState::Create(void)
luaL_openlibs(m_LuaState);
m_IsOwned = true;
cLuaStateTracker::Add(*this);
+
+ // Add the CanonLuaState value into the Lua state, so that we can get it from anywhere:
+ lua_pushlightuserdata(m_LuaState, reinterpret_cast<void *>(this));
+ lua_setglobal(m_LuaState, g_CanonLuaStateGlobalName);
}
@@ -206,6 +284,16 @@ void cLuaState::Close(void)
Detach();
return;
}
+
+ // Invalidate all callbacks:
+ {
+ cCSLock Lock(m_CSTrackedCallbacks);
+ for (auto & c: m_TrackedCallbacks)
+ {
+ c->Invalidate();
+ }
+ }
+
cLuaStateTracker::Del(*this);
lua_close(m_LuaState);
m_LuaState = nullptr;
@@ -871,6 +959,15 @@ bool cLuaState::GetStackValue(int a_StackPos, cRef & a_Ref)
+bool cLuaState::GetStackValue(int a_StackPos, cCallback & a_Callback)
+{
+ return a_Callback.RefStack(*this, a_StackPos);
+}
+
+
+
+
+
bool cLuaState::GetStackValue(int a_StackPos, double & a_ReturnedVal)
{
if (lua_isnumber(m_LuaState, a_StackPos))
@@ -1626,6 +1723,52 @@ int cLuaState::BreakIntoDebugger(lua_State * a_LuaState)
+void cLuaState::TrackCallback(cCallback & a_Callback)
+{
+ // Get the CanonLuaState global from Lua:
+ auto cb = WalkToNamedGlobal(g_CanonLuaStateGlobalName);
+ if (!cb.IsValid())
+ {
+ LOGWARNING("%s: Lua state %p has invalid CanonLuaState!", __FUNCTION__, reinterpret_cast<void *>(m_LuaState));
+ return;
+ }
+ auto & canonState = *reinterpret_cast<cLuaState *>(lua_touserdata(m_LuaState, -1));
+
+ // Add the callback:
+ cCSLock Lock(canonState.m_CSTrackedCallbacks);
+ canonState.m_TrackedCallbacks.push_back(&a_Callback);
+}
+
+
+
+
+
+void cLuaState::UntrackCallback(cCallback & a_Callback)
+{
+ // Get the CanonLuaState global from Lua:
+ auto cb = WalkToNamedGlobal(g_CanonLuaStateGlobalName);
+ if (!cb.IsValid())
+ {
+ LOGWARNING("%s: Lua state %p has invalid CanonLuaState!", __FUNCTION__, reinterpret_cast<void *>(m_LuaState));
+ return;
+ }
+ auto & canonState = *reinterpret_cast<cLuaState *>(lua_touserdata(m_LuaState, -1));
+
+ // Remove the callback:
+ cCSLock Lock(canonState.m_CSTrackedCallbacks);
+ auto & trackedCallbacks = canonState.m_TrackedCallbacks;
+ trackedCallbacks.erase(std::remove_if(trackedCallbacks.begin(), trackedCallbacks.end(),
+ [&a_Callback](cCallback * a_StoredCallback)
+ {
+ return (a_StoredCallback == &a_Callback);
+ }
+ ));
+}
+
+
+
+
+
////////////////////////////////////////////////////////////////////////////////
// cLuaState::cRef:
@@ -1681,7 +1824,7 @@ void cLuaState::cRef::RefStack(cLuaState & a_LuaState, int a_StackPos)
{
UnRef();
}
- m_LuaState = &a_LuaState;
+ m_LuaState = a_LuaState;
lua_pushvalue(a_LuaState, a_StackPos); // Push a copy of the value at a_StackPos onto the stack
m_Ref = luaL_ref(a_LuaState, LUA_REGISTRYINDEX);
}
@@ -1692,11 +1835,9 @@ void cLuaState::cRef::RefStack(cLuaState & a_LuaState, int a_StackPos)
void cLuaState::cRef::UnRef(void)
{
- ASSERT(m_LuaState->IsValid()); // The reference should be destroyed before destroying the LuaState
-
if (IsValid())
{
- luaL_unref(*m_LuaState, LUA_REGISTRYINDEX, m_Ref);
+ luaL_unref(m_LuaState, LUA_REGISTRYINDEX, m_Ref);
}
m_LuaState = nullptr;
m_Ref = LUA_REFNIL;
diff --git a/src/Bindings/LuaState.h b/src/Bindings/LuaState.h
index 215980033..63f419d2b 100644
--- a/src/Bindings/LuaState.h
+++ b/src/Bindings/LuaState.h
@@ -80,8 +80,11 @@ public:
/** Allows to use this class wherever an int (i. e. ref) is to be used */
explicit operator int(void) const { return m_Ref; }
+ /** Returns the Lua state associated with the value. */
+ lua_State * GetLuaState(void) { return m_LuaState; }
+
protected:
- cLuaState * m_LuaState;
+ lua_State * m_LuaState;
int m_Ref;
// Remove the copy-constructor:
@@ -112,6 +115,69 @@ public:
} ;
+ /** Represents a callback to Lua that C++ code can call.
+ Is thread-safe and unload-safe.
+ When the Lua state is unloaded, the callback returns an error instead of calling into non-existent code.
+ To receive the callback instance from the Lua side, use RefStack() or (better) cLuaState::GetStackValue().
+ Note that instances of this class are tracked in the canon LuaState instance, so that they can be invalidated
+ when the LuaState is unloaded; due to multithreading issues they can only be tracked by-ptr, which has
+ an unfortunate effect of disabling the copy and move constructors. */
+ class cCallback
+ {
+ public:
+ /** Creates an unbound callback instance. */
+ cCallback(void) = default;
+
+ ~cCallback()
+ {
+ Clear();
+ }
+
+ /** Calls the Lua callback, if still available.
+ Returns true if callback has been called.
+ Returns false if the Lua state isn't valid anymore. */
+ template <typename... Args>
+ bool Call(Args &&... args)
+ {
+ cCSLock Lock(m_CS);
+ if (!m_Ref.IsValid())
+ {
+ return false;
+ }
+ cLuaState(m_Ref.GetLuaState()).Call(m_Ref, std::forward<Args>(args)...);
+ return true;
+ }
+
+ /** Set the contained callback to the function in the specified Lua state's stack position.
+ If a callback has been previously contained, it is freed first. */
+ bool RefStack(cLuaState & a_LuaState, int a_StackPos);
+
+ /** Frees the contained callback, if any. */
+ void Clear(void);
+
+ protected:
+ friend class cLuaState;
+
+ /** The mutex protecting m_Ref against multithreaded access */
+ cCriticalSection m_CS;
+
+ /** Reference to the Lua callback */
+ cRef m_Ref;
+
+
+ /** Invalidates the callback, without untracking it from the cLuaState.
+ Called only from cLuaState when closing the Lua state. */
+ void Invalidate(void);
+
+ /** This class cannot be copied, because it is tracked in the LuaState by-ptr. */
+ cCallback(const cCallback &) = delete;
+
+ /** This class cannot be moved, because it is tracked in the LuaState by-ptr. */
+ cCallback(cCallback &&) = delete;
+ };
+ typedef SharedPtr<cCallback> cCallbackPtr;
+
+
/** A dummy class that's used only to delimit function args from return values for cLuaState::Call() */
class cRet
{
@@ -268,6 +334,7 @@ public:
bool GetStackValue(int a_StackPos, bool & a_Value);
bool GetStackValue(int a_StackPos, cPluginManager::CommandResult & a_Result);
bool GetStackValue(int a_StackPos, cRef & a_Ref);
+ bool GetStackValue(int a_StackPos, cCallback & a_Ref);
bool GetStackValue(int a_StackPos, double & a_Value);
bool GetStackValue(int a_StackPos, eBlockFace & a_Value);
bool GetStackValue(int a_StackPos, eWeather & a_Value);
@@ -441,8 +508,7 @@ protected:
bool m_IsOwned;
/** The subsystem name is used for reporting errors to the console, it is either "plugin %s" or "LuaScript"
- whatever is given to the constructor
- */
+ whatever is given to the constructor. */
AString m_SubsystemName;
/** Name of the currently pushed function (for the Push / Call chain) */
@@ -451,6 +517,15 @@ protected:
/** Number of arguments currently pushed (for the Push / Call chain) */
int m_NumCurrentFunctionArgs;
+ /** The tracked callbacks.
+ This object will invalidate all of these when it is about to be closed.
+ Protected against multithreaded access by m_CSTrackedCallbacks. */
+ std::vector<cCallback *> m_TrackedCallbacks;
+
+ /** Protects m_TrackedTallbacks against multithreaded access. */
+ cCriticalSection m_CSTrackedCallbacks;
+
+
/** Variadic template terminator: If there's nothing more to push / pop, just call the function.
Note that there are no return values either, because those are prefixed by a cRet value, so the arg list is never empty. */
bool PushCallPop(void)
@@ -533,6 +608,14 @@ protected:
/** Tries to break into the MobDebug debugger, if it is installed. */
static int BreakIntoDebugger(lua_State * a_LuaState);
+
+ /** Adds the specified callback to tracking.
+ The callback will be invalidated when this Lua state is about to be closed. */
+ void TrackCallback(cCallback & a_Callback);
+
+ /** Removes the specified callback from tracking.
+ The callback will no longer be invalidated when this Lua state is about to be closed. */
+ void UntrackCallback(cCallback & a_Callback);
} ;