diff options
Diffstat (limited to 'private/ole32/com/dcomrem')
47 files changed, 31384 insertions, 0 deletions
diff --git a/private/ole32/com/dcomrem/callctrl.cxx b/private/ole32/com/dcomrem/callctrl.cxx new file mode 100644 index 000000000..f393fdd27 --- /dev/null +++ b/private/ole32/com/dcomrem/callctrl.cxx @@ -0,0 +1,1820 @@ +//+------------------------------------------------------------------------- +// +// Microsoft Windows +// Copyright (C) Microsoft Corporation, 1992 - 1993. +// +// File: callctrl.cxx +// +// Contents: Contains the ORPC CallControl code +// +// History: 21-Dec-93 Johannp Original Version +// 04-Nov-94 Rickhi ReWrite as layer over channel +// +//-------------------------------------------------------------------------- +#include <ole2int.h> +#include <thkreg.h> // OLETHK_ defines +#include <dde.h> +#include <callctrl.hxx> // Class Definition +#include <objsrv.h> // IID_IObjServer + + +// private defines used only in this file +#define WM_SYSTIMER 0x0118 +#define SYS_ALTDOWN 0x2000 +#define WM_NCMOUSEFIRST WM_NCMOUSEMOVE +#define WM_NCMOUSELAST WM_NCMBUTTONDBLCLK + + +// empty slot in window registration +#define WD_EMPTY (HWND)-1 + +// the following table is used to quickly determine what windows +// message queue inputflag to specify for the various categories of +// outgoing calls in progress. The table is indexed by CALLCATEGORY. + +DWORD gMsgQInputFlagTbl[4] = { + QS_ALLINPUT | QS_TRANSFER | QS_ALLPOSTMESSAGE, // NOCALL + QS_ALLINPUT | QS_TRANSFER | QS_ALLPOSTMESSAGE, // SYNCHRONOUS + QS_ALLINPUT | QS_TRANSFER | QS_ALLPOSTMESSAGE, // ASYNC + QS_SENDMESSAGE}; // INPUTSYNC + + +// the following table is used to map bit flags in the Rpc Message to +// the equivalent OLE CALLCATEGORY. + +DWORD gRpcFlagToCallCatMap[3] = { + CALLCAT_SYNCHRONOUS, // no flags set + CALLCAT_INPUTSYNC, // RPCFLG_INPUT_SYNCHRONOUS + CALLCAT_ASYNC}; // RPCFLG_ASYNCHRONOUS + + +// prototype +HRESULT CopyMsgForRetry(RPCOLEMESSAGE *pMsg, + IRpcChannelBuffer *pChnl, + HRESULT hrIn); + + +//+------------------------------------------------------------------------- +// +// Function: CoRegisterMessageFilter, public +// +// Synopsis: registers an applications message filter with the call control +// +// Arguments: [pMsgFilter] - message filter to register +// [ppMsgFilter] - optional, where to return previous IMF +// +// Returns: S_OK - registered successfully +// +// History: 21-Dec-93 JohannP Created +// +//-------------------------------------------------------------------------- +STDAPI CoRegisterMessageFilter(LPMESSAGEFILTER pMsgFilter, + LPMESSAGEFILTER *ppMsgFilter) +{ + ComDebOut((DEB_MFILTER, "CoRegisterMessageFilter pMF:%x ppMFOld:%x\n", + pMsgFilter, ppMsgFilter)); + CALLHOOKOBJECT(S_OK,CLSID_NULL,IID_IMessageFilter,(IUnknown **)&pMsgFilter); + + // validate the parameters. NULL acceptable for either or both parameters. + if (pMsgFilter != NULL && !IsValidInterface(pMsgFilter)) + { + return E_INVALIDARG; + } + + if(ppMsgFilter != NULL && !IsValidPtrOut(ppMsgFilter, sizeof(ppMsgFilter))) + { + return E_INVALIDARG; + } + + // this operation is not allowed on MTA Threads + if (IsMTAThread()) + return CO_E_NOT_SUPPORTED; + + // find the callcontrol for this apartment and replace the existing + // message filter. if no callctrl has been created yet, just stick + // the pMsgFilter in tls. + + COleTls tls; + CAptCallCtrl *pACC = tls->pCallCtrl; + + IMessageFilter *pOldMF; + + if (pACC) + { + pOldMF = pACC->InstallMsgFilter(pMsgFilter); + } + else + { + pOldMF = tls->pMsgFilter; + + if (pMsgFilter) + { + pMsgFilter->AddRef(); + } + tls->pMsgFilter = pMsgFilter; + } + if (ppMsgFilter) + { + // return old MF to the caller + *ppMsgFilter = pOldMF; + } + else if (pOldMF) + { + // release the old MF + pOldMF->Release(); + } + + return S_OK; +} + +//+------------------------------------------------------------------------- +// +// Method: CAptCallCtrl::InstallMsgFilter +// +// Synopsis: called to install a new application provided message filter +// +// Arguments: [pMF] - new message filter to install (or NULL) +// +// Returns: previous message filter if there was one +// +// History: 20-Dec-93 JohannP Created +// +//-------------------------------------------------------------------------- +INTERNAL_(IMessageFilter *) CAptCallCtrl::InstallMsgFilter(IMessageFilter *pMF) +{ + IMessageFilter *pMFOld = _pMF; // save the old one to return + + _pMF = pMF; // install the new one + if (_pMF) + { + _pMF->AddRef(); + } + + return pMFOld; +} + +//+------------------------------------------------------------------------- +// +// Method: CAptCallCtrl::CAptCallCtrl +// +// Synopsis: constructor for per apartment call control state +// +// History: 11-Nov-94 Rickhi Created +// +//-------------------------------------------------------------------------- +CAptCallCtrl::CAptCallCtrl() : + _fInMsgFilter(FALSE), + _pTopCML(NULL) +{ + // The first one is reserved for ORPC. An hWnd value of WD_EMPTY + // means the slot is available. + _WD[0].hWnd = WD_EMPTY; + + // The second slot has fixed values for DDE + _WD[1].hWnd = NULL; + _WD[1].wFirstMsg = WM_DDE_FIRST; + _WD[1].wLastMsg = WM_DDE_LAST; + + // put our pointer into thread local storage, and retrieve any previously + // registered message filter. + + COleTls tls; + tls->pCallCtrl = this; + + _pMF = tls->pMsgFilter; + tls->pMsgFilter = NULL; +} + +//+------------------------------------------------------------------------- +// +// Method: CAptCallCtrl::~CAptCallCtrl +// +// Synopsis: destructor for per apartment call control state +// +// History: 11-Nov-94 Rickhi Created +// +//-------------------------------------------------------------------------- +CAptCallCtrl::~CAptCallCtrl() +{ + Win4Assert(_pTopCML == NULL); // no outgoing calls. + + if (_pMF) + { + _pMF->Release(); + } + + // remove our pointer from thread local storage + COleTls tls; + tls->pCallCtrl = NULL; +} + +//+------------------------------------------------------------------------- +// +// Method: CAptCallCtrl::Register/Revoke +// +// Synopsis: register or revoke RPC window data +// +// Arguments: [hWnd] - window handle to look for calls on +// [wFirstMsg] - msgid of first message in range to look for +// [wLastMsg] - msgid of last message in range to look for +// +// Returns: nothing +// +// Notes: This code is only ever called by the RpcChannel and by +// the DDE layer, and so error checking is kept to a minimum. +// +// History: 30-Apr-95 Rickhi Created +// +//-------------------------------------------------------------------------- +void CAptCallCtrl::Register(HWND hWnd, UINT wFirstMsg, UINT wLastMsg) +{ + Win4Assert(_WD[0].hWnd == WD_EMPTY && "Register Out of Space"); + + _WD[0].hWnd = hWnd; + _WD[0].wFirstMsg = wFirstMsg; + _WD[0].wLastMsg = wLastMsg; +} + +void CAptCallCtrl::Revoke(HWND hWnd) +{ + Win4Assert(_WD[0].hWnd == hWnd && "Revoke not found"); + _WD[0].hWnd = WD_EMPTY; +} + +//+------------------------------------------------------------------------- +// +// Function: GetSlowTimeFactor +// +// Synopsis: Get the time slowing factor for Wow apps +// +// Returns: The factor by which we need to slow time down. +// +// Algorithm: If there is a factor in the registry, we open and read the +// registry. Otherwise we just set it to the default. +// +// History: 22-Jul-94 Ricksa Created +// 09-Jun-95 Susia ANSI Chicago optimization +// +//-------------------------------------------------------------------------- +#ifdef _CHICAGO_ +#undef RegOpenKeyEx +#define RegOpenKeyEx RegOpenKeyExA +#undef RegQueryValueEx +#define RegQueryValueEx RegQueryValueExA +#endif +DWORD GetSlowTimeFactor(void) +{ + // Default slowing time so we can just exit if there is no key which + // is assumed to be the common case. + DWORD dwSlowTimeFactor = OLETHK_DEFAULT_SLOWRPCTIME; + + // Key for reading the value from the registry + HKEY hkeyOleThk; + + // Get the Ole Thunk special value key + LONG lStatus = RegOpenKeyEx(HKEY_CLASSES_ROOT, OLETHK_KEY, 0, KEY_READ, + &hkeyOleThk); + + if (lStatus == ERROR_SUCCESS) + { + DWORD dwType; + DWORD dwSizeData = sizeof(dwSlowTimeFactor); + + lStatus = RegQueryValueEx(hkeyOleThk, OLETHK_SLOWRPCTIME_VALUE, NULL, + &dwType, (LPBYTE) &dwSlowTimeFactor, &dwSizeData); + + if ((lStatus != ERROR_SUCCESS) || dwType != REG_DWORD) + { + // Guarantee that value is reasonable if something went wrong. + dwSlowTimeFactor = OLETHK_DEFAULT_SLOWRPCTIME; + } + + // Close the key since we are done with it. + RegCloseKey(hkeyOleThk); + } + + return dwSlowTimeFactor; +} + +//+------------------------------------------------------------------------- +// +// Function: CanMakeOutCall +// +// Synopsis: called when the client app wants to make an outgoing call to +// determine if it is OK to do it now or not. Common subroutine +// to CAptRpcChnl::GetBuffer and RemoteReleaseRifRef. +// +// Arguments: [dwCallCatOut] - call category of call the app wants to make +// [pChnl] - ptr to channel call is being made on +// [riid] - interface call is being made on +// +// Returns: S_OK - ok to make the call +// RPC_E_CANTCALLOUT_INEXTERNALCALL - inside IMessageFilter +// RPC_E_CANTCALLOUT_INASYNCCALL - inside async call +// RPC_E_CANTCALLOUT_ININPUTSYNCCALL - inside input sync or SendMsg +// +// History: 21-Dec-93 Johannp Original Version +// 04-Nov-94 Rickhi ReWrite +// 03-Oct-95 Rickhi Made into common subroutine +// +//-------------------------------------------------------------------------- +INTERNAL CanMakeOutCall(DWORD dwCallCatOut, REFIID riid) +{ + // get the topmost incoming call state from Tls. + + HRESULT hr; + COleTls tls(hr); + if (FAILED(hr)) + return hr; + + CSrvCallState *pSCS = tls->pTopSCS; + + DWORD dwCallCatIn = (pSCS) ? pSCS->GetCallCatIn() : CALLCAT_NOCALL; + + // if handling an incoming ASYNC call, only allow ASYNC outgoing calls, + // and local calls on IRemUnknown (which locally is actually IRundown). + + if (dwCallCatIn == CALLCAT_ASYNC && + dwCallCatOut != CALLCAT_ASYNC && + !IsEqualGUID(riid, IID_IRundown)) + { + return RPC_E_CANTCALLOUT_INASYNCCALL; + } + + // if handling an incoming INPUTSYNC call, or if we are handling a + // SendMessage, dont allow SYNCHRONOUS calls out or we could deadlock + // since SYNC uses PostMessage and INPUTSYNC uses SendMessage. + + if (dwCallCatOut == CALLCAT_SYNCHRONOUS && + (dwCallCatIn == CALLCAT_INPUTSYNC || InSendMessage())) + { + return RPC_E_CANTCALLOUT_ININPUTSYNCCALL; + } + + return S_OK; +} + +//+------------------------------------------------------------------------- +// +// Method: CMTARpcChnl::CMTARpcChnl/~CMTARpcChnl +// +// Synopsis: constructor/destructor +// +// Parameters: [pStdId] - std identity for the object +// [pOXIDEntry] - OXIDEntry for the object server +// [eState] - state flags passed thru to CRpcChannelBuffer +// (ignored by CMTARpcCnl). +// +// History: 11-Nov-94 Rickhi Created +// +//-------------------------------------------------------------------------- +CMTARpcChnl::CMTARpcChnl(CStdIdentity *pStdId, + OXIDEntry *pOXIDEntry, + DWORD eState) : + CRpcChannelBuffer(pStdId, pOXIDEntry, eState), + _dwTIDCallee(pOXIDEntry->dwTid), + _dwAptId(GetCurrentApartmentId()) +{ + ComDebOut((DEB_CALLCONT,"CMTARpcChnl::CMTARpcChnl this:%x\n", this)); +} + +CMTARpcChnl::~CMTARpcChnl() +{ + ComDebOut((DEB_CALLCONT,"CMTARpcChnl::~CMTARpcChnl this:%x\n", this)); +} + +//+------------------------------------------------------------------------- +// +// Method: CMTARpcChnl::GetBuffer +// +// Synopsis: Ensure it is legal to call out now, then get a buffer. +// +// Parameters: [pMsg] - ptr to message structure +// [riid] - interface call is being made on +// +// History: 11-Nov-94 Rickhi Created +// +//-------------------------------------------------------------------------- +STDMETHODIMP CMTARpcChnl::GetBuffer(RPCOLEMESSAGE *pMsg, REFIID riid) +{ + HRESULT hr; + COleTls tls(hr); // use this form incase no calls made on this thread yet + if (FAILED(hr)) + return hr; + + if (!IsMTAThread()) + { + ComDebOut((DEB_WARN,"CMTARpcChnl::GetBuffer - MTA proxy called on apartment thread, this: 0x%x\n", + this)); + return RPC_E_WRONG_THREAD; + } + + // Make sure we are allowed to make this outgoing call. We do that here + // so that we dont marshal all the parameters only to discover that we + // cant call out and then have to free all the marshalled parameters + // (especially the ones where marshalling has side effects). + + if (!(_dwAptId == GetCurrentApartmentId() || CallableOnAnyApt())) + { + // we are not being called on a thread in the MTA apartment + return RPC_E_WRONG_THREAD; + } + + if (pMsg->rpcFlags & RPCFLG_INPUT_SYNCHRONOUS) + { + // dont allow INPUTSYNC calls from an MTA apartment to anybody. + return RPC_E_CANTCALLOUT_ININPUTSYNCCALL; + } + + // All ASYNC calls from an MTA apartment are treated as SYNCHRONOUS, + // so convert the call category here before proceeding. + + pMsg->rpcFlags &= ~RPCFLG_ASYNCHRONOUS; + + // ask the real channel for a buffer. + return CRpcChannelBuffer::ClientGetBuffer(pMsg, riid); +} + +//+------------------------------------------------------------------------- +// +// Method: CAptRpcChnl::CAptRpcChnl/~CAptRpcChnl +// +// Synopsis: constructor/destructor +// +// Parameters: [pStdId] - std identity for the object +// [pOXIDEntry] - OXIDEntry for the object server +// [eState] - state flags passed thru to CRpcChannelBuffer +// (ignored by CAptRpcCnl). +// +// History: 11-Nov-94 Rickhi Created +// +//-------------------------------------------------------------------------- +CAptRpcChnl::CAptRpcChnl(CStdIdentity *pStdId, + OXIDEntry *pOXIDEntry, + DWORD eState) : + CRpcChannelBuffer(pStdId, pOXIDEntry, eState), + _dwTIDCallee(pOXIDEntry->dwTid), + _dwAptId(GetCurrentApartmentId()) +{ + ComDebOut((DEB_CALLCONT,"CAptRpcChnl::CAptRpcChnl this:%x\n", this)); +} + +CAptRpcChnl::~CAptRpcChnl() +{ + ComDebOut((DEB_CALLCONT,"CAptRpcChnl::~CAptRpcChnl this:%x\n", this)); +} + +//+------------------------------------------------------------------------- +// +// Method: CAptRpcChnl::GetBuffer +// +// Synopsis: Ensure it is legal to call out now, then get a buffer. +// +// Parameters: [pMsg] - ptr to message structure +// [riid] - interface call is being made on +// +// History: 11-Nov-94 Rickhi Created +// +//-------------------------------------------------------------------------- +STDMETHODIMP CAptRpcChnl::GetBuffer(RPCOLEMESSAGE *pMsg, REFIID riid) +{ + HRESULT hr; + COleTls tls(hr); // use this form incase no calls made on this thread yet + if (FAILED(hr)) + return hr; + + // first, make sure that we are being called on the correct thread + // so that we are sure tls->pCallCtrl is set. + + if (!IsSTAThread() || !(_dwAptId == GetCurrentApartmentId() || + CallableOnAnyApt())) + { + return RPC_E_WRONG_THREAD; + } + + // dont allow the application to call out while handling an + // IMessageFilter call because it screws up the call sequencing. + + CAptCallCtrl *pACC = tls->pCallCtrl; + + if (pACC && pACC->InMsgFilter()) + { + ComDebOut((DEB_ERROR, "Illegal callout from within IMessageFilter\n")); + return RPC_E_CANTCALLOUT_INEXTERNALCALL; + } + + // if the call is async and remote, or async and to an MTA apartment, + // then change the category to sync, since we dont support async remotely + // or to MTA apartments locally. This must be done before calling + // CanMakeOutCall in order to avoid deadlocks. If the call is input sync + // and remote or to an MTA apartment, dissallow the call. + + if (pMsg->rpcFlags & (RPCFLG_ASYNCHRONOUS | RPCFLG_INPUT_SYNCHRONOUS)) + { + DWORD dwCtx; + CRpcChannelBuffer::GetDestCtx(&dwCtx, NULL); + + if (dwCtx == MSHCTX_DIFFERENTMACHINE || + (GetOXIDEntry()->dwFlags & OXIDF_MTASERVER)) + { + if (pMsg->rpcFlags & RPCFLG_INPUT_SYNCHRONOUS) + return RPC_E_CANTCALLOUT_ININPUTSYNCCALL; + + // turn off the async flag so that the call looks (and acts) + // like it is synchronous. + + pMsg->rpcFlags &= ~RPCFLG_ASYNCHRONOUS; + } + } + + // Make sure we are allowed to make this outgoing call. We do that here + // so that we dont marshal all the parameters only to discover that we + // cant call out and then have to free all the marshalled parameters + // (especially the ones where marshalling has side effects). + + // figure out the call category of this call by looking at bit + // values in the rpc message flags. + + DWORD dwCallCatOut = RpcFlagToCallCat(pMsg->rpcFlags); + + // check other outgoing call restrictions common to multi and single + // threaded apartments + + hr = CanMakeOutCall(dwCallCatOut, riid); + + if (hr == S_OK) + { + // ask the real channel for a buffer. + hr = CRpcChannelBuffer::ClientGetBuffer(pMsg, riid); + } + + return hr; +} + +//+------------------------------------------------------------------------- +// +// Method: CAptRpcChnl::SendReceive +// +// Synopsis: instantiate a modal loop object and then transmit the call +// +// Parameters: [pMsg] - ptr to message structure +// [pulStatus] - place to return a status code +// +// History: 11-Nov-94 Rickhi Created +// +//-------------------------------------------------------------------------- +STDMETHODIMP CAptRpcChnl::SendReceive(RPCOLEMESSAGE *pMsg, ULONG *pulStatus) +{ + // Figure out the call category of this call by looking at the bit + // values in the rpc message flags. + + DWORD dwCallCatOut = RpcFlagToCallCat(pMsg->rpcFlags); + DWORD dwMsgQInputFlag = gMsgQInputFlagTbl[dwCallCatOut]; + + // Now for a spectacular hack. IRemUnknown::Release had slightly + // different dwMsgQInputFlag semantic in the old code base, so we + // check for that one case here and set the flag accordingly. Not + // doing this would allow SYSCOMMAND calls in during Release which + // we throw away, thus preventing an app from shutting down correctly. + // SimpSvr.exe is a good example of this. + + if ((pMsg->iMethod & ~RPC_FLAGS_VALID_BIT) == 5 && + (IsEqualIID(IID_IRundown, *MSG_TO_IIDPTR(pMsg)) || + IsEqualIID(IID_IRemUnknown, *MSG_TO_IIDPTR(pMsg)))) + { + dwMsgQInputFlag = (QS_POSTMESSAGE | QS_SENDMESSAGE | QS_TRANSFER | + QS_ALLPOSTMESSAGE); + } + + + // Now construct a modal loop object for the call that is about to + // be made. It maintains the call state and exits when the call has + // been completed, cancelled, or rejected. + + HRESULT hr; + CCliModalLoop CML(_dwTIDCallee, dwMsgQInputFlag); + + do + { + hr = CML.SendReceive(pMsg, pulStatus, this); + + if (hr == RPC_E_SERVERCALL_RETRYLATER) + { + // the call was rejected by the server and the client Msg Filter + // decided to retry the call. We have to make a copy of the + // message and re-send it. + + hr = CopyMsgForRetry(pMsg); + } + else if (hr == RPC_E_CALL_REJECTED) + { + // the call was rejected by the server and the client Msg Filter + // decided NOT to retry the call. We have to free the buffer + // that was returned since the proxy is not expecting it. + + FreeBuffer(pMsg); + } + + } while (hr == RPC_E_SERVERCALL_RETRYLATER); + + return hr; +} + +//+------------------------------------------------------------------------- +// +// Method: CAptRpcChnl::CopyMsgForRetry +// +// Synopsis: Makes a copy of the message we sent. We have to ask Rpc +// for another buffer and then copy the original buffer into +// the new one so we can make another call. +// +// Parameters: [pMsg] - ptr to message structure to copy +// +// History: 11-Nov-94 Rickhi Created +// +//-------------------------------------------------------------------------- +HRESULT CAptRpcChnl::CopyMsgForRetry(RPCOLEMESSAGE *pMsg) +{ + ComDebOut((DEB_CALLCONT,"CAptRpcChnl::CopyMsgForRetry pMsg:%x\n", pMsg)); + + // CODEWORK: this is dumb, but the channel blows chunks in FreeBuffer + // if i dont do this double copy. + + void *pTmpBuf = PrivMemAlloc(pMsg->cbBuffer); + if (pTmpBuf) + { + memcpy(pTmpBuf, pMsg->Buffer, pMsg->cbBuffer); + } + + // save copy of the contents of the old message so we can free it later + + HRESULT hr = E_OUTOFMEMORY; + RPCOLEMESSAGE MsgToFree = *pMsg; + FreeBuffer(&MsgToFree); + + if (pTmpBuf) + { + // allocate a new message, dont have to worry about checking the + // CanMakeOutCall again, so we just ask the Rpc channel directly. + + hr = CRpcChannelBuffer::GetBuffer(pMsg, *MSG_TO_IIDPTR(pMsg)); + + if (SUCCEEDED(hr)) + { + // copy the temp buffer into the new buffer + memcpy(pMsg->Buffer, pTmpBuf, pMsg->cbBuffer); + hr = RPC_E_SERVERCALL_RETRYLATER; + } + + PrivMemFree(pTmpBuf); + } + + return hr; +} + +//+------------------------------------------------------------------------- +// +// Member: CCliModalLoop::SendReceive +// +// Synopsis: called to transmit a call to the server and enter a modal +// loop. +// +// Arguments: [pMsg] - message to send +// [pulStatus] - place to return status code +// [pChnl] - IRpcChannelBuffer pointer +// +// Returns: result of the call. May return RETRYLATER if the call should +// be retransmitted. +// +// History: 11-Nov-94 Rickhi Created +// +//-------------------------------------------------------------------------- +INTERNAL CCliModalLoop::SendReceive(RPCOLEMESSAGE *pMsg, ULONG *pulStatus, + IRpcChannelBuffer2 *pChnl) +{ + // SendReceive is a blocking call. The channel will transmit the call + // asynchronously then call us back in BlockFn where we wait for an + // event such as the call completing, or a windows message arriving, + // or the user cancelling the call. Because of the callback, we need + // to set _hr before calling SR. + + _hr = RPC_S_CALLPENDING; + _hr = pChnl->SendReceive2(pMsg, pulStatus); + + // By this point the call has completed. Now check if it was rejected + // and if so, whether we need to retry immediately, later, or never. + // Handling of Rejected calls must occur here, not in the BlockFn, due + // to the fact that some calls and some protocols are synchronous, and + // other calls and protocols are asynchronous. + + if (_hr == RPC_E_CALL_REJECTED || _hr == RPC_E_SERVERCALL_RETRYLATER) + { + // this function decides on 1 of 3 different courses of action + // 1. fail the call - sets the state to Call_Rejected + // 2. retry immediately - sets _hr to RETRYLATER, fall out + // 3. retry later - starts the timer, we block below + + _hr = HandleRejectedCall(pChnl); + + // if a timer was installed to retry the call later, then we have + // to go into modal loop until the timer expires. if the call is + // cancelled while in this loop, the loop will be exited. + + while (!IsTimerAtZero()) + { + BlockFn(NULL); + } + + // Either it is time to retransmit the call, or the call was + // cancelled or rejected. + } + + return _hr; +} + +//+------------------------------------------------------------------------- +// +// Member: CCliModalLoop::HandleRejectedCall +// +// Synopsis: called when the response to a remote call is rejected or +// retry later. +// +// Arguments: [pChnl] - channel we are calling on. +// +// Returns: RPC_E_CALL_REJECTED - call is rejected +// RPC_E_SERVERCALL_RETRYLATER - the call should be retried +// (Timer is set if retry is to be delayed) +// +// Algorithm: Calls the app's message filter (if there is one) to +// determine whether the call should be failed, retried +// immediately, or retried at some later time. If there is +// no message filter, or the client is on a different machine, +// then the call is always rejected. +// +// History: 21-Dec-93 Johannp Created +// 30-Apr-95 Rickhi ReWrite +// +//-------------------------------------------------------------------------- +INTERNAL CCliModalLoop::HandleRejectedCall(IRpcChannelBuffer2 *pChnl) +{ + // default return value - rejected + DWORD dwRet = 0xffffffff; + + DWORD dwDestCtx; + HRESULT hr = pChnl->GetDestCtx(&dwDestCtx, NULL); + + if (SUCCEEDED(hr) && dwDestCtx != MSHCTX_DIFFERENTMACHINE) + { + // the call is local to this machine, ask the message filter + // what to do. For remote calls we never allow retry, since + // the parameters were not sent back to us in the packet. + + IMessageFilter *pMF = _pACC->GetMsgFilter(); + if (pMF) + { + ComDebOut((DEB_MFILTER, + "pMF->RetryRejectedCall(dwTIDCallee:%x ElapsedTime:%x Type:%x)\n", + _dwTIDCallee, GetElapsedTime(), + (_hr == RPC_E_CALL_REJECTED) ? SERVERCALL_REJECTED + : SERVERCALL_RETRYLATER)); + + dwRet = pMF->RetryRejectedCall((MF_HTASK)_dwTIDCallee, GetElapsedTime(), + (_hr == RPC_E_CALL_REJECTED) ? SERVERCALL_REJECTED + : SERVERCALL_RETRYLATER); + + ComDebOut((DEB_MFILTER,"pMF->RetryRejected() dwRet:%x\n", dwRet)); + + _pACC->ReleaseMsgFilter(); + } + } + + if (dwRet == 0xffffffff) + { + // Really rejected. Mark it as such incase it was actually + // Call_RetryLater, also ensures that IsWaiting returns FALSE + return RPC_E_CALL_REJECTED; + } + else if (dwRet >= 100) + { + // Retry Later. Start the timer. This ensures that IsTimerAtZero + // returns FALSE and IsWaiting returns TRUE + return StartTimer(dwRet); + } + else + { + // Retry Immediately. The state is set so that IsTimerAtZero + // returns TRUE. + + Win4Assert(IsTimerAtZero()); + return RPC_E_SERVERCALL_RETRYLATER; + } +} + +//+------------------------------------------------------------------------- +// +// Function: OleModalLoopBlockFn +// +// Synopsis: Called by the RpcChannel during an outgoing call while +// waiting for the reply message. +// +// Arguments: [pvWnd] - Window handle to expect the reply on +// [pvCtx] - Call Context (the CCliModalLoop) +// [hCallWaitEvent] - optional event to have CallControl wait on +// +// Returns: result of the call +// +// Algorithm: pvCtx is the topmost modal loop for the current apartment. +// Just call it's block function. +// +// History: Dec-93 JohannP Created +// +//-------------------------------------------------------------------------- +RPC_STATUS OleModalLoopBlockFn(void *pvWnd, void *pvCtx, HANDLE hCallWaitEvent) +{ + Win4Assert( pvCtx != NULL ); + return ((CCliModalLoop *) pvCtx)->BlockFn(hCallWaitEvent); +} + +//+------------------------------------------------------------------------- +// +// Member: CCliModalLoop::BlockFn (private) +// +// Synopsis: Implements the blocking part of the modal loop. This function +// blocks until an event of interest occurs, then it goes and +// processes that event and returns. +// +// Arguments: [hCallWaitEvent] - event to wait on (optional) +// +// Returns: RPC_S_CALLPENDING - the call is still pending a reply +// RPC_E_CALL_CANCELLED - the call was cancelled. +// RPC_E_SERVERCALL_RETRYLATER - the call should be retried later +// +// History: Dec-93 JohannP Created +// 30-Apr-95 Rickhi ReWrite +// +//-------------------------------------------------------------------------- +HRESULT CCliModalLoop::BlockFn(HANDLE hEventCallComplete) +{ + ComDebOut((DEB_CALLCONT, + "CCliModalLoop::BlockFn this:%x dwMsgQInputFlag:%x hEvent:%x\n", + this, _dwMsgQInputFlag, hEventCallComplete)); + + Win4Assert(IsWaiting() && "ModalLoop::BlockFn - not waiting on call"); + + // First, we wait for an event of interest to occur, either for the call + // to complete, or a new windows message to arrive on the queue. + + DWORD dwWakeReason = WAIT_TIMEOUT; + HANDLE rgEvents[1] = {hEventCallComplete}; + DWORD cEvents = 0; + + if (hEventCallComplete != NULL) + { + // Check if the event is already signalled. This ensures that + // when we return from nested calls and the upper calls have already + // been acknowledged, that no windows messages can come in. + + ComDebOut((DEB_CALLCONT, "WaitForSingleObject hEvent:%x\n", rgEvents[0])); + cEvents = 1; + dwWakeReason = WaitForSingleObject(rgEvents[0], 0); + } + + if (dwWakeReason == WAIT_TIMEOUT) + { + DWORD dwWaitTime = TicksToWait(); + + // If we want to wake up for a posted message, we need to make + // sure that we haven't missed any because of the queue status + // being affected by prior PeekMessages. We don't worry about + // QS_SENDMESSAGE because if PeekMessage got called, the pending + // send got dispatched. Further, if we are in an input sync call, + // we don't want to start dispatching regular RPC calls here by + // accident. + + if (_dwMsgQInputFlag & QS_POSTMESSAGE) + { + DWORD dwStatus = GetQueueStatus(_dwMsgQInputFlag); + + // We care about any message on the queue not just new messages + // because PeekMessage affects the queue state. It resets the + // state so even if a message is not processed, the queue state + // represents this as an old message even though no one has + // ever looked at it. So even though the message queue tells us + // there are no new messages in the queue. A new message we are + // interested in could be in the queue. + + WORD wNew = (WORD) dwStatus | HIWORD(dwStatus); + + // Note that we look for send as well as post because our + // queue status could have reset the state of the send message + // bit and therefore, MsgWaitForMultipleObject below will not + // wake up to dispatch the send message. + + if (wNew & (QS_POSTMESSAGE | QS_SENDMESSAGE)) + { + // the acknowledge message might be already in the queue + if (PeekRPCAndDDEMessage()) + { + // we know that *some* RPC message came in and was + // processed. It could have been the Reply we were waiting + // for OR some other incoming call. Since we cant tell + // which, we return to RPC land. If it was not our Reply + // then RPC will call our modal loop again. + return _hr; + } + } + +#ifdef _CHICAGO_ + //Note:POSTPPC + WORD wOld = HIWORD(dwStatus); + + if (wOld & (QS_POSTMESSAGE)) + { + ComDebOut((DEB_CALLCONT, "Set timeout time to 100\n")); + dwWaitTime = 100; + } +#endif //_CHICAGO_ + } + + ComDebOut((DEB_CALLCONT, + "Call MsgWaitForMultiple time:%ld, cEvents:%x hEvent:%x,\n", + dwWaitTime, cEvents, rgEvents[0] )); + + dwWakeReason = MsgWaitForMultipleObjects(cEvents, rgEvents, FALSE, + dwWaitTime, _dwMsgQInputFlag); + + ComDebOut((DEB_CALLCONT, + "MsgWaitForMultipleObjects hr:%ld\n", dwWakeReason)); + } + + + // OK, we've done whatever blocking we were going to do and now we have + // been woken up, so figure out what event of interest occured to wake + // us up and go handle it. + + if (dwWakeReason == (WAIT_OBJECT_0 + cEvents)) + { + // Windows message came in - go process it + ComDebOut((DEB_CALLCONT, "BlockFn: Windows Message Arrived\n")); + HandleWakeForMsg(); + } + else if (dwWakeReason == WAIT_TIMEOUT) + { + if (_hr == RPC_S_WAITONTIMER && IsTimerAtZero()) + { + // The Retrytimer timed out - just exit and retransmit the call + ComDebOut((DEB_CALLCONT, "BlockFn: Timer at zero\n")); + _hr = RPC_E_SERVERCALL_RETRYLATER; + } + else + { + // we may have missed a message before we called MsgWaitForMult... + // so we go check now for any incoming messages. + ComDebOut((DEB_CALLCONT, "BlockFn: Timeout-Look for msgs\n")); + HandleWakeForMsg(); + } + } + else + { + // CallComplete signalled - the call is done. + Win4Assert(rgEvents[dwWakeReason - WAIT_OBJECT_0] == hEventCallComplete); + ComDebOut((DEB_CALLCONT, "BlockFn: CallComplete Event Signaled\n")); + _hr = S_OK; + } + + ComDebOut((DEB_CALLCONT, "CCliModalLoop::BlockFn this:%x returns:%x\n", + this, _hr)); + return _hr; +} + +//+------------------------------------------------------------------------- +// +// Member: CCliModalLoop::HandleWakeForMsg (private) +// +// Synopsis: Handle wake for the arrival of some kind of message +// +// Returns: nothing +// fClearedQueue flag set if appropriate +// +// Algorithm: If this is called to wake up for a posted message, we +// check the queue status. If the message queue status indicates +// that there is some kind of a modal loop going on, then we +// clear all the keyboard and mouse messages in our queue. Then +// if we wake up for all input, we check the message queue to +// see whether we need to notify the application that a message +// has arrived. Then, we dispatch any messages that have to do +// with the ORPC system. Finally we yield just in case we need +// to dispatch a send message in the VDM. For an input sync +// RPC, all we do is a call that will yield to get the pending +// send message dispatched. +// +// History: Dec-93 JohannP Created +// 13-Aug-94 Ricksa Created +// +//-------------------------------------------------------------------------- +INTERNAL_(void) CCliModalLoop::HandleWakeForMsg() +{ + MSG msg; // Used for various peeks. + + // Is this an input sync call? + if (_dwMsgQInputFlag != QS_SENDMESSAGE) + { + // No, so we have to worry about the state of the message queue. + // We have to be careful that we aren't holding the input focus + // on an input synchronized queue. + + // So what is the state of the queue? - note we or QS_TRANSFER because + // this an undocumented flag which tells us the the input focus has + // changed to us. + + DWORD dwQueueFlags = GetQueueStatus(QS_ALLINPUT | QS_TRANSFER); + ComDebOut((DEB_CALLCONT, "Queue Status %lx\n", dwQueueFlags)); + + // Call through to the application if we are going to. We do this here + // so that the application gets a chance to process any + // messages that it wants to and also allows the call control to + // dispatch certain messages that it knows how to, thus making the + // queue more empty. + + if (((_dwMsgQInputFlag & QS_ALLINPUT) == QS_ALLINPUT) && + FindMessage(dwQueueFlags)) + { + // pending message in the queue + HandlePendingMessage(); + } + + // Did the input focus change to us? + if ((LOWORD(dwQueueFlags) & QS_TRANSFER) || _dwFlags & CMLF_CLEAREDQUEUE) + { + ComDebOut((DEB_CALLCONT, "Message Queue is being cleared\n")); + _dwFlags |= CMLF_CLEAREDQUEUE; + + // Try to clear the queue as best we can of any messages that + // might be holding off some other modal loop from executing. + // So we eat all mouse and key events. + if (HIWORD(dwQueueFlags) & QS_KEY) + { + while (MyPeekMessage(&msg, NULL, WM_KEYFIRST, WM_KEYLAST, + PM_REMOVE | PM_NOYIELD)) + { + ; + } + } + + // Clear mouse releated messages if there are any + if (HIWORD(dwQueueFlags) & QS_MOUSE) + { + while (MyPeekMessage(&msg, NULL, WM_MOUSEFIRST, WM_MOUSELAST, + PM_REMOVE | PM_NOYIELD)) + { + ; + } + + while (MyPeekMessage(&msg, NULL, WM_NCMOUSEFIRST, + WM_NCMOUSELAST, PM_REMOVE | PM_NOYIELD)) + { + ; + } + + while (MyPeekMessage(&msg, NULL, WM_QUEUESYNC, WM_QUEUESYNC, + PM_REMOVE | PM_NOYIELD)) + { + ; + } + } + + // Get rid of paint message if we can as well -- this makes + // the screen look so much better. + if (HIWORD(dwQueueFlags) & QS_PAINT) + { + if (MyPeekMessage(&msg, NULL, WM_PAINT, WM_PAINT, PM_REMOVE | PM_NOYIELD)) + { + ComDebOut((DEB_CALLCONT, "Dispatch paint\n")); + DispatchMessage(&msg); + } + } + } + } + else if (!IsWOWThread() || !IsWOWThreadCallable()) + { + // We need to give user control so that the send message + // can get dispatched. Thus the following is simply a no-op + // which gets into user to let it dispatch the message. + PeekMessage(&msg, 0, WM_NULL, WM_NULL, PM_NOREMOVE); + } + + if (IsWOWThread() && IsWOWThreadCallable()) + { + // In WOW, a genuine yield is the only thing to guarantee + // that SendMessage will get through + g_pOleThunkWOW->YieldTask16(); + } +} + +//+------------------------------------------------------------------------- +// +// Member: CCliModalLoop::PeekRPCAndDDEMessage +// +// Synopsis: Called when a windows message arrives to look for incoming +// Rpc messages which might be the reply to an outstanding call +// or may be new incoming request messages. Also looks for +// DDE messages. +// +// Returns: TRUE - found and processed an RPC message +// FALSE - did not find an RPC message +// +// History: 21-Dec-93 JohannP Created +// 30-Apr-95 Rickhi ReWrite +// +//-------------------------------------------------------------------------- +BOOL CCliModalLoop::PeekRPCAndDDEMessage() +{ + // loop over all windows looking for incoming Rpc messages. Note that + // it is possible for a dispatch here to cause one of the windows to + // be deregistered or another to be registered, so our loop has to account + // for that, hence the check for NULL hWnd. + + BOOL fRet = FALSE; + MSG Msg; + + for (UINT i = 0; i < 2; i++) + { + // get window info and peek on it if the hWnd is still OK + SWindowData *pWD = _pACC->GetWindowData(i); + + if (pWD->hWnd != WD_EMPTY) + { + if (MyPeekMessage(&Msg, pWD->hWnd, pWD->wFirstMsg, pWD->wLastMsg, + PM_REMOVE | PM_NOYIELD)) + { + Win4Assert(IsWaiting()); + DispatchMessage(&Msg); + + // exit on the first dispatched message. If the message was + // not the reply we were waiting for, then the channel will + // call us back again. + return TRUE; + } + } + } + + return FALSE; +} + +//+------------------------------------------------------------------------- +// +// Member: CCliModalLoop::FindMessage +// +// Synopsis: Called by HandleWakeForMsg when a message arrives on the +// windows msg queue. Determines if there is something of +// interest to us, and pulls timer msgs. Dispatches RPC, DDE, +// and RPC timer messages. +// +// Arguments: [dwStatus] - current Queue status (from GetQueueStatus) +// +// Returns: TRUE - there is a message to process +// FALSE - no messages to process +// +// Algorithm: Find the next message in the queue by using the following +// priority list: +// +// 1. RPC and DDE messages +// 2. mouse and keyboard messages +// 3. other messages +// +// History: 21-Dec-93 Johannp Created +// +//-------------------------------------------------------------------------- +INTERNAL_(BOOL) CCliModalLoop::FindMessage(DWORD dwStatus) +{ + WORD wOld = HIWORD(dwStatus); + WORD wNew = (WORD) dwStatus; + + if (!wNew) + { + if (!(wOld & QS_POSTMESSAGE)) + return FALSE; // no messages to take care of + else + wNew |= QS_POSTMESSAGE; + } + + MSG Msg; + + // Priority 1: look for RPC and DDE messages + if (wNew & (QS_POSTMESSAGE | QS_SENDMESSAGE | QS_TIMER)) + { + if (PeekRPCAndDDEMessage()) + { + // we know that *some* RPC message came in, might be our + // reply or may be some incoming call. In any case, return to + // the modal loop to guy so we can figure out if we need to + // keep going. + return FALSE; + } + } + + if (wNew & QS_TIMER) + { + // throw the system timer messages away + while (MyPeekMessage(&Msg, 0, WM_SYSTIMER, WM_SYSTIMER, PM_REMOVE | PM_NOYIELD)) + ; + } + + // Priority 2: messages from the hardware queue + if (wNew & (QS_KEY | QS_MOUSEMOVE | QS_MOUSEBUTTON)) + { + return TRUE; // these messages are always removed + } + else if (wNew & QS_TIMER) + { + if (MyPeekMessage(&Msg, 0, WM_TIMER, WM_TIMER, PM_NOREMOVE | PM_NOYIELD) ) + return TRUE; + } + else if (wNew & QS_PAINT) + { + return TRUE; // this message might not get removed + } + else if (wNew & (QS_POSTMESSAGE | QS_SENDMESSAGE)) + { + if (MyPeekMessage(&Msg, 0, 0, 0, PM_NOREMOVE)) + return TRUE; // Priority 3: all other messages + } + + return FALSE; +} + +//+------------------------------------------------------------------------- +// +// Member: CCliModalLoop::HandlePendingMessage +// +// Synopsis: this function is called for system messages and other +// pending messages +// +// Arguments: none +// +// Returns: nothing, _hr may be updated if call is cancelled. +// +// Algorithm: +// +// History: 21-Dec-93 Johannp Created +// 30-Apr-95 Rickhi ReWrite +// +//-------------------------------------------------------------------------- +INTERNAL_(void) CCliModalLoop::HandlePendingMessage() +{ + // get and call the message filter if there is one + IMessageFilter *pMF = _pACC->GetMsgFilter(); + + if (pMF) + { + ComDebOut((DEB_MFILTER, + "pMF->MessagePending(dwTIDCallee:%x ElapsedTime:%x Type:%x)\n", + _dwTIDCallee, GetElapsedTime(), + (_pPrev) ? PENDINGTYPE_NESTED : PENDINGTYPE_TOPLEVEL)); + + DWORD dwRet = pMF->MessagePending((MF_HTASK)_dwTIDCallee, + GetElapsedTime(), + (_pPrev) ? PENDINGTYPE_NESTED + : PENDINGTYPE_TOPLEVEL); + + ComDebOut((DEB_MFILTER,"pMF->MessagePending() dwRet:%x\n", dwRet)); + + + _pACC->ReleaseMsgFilter(); + + if (dwRet == PENDINGMSG_CANCELCALL) + { + _hr = RPC_E_CALL_CANCELED; + return; + } + + Win4Assert((dwRet == PENDINGMSG_WAITDEFPROCESS || + dwRet == PENDINGMSG_WAITNOPROCESS) && + "Invalid return value from pMF->MessagePending"); + } + + // if we get here we are going to do the default message processing. + // Default Processing: Continue to wait for the call return and + // don't dispatch the new message. Perform default processing on + // special system messages. + + MSG msg; + + // we have to take out all syscommand messages + if (MyPeekMessage(&msg, 0, WM_SYSCOMMAND, WM_SYSCOMMAND, PM_REMOVE | PM_NOYIELD)) + { + // only dispatch some syscommands + if (msg.wParam == SC_HOTKEY || msg.wParam == SC_TASKLIST) + { + ComDebOut((DEB_CALLCONT,">>>> Dispatching SYSCOMMAND message: %x; wParm: %x \r\n",msg.message, msg.wParam)); + DispatchMessage(&msg); + } + else + { + ComDebOut((DEB_CALLCONT,">>>> Received/discarded SYSCOMMAND message: %x; wParm: %x \r\n",msg.message, msg.wParam)); + MessageBeep(0); + } + } + else if (MyPeekMessage(&msg, 0, WM_SYSKEYDOWN, WM_SYSKEYDOWN, PM_NOREMOVE | PM_NOYIELD)) + { + if (msg.message == WM_KEYDOWN) + { + if (msg.wParam != VK_CONTROL && msg.wParam != VK_SHIFT) + MessageBeep(0); + } + else if (msg.message == WM_SYSKEYDOWN && msg.lParam & SYS_ALTDOWN && + (msg.wParam == VK_TAB || msg.wParam == VK_ESCAPE)) + { + MyPeekMessage(&msg, 0, WM_SYSKEYDOWN, WM_SYSKEYDOWN, PM_REMOVE | PM_NOYIELD); + TranslateMessage(&msg); + DispatchMessage(&msg); + } + } + else if (MyPeekMessage(&msg, 0, WM_ACTIVATE, WM_ACTIVATE, PM_REMOVE | PM_NOYIELD) + || MyPeekMessage(&msg, 0, WM_ACTIVATEAPP, WM_ACTIVATEAPP, PM_REMOVE | PM_NOYIELD) + || MyPeekMessage(&msg, 0, WM_NCACTIVATE, WM_NCACTIVATE, PM_REMOVE | PM_NOYIELD) ) + { + DispatchMessage(&msg); + } +} + +//+------------------------------------------------------------------------- +// +// Member: CCliModalLoop::MyPeekMessage +// +// Synopsis: This function is called whenever we want to do a PeekMessage. +// It intercepts WM_QUIT messages and remembers them so that +// they can be reposted when the modal loop is exited. +// +// Arguments: [pMsg] - message structure +// [hWnd] - window to peek on +// [min/max] - min and max message numbers +// [wFlag] - peek flags +// +// Returns: TRUE - a message is available +// FALSE - no messages available +// +// History: 21-Dec-93 Johannp Created +// +//-------------------------------------------------------------------------- +INTERNAL_(BOOL) CCliModalLoop::MyPeekMessage(MSG *pMsg, HWND hwnd, + UINT min, UINT max, WORD wFlag) +{ + BOOL fRet = PeekMessage(pMsg, hwnd, min, max, wFlag); + + while (fRet) + { + ComDebOut((DEB_CALLCONT, "MyPeekMessage: hwnd:%x msg:%d time:%ld\n", + pMsg->hwnd, pMsg->message, pMsg->time)); + + if (pMsg->message != WM_QUIT) + { + // it is not a QUIT message so exit the loop and return TRUE + break; + } + + // just remember that we saw a QUIT message. we will ignore it for + // now and repost it after our call has completed. + + ComDebOut((DEB_CALLCONT, "WM_QUIT received.\n")); + _wQuitCode = pMsg->wParam; + _dwFlags |= CMLF_QUITRECEIVED; + + if (!(wFlag & PM_REMOVE)) // NOTE: dont use PM_NOREMOVE + { + // quit message is still on queue so pull it off + PeekMessage(pMsg, hwnd, WM_QUIT, WM_QUIT, PM_REMOVE | PM_NOYIELD); + } + + // peek again to see if there is another message + fRet = PeekMessage(pMsg, hwnd, min, max, wFlag); + } + + return fRet; +} + +#if DBG==1 +//+------------------------------------------------------------------------- +// +// Member: CCliModalLoop::DispatchMessage +// +// Synopsis: This function is called whenever we want to dispatch a +// message we have peeked. It is just a debug wrapper to provide +// debug out statements about dispatched messages. +// +// Arguments: [pMsg] - message structure +// +//-------------------------------------------------------------------------- +INTERNAL_(void) CCliModalLoop::DispatchMessage(MSG *pMsg) +{ + ComDebOut((DEB_CALLCONT, "Dispatching Message hWnd:%x msg:%d wParam:%x\n", + pMsg->hwnd, pMsg->message, pMsg->wParam)); + + ::DispatchMessage(pMsg); +} +#endif + +//+------------------------------------------------------------------------- +// +// Member: CCliModalLoop::GetElapsedTime +// +// Synopsis: Get the elapsed time for an RPC call +// +// Returns: Elapsed time of current call +// +// Algorithm: This checks whether we have the slow time factor. If not, +// and we are in WOW we read it from the registry. Otherwise, +// it is just set to one. Then we calculate the time of the +// RPC call and divide it by the slow time factor. +// +// History: 22-Jul-94 Ricksa Created +// +//-------------------------------------------------------------------------- +INTERNAL_(DWORD) CCliModalLoop::GetElapsedTime() +{ + // Define slow time factor to something invalid + static dwSlowTimeFactor = 0; + + if (dwSlowTimeFactor == 0) + { + if (IsWOWProcess()) + { + // Get time factor from registry otherwise set to the default + dwSlowTimeFactor = GetSlowTimeFactor(); + } + else + { + // Time is unmodified for 32 bit apps + dwSlowTimeFactor = 1; + } + } + + DWORD dwTickCount = GetTickCount(); + DWORD dwElapsedTime = dwTickCount - _dwTimeOfCall; + if (dwTickCount < _dwTimeOfCall) + { + // the timer wrapped + dwElapsedTime = 0xffffffff - _dwTimeOfCall + dwTickCount; + } + + return (dwElapsedTime / dwSlowTimeFactor); +} + +//+------------------------------------------------------------------------- +// +// Member: CCliModalLoop::FindPrevCallOnLID [server side] +// +// Synopsis: When an incoming call arrives this is used to find any +// previous call for the same logical thread, ignoring +// INTERNAL calls. The result is used to determine if this +// is a nested call or not. +// +// Arguments: [lid] - logical threadid of incoming call +// +// Returns: pCML - if a previous CliModalLoop found for this lid +// NULL - otherwise +// +// Algorithm: just walk backwards on the _pPrev chain +// +// History: 17-Dec-93 JohannP Created +// 30-Apr-95 Rickhi ReWrite +// +//-------------------------------------------------------------------------- +CCliModalLoop *CCliModalLoop::FindPrevCallOnLID(REFLID lid) +{ + CCliModalLoop *pCML = this; + + do + { + if (pCML->_lid == lid) + { + break; // found a match, return it + } + + } while ((pCML = pCML->_pPrev) != NULL); + + return pCML; +} + +//+------------------------------------------------------------------------- +// +// Function: STAInvoke +// +// Synopsis: Called whenever an incoming call arrives in a single-threaded +// apartment. It asks the apps message filter (if there is one) +// whether it wants to handle the call or not, and dispatches +// the call if OK. +// +// Arguments: [pMsg] - Incoming Rpc message +// [pStub] - stub to call if MF says it is OK +// [pChnl] - channel ptr to give to stub +// [pv] - real interface being called +// [pdwFault] - where to store fault code if there is a fault +// +// Returns: result for MF or from call to stub +// +// History: 21-Dec-93 Johannp Original Version +// 22-Jul-94 Rickhi ReWrite +// +//-------------------------------------------------------------------------- +INTERNAL STAInvoke(RPCOLEMESSAGE *pMsg, DWORD CallCatIn, IRpcStubBuffer *pStub, + IRpcChannelBuffer *pChnl, void *pv, DWORD *pdwFault) +{ + ComDebOut((DEB_CALLCONT, + "STAInvoke pMsg:%x CallCatIn:%x pStub:%x pChnl:%x\n", + pMsg, CallCatIn, pStub, pChnl)); + + HRESULT hr = HandleIncomingCall(*MSG_TO_IIDPTR(pMsg), + (WORD)pMsg->iMethod, + CallCatIn, pv); + if (hr == S_OK) + { + // the message filter says its OK to invoke the call. + + // construct a server call state. This puts the current incoming + // call's CallCat in Tls so we can check it if the server tries to + // make an outgoing call while handling this call. See CanMakeOutCall. + CSrvCallState SCS(CallCatIn); + + // invoke the call + hr = MTAInvoke(pMsg, CallCatIn, pStub, pChnl, pdwFault); + } + else if (hr == RPC_E_CALL_REJECTED || hr == RPC_E_SERVERCALL_RETRYLATER) + { + // server is rejecting the call, try to copy the incomming buffer so + // that the client has the option of retrying the call. + hr = CopyMsgForRetry(pMsg, pChnl, hr); + } + + ComDebOut((DEB_CALLCONT,"STAInvoke returns:%x\n",hr)); + return hr; +} + +//+------------------------------------------------------------------------- +// +// Function: HandleIncomingCall, internal +// +// Synopsis: Called whenever an incoming call arrives in a single-threaded +// apartment. It asks the app's message filter (if there is one) +// whether it wants to handle the call or not +// +// Arguments: [piid] - ptr to interface the call is being made on +// [iMethod] - method number being called +// [CallCatIn] - category of incoming call +// [pv] - real interface being called +// +// Returns: result from MF +// +// History: 11-Oct-96 Rickhi Separated from STAInvoke +// +//-------------------------------------------------------------------------- +INTERNAL HandleIncomingCall(REFIID riid, WORD iMethod, DWORD CallCatIn, void *pv) +{ + ComDebOut((DEB_CALLCONT, + "HandleIncomingCall iid:%I iMethod:%x CallCatIn:%x pv:%x:%x\n", + &riid, iMethod, CallCatIn, pv)); + + COleTls tls; + if (!(tls->dwFlags & OLETLS_APARTMENTTHREADED)) + { + // free-threaded apartments don't have a message filter + return S_OK; + } + + HRESULT hr = S_OK; + CAptCallCtrl *pACC = tls->pCallCtrl; + + + // We dont call the message filter for IUnknown since older versions + // of OLE did not, and doing so (unfortunately) breaks compatibility. + // Also check for IRundown since local clients call on it instead of + // IRemUnknown. + + IMessageFilter *pMF = (riid == IID_IRundown || riid == IID_IRemUnknown) + ? NULL : pACC->GetMsgFilter(); + + if (pMF) + { + // the app has installed a message filter, call it. + + INTERFACEINFO IfInfo; + IfInfo.pUnk = (IUnknown *)pv; + IfInfo.iid = riid; + IfInfo.wMethod = iMethod; + + ComDebOut((DEB_CALLCONT, "Calling iMethod:%x riid:%I\n", + IfInfo.wMethod, &IfInfo.iid)); + + CCliModalLoop *pCML = NULL; + REFLID lid = tls->LogicalThreadId; + DWORD TIDCaller = tls->dwTIDCaller; + + DWORD dwCallType = pACC->GetCallTypeForInCall(&pCML, lid, CallCatIn); + DWORD dwElapsedTime = (pCML) ? pCML->GetElapsedTime() : 0; + + // The DDE layer doesn't provide any interface information. This + // was true on the 16-bit implementation, and has also been + // brought forward into this implementation to insure + // compatibility. However, the CallCat of the IfInfo is still + // provided. + // + // Therefore, if pIfInfo has its pUnk member set to NULL, then + // we are going to send a NULL pIfInfo to the message filter. + + ComDebOut((DEB_MFILTER, + "pMF->HandleIncomingCall(dwCallType:%x TIDCaller:%x dwElapsedTime:%x IfInfo:%x)\n", + dwCallType, TIDCaller, dwElapsedTime, (IfInfo.pUnk) ? &IfInfo : NULL)); + + DWORD dwRet = pMF->HandleInComingCall(dwCallType, + (MF_HTASK)TIDCaller, + dwElapsedTime, + IfInfo.pUnk ? &IfInfo : NULL); + + ComDebOut((DEB_MFILTER,"pMF->HandleIncomingCall() dwRet:%x\n", dwRet)); + + pACC->ReleaseMsgFilter(); + + // strict checking of app return code for win32 + Win4Assert(dwRet == SERVERCALL_ISHANDLED || + dwRet == SERVERCALL_REJECTED || + dwRet == SERVERCALL_RETRYLATER || + IsWOWThread() && "Invalid Return code from App IMessageFilter"); + + + if (dwRet != SERVERCALL_ISHANDLED) + { + if (CallCatIn == CALLCAT_ASYNC || CallCatIn == CALLCAT_INPUTSYNC) + { + // Note: input-sync and async calls can not be rejected + // Even though they can not be rejected, we still have to + // call the MF above to maintain 16bit compatability. + hr = S_OK; + } + else if (dwRet == SERVERCALL_REJECTED) + { + hr = RPC_E_CALL_REJECTED; + } + else if (dwRet == SERVERCALL_RETRYLATER) + { + hr = RPC_E_SERVERCALL_RETRYLATER; + } + else + { + // 16bit OLE let bogus return codes go through and of course + // apps rely on that behaviour so we let them through too, but + // we are more strict on 32bit. + hr = (IsWOWThread()) ? S_OK : RPC_E_UNEXPECTED; + } + } + } + + ComDebOut((DEB_CALLCONT, "HandleIncomingCall hr:%x\n", hr)); + return hr; +} + +//+------------------------------------------------------------------------- +// +// Function: MTAInvoke +// +// Synopsis: Multi-Threaded Apartment Invoke. Called whenever an incoming +// call arrives in the MTA apartment (or as a subroutine to +// STAInvoke). It just dispatches to a common sub-routine. +// +// Arguments: [pMsg] - Incoming Rpc message +// [pStub] - stub to call if MF says it is OK +// [pChnl] - channel ptr to give to stub +// [pdwFault] - where to store fault code if there is a fault +// +// Returns: result from calling the stub +// +// History: 03-Oct-95 Rickhi Made into subroutine from STAInvoke +// +//-------------------------------------------------------------------------- +INTERNAL MTAInvoke(RPCOLEMESSAGE *pMsg, DWORD CallCatIn, IRpcStubBuffer *pStub, + IRpcChannelBuffer *pChnl, DWORD *pdwFault) +{ +#if DBG==1 + ComDebOut((DEB_CALLCONT, + "MTAInvoke pMsg:%x CallCatIn:%x pStub:%x pChnl:%x\n", + pMsg, CallCatIn, pStub, pChnl)); + IID *piid = MSG_TO_IIDPTR(pMsg); + DebugPrintORPCCall(ORPC_INVOKE_BEGIN, *piid, pMsg->iMethod, CallCatIn); + RpcSpy((CALLIN_BEGIN, NULL, *piid, pMsg->iMethod, 0)); +#endif + + // call a common subroutine to do the dispatch. The subroutine also + // catches exceptions and provides some debug help. + + HRESULT hr = StubInvoke(pMsg, pStub, pChnl, pdwFault); + +#if DBG==1 + RpcSpy((CALLIN_END, NULL, *piid, pMsg->iMethod, hr)); + DebugPrintORPCCall(ORPC_INVOKE_END, *piid, pMsg->iMethod, CallCatIn); + ComDebOut((DEB_CALLCONT,"MTAInvoke returns:%x\n",hr)); +#endif + + return hr; +} + +//+------------------------------------------------------------------------- +// +// Function: CopyMsgForRetry +// +// Synopsis: Makes a copy of the server-side message buffer to return to +// the client so that the client can retry the call later. +// Returns an error if the client is on a different machine. +// +// Parameters: [pMsg] - ptr to message to copy +// [pChnl] - ptr to channel call is being made on +// [hr] - result code +// +// History: 30-05-95 Rickhi Created +// +//+------------------------------------------------------------------------- +HRESULT CopyMsgForRetry(RPCOLEMESSAGE *pMsg, IRpcChannelBuffer *pChnl, HRESULT hrIn) +{ + ComDebOut((DEB_CALLCONT,"CopyMsgForRetry pMsg:%x pChnl:%x pBuffer:%x\n", + pMsg, pChnl, pMsg->Buffer)); + + DWORD dwDestCtx; + HRESULT hr = pChnl->GetDestCtx(&dwDestCtx, NULL); + + if (SUCCEEDED(hr) && dwDestCtx != MSHCTX_DIFFERENTMACHINE && + !IsEqualGUID(IID_IObjServer, *MSG_TO_IIDPTR(pMsg))) + { + // client on same machine as server. + void *pSavedBuffer = pMsg->Buffer; + hr = pChnl->GetBuffer(pMsg, *MSG_TO_IIDPTR(pMsg)); + + if (SUCCEEDED(hr)) + { + // copy original buffer to the new buffer + memcpy(pMsg->Buffer, pSavedBuffer, pMsg->cbBuffer); + hr = hrIn; + } + } + else + { + // client on different machine than server, or the call was on + // the activation interface, fail the call and dont send back + // a copy of the parameter packet. + hr = RPC_E_CALL_REJECTED; + } + + ComDebOut((DEB_CALLCONT,"CopyMsgForRetry pBuffer:%x hr:%x\n", + pMsg->Buffer, hr)); + return hr; +} + +//+------------------------------------------------------------------------- +// +// Method: CAptCallCtrl::GetCallTypeForInCall +// +// Synopsis: called when an incoming call arrives in order to determine +// what CALLTYPE to pass to the applications message filter. +// +// Arguments: [ppCML] - Client Modal Loop of prev call on same lid (if any) +// [lid] - logical thread id of this call +// [dwCallCat] - call category of incoming call +// +// Returns: the CALLTYPE to give to the message filter +// +// History: 21-Dec-93 Johannp Created +// 30-Apr-95 Rickhi ReWrite +// +// Notes: +// +// 1 = CALLTYPE_TOPLEVEL // sync or inputsync call - no outgoing call +// 2 = CALLTYPE_NESTED // callback on behalf of previous outgoing call +// 3 = CALLTYPE_ASYNC // asynchronous call - no outstanding call +// 4 = CALLTYPE_TOPLEVEL_CALLPENDING // call with new LID - outstand call +// 5 = CALLTYPE_ASYNC_CALLPENDING // async call - outstanding call +// +//-------------------------------------------------------------------------- +DWORD CAptCallCtrl::GetCallTypeForInCall(CCliModalLoop **ppCML, + REFLID lid, DWORD dwCallCatIn) +{ + DWORD CallType; + CCliModalLoop *pCML = GetTopCML(); + + if (dwCallCatIn == CALLCAT_ASYNC) // asynchronous call has arrived + { + if (pCML == NULL) + CallType = CALLTYPE_ASYNC; // no outstanding calls + else + CallType = CALLTYPE_ASYNC_CALLPENDING; // outstanding call + } + else // non-async call has arrived + { + if (pCML == NULL) + CallType = CALLTYPE_TOPLEVEL; // no outstanding call + else if ((*ppCML = pCML->FindPrevCallOnLID(lid)) != NULL) + CallType = CALLTYPE_NESTED; // outstanding call on same lid + else + CallType = CALLTYPE_TOPLEVEL_CALLPENDING; // different lid + } + + ComDebOut((DEB_CALLCONT,"GetCallTypeForInCall return:%x\n", CallType)); + return CallType; +} + diff --git a/private/ole32/com/dcomrem/callctrl.hxx b/private/ole32/com/dcomrem/callctrl.hxx new file mode 100644 index 000000000..44b2efcd1 --- /dev/null +++ b/private/ole32/com/dcomrem/callctrl.hxx @@ -0,0 +1,525 @@ +//+------------------------------------------------------------------------- +// +// Microsoft Windows +// Copyright (C) Microsoft Corporation, 1992 - 1993. +// +// File: callctrl.hxx +// +// Contents: OLE Call Control +// +// Functions: +// +// History: 21-Dec-93 Johannp Original Version +// 04-Nov-94 Rickhi ReWrite +// +//-------------------------------------------------------------------------- +#ifndef __CALLCTRL_HXX__ +#define __CALLCTRL_HXX__ + +#include <channelb.hxx> // CRpcChannelBuffer + +#undef RPC_S_CALLPENDING +#undef RPC_S_WAITONTIMER +#define RPC_S_CALLPENDING 21 // BUGBUG +#define RPC_S_WAITONTIMER 22 // BUGBUG + +// Max time we wait for MsgWaitForMultiple before waking up and +// checking the queue. This is needed because the API is broken +// (ie it does not wake up on messages posted before it is called). +#define MAX_TICKS_TO_WAIT 1000 + + +// Private definition for change of input focus +// BUGBUG: fix for CHICAGO when chicago's USER supports QS_TRANSFER +#ifdef _CHICAGO_ +#define QS_TRANSFER 0x0000 +#else +#define QS_TRANSFER 0x4000 +#endif + + +typedef IID LID; // logical thread id +typedef REFIID REFLID; // ref to logical thread id + +#define MF_HTASK struct HTASK__ * + + +// the following table is used to quickly determine what windows +// message queue inputflag to specify for the various categories of +// outgoing calls in progress. The table is indexed by CALLCATEGORY. + +extern DWORD gMsgQInputFlagTbl[4]; // see callctrl.cxx + +// the following table is used to map bit flags in the Rpc Message to +// the equivalent OLE CALLCATEGORY. + +extern DWORD gRpcFlagToCallCatMap[3]; // see callctrl.cxx + +// the following inline funtion is used to compute the CALLCATEGORY from +// the RpcFlags field in the RPC message + +inline DWORD RpcFlagToCallCat(DWORD RpcFlags) +{ + return gRpcFlagToCallCatMap[(RpcFlags & 0x60000000) >> 29]; +} + +// convenient mapping from RPCOLEMESSAGE to IID in the message +#define MSG_TO_IIDPTR(pMsg) \ + &((RPC_SERVER_INTERFACE *)((RPC_MESSAGE *)pMsg)->RpcInterfaceInformation)->InterfaceId.SyntaxGUID + + +// private structure used to hold the window handles and message ranges +// to peek to see if there is more work to do when in the modal loop. + +typedef struct tagSWindowData +{ + HWND hWnd; // window handle to peek on + UINT wFirstMsg; // first msg in range to peek + UINT wLastMsg; // Last msg in range to peek +} SWindowData; + + +// function prototypes. This function is called by the channel during +// transmission in the apartment model. + +RPC_STATUS OleModalLoopBlockFn(void *, void *, HANDLE hEventComplete); + + +// function called by the channel during dispatch in the apartment model. +// STAInvoke is used for single-threaded apartments, MTAInvoke is used +// for Multi-threaded apartments. + +INTERNAL STAInvoke(RPCOLEMESSAGE *pMsg, DWORD dwCallCat, IRpcStubBuffer *pStub, + IRpcChannelBuffer *pChnl, void *pv, DWORD *pdwFault); + +INTERNAL MTAInvoke(RPCOLEMESSAGE *pMsg, DWORD dwCallCat, IRpcStubBuffer *pStub, + IRpcChannelBuffer *pChnl, DWORD *pdwFault); + +INTERNAL CanMakeOutCall(DWORD dwCallCatOut, REFIID riid); + +class CAptCallCtrl; + + +//+------------------------------------------------------------------------- +// +// Function: GetAptCallCtrl +// +// Synopsis: Gets the current apartment's call control ptr from TLS. +// +//+------------------------------------------------------------------------- +inline CAptCallCtrl *GetAptCallCtrl(void) +{ + COleTls tls; + return tls->pCallCtrl; +} + +//+------------------------------------------------------------------------- +// +// Class: CAptRpcChnl +// +// Synopsis: Client side Apartment model Rpc Channel. +// +// History: 11-Nov-94 Rickhi Created +// +// Notes: This object inherits the Rpc channel and adds some +// functionality to it that is needed by the apartment model, +// (eg deadlock prevention, call retry, nested call support). +// For each outgoing call, it verifies the app is allowed to +// make the call, and instantiates a modal loop for it, thereby +// allowing callbacks and new calls to be handled by the app +// thread. +// +// Important: Since no mutual exclusion primitives are used in this code, +// the derived class CAptRpcChnl must be stateless, as some +// proxies are freethreaded even in the apartment model, in +// particular, IRemUnknown and the SCM activation interface. All +// relevant state is maintained in the CCliModalLoop object which +// is constructed on the stack on a per call basis. Note that the +// _dwTIDCallee state is safe because it is set in the ctor and +// never changes. The base class, CRpcChannelBuffer *is* thread +// safe since it is used in the freethreaded model also. +// +//-------------------------------------------------------------------------- +class CAptRpcChnl : public CRpcChannelBuffer +{ +public: + CAptRpcChnl(CStdIdentity *pStdId, OXIDEntry *pOXIDEntry, DWORD eState); + + // CRpcChannelBuffer methods that we override + STDMETHOD (GetBuffer) (RPCOLEMESSAGE *pMsg, REFIID riid); + STDMETHOD (SendReceive) (RPCOLEMESSAGE *pMsg, ULONG *pulStatus); + +private: + ~CAptRpcChnl(); // can only be called from Release + HRESULT CopyMsgForRetry(RPCOLEMESSAGE *pMsg); + + DWORD _dwTIDCallee; // TID of thread server lives on + DWORD _dwAptId; // Apartment ID proxy lives in. +}; + + +//+------------------------------------------------------------------------- +// +// Class: CMTARpcChnl +// +// Synopsis: Client side Multi-Threaded Apartment Rpc Channel. +// +// History: 11-Nov-94 Rickhi Created +// +// Notes: This object inherits the Rpc channel and adds some +// functionality to it that is needed by Multi-threaded apartment. +// For each outgoing call, it verifies the app is allowed to +// make the call. +// +// Important: Since no mutual exclusion primitives are used in this code, +// the derived class CMTARpcChnl must be stateless. Note that the +// _dwTIDCallee state is safe because it is set in the ctor and +// never changes. The base class, CRpcChannelBuffer *is* thread +// safe. +// +//-------------------------------------------------------------------------- +class CMTARpcChnl : public CRpcChannelBuffer +{ +public: + CMTARpcChnl(CStdIdentity *pStdId, OXIDEntry *pOXIDEntry, DWORD eState); + + // CRpcChannelBuffer methods that we override + STDMETHOD (GetBuffer) (RPCOLEMESSAGE *pMsg, REFIID riid); + +private: + ~CMTARpcChnl(); // can only be called from Release + + DWORD _dwTIDCallee; // TID of thread server lives on + DWORD _dwAptId; // Apartment ID proxy lives in. +}; + + +//+------------------------------------------------------------------------- +// +// Class: CCliModalLoop +// +// Synopsis: Each outgoing client call enters a modal loop. This object +// maintains the state of the modal loop for one outgoing +// call. +// +// History: 11-Nov-94 Rickhi Created +// +// Notes: This object is constructed on the stack on a per call basis +// and needs no mutual exclusion mechanisms. A pointer to this +// state is stored in TLS (actually in CAptCallCtrl in TLS) and +// later referenced when OleModalLoopBlockFn is called from deep +// within the bowls of SendReceive in the channel (or Rpc Runtime +// if the MSWMSG transport is used). +// +//-------------------------------------------------------------------------- +class CCliModalLoop +{ +public: + CCliModalLoop(DWORD TIDCallee, DWORD CallCatOut); + ~CCliModalLoop(); + + CCliModalLoop *FindPrevCallOnLID(REFLID lid); + HRESULT SendReceive(RPCOLEMESSAGE *pMsg, ULONG *pulStatus, + IRpcChannelBuffer2 *pChnl); + HRESULT BlockFn(HANDLE hEventCallComplete); + BOOL IsWaiting(void) + { + return (_hr == RPC_S_CALLPENDING || _hr == RPC_S_WAITONTIMER); + } + + INTERNAL_(BOOL) MyPeekMessage(MSG *pMsg, HWND hwnd, UINT min, UINT max, WORD wFlag); + INTERNAL_(DWORD) GetElapsedTime(); + +private: + // message processing in modal loop + INTERNAL_(void) HandleWakeForMsg(void); + INTERNAL_(BOOL) FindMessage(DWORD dwStatus); + INTERNAL_(void) HandlePendingMessage(void); + INTERNAL_(BOOL) PeekRPCAndDDEMessage(void); + +#if DBG==1 + INTERNAL_(void) DispatchMessage(MSG *pMsg); +#endif + + // rejected call processing + INTERNAL HandleRejectedCall(IRpcChannelBuffer2 *pChnl); + INTERNAL StartTimer(DWORD dwMilliSecToWait); + INTERNAL_(BOOL) IsTimerAtZero(); + INTERNAL_(DWORD) TicksToWait(); + + + HRESULT _hr; // the return value of this call + CCliModalLoop *_pPrev; // Previous CCliModalLoop for this apartment + DWORD _dwTIDCallee; // TID of thread we are calling + DWORD _dwMsgQInputFlag; // message queue input flag + LID _lid; // logical threadid of call + + DWORD _dwFlags; // internal flags (see CMLFLAGS) + UINT _wQuitCode; // quit code if WM_QUIT received + + DWORD _dwTimeOfCall; // time when call was made + DWORD _dwWakeup; // absolute time to wake up + DWORD _dwMillSecToWait; // relative time + + CAptCallCtrl *_pACC; // apartment call control object +}; + +// bit values for the CliModalLoop _dwFlags field +typedef enum tagCMLFLAGS +{ + CMLF_QUITRECEIVED = 1, // WM_QUIT was received + CMLF_CLEAREDQUEUE = 2 // the msg queue has been cleared +} CMLFLAGS; + + +//+------------------------------------------------------------------------- +// +// Class: CSrvCallState +// +// Synopsis: Each incoming server call generates one of these objects. +// It maintains the state of the incoming call. +// +// History: 11-Nov-94 Rickhi Created +// +//-------------------------------------------------------------------------- +class CSrvCallState +{ +public: + CSrvCallState(DWORD CallCatIn); + ~CSrvCallState(); + DWORD GetCallCatIn(void) { return _dwCallCatIn; } + +private: + DWORD _dwCallCatIn; // category of this incoming call + CSrvCallState *_pPrev; // previous CSrvCallState on the stack +}; + + +//+------------------------------------------------------------------------- +// +// Class: CAptCallCtrl +// +// Synopsis: Represents per apartment Call Control state that is shared +// between both the client side and server side call control +// objects. +// +// History: 11-Nov-94 Rickhi Created +// +// Notes: Two LIFO stacks are maintained, one for client call modal +// loops, and one for incoming server calls. The incoming +// server calls are used in both Single-Threaded apartments +// and multi-threaded apartments, and so are stored in TLS +// directly, rather than chained off this object. +// +//-------------------------------------------------------------------------- +class CAptCallCtrl +{ +public: + CAptCallCtrl(); + ~CAptCallCtrl(); + + // message filter handling methods + IMessageFilter *InstallMsgFilter(IMessageFilter *pMF); + IMessageFilter *GetMsgFilter(); + void ReleaseMsgFilter() { _fInMsgFilter = FALSE; } + BOOL InMsgFilter() { return _fInMsgFilter; } + + // client side LIFO modal loop queue + void SetTopCML(CCliModalLoop *pCML) { _pTopCML = pCML; } + CCliModalLoop *GetTopCML(void) { return _pTopCML; } + + // modal loop helper functions + SWindowData *GetWindowData(UINT i) { return &_WD[i]; } + DWORD GetCallTypeForInCall(CCliModalLoop **ppCML, + REFLID lid, DWORD dwCallCatIn); + + // window registration/revocation methods (used by channel & dde) + void Register(HWND hWnd, UINT wFirstMsg, UINT wLastMsg); + void Revoke(HWND hWnd); + +private: + IMessageFilter *_pMF; // app supplied Msg Filter + BOOL _fInMsgFilter; // TRUE when calling the Apps MF + + CCliModalLoop *_pTopCML; // topmost Client Modal Loop + + SWindowData _WD[2]; // RPC and DDE Window Data +}; + + + +//+------------------------------------------------------------------------- +// +// Method: CCliModalLoop::CCliModalLoop +// +// Synopsis: constructor for the client side modal loop +// +//+------------------------------------------------------------------------- +inline CCliModalLoop::CCliModalLoop(DWORD dwTIDCallee, DWORD dwMsgQInputFlag) : + _dwTIDCallee(dwTIDCallee), + _dwMsgQInputFlag(dwMsgQInputFlag), + _dwFlags(0) // all flags start FALSE +{ + COleTls tls; + + _lid = tls->LogicalThreadId; + + // push self on top of the per apartment modal loop stack + _pACC = tls->pCallCtrl; + _pPrev = _pACC->GetTopCML(); + _pACC->SetTopCML(this); + + _dwTimeOfCall = GetTickCount(); // record start time of the call + + // the rest of the fields are initialized when first used + + ComDebOut((DEB_CALLCONT, "CCliModalLoop::CCliModalLoop at:%x\n", this)); +} + +//+------------------------------------------------------------------------- +// +// Method: CCliModalLoop::~CCliModalLoop +// +// Synopsis: destructor for the client side modal loop +// +//+------------------------------------------------------------------------- +inline CCliModalLoop::~CCliModalLoop() +{ + // pop self off the per apartment modal loop stack by resetting the + // top of stack to the previous value. + _pACC->SetTopCML(_pPrev); + + // repost any WM_QUIT message we intercepted during the call + if (_dwFlags & CMLF_QUITRECEIVED) + { + ComDebOut((DEB_CALLCONT, "posting WM_QUIT\n")); + PostQuitMessage(_wQuitCode); + } + + ComDebOut((DEB_CALLCONT, "CCliModalLoop::~CCliModalLoop at:%x\n", this)); +} + +//+------------------------------------------------------------------------- +// +// Method: CCliModalLoop::StartTimer +// +// Synopsis: starts a timer when a call is rejected and the client +// wants to retry it later. +// +//+------------------------------------------------------------------------- +inline HRESULT CCliModalLoop::StartTimer(DWORD dwMilliSecToWait) +{ + // Set time when we should awake and retry the call. Note that + // if the GetTickCount + dwMilliSecToWait wraps the timer, then + // we may wakeup earlier than expected, but at least we wont + // deadlock. + + ComDebOut((DEB_CALLCONT, + "Timer installed for %lu msec.\n", dwMilliSecToWait)); + + _dwMillSecToWait = dwMilliSecToWait; + _dwWakeup = GetTickCount() + dwMilliSecToWait; + + // caller should place the return value in _hr + return RPC_S_WAITONTIMER; +} + +//+------------------------------------------------------------------------- +// +// Method: CCliModalLoop::IsTimerAtZero +// +// Synopsis: returns TRUE if the timer is not started or the timer has +// run down. +// +//+------------------------------------------------------------------------- +inline BOOL CCliModalLoop::IsTimerAtZero() +{ + // if no timer installed, return TRUE + if (_hr != RPC_S_WAITONTIMER) + return TRUE; + + DWORD dwTickCount = GetTickCount(); + + // the second test is in case GetTickCount wrapped during + // the call. see also the comment in StartTimer. + + if (dwTickCount > _dwWakeup || + dwTickCount < _dwWakeup - _dwMillSecToWait) + { + // this _hr will tell SendReceive to retransmit the call + _hr = RPC_E_SERVERCALL_RETRYLATER; + return TRUE; + } + + return FALSE; +} + +//+------------------------------------------------------------------------- +// +// Method: CCliModalLoop::TicksToWait +// +// Synopsis: returns the amount of time to wait for a message to arrive. +// +//+------------------------------------------------------------------------- +inline DWORD CCliModalLoop::TicksToWait() +{ + if (_hr != RPC_S_WAITONTIMER) + return MAX_TICKS_TO_WAIT; + + // waiting to retry a rejected call + DWORD dwTick = GetTickCount(); + return (_dwWakeup < dwTick) ? 0 : _dwWakeup - dwTick; +} + +//+------------------------------------------------------------------------- +// +// Method: CSrvCallState::CSrvCallState +// +// Synopsis: constructor for server side call state. Pushes the call +// state on the call control stack. +// +//+------------------------------------------------------------------------- +inline CSrvCallState::CSrvCallState(DWORD dwCallCatIn) : + _dwCallCatIn(dwCallCatIn) +{ + // push self on top of the per apartment server call state stack + COleTls tls; + + _pPrev = tls->pTopSCS; + tls->pTopSCS = this; +} + +//+------------------------------------------------------------------------- +// +// Method: CSrvCallState::~CSrvCallState +// +// Synopsis: destructor for server side call state. Pops the call +// state off the call control stack. +// +//+------------------------------------------------------------------------- +inline CSrvCallState::~CSrvCallState() +{ + // pop self on top of the per apartment server call state stack + COleTls tls; + tls->pTopSCS = _pPrev; +} + +//+------------------------------------------------------------------------- +// +// Method: CAptCallCtrl::GetMsgFilter +// +// Synopsis: returns the IMessageFilter and set the flag indicating we +// are currently calling the IMF, so that apps are prevented +// from making outgoing calls while inside their IMF. +// +//+------------------------------------------------------------------------- +inline IMessageFilter *CAptCallCtrl::GetMsgFilter() +{ + if (_pMF) + { + _fInMsgFilter = TRUE; + } + return _pMF; +} + +#endif // __CALLCTRL_HXX__ diff --git a/private/ole32/com/dcomrem/chancont.cxx b/private/ole32/com/dcomrem/chancont.cxx new file mode 100644 index 000000000..8c936f93c --- /dev/null +++ b/private/ole32/com/dcomrem/chancont.cxx @@ -0,0 +1,1098 @@ +/*++ + +copyright (c) 1992 Microsoft Corporation + +Module Name: + + chancont.cxx + +Abstract: + + This module contains thread switching code for the single threaded mode + and the message filter hooks + +Author: + + Alex Mitchell + +Revision History: + + Mar 1994 JohannP Added call category support. + 29 Dec 1993 Alex Mitchell Creation. + 19 Jul 1994 CraigWi Added support for ASYNC calls + 27-Jan-95 BruceMa Don't get on CChannelControl list unless + constructor is successsful + +Functions: + +--*/ + +#include <ole2int.h> +#include <userapis.h> + +#include <chancont.hxx> +#include <channelb.hxx> +#include <threads.hxx> +#include <objerror.h> +#include <callctrl.hxx> +#include <service.hxx> +#include <ipidtbl.hxx> + +/* Prototypes. */ +void Cancel ( CChannelCallInfo ** ); +HRESULT ModalLoop ( CChannelCallInfo *call ); +HRESULT ProtectedPostToSTA( OXIDEntry *, CChannelCallInfo *call ); +HRESULT TransmitCall( OXIDEntry *, CChannelCallInfo * ); + +/***************************************************************************/ +/* Globals. */ + +// Rpc worker thread cache. +CRpcThreadCache gRpcThreadCache; + +// Event cache. +CEventCache gEventCache; + +HANDLE CEventCache::_list[] = {0,0,0,0,0,0,0,0}; +DWORD CEventCache::_ifree = 0; + + +extern LPTSTR gOleWindowClass; + +extern BOOL gfChannelProcessInitialized; + +extern BOOL gfDestroyingMainWindow; + + +/***************************************************************************/ +CChannelCallInfo::CChannelCallInfo( CALLCATEGORY callcat, + RPCOLEMESSAGE *message, + DWORD flags, + REFIPID ipidServer, + DWORD destctx, + CRpcChannelBuffer *channel, + DWORD authn_level ) +{ + // The call info must hold a reference to the channel on the client side + // because the channel holds the binding handle that ThreadSendReceive + // uses. + category = callcat; + event = NULL; + iFlags = flags; + eState = in_progress_cs; + pmessage = message; + ipid = ipidServer; + iDestCtx = destctx; + pNext = NULL; + pHeader = NULL; + pChannel = channel; + lSavedAuthnLevel = 0; + lAuthnLevel = authn_level; + if (pChannel != NULL) + pChannel->AddRef(); +} + + +/***************************************************************************/ +CChannelCallInfo::~CChannelCallInfo() +{ + if (event != NULL) + gEventCache.Free(event); + + // Release the reply buffer. + if (eState == canceled_cs && pmessage->Buffer != NULL) + DeallocateBuffer(pmessage); + + // Release the channel. + if (pChannel != NULL) + pChannel->Release(); +} + +/***************************************************************************/ +#if DBG==1 +void DebugIsValidWindow(void *hWnd) +{ + // USER could be out of memory and unable to validate the handle. + // GetDesktopWindow only returns NULL if USER is out of memory. So + // we only assert if USER is not out of memory and our window handle + // is invalid. + if (GetDesktopWindow() == NULL) + return; + + Win4Assert( IsWindow((HWND) hWnd)); +} +#else +inline void DebugIsValidWindow(void *hWnd) {} +#endif + +/***************************************************************************/ +void Cancel( CChannelCallInfo **call ) +{ + DWORD result; + + // If the call is still in progress, change it to canceled. + LOCK + if ((*call)->eState == in_progress_cs) + (*call)->eState = canceled_cs; + UNLOCK + + // If the call completed before it could be canceled, wait for it to + // signal the completion event and clean up. + if ((*call)->eState == server_done_cs || (*call)->eState == got_done_msg_cs) + { + (*call)->eState = canceled_cs; + if (IsWOWThread() && (*call)->Local()) + { + // If the reply has arrived, the call can be deleted. + if ((*call)->eState == got_done_msg_cs) + { + delete *call; + } + // Otherwise + // the completion routine will have posted a message instead of + // setting an event, so we have to mark it as canceled and cleanup + // when the Reply msg comes in. + return; + } + else + { + // A call that completed in TransmitCall (ie, didn't create an event) + // cannot be canceled. + + Win4Assert( (*call)->event != NULL ); + result = WaitForSingleObject((*call)->event, INFINITE); + Win4Assert( result == WAIT_OBJECT_0 ); + + delete *call; + } + } + + // Null the CChannelCallInfo pointer so no one tries to access it. + *call = NULL; +} + +/***************************************************************************/ +HRESULT GetToSTA( OXIDEntry *pOxid, CChannelCallInfo *call ) +{ + TRACECALL(TRACE_RPC, "GetToSTA"); + ComDebOut((DEB_CHANNEL, "GetToSTA pCall:%x\n", call)); + gOXIDTbl.ValidateOXID(); + ASSERT_LOCK_HELD + + HRESULT result; + + Win4Assert(call->event == NULL); + Win4Assert(pOxid->dwTid != GetCurrentThreadId()); + + + // Don't accept calls if this thread has been uninitialized. + if (pOxid->dwFlags & OXIDF_STOPPED) + return RPC_E_SERVER_DIED_DNE; + + if (call->category == CALLCAT_INPUTSYNC) + { + UNLOCK + ASSERT_LOCK_RELEASED + // On CoUninitialize this may fail when the window is destroyed. + // Pass the thread id to aid debugging. + SetLastError( 0 ); + SendMessage((HWND)pOxid->hServerSTA, WM_OLE_ORPC_SEND, + GetCurrentThreadId(), (DWORD) call); + if (GetLastError() == 0) + result = call->hResult; + else + result = RPC_E_SERVER_DIED; + ASSERT_LOCK_RELEASED + LOCK + } + else if (call->category == CALLCAT_ASYNC) + { + // async call; copy message, post message and return. + // NOTE that in the MTA case, async was converted to SYNC by + // the call control. + + CChannelCallInfo *copy = MakeAsyncCopy( call ); + if (copy == NULL) + { + result = RPC_E_OUT_OF_RESOURCES; + } + else + { + // Post a message and wait for the app to get back to GetMessage. + result = ProtectedPostToSTA( pOxid, copy ); + + if (result != S_OK) + { + // error in posting; free packet and return error (result set above) + delete copy; + } + } + } + else + { + Win4Assert( call->category == CALLCAT_SYNCHRONOUS ); + + // Get completion event. May cause an event to be created. + result = gEventCache.Get( &call->event ); + if (result == S_OK) + { + result = ProtectedPostToSTA( pOxid, call ); + + if (result == S_OK) + { + UNLOCK + ASSERT_LOCK_RELEASED + + // Wait for the app to finish processing the request. + if (WaitForSingleObject(call->event, INFINITE) == WAIT_OBJECT_0) + result = call->hResult; + else + result = RPC_E_SYS_CALL_FAILED; + + ASSERT_LOCK_RELEASED + LOCK + } + } + } + + ASSERT_LOCK_HELD + gOXIDTbl.ValidateOXID(); + return result; +} + +/***************************************************************************/ +HRESULT ModalLoop( CChannelCallInfo *pcall ) +{ + ASSERT_LOCK_RELEASED + DWORD result; + + // we should only enter the modal loop for synchronous calls or input + // synchronous calls to another process or to an MTA apartment within + // the current process. + + Win4Assert(pcall->category == CALLCAT_SYNCHRONOUS || + pcall->category == CALLCAT_INPUTSYNC); + + + // detemine if we are using an event or a postmessage for the call + // completion signal. We use PostMessage only for process local + // calls in WOW, otherwise we use events and the OleModalLoop determines + // if the call completed or not. + + BOOL fMsg = (pcall->Local() && IsWOWThread()); + BOOL fWait = TRUE; + CAptCallCtrl *pACC = GetAptCallCtrl(); + CCliModalLoop *pCML = pACC->GetTopCML(); + + ComDebOut((DEB_CALLCONT,"ModalLoop: wait on %s\n",(fMsg) ? "Msg" : "Event")); + + // Wait at least once so the event is returned to the cache in the + // unsignalled state. + do + { + Win4Assert(fMsg || pcall->event); + + result = OleModalLoopBlockFn(NULL, pCML, pcall->event); + + if (fMsg) + { + if (result == RPC_E_CALL_CANCELED) + { + fWait = FALSE; + } + else + { + // loop until the call's state indicates the arrival of the + // reply message. + fWait = (pcall->eState != got_done_msg_cs); + result = S_OK; + } + } + else + { + // loop until the OleModalLoop tells us the call is no longer + // pending. + fWait = (result == RPC_S_CALLPENDING); + } + + } while (fWait); + + ASSERT_LOCK_RELEASED + return result; +} + +#if DBG==1 +/***************************************************************************/ +LONG ProtectedPostExceptionFilter( DWORD lCode, + LPEXCEPTION_POINTERS lpep ) +{ + ComDebOut((DEB_ERROR, "Exception 0x%x in ProtectedPostToCOMThread at address 0x%x\n", + lCode, lpep->ExceptionRecord->ExceptionAddress)); + DebugBreak(); + return EXCEPTION_EXECUTE_HANDLER; +} +#endif // DBG + +/***************************************************************************/ +// executed on client thread (in local case) and RPC thread (in remote case); +// posts a message to the server thread, guarding against disconnected threads +HRESULT ProtectedPostToSTA( OXIDEntry *pOxid, CChannelCallInfo *call ) +{ + ComDebOut((DEB_CHANNEL, "ProtectedPostToSTA hWnd:%x pCall:%x\n", + pOxid->hServerSTA, call)); + + // ensure we are not posting to ourself and that the apartment is not + // an MTA apartment. + Win4Assert((pOxid->dwTid != GetCurrentThreadId()) && + ((pOxid->dwFlags & OXIDF_MTASERVER) == 0)); + ASSERT_LOCK_HELD + + HRESULT result; + + if (!(pOxid->dwFlags & OXIDF_STOPPED)) + { +#if DBG==1 + DebugIsValidWindow(pOxid->hServerSTA); + _try + { +#endif + // Pass the thread id to aid debugging. + if (PostMessage((HWND)pOxid->hServerSTA, WM_OLE_ORPC_POST, + GetCurrentThreadId(), (DWORD)call)) + result = S_OK; + else + result = RPC_E_SYS_CALL_FAILED; + +#if DBG==1 + } + _except( ProtectedPostExceptionFilter(GetExceptionCode(), + GetExceptionInformation()) ) + { + } + Win4Assert( IsWindow((HWND) pOxid->hServerSTA) ); +#endif + } + else + result = RPC_E_SERVER_DIED_DNE; + + return result; +} + + +/***************************************************************************/ +HRESULT SwitchSTA( OXIDEntry *pOxid, CChannelCallInfo **call ) +{ + TRACECALL(TRACE_RPC, "SwitchSTA"); + ComDebOut((DEB_CHANNEL, "SwitchSTA hWnd:%x pCall:%x hEvent:%x\n", + (*call)->hWndCaller, (*call), (*call)->event)); + gOXIDTbl.ValidateOXID(); + ASSERT_LOCK_RELEASED + + // Transmit the call. + HRESULT result = TransmitCall( pOxid, *call ); + + // the transmit was successful and the reply isn't already here so wait. + if (result == RPC_S_CALLPENDING) + { + // This is a single-threaded apartment so enter the modal loop. + result = ModalLoop( *call ); + } + + if (result == S_OK) + result = (*call)->hResult; + else if (result == RPC_E_CALL_CANCELED) + Cancel( call ); + + ASSERT_LOCK_RELEASED + gOXIDTbl.ValidateOXID(); + ComDebOut((DEB_CHANNEL, "SwitchSTA hr:%x\n", result)); + return result; +} + +/***************************************************************************/ +/* + This routine is called by the OLE Worker thread on the client side, + and by ThreadWndProc on the server side. + + For the client case, it calls ThreadSendReceive which will send the + the data over to the server side. + This routine notifies the COM thread when the call is complete. If the + call is canceled before completion, the routine cleans up. +*/ +void ThreadDispatch( CChannelCallInfo **ppcall) +{ + CChannelCallInfo *pcall = *ppcall; + gOXIDTbl.ValidateOXID(); + + // Dispatch the call. + if (pcall->edispatch == invoke_wd) + pcall->hResult = ComInvoke( pcall ); + else + pcall->hResult = ThreadSendReceive( pcall ); + + // Change the state to done; we cheat on non-local, recipient side since + // there is only one thread accessing the channel control; no need to + // lock and no need to check for cancel since it can't happen. + if (pcall->edispatch == invoke_wd && !pcall->Local()) + { + // non-local recipient; just set to done + Win4Assert(pcall->eState == in_progress_cs); + pcall->eState = server_done_cs; + } + else + { + // sender or local case; use lock in case other thread accesses it + LOCK + if (pcall->eState == in_progress_cs) + pcall->eState = server_done_cs; + UNLOCK + } + + // If the call completed, wake up the waiting thread. For local calls + // the client thread is waiting. For remote calls the helper thread is + // waiting. + if (pcall->eState == server_done_cs) + { + // only need to wake somebody for synchronous calls + if (pcall->category == CALLCAT_SYNCHRONOUS || + pcall->category == CALLCAT_INPUTSYNC) + { + // Don't do anything in an STA server for input synchronous + // calls since the other thread called here with SendMessage. + + if (pcall->category == CALLCAT_SYNCHRONOUS || + pcall->edispatch == sendreceive_wd || + IsMTAThread()) + { + + if (!pcall->Local() || !IsWOWThread()) + { + // remote calls (outside this process) always use events for + // notification. 32bit uses events for local calls too. + + // someone waiting (e.g., not a SendMessage-type call) + ComDebOut((DEB_CHANNEL,"SetEvent pInfo:%x hEvent:%x\n", + pcall, pcall->event)); + SetEvent( pcall->event ); + } + else + { + // NOTE NOTE NOTE NOTE NOTE NOTE NOTE + // 16bit OLE used to do PostMessage for the Reply; we + // tried using SetEvent (which is faster) but this caused + // compatibility problems for applications which had bugs that + // were hidden by the 16bit OLE DLLs because messages happened + // to be dispatched in a particular order (see NtBug 21616 for + // an example). To retain the old behavior, we do a + // PostMessage here. + + ComDebOut((DEB_CHANNEL, + "PostMessage Reply hWnd:%x pCall:%x hEvent:%x\n", + pcall->hWndCaller, pcall, pcall->event)); + + // Pass the thread id to aid debugging. + Verify(PostMessage(pcall->hWndCaller, + WM_OLE_ORPC_DONE, + GetCurrentThreadId(), (DWORD)pcall)); + } + + // pcall likely invalid here as other thread probably deleted it + } + } + + // Must be asynchronous. + else if (pcall->edispatch == invoke_wd) + { + // async call and on recipient side, free packet (no one waiting) + Win4Assert( pcall->category == CALLCAT_ASYNC ); + delete pcall; + *ppcall = NULL; + } + } + + // If the call was canceled, clean up. + else + { + // can only cancel when on client side or local call + Win4Assert(pcall->edispatch == sendreceive_wd || pcall->Local()); + + delete pcall; + *ppcall = NULL; + } + gOXIDTbl.ValidateOXID(); +} + +//+------------------------------------------------------------------------- +// +// Member: ThreadStart +// +// Synopsis: Apartment model only. Setup the window used for MSWMSG, +// local thread switches and the call control. +// +// History: 08-02-95 Rickhi Created, from various pieces +// +//-------------------------------------------------------------------------- +HRESULT ThreadStart(void) +{ + Win4Assert(IsSTAThread()); + HRESULT hr = S_OK; + RPC_STATUS sc; + + LOCK // lock since GetLocalOXIDEntry expects it + OXIDEntry *pOxid = GetLocalOXIDEntry(); + Win4Assert(pOxid != NULL); //already created so cant fail + UNLOCK + + + if (GetCurrentThreadId() == gdwMainThreadId && hwndOleMainThread != NULL) + { + // this is the main thread, we can just re-use the already + // existing gMainThreadWnd. + + pOxid->hServerSTA = hwndOleMainThread; + } + else + { + // Create a new window for use by the current thread for the + // apartment model. The window is destroyed in ThreadStop. + + Win4Assert(gOleWindowClass != NULL); + pOxid->hServerSTA = CreateWindowEx(0, + gOleWindowClass, + TEXT("OLEChannelWnd"), + // must use WS_POPUP so the window does not get + // assigned a hot key by user. + (WS_DISABLED | WS_POPUP), + CW_USEDEFAULT, + CW_USEDEFAULT, + CW_USEDEFAULT, + CW_USEDEFAULT, + NULL, + NULL, + g_hinst, + NULL); + } + + if (pOxid->hServerSTA) + { + DebugIsValidWindow(pOxid->hServerSTA); + + // Override the window proc function + SetWindowLong((HWND)pOxid->hServerSTA, GWL_WNDPROC, (LONG)ThreadWndProc); + + + // get the local call control object, and register the + // the window with it. Note that it MUST exist cause we + // created it in ChannelThreadInitialize. + + CAptCallCtrl *pCallCtrl = GetAptCallCtrl(); + pCallCtrl->Register((HWND) pOxid->hServerSTA, WM_USER, 0x7fff ); + } + else + { + hr = MAKE_SCODE(SEVERITY_ERROR, FACILITY_WIN32, GetLastError()); + } + + ComDebOut((DEB_CALLCONT, "ThreadStart returns %x\n", hr)); + return hr; +} + + +//+------------------------------------------------------------------------- +// +// Member: ThreadCleanup +// +// Synopsis: Release the window for the thread. +// +//-------------------------------------------------------------------------- +void ThreadCleanup() +{ + LOCK + OXIDEntry *pOxid = GetLocalOXIDEntry(); + UNLOCK + + if (pOxid != NULL) + { + Win4Assert( (pOxid->dwFlags & OXIDF_MTASERVER) == 0 ); + + // Destroy the window. This will unblock any pending SendMessages. + if (pOxid->hServerSTA == hwndOleMainThread) + { + // restore the window proceedure + SetWindowLong(hwndOleMainThread, GWL_WNDPROC, + (LONG)OleMainThreadWndProc); + } + else + { + // This may fail if threads get terminated. + DestroyWindow((HWND) pOxid->hServerSTA); + } + + pOxid->hServerSTA = NULL; + } + + ComDebOut((DEB_CALLCONT, "ThreadCleanup called.\n")); +} + +//+------------------------------------------------------------------------- +// +// Member: ThreadStop +// +// Synopsis: Per thread uninitialization +// +// History: ??-???-?? ? Created +// 05-Jul-94 AlexT Separated thread and process uninit +// +// Notes: We are not holding the single thread mutex during this call +// +//-------------------------------------------------------------------------- +STDAPI_(void) ThreadStop(void) +{ + LOCK + + OXIDEntry *pOxid = GetLocalOXIDEntry(); + if (pOxid != NULL) + { + // Change state + pOxid->dwFlags |= OXIDF_STOPPED; + } + + UNLOCK + + + if (pOxid != NULL) + { + // Stop MSWMSG. + I_RpcServerStopListening(); + + if (pOxid->dwFlags & OXIDF_MTASERVER) + { + if (pOxid->cCalls != 0) + { + Win4Assert( pOxid->hComplete != NULL ); + g_mxsSingleThreadOle.Release(); + WaitForSingleObject( pOxid->hComplete, INFINITE ); + g_mxsSingleThreadOle.Request(); + // a new thread may have been initialized while we released + // the lock, so we cant assert that the cCalls is zero. + } + } + else + { + // Single-threaded apartment so wait for all current calls + // to complete. + + ASSERT_LOCK_RELEASED + + MSG msg; + BOOL got_quit = FALSE; + WPARAM quit_val; + + while(PeekMessage(&msg, (HWND) pOxid->hServerSTA, WM_USER, + 0x7fff, PM_REMOVE | PM_NOYIELD)) + { + if (msg.message == WM_QUIT) + { + got_quit = TRUE; + quit_val = msg.wParam; + } + else + { + DispatchMessage(&msg); + } + } + + if (got_quit) + { + PostQuitMessage( quit_val ); + } + } + } + + ComDebOut((DEB_CALLCONT, "ThreadStop called.\n")); +} + + +//+------------------------------------------------------------------------- +// +// Function: ThreadWndProc, Internal +// +// Synopsis: Dispatch COM windows messages. This routine is only called +// for Single-Threaded Apartments. It dispatches calls and call +// complete messages. If it does not recognize the message, it +// calls MSWMSG to dispatch it. +// +//-------------------------------------------------------------------------- +LRESULT ThreadWndProc(HWND window, UINT message, WPARAM unused, LPARAM params) +{ + Win4Assert(IsSTAThread()); + + if (message == WM_OLE_ORPC_POST || + message == WM_OLE_ORPC_SEND) + { + ASSERT_LOCK_RELEASED + + CChannelCallInfo *call = (CChannelCallInfo *) params; + ComDebOut((DEB_CHANNEL, "ThreadWndProc: Incoming Call pCall:%x\n", call)); + + // Dispatch all calls through ThreadDispatch. Local calls may be + // canceled. Server-side, non-local calls cannot be canceled. Send + // message calls (event == NULL) are handled as well. + + call->edispatch = invoke_wd; + ThreadDispatch( &call ); + + ASSERT_LOCK_RELEASED + return 0; + } + else if (message == WM_OLE_ORPC_DONE) + { + ASSERT_LOCK_RELEASED + + // call completed - only happens InWow() + CChannelCallInfo *call = (CChannelCallInfo *) params; + ComDebOut((DEB_CHANNEL, "ThreadWndProc: Call Completed hWnd:%x pCall:%x\n", window, call)); + + if (call->eState == canceled_cs) + { + // canceled, throw it away + delete call; + } + else + { + // Notify the modal loop that the call is complete. + call->eState = got_done_msg_cs; + } + + ASSERT_LOCK_RELEASED + return 0; + } + else if (message == WM_OLE_ORPC_RELRIFREF) + { + ASSERT_LOCK_RELEASED + + HandlePostReleaseRifRef(params); + + ASSERT_LOCK_RELEASED + return 0; + } + else if (message == WM_OLE_GETCLASS) + { + return OleMainThreadWndProc(window, message, unused, params); + } + else + { + // when the window is first created we are holding the lock, and the + // creation of the window causes some messages to be dispatched. + ASSERT_LOCK_DONTCARE + + // check if the window is being destroyed because of UninitMainWindow + // or because of system shut down. Only destroy it in the former case. + if ((message == WM_DESTROY || message == WM_CLOSE) && + window == hwndOleMainThread && + gfDestroyingMainWindow == FALSE) + { + ComDebOut((DEB_WARN, "Attempted to destroy window outside of UninitMainThreadWnd")); + return 0; + } +#ifdef _CHICAGO_ + // Otherwise let the default window procedure have the message. + return DefWindowProc( window, message, unused, params ); +#else + return I_RpcWindowProc( window, message, unused, params ); +#endif + } +} + + + +/***************************************************************************/ +/* + Return S_OK if the call completed successfully. + Return RPC_S_CALL_PENDING if the caller should block. + Return an error if the call failed. +*/ +HRESULT TransmitCall( OXIDEntry *pOxid, CChannelCallInfo *call ) +{ + ComDebOut((DEB_CHANNEL, "TransmitCall pCall:%x\n", call)); + ASSERT_LOCK_RELEASED + + BOOL fDispCall = FALSE; + BOOLEAN wait = FALSE; + HRESULT result; + + // Don't touch the call hresult after the other thread starts, + // otherwise we might erase the results of the other thread. + // Since we never want signalled events returned to the cache, always + // wait on the event at least once. For example, the post message + // succeeds and the call completes immediately. Return RPC_S_CALLPENDING even + // though the call already has a S_OK in it. + + + if (call->Local()) + { + // server is in this process. + + if (!(pOxid->dwFlags & OXIDF_MTASERVER)) + { + // server is in an STA apartment + + if (call->category == CALLCAT_INPUTSYNC) + { + // Inputsync call. Send the message. + if (!(pOxid->dwFlags & OXIDF_STOPPED)) + { + // On CoUninitialize this may fail when the window is destroyed. + // Pass the thread id to aid debugging. + SetLastError( 0 ); + SendMessage((HWND)pOxid->hServerSTA, WM_OLE_ORPC_SEND, + GetCurrentThreadId(), (DWORD) call); + + if (GetLastError() != 0) + { + call->hResult = RPC_E_SERVER_DIED; + } + } + else + { + call->hResult = RPC_E_SERVER_DIED_DNE; + } + } + else if (call->category == CALLCAT_ASYNC) + { + // Async call. Copy message, post message and return. + + LOCK + CChannelCallInfo *copy = MakeAsyncCopy( call ); + if (copy == NULL) + { + call->hResult = RPC_E_OUT_OF_RESOURCES; + } + else + { + call->hResult = ProtectedPostToSTA( pOxid, copy ); + + if (call->hResult != S_OK) + { + delete copy; + } + } + UNLOCK + } + else + { + // Sync call. Post the message and wait for a reply. + + LOCK + + Win4Assert(call->category == CALLCAT_SYNCHRONOUS); + call->hResult = S_OK; + if (!IsWOWThread()) + { + // Get an event from the cache. In 32bit, replyies are done + // via Events, but for 16bit, repliest are done with PostMsg, + // so we dont need an event. Not having an event makes the + // callctrl modal loop a little faster. + call->hResult = gEventCache.Get( &call->event ); + } + else + { + Win4Assert( GetLocalOXIDEntry() != NULL ); + call->hWndCaller = (HWND) GetLocalOXIDEntry()->hServerSTA; + call->event = NULL; + } + + if (call->hResult == S_OK) + { + // Post a message to server + call->hResult = RPC_S_CALLPENDING; + result = ProtectedPostToSTA( pOxid, call ); + + if (result != S_OK) + call->hResult = result; + else + wait = TRUE; + } + + UNLOCK + } + } + else + { + // server is in an MTA apartment. Transmit the call by having + // a worker thread invoke the server directly. Async calls to + // a FT server are treated as SYNC calls and should have been + // converted by this point, so we never expect to see callcat + // ASYNC. + + Win4Assert(call->category != CALLCAT_ASYNC); + + wait = TRUE; + call->edispatch = invoke_wd; + fDispCall = TRUE; + } + } + else + { + // server is in a different process or on a different machine. + + if (call->category == CALLCAT_ASYNC) + { + // For async calls to other local processes, just make an RPC call. + call->hResult = ThreadSendReceive(call); + } + else + { + // Get a worker thread to do the RPC call. + wait = TRUE; + call->edispatch = sendreceive_wd; + fDispCall = TRUE; + } + } + + if (fDispCall) + { + // Dispatch to a worker thread to make the call + + LOCK + call->hResult = gEventCache.Get( &call->event ); + UNLOCK + + if (call->hResult == S_OK) + { + call->hResult = RPC_S_CALLPENDING; + + result = gRpcThreadCache.Dispatch( call ); + if (result != S_OK) + { + call->hResult = result; + wait = FALSE; + } + } + } + + ComDebOut((DEB_CHANNEL, "TransmitCall call->hResult:%x fWait:%x\n", + call->hResult, wait)); + + Win4Assert(wait || call->hResult != RPC_S_CALLPENDING); + return (wait) ? RPC_S_CALLPENDING : call->hResult; +} + +//+------------------------------------------------------------------------- +// +// Member: CEventCache::Cleanup +// +// Synopsis: Empty the event cache +// +// Notes: This function must be thread safe because Canceled calls +// can complete at any time. +// +//-------------------------------------------------------------------------- +void CEventCache::Cleanup(void) +{ + ASSERT_LOCK_HELD + + while (_ifree > 0) + { + _ifree--; // decrement the index first! + Verify(CloseHandle(_list[_ifree])); + _list[_ifree] = NULL; // NULL slot so we dont need to re-init + } + + // reset the index to 0 so reinitialization is not needed + _ifree = 0; +} + +//+------------------------------------------------------------------------- +// +// Member: CEventCache::Free +// +// Synopsis: returns an event to the cache if there are any available +// slots, frees the event if not. +// +// Notes: This function must be thread safe because Canceled calls +// can complete at any time. +// +//-------------------------------------------------------------------------- +void CEventCache::Free( HANDLE hEvent ) +{ + // there better be an event + Win4Assert(hEvent != NULL); + + LOCK + + // dont return anything to the cache if the process is no longer init'd. + if (_ifree < CEVENTCACHE_MAX_EVENT && gfChannelProcessInitialized) + { + // there is space, save this event. + +#if DBG==1 + // in debug, ensure slot is NULL + Win4Assert(_list[_ifree] == NULL && "Free: _list[_ifree] != NULL"); + + // enusre not already in the list + for (ULONG j=0; j<_ifree; j++) + { + Win4Assert(_list[j] != hEvent && "Free: event already in cache!"); + } + + // ensure that the event is in the non-signalled state + Win4Assert(WaitForSingleObject(hEvent, 0) == WAIT_TIMEOUT && + "Free: Signalled event returned to cache!\n"); +#endif + + _list[_ifree] = hEvent; + _ifree++; + } + else + { + // Otherwise really free it. + Verify(CloseHandle(hEvent)); + } + + UNLOCK +} + +//+------------------------------------------------------------------------- +// +// Member: CEventCache::Get +// +// Synopsis: gets an event from the cache if there are any available, +// allocates one if not. +// +// Notes: This function must be thread safe because Canceled calls +// can complete at any time. +// +//-------------------------------------------------------------------------- +HRESULT CEventCache::Get( HANDLE *hEvent ) +{ + ASSERT_LOCK_HELD + Win4Assert(_ifree <= CEVENTCACHE_MAX_EVENT); + + if (_ifree > 0) + { + // there is an event in the cache, use it. + _ifree--; + *hEvent = _list[_ifree]; + +#if DBG==1 + // in debug, NULL the slot. + _list[_ifree] = NULL; +#endif + + return S_OK; + } + + // no free event in the cache, allocate a new one. +#ifdef _CHICAGO_ + *hEvent = CreateEventA( NULL, FALSE, FALSE, NULL ); +#else //_CHICAGO_ + *hEvent = CreateEvent( NULL, FALSE, FALSE, NULL ); +#endif //_CHICAGO_ + + if (*hEvent) + return S_OK; + + Win4Assert(*hEvent != NULL && "CEventCache:GetEvent returning NULL"); + return RPC_E_OUT_OF_RESOURCES; +} diff --git a/private/ole32/com/dcomrem/chancont.hxx b/private/ole32/com/dcomrem/chancont.hxx new file mode 100644 index 000000000..3fe24bf1a --- /dev/null +++ b/private/ole32/com/dcomrem/chancont.hxx @@ -0,0 +1,171 @@ +#ifndef _CHANCONT_HXX_ +#define _CHANCONT_HXX_ + +#include <wtypes.h> +#include <OleSpy.hxx> +#include <ipidtbl.hxx> + + +//+------------------------------------------------------------------------- +// +// class: CEventCache +// +// Synopsis: Since ORPC uses events very frequently, we keep a small +// internal cache of them. There is only one of them, so +// we use static initializers to reduce Init time. +// +// History: 25-Oct-95 Rickhi Made data static +// +//-------------------------------------------------------------------------- + +// dont change this value without changing the static initializer. +#define CEVENTCACHE_MAX_EVENT 8 + +class CEventCache : public CPrivAlloc +{ +public: + void Free( HANDLE ); + HRESULT Get ( HANDLE * ); + + void Cleanup(void); + +private: + + static HANDLE _list[CEVENTCACHE_MAX_EVENT]; + static DWORD _ifree; +}; + +extern CEventCache gEventCache; + + +/***************************************************************************/ + +typedef enum +{ + in_progress_cs, + server_done_cs, + got_done_msg_cs, + canceled_cs +} ECallState; + +typedef enum +{ + CF_LOCKED = 0x1, // Set when free buffer must call UnlockClient + CF_PROCESS_LOCAL = 0x2, // Set for process local calls + CF_WAS_IMPERSONATING = 0x4, // Client was impersonating before call started +} ECallFlags; + +typedef enum +{ + none_wd, // dont call anything + invoke_wd, // call ComInvoke + sendreceive_wd // call ThreadSendReceive +} EWhichDispatch; + +class CRpcChannelBuffer; + +class CChannelCallInfo +{ + public: + CChannelCallInfo(); + CChannelCallInfo( CALLCATEGORY callcat, + RPCOLEMESSAGE *message, + DWORD iFlags, + REFIPID ipidServer, + DWORD destctx, + CRpcChannelBuffer *channel, + DWORD authn_level ); + ~CChannelCallInfo(); + BOOL Local () { return iFlags & CF_PROCESS_LOCAL; } + BOOL Locked () { return iFlags & CF_LOCKED; } + + // Channel controller fields. + CALLCATEGORY category; + DWORD iFlags; // ECallFlags + ECallState eState; + SCODE hResult; // SCODE or exception code + HANDLE event; // caller wait event + HWND hWndCaller; // caller apartment hWnd (only used InWow) + EWhichDispatch edispatch; // which function to invoke in worker thread + IPID ipid; + + // Channel fields. + RPCOLEMESSAGE *pmessage; + DWORD server_fault; + DWORD iDestCtx; + CChannelCallInfo *pNext; + void *pHeader; + CRpcChannelBuffer *pChannel; + DWORD lSavedAuthnLevel; + DWORD lAuthnLevel; +}; + + +/***************************************************************************/ +/* Classes. */ + +/* + The channel controller switches threads for the channel. It is not +used in the free threaded mode. There are two basic scenarios: a local +call and a remote call. + + A local call looks like this. + + Client Server + SendReceive + SwitchSTA + TransmitCall + PostMessage + ModalLoop + MsgWaitForMultipleObjects + ThreadWndProc + ThreadDispatch + ComInvoke + AppInvoke + AptInvoke + Stub + SetEvent + + A remote call looks like this. + + Client ClientHelper ServerHelper Server + SendReceive + SwitchSTA + TransmitCall + SetEvent + ModalLoop + MsgWaitForMultipleObjects + ThreadSendReceive + RPC + ThreadInvoke + GetToSTA + PostMessage + WaitForSingleObject + ThreadWndProc + ThreadDispatch + ComInvoke + AppInvoke + AptInvoke + Stub + SetEvent + reply + SetEvent + + The actual thread switch mechanism (PostMessage, event) depend on +the call category, whether or not the call is local, whether or not the +call is in WOW, and the direction (request vs. reply). + +*/ + +/***************************************************************************/ +/* Externals. */ +extern CEventCache EventCache; + +HRESULT GetToSTA ( OXIDEntry *, CChannelCallInfo * ); +HRESULT SwitchSTA ( OXIDEntry *, CChannelCallInfo ** ); +void ThreadCleanup ( void ); +void ThreadDispatch ( CChannelCallInfo ** ); +HRESULT ThreadStart ( void ); +LRESULT ThreadWndProc (HWND window, UINT message, WPARAM unused, LPARAM params); + +#endif // _CHANCONT_HXX_ diff --git a/private/ole32/com/dcomrem/channelb.cxx b/private/ole32/com/dcomrem/channelb.cxx new file mode 100644 index 000000000..c602b2f40 --- /dev/null +++ b/private/ole32/com/dcomrem/channelb.cxx @@ -0,0 +1,2665 @@ +//+--------------------------------------------------------------------- +// +// Microsoft Windows +// Copyright (C) Microsoft Corporation, 1993 - 1994. +// +// File: d:\nt\private\cairole\com\remote\channelb.cxx +// +// Contents: This module contains thunking classes that allow proxies +// and stubs to use a buffer interface on top of RPC for Cairo +// +// Classes: CRpcChannelBuffer +// +// Functions: +// ChannelThreadInitialize +// ChannelProcessInitialize +// ChannelRegisterProtseq +// ChannelThreadUninitialize +// ChannelProcessUninitialize +// CRpcChannelBuffer::AddRef +// CRpcChannelBuffer::AppInvoke +// CRpcChannelBuffer::CRpcChannelBuffer +// CRpcChannelBuffer::FreeBuffer +// CRpcChannelBuffer::GetBuffer +// CRpcChannelBuffer::QueryInterface +// CRpcChannelBuffer::Release +// CRpcChannelBuffer::SendReceive +// DebugCoSetRpcFault +// DllDebugObjectRPCHook +// ThreadInvoke +// ThreadSendReceive +// +// History: 22 Jun 93 AlexMi Created +// 31 Dec 93 ErikGav Chicago port +// 15 Mar 94 JohannP Added call category support. +// 09 Jun 94 BruceMa Get call category from RPC message +// 19 Jul 94 CraigWi Added support for ASYNC calls +// 01-Aug-94 BruceMa Memory sift fix +// +//---------------------------------------------------------------------- + +#include <ole2int.h> +#include <channelb.hxx> +#include <hash.hxx> // CUUIDHashTable +#include <riftbl.hxx> // gRIFTbl +#include <callctrl.hxx> // CAptRpcChnl, AptInvoke +#include <threads.hxx> // CRpcThreadCache +#include <service.hxx> // StopListen +#include <resolver.hxx> // CRpcResolver + +extern "C" +{ +#include "orpc_dbg.h" +} + +#include <rpcdcep.h> +#include <rpcndr.h> + +#include <obase.h> +#include <ipidtbl.hxx> +#include <security.hxx> +#include <chock.hxx> + + +// This is needed for the debug hooks. See orpc_dbg.h +#pragma code_seg(".orpc") + +/***************************************************************************/ +/* Defines. */ + +#define MAKE_WIN32( status ) \ + MAKE_SCODE( SEVERITY_ERROR, FACILITY_WIN32, (status) ) + +// This should just return a status to runtime, but runtime does not +// support both comm and fault status yet. +#ifdef _CHICAGO_ +#define RETURN_COMM_STATUS( status ) return (status) +#else +#define RETURN_COMM_STATUS( status ) RpcRaiseException( status ) +#endif + +// Flags for local rpc header. +// These are only valid on a request (in a ORPCTHIS header). +const int ORPCF_INPUT_SYNC = ORPCF_RESERVED1; +const int ORPCF_ASYNC = ORPCF_RESERVED2; + +// These are only valid on a reply (in a ORPCTHAT header). +const int ORPCF_REJECTED = ORPCF_RESERVED1; +const int ORPCF_RETRY_LATER = ORPCF_RESERVED2; + +// Default size of hash table. +const int INITIAL_NUM_BUCKETS = 20; + + +/***************************************************************************/ +/* Typedefs. */ + +// This structure contains a copy of all the information needed to make a +// call. It is copied so it can be canceled without stray pointer references. +const DWORD CALLCACHE_SIZE = 8; +struct working_call : public CChannelCallInfo +{ + working_call( CALLCATEGORY callcat, + RPCOLEMESSAGE *message, + DWORD flags, + REFIPID ipidServer, + DWORD destctx, + CRpcChannelBuffer *channel, + DWORD authn_level ); + void *operator new ( size_t ); + void operator delete( void * ); + static void Cleanup ( void ); + static void Initialize ( void ); + + RPCOLEMESSAGE message; + +private: + static void *list[CALLCACHE_SIZE]; + static DWORD next; +}; + +void *working_call::list[CALLCACHE_SIZE] = + { NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL }; +DWORD working_call::next = 0; + + +/***************************************************************************/ +/* Macros. */ + +// Compute the size needed for the implicit this pointer including the +// various optional headers. +inline DWORD SIZENEEDED_ORPCTHIS( BOOL local, DWORD debug_size ) +{ + if (debug_size == 0) + return sizeof(WireThisPart1) + ((local) ? sizeof(LocalThis) : 0); + else + return sizeof(WireThisPart2) + ((local) ? sizeof(LocalThis) : 0) + + debug_size; +} + +inline DWORD SIZENEEDED_ORPCTHAT( DWORD debug_size ) +{ + if (debug_size == 0) + return sizeof(WireThatPart1); + else + return sizeof(WireThatPart2) + debug_size; +} + +inline CALLCATEGORY GetCallCat( void *header ) +{ + WireThis *inb = (WireThis *) header; + if (inb->c.flags & ORPCF_ASYNC) + return CALLCAT_ASYNC; + else if (inb->c.flags & ORPCF_INPUT_SYNC) + return CALLCAT_INPUTSYNC; + else + return CALLCAT_SYNCHRONOUS; +} + + +/***************************************************************************/ +/* Globals. */ + +// Should the debugger hooks be called? +BOOL DoDebuggerHooks = FALSE; +LPORPC_INIT_ARGS DebuggerArg = NULL; + +// The extension identifier for debug data. +const uuid_t DEBUG_EXTENSION = +{ 0xf1f19680, 0x4d2a, 0x11ce, {0xa6, 0x6a, 0x00, 0x20, 0xaf, 0x6e, 0x72, 0xf4}}; + +#if DBG == 1 +// strings that prefix the call +WCHAR *wszDebugORPCCallPrefixString[4] = { L"--> [BEG]", // Invoke + L" --> [end]", + L"<-- [BEG]", // SendReceive + L" <-- [end]" }; + +LONG ulDebugORPCCallNestingLevel[4] = {1, -1, 1, -1}; +#endif + + +SHashChain OIDBuckets[23] = { {&OIDBuckets[0], &OIDBuckets[0]}, + {&OIDBuckets[1], &OIDBuckets[1]}, + {&OIDBuckets[2], &OIDBuckets[2]}, + {&OIDBuckets[3], &OIDBuckets[3]}, + {&OIDBuckets[4], &OIDBuckets[4]}, + {&OIDBuckets[5], &OIDBuckets[5]}, + {&OIDBuckets[6], &OIDBuckets[6]}, + {&OIDBuckets[7], &OIDBuckets[7]}, + {&OIDBuckets[8], &OIDBuckets[8]}, + {&OIDBuckets[9], &OIDBuckets[9]}, + {&OIDBuckets[10], &OIDBuckets[10]}, + {&OIDBuckets[11], &OIDBuckets[11]}, + {&OIDBuckets[12], &OIDBuckets[12]}, + {&OIDBuckets[13], &OIDBuckets[13]}, + {&OIDBuckets[14], &OIDBuckets[14]}, + {&OIDBuckets[15], &OIDBuckets[15]}, + {&OIDBuckets[16], &OIDBuckets[16]}, + {&OIDBuckets[17], &OIDBuckets[17]}, + {&OIDBuckets[18], &OIDBuckets[18]}, + {&OIDBuckets[19], &OIDBuckets[19]}, + {&OIDBuckets[20], &OIDBuckets[20]}, + {&OIDBuckets[21], &OIDBuckets[21]}, + {&OIDBuckets[22], &OIDBuckets[22]} + }; + +CUUIDHashTable gClientRegisteredOIDs; + + +// flag whether or not the channel has been initialized for current process +BOOL gfChannelProcessInitialized = 0; +BOOL gfMTAChannelInitialized = 0; + +// count of multi-threaded apartment inits (see CoInitializeEx) +extern DWORD g_cMTAInits; + + +// Channel debug hook object. +CDebugChannelHook gDebugHook; + +// Channel error hook object. +CErrorChannelHook gErrorHook; + +#if DBG==1 +//------------------------------------------------------------------------- +// +// Function: GetInterfaceName +// +// synopsis: Gets the human readable name of an Interface given it's IID. +// +// History: 12-Jun-95 Rickhi Created +// +//------------------------------------------------------------------------- +LONG GetInterfaceName(REFIID riid, WCHAR *pwszName) +{ + // convert the iid to a string + CDbgGuidStr dbgsIID(riid); + + // Read the registry entry for the interface to get the interface name + LONG ulcb=256; + WCHAR szKey[80]; + + szKey[0] = L'\0'; + lstrcatW(szKey, L"Interface\\"); + lstrcatW(szKey, dbgsIID._wszGuid); + + LONG result = RegQueryValue( + HKEY_CLASSES_ROOT, + szKey, + pwszName, + &ulcb); + + Win4Assert( result == 0 ); + return result; +} + +//--------------------------------------------------------------------------- +// +// Function: DebugPrintORPCCall +// +// synopsis: Prints the interface name and method number to the debugger +// to allow simple ORPC call tracing. +// +// History: 12-Jun-95 Rickhi Created +// +//--------------------------------------------------------------------------- +void DebugPrintORPCCall(DWORD dwFlag, REFIID riid, DWORD iMethod, DWORD Callcat) +{ + if (CairoleInfoLevel & DEB_USER15) + { + Win4Assert (dwFlag < 4); + + // adjust the nesting level for this thread. + COleTls tls; + tls->cORPCNestingLevel += ulDebugORPCCallNestingLevel[dwFlag]; + + + // set the indentation string according to the nesting level + CHAR szNesting[100]; + memset(szNesting, 0x20, 100); + + if (tls->cORPCNestingLevel > 99) // watch for overflow + szNesting[99] = '\0'; + else + szNesting[tls->cORPCNestingLevel] = '\0'; + + + // construct the debug strings + WCHAR *pwszDirection = wszDebugORPCCallPrefixString[dwFlag]; + WCHAR wszName[100]; + GetInterfaceName(riid, wszName); + + ComDebOut((DEB_USER15, "%s%ws [%x] %ws:: %x\n", + szNesting, pwszDirection, Callcat, wszName, iMethod)); + } +} +#endif + +/***************************************************************************/ +void ByteSwapThis( DWORD drep, WireThis *inb ) +{ + if ((drep & NDR_LOCAL_DATA_REPRESENTATION) != NDR_LITTLE_ENDIAN) + { + // Extensions are swapped later. If we ever use the reserved field, + // swap it. + ByteSwapShort( inb->c.version.MajorVersion ); + ByteSwapShort( inb->c.version.MinorVersion ); + ByteSwapLong( inb->c.flags ); + // ByteSwapLong( inb->c.reserved1 ); + ByteSwapLong( inb->c.cid.Data1 ); + ByteSwapShort( inb->c.cid.Data2 ); + ByteSwapShort( inb->c.cid.Data3 ); + } +} + +/***************************************************************************/ +void ByteSwapThat( DWORD drep, WireThat *outb ) +{ + if ((drep & NDR_LOCAL_DATA_REPRESENTATION) != NDR_LITTLE_ENDIAN) + { + // Extensions are swapped later. + ByteSwapLong( outb->c.flags ); + } +} + +//+------------------------------------------------------------------- +// +// Function: ChannelProcessInitialize, public +// +// Synopsis: Initializes the channel subsystem per process data. +// +// History: 23-Nov-93 AlexMit Created +// +//-------------------------------------------------------------------- +STDAPI ChannelProcessInitialize() +{ + TRACECALL(TRACE_RPC, "ChannelProcessInitialize"); + ComDebOut((DEB_COMPOBJ, "ChannelProcessInitialize [IN]\n")); + + Win4Assert( (sizeof(WireThisPart1) & 7) == 0 ); + Win4Assert( (sizeof(WireThisPart2) & 7) == 0 ); + Win4Assert( (sizeof(LocalThis) & 7) == 0 ); + Win4Assert( (sizeof(WireThatPart1) & 7) == 0 ); + Win4Assert( (sizeof(WireThatPart2) & 7) == 0 ); + + // we want to take the gComLock since that prevents other Rpc + // threads from accessing anything we are about to create, in + // particular, the event cache and working_call cache are accessed + // by Rpc worker threads of cancelled calls. + + ASSERT_LOCK_RELEASED + LOCK + + HRESULT hr = S_OK; + + if (!gfChannelProcessInitialized) + { + // Initialize the interface hash tables, the OID hash table, and + // the MID hash table. We dont need to cleanup these on errors. + + gMIDTbl.Initialize(); + gOXIDTbl.Initialize(); + gRIFTbl.Initialize(); + gIPIDTbl.Initialize(); + gSRFTbl.Initialize(); + gClientRegisteredOIDs.Initialize(OIDBuckets); + + // Register the debug channel hook. + hr = CoRegisterChannelHook( DEBUG_EXTENSION, &gDebugHook ); + + // Register the error channel hook. + if(SUCCEEDED(hr)) + { + hr = CoRegisterChannelHook( ERROR_EXTENSION, &gErrorHook ); + } + + // Initialize security. + if (SUCCEEDED(hr)) + { + hr = InitializeSecurity(); + } + + // always set to TRUE if we initialized ANYTHING, regardless of + // whether there were any errors. That way, ChannelProcessUninit + // will cleanup anything we have initialized. + gfChannelProcessInitialized = TRUE; + } + + UNLOCK + ASSERT_LOCK_RELEASED + + if (FAILED(hr)) + { + // cleanup anything we have created thus far. + ChannelProcessUninitialize(); + } + + ComDebOut((DEB_COMPOBJ, "ChannelProcessInitialize [OUT] hr:%x\n", hr)); + return hr; +} + +//+------------------------------------------------------------------- +// +// Function: CleanupRegOIDs, public +// +// Synopsis: called to delete each node of the registered OID list. +// +//+------------------------------------------------------------------- +void CleanupRegOIDs(SHashChain *pNode) +{ + delete pNode; +} + +//+------------------------------------------------------------------- +// +// Function: ChannelProcessUninitialize, public +// +// Synopsis: Uninitializes the channel subsystem global data. +// +// History: 23-Nov-93 Rickhi Created +// +// Notes: This is called at process uninitialize, not thread +// uninitialize. +// +//-------------------------------------------------------------------- +STDAPI_(void) ChannelProcessUninitialize(void) +{ + TRACECALL(TRACE_RPC, "ChannelProcessUninitialize"); + ComDebOut((DEB_COMPOBJ, "ChannelProcessUninitialize [IN]\n")); + + if (gfChannelProcessInitialized) + { + // Stop accepting calls from the object resolver and flag that service + // is no longer initialized. This can result in calls being + // dispatched. Do not hold the lock around this call. + + UnregisterDcomInterfaces(); + } + + gResolver.ReleaseSCMProxy(); + + // we want to take the gComLock since that prevents other Rpc + // threads from accessing anything we are about to cleanup, in + // particular, the event cache and working_call are accessed by + // Rpc worker threaded for cancelled calls. + ASSERT_LOCK_RELEASED + LOCK + + if (gfChannelProcessInitialized) + { + // Release the interface tables. This causes RPC to stop dispatching + // DCOM calls. This can result in calls being dispatched. + // UnRegisterServerInterface releases and reaquires the lock each + // time it is called. + gRIFTbl.Cleanup(); + + if (gpLocalMIDEntry) + { + // release the local MIDEntry + DecMIDRefCnt(gpLocalMIDEntry); + gpLocalMIDEntry = NULL; + } + + // release the MTA apartment's OXIDEntry if there is one. Do this + // after the RIFTble cleanup so we are not processing any calls + // while it happens. + gOXIDTbl.ReleaseLocalMTAEntry(); + + if (gpsaCurrentProcess) + { + // delete the string bindings for the current process + PrivMemFree(gpsaCurrentProcess); + gpsaCurrentProcess = NULL; + } + + // cleanup the IPID, OXID, and MID tables + gOXIDTbl.FreeExpiredEntries(GetTickCount()+1); + gIPIDTbl.Cleanup(); + gOXIDTbl.Cleanup(); + gMIDTbl.Cleanup(); + gSRFTbl.Cleanup(); + + // Cleanup the OID registration table. + gClientRegisteredOIDs.Cleanup(CleanupRegOIDs); + + // Cleanup the call cache. + working_call::Cleanup(); + + // Release all cached threads. + gRpcThreadCache.ClearFreeList(); + + // cleanup the event cache + gEventCache.Cleanup(); + + // Cleanup the channel hooks. + CleanupChannelHooks(); + } + + // Always cleanup the RPC OXID resolver since security may initialize it. + gResolver.Cleanup(); + + // Cleanup security. + UninitializeSecurity(); + + // mark the channel as no longer intialized for this process + gfChannelProcessInitialized = FALSE; + gfMTAChannelInitialized = FALSE; + + UNLOCK + ASSERT_LOCK_RELEASED + + // release the static unmarshaler + if (gpStdMarshal) + { + ((CStdIdentity *)gpStdMarshal)->UnlockAndRelease(); + gpStdMarshal = NULL; + } + + ComDebOut((DEB_COMPOBJ, "ChannelProcessUninitialize [OUT]\n")); + return; +} + +//+------------------------------------------------------------------- +// +// Function: STAChannelInitialize, public +// +// Synopsis: Initializes the channel subsystem per thread data +// for single-threaded apartments. +// +// History: 23-Nov-93 Rickhi Created +// +// Notes: This is called at thread initialize, not process +// initialize. Cleanup is done in ChannelThreadUninitialize. +// +//-------------------------------------------------------------------- +STDAPI STAChannelInitialize(void) +{ + ComDebOut((DEB_COMPOBJ, "STAChannelInitialize [IN]\n")); + Win4Assert(IsSTAThread()); + + HRESULT hr = S_OK; + + if (!gfChannelProcessInitialized) + { + // process initialization has not been done, do that now. + if (FAILED(hr = ChannelProcessInitialize())) + return hr; + } + + // create the callctrl before calling ThreadStart, since the latter + // tries to register with the call controller. We might already have + // a callctrl if some DDE stuff has already run. + + COleTls tls; + + if (tls->pCallCtrl == NULL) + { + // assume OOM and try to create callctrl. ctor sets tls. + hr = E_OUTOFMEMORY; + CAptCallCtrl *pCallCtrl = new CAptCallCtrl(); + } + + if (tls->pCallCtrl) + { + // mark the channel as initialized now to prevent re-entracy in + // GetLocalEntry. + + tls->dwFlags |= OLETLS_CHANNELTHREADINITIALZED; + + // Precreate the thread window. The window is normally only used + // for requests (and thus created during marshalling). But in WOW + // it is used for responses (and thus created during initialization). + // We do it for normal cases here too in order to avoid recursion + // when marshaling the first interface. + + ASSERT_LOCK_RELEASED + LOCK + + OXIDEntry *pOxid; + hr = gOXIDTbl.GetLocalEntry( &pOxid ); + + UNLOCK + ASSERT_LOCK_RELEASED + + if (SUCCEEDED(hr)) + { + hr = ThreadStart(); + } + + // Clear the channel initialized flag if initialization fails. + // Everything gets cleaned up in uninitialize regardless of the + // channel flag. + if (FAILED(hr)) + tls->dwFlags &= ~OLETLS_CHANNELTHREADINITIALZED; + } + + ComDebOut((DEB_COMPOBJ, "STAChannelInitialize [OUT] hr:%x\n", hr)); + return hr; +} + +//+------------------------------------------------------------------- +// +// Function: MTAChannelInitialize, public +// +// Synopsis: Initializes the channel subsystem per thread data +// for multi-threaded apartments. +// +// History: 19-Mar-96 Rickhi Created +// +// Notes: This is called at thread initialize, not process +// initialize. Cleanup is done in ChannelThreadUninitialize. +// +//-------------------------------------------------------------------- +STDAPI MTAChannelInitialize(void) +{ + ComDebOut((DEB_COMPOBJ, "MTAChannelInitialize [IN]\n")); + Win4Assert(IsMTAThread()); + + HRESULT hr = S_OK; + + if (!gfChannelProcessInitialized) + { + // process initialization has not been done, do that now. + if (FAILED(hr = ChannelProcessInitialize())) + return hr; + } + + ASSERT_LOCK_RELEASED + LOCK + + if (!gfMTAChannelInitialized) + { + // Create the OXID entry for this apartment. Do it now to avoid + // any races with two threads creating it simultaneously. + + OXIDEntry *pOxid; + hr = gOXIDTbl.GetLocalEntry( &pOxid ); + if (SUCCEEDED(hr)) + { + pOxid->dwFlags &= ~ OXIDF_STOPPED; + gfMTAChannelInitialized = TRUE; + } + + } + + UNLOCK + ASSERT_LOCK_RELEASED + + ComDebOut((DEB_COMPOBJ, "MTAChannelInitialize [OUT] hr:%x\n", hr)); + return hr; +} + +//+------------------------------------------------------------------- +// +// Function: ChannelThreadUninitialize, private +// +// Synopsis: Uninitializes the channel subsystem per thread data. +// +// History: 23-Nov-93 Rickhi Created +// +// Notes: This is called at thread uninitialize, not process +// uninitialize. +// +//-------------------------------------------------------------------- +STDAPI_(void) ChannelThreadUninitialize(void) +{ + TRACECALL(TRACE_RPC, "ChannelThreadUninitialize"); + ComDebOut((DEB_COMPOBJ, "ChannelThreadUninitialize [IN]\n")); + + COleTls tls; + + if (tls->dwFlags & OLETLS_APARTMENTTHREADED) + { + // Cleanup the window for this thread. + ThreadCleanup(); + + // Free the apartment call control. + delete tls->pCallCtrl; + tls->pCallCtrl = NULL; + + // Free any registered MessageFilter that has not been picked + // up by the call ctrl. + if (tls->pMsgFilter) + { + tls->pMsgFilter->Release(); + tls->pMsgFilter = NULL; + } + + // release the OXIDEntry for this thread. + ASSERT_LOCK_RELEASED + LOCK + + gOXIDTbl.ReleaseLocalSTAEntry(); + + UNLOCK + ASSERT_LOCK_RELEASED + + // mark the thread as no longer intialized for the channel + tls->dwFlags &= ~OLETLS_CHANNELTHREADINITIALZED; + } + else + { + // the MTA channel is no longer initialized. + gfMTAChannelInitialized = FALSE; + } + + ComDebOut((DEB_COMPOBJ, "ChannelThreadUninitialize [OUT]\n")); +} + +// count of multi-threaded inits +//+------------------------------------------------------------------- +// +// Function: InitChannelIfNecessary, private +// +// Synopsis: Checks if the ORPC channel has been initialized for +// the current apartment and initializes if not. This is +// required by the delayed initialization logic. +// +// History: 26-Oct-95 Rickhi Created +// +//-------------------------------------------------------------------- +INTERNAL InitChannelIfNecessary() +{ + HRESULT hr; + COleTls tls(hr); + + if (FAILED(hr)) + return hr; + + if (!(tls->dwFlags & OLETLS_APARTMENTTHREADED)) + { + if (!gfMTAChannelInitialized) + { + if (g_cMTAInits > 0) + { + // initialize the MTAChannel + return MTAChannelInitialize(); + } + + // CoInitializeEx(MULTITHREADED) has not been called + return CO_E_NOTINITIALIZED; + } + } + else if (!(tls->dwFlags & OLETLS_CHANNELTHREADINITIALZED)) + { + if (tls->cComInits > 0 && + !(tls->dwFlags & OLETLS_THREADUNINITIALIZING)) + return STAChannelInitialize(); + + // CoInitializeEx(APARTMENTTHREADED) has not been called, + // or the thread is Uninitializing + return CO_E_NOTINITIALIZED; + } + + return S_OK; +} + + +/***************************************************************************/ +/* + Make a copy of a message for an asyncronous call. Also make a fake +reply for the call. +*/ +CChannelCallInfo *MakeAsyncCopy( CChannelCallInfo *original ) +{ + void *pBuffer = NULL; + WireThat *outb; + BOOL success; + + ASSERT_LOCK_HELD + + working_call *copy = new working_call( original->category, + original->pmessage, + original->iFlags, + original->ipid, + original->iDestCtx, + original->pChannel, + original->lAuthnLevel ); + + if (copy != NULL) + { + original->hResult = S_OK; + + if (original->pmessage->rpcFlags & RPCFLG_LOCAL_CALL) + { + // no need to duplicate the buffer, just use it as is. + copy->message.Buffer = original->pmessage->Buffer; + } + else + { + pBuffer = PrivMemAlloc8(original->pmessage->cbBuffer); + Win4Assert(((ULONG)pBuffer & 0x7) == 0 && "Buffer not aligned properly"); + + if (pBuffer != NULL) + { + copy->message.Buffer = pBuffer; + memcpy(pBuffer, original->pmessage->Buffer, + original->pmessage->cbBuffer); + } + else + { + original->hResult = RPC_E_OUT_OF_RESOURCES; + } + } + + if (SUCCEEDED(original->hResult)) + { + // pretend local so we don't touch rpc for more buffers, etc. + copy->message.rpcFlags |= RPCFLG_LOCAL_CALL; + + // Create a fake reply containing a result even though the + // client will never see it. + original->pmessage->cbBuffer = SIZENEEDED_ORPCTHAT(0) + 4; + + if (original->pmessage->rpcFlags & RPCFLG_LOCAL_CALL) + { + original->pmessage->Buffer = PrivMemAlloc8(original->pmessage->cbBuffer); + success = original->pmessage->Buffer != NULL; + } + else + { + success = I_RpcGetBuffer((RPC_MESSAGE *) original->pmessage) == RPC_S_OK; + } + + // simulate success in method call + if (success) + { + outb = (WireThat *) original->pmessage->Buffer; + outb->c.flags = ORPCF_NULL; + outb->c.unique = 0; + *(SCODE *)((WireThatPart1 *)outb + 1) = S_OK; + return copy; + } + } + + PrivMemFree(pBuffer); + delete copy; + } + + return NULL; +} + + +/***************************************************************************/ +STDMETHODIMP_(ULONG) CRpcChannelBuffer::AddRef( THIS ) +{ + // can't call AssertValid(FALSE) since it is used in asserts + InterlockedIncrement( (long *) &ref_count ); + return ref_count; +} + +/***************************************************************************/ +HRESULT CRpcChannelBuffer::AppInvoke( CChannelCallInfo *call, + IRpcStubBuffer *stub, + void *pv, + void *orig_stub_buffer, + LocalThis *localb ) +{ + ASSERT_LOCK_RELEASED + + RPC_MESSAGE *message = (RPC_MESSAGE *) call->pmessage; + void *orig_buffer = message->Buffer; + WireThat *outb = NULL; + HRESULT result; + + // Save a pointer to the inbound header. + call->pHeader = message->Buffer; + + // Adjust the buffer. + message->BufferLength -= (char *) orig_stub_buffer - (char *) message->Buffer; + message->Buffer = orig_stub_buffer; + message->ProcNum &= ~RPC_FLAGS_VALID_BIT; + + // if the incoming call is from a non-NDR client, then set a bit in + // the message flags field so the stub can figure out how to dispatch + // the call. This allows a 32bit server to simultaneously service a + // 32bit client using NDR and a 16bit client using non-NDR, in particular, + // to support OLE Automation. + if (localb != NULL && localb->flags & LOCALF_NONNDR) + message->RpcFlags |= RPCFLG_NON_NDR; + + if (IsMTAThread()) + { + // do multi-threaded apartment invoke + result = MTAInvoke((RPCOLEMESSAGE *)message, GetCallCat( call->pHeader ), + stub, this, &call->server_fault); + } + else + { + // do single-threaded apartment invoke + result = STAInvoke((RPCOLEMESSAGE *)message, GetCallCat( call->pHeader ), + stub, this, pv, &call->server_fault); + } + + // For local calls, just free the in buffer. For non-local calls, + // the RPC runtime does this for us. + if (message->RpcFlags & RPCFLG_LOCAL_CALL) + PrivMemFree( orig_buffer ); + + // If an exception occurred before a new buffer was allocated, + // set the Buffer field to point to the original buffer. + if (message->Buffer == orig_stub_buffer) + { + // The buffer pointer in the message must be correct so RPC can free it. + if (message->RpcFlags & RPCFLG_LOCAL_CALL) + message->Buffer = NULL; + else + message->Buffer = orig_buffer; + } + else if (message->Buffer != NULL) + { + // An out buffer exists, get the pointer to the channel header. + Win4Assert( call->pHeader != orig_buffer ); + message->BufferLength += (char *) message->Buffer - (char *) call->pHeader; + message->Buffer = call->pHeader; + outb = (WireThat *) message->Buffer; + } + + // If successful, adjust the buffer. + if (result == S_OK) + { + if (call->iDestCtx == MSHCTX_DIFFERENTMACHINE) + outb->c.flags = 0; + else + outb->c.flags = ORPCF_LOCAL; + + // For asynchronous calls, MSWMSG will delete the out buffer. If MSWMSG + // is not the transport, delete it here. Non-Mswmsg async calls have + // been converted to local calls. + if (message->RpcFlags & RPCFLG_LOCAL_CALL && + call->category == CALLCAT_ASYNC) + { + PrivMemFree( message->Buffer ); + message->Buffer = NULL; + } + } + else if (result == RPC_E_CALL_REJECTED) + { + // Call was rejected. If the caller is on another machine, just fail the + // call. + if (call->iDestCtx != MSHCTX_DIFFERENTMACHINE && outb != NULL) + { + // Otherwise return S_OK so the buffer gets back, but set the flag + // to indicate it was rejected. + outb->c.flags = ORPCF_LOCAL | ORPCF_REJECTED; + result = S_OK; + } + } + else if (result == RPC_E_SERVERCALL_RETRYLATER) + { + // Call was rejected. If the caller is on another machine, just fail the + // call. + if (call->iDestCtx != MSHCTX_DIFFERENTMACHINE && outb != NULL) + { + // Otherwise return S_OK so the buffer gets back, but set the flag + // to indicate it was rejected with retry later. + outb->c.flags = ORPCF_LOCAL | ORPCF_RETRY_LATER; + result = S_OK; + } + } + else if (message->RpcFlags & RPCFLG_LOCAL_CALL) + { + // call failed and the call is local, free the out buffer. For + // non-local calls the RPC runtime does this for us. + PrivMemFree( message->Buffer ); + message->Buffer = NULL; + } + + ASSERT_LOCK_RELEASED + return result; +} + +//+--------------------------------------------------------------------------- +// +// Function: AppInvokeExceptionFilter +// +// Synopsis: Determine if the application as thrown an exception we want +// to report. If it has, then print out enough information for +// the 'user' to debug the problem +// +// Arguments: [lpep] -- Exception context records +// +// History: 6-20-95 kevinro Created +// +// Notes: +// +// At the moment, I was unable to get this to work for Win95, so I have +// commented out the code. +// +//---------------------------------------------------------------------------- + + +#ifdef _CHICAGO_ + +// +// Win95 doesn't appear to support this functionality by default. +// + +inline LONG +AppInvokeExceptionFilter( + LPEXCEPTION_POINTERS lpep + ) +{ + return(EXCEPTION_EXECUTE_HANDLER); +} + +#else + +#include <imagehlp.h> + +#define SYM_HANDLE GetCurrentProcess() + +#if defined(_M_IX86) +#define MACHINE_TYPE IMAGE_FILE_MACHINE_I386 +#elif defined(_M_MRX000) +#define MACHINE_TYPE IMAGE_FILE_MACHINE_R4000 +#elif defined(_M_ALPHA) +#define MACHINE_TYPE IMAGE_FILE_MACHINE_ALPHA +#elif defined(_M_PPC) +#define MACHINE_TYPE IMAGE_FILE_MACHINE_POWERPC +#else +#error( "unknown target machine" ); +#endif + +LONG +AppInvokeExceptionFilter( + LPEXCEPTION_POINTERS lpep + ) +{ +#if DBG == 1 + BOOL rVal; + STACKFRAME StackFrame; + CONTEXT Context; + + + SymSetOptions( SYMOPT_UNDNAME ); + SymInitialize( SYM_HANDLE, NULL, TRUE ); + ZeroMemory( &StackFrame, sizeof(StackFrame) ); + Context = *lpep->ContextRecord; + +#if defined(_M_IX86) + StackFrame.AddrPC.Offset = Context.Eip; + StackFrame.AddrPC.Mode = AddrModeFlat; + StackFrame.AddrFrame.Offset = Context.Ebp; + StackFrame.AddrFrame.Mode = AddrModeFlat; + StackFrame.AddrStack.Offset = Context.Esp; + StackFrame.AddrStack.Mode = AddrModeFlat; +#endif + + + ComDebOut((DEB_FORCE,"An Exception occurred while calling into app\n")); + ComDebOut((DEB_FORCE, + "Exception address = 0x%x Exception number 0x%x\n", + lpep->ExceptionRecord->ExceptionAddress, + lpep->ExceptionRecord->ExceptionCode )); + + ComDebOut((DEB_FORCE,"The following stack trace is where the exception occured\n")); + ComDebOut((DEB_FORCE,"Frame RetAddr mod!symbol\n")); + do + { + rVal = StackWalk(MACHINE_TYPE,SYM_HANDLE,0,&StackFrame,&Context,ReadProcessMemory, + SymFunctionTableAccess,SymGetModuleBase,NULL); + + if (rVal) + { + DWORD dump[200]; + ULONG Displacement; + PIMAGEHLP_SYMBOL sym = (PIMAGEHLP_SYMBOL) &dump; + IMAGEHLP_MODULE ModuleInfo; + LPSTR pModuleName = "???"; + BOOL fSuccess; + + sym->SizeOfStruct = sizeof(dump); + + fSuccess = SymGetSymFromAddr(SYM_HANDLE,StackFrame.AddrPC.Offset, + &Displacement,sym); + + // + // If there is module name information available, then grab it. + // + if(SymGetModuleInfo(SYM_HANDLE,StackFrame.AddrPC.Offset,&ModuleInfo)) + { + pModuleName = ModuleInfo.ModuleName; + } + + if (fSuccess) + { + ComDebOut((DEB_FORCE, + "%08x %08x %s!%s + %x\n", + StackFrame.AddrFrame.Offset, + StackFrame.AddrReturn.Offset, + pModuleName, + sym->Name, + Displacement)); + } + else + { + ComDebOut((DEB_FORCE, + "%08x %08x %s!%08x\n", + StackFrame.AddrFrame.Offset, + StackFrame.AddrReturn.Offset, + pModuleName, + StackFrame.AddrPC.Offset)); + } + } + } while( rVal ); + + SymCleanup( SYM_HANDLE ); + +#endif + + return EXCEPTION_EXECUTE_HANDLER; +} +#endif // _CHICAGO_ + +/***************************************************************************/ +#if DBG == 1 +DWORD AppInvoke_break = 0; +DWORD AppInvoke_count = 0; +#endif + +HRESULT StubInvoke(RPCOLEMESSAGE *pMsg, IRpcStubBuffer *pStub, + IRpcChannelBuffer *pChnl, DWORD *pdwFault) +{ + ComDebOut((DEB_CHANNEL, "StubInvoke pMsg:%x pStub:%x pChnl:%x pdwFault:%x\n", + pMsg, pStub, pChnl, pdwFault)); + ASSERT_LOCK_RELEASED + + HRESULT hr; + +#if DBG==1 + DWORD dwMethod = pMsg->iMethod; + IID iidBeingCalled = ((RPC_SERVER_INTERFACE *) pMsg->reserved2[1])->InterfaceId.SyntaxGUID; +#endif + + _try + { + TRACECALL(TRACE_RPC, "StubInvoke"); +#if DBG == 1 + // + // On a debug build, we are able to break on a call by serial number. + // This isn't really 100% thread safe, but is still extremely useful + // when debugging a problem. + // + DWORD dwBreakCount = ++AppInvoke_count; + + ComDebOut((DEB_CHANNEL, "AppInvoke(0x%x) calling method 0x%x iid %I\n", + dwBreakCount,dwMethod, &iidBeingCalled)); + + if(AppInvoke_break == dwBreakCount) + { + DebugBreak(); + } +#endif + +#ifdef WX86OLE + if (! gcwx86.IsN2XProxy(pStub)) + { + IUnknown *pActual; + + hr = pStub->DebugServerQueryInterface((void **)&pActual); + if (SUCCEEDED(hr)) + { + if (gcwx86.IsN2XProxy(pActual)) + { + // If we are going to invoke a native stub that is + // connected to an object on the x86 side then + // set a flag in the Wx86 thread environment to + // let the thunk layer know that the call is a + // stub invoked call and allow any in or out + // custom interface pointers to be thunked as + // IUnknown rather than failing the interface thunking + gcwx86.SetStubInvokeFlag((BOOL)1); + } + pStub->DebugServerRelease(pActual); + } + } +#endif + + hr = pStub->Invoke(pMsg, pChnl); + + } + _except(AppInvokeExceptionFilter( GetExceptionInformation())) + { + hr = RPC_E_SERVERFAULT; + *pdwFault = GetExceptionCode(); + +#if DBG == 1 + // + // OLE catches exceptions when the server generates them. This is so we can + // cleanup properly, and allow the client to continue. + // + if (*pdwFault == STATUS_ACCESS_VIOLATION || + *pdwFault == STATUS_POSSIBLE_DEADLOCK || + *pdwFault == STATUS_INSTRUCTION_MISALIGNMENT || + *pdwFault == STATUS_DATATYPE_MISALIGNMENT ) + { + + WCHAR iidName[256]; + iidName[0] = 0; + char achProgname[256]; + achProgname[0] = 0; + + GetModuleFileNameA(NULL,achProgname,sizeof(achProgname)); + + GetInterfaceName(iidBeingCalled,iidName); + + ComDebOut((DEB_FORCE, + "OLE has caught a fault 0x%08x on behalf of the server %s\n", + *pdwFault, + achProgname)); + + ComDebOut((DEB_FORCE, + "The fault occured when OLE called the interface %I (%ws) method 0x%x\n", + &iidBeingCalled,iidName,dwMethod)); + + Win4Assert(!"The server application has faulted processing an inbound RPC request. Check the kernel debugger for useful output. OLE can continue but you probably want to stop and debug the application."); + } +#endif + } + + ASSERT_LOCK_RELEASED + ComDebOut((DEB_CHANNEL, "StubInvoke hr:%x dwFault:%x\n", hr, *pdwFault)); + return hr; +} + +/***************************************************************************/ +#if DBG==1 +LONG ComInvokeExceptionFilter( DWORD lCode, + LPEXCEPTION_POINTERS lpep ) +{ + ComDebOut((DEB_ERROR, "Exception 0x%x in ComInvoke at address 0x%x\n", + lCode, lpep->ExceptionRecord->ExceptionAddress)); + DebugBreak(); + return EXCEPTION_EXECUTE_HANDLER; +} +#endif + +/***************************************************************************/ +HRESULT ComInvoke( CChannelCallInfo *call ) +{ + TRACECALL(TRACE_RPC, "ComInvoke"); + ASSERT_LOCK_RELEASED + + RPC_MESSAGE *message = (RPC_MESSAGE *) call->pmessage; + LocalThis *localb; + void *saved_buffer; + RPC_STATUS status; + HRESULT result; + IPIDEntry *ipid_entry = NULL; + CRpcChannelBuffer *server_channel = NULL; + DWORD TIDCallerSaved; + BOOL fLocalSaved; + UUID saved_threadid; + IUnknown *save_context; + DWORD saved_authn_level; + char *stub_data; + WireThis *inb = (WireThis *) message->Buffer; + OXIDEntry *oxid; + HANDLE hWakeup = NULL; + + ComDebOut((DEB_CHANNEL, "ComInvoke callinfo:%x header:%x\n", + call, message->Buffer)); + + COleTls tls(result); + if (FAILED(result)) + return result; + + // Catch exceptions that might keep the lock. +#if DBG == 1 + _try + { +#endif + + // Find the IPID entry. Fail if the IPID or the OXID are not ready. + LOCK + ipid_entry = gIPIDTbl.LookupIPID( call->ipid ); + Win4Assert( ipid_entry == NULL || ipid_entry->pOXIDEntry != NULL ); + if (ipid_entry == NULL || (ipid_entry->dwFlags & IPIDF_DISCONNECTED) || + (ipid_entry->pOXIDEntry->dwFlags & OXIDF_STOPPED) || + ipid_entry->pChnl == NULL) + result = RPC_E_DISCONNECTED; + else if (ipid_entry->pStub == NULL) + result = E_NOINTERFACE; + + // Keep the server object and our associated objects alive during the call. + if (SUCCEEDED(result)) + { + oxid = ipid_entry->pOXIDEntry; + server_channel = ipid_entry->pChnl; + Win4Assert( server_channel != NULL && server_channel->pStdId != NULL ); + server_channel->pStdId->LockServer(); + InterlockedIncrement( &oxid->cCalls ); + } + UNLOCK + ASSERT_LOCK_RELEASED + + if (FAILED(result)) + { + return result; + } + + // Create a new security call context; + CServerSecurity security( call ); + save_context = tls->pCallContext; + tls->pCallContext = &security; + + // save the original threadid & copy in the new one. + if (!(tls->dwFlags & OLETLS_UUIDINITIALIZED)) + { + UuidCreate(&tls->LogicalThreadId); + tls->dwFlags |= OLETLS_UUIDINITIALIZED; + } + saved_threadid = tls->LogicalThreadId; + tls->LogicalThreadId = inb->c.cid; + + ComDebOut((DEB_CALLCONT, "ComInvoke: LogicalThreads Old:%I New:%I\n", + &saved_threadid, &tls->LogicalThreadId)); + + // Save the call info in TLS. + call->pNext = (CChannelCallInfo *) tls->pCallInfo; + tls->pCallInfo = call; + saved_authn_level = tls->dwAuthnLevel; + tls->dwAuthnLevel = call->lAuthnLevel; + + // Call the channel hooks. Set up as much TLS data as possible before + // calling the hooks so they can access it. + result = ServerNotify( + ((RPC_SERVER_INTERFACE *) message->RpcInterfaceInformation)->InterfaceId.SyntaxGUID, + (WireThis *) message->Buffer, + message->BufferLength, + (void **) &stub_data, + message->DataRepresentation ); + + // Find the local header. + if (inb->c.flags & ORPCF_LOCAL) + { + localb = (LocalThis *) stub_data; + stub_data += sizeof(LocalThis); + } + else + localb = NULL; + + // Set the caller TID. This is needed by some interop code in order + // to do focus management via tying queues together. We first save the + // current one so we can restore later to deal with nested calls + // correctly. + TIDCallerSaved = tls->dwTIDCaller; + fLocalSaved = tls->dwFlags & OLETLS_LOCALTID; + tls->dwTIDCaller = localb != NULL ? localb->client_thread : 0; + + if (call->iFlags & CF_PROCESS_LOCAL) + tls->dwFlags |= OLETLS_LOCALTID; // turn the local bit on + else + tls->dwFlags &= ~OLETLS_LOCALTID; // turn the local bit off + + // Continue dispatching the call. + if (result == S_OK) + { + result = server_channel->AppInvoke( + call, + (IRpcStubBuffer *) ipid_entry->pStub, + ipid_entry->pv, + stub_data, + localb ); + } + + // Restore the original thread id, call info, dest context and thread id. + tls->LogicalThreadId = saved_threadid; + tls->pCallInfo = call->pNext; + tls->dwTIDCaller = TIDCallerSaved; + tls->dwAuthnLevel = saved_authn_level; + + if (fLocalSaved) + tls->dwFlags |= OLETLS_LOCALTID; + else + tls->dwFlags &= ~OLETLS_LOCALTID; + + // Restore the security context; + tls->pCallContext = save_context; + security.EndCall(); + + // Decrement the call count. If the MTA is waiting to uninitialize + // and this is the last call, wake up the uninitializing thread, but + // do this *after* calling UnLockServer so the other thread does not + // blow away the server. + if (InterlockedDecrement( &oxid->cCalls ) == 0 && + (oxid->dwFlags & (OXIDF_MTASERVER | OXIDF_STOPPED)) == (OXIDF_MTASERVER | OXIDF_STOPPED)) + hWakeup = oxid->hComplete; + + // Release our hold on the object and channel. + server_channel->pStdId->UnLockServer(); + + if (hWakeup) + SetEvent(hWakeup); + + // Catch exceptions that might keep the lock. +#if DBG == 1 + } + _except( ComInvokeExceptionFilter(GetExceptionCode(), + GetExceptionInformation()) ) + { + } +#endif + + ASSERT_LOCK_RELEASED + return result; +} + +/***************************************************************************/ +CRpcChannelBuffer *CRpcChannelBuffer::Copy(OXIDEntry *pOXIDEntry, + REFIPID ripid, REFIID riid) +{ + Win4Assert( !(state & server_cs) ); + + CRpcChannelBuffer *chan; + + if (IsMTAThread()) + { + // make client side multi-threaded apartment version of channel + chan = new CMTARpcChnl(pStdId, pOXIDEntry, state); + } + else + { + // make client side single-threaded apartment version of channel + chan = new CAptRpcChnl(pStdId, pOXIDEntry, state); + } + + if (chan != NULL) + { + chan->state = proxy_cs | (state & ~client_cs); + chan->lAuthnLevel = lAuthnLevel; + } + + return chan; +} + +/***************************************************************************/ +HRESULT CRpcChannelBuffer::InitClientSideHandle() +{ + Win4Assert((state & proxy_cs)); + ASSERT_LOCK_HELD + + if (state & initialized_cs) + return S_OK; + + // Lookup the interface info. This cant fail. + pInterfaceInfo = gRIFTbl.GetClientInterfaceInfo(pIPIDEntry->iid); + + RPC_STATUS status; +#ifndef _CHICAGO_ + if (state & process_local_cs) + { + handle = NULL; + status = RPC_S_OK; + } + else +#endif + { + status = RpcBindingCopy(pOXIDEntry->hServerSTA, &handle); + + if (status == RPC_S_OK) + + // If this is a single threaded apartment, give LRPC the blocking + // hook. + if (state & mswmsg_cs) + status = I_RpcBindingSetAsync(handle, OleModalLoopBlockFn); + +#ifndef CHICAGO + // If the server is a single threaded apartment, tell LRPC to + // use MSWMSG to dispatch. + else if (pOXIDEntry->dwTid != 0) + status = I_RpcBindingSetAsync(handle, NULL); +#endif + + if (status == RPC_S_OK) + status = RpcBindingSetObject(handle, (GUID *)&pIPIDEntry->ipid); + } + + if (status == RPC_S_OK) + { + state |= initialized_cs; + return S_OK; + } + + return MAKE_WIN32(status); +} + + +/***************************************************************************/ +CRpcChannelBuffer::CRpcChannelBuffer(CStdIdentity *standard_identity, + OXIDEntry *pOXID, + DWORD eState ) +{ + ComDebOut((DEB_MARSHAL, "CRpcChannelBuffer %s Created this:%x pOXID:%x\n", + (eState & client_cs) ? "CLIENT" : "SERVER", this, pOXID)); + + // Fill in the easy fields first. + ref_count = 1; + pStdId = standard_identity; + handle = NULL; + pOXIDEntry = pOXID; + pIPIDEntry = NULL; + pInterfaceInfo = NULL; + hToken = NULL; + lAuthnLevel = gAuthnLevel; + state = eState; + state |= pOXID->dwPid == GetCurrentProcessId() ? process_local_cs : 0; + SetImpLevel( gImpLevel ); + + if ((pOXID->dwFlags & OXIDF_MSWMSG) && IsSTAThread()) + { + // use MSWMSG protocol with the blocking hook + state |= mswmsg_cs; + } + + if (state & (client_cs | proxy_cs)) + { + // Determine the destination context. + if (pOXID->dwFlags & OXIDF_MACHINE_LOCAL) + if (!IsWOWThread() && (state & process_local_cs)) + iDestCtx = MSHCTX_INPROC; + else + iDestCtx = MSHCTX_LOCAL; + else + iDestCtx = MSHCTX_DIFFERENTMACHINE; + } + else + { + // On the server side, the destination context isn't known + // untill a call arrives. + iDestCtx = 0; + } +} + +/***************************************************************************/ +CRpcChannelBuffer::~CRpcChannelBuffer() +{ + ComDebOut((DEB_MARSHAL, "CRpcChannelBuffer %s Deleted this:%x\n", + (state & server_cs) ? "SERVER" : "CLIENT", this)); + + if (handle != NULL) + RpcBindingFree( &handle ); + if (hToken != NULL) + CloseHandle( hToken ); +} + +/***************************************************************************/ +STDMETHODIMP CRpcChannelBuffer::FreeBuffer( RPCOLEMESSAGE *pMessage ) +{ + TRACECALL(TRACE_RPC, "CRpcChannelBuffer::FreeBuffer"); + ASSERT_LOCK_RELEASED + AssertValid(FALSE, TRUE); + + if (pMessage->Buffer == NULL) + return S_OK; + + // Pop the call stack. + COleTls tls; + Win4Assert( tls->pCallInfo != NULL ); + working_call *pCall = (working_call *) tls->pCallInfo; + tls->pCallInfo = pCall->pNext; + tls->dwAuthnLevel = pCall->lSavedAuthnLevel; + pMessage->Buffer = pCall->pHeader;; + + DeallocateBuffer(pCall->pmessage); + + // Resume any outstanding impersonation. + ResumeImpersonate( tls->pCallContext, pCall->iFlags & CF_WAS_IMPERSONATING ); + + // Release the AddRef we did earlier. Note that we cant do this until + // after DeallocateBuffer since it may release a binding handle that + // I_RpcFreeBuffer needs. + if (pCall->Locked()) + pStdId->UnLockClient(); + + pMessage->Buffer = NULL; + delete pCall; + + ASSERT_LOCK_RELEASED + return S_OK; +} + +//------------------------------------------------------------------------- +// +// Member: CRpcChannelBuffer::GetBuffer +// +// Synopsis: Calls ClientGetBuffer or ServerGetBuffer +// +//------------------------------------------------------------------------- +STDMETHODIMP CRpcChannelBuffer::GetBuffer( RPCOLEMESSAGE *pMessage, + REFIID riid ) +{ + gOXIDTbl.ValidateOXID(); + if (state & proxy_cs) + return ClientGetBuffer( pMessage, riid ); + else + return ServerGetBuffer( pMessage, riid ); +} + +//------------------------------------------------------------------------- +// +// Member: CRpcChannelBuffer::ClientGetBuffer +// +// Synopsis: Gets a buffer and sets up client side stuff +// +//------------------------------------------------------------------------- +HRESULT CRpcChannelBuffer::ClientGetBuffer( RPCOLEMESSAGE *pMessage, + REFIID riid ) +{ + TRACECALL(TRACE_RPC, "CRpcChannelBuffer::ClientGetBuffer"); + ASSERT_LOCK_RELEASED + + RPC_STATUS status; + CALLCATEGORY callcat = CALLCAT_SYNCHRONOUS; + ULONG debug_size; + ULONG num_extent; + WireThis *inb; + LocalThis *localb; + IID *logical_thread; + working_call *call; + DWORD flags; + BOOL resume; + DWORD orig_size = pMessage->cbBuffer; + COleTls tls; + + Win4Assert(state & proxy_cs); + AssertValid(FALSE, TRUE); + + // Don't allow remote calls if DCOM is disabled. + if (gDisableDCOM && iDestCtx == MSHCTX_DIFFERENTMACHINE) + return RPC_E_REMOTE_DISABLED; + + // Fetch the call category from the RPC message structure + if (pMessage->rpcFlags & RPCFLG_ASYNCHRONOUS) + { + // only allow async for these two interfaces for now + if (riid != IID_IAdviseSink && riid != IID_IAdviseSink2) + return E_UNEXPECTED; + callcat = CALLCAT_ASYNC; + } + else + { + logical_thread = TLSGetLogicalThread(); + if (logical_thread == NULL) + { + return RPC_E_OUT_OF_RESOURCES; + } + if (pMessage->rpcFlags & RPCFLG_INPUT_SYNCHRONOUS) + { + callcat = CALLCAT_INPUTSYNC; + } + } + + // Set the buffer complete flag for local calls. + pMessage->rpcFlags |= RPC_BUFFER_COMPLETE; + + // Note - RPC requires that the 16th bit of the proc num be set because + // we use the rpcFlags field of the RPC_MESSAGE struct. + pMessage->iMethod |= RPC_FLAGS_VALID_BIT; + + // if service object of destination is in same process, definitely local + // calls; async calls are also forced to be local. + if (state & process_local_cs) + { + pMessage->rpcFlags |= RPCFLG_LOCAL_CALL; + flags = CF_PROCESS_LOCAL; + } + else + flags = 0; + + // Find out if we need hook data. + debug_size = ClientGetSize( riid, &num_extent ); + + LOCK + + // Complete the channel initialization if needed. + status = InitClientSideHandle(); + if (status != RPC_S_OK) + { + UNLOCK; + ASSERT_LOCK_RELEASED; + return status; + } + + // Fill in the binding handle. Adjust the size. Clear the transfer + // syntax. Set the interface identifier. + if ((pMessage->rpcFlags & RPCFLG_LOCAL_CALL) == 0) + pMessage->reserved1 = handle; + pMessage->cbBuffer += SIZENEEDED_ORPCTHIS( pOXIDEntry->dwFlags & OXIDF_MACHINE_LOCAL, + debug_size ); + pMessage->reserved2[0] = 0; + pMessage->reserved2[1] = pInterfaceInfo; + Win4Assert( pMessage->reserved2[1] != NULL ); + + // Allocate a call record. + call = new working_call( callcat, pMessage, flags, pIPIDEntry->ipid, + iDestCtx, this, lAuthnLevel ); + pMessage->cbBuffer = orig_size; + UNLOCK + ASSERT_LOCK_RELEASED + + if (call == NULL) + return E_OUTOFMEMORY; + + // Suspend any outstanding impersonation and ignore failures. + SuspendImpersonate( tls->pCallContext, &resume ); + + // Get a buffer. + if (call->pmessage->rpcFlags & RPCFLG_LOCAL_CALL) + { + // NDR_DREP_ASCII | NDR_DREP_LITTLE_ENDIAN | NDR_DREP_IEEE + call->pmessage->dataRepresentation = 0x00 | 0x10 | 0x0000; + call->pmessage->Buffer = PrivMemAlloc8( call->pmessage->cbBuffer ); + if (call->pmessage->Buffer == NULL) + status = RPC_S_OUT_OF_MEMORY; + else + status = RPC_S_OK; + } + else + { + TRACECALL(TRACE_RPC, "I_RpcGetBuffer"); + status = I_RpcGetBuffer( (RPC_MESSAGE *) call->pmessage ); + } + + if (status != RPC_S_OK) + { + // Resume any outstanding impersonation. + ResumeImpersonate( tls->pCallContext, resume ); + + // Cleanup. + pMessage->cbBuffer = 0; + tls->fault = MAKE_WIN32( status ); + delete call; + return MAKE_WIN32( status ); + } + + // Save the impersonation flag. + if (resume) + call->iFlags |= CF_WAS_IMPERSONATING; + + // Chain the call info in TLS. + call->pNext = (CChannelCallInfo *)tls->pCallInfo; + tls->pCallInfo = call; + call->pHeader = call->message.Buffer; + + // Adjust the authentication level in TLS. + call->lSavedAuthnLevel = tls->dwAuthnLevel; + tls->dwAuthnLevel = lAuthnLevel; + + // Fill in the COM header. + pMessage->Buffer = call->message.Buffer; + inb = (WireThis *) pMessage->Buffer; + inb->c.version.MajorVersion = COM_MAJOR_VERSION; + inb->c.version.MinorVersion = COM_MINOR_VERSION; + inb->c.reserved1 = 0; + + // Generate a new logical thread for async calls. + if (callcat == CALLCAT_ASYNC) + UuidCreate( &inb->c.cid ); + // Find the logical thread id. + else + inb->c.cid = *logical_thread; + + if (pOXIDEntry->dwFlags & OXIDF_MACHINE_LOCAL) + inb->c.flags = ORPCF_LOCAL; + else + inb->c.flags = ORPCF_NULL; + + // Fill in any hook data and adjust the buffer pointer. + if (debug_size != 0) + { + pMessage->Buffer = FillBuffer( riid, &inb->d.ea, debug_size, num_extent, + TRUE ); + inb->c.unique = 0x77646853; // Any non-zero value. + } + else + { + pMessage->Buffer = (void *) &inb->d.ea; + inb->c.unique = FALSE; + } + + // Fill in the local header. + if (pOXIDEntry->dwFlags & OXIDF_MACHINE_LOCAL) + { + localb = (LocalThis *) pMessage->Buffer; + localb->client_thread = GetCurrentApartmentId(); + localb->flags = 0; + pMessage->Buffer = localb + 1; + if (callcat == CALLCAT_ASYNC) + inb->c.flags |= ORPCF_ASYNC; + else if (callcat == CALLCAT_INPUTSYNC) + inb->c.flags |= ORPCF_INPUT_SYNC; + + // if the caller is using a non-NDR proxy, set a bit in the local + // header flags so that server side stub knows which way to unmarshal + // the parameters. This lets a 32bit server simultaneously service calls + // from 16bit non-NDR clients and 32bit NDR clients, in particular, to + // support OLE Automation. + + if (pIPIDEntry->dwFlags & (IPIDF_NONNDRPROXY | IPIDF_NONNDRSTUB)) + localb->flags |= LOCALF_NONNDR; + } + + ComDebOut((DEB_CALLCONT, "ClientGetBuffer: LogicalThreadId:%I\n", + &(tls->LogicalThreadId))); + + ASSERT_LOCK_RELEASED + return S_OK; +} + +//------------------------------------------------------------------------- +// +// Member: CRpcChannelBuffer::ServerGetBuffer +// +// Synopsis: Gets a buffer and sets up server side stuff +// +//------------------------------------------------------------------------- +HRESULT CRpcChannelBuffer::ServerGetBuffer( RPCOLEMESSAGE *pMessage, + REFIID riid ) +{ + TRACECALL(TRACE_RPC, "CRpcChannelBuffer::ServerGetBuffer"); + ASSERT_LOCK_RELEASED + + RPC_STATUS status; + ULONG debug_size; + ULONG num_extent; + HRESULT result = S_OK; + WireThis *inb; + WireThat *outb; + CChannelCallInfo *call; + void *stub_data; + DWORD orig_size = pMessage->cbBuffer; + + Win4Assert( state & server_cs ); + + AssertValid(FALSE, TRUE); + + // Get the call info from TLS. + COleTls tls; + call = (CChannelCallInfo *) tls->pCallInfo; + Win4Assert( call != NULL ); + + // Find out if we need debug data. + pMessage->Buffer = call->pHeader; + debug_size = ServerGetSize( riid, &num_extent ); + + // Adjust the buffer size. + pMessage->cbBuffer += SIZENEEDED_ORPCTHAT( debug_size ); + + // Get a buffer. + if (pMessage->rpcFlags & RPCFLG_LOCAL_CALL) + { + // NDR_DREP_ASCII | NDR_DREP_LITTLE_ENDIAN | NDR_DREP_IEEE + pMessage->dataRepresentation = 0x00 | 0x10 | 0x0000; + pMessage->Buffer = PrivMemAlloc8( pMessage->cbBuffer ); + if (pMessage->Buffer == NULL) + status = RPC_S_OUT_OF_MEMORY; + else + status = RPC_S_OK; + } + else + { + TRACECALL(TRACE_RPC, "I_RpcGetBuffer"); + status = I_RpcGetBuffer( (RPC_MESSAGE *) pMessage ); + Win4Assert( call->pHeader != pMessage->Buffer || status != RPC_S_OK ); + } + + if (status != RPC_S_OK) + { + pMessage->cbBuffer = 0; + pMessage->Buffer = NULL; + tls->fault = MAKE_WIN32( status ); + return MAKE_WIN32( status ); + } + + // Fill in the outbound COM header. + call->pHeader = pMessage->Buffer; + outb = (WireThat *) pMessage->Buffer; + outb->c.flags = ORPCF_NULL; + pMessage->cbBuffer = orig_size; + if (debug_size != 0) + { + stub_data = FillBuffer( riid, &outb->d.ea, debug_size, num_extent, FALSE ); + outb->c.unique = 0x77646853; // Any non-zero value. + pMessage->Buffer = stub_data; + } + else + { + outb->c.unique = 0; + pMessage->Buffer = &outb->d.ea; + } + + ComDebOut((DEB_CALLCONT, "ServerGetBuffer: LogicalThreadId:%I\n", + &(tls->LogicalThreadId))); + ASSERT_LOCK_RELEASED + return S_OK; +} + +/***************************************************************************/ +STDMETHODIMP CRpcChannelBuffer::GetDestCtx( DWORD FAR* lpdwDestCtx, + LPVOID FAR* lplpvDestCtx ) +{ + TRACECALL(TRACE_RPC, "CRpcChannelBuffer::GetDestCtx"); + AssertValid(FALSE, FALSE); + + // On the client side, get the destination context from the channel. + if (state & (client_cs | proxy_cs)) + { + *lpdwDestCtx = iDestCtx; + } + + // On the server side, get the destination context from TLS. + else + { + COleTls tls; + Win4Assert( tls->pCallInfo != NULL ); + *lpdwDestCtx = ((CChannelCallInfo *) tls->pCallInfo)->iDestCtx; + } + + if (lplpvDestCtx != NULL) + *lplpvDestCtx = NULL; + + return S_OK; +} + +/***************************************************************************/ +STDMETHODIMP CRpcChannelBuffer::IsConnected( THIS ) +{ + // must be on right thread because it is only called by proxies and stubs. + AssertValid(FALSE, TRUE); + + // Server channels never know if they are connected. The only time the + // client side knows it is disconnected is after the standard identity + // has disconnected the proxy from the channel. In that case it doesn't + // matter. + return S_OK; +} + +/***************************************************************************/ +STDMETHODIMP CRpcChannelBuffer::QueryInterface( THIS_ REFIID riid, LPVOID FAR* ppvObj) +{ + AssertValid(FALSE, FALSE); + + // IMarshal is queried more frequently than any other interface, so + // check for that first. + + if (IsEqualIID(riid, IID_IMarshal)) + { + *ppvObj = (IMarshal *) this; + } + else if (IsEqualIID(riid, IID_IUnknown) || + IsEqualIID(riid, IID_IRpcChannelBuffer)) + { + *ppvObj = (IRpcChannelBuffer *) this; + } + else if (IsEqualIID(riid, IID_INonNDRStub) && + (state & proxy_cs) && pIPIDEntry && + (pIPIDEntry->dwFlags & IPIDF_NONNDRSTUB)) + { + // this interface is used to tell proxies whether the server side speaks + // NDR or not. Returns S_OK if NOT NDR. + *ppvObj = (IUnknown *) this; + } + else + { + *ppvObj = NULL; + return E_NOINTERFACE; + } + + AddRef(); + return S_OK; +} + +/***************************************************************************/ +STDMETHODIMP_(ULONG) CRpcChannelBuffer::Release( THIS ) +{ + // can't call AssertValid(FALSE) since it is used in asserts + ULONG lRef = ref_count - 1; + + if (InterlockedDecrement( (long*) &ref_count ) == 0) + { + delete this; + return 0; + } + else + { + return lRef; + } +} + +/***************************************************************************/ +STDMETHODIMP CRpcChannelBuffer::SendReceive( THIS_ RPCOLEMESSAGE *pMessage, + ULONG *status ) +{ + return CRpcChannelBuffer::SendReceive2(pMessage, status); +} + +/***************************************************************************/ +STDMETHODIMP CRpcChannelBuffer::SendReceive2( THIS_ RPCOLEMESSAGE *pMessage, + ULONG *status ) +{ + TRACECALL(TRACE_RPC, "CRpcChannelBuffer::SendReceive"); + ComDebOut((DEB_CHANNEL, "CRpcChannelBuffer::SendReceive pChnl:%x pMsg:%x\n", + this, pMessage)); + + AssertValid(FALSE, TRUE); + Win4Assert( state & proxy_cs ); + gOXIDTbl.ValidateOXID(); + ASSERT_LOCK_RELEASED + + HRESULT result; + working_call *call; + working_call *next_call; + IID iid; + WireThis *inb; + WireThat *outb; + DWORD saved_authn_level; + BOOL resume; + char *stub_data; + + // Get the information about the call stored in TLS + COleTls tls; + call = (working_call *) tls->pCallInfo; + Win4Assert( call != NULL ); + next_call = (working_call *) call->pNext; + saved_authn_level = call->lSavedAuthnLevel; + resume = call->iFlags & CF_WAS_IMPERSONATING; + + // Set up the header pointers. + inb = (WireThis *) call->pHeader; + iid = + ((RPC_CLIENT_INTERFACE *) ((RPC_MESSAGE *) call->pmessage)->RpcInterfaceInformation)->InterfaceId.SyntaxGUID; + + // we must ensure that we dont go away during this call. we will Release + // ourselves in the FreeBuffer call, or in the error handling at the + // end of this function. + pStdId->LockClient(); + +#if DBG==1 + DWORD CallCat = GetCallCat( inb ); + DebugPrintORPCCall(ORPC_SENDRECEIVE_BEGIN, iid, call->message.iMethod, CallCat); + RpcSpy((CALLOUT_BEGIN, inb, iid, call->message.iMethod, 0)); +#endif + + // Send the request. + if ((state & mswmsg_cs) || (IsMTAThread() && !call->Local())) + { + // For MSWMSG or non-local MTA, call ThreadSendReceive directly. + result = ThreadSendReceive( call ); + } + else + { + if (call->Local()) + call->message.reserved2[3] = NULL; + + if (IsMTAThread()) + { + LOCK + result = GetToSTA( pOXIDEntry, call); + UNLOCK + } + else + { + result = SwitchSTA( pOXIDEntry, (CChannelCallInfo **) &call ); + } + } + +#if DBG==1 + DebugPrintORPCCall(ORPC_SENDRECEIVE_END, iid, pMessage->iMethod, CallCat); + RpcSpy((CALLOUT_END, inb, iid, pMessage->iMethod, result)); +#endif + + // We can't look at the call structure if the call was canceled. + if (result != RPC_E_CALL_CANCELED) + { + // Get the reply header if there is a reply buffer. + if ((state & mswmsg_cs) && (pMessage->rpcFlags & RPCFLG_ASYNCHRONOUS)) + outb = NULL; + else + outb = (WireThat *) call->message.Buffer; + + // Local calls reuse pNext on the server side. + call->pNext = next_call; + + // Save the real buffer pointer for FreeBuffer. + call->pHeader = call->message.Buffer; + } + else + outb = NULL; + + // Figure out when to retry. + // FreeThreaded - treat retry as a failure. + // Apartment - return the buffer and let call control decide. + + if (result == S_OK) + { + // No buffer was returned for async calls on MSWMSG. + if (outb == NULL) + *status = S_OK; + else if (IsMTAThread()) + { + if (outb->c.flags & ORPCF_REJECTED) + result = RPC_E_CALL_REJECTED; + else if (outb->c.flags & ORPCF_RETRY_LATER) + result = RPC_E_SERVERCALL_RETRYLATER; + else + *status = S_OK; + } + else if (outb->c.flags & ORPCF_REJECTED) + *status = (ULONG) RPC_E_CALL_REJECTED; + else if (outb->c.flags & ORPCF_RETRY_LATER) + *status = (ULONG) RPC_E_SERVERCALL_RETRYLATER; + else + *status = S_OK; + } + + // Check the packet extensions. + if (result != RPC_E_CALL_CANCELED) + { + stub_data = (char *) call->message.Buffer; + result = ClientNotify( iid, outb, call->message.cbBuffer, + (void **) &stub_data, + call->message.dataRepresentation, + result ); + } + else + result = ClientNotify( iid, outb, 0, (void **) &stub_data, 0, result ); + + // Call succeeded. + if (result == S_OK && outb != NULL) + { + // The locked flag lets FreeBuffer know that it has to call + // RH->UnlockClient. + call->iFlags |= CF_LOCKED; + pMessage->Buffer = stub_data; + pMessage->cbBuffer = call->message.cbBuffer - + (stub_data - (char *) call->message.Buffer); + pMessage->dataRepresentation = call->message.dataRepresentation; + result = *status; + + // Copy a portion of the message structure that RPC updated on SendReceive. + // This is needed to free the buffer. Note that we still have to free + // the buffer in certain failure cases (reject). + pMessage->reserved2[2] = call->message.reserved2[2]; + + } + else + { + // Resume any outstanding impersonation. + ResumeImpersonate( tls->pCallContext, resume ); + + // Clean up the call. + pStdId->UnLockClient(); + tls->pCallInfo = next_call; + tls->dwAuthnLevel = saved_authn_level; + delete call; + + // Make sure FreeBuffer doesn't try to free the in buffer. + pMessage->Buffer = NULL; + + // If the result is server fault, get the exception code from the CChannelCallInfo. + if (result == RPC_E_SERVERFAULT) + { + *status = call->server_fault; + } + // Everything else is a comm fault. + else if (result != S_OK) + { + *status = result; + result = RPC_E_FAULT; + } + tls->fault = *status; + + // Since result is almost always mapped to RPC_E_FAULT, display the + // real error here to assist debugging. + if (*status != S_OK) + ComDebOut((DEB_CHANNEL, "ORPC call failed. status = %x\n", *status)); + } + + ASSERT_LOCK_RELEASED + gOXIDTbl.ValidateOXID(); + ComDebOut((DEB_CHANNEL, "CRpcChannelBuffer::SendReceive hr:%x\n", result)); + return result; +} + +/***************************************************************************/ +HANDLE CRpcChannelBuffer::SwapSecurityToken( HANDLE hNew ) +{ + HANDLE hOld = hToken; + hToken = hNew; + return hOld; +} + +#if DBG == 1 +//+------------------------------------------------------------------- +// +// Member: CRpcChannelBuffer::AssertValid +// +// Synopsis: Validates that the state of the object is consistent. +// +// History: 25-Jan-94 CraigWi Created. +// +// DCOMWORK - Put in some asserts. +// +//-------------------------------------------------------------------- +void CRpcChannelBuffer::AssertValid(BOOL fKnownDisconnected, + BOOL fMustBeOnCOMThread) +{ + Win4Assert(state & (proxy_cs | client_cs | server_cs )); + + if (state & (client_cs | proxy_cs)) + { + ; + } + else if (state & server_cs) + { + Win4Assert( !(state & freethreaded_cs) ); + if (fMustBeOnCOMThread && IsSTAThread()) + Win4Assert(IsMTAThread() || pOXIDEntry->dwTid == GetCurrentThreadId()); + // ref count can be 0 in various stages of connection and disconnection + Win4Assert(ref_count < 0x7fff && "Channel ref count unreasonably high"); + + // the pStdId pointer can not be NULL + // Win4Assert(IsValidInterface(pStdId)); + } +} +#endif // DBG == 1 + + +/***************************************************************************/ +STDAPI_(ULONG) DebugCoGetRpcFault() +{ + HRESULT hr; + COleTls tls(hr); + + if (SUCCEEDED(hr)) + return tls->fault; + + return 0; +} + +/***************************************************************************/ +STDAPI_(void) DebugCoSetRpcFault( ULONG fault ) +{ + HRESULT hr; + COleTls tls(hr); + + if (SUCCEEDED(hr)) + tls->fault = fault; +} + +/***************************************************************************/ +extern "C" +BOOL _stdcall DllDebugObjectRPCHook( BOOL trace, LPORPC_INIT_ARGS pass_through ) +{ + if (!IsWOWThread()) + { + DoDebuggerHooks = trace; + DebuggerArg = pass_through; + return TRUE; + } + else + return FALSE; +} + +/***************************************************************************/ +BOOL LocalCall() +{ + CChannelCallInfo *call; + + // Get the call info from TLS. + COleTls tls; + call = (CChannelCallInfo *) tls->pCallInfo; + Win4Assert( call != NULL ); + return call->iFlags & CF_PROCESS_LOCAL; +} + +/***************************************************************************/ +LONG ThreadInvokeExceptionFilter( DWORD lCode, + LPEXCEPTION_POINTERS lpep ) +{ + ComDebOut((DEB_ERROR, "Exception 0x%x in ThreadInvoke at address 0x%x\n", + lCode, lpep->ExceptionRecord->ExceptionAddress)); + DebugBreak(); + return EXCEPTION_EXECUTE_HANDLER; +} + +/***************************************************************************/ +/* This routine returns both comm status and server faults to the runtime + by raising exceptions. If FreeThreading is true, ComInvoke will throw + exceptions to indicate server faults. These will not be caught and will + propogate directly to the runtime. If FreeThreading is false, ComInvoke + will return the result and fault in the CChannelCallInfo record. + + NOTE: + This function switches to the 32 bit stack under WIN95. + An exception has to be caught while switched to the 32 bit stack. + The exceptions has to be pass as a value and rethrown again on the + 16 bit stack (see SSInvoke in stkswtch.cxx) +*/ + +#ifdef _CHICAGO_ +DWORD +#else +void +#endif +SSAPI(ThreadInvoke)(RPC_MESSAGE *message ) +{ + HRESULT result = S_OK; + + TRACECALL(TRACE_RPC, "ThreadInvoke"); + ComDebOut((DEB_CHANNEL,"ThreadInvoke pMsg:%x\n", message)); + gOXIDTbl.ValidateOXID(); + ASSERT_LOCK_RELEASED + + BOOL success; + WireThis *inb = (WireThis *) message->Buffer; + IPID ipid; + RPC_STATUS status; + OXIDEntry *pOxid; + unsigned int transport_type; + DWORD authn_level; + + // Byte swap the header. + ByteSwapThis( message->DataRepresentation, inb ); + + // Validate several things: + // The packet size is larger then the first header size. + // No extra flags are set. + // The procedure number is greater then 2 (not QI, AddRef, Release). + if (sizeof(WireThisPart1) > message->BufferLength || + (inb->c.flags & ~(ORPCF_LOCAL | ORPCF_RESERVED1 | + ORPCF_RESERVED2 | ORPCF_RESERVED3 | ORPCF_RESERVED4)) != 0 || + message->ProcNum < 3) + RETURN_COMM_STATUS( RPC_E_INVALID_HEADER ); + + // Validate the version. + if (inb->c.version.MajorVersion != COM_MAJOR_VERSION || + inb->c.version.MinorVersion > COM_MINOR_VERSION) + RETURN_COMM_STATUS( RPC_E_VERSION_MISMATCH ); + + // Get the transport the call arrived on. + status = I_RpcServerInqTransportType( &transport_type ); + if (status != RPC_S_OK) + RETURN_COMM_STATUS( RPC_E_SYS_CALL_FAILED ); + + // Don't accept the local header on remote calls. + if (inb->c.flags & ORPCF_LOCAL) + { + if (transport_type != TRANSPORT_TYPE_LPC && + transport_type != TRANSPORT_TYPE_WMSG) + RETURN_COMM_STATUS( RPC_E_INVALID_HEADER ); + + // For local calls the authentication level will always be encrypt. + authn_level = RPC_C_AUTHN_LEVEL_PKT_PRIVACY; + } + + // Don't accept remote calls if DCOM is diabled. + else if (gDisableDCOM && + (transport_type == TRANSPORT_TYPE_CN || transport_type == TRANSPORT_TYPE_DG)) + RETURN_COMM_STATUS( RPC_E_CALL_REJECTED ); + + // Lookup the authentication level. + else + { + result = RpcBindingInqAuthClient( message->Handle, NULL, + NULL, &authn_level, NULL, NULL ); + if (result == RPC_S_BINDING_HAS_NO_AUTH) + authn_level = RPC_C_AUTHN_LEVEL_NONE; + else if (result != RPC_S_OK) + { + Win4Assert( result == RPC_S_OUT_OF_RESOURCES ); + RETURN_COMM_STATUS( MAKE_WIN32( result ) ); + } + + // Verify the authentication level. + if (gAuthnLevel > RPC_C_AUTHN_LEVEL_NONE || + gImpLevel > 0) + { + if (authn_level < gAuthnLevel) + RETURN_COMM_STATUS( RPC_E_ACCESS_DENIED ); + } + } + +#if DBG==1 + _try + { +#endif + + // Find the ipid entry from the ipid. + status = RpcBindingInqObject( message->Handle, &ipid ); + if (status == RPC_S_OK) + { + // The CChannelCallInfo is created in a nested scope so that it + // is destroyed before the calls to throw an exception at the + // end of this function. + CChannelCallInfo call( + GetCallCat( inb ), + (RPCOLEMESSAGE *) message, + 0, + ipid, + (inb->c.flags & ORPCF_LOCAL) ? MSHCTX_LOCAL : MSHCTX_DIFFERENTMACHINE, + NULL, + authn_level ); + + + // Find the OXIDEntry of the server apartment. + + ASSERT_LOCK_RELEASED + LOCK + + IPIDEntry *ipid_entry = gIPIDTbl.LookupIPID( ipid ); + + if (ipid_entry == NULL || (ipid_entry->dwFlags & IPIDF_DISCONNECTED) + || ipid_entry->pChnl == NULL ) + { + UNLOCK + ASSERT_LOCK_RELEASED + result = RPC_E_DISCONNECTED; + } + else + { + pOxid = ipid_entry->pOXIDEntry; + + // NCALRPC always gets the thread right (except on Chicago). + // For MTAs, any thread will do. + if (transport_type == TRANSPORT_TYPE_WMSG || +#ifndef _CHICAGO_ + transport_type == TRANSPORT_TYPE_LPC || +#endif + (pOxid->dwFlags & OXIDF_MTASERVER)) + { + UNLOCK + ASSERT_LOCK_RELEASED + result = ComInvoke( &call ); + } + else + { + // Pass the message to the app thread. + + IncOXIDRefCnt( pOxid ); + result = GetToSTA( pOxid, &call ); + DecOXIDRefCnt( pOxid ); + + UNLOCK + ASSERT_LOCK_RELEASED + } + } + } + else + { + result = MAKE_WIN32( status ); + } + +#if DBG==1 + } + _except( ThreadInvokeExceptionFilter(GetExceptionCode(), + GetExceptionInformation()) ) + { + } +#endif + + // For comm and server faults, generate an exception. Otherwise the buffer + // is set up correctly. + gOXIDTbl.ValidateOXID(); + if (result == RPC_E_SERVERFAULT) + { + ASSERT_LOCK_RELEASED + RETURN_COMM_STATUS( RPC_E_SERVERFAULT ); + } + else if (result != S_OK) + { + ASSERT_LOCK_RELEASED + RETURN_COMM_STATUS( result ); + } + +#ifdef _CHICAGO_ + return 0; +#endif //_CHICAGO_ +} + + +/***************************************************************************/ +HRESULT ThreadSendReceive( CChannelCallInfo *call ) +{ + TRACECALL(TRACE_RPC, "ThreadSendReceive"); + ComDebOut((DEB_CHANNEL, "ThreadSendReceive pCall:%x\n", call)); + + ASSERT_LOCK_RELEASED + + HRESULT result; + RPCOLEMESSAGE *message = call->pmessage; + WireThat *outb; + + // Call the runtime. In the future, detect server faults and + // change the value of result to RPC_E_SERVERFAULT. + if (call->pChannel->state & mswmsg_cs) + { + CAptCallCtrl *pACC = GetAptCallCtrl(); + CCliModalLoop *pCML = (pACC) ? pACC->GetTopCML() : NULL; + OXIDEntry *pOxidClient; + HWND hwnd = NULL; + + if (IsWOWThread()) + { + LOCK + result = gOXIDTbl.GetLocalEntry( &pOxidClient ); + UNLOCK + Win4Assert( result == S_OK ); + hwnd = (HWND) pOxidClient->hServerSTA; + } + TRACECALL(TRACE_RPC, "I_RpcAsyncSendReceive"); + result = I_RpcAsyncSendReceive( (RPC_MESSAGE *) message, pCML, hwnd ); + + // If the call was canceled, the rest of the code path assumes that + // the call was deleted (by SwitchComThread). So delete it. + if (result == RPC_S_CALL_CANCELLED) + { + // Convert the win32 error to a hresult. + result = RPC_E_CALL_CANCELED; + delete call; + } + } + else + { + TRACECALL(TRACE_RPC, "I_RpcSendReceive"); + result = I_RpcSendReceive( (RPC_MESSAGE *) message ); + } + + // If the result is small, it is probably a Win32 code. + if (result != 0) + { + message->Buffer = NULL; + if ((ULONG) result > 0xfffffff7 || (ULONG) result < 0x2000) + result = MAKE_WIN32( result ); + } + else + { + // No buffer is returned for asynchronous calls on MSWMSG. + if ((call->pChannel->state & mswmsg_cs) == 0 || + (message->rpcFlags & RPCFLG_ASYNCHRONOUS) == 0) + { + // Byte swap the reply header. Fail the call if the buffer is too + // small. + outb = (WireThat *) message->Buffer; + if (message->cbBuffer >= sizeof(WireThatPart1)) + ByteSwapThat( message->dataRepresentation, outb); + else + result = RPC_E_INVALID_HEADER; + } + } + + ComDebOut((DEB_CHANNEL, "ThreadSendReceive pCall:%x hr:%x\n", call, result)); + return result; +} + +/***************************************************************************/ +/* static */ + +void working_call::Cleanup() +{ + ASSERT_LOCK_HELD + + DWORD i; + + // Release everything. + if (next <= CALLCACHE_SIZE) + { + for (i = 0; i < next; i++) + if (list[i] != NULL) + { + PrivMemFree( list[i] ); + list[i] = NULL; + } + + next = 0; + } +} + +/***************************************************************************/ +/* static */ + +void working_call::Initialize() +{ + ASSERT_LOCK_HELD + next = 0; +} + +//--------------------------------------------------------------------------- +// +// Method: working_call:: operator delete +// +// Synopsis: Cache or actually free a working call. +// +// Notes: gComLock need not be held before calling this function. +// +//--------------------------------------------------------------------------- +void working_call::operator delete( void *call ) +{ + // Add the structure to the list if the list is not full and + // if the process is still initialized (since latent threads may try + // to return stuff). + + LOCK + if (next < CALLCACHE_SIZE && gfChannelProcessInitialized) + { + list[next] = call; + next += 1; + } + + // Otherwise just free it. + else + { + PrivMemFree( call ); + } + UNLOCK +} + +//--------------------------------------------------------------------------- +// +// Method: working_call:: operator new +// +// Synopsis: Keep a cache of working_calls. Since the destructor is +// virtual, the correct delete will be called if any base +// class is deleted. +// +// Notes: gComLock must be held before calling this function. +// +//--------------------------------------------------------------------------- +void *working_call::operator new( size_t size ) +{ + ASSERT_LOCK_HELD + + void *call; + + // Get the last entry from the cache. + Win4Assert( size == sizeof( working_call ) ); + if (next > 0 && next < CALLCACHE_SIZE+1) + { + next -= 1; + call = list[next]; + list[next] = NULL; + } + + // If there are none, allocate a new one. + else + call = PrivMemAlloc(size); + return call; +} + +/**********************************************************************/ +working_call::working_call( CALLCATEGORY callcat, + RPCOLEMESSAGE *original_msg, + DWORD flags, + REFIPID ipidServer, + DWORD destctx, + CRpcChannelBuffer *channel, + DWORD authn_level ) : + CChannelCallInfo( callcat, &message, flags, ipidServer, destctx, channel, + authn_level ) +{ + message = *original_msg; +} + diff --git a/private/ole32/com/dcomrem/channelb.hxx b/private/ole32/com/dcomrem/channelb.hxx new file mode 100644 index 000000000..75e41f23f --- /dev/null +++ b/private/ole32/com/dcomrem/channelb.hxx @@ -0,0 +1,259 @@ + +#ifndef _CHANNELB_HXX_ +#define _CHANNELB_HXX_ + +#include <sem.hxx> +#include <rpc.h> +#include <rpcndr.h> +#include <chancont.hxx> +#include <stdid.hxx> + +extern "C" +{ +#include "orpc_dbg.h" +} + +/* Type definitions. */ +typedef enum EChannelState +{ + // The channel on the client side held by the remote handler. + client_cs = 0x1, + + // The channels on the client side held by proxies. + proxy_cs = 0x2, + + // The server channels held by remote handlers. + server_cs = 0x4, + + // Flag to indicate that the channel may be used on any thread. + freethreaded_cs = 0x8, + + // Client side only. Use mswmsg transport. + mswmsg_cs = 0x10, + + // Client side only. handle and pInterfaceInfo initialized. + initialized_cs = 0x20, + + // The server and client are in this process. + process_local_cs = 0x40, + + // The proxy has been set to identify level impersonation (process local only). + identify_cs = 0x80 + +} EChannelState; + + +// The size of this structure must be a multiple of 8. +typedef struct LocalThis +{ + DWORD flags; + DWORD client_thread; +} LocalThis; + +// LocalThis flag indicates parameters in buffer not marshalled NDR +const DWORD LOCALF_NONNDR = 0x800; + + +/***************************************************************************/ + +// Debug Code + +#define ORPC_INVOKE_BEGIN 0 +#define ORPC_INVOKE_END 1 +#define ORPC_SENDRECEIVE_BEGIN 2 +#define ORPC_SENDRECEIVE_END 3 + +#if DBG==1 +void DebugPrintORPCCall(DWORD dwFlag, REFIID riid, DWORD iMethod, DWORD Callcat); +#else +inline void DebugPrintORPCCall(DWORD dwFlag, REFIID riid, DWORD iMethod, DWORD Callcat) {} +#endif + +//+------------------------------------------------------------------------- +// +// Interface: IRpcChannelBuffer2 +// +// Synopsis: Interface to add one more method to the IRpcChannelBuffer +// for use by the call control. +// +//+------------------------------------------------------------------------- +class IRpcChannelBuffer2 : public IRpcChannelBuffer +{ +public: + STDMETHOD (QueryInterface) (REFIID riid, LPVOID FAR* ppvObj) = 0; + STDMETHOD_(ULONG,AddRef) (void) = 0; + STDMETHOD_(ULONG,Release) (void) = 0; + + STDMETHOD (GetBuffer) (RPCOLEMESSAGE *pMessage, REFIID) = 0; + STDMETHOD (FreeBuffer) (RPCOLEMESSAGE *pMessage) = 0; + STDMETHOD (SendReceive) (RPCOLEMESSAGE *pMessage, ULONG *) = 0; + STDMETHOD (GetDestCtx) (DWORD *lpdwDestCtx, LPVOID *lplpvDestCtx) = 0; + STDMETHOD (IsConnected) (void) = 0; + + // method on apartment channels called by CCliModalLoop + STDMETHOD (SendReceive2) (RPCOLEMESSAGE *pMsg, ULONG *pulStatus) = 0; +}; + + + +//+---------------------------------------------------------------- +// +// Class: CRpcChannelBuffer +// +// Purpose: Three distinct uses: +// Client side channel +// State: +// When not connected (after create and after disconnect): +// ref_count 1 + 1 per addref during pending +// calls that were in progress during +// a disconnect. +// pStdId back pointer to Id; not Addref'd +// state client_cs +// +// When connected (after unmarshal): +// ref_count 1 + 1 for each proxy +// pStdId same +// state client_cs +// +// Server side channel; free standing; comes and goes with each +// connection; addref owned disconnected via last release. +// State: +// ref_count > 0 +// pStdId pointer to other Id; AddRef'd +// state server_cs +// +// Interface: IRpcChannelBuffer +// +//----------------------------------------------------------------- + +class CRpcChannelBuffer : public IRpcChannelBuffer2, + public CPrivAlloc +{ + friend HRESULT ComInvoke ( CChannelCallInfo * ); + friend HRESULT ThreadSendReceive ( CChannelCallInfo * ); + + public: + STDMETHOD (QueryInterface) ( REFIID riid, LPVOID FAR* ppvObj); + STDMETHOD_(ULONG,AddRef) ( void ); + STDMETHOD_(ULONG,Release) ( void ); + STDMETHOD ( GetBuffer ) ( RPCOLEMESSAGE *pMessage, REFIID ); + STDMETHOD ( FreeBuffer ) ( RPCOLEMESSAGE *pMessage ); + STDMETHOD ( SendReceive ) ( RPCOLEMESSAGE *pMessage, ULONG * ); + STDMETHOD ( GetDestCtx ) ( DWORD FAR* lpdwDestCtx, + LPVOID FAR* lplpvDestCtx ); + STDMETHOD ( IsConnected ) ( void ); + STDMETHOD (SendReceive2) (RPCOLEMESSAGE *pMsg, ULONG *pulStatus); + + + CRpcChannelBuffer *Copy (OXIDEntry *pOXIDEntry, + REFIPID ripid, REFIID riid); + HRESULT GetHandle ( handle_t * ); + DWORD GetImpLevel ( void ); + REFMOXID GetMOXID ( void ) { return pOXIDEntry->moxid;} + OXIDEntry *GetOXIDEntry ( void ) { return pOXIDEntry; } + HANDLE GetSecurityToken ( void ) { return hToken; } + CStdIdentity *GetStdId ( void ); + BOOL ProcessLocal ( void ) { return state & process_local_cs; } + void SetAuthnLevel ( DWORD level ) { lAuthnLevel = level; } + void SetImpLevel ( DWORD level ); + void SetIPIDEntry(IPIDEntry *pEntry) { pIPIDEntry = pEntry; } + HANDLE SwapSecurityToken ( HANDLE ); + BOOL UsingMswmsg ( void ) { return state & mswmsg_cs; } + +#if DBG == 1 + void AssertValid(BOOL fKnownDisconnected, BOOL fMustBeOnCOMThread); +#else + void AssertValid(BOOL fKnownDisconnected, BOOL fMustBeOnCOMThread) { } +#endif + + CRpcChannelBuffer( CStdIdentity *, + OXIDEntry *, + DWORD eState ); + ~CRpcChannelBuffer(); + + protected: + BOOL CallableOnAnyApt( void ) { return state & freethreaded_cs; } + HRESULT ClientGetBuffer ( RPCOLEMESSAGE *, REFIID ); + + private: + HRESULT AppInvoke ( CChannelCallInfo *, IRpcStubBuffer *, + void *object, void *stub_data, LocalThis * ); + void CheckDestCtx ( void *pDestProtseq ); + HRESULT ServerGetBuffer( RPCOLEMESSAGE *, REFIID ); + HRESULT InitClientSideHandle(); + + ULONG ref_count; + CStdIdentity *pStdId; + DWORD state; // See EChannelState + handle_t handle; + OXIDEntry *pOXIDEntry; + IPIDEntry *pIPIDEntry; + DWORD iDestCtx; + void *pInterfaceInfo; + HANDLE hToken; + DWORD lAuthnLevel; +}; + +inline void DeallocateBuffer(RPCOLEMESSAGE *message ) +{ + if (message->rpcFlags & RPCFLG_LOCAL_CALL) + PrivMemFree( message->Buffer ); + else + I_RpcFreeBuffer( (RPC_MESSAGE *) message ); +} + +// returns the std identity object; not addref'd. +inline CStdIdentity *CRpcChannelBuffer::GetStdId() +{ + AssertValid(FALSE, FALSE); + Win4Assert( pStdId != NULL ); + return pStdId; +} + +inline DWORD CRpcChannelBuffer::GetImpLevel() +{ + if (state & identify_cs) + return RPC_C_IMP_LEVEL_IDENTIFY; + else + return RPC_C_IMP_LEVEL_IMPERSONATE; +} + +inline void CRpcChannelBuffer::SetImpLevel( DWORD level ) +{ + if (level == RPC_C_IMP_LEVEL_IDENTIFY) + state |= identify_cs; + else + state &= ~identify_cs; +} + +inline HRESULT CRpcChannelBuffer::GetHandle( handle_t *pHandle ) +{ + HRESULT status; + LOCK + status = InitClientSideHandle(); + if (status == RPC_S_OK) + *pHandle = handle; + UNLOCK + return status; +} + +/* Prototypes. */ +HRESULT ComInvoke ( CChannelCallInfo * ); +BOOL LocalCall ( void ); +CChannelCallInfo *MakeAsyncCopy ( CChannelCallInfo * ); +void ThreadInvoke ( RPC_MESSAGE *message ); +HRESULT ThreadSendReceive( CChannelCallInfo * ); +HRESULT StubInvoke(RPCOLEMESSAGE *pMsg, IRpcStubBuffer *pStub, + IRpcChannelBuffer *pChnl, DWORD *pdwFault); + +#if DBG==1 +LONG GetInterfaceName(REFIID riid, WCHAR *wszName); +#endif + +// Externs +extern BOOL DoDebuggerHooks; +extern LPORPC_INIT_ARGS DebuggerArg; +extern const uuid_t DEBUG_EXTENSION; + +#endif //_CHANNELB_HXX_ + diff --git a/private/ole32/com/dcomrem/chock.cxx b/private/ole32/com/dcomrem/chock.cxx new file mode 100644 index 000000000..23bd268de --- /dev/null +++ b/private/ole32/com/dcomrem/chock.cxx @@ -0,0 +1,974 @@ +//+------------------------------------------------------------------- +// +// File: chock.cxx +// +// Contents: Channel hook APIs +// +// Classes: CDebugChannelHook +// +//-------------------------------------------------------------------- +#include <ole2int.h> +extern "C" +{ +#include "orpc_dbg.h" +} +#include <channelb.hxx> +#include <ipidtbl.hxx> +#include <chock.hxx> +#include <stream.hxx> + + +//+---------------------------------------------------------------- +// Definitions. + +typedef struct SHookList +{ + struct SHookList *pNext; + IChannelHook *pHook; + UUID uExtension; +} SHookList; + + +//+---------------------------------------------------------------- +// Global variables. +SHookList gHookList = { &gHookList, NULL }; +ULONG gNumExtent = 0; + + +//+------------------------------------------------------------------- +// +// Function: CleanupChannelHooks +// +// Synopsis: Releases all the hooks in the list. +// +//-------------------------------------------------------------------- +void CleanupChannelHooks() +{ + SHookList *pCurr = gHookList.pNext; + + // Release and free each entry. + while (pCurr != &gHookList) + { + pCurr->pHook->Release(); + gHookList.pNext = pCurr->pNext; + PrivMemFree( pCurr ); + pCurr = gHookList.pNext; + } + gNumExtent = 0; +} + +//+------------------------------------------------------------------- +// +// Function: CoRegisterChannelHook +// +// Synopsis: Adds a hook object to the list of hook objects. +// +//-------------------------------------------------------------------- +WINOLEAPI CoRegisterChannelHook( REFGUID uExtension, IChannelHook *pCaptain ) +{ + SHookList *pCurr; + HRESULT hr = S_OK; + + // ChannelProcessIntialize calls while holding the lock. + ASSERT_LOCK_DONTCARE + LOCK + +#if DBG==1 + // See if the extenstion is already on the list. + pCurr = gHookList.pNext; + while (pCurr != &gHookList) + { + if (pCurr->uExtension == uExtension) + break; + pCurr = pCurr->pNext; + } + Win4Assert( pCurr == &gHookList ); + Win4Assert( pCaptain != NULL ); +#endif + + // Add a node at the head. + pCurr = (SHookList *) PrivMemAlloc( sizeof(SHookList) ); + if (pCurr != NULL) + { + pCaptain->AddRef(); + pCurr->uExtension = uExtension; + pCurr->pHook = pCaptain; + pCurr->pNext = gHookList.pNext; + gHookList.pNext = pCurr; + gNumExtent += 1; + } + else + hr = E_OUTOFMEMORY; + + UNLOCK + return hr; +} + +//+------------------------------------------------------------------- +// +// Function: ClientGetSize +// +// Synopsis: Asks each hook in the list how much data it wishes to +// place in the next request on this thread. +// +//-------------------------------------------------------------------- +ULONG ClientGetSize( REFIID riid, ULONG *cNumExtent ) +{ + SHookList *pCurr = gHookList.pNext; + ULONG lSize = sizeof(WireExtentArray) - 8; + ULONG lPiece = 0; + *cNumExtent = 0; + + // Ignore any hooks added to the head of the list. + ASSERT_LOCK_DONTCARE + + // Ask each hook. + while (pCurr != &gHookList) + { + pCurr->pHook->ClientGetSize( pCurr->uExtension, riid, &lPiece ); + if (lPiece != 0) + { + lPiece = ((lPiece + 7) & ~7) + sizeof(WireExtent); + lSize += lPiece; + *cNumExtent += 1; + } + pCurr = pCurr->pNext; + } + + // Round up the number of extents and add size for an array of unique + // flags. + *cNumExtent = (*cNumExtent + 1) & ~1; + lSize += sizeof(DWORD) * *cNumExtent; + + if (*cNumExtent != 0) + return lSize; + else + return 0; +} + +//+------------------------------------------------------------------- +// +// Function: FillBuffer +// +// Synopsis: Asks each hook in the list to place data in the buffer +// for the next request on this thread. Returns the final +// buffer pointer. +// +//-------------------------------------------------------------------- +void *FillBuffer( REFIID riid, WireExtentArray *pArray, ULONG cMax, + ULONG cNumExtent, BOOL fClient ) +{ + SHookList *pCurr; + WireExtent *pExtent; + ULONG lPiece; + ULONG cNumFill; + ULONG i; + + // Ignore any hooks added to the head of the list. + ASSERT_LOCK_DONTCARE + + // Figure out where the extents start. + pCurr = gHookList.pNext; + pExtent = (WireExtent *) ((void **) (pArray + 1) + cNumExtent - 2); + cNumFill = 0; + cMax -= sizeof(WireExtentArray) - 8 + sizeof(void*)*cNumExtent; + + // Ask each hook. + while (pCurr != &gHookList && cMax > 0) + { + lPiece = cMax - sizeof(WireExtent); + if (fClient) + pCurr->pHook->ClientFillBuffer( pCurr->uExtension, riid, &lPiece, + pExtent+1 ); + else + pCurr->pHook->ServerFillBuffer( pCurr->uExtension, riid, &lPiece, + pExtent+1, S_OK ); + Win4Assert( ((lPiece+7)&~7) + sizeof(WireExtent) <= cMax ); + + // If the hook put in data, initialize this extent and find the next. + if (lPiece != 0) + { + pExtent->size = lPiece; + pExtent->rounded_size = (lPiece+7) & ~7; + pExtent->id = pCurr->uExtension; + cNumFill += 1; + cMax -= pExtent->rounded_size + sizeof(WireExtent); + pExtent = (WireExtent *) ((char *) (pExtent+1) + + pExtent->rounded_size); + + Win4Assert( cNumFill <= cNumExtent ); + } + pCurr = pCurr->pNext; + } + + // If any hooks put in data, fill in the header. + if (cNumFill != 0) + { + pArray->size = cNumFill; + pArray->reserved = 0; + pArray->unique = 0x6d727453; // Any non-zero value. + pArray->rounded_size = (cNumFill+1) & ~1; + for (i = 0; i < cNumExtent; i++) + if (i < cNumFill) + pArray->unique_flag[i] = 0x79614b44; // Any non-zero value. + else + pArray->unique_flag[i] = 0; + return pExtent; + } + + // Otherwise return the original buffer. + else + { + return pArray; + } +} + +//+------------------------------------------------------------------- +// +// Function: FindExtentId +// +// Synopsis: Search for the specified extension id in the list of +// registered extensions. Return the index of the entry +// if found +// +//-------------------------------------------------------------------- +ULONG FindExtentId( SHookList *pHead, UUID uExtension ) +{ + ULONG i = 0; + while (pHead != &gHookList) + if (pHead->uExtension == uExtension) + return i; + else + { + i += 1; + pHead = pHead->pNext; + } + return 0xffffffff; +} + +//+------------------------------------------------------------------- +// +// Function: VerifyExtent +// +// Synopsis: Verifies extent array and extents. +// +//-------------------------------------------------------------------- +void *VerifyExtent( SHookList *pHead, WireExtentArray *pArray, ULONG cMax, + WireExtent **aExtent, DWORD dwRep ) +{ + WireExtent *pExtent; + ULONG i; + ULONG j; + ULONG cNumExtent; + WireExtent *pEnd; + + // Fail if the buffer isn't larger then the extent array header. + if (cMax < sizeof(WireExtentArray) - 8) + return NULL; + + // Byte swap the array header. + if ((dwRep & NDR_LOCAL_DATA_REPRESENTATION) != NDR_LITTLE_ENDIAN) + { + ByteSwapLong( pArray->size ); + // ByteSwapLong( pArray->reserved ); + ByteSwapLong( pArray->rounded_size ); + } + + // Validate the array header. + if (cMax < sizeof(WireExtentArray) - 8 + + sizeof(ULONG) * pArray->rounded_size || + (pArray->rounded_size & 1) != 0 || + pArray->size > pArray->rounded_size || + pArray->reserved != 0) + return NULL; + + // Count how many unique flags are set. + cNumExtent = 0; + for (i = 0; i < pArray->size; i++) + if (pArray->unique_flag[i]) + cNumExtent += 1; + + // Look up each extent from the packet in the registered list. + pEnd = (WireExtent *) ((char *) pArray + cMax); + pExtent = (WireExtent *) &pArray->unique_flag[pArray->rounded_size]; + for (i = 0; i < cNumExtent; i++) + { + // Fail if the next extent header doesn't fit in the buffer. + if (pExtent + 1 > pEnd) + return NULL; + + // Byte swap the extent header. + if ((dwRep & NDR_LOCAL_DATA_REPRESENTATION) != NDR_LITTLE_ENDIAN) + { + ByteSwapLong( pExtent->rounded_size ); + ByteSwapLong( pExtent->size ); + ByteSwapLong( pExtent->id.Data1 ); + ByteSwapShort( pExtent->id.Data2 ); + ByteSwapShort( pExtent->id.Data3 ); + } + + // Validate the extent. + if (pExtent->size > pExtent->rounded_size || + (pExtent->rounded_size & 1) != 0 || + ((char *) (pExtent+1)) + pExtent->rounded_size > (char *) pEnd) + return NULL; + + // If the extension is registered, save a pointer to it. + j = FindExtentId( pHead, pExtent->id ); + if (j != 0xffffffff) + aExtent[j] = pExtent; + + // Find the next extension. + pExtent = (WireExtent *) ((char *) (pExtent + 1) + + pExtent->rounded_size); + } + return pExtent; +} + +//+------------------------------------------------------------------- +// +// Function: ClientNotify +// +// Synopsis: Calls each hook and passes data to those that received +// data in a reply. +// +// Notes: pOut is NULL for failed calls or async calls. +// +//-------------------------------------------------------------------- +HRESULT ClientNotify( REFIID riid, WireThat *pOut, ULONG cMax, void **pStubData, + DWORD dwRep, HRESULT hr ) +{ + SHookList *pHead = gHookList.pNext; + SHookList *pCurr; + WireExtent **aExtent; + ULONG cMaxExtent = gNumExtent; + ULONG i; + + // Return immediately if there is nothing to do. + *pStubData = &pOut->d.ea; + if (pHead == &gHookList && + (pOut == NULL || pOut->c.unique == FALSE)) + return hr; + + // Initialize the array of extent pointers. + aExtent = (WireExtent **) _alloca( cMaxExtent * sizeof(WireExtent *) ); + memset( aExtent, 0, cMaxExtent * sizeof( WireExtent *) ); + + // If there are any extents, verify them and sort them. + if (SUCCEEDED(hr) && pOut != NULL && pOut->c.unique) + { + *pStubData = VerifyExtent( pHead, &pOut->d.ea, cMax - sizeof(WireThatPart1), + aExtent, dwRep ); + if (*pStubData == NULL) + return RPC_E_INVALID_EXTENSION; + } + + // Notify all the hooks + for (pCurr = pHead, i = 0; pCurr != &gHookList; pCurr = pCurr->pNext, i++) + pCurr->pHook->ClientNotify( pCurr->uExtension, riid, + aExtent[i] != NULL ? aExtent[i]->size : 0, + aExtent[i] != NULL ? aExtent[i] + 1 : NULL, + dwRep, hr ); + return hr; +} + +//+------------------------------------------------------------------- +// +// Function: ServerNotify +// +// Synopsis: Calls each hook and passes data to those that receive +// data in a request. +// +//-------------------------------------------------------------------- +HRESULT ServerNotify( REFIID riid, WireThis *pIn, ULONG cMax, void **pStubData, + DWORD dwRep ) +{ + SHookList *pHead = gHookList.pNext; + SHookList *pCurr; + WireExtent **aExtent; + ULONG cMaxExtent = gNumExtent; + ULONG i; + + // Return immediately if there is nothing to do. + *pStubData = &pIn->d.ea; + if (pHead == &gHookList && pIn->c.unique == FALSE) + return S_OK; + + // Initialize the array of extent pointers. + aExtent = (WireExtent **) _alloca( cMaxExtent * sizeof(WireExtent *) ); + memset( aExtent, 0, cMaxExtent * sizeof( WireExtent *) ); + + // If there are any extents, verify them and sort them. + if (pIn->c.unique) + { + *pStubData = VerifyExtent( pHead, &pIn->d.ea, cMax - sizeof(WireThisPart1), + aExtent, dwRep ); + if (*pStubData == NULL) + return RPC_E_INVALID_EXTENSION; + } + + // Notify all the hooks + for (pCurr = pHead, i = 0; pCurr != &gHookList; pCurr = pCurr->pNext, i++) + pCurr->pHook->ServerNotify( pCurr->uExtension, riid, + aExtent[i] != NULL ? aExtent[i]->size : 0, + aExtent[i] != NULL ? aExtent[i] + 1 : NULL, + dwRep ); + return S_OK; +} + +//+------------------------------------------------------------------- +// +// Function: ServerGetSize +// +// Synopsis: Asks each hook in the list how much data it wishes to +// place in the next reply on this thread. +// +//-------------------------------------------------------------------- +ULONG ServerGetSize( REFIID riid, ULONG *cNumExtent ) +{ + SHookList *pCurr = gHookList.pNext; + ULONG lSize = sizeof(WireExtentArray) - 8; + ULONG lPiece = 0; + *cNumExtent = 0; + + // Ask each hook. + while (pCurr != &gHookList) + { + pCurr->pHook->ServerGetSize( pCurr->uExtension, riid, S_OK, &lPiece ); + if (lPiece != 0) + { + lPiece = ((lPiece + 7) & ~7) + sizeof(WireExtent); + lSize += lPiece; + *cNumExtent += 1; + } + pCurr = pCurr->pNext; + } + + // Round up the number of extents and add size for an array of unique + // flags. + *cNumExtent = (*cNumExtent + 1) & ~1; + lSize += sizeof(DWORD) * *cNumExtent; + if (*cNumExtent != 0) + return lSize; + else + return 0; +} + +//+------------------------------------------------------------------- +// +// Member: CDebugChannelHook::ClientGetSize +// +// Synopsis: Asks the VC debugger how much data to put in the next +// request on this thread. Stores the result in TLS. +// +//-------------------------------------------------------------------- +STDMETHODIMP_(void) CDebugChannelHook::ClientGetSize( REFGUID uExtension, REFIID riid, + ULONG *pSize ) +{ + COleTls tls; + + Win4Assert( DEBUG_EXTENSION == uExtension ); + + if (DoDebuggerHooks) + tls->cDebugData = DebugORPCClientGetBufferSize( NULL, + riid, NULL, NULL, DebuggerArg, DoDebuggerHooks ); + else + tls->cDebugData = 0; + + *pSize = tls->cDebugData; +} + +//+------------------------------------------------------------------- +// +// Member: CDebugChannelHook::ClientFillBuffer +// +// Synopsis: Asks the VC debugger to place data in the buffer for +// the next request on this thread. Uses the size stored +// in TLS. +// +//-------------------------------------------------------------------- +STDMETHODIMP_(void) CDebugChannelHook::ClientFillBuffer( REFGUID uExtension, + REFIID riid, + ULONG *pSize, void *pBuffer ) +{ + COleTls tls; + + Win4Assert( DEBUG_EXTENSION == uExtension ); + Win4Assert( tls->cDebugData <= *pSize ); + + if (tls->cDebugData != 0) + DebugORPCClientFillBuffer( + tls->pCallInfo->pmessage, + riid, + NULL, + NULL, + pBuffer, + tls->cDebugData, + DebuggerArg, + DoDebuggerHooks ); + *pSize = tls->cDebugData; +} + +//+------------------------------------------------------------------- +// +// Member: CDebugChannelHook::ClientNotify +// +// Synopsis: Passes data to the VC debugger received on the last +// reply on this thread. +// +//-------------------------------------------------------------------- +STDMETHODIMP_(void) CDebugChannelHook::ClientNotify( + REFGUID uExtension, REFIID riid, + ULONG lSize, void *pBuffer, + DWORD dwRep, HRESULT hr ) +{ + COleTls tls; + + Win4Assert( DEBUG_EXTENSION == uExtension ); + + if (pBuffer != NULL || DoDebuggerHooks) + DebugORPCClientNotify( + tls->pCallInfo->pmessage, + riid, + NULL, + NULL, + hr, + pBuffer, + lSize, + DebuggerArg, + DoDebuggerHooks ); +} + +//+------------------------------------------------------------------- +// +// Member: CDebugChannelHook::ServerNotify +// +// Synopsis: Passes data to the VC debugger receive on the last +// request on this thread. +// +//-------------------------------------------------------------------- +STDMETHODIMP_(void) CDebugChannelHook::ServerNotify( + REFGUID uExtension, REFIID riid, + ULONG lSize, void *pBuffer, + DWORD dwRep ) +{ + COleTls tls; + IPIDEntry *pIpid; + void *pv = NULL; + + Win4Assert( DEBUG_EXTENSION == uExtension ); + + if (pBuffer != NULL || DoDebuggerHooks) + { + // Lookup the IPID entry. + LOCK + pIpid = gIPIDTbl.LookupIPID( tls->pCallInfo->ipid ); + UNLOCK + Win4Assert( pIpid != NULL ); + + // Get the object pointer from the stub because the IPID entry + // might have a different pointer. + ((IRpcStubBuffer *) pIpid->pStub)->DebugServerQueryInterface( &pv ); + + // Call the debugger. + DebugORPCServerNotify( + tls->pCallInfo->pmessage, + riid, + pIpid->pChnl, + pv, + NULL, + pBuffer, + lSize, + DebuggerArg, + DoDebuggerHooks ); + + // Release the object pointer. + if (pv != NULL) + ((IRpcStubBuffer *) pIpid->pStub)->DebugServerRelease( pv ); + } +} + +//+------------------------------------------------------------------- +// +// Member: CDebugChannelHook::ServerGetSize +// +// Synopsis: Asks the VC debugger how much data to place in the buffer +// for the next reply on this thread. Stores the result +// in TLS. +// +//-------------------------------------------------------------------- +STDMETHODIMP_(void) CDebugChannelHook::ServerGetSize( REFGUID uExtension, REFIID riid, + HRESULT hrFault, ULONG *pSize ) +{ + COleTls tls; + IPIDEntry *pIpid; + void *pv = NULL; + + Win4Assert( DEBUG_EXTENSION == uExtension ); + + if (DoDebuggerHooks) + { + // Lookup the IPID entry. + LOCK + pIpid = gIPIDTbl.LookupIPID( tls->pCallInfo->ipid ); + UNLOCK + Win4Assert( pIpid != NULL ); + + // Get the object pointer from the stub because the IPID entry + // might have a different pointer. + ((IRpcStubBuffer *) pIpid->pStub)->DebugServerQueryInterface( &pv ); + + // Ask the debugger how much data it has. + tls->cDebugData = DebugORPCServerGetBufferSize( + tls->pCallInfo->pmessage, + riid, + pIpid->pChnl, + pv, + NULL, + DebuggerArg, + DoDebuggerHooks ); + + // Release the object pointer. + if (pv != NULL) + ((IRpcStubBuffer *) pIpid->pStub)->DebugServerRelease( pv ); + } + else + tls->cDebugData = 0; + + *pSize = tls->cDebugData; +} + +//+------------------------------------------------------------------- +// +// Member: CDebugChannelHook::ServerFillBuffer +// +// Synopsis: Asks the VC debugger to place data in the buffer for the +// next reply on this thread. Uses the size from TLS. +// +//-------------------------------------------------------------------- +STDMETHODIMP_(void) CDebugChannelHook::ServerFillBuffer( REFGUID uExtension, REFIID riid, + ULONG *pSize, void *pBuffer, HRESULT hrFault ) +{ + COleTls tls; + IPIDEntry *pIpid; + void *pv = NULL; + + Win4Assert( DEBUG_EXTENSION == uExtension ); + Win4Assert( tls->cDebugData <= *pSize ); + + if (tls->cDebugData != 0) + { + // Lookup the IPID entry. + LOCK + pIpid = gIPIDTbl.LookupIPID( tls->pCallInfo->ipid ); + UNLOCK + Win4Assert( pIpid != NULL ); + + // Get the object pointer from the stub because the IPID entry + // might have a different pointer. + ((IRpcStubBuffer *) pIpid->pStub)->DebugServerQueryInterface( &pv ); + + // Ask the debugger to write its data. + DebugORPCServerFillBuffer( + tls->pCallInfo->pmessage, + riid, + pIpid->pChnl, + pv, + NULL, + pBuffer, + tls->cDebugData, + DebuggerArg, + DoDebuggerHooks ); + + // Release the object pointer. + if (pv != NULL) + ((IRpcStubBuffer *) pIpid->pStub)->DebugServerRelease( pv ); + } + + *pSize = tls->cDebugData; +} + +//+------------------------------------------------------------------- +// +// Member: CDebugChannelHook::QueryInterface +// +// Synopsis: Queries this object for interfaces +// +//-------------------------------------------------------------------- +STDMETHODIMP CDebugChannelHook::QueryInterface( REFIID riid, LPVOID FAR* ppvObj) +{ + if (IsEqualIID(riid, IID_IUnknown) || + IsEqualIID(riid, IID_IChannelHook)) + { + *ppvObj = this; + } + else + { + *ppvObj = NULL; + return E_NOINTERFACE; + } + + // This object is not reference counted. + // AddRef(); + return S_OK; +} + +//+------------------------------------------------------------------- +// +// Member: CDebugChannelHook::AddRef +// +// Synopsis: Increments object reference count. +// +//-------------------------------------------------------------------- +STDMETHODIMP_(ULONG) CDebugChannelHook::AddRef( ) +{ + return 1; +} + +//+------------------------------------------------------------------- +// +// Member: CDebugChannelHook::Release +// +// Synopsis: Decrements object reference count and deletes if zero. +// +//-------------------------------------------------------------------- +STDMETHODIMP_(ULONG) CDebugChannelHook::Release( ) +{ + return 1; +} + + +//+------------------------------------------------------------------- +// +// Member: CErrorChannelHook::ClientGetSize +// +// Synopsis: Does nothing. +// +//-------------------------------------------------------------------- +STDMETHODIMP_(void) CErrorChannelHook::ClientGetSize( REFGUID uExtension, REFIID riid, + ULONG *pSize ) +{ + Win4Assert( ERROR_EXTENSION == uExtension ); + + *pSize = 0; +} + +//+------------------------------------------------------------------- +// +// Member: CErrorChannelHook::ClientFillBuffer +// +// Synopsis: Does nothing. +// +//-------------------------------------------------------------------- +STDMETHODIMP_(void) CErrorChannelHook::ClientFillBuffer( REFGUID uExtension, + REFIID riid, + ULONG *pSize, void *pBuffer ) +{ + Win4Assert( ERROR_EXTENSION == uExtension ); + + *pSize = 0; +} + +//+------------------------------------------------------------------- +// +// Member: CErrorChannelHook::ClientNotify +// +// Synopsis: Unmarshals the COM extended error information. +// +//-------------------------------------------------------------------- +STDMETHODIMP_(void) CErrorChannelHook::ClientNotify( + REFGUID uExtension, REFIID riid, + ULONG lSize, void *pBuffer, + DWORD dwRep, HRESULT hr ) +{ + COleTls tls; + + Win4Assert( ERROR_EXTENSION == uExtension ); + + + //Unmarshal the new error object. + if ((pBuffer != NULL) && (lSize > 0)) + { + CNdrStream MemStream((unsigned char *)pBuffer, lSize); + + //Release the old error object. + if(tls->punkError != NULL) + { + tls->punkError->Release(); + tls->punkError = NULL; + } + + CoUnmarshalInterface(&MemStream, + IID_IUnknown, + (void **) &tls->punkError); + } + else if((tls->punkError != NULL) && + !IsEqualIID(riid, IID_IRundown) && + !IsEqualIID(riid, IID_IRemUnknown) && + !IsEqualIID(riid, IID_ISupportErrorInfo)) + { + //Release the old error object. + tls->punkError->Release(); + tls->punkError = NULL; + } + +} + +//+------------------------------------------------------------------- +// +// Member: CErrorChannelHook::ServerNotify +// +// Synopsis: Clears the COM extended error information on an +// incoming call. +// +//-------------------------------------------------------------------- +STDMETHODIMP_(void) CErrorChannelHook::ServerNotify( + REFGUID uExtension, REFIID riid, + ULONG lSize, void *pBuffer, + DWORD dwRep ) +{ + COleTls tls; + + Win4Assert( ERROR_EXTENSION == uExtension ); + + //Release the old error object. + if(tls->punkError != NULL) + { + tls->punkError->Release(); + tls->punkError = NULL; + } + +} + +//+------------------------------------------------------------------- +// +// Member: CErrorChannelHook::ServerGetSize +// +// Synopsis: Calculates the size of the marshalled error object. +// +//-------------------------------------------------------------------- +STDMETHODIMP_(void) CErrorChannelHook::ServerGetSize( REFGUID uExtension, REFIID riid, + HRESULT hrFault, ULONG *pSize ) +{ + HRESULT hr; + COleTls tls; + + Win4Assert( ERROR_EXTENSION == uExtension ); + + tls->cbErrorData = 0; + + //Compute the size of the marshalled error object. + if(tls->punkError != NULL) + { + hr = CoGetMarshalSizeMax( &tls->cbErrorData, + IID_IUnknown, + tls->punkError, + ((CChannelCallInfo *) tls->pCallInfo)->iDestCtx, + NULL, + MSHLFLAGS_NORMAL ); + if(FAILED(hr)) + { + //Release the error object. + tls->punkError->Release(); + tls->punkError = NULL; + tls->cbErrorData = 0; + } + } + + *pSize = tls->cbErrorData; +} + +//+------------------------------------------------------------------- +// +// Member: CErrorChannelHook::ServerFillBuffer +// +// Synopsis: Marshals the error object. +// +//-------------------------------------------------------------------- +STDMETHODIMP_(void) CErrorChannelHook::ServerFillBuffer( REFGUID uExtension, REFIID riid, + ULONG *pSize, void *pBuffer, HRESULT hrFault ) +{ + HRESULT hr; + COleTls tls; + ULONG cbSize = 0; + + Win4Assert( ERROR_EXTENSION == uExtension ); + Win4Assert( tls->cbErrorData <= *pSize ); + + if(tls->punkError != NULL) + { + //Marshal the error object. + if(tls->cbErrorData > 0) + { + CNdrStream MemStream((unsigned char *)pBuffer, tls->cbErrorData); + + hr = CoMarshalInterface(&MemStream, + IID_IUnknown, + tls->punkError, + ((CChannelCallInfo *) tls->pCallInfo)->iDestCtx, + NULL, + MSHLFLAGS_NORMAL); + + if(FAILED(hr)) + { + tls->cbErrorData = 0; + } + } + + //Release the error object. + tls->punkError->Release(); + tls->punkError = NULL; + } + + *pSize = tls->cbErrorData; +} + +//+------------------------------------------------------------------- +// +// Member: CErrorChannelHook::QueryInterface +// +// Synopsis: Queries this object for interfaces +// +//-------------------------------------------------------------------- +STDMETHODIMP CErrorChannelHook::QueryInterface( REFIID riid, LPVOID FAR* ppvObj) +{ + if (IsEqualIID(riid, IID_IUnknown) || + IsEqualIID(riid, IID_IChannelHook)) + { + *ppvObj = this; + } + else + { + *ppvObj = NULL; + return E_NOINTERFACE; + } + + // This object is not reference counted. + // AddRef(); + return S_OK; +} + +//+------------------------------------------------------------------- +// +// Member: CErrorChannelHook::AddRef +// +// Synopsis: Increments object reference count. +// +//-------------------------------------------------------------------- +STDMETHODIMP_(ULONG) CErrorChannelHook::AddRef( ) +{ + return 1; +} + +//+------------------------------------------------------------------- +// +// Member: CErrorChannelHook::Release +// +// Synopsis: Decrements object reference count and deletes if zero. +// +//-------------------------------------------------------------------- +STDMETHODIMP_(ULONG) CErrorChannelHook::Release( ) +{ + return 1; +} + diff --git a/private/ole32/com/dcomrem/chock.hxx b/private/ole32/com/dcomrem/chock.hxx new file mode 100644 index 000000000..e8c0fd5a7 --- /dev/null +++ b/private/ole32/com/dcomrem/chock.hxx @@ -0,0 +1,173 @@ +//+------------------------------------------------------------------- +// +// File: chock.hxx +// +// Contents: APIs for channel hooks +// +//-------------------------------------------------------------------- +#ifndef _CHOCK_HXX_ +#define _CHOCK_HXX_ + +/***************************************************************************/ +const DWORD MASK_A_C_ = 0xFF00FF00; +const DWORD MASK__B_D = 0x00FF00FF; +const DWORD MASK_AB__ = 0xFFFF0000; +const DWORD MASK___CD = 0x0000FFFF; + +inline void ByteSwapLong( DWORD &l ) +{ + // Start with ABCD. + // First change it to BADC. + l = ((l & MASK_A_C_) >> 8) | ((l & MASK__B_D) << 8); + + // Then change it to DCBA. + l = ((l & MASK_AB__) >> 16) | ((l & MASK___CD) << 16); +} + +/***************************************************************************/ +inline void ByteSwapShort( unsigned short &s ) +{ + s = ((s & 0x00FF) << 8) | ((s & 0xFF00) >> 8); +} + +/***************************************************************************/ +typedef struct +{ + unsigned long rounded_size; // Actual number of extents. + uuid_t id; // Extension identifier. + unsigned long size; // Extension size. + + // byte data[]; // Extension data. +} WireExtent; + + +// Array of extensions. +typedef struct +{ + unsigned long size; // Number of extents. + unsigned long reserved; // Must be zero. + unsigned long unique; // Flag to indicate presence of unique_flag array. + unsigned long rounded_size; // Actual number of extents. + + unsigned long unique_flag[2]; // Flags to indicate presense of ORPC_EXTENTs +} WireExtentArray; + +// These two structures are laid out to match the NDR wire represenation +// of the type ORPCTHIS in obase.idl. +typedef struct +{ + COMVERSION version; // COM version number + unsigned long flags; // ORPCF flags for presence of other data + unsigned long reserved1; // set to zero + CID cid; // causality id of caller + unsigned long unique; // tag to indicate presence of extensions +} WireThisPart1; + +typedef struct +{ + WireThisPart1 part1; + + // Debug data. + WireExtentArray ea; + WireExtent e; +} WireThisPart2; + +typedef union +{ + WireThisPart1 c; + WireThisPart2 d; +} WireThis; + +// These two structures are laid out to match the NDR wire represenation +// of the type ORPCTHAT in obase.idl. +typedef struct +{ + unsigned long flags; // ORPCF flags for presence of other data + unsigned long unique; // tag to indicate presence of extensions +} WireThatPart1; + +typedef struct +{ + WireThatPart1 part1; + + // Debug data. + WireExtentArray ea; + WireExtent e; +} WireThatPart2; + +typedef union +{ + WireThatPart1 c; + WireThatPart2 d; +} WireThat; + + +//+---------------------------------------------------------------- +// +// Class: CDebugChannelHook, private +// +// Purpose: Translates channel hook calls to special calls the VC +// debugger expects. +// +//----------------------------------------------------------------- + +class CDebugChannelHook : public IChannelHook +{ + public: + STDMETHOD (QueryInterface) ( REFIID riid, LPVOID FAR* ppvObj); + STDMETHOD_(ULONG,AddRef) ( void ); + STDMETHOD_(ULONG,Release) ( void ); + + STDMETHOD_(void,ClientGetSize) ( REFGUID, REFIID, ULONG *DataSize ); + STDMETHOD_(void,ClientFillBuffer)( REFGUID, REFIID, ULONG *DataSize, void *DataBuffer ); + STDMETHOD_(void,ClientNotify) ( REFGUID, REFIID, ULONG DataSize, void *DataBuffer, + DWORD DataRep, HRESULT hrFault ); + STDMETHOD_(void,ServerNotify) ( REFGUID, REFIID, ULONG DataSize, void *DataBuffer, + DWORD DataRep ); + STDMETHOD_(void,ServerGetSize) ( REFGUID, REFIID, HRESULT hrFault, + ULONG *DataSize ); + STDMETHOD_(void,ServerFillBuffer)( REFGUID, REFIID, ULONG *DataSize, void *DataBuffer, + HRESULT hrFault ); +}; + + +//+---------------------------------------------------------------- +// +// Class: CErrorChannelHook, private +// +// Purpose: Channel hook for marshalling COM extended error +// information. +// +//----------------------------------------------------------------- +class CErrorChannelHook : public IChannelHook +{ + public: + STDMETHOD (QueryInterface) ( REFIID riid, LPVOID FAR* ppvObj); + STDMETHOD_(ULONG,AddRef) ( void ); + STDMETHOD_(ULONG,Release) ( void ); + + STDMETHOD_(void,ClientGetSize) ( REFGUID, REFIID, ULONG *DataSize ); + STDMETHOD_(void,ClientFillBuffer)( REFGUID, REFIID, ULONG *DataSize, void *DataBuffer ); + STDMETHOD_(void,ClientNotify) ( REFGUID, REFIID, ULONG DataSize, void *DataBuffer, + DWORD DataRep, HRESULT hrFault ); + STDMETHOD_(void,ServerNotify) ( REFGUID, REFIID, ULONG DataSize, void *DataBuffer, + DWORD DataRep ); + STDMETHOD_(void,ServerGetSize) ( REFGUID, REFIID, HRESULT hrFault, + ULONG *DataSize ); + STDMETHOD_(void,ServerFillBuffer)( REFGUID, REFIID, ULONG *DataSize, void *DataBuffer, + HRESULT hrFault ); +}; + +/***************************************************************************/ +// Functions called by channel. +void CleanupChannelHooks(); +ULONG ClientGetSize( REFIID riid, ULONG *cNumExtent ); +HRESULT ClientNotify ( REFIID riid, WireThat *, ULONG cMax, void **pStubData, + DWORD DataRep, HRESULT hr ); +void *FillBuffer ( REFIID riid, WireExtentArray *, ULONG cMaxSize, + ULONG cNumExtent, BOOL fClient ); +ULONG ServerGetSize( REFIID riid, ULONG *cNumExtent ); +HRESULT ServerNotify ( REFIID riid, WireThis *, ULONG cMax, void **pStubData, + DWORD DataRep ); + +#endif // _CHOCK_HXX_ diff --git a/private/ole32/com/dcomrem/coapi.cxx b/private/ole32/com/dcomrem/coapi.cxx new file mode 100644 index 000000000..681a99c12 --- /dev/null +++ b/private/ole32/com/dcomrem/coapi.cxx @@ -0,0 +1,773 @@ +//+------------------------------------------------------------------- +// +// File: coapi.cxx +// +// Contents: Public COM remote subsystem APIs +// +// Functions: CoGetStandardMarshal - returns IMarshal for given interface +// CoGetMarshalSizeMax - returns max size buffer needed +// CoMarshalInterface - marshals an interface +// CoUnmarshalInterface - unmarshals an interface +// CoReleaseMarshalData - releases data from marshaled interface +// CoLockObjectExternal - keep object alive or releases it +// CoDisconnectObject - kills sessions held by remote clients +// CoIsHandlerConnected - try to determine if handler connected +// +// History: 23-Nov-92 Rickhi Created +// 11-Dec-93 CraigWi Switched to identity object +// 05-Jul-94 BruceMa Check for end of stream +// 20-Feb-95 Rickhi Major changes for DCOM +// +//-------------------------------------------------------------------- +#include <ole2int.h> +#include <olerem.h> +#include <marshal.hxx> // CStdMarshal +#include <stdid.hxx> // CStdIdentity, IDTable APIs +#include <service.hxx> // SASIZE + + +// static unmarshaler +IMarshal *gpStdMarshal = NULL; + + +//+------------------------------------------------------------------- +// +// Function: CoGetStandardMarshal, public +// +// Synopsis: Returns an instance of the standard IMarshal for the +// specifed object. +// +// Algorithm: lookup or create a CStdIdentity (and CStdMarshal) for +// the object. +// +// History: 23-Nov-92 Rickhi Created +// 11-Dec-93 CraigWi Switched to identity object +// 20-Feb-95 Rickhi Switched to CStdMarshal +// +//-------------------------------------------------------------------- +STDAPI CoGetStandardMarshal(REFIID riid, IUnknown *pUnk, DWORD dwDestCtx, + void *pvDestCtx, DWORD mshlflags, IMarshal **ppMarshal) +{ + TRACECALL(TRACE_MARSHAL, "CoGetStandardMarshal"); + ComDebOut((DEB_MARSHAL, + "CoGetStandardMarshal riid:%I pUnk:%x dwDest:%x pvDest:%x flags:%x\n", + &riid, pUnk, dwDestCtx, pvDestCtx, mshlflags)); + + // validate the input parameters + if (ppMarshal == NULL || + dwDestCtx > MSHCTX_INPROC || pvDestCtx != NULL || + (mshlflags & ~MSHLFLAGS_ALL)) + { + return E_INVALIDARG; + } + + *ppMarshal = NULL; + + HRESULT hr = InitChannelIfNecessary(); + if (FAILED(hr)) + return hr; + + if (pUnk == NULL) + { + // this is the unmarshal side. any instance will do so we return + // the static one. Calling UnmarshalInterface will return the real + // proxy. + + hr = GetStaticUnMarshaler(ppMarshal); + } + else + { + // this is the marshal side. We put a strong reference on the StdId + // so that the ID does not get disconnected when the last external + // Release occurs. + + CALLHOOKOBJECT(S_OK,CLSID_NULL,riid,&pUnk); + + DWORD dwFlags = IDLF_CREATE | IDLF_STRONG; + if (mshlflags & MSHLFLAGS_NOPING) + { + // requesting NOPING, so set the IDL flags accordingly + dwFlags |= IDLF_NOPING; + } + + CStdIdentity *pStdId; + hr = LookupIDFromUnk(pUnk, dwFlags, &pStdId); + *ppMarshal = (IMarshal *)pStdId; + } + + ComDebOut((DEB_MARSHAL, "CoGetStandardMarshal: pIM:%x hr:%x\n", + *ppMarshal, hr)); + return hr; +} + +//+------------------------------------------------------------------- +// +// Function: CoGetMarshalSizeMax, public +// +// synopsis: returns max size needed to marshal the specified interface. +// +// History: 23-Nov-92 Rickhi Created +// 11-Dec-93 CraigWi Switched to static marshaler +// 20-Feb-95 Rickhi Return correct sizes once again. +// +//-------------------------------------------------------------------- +STDAPI CoGetMarshalSizeMax(ULONG *pulSize, REFIID riid, IUnknown *pUnk, + DWORD dwDestCtx, void *pvDestCtx, DWORD mshlflags) +{ + TRACECALL(TRACE_MARSHAL, "CoGetMarshalSizeMax"); + CALLHOOKOBJECT(S_OK,CLSID_NULL,riid,&pUnk); + ComDebOut((DEB_MARSHAL, + "CoGetMarshalSizeMax: riid:%I pUnk:%x dwDest:%x pvDest:%x flags:%x\n", + &riid, pUnk, dwDestCtx, pvDestCtx, mshlflags)); + + // validate the input parameters + if (pulSize == NULL || pUnk == NULL || + dwDestCtx > MSHCTX_INPROC || pvDestCtx != NULL || + (mshlflags & ~MSHLFLAGS_ALL)) + { + return E_INVALIDARG; + } + + *pulSize = 0; + + HRESULT hr = InitChannelIfNecessary(); + if (FAILED(hr)) + return hr; + + IMarshal *pIM; + hr = pUnk->QueryInterface(IID_IMarshal, (void **)&pIM); + + if (SUCCEEDED(hr)) + { + // object supports custom marshalling, ask it how much space it needs + hr = pIM->GetMarshalSizeMax(riid, (void *)pUnk, dwDestCtx, + pvDestCtx, mshlflags, pulSize); + pIM->Release(); + + // add in the size of the stuff CoMarshalInterface will write + *pulSize += sizeof(OBJREF); + } + else + { + // uses standard marshalling, we know the max size already. + *pulSize = sizeof(OBJREF) + SASIZE(gpsaLocalResolver->wNumEntries); + hr = S_OK; + } + + ComDebOut((DEB_MARSHAL, "CoGetMarshalSizeMax: pUnk:%x size:%x hr:%x\n", + pUnk, *pulSize, hr)); + return hr; +} + +//+------------------------------------------------------------------- +// +// Function: CoMarshalInterface, public +// +// Synopsis: marshals the specified interface into the given stream +// +// History: 23-Nov-92 Rickhi Created +// 11-Dec-93 CraigWi Switched to identity object and +// new marshaling format +// 20-Feb-95 Rickhi switched to newer marshal format +// +//-------------------------------------------------------------------- +STDAPI CoMarshalInterface(IStream *pStm, REFIID riid, IUnknown *pUnk, + DWORD dwDestCtx, void *pvDestCtx, DWORD mshlflags) +{ + TRACECALL(TRACE_MARSHAL, "CoMarshalInterface"); + ComDebOut((DEB_MARSHAL, + "CoMarshalInterface: pStm:%x riid:%I pUnk:%x dwDest:%x pvDest:%x flags:%x\n", + pStm, &riid, pUnk, dwDestCtx, pvDestCtx, mshlflags)); + + // validate the input parameters + if (pStm == NULL || pUnk == NULL || + dwDestCtx > MSHCTX_INPROC || pvDestCtx != NULL || + (mshlflags & ~MSHLFLAGS_ALL)) + { + return E_INVALIDARG; + } + + HRESULT hr = InitChannelIfNecessary(); + if (FAILED(hr)) + return hr; + + CALLHOOKOBJECT(S_OK,CLSID_NULL,riid,(IUnknown **)&pUnk); + CALLHOOKOBJECT(S_OK,CLSID_NULL,IID_IStream,(IUnknown **)&pStm); + + + // determine whether to do custom or standard marshaling + IMarshal *pIM; + hr = pUnk->QueryInterface(IID_IMarshal, (void **)&pIM); + + if (SUCCEEDED(hr)) + { + // object supports custom marshaling, use it. we package the + // custom data inside an OBJREF. + Win4Assert(pIM); + + OBJREF objref; + objref.signature = OBJREF_SIGNATURE; + objref.flags = OBJREF_CUSTOM; + objref.iid = riid; + + // get the clsid for unmarshaling + hr = pIM->GetUnmarshalClass(riid, pUnk, dwDestCtx, pvDestCtx, + mshlflags, &ORCST(objref).clsid); + + if (SUCCEEDED(hr) && + !IsEqualCLSID(CLSID_StdMarshal, ORCST(objref).clsid)) + { + // get the size of data to marshal + hr = pIM->GetMarshalSizeMax(riid, (void *)pUnk, dwDestCtx, + pvDestCtx, mshlflags, + &ORCST(objref).size); + + // currently we dont write any extensions into the custom + // objref. The provision is there so we can do it in the + // future, for example, if the unmarshaler does not have the + // unmarshal class code available we could to provide a callback + // mechanism by putting the OXID, and saResAddr in there. + ORCST(objref).cbExtension = 0; + + // write the objref header info into the stream + ULONG cbToWrite = (BYTE *)(&ORCST(objref).pData) - (BYTE *)&objref; + hr = pStm->Write(&objref, cbToWrite, NULL); + } + + if (SUCCEEDED(hr)) + { + // tell the marshaler to write the rest of the data + hr = pIM->MarshalInterface(pStm, riid, pUnk, dwDestCtx, + pvDestCtx, mshlflags); + } + + pIM->Release(); + } + else + { + // use standard marshaling - find or create a standard marshaler + // note this may include handler marshaling. + + // HACKALERT: + // Figure out what flags to pass. If marshaling TABLEWEAK, don't + // add then remove a strong connection, since many objects have a + // bogus implementation of IExternalConnection that shuts down the + // object when the last strong count goes to zero regardless of the + // fLastReleaseCloses flag. + + DWORD dwFlags = IDLF_CREATE; + if (!(mshlflags & MSHLFLAGS_TABLEWEAK)) + { + dwFlags |= IDLF_STRONG; + } + + CStdIdentity *pStdId; + hr = LookupIDFromUnk(pUnk, dwFlags, &pStdId); + + if (SUCCEEDED(hr)) + { + hr = pStdId->MarshalInterface(pStm, riid, pUnk, dwDestCtx, + pvDestCtx, mshlflags); + + if (!(mshlflags & MSHLFLAGS_TABLEWEAK)) + { + // If marshaling succeeded, removing the last strong connection + // should keep the object alive. If marshaling failed, + // removing the last strong connection should shut it down. + + BOOL fKeepAlive = (SUCCEEDED(hr)) ? TRUE : FALSE; + pStdId->DecStrongCnt(fKeepAlive); + } + else + { + pStdId->Release(); + } + } + } + + ComDebOut((DEB_MARSHAL,"CoMarshalInterface: pUnk:%x hr:%x\n",pUnk,hr)); + return hr; +} + +//+------------------------------------------------------------------- +// +// Function: CoUnmarshalInterface, public +// +// Synopsis: Unmarshals a marshaled interface pointer from the stream. +// +// Notes: when a controlling unknown is supplied, it is assumed that +// the HANDLER for the class has done a CreateInstance and wants +// to aggregate just the proxymanager, ie. we dont want to +// instantiate a new class handler (the default unmarshalling +// behaviour). +// +// History: 23-Nov-92 Rickhi Created +// 11-Dec-93 CraigWi Switched to static marshaler and +// new marshaling format +// 20-Feb-95 Rickhi switched to newer marshal format +// +//-------------------------------------------------------------------- +STDAPI CoUnmarshalInterface(IStream *pStm, + REFIID riid, + void **ppv) +{ + TRACECALL(TRACE_MARSHAL, "CoUnmarshalInterface"); + CALLHOOKOBJECT(S_OK,CLSID_NULL,IID_IStream,(IUnknown **)&pStm); + ComDebOut((DEB_MARSHAL, + "CoUnmarshalInterface: pStm:%x riid:%I\n", pStm, &riid)); + + // validate the input parameters + if (pStm == NULL || ppv == NULL) + { + return E_INVALIDARG; + } + + *ppv = NULL; + + HRESULT hr = InitChannelIfNecessary(); + if (FAILED(hr)) + return hr; + + // read the objref from the stream. + OBJREF objref; + hr = ReadObjRef(pStm, objref); + + if (SUCCEEDED(hr)) + { + if (objref.flags & OBJREF_CUSTOM) + { + // uses custom marshaling, create an instance and ask that guy + // to do the unmarshaling. special case createinstance for the + // freethreaded marshaler. + + IMarshal *pIM; + + if (IsEqualCLSID(CLSID_InProcFreeMarshaler, ORCST(objref).clsid)) + { + hr = GetInProcFreeMarshaler(&pIM); + } + else + { + hr = CoCreateInstance(ORCST(objref).clsid, NULL, CLSCTX_INPROC, + IID_IMarshal, (void **)&pIM); + } + + if (SUCCEEDED(hr)) + { + hr = pIM->UnmarshalInterface(pStm, objref.iid, ppv); + pIM->Release(); + } + else + { + // seek past the custom marshalers data so we leave the + // stream at the correct position. + + LARGE_INTEGER libMove; + libMove.LowPart = ORCST(objref).size; + libMove.HighPart = 0; + pStm->Seek(libMove, STREAM_SEEK_CUR, NULL); + } + } + else + { + // uses standard marshaling, call API to find or create the + // instance of CStdMarshal for the oid inside the objref, and + // ask that instance to unmarshal the interface. This covers + // handler unmarshaling also. + + hr = UnmarshalObjRef(objref, ppv); + } + + // free the objref we read above + FreeObjRef(objref); + + if (!InlineIsEqualGUID(riid, GUID_NULL) && + !InlineIsEqualGUID(riid, objref.iid) && SUCCEEDED(hr)) + { + // the interface iid requested was different than the one that + // was marshaled (and was not GUID_NULL), so go get the requested + // one and release the marshaled one. GUID_NULL is used by the Ndr + // unmarshaling engine and means return whatever interface was + // marshaled. + + IUnknown *pUnk = (IUnknown *)*ppv; + +#ifdef WX86OLE + if (gcwx86.IsN2XProxy(pUnk)) + { + // Tell wx86 thunk layer to thunk as IUnknown + gcwx86.SetStubInvokeFlag((BOOL)1); + } +#endif + + hr = pUnk->QueryInterface(riid, ppv); + pUnk->Release(); + } + } + + ComDebOut((DEB_MARSHAL, "CoUnmarshalInterface: pUnk:%x hr:%x\n", + *ppv, hr)); + return hr; +} + +//+------------------------------------------------------------------- +// +// Function: CoReleaseMarshalData, public +// +// Synopsis: release the reference created by CoMarshalInterface +// +// Algorithm: +// +// History: 23-Nov-92 Rickhi +// 11-Dec-93 CraigWi Switched to static marshaler and +// new marshaling format +// 20-Feb-95 Rickhi switched to newer marshal format +// +//-------------------------------------------------------------------- +STDAPI CoReleaseMarshalData(IStream *pStm) +{ + TRACECALL(TRACE_MARSHAL, "CoReleaseMarshalData"); + ComDebOut((DEB_MARSHAL, "CoReleaseMarshalData pStm:%x\n", pStm)); + CALLHOOKOBJECT(S_OK,CLSID_NULL,IID_IStream,(IUnknown **) &pStm); + + // validate the input parameters + if (pStm == NULL) + { + return E_INVALIDARG; + } + + HRESULT hr = InitChannelIfNecessary(); + if (FAILED(hr)) + return hr; + + // read the objref from the stream. + OBJREF objref; + hr = ReadObjRef(pStm, objref); + + if (SUCCEEDED(hr)) + { + if (objref.flags & OBJREF_CUSTOM) + { + // object uses custom marshaling. create an instance of + // the unmarshaling code and ask it to release the marshaled + // data. + + IMarshal *pIM; + + if (IsEqualCLSID(CLSID_InProcFreeMarshaler, ORCST(objref).clsid)) + { + hr = GetInProcFreeMarshaler(&pIM); + } + else + { + hr = CoCreateInstance(ORCST(objref).clsid, NULL, CLSCTX_INPROC, + IID_IMarshal, (void **)&pIM); + } + + if (SUCCEEDED(hr)) + { + hr = pIM->ReleaseMarshalData(pStm); + pIM->Release(); + } + else + { + // seek past the custom marshalers data so we leave the + // stream at the correct position. + + LARGE_INTEGER libMove; + libMove.LowPart = ORCST(objref).size; + libMove.HighPart = 0; + pStm->Seek(libMove, STREAM_SEEK_CUR, NULL); + } + } + else + { + // uses standard marshaling, find or create the instance of + // CStdMarshal for the oid inside the objref, and ask that + // instance to unmarshal the interface. + + CStdMarshal *pStdMshl; + hr = FindStdMarshal(objref, &pStdMshl); + + if (SUCCEEDED(hr)) + { + hr = pStdMshl->ReleaseMarshalObjRef(objref); + pStdMshl->Release(); + } + else if (hr == CO_E_OBJNOTCONNECTED) + { + // it was for this process but the object is already dead + hr = S_OK; + } + } + + // free the objref we read above + FreeObjRef(objref); + } + + ComDebOut((DEB_MARSHAL, "CoReleaseMarshalData hr:%x\n", hr)); + return hr; +} + +//+------------------------------------------------------------------- +// +// Function: CoDisconnectObject, public +// +// synopsis: disconnects all clients of an object by marking their +// connections as terminted abnormaly. +// +// History: 04-Oct-93 Rickhi Created +// 11-Dec-93 CraigWi Switched to identity object +// +//-------------------------------------------------------------------- +STDAPI CoDisconnectObject(IUnknown *pUnk, DWORD dwReserved) +{ + TRACECALL(TRACE_MARSHAL, "CoDisconnectObject"); + ComDebOut((DEB_MARSHAL, "CoDisconnectObject pUnk:%x dwRes:%x\n", + pUnk, dwReserved)); + + // validate the input parameters + if (pUnk == NULL || dwReserved != 0) + { + return E_INVALIDARG; + } + + if (!IsValidInterface(pUnk)) + return E_INVALIDARG; + + if (!IsApartmentInitialized()) + { + return CO_E_NOTINITIALIZED; + } + + CALLHOOKOBJECT(S_OK,CLSID_NULL,IID_IUnknown,&pUnk); + + IMarshal *pIM = NULL; + HRESULT hr = pUnk->QueryInterface(IID_IMarshal, (void **)&pIM); + + if (FAILED(hr)) + { + // object does not support IMarshal directly. Find its standard + // marshaler if there is one, otherwise return an error. + + CStdIdentity *pStdId; + hr = LookupIDFromUnk(pUnk, 0, &pStdId); + pIM = (IMarshal *)pStdId; + } + + if (SUCCEEDED(hr)) + { + hr = pIM->DisconnectObject(dwReserved); + pIM->Release(); + } + else + { + // could not get std marshal, must be disconnected already + return S_OK; + } + + ComDebOut((DEB_MARSHAL,"CoDisconnectObject pIM:%x hr:%x\n", pIM, hr)); + return hr; +} + +//+------------------------------------------------------------------- +// +// Function: CoLockObjectExternal, public +// +// synopsis: adds/revokes a strong reference count to/from the +// identity for the given object. +// +// parameters: [punkObject] - IUnknown of the object +// [fLock] - lock/unlock the object +// [fLastUR] - last unlock releases. +// +// History: 23-Nov-92 Rickhi Created +// 11-Dec-93 CraigWi Switched to identity object +// +//-------------------------------------------------------------------- +STDAPI CoLockObjectExternal(IUnknown *pUnk, BOOL fLock, BOOL fLastUR) +{ + TRACECALL(TRACE_MARSHAL, "CoLockObjectExternal"); + ComDebOut((DEB_MARSHAL, + "CoLockObjectExternal pUnk:%x fLock:%x fLastUR:%x\n", pUnk, fLock, fLastUR)); + + if (!IsValidInterface(pUnk)) + return E_INVALIDARG; + + CALLHOOKOBJECT(S_OK,CLSID_NULL,IID_IUnknown,(IUnknown **)&pUnk); + + HRESULT hr = InitChannelIfNecessary(); + if (FAILED(hr)) + return hr; + + CStdIdentity *pStdID; + hr = LookupIDFromUnk(pUnk, (fLock) ? IDLF_CREATE : 0, &pStdID); + + switch (hr) + { + case S_OK: + + // REF COUNTING: inc or dec external ref count + hr = pStdID->LockObjectExternal(fLock, fLastUR); + pStdID->Release(); + break; + + case CO_E_OBJNOTREG: + // unlock when not registered; 16bit code returned NOERROR; + // disconnected handler goes to S_OK case above. + hr = S_OK; + break; + + case E_OUTOFMEMORY: + break; + + default: + hr = E_UNEXPECTED; + break; + } + + ComDebOut((DEB_MARSHAL, + "CoLockObjectExternal pStdID:%x hr:%x\n", pStdID, hr)); + + return hr; +} + +//+------------------------------------------------------------------- +// +// Function: CoIsHandlerConnected, public +// +// Synopsis: Returns whether or not handler is connected to remote +// +// Algorithm: QueryInterface to IProxyManager. If this is supported, +// then this is a handler. We ask the handler +// for its opinion otherwise we simply return TRUE. +// +// History: 04-Oct-93 Rickhi Created +// +// Notes: The answer of this routine may be wrong by the time +// the routine returns. This is correct behavior as +// this routine is primilary to cleanup state associated +// with connections. +// +//-------------------------------------------------------------------- +STDAPI_(BOOL) CoIsHandlerConnected(LPUNKNOWN pUnk) +{ + // validate input parameters + if (!IsValidInterface(pUnk)) + return FALSE; + + // Assume it is connected + BOOL fResult = TRUE; + + // Handler should support IProxyManager + IProxyManager *pPM; + CALLHOOKOBJECT(S_OK,CLSID_NULL,IID_IUnknown,(IUnknown **)&pUnk); + if (SUCCEEDED(pUnk->QueryInterface(IID_IProxyManager, (void **)&pPM))) + { + // We have something that thinks its is an Ole handler so we ask + fResult = pPM->IsConnected(); + pPM->Release(); + } + + return fResult; +} + +//+------------------------------------------------------------------- +// +// Function: GetStaticUnMarshaler, private +// +// Synopsis: Returns the static instance of the CStdMarshal. +// +// History: 20-Feb-95 Rickhi Created. +// +// Notes: The standard marshaler must be able to resolve identity, that +// is two proxies for the same object must never be created in +// the same apartment. Given that, it makes sense to let the +// standard guy do the unmarshaling. Since we dont know the +// identity of the object upfront, and any instance will do, we +// use a static instance to handle unmarshal. +// +//-------------------------------------------------------------------- +INTERNAL GetStaticUnMarshaler(IMarshal **ppIM) +{ + HRESULT hr = S_OK; + + LOCK + if (gpStdMarshal == NULL) + { + // the global instance has not been created yet, so go make it now. + hr = CreateIdentityHandler(NULL, 0, IID_IMarshal, + (void **)&gpStdMarshal); + if (SUCCEEDED(hr)) + { + // dont let anybody but us delete this thing. + ((CStdIdentity *)gpStdMarshal)->SetLockedInMemory(); + hr = S_OK; + } + } + + *ppIM = gpStdMarshal; + if (gpStdMarshal) + { + gpStdMarshal->AddRef(); + } + UNLOCK; + return hr; +} + +#ifdef WX86OLE +//+------------------------------------------------------------------- +// +// Function: CoGetIIDFromMarshaledInterface, public +// +// Synopsis: Returns the IID embedded inside a marshaled interface +// pointer. Needed by the x86 thunking code. +// +// History: 16-Apr-96 Rickhi Created +// +//-------------------------------------------------------------------- +STDAPI CoGetIIDFromMarshaledInterface(IStream *pStm, IID *piid) +{ + ULARGE_INTEGER ulSeekEnd; + LARGE_INTEGER lSeekStart, lSeekEnd; + LISet32(lSeekStart, 0); + + // remember the current position + HRESULT hr = pStm->Seek(lSeekStart, STREAM_SEEK_CUR, &ulSeekEnd); + + if (SUCCEEDED(hr)) + { + // read the first part of the objref which contains the IID + // also check to ensure the objref is at least partially sane + + OBJREF objref; + hr = StRead(pStm, &objref, 2*sizeof(ULONG) + sizeof(IID)); + + if (SUCCEEDED(hr)) + { + if ((objref.signature != OBJREF_SIGNATURE) || + (objref.flags & OBJREF_RSRVD_MBZ) || + (objref.flags == 0)) + { + // the objref signature is bad, or one of the reserved + // bits in the flags is set, or none of the required bits + // in the flags is set. the objref cant be interpreted so + // fail the call. + + Win4Assert(!"Invalid Objref Flags"); + return RPC_E_INVALID_OBJREF; + } + + // extract the IID + *piid = objref.iid; + } + + // put the seek pointer back to the original location + lSeekEnd.LowPart = ulSeekEnd.LowPart; + lSeekEnd.HighPart = (LONG)ulSeekEnd.HighPart; + hr = pStm->Seek(lSeekEnd, STREAM_SEEK_SET, NULL); + } + + return hr; +} +#endif diff --git a/private/ole32/com/dcomrem/daytona/makefile b/private/ole32/com/dcomrem/daytona/makefile new file mode 100644 index 000000000..1d3728d41 --- /dev/null +++ b/private/ole32/com/dcomrem/daytona/makefile @@ -0,0 +1,10 @@ +############################################################################ +# +# Copyright (C) 1992, Microsoft Corporation. +# +# All rights reserved. +# +############################################################################ + +!include $(NTMAKEENV)\makefile.def + diff --git a/private/ole32/com/dcomrem/daytona/sources b/private/ole32/com/dcomrem/daytona/sources new file mode 100644 index 000000000..4ac4f5b09 --- /dev/null +++ b/private/ole32/com/dcomrem/daytona/sources @@ -0,0 +1,87 @@ +!IF 0 + +Copyright (c) 1989 Microsoft Corporation + +Module Name: + + sources. + +Abstract: + + This file specifies the target component being built and the list of + sources files needed to build that component. Also specifies optional + compiler switches and libraries that are unique for the component being + built. + + +Author: + + Donna Liu (DonnaLi) 19-Dec-1993 + +!ENDIF + + +MAJORCOMP = cairole +MINORCOMP = com + +!include ..\..\..\daytona.inc + +# +# This is the name of the target built from the source files specified +# below. The name should include neither the path nor the file extension. +# + +TARGETNAME= remote + +# +# This specifies where the target is to be built. A private target of +# type LIBRARY or DYNLINK should go to obj, whereas a public target of +# type LIBRARY or DYNLINK should go to $(BASEDIR)\public\sdk\lib. +# + +TARGETPATH= obj + +# +# This specifies the type of the target, such as PROGRAM, DYNLINK, LIBRARY, +# etc. +# + +TARGETTYPE= LIBRARY + +INCLUDES = ..\..\..\common\daytona;..\..\..\ih;..;..\..\inc; +INCLUDES = $(INCLUDES);..\..\dcomidl\daytona;..\..\class;..\..\objact; +INCLUDES = $(INCLUDES);$(BASEDIR)\private\dcomidl\obj + + +C_DEFINES= \ + $(C_DEFINES) -DMSWMSG + +SOURCES= \ + ..\callctrl.cxx \ + ..\chancont.cxx \ + ..\channelb.cxx \ + ..\chock.cxx \ + ..\coapi.cxx \ + ..\hash.cxx \ + ..\idtable.cxx \ + ..\ipidtbl.cxx \ + ..\ipmrshl.cxx \ + ..\locks.cxx \ + ..\marshal.cxx \ + ..\orpc_dbg.c \ + ..\pgalloc.cxx \ + ..\remoteu.cxx \ + ..\resolver.cxx \ + ..\riftbl.cxx \ + ..\security.cxx \ + ..\service.cxx \ + ..\stdid.cxx \ + ..\stream.cxx \ + ..\threads.cxx + +UMTYPE= windows +UMAPPL= +UMTEST= +UMLIBS= + +!include ..\..\precomp2.inc diff --git a/private/ole32/com/dcomrem/dirs b/private/ole32/com/dcomrem/dirs new file mode 100644 index 000000000..80cd267ad --- /dev/null +++ b/private/ole32/com/dcomrem/dirs @@ -0,0 +1,37 @@ +!IF 0 + +Copyright (c) 1989 Microsoft Corporation + +Module Name: + + dirs. + +Abstract: + + This file specifies the subdirectories of the current directory that + contain component makefiles. + + +Author: + + Donna Liu (DonnaLi) 19-Dec-1993 + +!ENDIF + +# +# This is a list of all subdirectories that build required components. +# Each subdirectory name should appear on a line by itself. The build +# follows the order in which the subdirectories are specified. +# + +DIRS= + +# +# This is a list of all subdirectories that build optional components. +# Each subdirectory name should appear on a line by itself. The build +# follows the order in which the subdirectories are specified. +# + +OPTIONAL_DIRS= \ + daytona \ + diff --git a/private/ole32/com/dcomrem/filelist.mk b/private/ole32/com/dcomrem/filelist.mk new file mode 100644 index 000000000..c69313e63 --- /dev/null +++ b/private/ole32/com/dcomrem/filelist.mk @@ -0,0 +1,58 @@ +############################################################################ +# +# Copyright (C) 1992, Microsoft Corporation. +# +# All rights reserved. +# +############################################################################ + + +# +# Name of target. Include an extension (.dll, .lib, .exe) +# If the target is part of the release, set RELEASE to 1. +# + +TARGET = remote.lib + +RELEASE = + +# +# Source files. Remember to prefix each name with .\ +# + +CXXFILES = \ + .\channelb.cxx \ + .\chancont.cxx \ + .\callcont.cxx \ + .\callmain.cxx \ + .\imchnl.cxx \ + .\sichnl.cxx \ + .\service.cxx \ + .\endpnt.cxx \ + .\remhdlr.cxx \ + .\remapi.cxx \ + .\coapi.cxx \ + .\dd.cxx \ + .\stdid.cxx \ + .\idtable.cxx + +CFILES = .\orpc_dbg.c + +# +# Libraries and other object files to link. +# + +OBJFILES = + +# +# Precompiled headers. +# + +PXXFILE = headers.cxx +PFILE = + +CINC = $(CINC) -I..\inc -I..\idl -I..\class -I..\objact $(TRACELOG) + +MTHREAD = 1 + +MULTIDEPEND = MERGED diff --git a/private/ole32/com/dcomrem/hash.cxx b/private/ole32/com/dcomrem/hash.cxx new file mode 100644 index 000000000..ec4c519dd --- /dev/null +++ b/private/ole32/com/dcomrem/hash.cxx @@ -0,0 +1,432 @@ +//+-------------------------------------------------------------------------- +// +// File: hash.cxx +// +// Contents: class for maintaining a hash table. +// +// Classes: CUUIDHashTable +// +//--------------------------------------------------------------------------- +#include <ole2int.h> +#include <hash.hxx> // CUUIDHashTable +#include <locks.hxx> // ASSERT_LOCK_HELD +#include <service.hxx> // SASIZE + + +//+------------------------------------------------------------------------ +// Type definitions + +typedef struct +{ + const IPID *pIpid; + SECURITYBINDING *pName; +} SNameKey; + +//+------------------------------------------------------------------------ +// +// Secure references hash table buckets. This is defined as a global +// so that we dont have to run any code to initialize the hash table. +// +//+------------------------------------------------------------------------ +SHashChain SRFBuckets[23] = +{ + {&SRFBuckets[0], &SRFBuckets[0]}, + {&SRFBuckets[1], &SRFBuckets[1]}, + {&SRFBuckets[2], &SRFBuckets[2]}, + {&SRFBuckets[3], &SRFBuckets[3]}, + {&SRFBuckets[4], &SRFBuckets[4]}, + {&SRFBuckets[5], &SRFBuckets[5]}, + {&SRFBuckets[6], &SRFBuckets[6]}, + {&SRFBuckets[7], &SRFBuckets[7]}, + {&SRFBuckets[8], &SRFBuckets[8]}, + {&SRFBuckets[9], &SRFBuckets[9]}, + {&SRFBuckets[10], &SRFBuckets[10]}, + {&SRFBuckets[11], &SRFBuckets[11]}, + {&SRFBuckets[12], &SRFBuckets[12]}, + {&SRFBuckets[13], &SRFBuckets[13]}, + {&SRFBuckets[14], &SRFBuckets[14]}, + {&SRFBuckets[15], &SRFBuckets[15]}, + {&SRFBuckets[16], &SRFBuckets[16]}, + {&SRFBuckets[17], &SRFBuckets[17]}, + {&SRFBuckets[18], &SRFBuckets[18]}, + {&SRFBuckets[19], &SRFBuckets[19]}, + {&SRFBuckets[20], &SRFBuckets[20]}, + {&SRFBuckets[21], &SRFBuckets[21]}, + {&SRFBuckets[22], &SRFBuckets[22]} +}; + +CNameHashTable gSRFTbl; + + +//--------------------------------------------------------------------------- +// +// Function: DummyCleanup +// +// Synopsis: Callback for CHashTable::Cleanup that does nothing. +// +//--------------------------------------------------------------------------- +void DummyCleanup( SHashChain *pIgnore ) +{ +} + +//--------------------------------------------------------------------------- +// +// Method: CHashTable::Cleanup +// +// Synopsis: Cleans up the hash table by deleteing leftover entries. +// +//--------------------------------------------------------------------------- +void CHashTable::Cleanup(PFNCLEANUP *pfnCleanup) +{ + Win4Assert(pfnCleanup); + ASSERT_LOCK_HELD + + for (ULONG iHash=0; iHash < NUM_HASH_BUCKETS; iHash++) + { + // the ptrs could be NULL if the hash table was never initialized. + + while (_buckets[iHash].pNext != NULL && + _buckets[iHash].pNext != &_buckets[iHash]) + { + // remove the entry from the list and call it's cleanup function + SHashChain *pNode = _buckets[iHash].pNext; + + Remove(pNode); + (pfnCleanup)(pNode); + } + } + +#if DBG==1 + // Verify that the hash table is empty or uninitialized. + for (iHash = 0; iHash < NUM_HASH_BUCKETS; iHash++) + { + Win4Assert( _buckets[iHash].pNext == &_buckets[iHash] || + _buckets[iHash].pNext == NULL); + Win4Assert( _buckets[iHash].pPrev == &_buckets[iHash] || + _buckets[iHash].pPrev == NULL); + } +#endif +} + +//--------------------------------------------------------------------------- +// +// Method: CHashTable::Lookup +// +// Synopsis: Searches for a given key in the hash table. +// +// Note: iHash is between 0 and -1, not 0 and NUM_HASH_BUCKETS +// +//--------------------------------------------------------------------------- +SHashChain *CHashTable::Lookup(DWORD dwHash, const void *k) +{ + ASSERT_LOCK_HELD + + // compute the index to the hash chain (it's the hash value of the key + // mod the number of buckets in the hash table) + + DWORD iHash = dwHash % NUM_HASH_BUCKETS; + + SHashChain *pNode = _buckets[iHash].pNext; + + // Search the destination bucket for the key. + while (pNode != &_buckets[iHash]) + { + if (Compare( k, pNode, dwHash )) + return pNode; + + pNode = pNode->pNext; + } + + return NULL; +} + +//--------------------------------------------------------------------------- +// +// Method: CHashTable::Add +// +// Synopsis: Adds an element to the hash table. The Cleanup method will +// call a Cleanup function that can be used to delete the +// element. +// +// Note: iHash is between 0 and -1, not 0 and NUM_HASH_BUCKETS +// +//--------------------------------------------------------------------------- +void CHashTable::Add(DWORD dwHash, SHashChain *pNode) +{ + ASSERT_LOCK_HELD + + // Add the node to the bucket chain. + SHashChain *pHead = &_buckets[dwHash % NUM_HASH_BUCKETS]; + SHashChain *pNew = pNode; + + pNew->pPrev = pHead; + pHead->pNext->pPrev = pNew; + pNew->pNext = pHead->pNext; + pHead->pNext = pNew; +} + +//--------------------------------------------------------------------------- +// +// Method: CUUIDHashTable::Remove +// +// Synopsis: Removes an element from the hash table. +// +//--------------------------------------------------------------------------- +void CHashTable::Remove(SHashChain *pNode) +{ + ASSERT_LOCK_HELD + + pNode->pPrev->pNext = pNode->pNext; + pNode->pNext->pPrev = pNode->pPrev; +} + + +//--------------------------------------------------------------------------- +// +// Method: CUUIDHashTable::HashNode +// +// Synopsis: Computes the hash value for a given node. +// +//--------------------------------------------------------------------------- +DWORD CUUIDHashTable::HashNode(SHashChain *pNode) +{ + return Hash( ((SUUIDHashNode *) pNode)->key ); +} + +//--------------------------------------------------------------------------- +// +// Method: CUUIDHashTable::Compare +// +// Synopsis: Compares a node and a key. +// +//--------------------------------------------------------------------------- +BOOL CUUIDHashTable::Compare(const void *k, SHashChain *pNode, DWORD dwHash ) +{ + return InlineIsEqualGUID(*(const UUID *)k, + ((SUUIDHashNode *)pNode)->key); +} + +//--------------------------------------------------------------------------- +// +// Method: CStringHashTable::Hash +// +// Synopsis: Computes the hash value for a given key. +// +//--------------------------------------------------------------------------- +DWORD CStringHashTable::Hash(DUALSTRINGARRAY *psaKey) +{ + DWORD dwHash = 0; + DWORD *pdw = (DWORD *) &psaKey->aStringArray[0]; + + for (USHORT i=0; i< (psaKey->wNumEntries/2); i++) + { + dwHash = (dwHash << 8) ^ *pdw++; + } + + return dwHash; +} + +//--------------------------------------------------------------------------- +// +// Method: CStringHashTable::HashNode +// +// Synopsis: Computes the hash value for a given node. +// +//--------------------------------------------------------------------------- +DWORD CStringHashTable::HashNode(SHashChain *pNode) +{ + return Hash( ((SStringHashNode *) pNode)->psaKey ); +} + +//--------------------------------------------------------------------------- +// +// Method: CStringHashTable::Compare +// +// Synopsis: Compares a node and a key. +// +//--------------------------------------------------------------------------- +BOOL CStringHashTable::Compare(const void *k, SHashChain *pNode, DWORD dwHash ) +{ + SStringHashNode *pSNode = (SStringHashNode *) pNode; + const DUALSTRINGARRAY *psaKey = (const DUALSTRINGARRAY *) k; + + if (dwHash == pSNode->dwHash) + { + // a quick compare of the hash values found a match, now do + // a full compare of the key (Note: if the sizes of the two + // Keys are different, we exit the memcmp on the first dword, + // so we dont have to worry about walking off the endo of one + // of the Keys during the memcmp). + + return !memcmp(psaKey, pSNode->psaKey, SASIZE(psaKey->wNumEntries)); + } + return FALSE; +} + +//--------------------------------------------------------------------------- +// +// Method: CNameHashTable::Cleanup +// +// Synopsis: Call the base cleanup routine with a dummy callback function +// +//--------------------------------------------------------------------------- +void CNameHashTable::Cleanup() +{ + CHashTable::Cleanup( DummyCleanup ); +} + +//--------------------------------------------------------------------------- +// +// Method: CNameHashTable::Hash +// +// Synopsis: Computes the hash value for a given key. +// +//--------------------------------------------------------------------------- +DWORD CNameHashTable::Hash( REFIPID ipid, SECURITYBINDING *pName ) +{ + DWORD dwHash = 0; + DWORD *pdw = (DWORD *) &ipid; + DWORD dwLen = lstrlenW( (WCHAR *) pName ) >> 1; + ULONG i; + + // First hash the IPID. + for (i=0; i < 4; i++) + { + dwHash = (dwHash << 8) ^ *pdw++; + } + + // Then hash the name. + pdw = (DWORD *) pName; + for (i=0; i < dwLen; i++) + { + dwHash = (dwHash << 8) ^ *pdw++; + } + + return dwHash; +} + +//--------------------------------------------------------------------------- +// +// Method: CNameHashTable::HashNode +// +// Synopsis: Computes the hash value for a given node. +// +//--------------------------------------------------------------------------- +DWORD CNameHashTable::HashNode(SHashChain *pNode) +{ + SNameHashNode *pNNode = (SNameHashNode *) pNode; + return Hash( pNNode->ipid, &pNNode->sName ); +} + +//--------------------------------------------------------------------------- +// +// Method: CNameHashTable::Compare +// +// Synopsis: Compares a node and a key. +// +//--------------------------------------------------------------------------- +BOOL CNameHashTable::Compare(const void *k, SHashChain *pNode, DWORD dwHash ) +{ + SNameHashNode *pNNode = (SNameHashNode *) pNode; + const SNameKey *pKey = (const SNameKey *) k; + + if (dwHash == pNNode->dwHash) + { + // a quick compare of the hash values found a match, now do + // a full compare of the key + if (*pKey->pIpid == pNNode->ipid) + return !lstrcmpW( (WCHAR *) pKey->pName, (WCHAR *) &pNNode->sName ); + else + return FALSE; + } + + return FALSE; +} + +//--------------------------------------------------------------------------- +// +// Method: CNameHashTable::IncRef +// +// Synopsis: Find or create an entry for the specified name. Increment +// its reference count. +// +//--------------------------------------------------------------------------- +HRESULT CNameHashTable::IncRef( ULONG cRefs, REFIPID ipid, + SECURITYBINDING *pName ) +{ + SNameHashNode *pNode; + DWORD dwHash = Hash( ipid, pName ); + HRESULT hr = S_OK; + ULONG lLen; + SNameKey key; + + ASSERT_LOCK_HELD + + // See if there is already a node in the table. + key.pIpid = &ipid; + key.pName = pName; + pNode = (SNameHashNode *) Lookup( dwHash, &key ); + + // If not, create one. + if (pNode == NULL) + { + lLen = lstrlenW( (WCHAR *) pName ); + pNode = (SNameHashNode *) PrivMemAlloc( sizeof(SNameHashNode) + + lLen*sizeof(WCHAR) ); + if (pNode != NULL) + { + pNode->cRef = 0; + pNode->dwHash = dwHash; + pNode->ipid = ipid; + memcpy( &pNode->sName, pName, (lLen + 1) * sizeof(WCHAR) ); + Add( dwHash, &pNode->chain ); + } + else + hr = E_OUTOFMEMORY; + } + + // Increment the reference count on the node. + if (pNode != NULL) + pNode->cRef += cRefs; + return hr; +} + +//--------------------------------------------------------------------------- +// +// Method: CNameHashTable::DecRef +// +// Synopsis: Decrement references for the specified name. Do not decrement +// more references then exist. Return the actual decrement count. +// +//--------------------------------------------------------------------------- +ULONG CNameHashTable::DecRef( ULONG cRefs, REFIPID ipid, + SECURITYBINDING *pName ) +{ + SNameHashNode *pNode; + DWORD dwHash = Hash( ipid, pName ); + SNameKey key; + + ASSERT_LOCK_HELD + + // Lookup the name. + key.pIpid = &ipid; + key.pName = pName; + pNode = (SNameHashNode *) Lookup( dwHash, &key ); + + if (pNode != NULL) + { + if (pNode->cRef < cRefs) + cRefs = pNode->cRef; + pNode->cRef -= cRefs; + if (pNode->cRef == 0) + { + Remove( &pNode->chain ); + PrivMemFree( pNode ); + } + } + else + cRefs = 0; + + return cRefs; +} + diff --git a/private/ole32/com/dcomrem/hash.hxx b/private/ole32/com/dcomrem/hash.hxx new file mode 100644 index 000000000..38fc27096 --- /dev/null +++ b/private/ole32/com/dcomrem/hash.hxx @@ -0,0 +1,277 @@ +//+-------------------------------------------------------------------------- +// +// File: hash.hxx +// +// Contents: class for maintaining a GUID-based hash table. +// +// Classes: CHashTable +// +// History: 20-Feb-95 Rickhi Created +// +//--------------------------------------------------------------------------- +#ifndef _HASHTBL_HXX_ +#define _HASHTBL_HXX_ + +#include <obase.h> + +//--------------------------------------------------------------------------- +// +// Structure: SHashChain +// +// Synopsis: An element in the double link list. Used by S*HashNode and +// C*HashTable. +// +//--------------------------------------------------------------------------- +typedef struct SHashChain +{ + struct SHashChain *pNext; // ptr to next node in chain + struct SHashChain *pPrev; // ptr to prev node in chain +} SHashChain; + +//--------------------------------------------------------------------------- +// +// Structure: SUUIDHashNode +// +// Synopsis: This is an element in a bucket in the UUID hash table. +// +//--------------------------------------------------------------------------- +typedef struct SUUIDHashNode +{ + SHashChain chain; // double linked list ptrs + UUID key; // node key (the value that is hashed) +} SUUIDHashNode; + +//--------------------------------------------------------------------------- +// +// Structure: SStringHashNode +// +// Synopsis: This is an element in a bucket in the string hash table. +// +//--------------------------------------------------------------------------- +typedef struct SStringHashNode +{ + SHashChain chain; // double linked list ptrs + DWORD dwHash; // hash value of the key + DUALSTRINGARRAY *psaKey; // node key (the value that is hashed) +} SStringHashNode; + +//--------------------------------------------------------------------------- +// +// Structure: SNameHashNode +// +// Synopsis: This is an element in a bucket in the name hash table. +// +//--------------------------------------------------------------------------- +typedef struct SNameHashNode +{ + SHashChain chain; // double linked list ptrs + DWORD dwHash; // hash value of the key + ULONG cRef; // count of references + IPID ipid; // ipid holding the reference + SECURITYBINDING sName; // user name +} SNameHashNode; + + +// ptr to cleanup function that gets called by Cleanup +typedef void (PFNCLEANUP)(SHashChain *pNode); + + +// number of buckets in the hash table array. It should be a prime +// number > 20. + +#define NUM_HASH_BUCKETS 23 + + +//--------------------------------------------------------------------------- +// External definitions. + +class CNameHashTable; +extern SHashChain SRFBuckets[23]; +extern CNameHashTable gSRFTbl; + +//--------------------------------------------------------------------------- +// +// Class: CHashTable +// +// Synopsis: Base hash table. The table uses any key +// and stores nodes in an array of circular double linked lists. +// The hash value of the key is the index in the array to the +// double linked list that the node is chained off. +// +// Nodes must be allocated with new. A cleanup function is +// optionally called for each node when the table is cleaned +// up. +// +// Inheritors of this class must supply the HashNode +// and Compare functions. +// +// Notes: All locking must be done outside the class via LOCK/UNLOCK. +// +//--------------------------------------------------------------------------- +class CHashTable +{ +public: + virtual void Initialize(SHashChain *pChain) { _buckets = pChain; } + virtual void Cleanup(PFNCLEANUP *pfn); + void Remove(SHashChain *pNode); + + virtual BOOL Compare(const void *k, SHashChain *pNode, DWORD dwHash) = 0; + virtual DWORD HashNode(SHashChain *pNode) = 0; + +protected: + SHashChain *Lookup(DWORD dwHash, const void *k); + void Add(DWORD dwHash, SHashChain *pNode); + +private: + SHashChain *_buckets; // ptr to array of double linked lists +}; + + + +//--------------------------------------------------------------------------- +// +// Class: CUUIDHashTable +// +// Synopsis: This table inherits from CHashTable. It hashs based on a UUID. +// +// Nodes must be allocated with new. A cleanup function is +// optionally called for each node when the table is cleaned up. +// +// Notes: All locking must be done outside the class via LOCK/UNLOCK. +// +//--------------------------------------------------------------------------- +class CUUIDHashTable : public CHashTable +{ +public: + virtual BOOL Compare(const void *k, SHashChain *pNode, DWORD dwHash); + virtual DWORD HashNode(SHashChain *pNode); + + DWORD Hash(REFGUID k); + SUUIDHashNode *Lookup(DWORD dwHash, REFGUID k); + void Add(DWORD dwHash, REFGUID k, SUUIDHashNode *pNode); +}; + + +//--------------------------------------------------------------------------- +// +// Class: CStringHashTable +// +// Synopsis: String based hash table, uses a DUALSTRINGARRAY as the key, +// +// Nodes must be allocated with new. A cleanup function is +// optionally called for each node when the table is cleaned up. +// +// Notes: All locking must be done outside the class via LOCK/UNLOCK. +// +//--------------------------------------------------------------------------- +class CStringHashTable : public CHashTable +{ +public: + virtual BOOL Compare(const void *k, SHashChain *pNode, DWORD dwHash); + virtual DWORD HashNode(SHashChain *pNode); + + DWORD Hash(DUALSTRINGARRAY *psaKey); + SStringHashNode *Lookup(DWORD dwHash, DUALSTRINGARRAY *psaKey); + void Add(DWORD dwHash, DUALSTRINGARRAY *psaKey, SStringHashNode *pNode); +}; + + +//--------------------------------------------------------------------------- +// +// Class: CNameHashTable +// +// Synopsis: Name based hash table, uses a string as the key, +// +// Nodes must be allocated with new. A cleanup function is +// optionally called for each node when the table is cleaned up. +// +// Notes: All locking must be done outside the class via LOCK/UNLOCK. +// +//--------------------------------------------------------------------------- +class CNameHashTable : public CHashTable +{ +public: + virtual BOOL Compare(const void *k, SHashChain *pNode, DWORD dwHash); + virtual DWORD HashNode(SHashChain *pNode); + + void Cleanup(); + ULONG DecRef( ULONG cRefs, REFIPID ipid, SECURITYBINDING *pName ); + DWORD Hash ( REFIPID ipid, SECURITYBINDING *pName ); + HRESULT IncRef( ULONG cRefs, REFIPID ipid, SECURITYBINDING *pName ); + void Initialize() { CHashTable::Initialize( SRFBuckets ); } +}; + + +//--------------------------------------------------------------------------- +// +// Method: CUUIDHashTable::Hash +// +// Synopsis: Computes the hash value for a given key. +// +//--------------------------------------------------------------------------- +inline DWORD CUUIDHashTable::Hash(REFIID k) +{ + const DWORD *tmp = (DWORD *) &k; + DWORD sum = tmp[0] + tmp[1] + tmp[2] + tmp[3]; + return sum % NUM_HASH_BUCKETS; +} + +//--------------------------------------------------------------------------- +// +// Method: CUUIDHashTable::Lookup +// +// Synopsis: Finds the node with the requested key. +// +//--------------------------------------------------------------------------- +inline SUUIDHashNode *CUUIDHashTable::Lookup(DWORD iHash, REFGUID k) +{ + return (SUUIDHashNode *) CHashTable::Lookup( iHash, (const void *) &k ); +} + +//--------------------------------------------------------------------------- +// +// Method: CUUIDHashTable::Add +// +// Synopsis: Inserts the specified node. +// +//--------------------------------------------------------------------------- +inline void CUUIDHashTable::Add(DWORD iHash, REFGUID k, SUUIDHashNode *pNode) +{ + // set the key + pNode->key = k; + + CHashTable::Add( iHash, (SHashChain *) pNode ); +} + +//--------------------------------------------------------------------------- +// +// Method: CStringHashTable::Lookup +// +// Synopsis: Searches for a given key in the hash table. +// +//--------------------------------------------------------------------------- +inline SStringHashNode *CStringHashTable::Lookup(DWORD dwHash, DUALSTRINGARRAY *psaKey) +{ + return (SStringHashNode *) CHashTable::Lookup( dwHash, (const void *) psaKey); +} + +//--------------------------------------------------------------------------- +// +// Method: CStringHashTable::Add +// +// Synopsis: Adds an element to the hash table. The element must +// be allocated using new. The Cleanup method will call +// an optional Cleanup function, then will call delete on +// the element. +// +//--------------------------------------------------------------------------- +inline void CStringHashTable::Add(DWORD dwHash, DUALSTRINGARRAY *psaKey, SStringHashNode *pNode) +{ + // set the key and hash values + pNode->psaKey = psaKey; + pNode->dwHash = dwHash; + + CHashTable::Add( dwHash, (SHashChain *) pNode ); +} + +#endif // _HASHTBL_HXX_ diff --git a/private/ole32/com/dcomrem/idtable.cxx b/private/ole32/com/dcomrem/idtable.cxx new file mode 100644 index 000000000..10f8c3b6c --- /dev/null +++ b/private/ole32/com/dcomrem/idtable.cxx @@ -0,0 +1,625 @@ +//+------------------------------------------------------------------- +// +// File: idtable.cxx +// +// Contents: identity table +// +// Functions: +// +// History: 1-Dec-93 CraigWi Created +// 14-Apr-95 Rickhi ReVamped +// +//-------------------------------------------------------------------- +#include <ole2int.h> +#include <idtable.hxx> +#include <locks.hxx> // LOCK/UNLOCK +#include <resolver.hxx> // gResolver +#include <comsrgt.hxx> + +CIDArray *gpOIDTable = NULL; + +//+------------------------------------------------------------------- +// +// Function: LookupIDFromUnk, private +// +// Synopsis: Looks up and may create the identity object for the given +// object. If the identity object is created, it is not +// aggregated to the given object. +// +// Identity lookup is based on pUnkControl. +// +// Arguments: [pUnk] -- the object; not necessarily the controlling unknown +// [dwflags] -- see IDLFLAGS in idtable.hxx +// [ppStdId] -- when S_OK is returned, this is the identity +// +// Returns: S_OK - identity now exists (whether created here or not) +// CO_E_OBJNOTREG - no identity and !fCreate +// E_OUTOFMEMORY - +// E_UNEXPECTED - at least: no controlling unknown +// +// +// Notes: If the StdId is client-side, the returned pointer will hold +// the object alive. +// +// If the StdId is server-side, the returned pointer will hold +// the object alive only if IDLF_STRONG is set, otherwise it +// just holds the identity alive. +// +// History: 11-Dec-93 CraigWi Created. +// +//-------------------------------------------------------------------- +INTERNAL LookupIDFromUnk(IUnknown *pUnk, DWORD dwFlags, CStdIdentity **ppStdId) +{ + // QI for IStdID; if ok, return that + if (pUnk->QueryInterface(IID_IStdIdentity, (void **)ppStdId) == NOERROR) + return S_OK; + + // QI for controlling unknown; should succeed + IUnknown *pUnkControl; + if (pUnk->QueryInterface(IID_IUnknown, (void **)&pUnkControl) != NOERROR) + return E_UNEXPECTED; + + + HRESULT hr = S_OK; + CStdIdentity *pStdId = NULL; + CStdIdentity *pStdIdForRelease = NULL; + + // scan for value in map; may find one attached to object created by now + IDENTRY identry; + identry.m_tid = GetCurrentApartmentId(); + identry.m_pUnkControl = pUnkControl; + + // lock others out of the table while we do our stuff... + ASSERT_LOCK_RELEASED + LOCK + + int iID; + + if (gpOIDTable == NULL) + { + iID = -1; + hr = CO_E_OBJNOTREG; + + if (dwFlags & IDLF_CREATE) + { + hr = E_OUTOFMEMORY; + gpOIDTable = new CIDArray; + } + + if (gpOIDTable == NULL) + { + UNLOCK + ASSERT_LOCK_RELEASED + pUnkControl->Release(); + *ppStdId = NULL; + return hr; + } + + // change the GrowBy value to something better than 1 + gpOIDTable->SetSize(0, 20); + } + else + { + iID = gpOIDTable->IndexOf((void *)&identry.m_tid, + sizeof(identry.m_tid) + sizeof(identry.m_pUnkControl), + offsetof(IDENTRY, m_tid)); + } + + Win4Assert(gpOIDTable != NULL); + + if (iID == -1) + { + hr = CO_E_OBJNOTREG; // assume no creation + + if (dwFlags & IDLF_CREATE) + { + // try to create one. Must release the lock to do this since + // we have to go ask the app a bunch of questions. + + UNLOCK + ASSERT_LOCK_RELEASED + + hr = S_OK; + IUnknown *pUnkID; // internal unknown of Identity, ignored on + // the server side. + + pStdId = new CStdIdentity(STDID_SERVER, NULL,pUnkControl, &pUnkID); + if (pStdId == NULL) + { + hr = E_OUTOFMEMORY; + } + + ASSERT_LOCK_RELEASED + LOCK + + if (SUCCEEDED(hr)) + { + MOID moid; + if (dwFlags & IDLF_NOPING) + { + // object wont be pinged so dont bother using a + // pre-registered oid, just use a reserved one. Save + // the pre-registered ones for pinged objects. + hr = gResolver.ServerGetReservedMOID(&moid); + } + else + { + // object will be pinged, so get a pre-registered OID. + // Do this while the lock is released incase we have + // to Rpc to the resolver. Note this could yield if we + // have to pre-register more OIDs so do this before + // checking the table again. + hr = gResolver.ServerGetPreRegMOID(&moid); + } + + if (SUCCEEDED(hr)) + { + // while we released the lock, another thread could have + // come along and created the identity for this object, + // so we need to check again. + + iID = gpOIDTable->IndexOf((void *)&identry.m_tid, + sizeof(identry.m_tid) + sizeof(identry.m_pUnkControl), + offsetof(IDENTRY, m_tid)); + + if (iID == -1) + { + // make the created StdId the identity for the object. + hr = pStdId->SetOID(moid); + + if (SUCCEEDED(hr)) + { + // need to set the marshal time of the object to + // ensure that it does not run down when if the lock + // is released before our first marshal is complete. + pStdId->SetMarshalTime(); + + if (dwFlags & IDLF_STRONG) + { + pStdId->IncStrongCnt(); + pStdId->Release(); + } + + if (dwFlags & IDLF_NOPING) + { + pStdId->SetNoPing(); + } + } + else + { + // OOM on SetOID, release Identity and return error + pStdIdForRelease = pStdId; + pStdId = NULL; + Win4Assert(iID == -1); + } + } + else + { + // release the one we created and use the one in the + // tbl. we get it below in the (iID != -1) case. + pStdIdForRelease = pStdId; + } + } + else + { + // cant allocate an OID. Release the StdId we created + // when we exit the lock and return an error + + pStdIdForRelease = pStdId; + pStdId = NULL; + Win4Assert(iID == -1); + } + } + } + } + + if (iID != -1) + { + // found, addref pStdId which holds the identity alive + pStdId = gpOIDTable->ElementAt(iID).m_pStdID; + + if (dwFlags & IDLF_STRONG) + pStdId->IncStrongCnt(); + else + pStdId->AddRef(); + + Win4Assert(hr == S_OK); + } + + +#if DBG == 1 + if (pStdId != NULL) + { + if (iID == -1) + { + // object was created, need to get the iID for debug + iID = gpOIDTable->IndexOf((void *)&identry.m_tid, + sizeof(identry.m_tid) + sizeof(identry.m_pUnkControl), + offsetof(IDENTRY, m_tid)); + Win4Assert(iID != -1); + } + + // verify correctness of entry + Win4Assert(pUnkControl == gpOIDTable->ElementAt(iID).m_pUnkControl); + Win4Assert(IsEqualGUID(pStdId->GetOID(), gpOIDTable->ElementAt(iID).m_moid)); + Win4Assert(gpOIDTable->ElementAt(iID).m_tid == identry.m_tid); + } +#endif + + *ppStdId = pStdId; + + UNLOCK + ASSERT_LOCK_RELEASED + + // Release any of the pointers we dont need. Must unlock before + // doing this cause it will call app code. + + if (pStdIdForRelease) + { + ASSERT_LOCK_RELEASED + pStdIdForRelease->Release(); + } + + pUnkControl->Release(); + return hr; +} + +//+------------------------------------------------------------------- +// +// Function: LookupIDFromID, private +// +// Synopsis: Lookup an identity object based on an OID; does not create. +// +// Arguments: [moid] -- The identity +// [ppStdID] -- The cooresponding identity object if successful +// +// Returns: S_OK - have the identity object +// CO_E_OBJNOTREG - not present (when we looked) +// +// History: 11-Dec-93 CraigWi Created. +// +//-------------------------------------------------------------------- +INTERNAL LookupIDFromID(REFMOID moid, BOOL fAddRef, CStdIdentity **ppStdID) +{ +// ComDebOut((DEB_MARSHAL, "LookupIDFromID fAddRef:%x ppStdId:%x oid:%I\n", +// fAddRef, ppStdID, &moid)); + ASSERT_LOCK_HELD + + *ppStdID = NULL; + + if (gpOIDTable == NULL) + { + // no table, dont do lookup + return CO_E_OBJNOTREG; + } + + IDENTRY identry; + identry.m_moid = moid; + identry.m_tid = GetCurrentApartmentId(); + + int iID = gpOIDTable->IndexOf((void *)&identry.m_moid, + sizeof(identry.m_moid) + sizeof(identry.m_tid), + offsetof(IDENTRY, m_moid)); + + if (iID != -1) + { + // found, addref pStdID which holds the identity alive + *ppStdID = gpOIDTable->ElementAt(iID).m_pStdID; + + if (fAddRef) + { + // I sure hope the app doesn't try anything fancy in AddRef + // that would cause a deadlock here! (That is, in the aggregated + // case we will run app code). + (*ppStdID)->AddRef(); + } + +#if DBG == 1 + // verify correctness of entry + Win4Assert(IsEqualGUID(moid, gpOIDTable->ElementAt(iID).m_moid)); + Win4Assert(IsEqualGUID(moid, (*ppStdID)->GetOID())); + Win4Assert(gpOIDTable->ElementAt(iID).m_tid == identry.m_tid); +#endif + } + + return (*ppStdID == NULL) ? CO_E_OBJNOTREG : NOERROR; +} + + +//+------------------------------------------------------------------- +// +// Function: SetObjectID, private +// +// Synopsis: Called by the object id creation and unmarshal functions +// to establish the identity for an object (handler or server). +// Can fail if we discover an existing identity. +// +// Identity lookup is based on pUnkControl. +// +// Arguments: [moid] -- The id for the object +// [pUnkControl] -- The controlling uknown of the object being +// identitified. +// [pStdID] -- The identity object itself. +// +// Returns: S_OK - identity was set successfully +// CO_E_OBJISREG - object was already registered (as determined +// by pUnkControl); *ppStdIDExisting set (if requested). +// E_OUTOFMEMORY - +// +// History: 11-Dec-93 CraigWi Created. +// +//-------------------------------------------------------------------- +INTERNAL SetObjectID(REFMOID moid, IUnknown *pUnkControl, CStdIdentity *pStdID) +{ + ComDebOut((DEB_MARSHAL, "SetObjectID pUnk:%x pStdId:%x oid:%I\n", + pUnkControl, pStdID, &moid)); + Win4Assert(!IsEqualGUID(moid, GUID_NULL)); + ASSERT_LOCK_HELD + + + HRESULT hr = S_OK; + if (gpOIDTable == NULL) + { + gpOIDTable = new CIDArray; + if (gpOIDTable == NULL) + return E_OUTOFMEMORY; + + // change the GrowBy value to something better than 1 + gpOIDTable->SetSize(0, 20); + } + + IDENTRY identry; + identry.m_moid = moid; + identry.m_tid = GetCurrentApartmentId(); + identry.m_pUnkControl = pUnkControl; + identry.m_pStdID = pStdID; + +#if DBG==1 + // scan for value in map; better not find one + // CODEWORK: for freethreaded handler case we may need to allow + // finding a duplicate entry, and throw away the second copy. + + int iID = gpOIDTable->IndexOf((void *)&identry.m_tid, + sizeof(identry.m_tid) + sizeof(identry.m_pUnkControl), + offsetof(IDENTRY, m_tid)); + + if (iID != -1) + { + // if found, another thread created identity for same object; + // this is an error. + Win4Assert(!"Already Registered OID"); + } +#endif + + // add at end; no addrefs + if (gpOIDTable->Add(identry) == -1) + hr = E_OUTOFMEMORY; + + ASSERT_LOCK_HELD + return hr; +} + +//+------------------------------------------------------------------- +// +// Function: ClearObjectID, private +// +// Synopsis: Called during the revokation of the id only. Clears +// the identity entry in the table. +// +// Identity lookup is based on oid. +// +// Arguments: [moid] -- The identity +// [pUnkControl] -- The object for which the identity is being +// revoked; used for asserts only. +// [pStdID] -- The identity object; used for asserts only. +// +// Returns: S_OK - removed successfully +// CO_E_OBJNOTREG - not present (often ignored). +// +// History: 11-Dec-93 CraigWi Created. +// +//-------------------------------------------------------------------- +INTERNAL ClearObjectID(REFMOID moid, IUnknown *pUnkControl, CStdIdentity *pStdID) +{ + ComDebOut((DEB_MARSHAL, "ClearObjectID pUnk:%x pStdId:%x oid:%I\n", + pUnkControl, pStdID, &moid)); + ASSERT_LOCK_HELD + + HRESULT hr = NOERROR; + + IDENTRY identry; + identry.m_moid = moid; + identry.m_tid = GetCurrentApartmentId(); + + int iID = gpOIDTable->IndexOf((void *)&identry.m_moid, + sizeof(identry.m_moid) + sizeof(identry.m_tid), + offsetof(IDENTRY, m_moid)); + + if (iID != -1) + { + // found, remove it. +#if DBG == 1 + // verify correctness of entry + Win4Assert(pUnkControl == gpOIDTable->ElementAt(iID).m_pUnkControl); + Win4Assert(IsEqualGUID(pStdID->GetOID(), gpOIDTable->ElementAt(iID).m_moid)); + Win4Assert(pStdID == gpOIDTable->ElementAt(iID).m_pStdID); + Win4Assert(gpOIDTable->ElementAt(iID).m_tid == identry.m_tid); +#endif + + Win4Assert(gpOIDTable->GetSize() != 0); + int iLast = gpOIDTable->GetSize() - 1; + if (iID != iLast) + { + // element removed is not last; copy last element to current + gpOIDTable->ElementAt(iID) = gpOIDTable->ElementAt(iLast); + } + + // now setsize to one less to remove the now unused last element + gpOIDTable->SetSize(iLast); + + // for surrogates, we need to detect when there are no clients + // using servers in the surrogate process -- we rely on the + // fact that the OIDTable must be empty when there are no clients + + // if there are no external clients, this process should terminate + // if its a surrogate process + if(iLast == 0) + { + (void)CCOMSurrogate::FreeSurrogate(); + } + } + else + { + Win4Assert(!"ClearObjectID not found!"); + hr = CO_E_OBJNOTREG; + + } + + return hr; +} + +//+------------------------------------------------------------------- +// +// Function: IDTableUninitializeHelper, private +// +// Synopsis: Clears the id table memory for the specified thread (or all +// if party model). This involves scanning the table and for +// entries on the current thread, calling +// IMarshal::DisconnectObject. +// +// The purpose of this routine is to simulate inter-thread rundown +// as well as clean up memory. +// +// History: 23-Dec-93 CraigWi Created. +// 26-Apr-94 CraigWi Now called per-thread and disconnects +// +// Note: This function should only be called when the IDTable +// really needs to be uninitialized. For the party model, this +// means that it should only be called when the last thread +// is exiting. +// +// This function must NOT assume that it is being called within +// a critical section. +// +//-------------------------------------------------------------------- +INTERNAL_(void) IDTableThreadUninitializeHelper(DWORD tid) +{ + // The table being uninitialized is resized as items are deleted. Thus + // if an element in the middle is removed, the last element in the table + // will be copied into that slot and the table will shrink. Also, some + // of the calls made while cleaning up an entry will free other entries, + // causing further swapping. + + ASSERT_LOCK_RELEASED + LOCK + + int i = gpOIDTable->GetSize() - 1; + while (i >= 0) + { + if (gpOIDTable->ElementAt(i).m_tid == tid) + { + Win4Assert(IsValidInterface(gpOIDTable->ElementAt(i).m_pStdID)); + CStdIdentity *pStdID = gpOIDTable->ElementAt(i).m_pStdID; + pStdID->AddRef(); + + ComDebOut((DEB_ERROR, + "Object [%s] at %lx still has [%x] connections\n", + pStdID->IsClient() ? "CLIENT" : "SERVER", + gpOIDTable->ElementAt(i).m_pUnkControl, pStdID->GetRC())); + + pStdID->DbgDumpInterfaceList(); + + // release lock since the disconnect could take a long time. + UNLOCK + ASSERT_LOCK_RELEASED + + pStdID->Disconnect(); + pStdID->Release(); + + // re-request the lock since we need to guard the GetSize below + ASSERT_LOCK_RELEASED + LOCK + } + i--; + if (i >= gpOIDTable->GetSize()) + i = gpOIDTable->GetSize() - 1; + } + + UNLOCK + ASSERT_LOCK_RELEASED +} + +//+------------------------------------------------------------------- +// +// Function: IDTableUninitialize, public +// +// Synopsis: Clears the id table memory for the current apartment. +// +// History: 13 Apr 95 AlexMit Created. +// +//-------------------------------------------------------------------- +INTERNAL_(void) IDTableThreadUninitialize(void) +{ + if (gpOIDTable) + { + IDTableThreadUninitializeHelper(GetCurrentApartmentId()); + } +} + +//+------------------------------------------------------------------------- +// +// Function: IDTableProcessUninitialize +// +// Synopsis: Process specific IDTable uninitialization +// +// Effects: Frees up table memory +// +// Requires: All thread specific uninitialization already complete. This +// function assumes that the caller is holding the +// g_mxsSingleThreadOle mutex (so that no other thread is trying +// to use the table while we clean it up). +// +// History: 29-Jun-94 AlexT Created +// +//-------------------------------------------------------------------------- +INTERNAL_(void) IDTableProcessUninitialize(void) +{ + if (gpOIDTable) + { + gpOIDTable->RemoveAll(); + delete gpOIDTable; + gpOIDTable = NULL; + } +} + +#if DBG == 1 +//+------------------------------------------------------------------- +// +// Function: Dbg_FindRemoteHdlr +// +// Synopsis: finds a remote object handler for the specified object, +// and returns an instance of IMarshal on it. This is debug +// code for assert that reference counts are as expected and +// is used by tmarshal.exe. +// +// History: 23-Nov-93 Rickhi Created +// 23-Dec-93 CraigWi Changed to identity object +// +//-------------------------------------------------------------------- +extern "C" IMarshal * _stdcall Dbg_FindRemoteHdlr(IUnknown *punkObj) +{ + // validate input parms + Win4Assert(punkObj); + + IMarshal *pIM = NULL; + CStdIdentity *pStdID; + HRESULT hr = LookupIDFromUnk(punkObj, 0, &pStdID); + if (hr == NOERROR) + { + pIM = (IMarshal *)pStdID; + } + + return pIM; +} +#endif // DBG==1 + diff --git a/private/ole32/com/dcomrem/idtable.hxx b/private/ole32/com/dcomrem/idtable.hxx new file mode 100644 index 000000000..41fad6f61 --- /dev/null +++ b/private/ole32/com/dcomrem/idtable.hxx @@ -0,0 +1,88 @@ +//+------------------------------------------------------------------- +// +// File: idtable.hxx +// +// Contents: internal identity table definitions. +// +// Description: +// The table consists, logically, of two different +// keys which map to the same value. The two +// keys are the object identity (oid) and the +// controlling unknown (pUnkControl). The value +// is the pointer to the identity interface. +// +// Additionally, if the apartment model is in use, +// the thread forms an additional part of both keys. +// +// Presently this is implemented as an array of the +// three values and linear scans are used for lookup. +// Neither of the pointers are addref'd (as far as the +// table is concerned). Lookup does addref and return +// a pointer to the identity object. The pUnkControl +// is used solely for lookup. +// +// Possible changes: use two separate maps, one of +// which is keyed by the oid and the other keyed +// by pUnkControl. The value in the first map is +// the hKey of value in the second map. The id object +// would hold the hKey of the value in the first map +// to save on space. Using two maps increases the +// speed of lookup for large numbers of ids, but +// increases the cost per id from 40bytes to 48bytes. +// (40 = 2 GUIDS + 2 far pointers; 48 = 1 GUID + +// 3 hKeys/ptrs + 16 bytes overhead for two hash +// buckets) +// +// It is also possible to use one map (keyed by +// the oid) and trim the array by the GUID. The +// cost per id is 40bytes (1 GUID + 1 hKey + 3 ptr + +// 8 bytes overhead for one hash bucket). +// +// +// History: 1-Dec-93 CraigWi Created +// +//-------------------------------------------------------------------- +#ifndef __IDTABLE__ +#define __IDTABLE__ + +#include <olerem.h> +#include <stdid.hxx> // CStdIdentity +#include <sem.hxx> // CMutexSem + + +// flags passed in on LookupIDFromUnk. + +typedef enum tagIDLFLAGS +{ + IDLF_CREATE = 0x01, // create if not found + IDLF_STRONG = 0x02, // add a strong connection + IDLF_NOPING = 0x04 // object wont be pinged +} IDLFLAGS; + + +// entry in id array. the array is packed (no NULL holes) +// NOTE: when looking up for the apartment model, we pair the two fields +// m_oid/m_tid and m_tid/m_pUnkControl; + +struct IDENTRY +{ + MOID m_moid; // OID + MID + DWORD m_tid; + IUnknown *m_pUnkControl;// not addref'd directly + CStdIdentity *m_pStdID; // not addref'd directly +}; + +#include <array_id.h> + + +// other functions declared in olerem.h +INTERNAL LookupIDFromUnk(IUnknown *pUnk, DWORD dwFlags, CStdIdentity **ppStdId); +INTERNAL LookupIDFromID(REFMOID moid, BOOL fAddRef, CStdIdentity **ppStdId); +INTERNAL SetObjectID(REFMOID moid, IUnknown *pUnkControl, CStdIdentity *pStdID); +INTERNAL ClearObjectID(REFMOID moid, IUnknown *pUnkControl, CStdIdentity *pStdID); +INTERNAL_(void) IDTableThreadUninitialize(void); +INTERNAL_(void) IDTableProcessUninitialize(void); + +INTERNAL GetStaticUnMarshaler(IMarshal **ppMarshal); + +#endif // __IDTABLE__ diff --git a/private/ole32/com/dcomrem/ipidtbl.cxx b/private/ole32/com/dcomrem/ipidtbl.cxx new file mode 100644 index 000000000..21a7b4fac --- /dev/null +++ b/private/ole32/com/dcomrem/ipidtbl.cxx @@ -0,0 +1,1502 @@ +//+----------------------------------------------------------------------- +// +// File: ipidtbl.cxx +// +// Contents: IPID (interface pointer identifier) table. +// +// Classes: CIPIDTable +// +// History: 02-Feb-95 Rickhi Created +// +// Notes: All synchronization is the responsibility of the caller. +// +//------------------------------------------------------------------------- +#include <ole2int.h> +#include <ipidtbl.hxx> // CIPIDTable +#include <resolver.hxx> // CRpcResolver +#include <service.hxx> // SASIZE +#include <remoteu.hxx> // CRemoteUnknown +#include <marshal.hxx> // UnmarshalObjRef +#include <idtable.hxx> // LookupIDFromUnk +#include <callctrl.hxx> // OleModalLoopBlockFn + + +// global tables +CMIDTable gMIDTbl; // machine ID table +COXIDTable gOXIDTbl; // object exported ID table +CIPIDTable gIPIDTbl; // interface pointer ID table + +MIDEntry *gpLocalMIDEntry = NULL; // local machine MIDEntry +OXIDEntry *gpMTAOXIDEntry = NULL; // MTA OXIDEntry +DUALSTRINGARRAY *gpsaLocalResolver = NULL; // local OXIDResolver address + +OXIDEntry COXIDTable::_InUseHead = { &_InUseHead, &_InUseHead }; +OXIDEntry COXIDTable::_CleanupHead = { &_CleanupHead, &_CleanupHead }; +OXIDEntry COXIDTable::_ExpireHead = { &_ExpireHead, &_ExpireHead }; +DWORD COXIDTable::_cExpired = 0; + +CStringHashTable CMIDTable::_HashTbl; // hash table for MIDEntries +CPageAllocator CMIDTable::_palloc; // allocator for MIDEntries +CPageAllocator COXIDTable::_palloc; // allocator for OXIDEntries +CPageAllocator CIPIDTable::_palloc; // allocator for IPIDEntries + + +//+------------------------------------------------------------------------ +// +// Machine Identifier hash table buckets. This is defined as a global +// so that we dont have to run any code to initialize the hash table. +// +//+------------------------------------------------------------------------ +SHashChain MIDBuckets[23] = { + {&MIDBuckets[0], &MIDBuckets[0]}, + {&MIDBuckets[1], &MIDBuckets[1]}, + {&MIDBuckets[2], &MIDBuckets[2]}, + {&MIDBuckets[3], &MIDBuckets[3]}, + {&MIDBuckets[4], &MIDBuckets[4]}, + {&MIDBuckets[5], &MIDBuckets[5]}, + {&MIDBuckets[6], &MIDBuckets[6]}, + {&MIDBuckets[7], &MIDBuckets[7]}, + {&MIDBuckets[8], &MIDBuckets[8]}, + {&MIDBuckets[9], &MIDBuckets[9]}, + {&MIDBuckets[10], &MIDBuckets[10]}, + {&MIDBuckets[11], &MIDBuckets[11]}, + {&MIDBuckets[12], &MIDBuckets[12]}, + {&MIDBuckets[13], &MIDBuckets[13]}, + {&MIDBuckets[14], &MIDBuckets[14]}, + {&MIDBuckets[15], &MIDBuckets[15]}, + {&MIDBuckets[16], &MIDBuckets[16]}, + {&MIDBuckets[17], &MIDBuckets[17]}, + {&MIDBuckets[18], &MIDBuckets[18]}, + {&MIDBuckets[19], &MIDBuckets[19]}, + {&MIDBuckets[20], &MIDBuckets[20]}, + {&MIDBuckets[21], &MIDBuckets[21]}, + {&MIDBuckets[22], &MIDBuckets[22]} +}; + +//+------------------------------------------------------------------------ +// +// Member: CIPIDTbl::Initialize, public +// +// Synopsis: Initializes the IPID table. +// +// History: 02-Feb-96 Rickhi Created +// +//------------------------------------------------------------------------- +void CIPIDTable::Initialize() +{ + ComDebOut((DEB_OXID, "CIPIDTable::Initialize\n")); + _palloc.Initialize(sizeof(IPIDEntry), IPIDS_PER_PAGE); +} + +//+------------------------------------------------------------------------ +// +// Member: CIPIDTbl::Cleanup, public +// +// Synopsis: Cleanup the ipid table. +// +// History: 02-Feb-95 Rickhi Created +// +//------------------------------------------------------------------------- +void CIPIDTable::Cleanup() +{ + ComDebOut((DEB_OXID, "CIPIDTable::Cleanup\n")); + _palloc.Cleanup(); +} + +//+------------------------------------------------------------------------ +// +// Member: CIPIDTbl::LookupIPID, public +// +// Synopsis: Finds an entry in the IPID table with the given IPID. +// This is used by the unmarshalling code, the dispatch +// code, and CRemoteUnknown. +// +// Notes: This method should be called instead of GetEntryPtr +// whenever you dont know if the IPID is valid or not (eg it +// came in off the network), since this validates the IPID +// index to ensure its within the table size, as well as +// validating the rest of the IPID. +// +// History: 02-Feb-95 Rickhi Created +// +//------------------------------------------------------------------------- +IPIDEntry *CIPIDTable::LookupIPID(REFIPID ripid) +{ + ASSERT_LOCK_HELD + + // Validate the IPID index that is passed in, since this came in off + // off the net it could be bogus and we dont want to fault on it. + // first dword of the ipid is the index into the ipid table. + + if (_palloc.IsValidIndex(ripid.Data1)) + { + IPIDEntry *pIPIDEntry = GetEntryPtr(ripid.Data1); + + // entry must be server side and not vacant + if ((pIPIDEntry->dwFlags & (IPIDF_SERVERENTRY | IPIDF_VACANT)) == + IPIDF_SERVERENTRY) + { + // validate the rest of the guid + if (InlineIsEqualGUID(pIPIDEntry->ipid, ripid)) + return pIPIDEntry; + } + } + + return NULL; +} + +//+------------------------------------------------------------------- +// +// Member: CIPIDTable::ReleaseEntryList +// +// Synopsis: return a linked list of IPIDEntry to the table's free list +// +// History: 20-Feb-95 Rickhi Created. +// +//-------------------------------------------------------------------- +void CIPIDTable::ReleaseEntryList(IPIDEntry *pFirst, IPIDEntry *pLast) +{ + ASSERT_LOCK_HELD + Win4Assert(pLast->pNextOID == NULL); + +#if DBG==1 + // In debug, walk the list to ensure they are released, vacant, + // disconnected etc. + IPIDEntry *pEntry = pFirst; + while (pEntry != NULL) + { + Win4Assert(pEntry->pOXIDEntry == NULL); // must already be released + Win4Assert(pEntry->dwFlags & IPIDF_VACANT); + Win4Assert(pEntry->dwFlags & IPIDF_DISCONNECTED); + + pEntry = pEntry->pNextOID; + } +#endif + + _palloc.ReleaseEntryList((PageEntry *)pFirst, (PageEntry *)pLast); +} + +#if DBG==1 +//+------------------------------------------------------------------- +// +// Member: CIPIDTable::ValidateIPIDEntry +// +// Synopsis: Ensures the IPIDEntry is valid. +// +// History: 20-Feb-95 Rickhi Created. +// +//-------------------------------------------------------------------- +void CIPIDTable::ValidateIPIDEntry(IPIDEntry *pEntry, BOOL fServerSide, + CRpcChannelBuffer *pChnl) +{ + // validate the IPID flags + Win4Assert(!(pEntry->dwFlags & IPIDF_VACANT)); + if (fServerSide) + { + // server side must have SERVERENTRY ipids + Win4Assert(pEntry->dwFlags & IPIDF_SERVERENTRY); + } + else + { + // client side must not have SERVERENTRY ipids + Win4Assert(!(pEntry->dwFlags & IPIDF_SERVERENTRY)); + } + + + // Validate the pStub interface + if (IsEqualIID(pEntry->iid, IID_IUnknown)) + { + // there is no proxy or stub for IUnknown interface + Win4Assert(pEntry->pStub == NULL); + } + else + { + if ((pEntry->dwFlags & IPIDF_DISCONNECTED) && + (pEntry->dwFlags & IPIDF_SERVERENTRY)) + { + // disconnected server side has NULL pStub + Win4Assert(pEntry->pStub == NULL); + } + else + { + // both connected and disconnected client side has valid proxy + Win4Assert(pEntry->pStub != NULL); + Win4Assert(IsValidInterface(pEntry->pStub)); + } + } + + + // Validate the interface pointer (pv) + if (!(pEntry->dwFlags & IPIDF_DISCONNECTED)) + { + Win4Assert(pEntry->pv != NULL); + Win4Assert(IsValidInterface(pEntry->pv)); + } + + + // Validate the channel ptr + if (fServerSide) + { + // all stubs share the same channel on the server side + Win4Assert(pEntry->pChnl == pChnl); + } + else + { + // all proxies have their own different channel on client side + Win4Assert(pEntry->pChnl != pChnl || pEntry->pChnl == NULL); + } + + // Validate the RefCnts + if (!(pEntry->dwFlags & IPIDF_DISCONNECTED) && !fServerSide) + { + // if connected, must be > 0 refcnt on client side. + // potentially not > 0 if TABLE marshal on server side. + Win4Assert(pEntry->cStrongRefs + pEntry->cWeakRefs > 0); + } + + // Validate the OXIDEntry + if (pEntry->pOXIDEntry) + { + OXIDEntry *pOX = pEntry->pOXIDEntry; + if (fServerSide) + { + // check OXID tid and pid + Win4Assert(pOX->dwPid == GetCurrentProcessId()); + if ((pOX->dwFlags & OXIDF_MTASERVER)) + Win4Assert(pOX->dwTid == 0); + else + Win4Assert(pOX->dwTid == GetCurrentThreadId()); + + if (pChnl != NULL) + { + // CODEWORK: ensure OXID is same as the rest of the object + // Win4Assert(IsEqualGUID(pOX->moxid, GetMOXID())); + } + } + } + + + // Validate the pNextOID + if (pEntry->pNextOID != NULL) + { + // ensure it is within the bounds of the table + Win4Assert(GetEntryIndex(pEntry) != -1); + + // cant point back to self or we have a circular list + Win4Assert(pEntry->pNextOID != pEntry); + } +} +#endif + + +//+------------------------------------------------------------------------ +// +// Member: COXIDTbl::Initialize, public +// +// Synopsis: +// +// History: 02-Feb-96 Rickhi Created +// +//------------------------------------------------------------------------- +void COXIDTable::Initialize() +{ + ComDebOut((DEB_OXID, "COXIDTable::Initialize\n")); + _palloc.Initialize(sizeof(OXIDEntry), OXIDS_PER_PAGE); +} + +//+------------------------------------------------------------------------ +// +// Member: COXIDTbl::Cleanup, public +// +// Synopsis: Cleanup the OXID table. +// +// History: 02-Feb-95 Rickhi Created +// +//------------------------------------------------------------------------- +void COXIDTable::Cleanup() +{ + ComDebOut((DEB_OXID, "COXIDTable::Cleanup\n")); + ASSERT_LOCK_HELD + + // the lists better be empty before we delete the entries + AssertListsEmpty(); + _palloc.Cleanup(); +} + +//+------------------------------------------------------------------------ +// +// Member: COXIDTbl::AddEntry, public +// +// Synopsis: Adds an entry to the OXID table. The entry is AddRef'd. +// +// History: 02-Feb-95 Rickhi Created +// +//------------------------------------------------------------------------- +HRESULT COXIDTable::AddEntry(REFOXID roxid, OXID_INFO *poxidInfo, + MIDEntry *pMIDEntry, OXIDEntry **ppEntry) +{ + Win4Assert(poxidInfo != NULL); + Win4Assert(pMIDEntry != NULL); + ASSERT_LOCK_HELD + + // find first free entry slot, grow table if necessary + OXIDEntry *pEntry = (OXIDEntry *) _palloc.AllocEntry(); + if (pEntry == NULL) + { + ComDebOut((DEB_ERROR,"Out Of Memory in COXIDTable::AddEntry\n")); + return E_OUTOFMEMORY; + } + + // chain it on the list of inuse entries + pEntry->pPrev = &_InUseHead; + _InUseHead.pNext->pPrev = pEntry; + pEntry->pNext = _InUseHead.pNext; + _InUseHead.pNext = pEntry; + + // Copy oxidInfo into OXIDEntry. + + MOXIDFromOXIDAndMID(roxid, pMIDEntry->mid, &pEntry->moxid); + pEntry->cRefs = 1; // caller gets one reference + pEntry->cWaiters = 0; + pEntry->dwPid = poxidInfo->dwPid; + pEntry->dwTid = poxidInfo->dwTid; + pEntry->dwFlags = (poxidInfo->dwPid == 0) ? 0 : OXIDF_MACHINE_LOCAL; + pEntry->dwFlags |= (poxidInfo->dwTid != 0) ? 0 : OXIDF_MTASERVER; + pEntry->pRUSTA = NULL; + pEntry->pRUMTA = NULL; + pEntry->ipidRundown = poxidInfo->ipidRemUnknown; + pEntry->hServerSTA = NULL; + pEntry->hServerMTA = NULL; + pEntry->pMIDEntry = pMIDEntry; + pEntry->hComplete = NULL; + pEntry->cCalls = 0; + pEntry->cResolverRef = 0; + IncMIDRefCnt(pMIDEntry); + + + HRESULT hr = S_OK; + + if (poxidInfo->dwPid != GetCurrentProcessId()) + { + // This OXID is for an apartment outside the current process. We + // need to make an RPC binding handle from the supplied strings. + + Win4Assert(poxidInfo->psa != NULL && + poxidInfo->psa->aStringArray[0] != 0); + + // Set the MSWMSG flag if the transport is MSWMSG. + RPC_STATUS sc = CheckClientMswmsg(poxidInfo->psa->aStringArray, + &pEntry->dwFlags); + + // Make a binding handle from the string bindings. + if (sc == RPC_S_OK) + { + sc = RpcBindingFromStringBinding(poxidInfo->psa->aStringArray, + &pEntry->hServerSTA); + } + + // Pass our blocking function to MSWMSG. When we make calls out, + // MSWMSG will call the blocking function. + if (sc == RPC_S_OK && (pEntry->dwFlags & OXIDF_MSWMSG)) + { + sc = I_RpcBindingSetAsync(pEntry->hServerSTA, OleModalLoopBlockFn); + } + + // Set security on the binding handle if necessary. + if (sc == RPC_S_OK) + { + hr = SetAuthnService( pEntry->hServerSTA, poxidInfo, pEntry ); + } + else + { + hr = HRESULT_FROM_WIN32(sc); + } + } + + // Get a shutdown event for server side MTAs. Don't use the event + // cache because the event isn't always reset. + else if (pEntry->dwFlags & OXIDF_MTASERVER) + { +#ifdef _CHICAGO_ + pEntry->hComplete = CreateEventA( NULL, FALSE, FALSE, NULL ); +#else //_CHICAGO_ + pEntry->hComplete = CreateEvent( NULL, FALSE, FALSE, NULL ); +#endif //_CHICAGO_ + if (pEntry->hComplete == NULL) + hr = RPC_E_OUT_OF_RESOURCES; + } + + if (FAILED(hr)) + { + // failed, release the OXIDEntry + DecOXIDRefCnt(pEntry); + pEntry = NULL; + } + + *ppEntry = pEntry; + gOXIDTbl.ValidateOXID(); + ComDebOut((DEB_OXID,"COXIDTable::AddEntry pEntry:%x moxid:%I\n", + pEntry, (pEntry) ? &pEntry->moxid : &GUID_NULL)); + return hr; +} + +//+------------------------------------------------------------------------ +// +// Member: COXIDTbl::LookupOXID, public +// +// Synopsis: finds an entry in the OXID table with the given OXID. +// This is used by the unmarshalling code. The returned +// entry has been AddRef'd. +// +// History: 02-Feb-95 Rickhi Created +// +// PERFWORK: we could move the OXIDEntry to the head of the InUse list on +// the assumption that it will be the most frequently used item +// in the near future. +// +//------------------------------------------------------------------------- +OXIDEntry *COXIDTable::LookupOXID(REFOXID roxid, REFMID rmid) +{ + ASSERT_LOCK_HELD + + MOXID moxid; + MOXIDFromOXIDAndMID(roxid, rmid, &moxid); + + // first, search the InUse list. + OXIDEntry *pEntry = SearchList(moxid, &_InUseHead); + + if (pEntry == NULL) + { + // not found on InUse list, search the Expire list. + if ((pEntry = SearchList(moxid, &_ExpireHead)) != NULL) + { + // found it, unchain it from the list of Expire entries + pEntry->pPrev->pNext = pEntry->pNext; + pEntry->pNext->pPrev = pEntry->pPrev; + + // chain it on the list of InUse entries + pEntry->pPrev = &_InUseHead; + _InUseHead.pNext->pPrev = pEntry; + pEntry->pNext = _InUseHead.pNext; + _InUseHead.pNext = pEntry; + + // reset the cRefs field (which was overloaded with the + // expire time by ReleaseEntry), and count one less entry. + + pEntry->cRefs = 1; + _cExpired--; + } + } + + ComDebOut((DEB_OXID,"COXIDTable::LookupOXID pEntry:%x moxid:%I\n", + pEntry, &moxid)); + gOXIDTbl.ValidateOXID(); + return pEntry; +} + +//+------------------------------------------------------------------------ +// +// Member: COXIDTbl::SearchList, private +// +// Synopsis: Searches the specified list for a matching OXID entry. +// This is a subroutine of LookupOXID. +// +// History: 25-Aug-95 Rickhi Created +// +//------------------------------------------------------------------------- +OXIDEntry *COXIDTable::SearchList(REFMOXID rmoxid, OXIDEntry *pStart) +{ + ASSERT_LOCK_HELD + + OXIDEntry *pEntry = pStart->pNext; + while (pEntry != pStart) + { + if (InlineIsEqualGUID(rmoxid, pEntry->moxid)) + { + IncOXIDRefCnt(pEntry); + return pEntry; // found a match, return it + } + + pEntry = pEntry->pNext; // try next one in use + } + + return NULL; +} + +//+------------------------------------------------------------------------ +// +// Member: COXIDTbl::ReleaseEntry, public +// +// Synopsis: removes an entry from the OXID table InUse list and +// places it on the Expire list. Entries on the Expire list +// will be cleaned up by a worker thread at a later time, or +// placed back on the InUse list by LookupOXID. +// +// History: 02-Feb-95 Rickhi Created +// +//------------------------------------------------------------------------- +void COXIDTable::ReleaseEntry(OXIDEntry *pEntry) +{ + Win4Assert(pEntry); + Win4Assert(pEntry->cRefs == 0); // must be no users of this entry + gOXIDTbl.ValidateOXID(); + ASSERT_LOCK_HELD + + if (pEntry->dwFlags & OXIDF_PENDINGRELEASE) + { + return; // already being deleted, just ignore. + } + + // unchain it from the list of InUse entries + pEntry->pPrev->pNext = pEntry->pNext; + pEntry->pNext->pPrev = pEntry->pPrev; + + // chain it on the *END* of the list of Expire entries, and + // count one more expired entry. + pEntry->pPrev = _ExpireHead.pPrev; + pEntry->pNext = &_ExpireHead; + _ExpireHead.pPrev->pNext= pEntry; + _ExpireHead.pPrev = pEntry; + + _cExpired++; + + // set the time when it was placed on the Expire list. This (may be) + // used to determine when this entry should really expire. + pEntry->cRefs = GetTickCount(); + + // Free anything hanging around on the cleanup list. This may release + // the lock. + FreeCleanupEntries(); + + ASSERT_LOCK_HELD + ComDebOut((DEB_OXID,"COXIDTable::ReleaseEntry pEntry:%x\n", pEntry)); +} + +//+------------------------------------------------------------------------ +// +// Member: COXIDTbl::FreeExpiredEntries, public +// +// Synopsis: Walks the Expire list and deletes the OXIDEntries that +// were placed on the expire list before the given time. +// +// History: 02-Feb-95 Rickhi Created +// +//------------------------------------------------------------------------- +void COXIDTable::FreeExpiredEntries(DWORD dwTime) +{ + ASSERT_LOCK_HELD + + while (_ExpireHead.pNext != &_ExpireHead) + { +#if 0 + // CODEWORK: currently we never use the ExpireTime function, + // we only call this routine from ChannelProcessUninit, so ignore + // the expire time and release all the entries. + + // there is an entry on the list. check its time stamp (which + // was placed in the cRefs field) + + if ((DWORD)_ExpireHead.pNext->cRefs - dwTime > 0) + { + // this entry has not yet expired. All entries after this + // one must not have expired either, so exit early. + break; + } +#endif + // unchain it from the list of Expire entries, and count one less + // expired entry. + OXIDEntry *pEntry = _ExpireHead.pNext; + + pEntry->pPrev->pNext = pEntry->pNext; + pEntry->pNext->pPrev = pEntry->pPrev; + + _cExpired--; + + ExpireEntry(pEntry); + } + + // The worker thread moves entries to the cleanup list while holding the + // lock. Since the expire list is now empty no more OXIDs can be added + // to the cleanup list. Now would be a good time to free items on the + // cleanup list. + FreeCleanupEntries(); + + AssertListsEmpty(); // the lists better be empty now + ComDebOut((DEB_OXID, "COXIDTable::FreeExpiredEntries dwTime:%x\n", dwTime)); +} + +//+------------------------------------------------------------------------ +// +// Member: COXIDTbl::FreeCleanupEntries, public +// +// Synopsis: Deletes all OXID entries on the Cleanup list. +// +// History: 02-Feb-95 Rickhi Created +// +//------------------------------------------------------------------------- +void COXIDTable::FreeCleanupEntries() +{ + ASSERT_LOCK_HELD + + while (_CleanupHead.pNext != &_CleanupHead) + { + // Unchain the entries and free all resources it holds. + OXIDEntry *pEntry = _CleanupHead.pNext; + _CleanupHead.pNext = pEntry->pNext; + ExpireEntry(pEntry); + } + + ComDebOut((DEB_OXID, "COXIDTable::FreeCleanupEntries\n")); +} + +//+------------------------------------------------------------------------ +// +// Member: COXIDTable::NumOxidsToRemove +// +// Synopsis: Returns the number of OXIDs on the expired list that can be +// freed. +// +// History: 03-Jun-96 AlexMit Created +// +//------------------------------------------------------------------------- +DWORD COXIDTable::NumOxidsToRemove() +{ + ASSERT_LOCK_HELD + + // Compute how many extra OXIDs are on the expired list. + if (_cExpired > OXIDTBL_MAXEXPIRED) + return _cExpired - OXIDTBL_MAXEXPIRED; + else + return 0; +} + +//+------------------------------------------------------------------------ +// +// Member: COXIDTable::GetOxidsToRemove +// +// Synopsis: Builds a list of OXIDs old enough to be deleted. Removes +// them from the expired list and puts them on the cleanup list. +// Moves machine local OXIDs directly to the cleanup list. +// +// History: 03-Jun-42 AlexMit Created +// +//------------------------------------------------------------------------- +void COXIDTable::GetOxidsToRemove( OXID_REF *pRef, DWORD *pNum ) +{ + OXIDEntry *pEntry; + ASSERT_LOCK_HELD + + // Expire entries until the expired list is short enough. + *pNum = 0; + while (_cExpired > OXIDTBL_MAXEXPIRED) + { + // Only count machine remote OXIDs. + pEntry = _ExpireHead.pNext; + if ((pEntry->dwFlags & OXIDF_MACHINE_LOCAL) == 0) + { + // Add the OXID to the list to deregister. + MIDFromMOXID( pEntry->moxid, &pRef->mid ); + OXIDFromMOXID( pEntry->moxid, &pRef->oxid ); + pRef->refs = pEntry->cResolverRef; + pRef++; + *pNum += 1; + } + + // Remove the OXID from the expired list and put it on a list + // of OXIDs to be released by some apartment thread. + _cExpired--; + pEntry->pPrev->pNext = pEntry->pNext; + pEntry->pNext->pPrev = pEntry->pPrev; + pEntry->pNext = _CleanupHead.pNext; + _CleanupHead.pNext = pEntry; + } +} + +//+------------------------------------------------------------------------ +// +// Member: COXIDTbl::ExpireEntry, private +// +// Synopsis: deletes all state associated with an OXIDEntry that has +// been expired. +// +// History: 02-Feb-95 Rickhi Created +// +//------------------------------------------------------------------------- +void COXIDTable::ExpireEntry(OXIDEntry *pEntry) +{ + ComDebOut((DEB_OXID, "COXIDTable::ExpireEntry pEntry:%x\n", pEntry)); + Win4Assert(pEntry); + Win4Assert(!(pEntry->dwFlags & OXIDF_PENDINGRELEASE)); + ASSERT_LOCK_HELD + + if (pEntry->pRUSTA || pEntry->pRUMTA) + { + // release the IRemUnknown. Note that the IRemUnk is an object + // proxy who's IPIDEntry holds a reference back to the very + // OXIDEntry we are releasing. In order to prevent recursive + // Release's we set a simple flag here and check for it above. + + pEntry->dwFlags |= OXIDF_PENDINGRELEASE; + + UNLOCK + ASSERT_LOCK_RELEASED + + if (pEntry->pRUSTA) + { + pEntry->pRUSTA->Release(); + } + if (pEntry->pRUMTA) + { + pEntry->pRUMTA->Release(); + } + + ASSERT_LOCK_RELEASED + LOCK + } + + if (pEntry->hServerSTA != NULL) + { + // Note that if hServerSTA is an HWND (apartment model, same process) + // then it should have been cleaned up already in ThreadStop. We + // just assert that here. + Win4Assert(pEntry->dwPid != GetCurrentProcessId()); + + // hServerSTA is an RPC binding handle. Free it. + RPC_STATUS sc = RpcBindingFree(&pEntry->hServerSTA); + ComDebErr(sc != RPC_S_OK, "RpcBindingFree failed.\n"); + } + + if (pEntry->hServerMTA != NULL) + { + // hServerMTA is an RPC binding handle. Free it. + Win4Assert(pEntry->dwPid != GetCurrentProcessId()); + RPC_STATUS sc = RpcBindingFree(&pEntry->hServerMTA); + ComDebErr(sc != RPC_S_OK, "RpcBindingFree failed.\n"); + } + + // dec the refcnt on the MIDEntry + DecMIDRefCnt(pEntry->pMIDEntry); + + // Release the call shutdown event. + if (pEntry->hComplete != NULL) + CloseHandle( pEntry->hComplete ); + + // zero out the fields + memset(pEntry, 0, sizeof(OXIDEntry)); + + // return it to the allocator + _palloc.ReleaseEntry((PageEntry *)pEntry); + + ComDebOut((DEB_OXID,"COXIDTable::ExpireEntry pEntry:%x\n", pEntry)); +} + +//+------------------------------------------------------------------------ +// +// Function: COXIDTbl::DecOXIDRefCnt, public +// +// Synopsis: release one reference to the OXIDEntry and release +// the entry if the count goes to zero. +// +// History: 02-Feb-95 Rickhi Created +// +//------------------------------------------------------------------------- +void DecOXIDRefCnt(OXIDEntry *pEntry) +{ + Win4Assert(pEntry); + ASSERT_LOCK_HELD + + ComDebOut((DEB_OXID, + "DecOXIDRefCnt pEntry:%x cRefs[%x]\n", pEntry, pEntry->cRefs-1)); + + pEntry->cRefs--; + if (pEntry->cRefs == 0) + { + gOXIDTbl.ReleaseEntry(pEntry); + } +} + + +//+------------------------------------------------------------------- +// +// Member: COXIDTable::GetRemUnk, public +// +// Synopsis: Find or create the proxy for the IRemUnknown for the +// specified OXID +// +// History: 27-Mar-95 AlexMit Created +// +//-------------------------------------------------------------------- +HRESULT COXIDTable::GetRemUnk(OXIDEntry *pOXIDEntry, IRemUnknown **ppRemUnk) +{ + ComDebOut((DEB_OXID, "COXIDTable::GetRemUnk pOXIDEntry:%x ppRemUnk:%x\n", + pOXIDEntry, ppRemUnk)); + ASSERT_LOCK_HELD + HRESULT hr = S_OK; + + if (IsMTAThread()) + { + // return the MTA version of the IRemUnknown proxy. + if (pOXIDEntry->pRUMTA == NULL) + { + hr = MakeRemUnk(pOXIDEntry); + } + *ppRemUnk = pOXIDEntry->pRUMTA; + } + else + { + // return the STA version of the IRemUnknown proxy. + if (pOXIDEntry->pRUSTA == NULL) + { + hr = MakeRemUnk(pOXIDEntry); + } + *ppRemUnk = pOXIDEntry->pRUSTA; + } + + ComDebOut((DEB_OXID, "COXIDTable::GetRemUnk pOXIDEntry:%x pRU:%x hr:%x\n", + pOXIDEntry, *ppRemUnk, hr)); + return hr; +} + +//+------------------------------------------------------------------- +// +// Member: COXIDTable::MakeRemUnk, private +// +// Synopsis: Create the proxy for the IRemUnknown for the +// specified OXID and current apartments threading model. +// +// History: 27-Mar-95 AlexMit Created +// +//-------------------------------------------------------------------- +HRESULT COXIDTable::MakeRemUnk(OXIDEntry *pOXIDEntry) +{ + // There is no remote unknown proxy for this entry, get one. + // Make up an objref, then unmarshal it to create a proxy to + // the remunk object in the server. + + // on the same machine, we ask for the IRundown interface since we may + // need the RemChangeRef method. IRundown inherits from IRemUnknown2 + // and IRemUnknown. + + REFIID riid = (pOXIDEntry->dwFlags & OXIDF_MACHINE_LOCAL) + ? IID_IRundown : IID_IRemUnknown; + + OBJREF objref; + HRESULT hr = MakeFakeObjRef(objref, pOXIDEntry, pOXIDEntry->ipidRundown, riid); + + UNLOCK + ASSERT_LOCK_RELEASED + + IRemUnknown *pRU = NULL; + + if (SUCCEEDED(hr)) + { + hr = UnmarshalInternalObjRef(objref, (void **)&pRU); + } + + ASSERT_LOCK_RELEASED + LOCK + + if (SUCCEEDED(hr) && IsMTAThread() && pOXIDEntry->pRUMTA == NULL) + { + pOXIDEntry->pRUMTA = pRU; + + // need to adjust the internal refcnt on the OXIDEntry, since + // the IRemUnknown has an IPID that holds a reference to it. + // Dont use DecOXIDRefCnt since that would delete if it was 0. + + Win4Assert(pOXIDEntry->cRefs > 0); + pOXIDEntry->cRefs--; + } + else if (SUCCEEDED(hr) && IsSTAThread() && pOXIDEntry->pRUSTA == NULL) + { + pOXIDEntry->pRUSTA = pRU; + + // need to adjust the internal refcnt on the OXIDEntry, since + // the IRemUnknown has an IPID that holds a reference to it. + // Dont use DecOXIDRefCnt since that would delete if it was 0. + + Win4Assert(pOXIDEntry->cRefs > 0); + pOXIDEntry->cRefs--; + } + else if (pRU) + { + // either setting of the security failed OR, we released the + // lock and when we took the lock again some other thread had already + // created the proxy. In either case we just release the one we made. + + UNLOCK + ASSERT_LOCK_RELEASED + pRU->Release(); + ASSERT_LOCK_RELEASED + LOCK + } + + ComDebOut((DEB_OXID, "COXIDTable::GetRemUnk pOXIDEntry:%x pRU:%x hr:%x\n", + pOXIDEntry, pRU, hr)); + return hr; +} + +//+------------------------------------------------------------------------ +// +// Member: COXIDTbl::GetLocalEntry, public +// +// Synopsis: Finds an entry in the OXID table for the local apartment. +// If no entry exists, it creates an entry, and starts RPC +// listening if appropriate. +// +// History: 20-Feb-95 Rickhi Created +// +// Notes: Marshalling the remote unknown causes recursion back to +// this function. The recursion is terminated because +// GetLocalOXIDEntry is not NULL on the second call. +// +//------------------------------------------------------------------------- +HRESULT COXIDTable::GetLocalEntry(OXIDEntry **ppEntry) +{ + ComDebOut((DEB_OXID, "COXIDTable::GetLocalEntry ppEntry:%x\n", ppEntry)); + ASSERT_LOCK_HELD + + HRESULT hr = S_OK; + MIDEntry *pMIDEntry; + + *ppEntry = GetLocalOXIDEntry(); + + if (*ppEntry == NULL && SUCCEEDED(hr = GetLocalMIDEntry(&pMIDEntry))) + { + // No local OXID entry exists, make one. + + // NOTE: Chicken And Egg Problem. + // + // Marshaling needs the local OXIDEntry. The local OXIDEntry needs + // the local OXID. To get the local OXID we have to call the resolver. + // To call the resolver we need the IPID for IRemUnknown. To get the + // IPID for IRemUnknown, we need to marshal CRemoteUnknown! + // + // To get around this problem, we create a local OXIDEntry (that has + // a 0 OXID and NULL ipidRemUnknown) so that marshaling can find it. + // Then we marshal the RemoteUnknown and extract its IPID value, stick + // it in the local OXIDEntry. When we call the resolver (to get some + // pre-registered OIDs) we get the real OXID value which we then stuff + // in the local OXIDEntry. + + OXID_INFO oxidInfo; + oxidInfo.dwTid = (IsMTAThread()) ? 0 : GetCurrentThreadId(); + oxidInfo.dwPid = GetCurrentProcessId(); + oxidInfo.ipidRemUnknown = GUID_NULL; + oxidInfo.psa = NULL; + oxidInfo.dwAuthnHint = RPC_C_AUTHN_LEVEL_NONE; + + // NOTE: temp creation of OXID. We dont know the real OXID until + // we call the resolver. So, we use 0 temporarily (it wont conflict + // with any other MOXIDs we might be searching for because we already + // have the real MID and our local resolver wont give out a 0 OXID). + // The OXID will be replaced with the real one when we register + // with the resolver in CRpcResolver::ServerAllocateOXIDAndOIDs. + + OXID oxid; + memset(&oxid, 0, sizeof(oxid)); + + hr = AddEntry(oxid, &oxidInfo, pMIDEntry, ppEntry); + + if (SUCCEEDED(hr)) + { + // Set the local OXID index and marshal IRemUnknown. Note + // that the index must be set before we construct the + // CRemoteUnknown since that calls MarshalObjRef which + // recurses back into GetLocalEntry. Setting the LocalOXID + // now allows us to break the recursion. + + SetLocalOXIDEntry(*ppEntry); + + // Create the remote unknown for this apartment. It places + // itself in TLS or in the global gpMTARemoteUnknown. + + hr = E_OUTOFMEMORY; // assume OOM + CRemoteUnknown *pRemUnk = new CRemoteUnknown(hr, + &(*ppEntry)->ipidRundown); + + if (FAILED(hr)) + { + // remove the Local OXID entry. This will also clean up + // pRemUnk if the allocation succeeded but ctor failed. + + if (IsSTAThread()) + { + ReleaseLocalSTAEntry(); + } + else + { + ReleaseLocalMTAEntry(); + } + } + } + } + + ComDebOut((DEB_OXID, "COXIDTable::GetLocalEntry this:%x pEntry:%x\n", + this, *ppEntry)); + return hr; +} + +//+------------------------------------------------------------------------ +// +// Member: COXIDTbl::ReleaseLocalSTAEntry, public +// +// Synopsis: releases the OXIDEntry for the current STA apartment. +// +// History: 20-Feb-95 Rickhi Created +// +//+------------------------------------------------------------------------ +void COXIDTable::ReleaseLocalSTAEntry(void) +{ + ComDebOut((DEB_OXID, "COXIDTable::ReleaseLocalSTAEntry\n")); + Win4Assert(IsSTAThread()); + ASSERT_LOCK_HELD + + COleTls tls; + + OXIDEntry *pOXIDEntry = (OXIDEntry *)(tls->pOXIDEntry); + + if (pOXIDEntry) + { + // get the CRemoteUnknown for this apartment. + CRemoteUnknown *pRemUnk = tls->pRemoteUnk; + tls->pRemoteUnk = NULL; + + // this guy ignores refcounts so we delete him directly. + delete pRemUnk; + + // de-register the OXID and OIDs with the resolver. + gResolver.ServerFreeOXID(pOXIDEntry); + + // Clear the apartment OXID Entry. + tls->pOXIDEntry = NULL; + + // now decrement its count. + DecOXIDRefCnt(pOXIDEntry); + } +} + +//+------------------------------------------------------------------------ +// +// Member: COXIDTbl::ReleaseLocalMTAEntry, public +// +// Synopsis: releases the OXIDEntry for the current apartment. +// +// History: 20-Feb-95 Rickhi Created +// +//+------------------------------------------------------------------------ +void COXIDTable::ReleaseLocalMTAEntry(void) +{ + ComDebOut((DEB_OXID, "COXIDTable::ReleaseLocalMTAEntry\n")); + ASSERT_LOCK_HELD + + OXIDEntry *pOXIDEntry = gpMTAOXIDEntry; + + if (pOXIDEntry) + { + // get the CRemoteUnknown for this apartment. + CRemoteUnknown *pRemUnk = gpMTARemoteUnknown;; + gpMTARemoteUnknown = NULL; + + // this guy ignores refcounts so we delete him directly. + delete pRemUnk; + + // de-register the OXID and OIDs with the resolver. + gResolver.ServerFreeOXID(pOXIDEntry); + + // Clear the MTA apartment OXID Entry. + gpMTAOXIDEntry = NULL; + + // now decrement its count. + DecOXIDRefCnt(pOXIDEntry); + } +} + +//+------------------------------------------------------------------- +// +// Function: FindOrCreateOXIDEntry +// +// Synopsis: finds or adds an OXIDEntry for the given OXID. May +// also create a MIDEntry if one does not yet exist. +// +// History: 22-Jan-96 Rickhi Created. +// +//-------------------------------------------------------------------- +HRESULT FindOrCreateOXIDEntry(REFOXID roxid, + OXID_INFO &oxidInfo, + FOCOXID eResolverRef, + DUALSTRINGARRAY *psaResolver, + REFMID rmid, + MIDEntry *pMIDEntry, + OXIDEntry **ppOXIDEntry) +{ + ComDebOut((DEB_OXID,"FindOrCreateOXIDEntry oxid:%08x %08x oxidInfo:%x psa:%ws pMIDEntry:%x\n", + roxid, &oxidInfo, psaResolver, pMIDEntry)); + gOXIDTbl.ValidateOXID(); + ASSERT_LOCK_HELD + + HRESULT hr = S_OK; + + // check if the OXIDEntry was created while we were resolving it. + *ppOXIDEntry = gOXIDTbl.LookupOXID(roxid, rmid); + + if (*ppOXIDEntry == NULL) + { + BOOL fReleaseMIDEntry = FALSE; + + if (pMIDEntry == NULL) + { + // dont yet have a MIDEntry for the machine so go add it + hr = gMIDTbl.FindOrCreateMIDEntry(rmid, psaResolver, &pMIDEntry); + fReleaseMIDEntry = TRUE; + } + + if (pMIDEntry) + { + // add a new the OXIDEntry + hr = gOXIDTbl.AddEntry(roxid, &oxidInfo, pMIDEntry, ppOXIDEntry); + + if (fReleaseMIDEntry) + { + // undo the reference added by FindOrCreateMIDEntry + DecMIDRefCnt(pMIDEntry); + } + } + } + + if (SUCCEEDED(hr) && eResolverRef == FOCOXID_REF) + { + // Increment the count of references handed to us from the resolver. + (*ppOXIDEntry)->cResolverRef += 1; + } + + gOXIDTbl.ValidateOXID(); + ComDebOut((DEB_OXID,"FindOrCreateOXIDEntry pOXIDEntry:%x hr:%x\n", + *ppOXIDEntry, hr)); + ASSERT_LOCK_HELD + return hr; +} + +//+------------------------------------------------------------------------ +// +// Function: GetLocalOXIDEntry +// +// Synopsis: Get either the global or the TLS OXIDEntry based on the +// threading model of the current thread. +// +// History: 05-May-95 AlexMit Created +// +//------------------------------------------------------------------------- +OXIDEntry *GetLocalOXIDEntry() +{ + ASSERT_LOCK_HELD + + COleTls tls; + if (tls->dwFlags & OLETLS_APARTMENTTHREADED) + return (OXIDEntry *)(tls->pOXIDEntry); + + return gpMTAOXIDEntry; +} + +//+------------------------------------------------------------------------ +// +// Function: SetLocalOXIDEntry +// +// Synopsis: Set either the global or the TLS OXIDEntry based on the +// threading model of the current thread. +// +// History: 05-May-95 AlexMit Created +// +//------------------------------------------------------------------------- +void SetLocalOXIDEntry(OXIDEntry *pOXIDEntry) +{ + ASSERT_LOCK_HELD + + COleTls tls; + if (tls->dwFlags & OLETLS_APARTMENTTHREADED) + { + tls->pOXIDEntry = (void *)pOXIDEntry; + return; + } + + gpMTAOXIDEntry = pOXIDEntry; +} + +//+------------------------------------------------------------------------ +// +// Function: CoGetTidFromIPID +// +// Synopsis: Take an IPID and return the thread id the object is on. +// MSWMSG calls this function during dispatches. +// +//------------------------------------------------------------------------- +STDAPI_(DWORD) CoGetTIDFromIPID( UUID *pIPID ) +{ + DWORD iTid = 0; + LOCK + + IPIDEntry *pEntry = gIPIDTbl.LookupIPID( *pIPID ); + if (pEntry != NULL && pEntry->pOXIDEntry != NULL) + { + iTid = pEntry->pOXIDEntry->dwTid; + } + + UNLOCK + return iTid; +} + +//+------------------------------------------------------------------------ +// +// Function: CleanupMIDEntry +// +// Synopsis: Called by the MID hash table when cleaning up any leftover +// entries. +// +// History: 02-Feb-96 Rickhi Created +// +//------------------------------------------------------------------------- +void CleanupMIDEntry(SHashChain *pNode) +{ + gMIDTbl.ReleaseEntry((MIDEntry *)pNode); +} + +//+------------------------------------------------------------------------ +// +// Member: CMIDTbl::Initialize, public +// +// Synopsis: Initializes the MID table. +// +// History: 02-Feb-96 Rickhi Created +// +//------------------------------------------------------------------------- +void CMIDTable::Initialize() +{ + ComDebOut((DEB_OXID, "CMIDTable::Initialize\n")); + _HashTbl.Initialize(MIDBuckets); + _palloc.Initialize(sizeof(MIDEntry), MIDS_PER_PAGE); +} + +//+------------------------------------------------------------------------ +// +// Member: CMIDTbl::Cleanup, public +// +// Synopsis: Cleanup the MID table. +// +// History: 02-Feb-95 Rickhi Created +// +//------------------------------------------------------------------------- +void CMIDTable::Cleanup() +{ + ComDebOut((DEB_OXID, "CMIDTable::Cleanup\n")); + _HashTbl.Cleanup(CleanupMIDEntry); + _palloc.Cleanup(); +} + +//+------------------------------------------------------------------------ +// +// Member: CMIDTable::FindOrCreateMIDEntry, public +// +// Synopsis: Looks for existing copy of the string array in the MID table, +// creates one if not found +// +// History: 05-Jan-96 Rickhi Created +// +//------------------------------------------------------------------------- +HRESULT CMIDTable::FindOrCreateMIDEntry(REFMID rmid, + DUALSTRINGARRAY *psaResolver, + MIDEntry **ppMIDEntry) +{ + ComDebOut((DEB_OXID, "CMIDTable::FindOrCreateMIDEntry psa:%x\n", psaResolver)); + Win4Assert(psaResolver != NULL); + ASSERT_LOCK_HELD + + HRESULT hr = S_OK; + DWORD dwHash; + + *ppMIDEntry = LookupMID(psaResolver, &dwHash); + + if (*ppMIDEntry == NULL) + { + hr = AddMIDEntry(rmid, dwHash, psaResolver, ppMIDEntry); + } + + ASSERT_LOCK_HELD + ComDebOut((DEB_OXID, "CMIDTable::FindOrCreateEntry pMIDEntry:%x hr:%x\n", *ppMIDEntry, hr)); + return hr; +} + +//+------------------------------------------------------------------------ +// +// Member: CMIDTable::LookupMID, public +// +// Synopsis: Looks for existing copy of the string array in the MID table. +// +// History: 05-Jan-96 Rickhi Created +// +//------------------------------------------------------------------------- +MIDEntry *CMIDTable::LookupMID(DUALSTRINGARRAY *psaResolver, DWORD *pdwHash) +{ + ComDebOut((DEB_OXID, "CMIDTable::LookupMID psa:%x\n", psaResolver)); + Win4Assert(psaResolver != NULL); + ASSERT_LOCK_HELD + + *pdwHash = _HashTbl.Hash(psaResolver); + MIDEntry *pMIDEntry = (MIDEntry *) _HashTbl.Lookup(*pdwHash, psaResolver); + + if (pMIDEntry) + { + // found the node, AddRef it and return + IncMIDRefCnt(pMIDEntry); + } + + ASSERT_LOCK_HELD + ComDebOut((DEB_OXID, "CMIDTable::LookupMID pMIDEntry:%x\n", pMIDEntry)); + return pMIDEntry; +} + +//+------------------------------------------------------------------------ +// +// Member: CMIDTable::AddEntry, public +// +// Synopsis: Adds an entry to the MID table. The entry is AddRef'd. +// +// History: 05-Jan-96 Rickhi Created +// +//------------------------------------------------------------------------- +HRESULT CMIDTable::AddMIDEntry(REFMID rmid, DWORD dwHash, + DUALSTRINGARRAY *psaResolver, + MIDEntry **ppMIDEntry) +{ + ComDebOut((DEB_OXID, "CMIDTable::AddMIDEntry rmid:%08x %08x dwHash:%x psa:%x\n", + rmid, dwHash, psaResolver)); + Win4Assert(psaResolver != NULL); + ASSERT_LOCK_HELD + + // We must make a copy of the psa to store in the table, since we are + // using the one read in from ReadObjRef (or allocated by MIDL). + + DUALSTRINGARRAY *psaNew; + HRESULT hr = CopyStringArray(psaResolver, NULL, &psaNew); + if (FAILED(hr)) + return hr; + + MIDEntry *pMIDEntry = (MIDEntry *) _palloc.AllocEntry(); + + if (pMIDEntry) + { + pMIDEntry->cRefs = 1; + pMIDEntry->dwFlags = 0; + pMIDEntry->mid = rmid; + + // add the entry to the hash table + _HashTbl.Add(dwHash, psaNew, &pMIDEntry->Node); + + hr = S_OK; + + // set the maximum size of any resolver PSA we have seen. This is used + // when computing the max marshal size during interface marshaling. + + DWORD dwpsaSize = SASIZE(psaNew->wNumEntries); + if (dwpsaSize > gdwPsaMaxSize) + { + gdwPsaMaxSize = dwpsaSize; + } + } + else + { + // cant create a MIDEntry, free the copy of the string array. + PrivMemFree(psaNew); + hr = E_OUTOFMEMORY; + } + + *ppMIDEntry = pMIDEntry; + + ASSERT_LOCK_HELD + ComDebOut((DEB_OXID, "CMIDTable::AddMIDEntry pMIDEntry:%x hr:%x\n", *ppMIDEntry, hr)); + return hr; +} + +//+------------------------------------------------------------------------ +// +// Member: CMIDTable::ReleaseEntry, public +// +// Synopsis: remove the MIDEntry from the hash table and free the memory +// +// History: 05-Jan-96 Rickhi Created +// +//------------------------------------------------------------------------- +void CMIDTable::ReleaseEntry(MIDEntry *pMIDEntry) +{ + ComDebOut((DEB_OXID, "CMIDTable::ReleaseEntry pMIDEntry:%x\n", pMIDEntry)); + Win4Assert(pMIDEntry->cRefs == 0); + ASSERT_LOCK_HELD + + // delete the string array + PrivMemFree(pMIDEntry->Node.psaKey); + + // remove from the hash chain and delete the node + _HashTbl.Remove(&pMIDEntry->Node.chain); + + _palloc.ReleaseEntry((PageEntry *)pMIDEntry); +} + +//+------------------------------------------------------------------------ +// +// Function: DecMIDRefCnt, public +// +// Synopsis: release one reference to the MIDEntry and release +// the entry if the count goes to zero. +// +// History: 05-Jan-96 Rickhi Created +// +//------------------------------------------------------------------------- +void DecMIDRefCnt(MIDEntry *pMIDEntry) +{ + Win4Assert(pMIDEntry); + ASSERT_LOCK_HELD + + ComDebOut((DEB_OXID, + "DecMIDRefCnt pMIDEntry:%x cRefs[%x]\n", pMIDEntry, pMIDEntry->cRefs-1)); + + pMIDEntry->cRefs--; + if (pMIDEntry->cRefs == 0) + { + gMIDTbl.ReleaseEntry(pMIDEntry); + } +} + +//+------------------------------------------------------------------------ +// +// Function: GetLocalMIDEntry +// +// Synopsis: Get or create the MID (Machine ID) entry for the local +// machine. gpLocalMIDEntry holds the network address for the +// local OXID resolver. +// +// History: 05-Jan-96 Rickhi Created +// +//------------------------------------------------------------------------- +HRESULT GetLocalMIDEntry(MIDEntry **ppMIDEntry) +{ + ASSERT_LOCK_HELD + HRESULT hr = S_OK; + + if (gpLocalMIDEntry == NULL) + { + // make sure we have the local resolver string bindings + RPC_STATUS sc = gResolver.GetConnection(); + if (sc == RPC_S_OK) + { + // Create a MID entry for the Local Resolver + hr = gMIDTbl.FindOrCreateMIDEntry(gLocalMid, gpsaLocalResolver, + &gpLocalMIDEntry); + } + else + { + hr = MAKE_SCODE(SEVERITY_ERROR, FACILITY_WIN32, sc); + } + } + + *ppMIDEntry = gpLocalMIDEntry; + return hr; +} diff --git a/private/ole32/com/dcomrem/ipidtbl.hxx b/private/ole32/com/dcomrem/ipidtbl.hxx new file mode 100644 index 000000000..04962ea79 --- /dev/null +++ b/private/ole32/com/dcomrem/ipidtbl.hxx @@ -0,0 +1,485 @@ +//+------------------------------------------------------------------------ +// +// File: ipidtbl.hxx +// +// Contents: MID (machine identifier) table. +// OXID (object exporter identifier) table. +// IPID (interface pointer identifier) table. +// +// Classes: CMIDTable +// COXIDTable +// CIPIDTable +// +// History: 02-Feb-95 Rickhi Created +// +//------------------------------------------------------------------------- +#ifndef _IPIDTBL_HXX_ +#define _IPIDTBL_HXX_ + +#include <pgalloc.hxx> // CPageAllocator +#include <lclor.h> // local OXID resolver interface +#include <remoteu.hxx> // CRemoteUnknown +#include <locks.hxx> // ASSERT_LOCK_HELD +#include <hash.hxx> // CStringHashTable + + +// forward declarations +class CRpcChannelBuffer; + + +//+------------------------------------------------------------------------ +// +// This structure defines an Entry in the MID table. There is one MID +// table for the entire process. There is one MIDEntry per machine that +// the current process is talking to (including one for the local machine). +// +//------------------------------------------------------------------------- +typedef struct tagMIDEntry +{ + SStringHashNode Node; // hash chain and key + MID mid; // machine identifier + LONG cRefs; // count of IPIDs using this OXIDEntry + DWORD dwFlags; // state flags +} MIDEntry; + +// MID Table constants. MIDS_PER_PAGE is the number of MIDEntries +// in one page of the page allocator. + +#define MIDS_PER_PAGE 5 + + +//+------------------------------------------------------------------------ +// +// class: CMIDTable +// +// Synopsis: Table of Machine IDs (MIDs) and associated information. +// +// History: 05-Jan-96 Rickhi Created +// +//------------------------------------------------------------------------- +class CMIDTable +{ +public: + void Initialize(); // initialize table + void Cleanup(); // cleanup table + + HRESULT FindOrCreateMIDEntry(REFMID rmid, + DUALSTRINGARRAY *psaResolver, + MIDEntry **ppMIDEntry); + + MIDEntry *LookupMID(DUALSTRINGARRAY *psaResolver, DWORD *pdwHash); + + void ReleaseEntry(MIDEntry *pMIDEntry); + +private: + HRESULT AddMIDEntry(REFMID rmid, + DWORD dwHash, + DUALSTRINGARRAY *psaResolver, + MIDEntry **ppMIDEntry); + + static CStringHashTable _HashTbl; // hash table for MIDEntries + static CPageAllocator _palloc; // page based allocator +}; + + + +//+------------------------------------------------------------------------ +// +// This structure defines an Entry in the OXID table. There is one OXID +// table for the entire process. There is one OXIDEntry per apartment. +// +//------------------------------------------------------------------------- +typedef struct tagOXIDEntry +{ + struct tagOXIDEntry *pPrev; // previous entry on inuse list + struct tagOXIDEntry *pNext; // next entry on free/inuse list + DWORD dwPid; // process id of server + DWORD dwTid; // thread id of server + MOXID moxid; // object exporter identifier + machine id + IPID ipidRundown;// IPID of IRundown and Remote Unknown + DWORD dwFlags; // state flags + handle_t hServerSTA; // rpc binding handle of server + handle_t hServerMTA; // rpc binding handle of server + MIDEntry *pMIDEntry; // MIDEntry for machine where server lives + IRemUnknown *pRUSTA; // STA model proxy for Remote Unknown + IRemUnknown *pRUMTA; // MTA model proxy for Remote Unknown + LONG cRefs; // count of IPIDs using this OXIDEntry + LONG cWaiters; // count of threads waiting for OIDs + HANDLE hComplete; // set when last outstanding call completes + LONG cCalls; // number of calls dispatched + LONG cResolverRef;//References to resolver + DWORD dwPad; // keep structure 16 byte aligned +} OXIDEntry; + +// bit flags for dwFlags of OXIDEntry +typedef enum tagOXIDFLAGS +{ + OXIDF_REGISTERED = 0x1, // oxid is registered with Resolver + OXIDF_MACHINE_LOCAL = 0x2, // oxid is local to this machine + OXIDF_STOPPED = 0x4, // thread can no longer receive calls + OXIDF_PENDINGRELEASE = 0x8, // oxid entry is already being released + OXIDF_MSWMSG = 0x10, // use mswmsg transport + OXIDF_REGISTERINGOIDS= 0x20, // a thread is busy registering OIDs + OXIDF_MTASERVER = 0x40 // the server is an MTA apartment. +} OXIDFLAGS; + +// Parameter to FindOrCreateOXIDEntry +typedef enum tagFOCOXID +{ + FOCOXID_REF = 1, // Got reference from resolver + FOCOXID_NOREF = 2 // No reference from resolver +} FOCOXID; + +// OXID Table constants. +#define OXIDS_PER_PAGE 10 +#define OXIDTBL_MAXEXPIRED 5 // max number of expired entries to keep + + +//+------------------------------------------------------------------------ +// +// class: COXIDTable +// +// Synopsis: Maintains a table of OXIDs and associated information +// +// History: 02-Feb-95 Rickhi Created +// +//------------------------------------------------------------------------- +class COXIDTable +{ +public: + HRESULT AddEntry(REFOXID roxid, OXID_INFO *poxidInfo, + MIDEntry *pMIDEntry, OXIDEntry **ppEntry); + + void ReleaseEntry(OXIDEntry *pEntry); + + HRESULT GetLocalEntry(OXIDEntry **ppEntry); + void ReleaseLocalSTAEntry(void); + void ReleaseLocalMTAEntry(void); + OXIDEntry *LookupOXID(REFOXID roxid, REFMID rmid); + + HRESULT GetRemUnk(OXIDEntry *pOXIDEntry, IRemUnknown **ppRemUnk); + + void Initialize(); // initialize table + void Cleanup(); // cleanup table + void FreeExpiredEntries(DWORD dwTime); + void ValidateOXID(); + void FreeCleanupEntries(); + DWORD NumOxidsToRemove(); + void GetOxidsToRemove( OXID_REF *pRef, DWORD *pNum ); + +private: + + void ExpireEntry(OXIDEntry *pEntry); + OXIDEntry *SearchList(REFMOXID rmoxid, OXIDEntry *pStart); + HRESULT MakeRemUnk(OXIDEntry *pOXIDEntry); + void AssertListsEmpty(void); + + static DWORD _cExpired; // count of expired entries + static OXIDEntry _InUseHead; // head of InUse list. + static OXIDEntry _ExpireHead; // head of Expire list. + static OXIDEntry _CleanupHead; // head of Cleanup list. + + static CPageAllocator _palloc; // page alloctor + + // PERFWORK: could save space since only the first two entries of + // the InUseHead and ExpireHead are used (the list ptrs) and hence + // dont need whole OXIDEntries here. +}; + +//+------------------------------------------------------------------------ +// +// Member: COXIDTbl::ValidateOXID, public +// +// Synopsis: Asserts that no OXIDEntries have trashed window handles. +// +//------------------------------------------------------------------------- +inline void COXIDTable::ValidateOXID() +{ +#if DBG==1 + LOCK + + // Check all entries in use. + OXIDEntry *pCurr = _InUseHead.pNext; + while (pCurr != &_InUseHead) + { + Win4Assert( pCurr->hServerSTA != (void *) 0xC000001C ); + Win4Assert( pCurr->hServerMTA != (void *) 0xC000001C ); + pCurr = pCurr->pNext; + } + UNLOCK +#endif +} + +//+------------------------------------------------------------------------ +// +// Member: COXIDTbl::AssertListsEmpty, public +// +// Synopsis: Asserts that no OXIDEntries are in use +// +// History: 19-Apr-96 Rickhi Created +// +//------------------------------------------------------------------------- +inline void COXIDTable::AssertListsEmpty(void) +{ + // Assert that there are no entries in the InUse or Expired lists. + Win4Assert(_InUseHead.pNext == &_InUseHead); + Win4Assert(_InUseHead.pPrev == &_InUseHead); + Win4Assert(_ExpireHead.pNext == &_ExpireHead); + Win4Assert(_ExpireHead.pPrev == &_ExpireHead); +} + + + +//+------------------------------------------------------------------------ +// +// This structure defines an Entry in the IPID table. There is one +// IPID table for the entire process. It holds IPIDs from local objects +// as well as remote objects. +// +//------------------------------------------------------------------------- +typedef struct tagIPIDEntry +{ + struct tagIPIDEntry *pNextOID; // next IPIDEntry for same object + DWORD dwFlags; // flags (see IPIDFLAGS) + ULONG cStrongRefs; // strong reference count + ULONG cWeakRefs; // weak reference count + ULONG cPrivateRefs;// private reference count + CRpcChannelBuffer *pChnl; // channel pointer + IUnknown *pStub; // proxy or stub pointer + OXIDEntry *pOXIDEntry; // ptr to OXIDEntry in OXID Table + IPID ipid; // interface pointer identifier + IID iid; // interface iid + void *pv; // real interface pointer + DWORD pad[3]; // round size to modulus 16 +} IPIDEntry; + +// bit flags for dwFlags of IPIDEntry +typedef enum tagIPIDFLAGS +{ + IPIDF_CONNECTING = 0x1, // ipid is being connected + IPIDF_DISCONNECTED = 0x2, // ipid is disconnected + IPIDF_SERVERENTRY = 0x4, // SERVER IPID vs CLIENT IPID + IPIDF_NOPING = 0x8, // dont need to ping the server or release + IPIDF_COPY = 0x10, // copy for security only + IPIDF_VACANT = 0x80, // entry is vacant (ie available to reuse) + IPIDF_NONNDRSTUB = 0x100, // stub does not use NDR marshaling + IPIDF_NONNDRPROXY = 0x200, // proxy does not use NDR marshaling + IPIDF_NOTIFYACT = 0x400 // notify activation on marshal/release +} IPIDFLAGS; + + +// IPID Table constants. IPIDS_PER_PAGE is the number of IPIDEntries +// in one page of the page allocator. + +#define IPIDS_PER_PAGE 50 + +//+------------------------------------------------------------------------ +// +// class: CIPIDTbl +// +// Synopsis: Maintains a table of IPIDs and associated information +// +// History: 02-Feb-95 Rickhi Created +// +//------------------------------------------------------------------------- +class CIPIDTable +{ +public: + IPIDEntry *LookupIPID(REFIPID ripid); // find entry in the table with + // the matching ipid + + IPIDEntry *FirstFree(void); + void ReleaseEntryList(IPIDEntry *pFirst, IPIDEntry *pLast); + IPIDEntry *GetEntryPtr(LONG iEntry); + LONG GetEntryIndex(IPIDEntry *pEntry); + +#if DBG==1 + void AssertValid(void) {;} + void ValidateIPIDEntry(IPIDEntry *pEntry, BOOL fServerSide, + CRpcChannelBuffer *pChnl); +#else + void AssertValid(void) {;} + void ValidateIPIDEntry(IPIDEntry *pEntry, BOOL fServerSide, + CRpcChannelBuffer *pChnl) {;} +#endif + + void Initialize(); // initialize table + void Cleanup(); // cleanup table + +private: + static CPageAllocator _palloc; // page alloctor +}; + + +//+------------------------------------------------------------------------ +// +// Global Externals +// +//+------------------------------------------------------------------------ + +extern CMIDTable gMIDTbl; // global table, defined in ipidtbl.cxx +extern COXIDTable gOXIDTbl; // global table, defined in ipidtbl.cxx +extern CIPIDTable gIPIDTbl; // global table, defined in ipidtbl.cxx +extern MIDEntry *gpLocalMIDEntry; // ptr to MIDEntry for current process +extern OXIDEntry *gpMTAOXIDEntry; // ptr to local OXIDEntry in MTA +extern DUALSTRINGARRAY *gpsaLocalResolver; // bindings for local OXID resolver. + +//+------------------------------------------------------------------------ +// +// Function Prototypes +// +//+------------------------------------------------------------------------ + +HRESULT GetLocalMIDEntry(MIDEntry **ppMIDEntry); +OXIDEntry *GetLocalOXIDEntry(); +void SetLocalOXIDEntry(OXIDEntry *pOXIDEntry); +void DecOXIDRefCnt(OXIDEntry *pEntry); +void DecMIDRefCnt(MIDEntry *pEntry); + +HRESULT FindOrCreateOXIDEntry(REFOXID roxid, + OXID_INFO &oxidInfo, + FOCOXID eReferenced, + DUALSTRINGARRAY *psaResolver, + REFMID rmid, + MIDEntry *pMIDEntry, + OXIDEntry **ppOXIDEntry); + + +//+------------------------------------------------------------------------ +// +// Member: CIPIDTbl::FirstFree, public +// +// Synopsis: Finds the first available entry in the table and returns +// its index. Returns -1 if no space is available and it +// cant grow the list. +// +// History: 02-Feb-95 Rickhi Created +// +//------------------------------------------------------------------------- +inline IPIDEntry *CIPIDTable::FirstFree() +{ + return (IPIDEntry *) _palloc.AllocEntry(); +} + +//+------------------------------------------------------------------------ +// +// Member: CIPIDTbl::GetEntryIndex, public +// +// Synopsis: Converts an entry ptr into an entry index +// +// History: 02-Feb-95 Rickhi Created +// +//------------------------------------------------------------------------- +inline LONG CIPIDTable::GetEntryIndex(IPIDEntry *pIPIDEntry) +{ + return _palloc.GetEntryIndex((PageEntry *)pIPIDEntry); +} + +//+------------------------------------------------------------------------ +// +// Member: CIPIDTbl::GetEntryPtr, public +// +// Synopsis: Converts an entry index into an entry pointer +// +// History: 02-Feb-95 Rickhi Created +// +//------------------------------------------------------------------------- +inline IPIDEntry *CIPIDTable::GetEntryPtr(LONG index) +{ + return (IPIDEntry *) _palloc.GetEntryPtr(index); +} + + + +//+------------------------------------------------------------------------ +// +// Function: IncOXIDRefCnt, public +// +// Synopsis: increment the number of references to the OXIDEntry +// +// History: 02-Feb-95 Rickhi Created +// +//------------------------------------------------------------------------- +inline void IncOXIDRefCnt(OXIDEntry *pEntry) +{ + Win4Assert(pEntry); + ASSERT_LOCK_HELD + + ComDebOut((DEB_OXID, + "IncOXIDRefCnt pEntry:%x cRefs[%x]\n", pEntry, pEntry->cRefs+1)); + + pEntry->cRefs++; +} + +//+------------------------------------------------------------------------ +// +// Function: IncMIDRefCnt, public +// +// Synopsis: increment the number of references to the MIDEntry +// +// History: 05-Janb-96 Rickhi Created +// +//------------------------------------------------------------------------- +inline void IncMIDRefCnt(MIDEntry *pEntry) +{ + Win4Assert(pEntry); + ASSERT_LOCK_HELD + + ComDebOut((DEB_OXID, + "IncMIDRefCnt pEntry:%x cRefs[%x]\n", pEntry, pEntry->cRefs+1)); + + pEntry->cRefs++; +} + +//+------------------------------------------------------------------------ +// +// Function: MOXIDFromOXIDAndMID, public +// +// Synopsis: creates a MOXID (machine and object exporter ID) from +// the individual OXID and MID components +// +// History: 05-Janb-96 Rickhi Created +// +//------------------------------------------------------------------------- +inline void MOXIDFromOXIDAndMID(REFOXID roxid, REFMID rmid, MOXID *pmoxid) +{ + BYTE *pb = (BYTE *)pmoxid; + memcpy(pb, &roxid, sizeof(OXID)); + memcpy(pb+8, &rmid, sizeof(MID)); +} + +//+------------------------------------------------------------------------ +// +// Function: OXIDFromMOXID, public +// +// Synopsis: extracts the OXID from a MOXID (machine and OXID) +// +// History: 05-Jan-96 Rickhi Created +// +//------------------------------------------------------------------------- +inline void OXIDFromMOXID(REFMOXID rmoxid, OXID *poxid) +{ + memcpy(poxid, (BYTE *)&rmoxid, sizeof(OXID)); +} + +//+------------------------------------------------------------------------ +// +// Function: MIDFromMOXID, public +// +// Synopsis: extracts the MID from a MOXID (machine and OXID) +// +// History: 05-Jan-96 Rickhi Created +// +//------------------------------------------------------------------------- +inline void MIDFromMOXID(REFMOXID rmoxid, OXID *pmid) +{ + memcpy(pmid, ((BYTE *)&rmoxid)+8, sizeof(MID)); +} + +// OID + MID versions of the above routines. + +#define MOIDFromOIDAndMID MOXIDFromOXIDAndMID +#define OIDFromMOID OXIDFromMOXID +#define MIDFromMOID MIDFromMOXID + +#endif // _IPIDTBL_HXX_ diff --git a/private/ole32/com/dcomrem/ipmrshl.cxx b/private/ole32/com/dcomrem/ipmrshl.cxx new file mode 100644 index 000000000..1833d2b45 --- /dev/null +++ b/private/ole32/com/dcomrem/ipmrshl.cxx @@ -0,0 +1,652 @@ +//+------------------------------------------------------------------- +// +// File: ipmrshl.cpp +// +// Contents: Code the implements the standard free thread in process +// marshaler. +// +// Classes: CFreeMarshaler +// CFmCtrlUnknown +// +// Functions: CoCreateFreeThreadedMarshaler +// +// History: 03-Nov-94 Ricksa +// +//-------------------------------------------------------------------- +#include <ole2int.h> +#include <stdid.hxx> + +//+------------------------------------------------------------------- +// +// Class: CFreeMarshaler +// +// Synopsis: Generic marshaling class +// +// Methods: IUnknown +// IMarshal +// +// History: 15-Nov-94 Ricksa Created +// +//-------------------------------------------------------------------- +class CFreeMarshaler : public IMarshal, public CPrivAlloc +{ +public: + CFreeMarshaler(IUnknown *punk); + + // IUnknown + STDMETHODIMP QueryInterface(REFIID iid, void FAR * FAR * ppv); + STDMETHODIMP_(ULONG) AddRef(void); + STDMETHODIMP_(ULONG) Release(void); + + // IMarshal Interface + STDMETHODIMP GetUnmarshalClass( + REFIID riid, + void *pv, + DWORD dwDestContext, + void *pvDestContext, + DWORD mshlflags, + CLSID *pCid); + + STDMETHODIMP GetMarshalSizeMax( + REFIID riid, + void *pv, + DWORD dwDestContext, + void *pvDestContext, + DWORD mshlflags, + DWORD *pSize); + + STDMETHODIMP MarshalInterface( + IStream __RPC_FAR *pStm, + REFIID riid, + void *pv, + DWORD dwDestContext, + void *pvDestContext, + DWORD mshlflags); + + STDMETHODIMP UnmarshalInterface( + IStream *pStm, + REFIID riid, + void **ppv); + + STDMETHODIMP ReleaseMarshalData(IStream *pStm); + + STDMETHODIMP DisconnectObject(DWORD dwReserved); + +private: + + friend class CFmCtrlUnknown; + + // Pointer to the controlling unknown. + IUnknown * _punkCtrl; + +}; + + + + +//+------------------------------------------------------------------- +// +// Class: CFmCtrlUnknown +// +// Synopsis: Controlling IUnknown for generic marshaling class. +// +// Methods: IUnknown +// IMarshal +// +// History: 15-Nov-94 Ricksa Created +// +//-------------------------------------------------------------------- +class CFmCtrlUnknown : public IUnknown, public CPrivAlloc +{ + // IUnknown + STDMETHODIMP QueryInterface(REFIID iid, void **ppv); + + STDMETHODIMP_(ULONG) AddRef(void); + + STDMETHODIMP_(ULONG) Release(void); + +private: + + friend HRESULT CoCreateFreeThreadedMarshaler( + IUnknown *punkCtrl, + IUnknown **punkMarshal); + + friend HRESULT GetInProcFreeMarshaler(IMarshal **ppIM); + + CFmCtrlUnknown(void); + + ~CFmCtrlUnknown(void); + + CFreeMarshaler * _pfm; + + ULONG _cRefs; +}; + + + +//+------------------------------------------------------------------- +// +// Function: CoCreateFreeThreadedMarshaler, public +// +// Synopsis: Create the controlling unknown for the marshaler +// +// Arguments: [punkOuter] - controlling unknown +// [ppunkMarshal] - controlling unknown for marshaler. +// +// Returns: NOERROR +// E_INVALIDARG +// E_OUTOFMEMORY +// +// History: 15-Nov-94 Ricksa Created +// +//-------------------------------------------------------------------- +HRESULT CoCreateFreeThreadedMarshaler( + IUnknown *punkOuter, + IUnknown **ppunkMarshal) +{ + HRESULT hr = E_INVALIDARG; + + // Validate the parameters + if (((punkOuter == NULL) || IsValidInterface(punkOuter)) + && IsValidPtrOut(ppunkMarshal, sizeof(IUnknown *))) + { + CALLHOOKOBJECT(S_OK,CLSID_NULL,IID_IUnknown,(IUnknown **)&punkOuter); + // Assume failure + *ppunkMarshal = NULL; + + hr = E_OUTOFMEMORY; + + // Allocate new free marshal object + CFmCtrlUnknown *pfmc = new CFmCtrlUnknown(); + + if (pfmc != NULL) + { + if (punkOuter == NULL) + { + // Caller wants a non-aggreagated object + punkOuter = pfmc; + } + + // Initialize the pointer + pfmc->_pfm = new CFreeMarshaler(punkOuter); + + if (pfmc->_pfm != NULL) + { + *ppunkMarshal = pfmc; + CALLHOOKOBJECTCREATE(S_OK,CLSID_NULL,IID_IUnknown, + (IUnknown **)ppunkMarshal); + hr = S_OK; + } + else + { + delete pfmc; + } + } + } + + return hr; +} + + + +//+------------------------------------------------------------------- +// +// Function: GetInProcFreeMarshaler, public +// +// Synopsis: Create the controlling unknown for the marshaler +// +// Arguments: [ppIM] - where to put inproc marshaler +// +// Returns: NOERROR +// E_OUTOFMEMORY +// +// History: 15-Nov-94 Ricksa Created +// +//-------------------------------------------------------------------- +HRESULT GetInProcFreeMarshaler(IMarshal **ppIM) +{ + HRESULT hr = E_OUTOFMEMORY; + + // Allocate new free marshal object + CFmCtrlUnknown *pfmc = new CFmCtrlUnknown(); + + if (pfmc != NULL) + { + // Initialize the pointer + pfmc->_pfm = new CFreeMarshaler(pfmc); + + if (pfmc->_pfm != NULL) + { + *ppIM = pfmc->_pfm; + hr = S_OK; + } + else + { + delete pfmc; + } + } + + return hr; +} + + +//+------------------------------------------------------------------- +// +// Member: CFmCtrlUnknown::CFmCtrlUnknown +// +// Synopsis: The constructor for controling IUnknown of free marshaler +// +// Arguments: None +// +// History: 15-Nov-94 Ricksa Created +// +//-------------------------------------------------------------------- +CFmCtrlUnknown::CFmCtrlUnknown(void) : _cRefs(1), _pfm(NULL) +{ + // Header does all the work. +} + + + + +//+------------------------------------------------------------------- +// +// Member: CFmCtrlUnknown::~CFmCtrlUnknown +// +// Synopsis: The destructor for controling IUnknown of free marshaler +// +// Arguments: None +// +// History: 15-Nov-94 Ricksa Created +// +//-------------------------------------------------------------------- +CFmCtrlUnknown::~CFmCtrlUnknown(void) +{ + delete _pfm; +} + + + + +//+------------------------------------------------------------------- +// +// Member: CFmCtrlUnknown::QueryInterface +// +// Returns: S_OK +// +// History: 15-Nov-94 Ricksa Created +// +//-------------------------------------------------------------------- +STDMETHODIMP CFmCtrlUnknown::QueryInterface(REFIID iid, void **ppv) +{ + *ppv = NULL; + HRESULT hr = E_NOINTERFACE; + + if (IsEqualGUID(iid, IID_IUnknown)) + { + *ppv = this; + AddRef(); + hr = S_OK; + } + else if (IsEqualGUID(iid, IID_IMarshal)) + { + *ppv = _pfm; + _pfm->AddRef(); + hr = S_OK; + } + + return hr; +} + + + +//+------------------------------------------------------------------- +// +// Member: CFmCtrlUnknown::AddRef +// +// Synopsis: Standard stuff +// +// History: 15-Nov-94 Ricksa Created +// +//-------------------------------------------------------------------- +STDMETHODIMP_(ULONG) CFmCtrlUnknown::AddRef(void) +{ + InterlockedIncrement((LONG *) &_cRefs); + + return _cRefs; +} + + + + +//+------------------------------------------------------------------- +// +// Member: CFmCtrlUnknown::Release +// +// Synopsis: Standard stuff +// +// History: 15-Nov-94 Ricksa Created +// +//-------------------------------------------------------------------- +STDMETHODIMP_(ULONG) CFmCtrlUnknown::Release(void) +{ + ULONG cRefs = InterlockedDecrement((LONG *) &_cRefs); + + if (cRefs == 0) + { + delete this; + } + + return cRefs; +} + + +//+------------------------------------------------------------------- +// +// Member: CFreeMarshaler::CFreeMarshaler() +// +// Synopsis: The constructor for CFreeMarshaler. +// +// Arguments: None +// +// History: 15-Nov-94 Ricksa Created +// +//-------------------------------------------------------------------- +CFreeMarshaler::CFreeMarshaler(IUnknown *punkCtrl) + : _punkCtrl(punkCtrl) +{ + // Header does all the work. +} + + + +//+------------------------------------------------------------------- +// +// Member: CFreeMarshaler::QueryInterface +// +// Synopsis: Pass QI to our controlling IUnknown +// +// Returns: S_OK +// +// History: 15-Nov-94 Ricksa Created +// +//-------------------------------------------------------------------- +STDMETHODIMP CFreeMarshaler::QueryInterface(REFIID iid, void **ppv) +{ + return _punkCtrl->QueryInterface(iid, ppv); +} + + + + +//+------------------------------------------------------------------- +// +// Member: CFreeMarshaler::AddRef +// +// Synopsis: Pass AddRef to our controlling IUnknown +// +// History: 15-Nov-94 Ricksa Created +// +//-------------------------------------------------------------------- +STDMETHODIMP_(ULONG) CFreeMarshaler::AddRef(void) +{ + return _punkCtrl->AddRef(); +} + + + + +//+------------------------------------------------------------------- +// +// Member: CFreeMarshaler::Release +// +// Synopsis: Pass release to our controlling IUnknown +// +// History: 15-Nov-94 Ricksa Created +// +//-------------------------------------------------------------------- +STDMETHODIMP_(ULONG) CFreeMarshaler::Release(void) +{ + return _punkCtrl->Release(); +} + + +//+------------------------------------------------------------------- +// +// Member: CFreeMarshaler::GetUnmarshalClass +// +// Synopsis: Return the unmarshaling class +// +// History: 08-Nov-94 Ricksa Created +// +//-------------------------------------------------------------------- +STDMETHODIMP CFreeMarshaler::GetUnmarshalClass( + REFIID riid, + void *pv, + DWORD dwDestContext, + void *pvDestContext, + DWORD mshlflags, + CLSID *pCid) +{ + // Inprocess context? + if (dwDestContext == MSHCTX_INPROC) + { + // If this is an inproc marshal then we are the class + // that can unmarshal. + *pCid = CLSID_InProcFreeMarshaler; + return S_OK; + } + + // we can just use the static guy here and save a lot of work. + IMarshal *pmrshlStd; + HRESULT hr = GetStaticUnMarshaler(&pmrshlStd); + + if (pmrshlStd != NULL) + { + hr = pmrshlStd->GetUnmarshalClass(riid, pv, dwDestContext, + pvDestContext, mshlflags, pCid); + + pmrshlStd->Release(); + } + + return hr; +} + +//+------------------------------------------------------------------- +// +// Member: CFreeMarshaler::GetMarshalSizeMax +// +// Synopsis: Return maximum bytes need for marshaling +// +// History: 08-Nov-94 Ricksa Created +// +//-------------------------------------------------------------------- +STDMETHODIMP CFreeMarshaler::GetMarshalSizeMax( + REFIID riid, + void *pv, + DWORD dwDestContext, + void *pvDestContext, + DWORD mshlflags, + DWORD *pSize) +{ + // Inprocess context? + if (dwDestContext == MSHCTX_INPROC) + { + // If this is an inproc marshal then we know the size + *pSize = sizeof(this); + return S_OK; + } + + // we can just use the static guy here and save a lot of work. + IMarshal *pmrshlStd; + HRESULT hr = GetStaticUnMarshaler(&pmrshlStd); + + if (pmrshlStd != NULL) + { + hr = pmrshlStd->GetMarshalSizeMax(riid, pv, dwDestContext, + pvDestContext, mshlflags, pSize); + + pmrshlStd->Release(); + } + + return hr; +} + +//+------------------------------------------------------------------- +// +// Member: CFreeMarshaler::MarshalInterface +// +// Synopsis: Marshal the interface +// +// History: 08-Nov-94 Ricksa Created +// +//-------------------------------------------------------------------- +STDMETHODIMP CFreeMarshaler::MarshalInterface( + IStream *pStm, + REFIID riid, + void *pv, + DWORD dwDestContext, + void *pvDestContext, + DWORD mshlflags) +{ + HRESULT hr; + + // Inprocess context? + if (dwDestContext == MSHCTX_INPROC) + { + // Write the marshal flags into the stream + hr = pStm->Write(&mshlflags, sizeof(mshlflags), NULL); + + if (hr == NOERROR) + { + // Write the pointer into the stream + ULONG cb; + + hr = pStm->Write(&pv, sizeof(pv), NULL); + + // Bump reference count based on type of marshal + if ((hr == NOERROR) && (mshlflags != MSHLFLAGS_TABLEWEAK)) + { + ((IUnknown *) pv)->AddRef(); + } + } + + return hr; + } + + // find or create a stdid for this object. Make sure we get a strong + // reference to gaurd against a simultaneous last release by another + // thread. + + CStdIdentity *pStdId; + hr = LookupIDFromUnk((IUnknown *) pv, IDLF_CREATE | IDLF_STRONG, &pStdId); + + if (SUCCEEDED(hr)) + { + hr = pStdId->MarshalInterface(pStm, riid, pv, dwDestContext, + pvDestContext, mshlflags); + + pStdId->DecStrongCnt(TRUE); // fKeepAlive + } + + return hr; +} + +//+------------------------------------------------------------------- +// +// Member: CFreeMarshaler::UnmarshalInterface +// +// Synopsis: Unmarshal the interface +// +// History: 08-Nov-94 Ricksa Created +// +//-------------------------------------------------------------------- +STDMETHODIMP CFreeMarshaler::UnmarshalInterface( + IStream *pStm, + REFIID riid, + void **ppv) +{ + HRESULT hr; + + // The marshal flags will tell us if we have to AddRef the object + DWORD mshlflags; + + hr = pStm->Read(&mshlflags, sizeof(mshlflags), NULL); + + if (hr == NOERROR) + { + // If Inprocess, we just read the pointer out of the stream + hr = pStm->Read(ppv, sizeof(*ppv), NULL); + + // AddRef the pointer if marshaled for a table. + if ((hr == NOERROR) + && ((mshlflags == MSHLFLAGS_TABLEWEAK) + || (mshlflags == MSHLFLAGS_TABLESTRONG))) + { + ((IUnknown *) *ppv)->AddRef(); + } + } + + return hr; +} + +//+------------------------------------------------------------------- +// +// Member: CFreeMarshaler::ReleaseMarshalData +// +// Synopsis: Release the marshaled data +// +// History: 08-Nov-94 Ricksa Created +// +//-------------------------------------------------------------------- +STDMETHODIMP CFreeMarshaler::ReleaseMarshalData(IStream *pStm) +{ + // Get the marshal flags + DWORD mshlflags; + + HRESULT hr = pStm->Read(&mshlflags, sizeof(mshlflags), NULL); + + if (hr == NOERROR) + { + IUnknown *punk; + + // If Inprocess, we just read the pointer out of the stream + hr = pStm->Read(&punk, sizeof(punk), NULL); + + if ((hr == NOERROR) && (mshlflags != MSHLFLAGS_TABLEWEAK)) + { + // Dump the extra AddRef we put on when we put the object + // during marshal. + punk->Release(); + } + } + + return hr; +} + +//+------------------------------------------------------------------- +// +// Member: CFreeMarshaler::DisconnectObject +// +// Synopsis: Disconnect the object +// +// History: 08-Nov-94 Ricksa Created +// +//-------------------------------------------------------------------- +STDMETHODIMP CFreeMarshaler::DisconnectObject(DWORD dwReserved) +{ + CStdIdentity *pStdId; + HRESULT hr = LookupIDFromUnk(_punkCtrl, 0, &pStdId); + + if (SUCCEEDED(hr)) + { + hr = pStdId->DisconnectObject(dwReserved); + pStdId->Release(); + } + else + { + // already disconnected, report success + hr = S_OK; + } + return hr; +} diff --git a/private/ole32/com/dcomrem/locks.cxx b/private/ole32/com/dcomrem/locks.cxx new file mode 100644 index 000000000..98244be90 --- /dev/null +++ b/private/ole32/com/dcomrem/locks.cxx @@ -0,0 +1,75 @@ +//+------------------------------------------------------------------- +// +// File: locks.cxx +// +// Contents: functions used in DBG builds to validate the lock state. +// +// History: 20-Feb-95 Rickhi Created +// +//-------------------------------------------------------------------- +#include <ole2int.h> +#include <locks.hxx> + +COleStaticMutexSem gComLock; + +#if DBG==1 + +#define MyAssert Win4Assert +// # define MyAssert(x) if (!(x)) { DebugBreak(); } + +struct tagGLOCK +{ + DWORD tid; // tid of current holder + LONG cLocks; // count of holds on the lock by current holder + DWORD line; // line # where lock taken + const char *file; // file name where lock taken +} glock = {0xffffffff, 0, 0xffffffff, 0}; + +void AssertLockHeld(void) +{ + MyAssert(glock.tid == GetCurrentThreadId()); + MyAssert(glock.cLocks > 0); // && "Lock not Held" +} + +void AssertLockReleased(void) +{ + MyAssert(glock.tid != GetCurrentThreadId() && "Lock not Released"); +} + +void ORPCLock(DWORD line, const char *file) +{ + gComLock.Request(); + + if (glock.cLocks > 0) + { + MyAssert(glock.tid == GetCurrentThreadId()); + } + else + { + glock.line = line; + glock.file = file; + } + + glock.tid = GetCurrentThreadId(); + glock.cLocks++; +} + +void ORPCUnLock(void) +{ + MyAssert(glock.cLocks > 0); // && "Releasing Unheld Lock" + MyAssert(glock.tid == GetCurrentThreadId()); + + glock.cLocks--; + + if (glock.cLocks == 0) + { + // we no longer hold the lock, set the tid to zero + glock.tid = 0; + } + + gComLock.Release(); +} + +#endif // DBG + + diff --git a/private/ole32/com/dcomrem/locks.hxx b/private/ole32/com/dcomrem/locks.hxx new file mode 100644 index 000000000..dfe84d437 --- /dev/null +++ b/private/ole32/com/dcomrem/locks.hxx @@ -0,0 +1,48 @@ +//+------------------------------------------------------------------- +// +// File: locks.hxx +// +// Contents: class and marcros for providing mutual exclusion +// +// Classes: CStaticSem +// +// History: 20-Feb-95 Rickhi Created +// +//-------------------------------------------------------------------- +#ifndef _ORPC_LOCKS_ +#define _ORPC_LOCKS_ + +#include "olesem.hxx" + +// global mutex for ORPC +extern COleStaticMutexSem gComLock; + + +//+--------------------------------------------------------------------------- +// +// Macros for use in the code. +// +//---------------------------------------------------------------------------- + +#if DBG==1 +void AssertLockHeld(void); +void AssertLockReleased(void); +void ORPCLock(DWORD line, const char *file); +void ORPCUnLock(void); + +#define LOCK ORPCLock(__LINE__, __FILE__); +#define UNLOCK ORPCUnLock(); +#define ASSERT_LOCK_HELD AssertLockHeld(); +#define ASSERT_LOCK_RELEASED AssertLockReleased(); +#define ASSERT_LOCK_DONTCARE // just exists to comment the code better + +#else + +#define LOCK gComLock.Request(); +#define UNLOCK gComLock.Release(); +#define ASSERT_LOCK_HELD +#define ASSERT_LOCK_RELEASED +#define ASSERT_LOCK_DONTCARE + +#endif // DBG +#endif // _ORPC_LOCKS_ diff --git a/private/ole32/com/dcomrem/makefile b/private/ole32/com/dcomrem/makefile new file mode 100644 index 000000000..e09078703 --- /dev/null +++ b/private/ole32/com/dcomrem/makefile @@ -0,0 +1,24 @@ +############################################################################ +# +# Copyright (C) 1992, Microsoft Corporation. +# +# All rights reserved. +# +############################################################################ + +!ifdef NTMAKEENV + +# We need to do the following so that build will stop reading from the +# pipe. + +all : + echo $(BUILDMSG) + +clean : all + +!else # NTMAKEENV + +!include $(CAIROLE)\com\makefile +!include $(DEPENDFILE) + +!endif # NTMAKEENV diff --git a/private/ole32/com/dcomrem/mapdwp.hxx b/private/ole32/com/dcomrem/mapdwp.hxx new file mode 100644 index 000000000..76f904e95 --- /dev/null +++ b/private/ole32/com/dcomrem/mapdwp.hxx @@ -0,0 +1,108 @@ +//+------------------------------------------------------------------- +// +// File: mapdwp.hxx +// +// Contents: Class to map thread id to thread local storage ptr +// +// Classes: CMapDword +// +// Notes: This class is needed soley for debug builds and then only +// because we dont get THREAD_DETACH notification for all +// threads when a process exits. This allows us to clean up +// the tls so we dont report memory leaks. +// +// In order to keep the implementation simple, we use a fixed +// array of entries, meaning we (may) get memory leaks +// reported if we ever have more than MAP_MAX_SIZE threads +// alive at any given time. +// +//+------------------------------------------------------------------- + +#if !defined(_CAIRO_) && DBG==1 + +#define MAP_MAX_SIZE 100 + +class CMapDword : public CPrivAlloc +{ +public: + CMapDword(void); + ~CMapDword(void); + + void SetAt(DWORD tid, void *pData); + void RemoveKey(DWORD tid); + void RemoveAll(void); + +private: + + DWORD _tid[MAP_MAX_SIZE]; + void * _pData[MAP_MAX_SIZE]; + DWORD _index; +}; + + +CMapDword::CMapDword(void) +{ + _index = 0; + memset(_tid, 0, MAP_MAX_SIZE * sizeof(DWORD)); +} + +CMapDword::~CMapDword(void) +{ + RemoveAll(); +} + +void CMapDword::SetAt(DWORD tid, void *pData) +{ + for (ULONG i=_index; i<MAP_MAX_SIZE; i++) + { + if (_tid[i] == 0) + { + _tid[i] = tid; + _pData[i] = pData; + _index = i; + return; + } + } + + for (i=0; i<_index; i++) + { + if (_tid[i] == 0) + { + _tid[i] = tid; + _pData[i] = pData; + _index = i; + return; + } + } + + Win4Assert(!"Tls Table is FULL"); +} + + +void CMapDword::RemoveKey(DWORD tid) +{ + for (ULONG i=0; i<MAP_MAX_SIZE; i++) + { + if (_tid[i] == tid) + { + _tid[i] = 0; + return; + } + } +} + + +void CMapDword::RemoveAll(void) +{ + for (ULONG i=0; i<MAP_MAX_SIZE; i++) + { + if (_tid[i] != 0) + { + PrivMemFree(_pData[i]); + _tid[i] = 0; + } + } +} + + +#endif // !defined(_CAIRO_) && DBG==1 diff --git a/private/ole32/com/dcomrem/marshal.cxx b/private/ole32/com/dcomrem/marshal.cxx new file mode 100644 index 000000000..42380e4ca --- /dev/null +++ b/private/ole32/com/dcomrem/marshal.cxx @@ -0,0 +1,4974 @@ +//+------------------------------------------------------------------- +// +// File: marshal.cxx +// +// Contents: class implementing standard COM interface marshaling +// +// Classes: CStdMarshal +// +// History: 20-Feb-95 Rickhi Created +// +// DCOMWORK: (maybe) implement Extended form marshal packet +// +// PERFWORK: during unmarshal and RMD compare the MOXID in the STDOBJREF +// to the one for the current apartment. If equal, then i know the IPID is +// just an index into the IPID table and i can index into it, grab the +// channel ptr and hence the stdid ptr and do very fast unmarshal or RMD +// with no table lookup or list walking. +// +//-------------------------------------------------------------------- +#include <ole2int.h> +#include <marshal.hxx> // CStdMarshal +#include <ipidtbl.hxx> // CIPIDTable, COXIDTable, CMIDTable +#include <riftbl.hxx> // CRIFTable +#include <resolver.hxx> // CRpcResolver +#include <stdid.hxx> // CStdIdentity +#include <channelb.hxx> // CRpcChannelBuffer +#include <callctrl.hxx> // CAptRpcChnl, CSrvCallCtrl +#include <scm.h> // CLSCTX_PS_DLL +#include <service.hxx> // SASIZE +#include <locks.hxx> // LOCK/UNLOCK etc +#include <thunkapi.hxx> // GetAppCompatabilityFlags + + +#if DBG==1 +// this flag and interface are used in debug to enable simpler testing +// of the esoteric NonNDR stub code feature. + +BOOL gfFakeNonNDR = FALSE; +const GUID IID_ICube = + {0x00000139,0x0001,0x0008,{0xC0,0x00,0x00,0x00,0x00,0x00,0x00,0x46}}; +#endif // DBG + + +// BUGBUG: this is not quite reliable enough. Maybe best solution is +// CoGetCurrentProcessId plus sequence number. +LONG gIPIDSeqNum = 0; + +// mappings from MSHLFLAGS to STDOBJREF flags +static ULONG mapMFtoSORF[] = +{ + SORF_NULL, // MSHLFLAGS_NORMAL + SORF_NULL, // MSHLFLAGS_TABLESTRONG + SORF_TBLWEAK // MSHLFLAGS_TABLEWEAK +}; + +// NULL resolver string array +DUALSTRINGARRAY saNULL = {0,0}; + +// number of remote AddRefs to acquire when we need more. +#define REM_ADDREF_CNT 5 + +// out internal psclass factory implementation +EXTERN_C HRESULT PrxDllGetClassObject(REFCLSID clsid, REFIID iid, void **ppv); + + +// structure used to post a delayed remote release call to ourself. +typedef struct tagPOSTRELRIFREF +{ + OXIDEntry *pOXIDEntry; // server OXIDEntry + USHORT cRifRef; // count of entries in arRifRef + REMINTERFACEREF arRifRef; // array of REMINTERFACEREFs +} POSTRELRIFREF; + + +//+------------------------------------------------------------------- +// +// Member: CStdMarshal::CStdMarshal/Init, public +// +// Synopsis: constructor/initializer of a standard marshaler +// +// History: 20-Feb-95 Rickhi Created. +// +//-------------------------------------------------------------------- +CStdMarshal::CStdMarshal() : _dwFlags(0), _pChnl(NULL) +{ + // Caller must call Init before doing anything! This just makes it + // easier for the identity object to figure out the init parameters + // before initializing us. +} + +void CStdMarshal::Init(IUnknown *punkObj, CStdIdentity *pStdId, + REFCLSID rclsidHandler, DWORD dwFlags) +{ + ASSERT_LOCK_DONTCARE // may be released if def handler calls CreateIdHdlr + + // server side we need to do the FirstMarshal work. + // client side we assume disconnected until we connect the first IPIDEntry + // and assume NOPING until we see any interface that needs pinging + + _dwFlags = dwFlags; + _dwFlags |= (ServerSide()) ? SMFLAGS_FIRSTMARSHAL + : SMFLAGS_DISCONNECTED | SMFLAGS_NOPING; + + _pFirstIPID = NULL; + _cIPIDs = 0; + _pStdId = pStdId; + _pChnl = NULL; + _cNestedCalls = 0; + _cTableRefs = 0; + _dwMarshalTime = 0; + _clsidHandler = rclsidHandler; + _pSecureRemUnk = NULL; + + ComDebOut((DEB_MARSHAL,"CStdMarshal %s New this:%x pStdId:%x punkObj:%x\n", + (ClientSide()) ? "CLIENT" : "SERVER", this, pStdId, punkObj)); + + AssertValid(); +} + +//+------------------------------------------------------------------- +// +// Member: CStdMarshal::~CStdMarshal, public +// +// Synopsis: destructor of a standard marshaler +// +// History: 20-Feb-95 Rickhi Created. +// +//-------------------------------------------------------------------- +CStdMarshal::~CStdMarshal() +{ + ComDebOut((DEB_MARSHAL, "CStdMarshal %s Deleted this:%x\n", + (ClientSide()) ? "CLIENT" : "SERVER", this)); + ASSERT_LOCK_RELEASED + + if (ClientSide()) + { + // Due to backward compatibility, we are not allowed to release + // interface proxies in Disconnect since the client might try to + // reconnect later and expects the same interface pointer values. + // Since we are going away now, we go release the proxies. + + ReleaseCliIPIDs(); + if (_pSecureRemUnk != NULL) + { + _pSecureRemUnk->Release(); + } + } + + if (_pChnl) + { + // release the channel + _pChnl->Release(); + } + + ASSERT_LOCK_RELEASED +} + +//+------------------------------------------------------------------- +// +// Member: CStdMarshal::GetUnmarshalClass, public +// +// Synopsis: returns the clsid of the standard marshaller +// +// History: 20-Feb-95 Rickhi Created. +// +//-------------------------------------------------------------------- +STDMETHODIMP CStdMarshal::GetUnmarshalClass(REFIID riid, LPVOID pv, + DWORD dwDestCtx, LPVOID pvDestCtx, DWORD mshlflags, LPCLSID pClsid) +{ + AssertValid(); + ASSERT_LOCK_RELEASED + + *pClsid = CLSID_StdMarshal; + return S_OK; +} + +//+------------------------------------------------------------------- +// +// Member: CStdMarshal::GetMarshalSizeMax, public +// +// Synopsis: Returns an upper bound on the amount of data for +// a standard interface marshal. +// +// History: 20-Feb-95 Rickhi Created. +// +//-------------------------------------------------------------------- +STDMETHODIMP CStdMarshal::GetMarshalSizeMax(REFIID riid, LPVOID pv, + DWORD dwDestCtx, LPVOID pvDestCtx, DWORD mshlflags, LPDWORD pSize) +{ + AssertValid(); + Win4Assert(gdwPsaMaxSize != 0); + ASSERT_LOCK_RELEASED + + *pSize = sizeof(OBJREF) + gdwPsaMaxSize; + return S_OK; +} + +//+------------------------------------------------------------------- +// +// Function: MarshalObjRef, private +// +// Synopsis: Marshals interface into the objref. +// +// Arguements: [objref] - object reference +// [riid] - interface id to marshal +// [pv] - interface to marshal +// [mshlflags] - marshal flags +// +// Algorithm: Get the correct standard identity and ask it to do +// all the work. +// +// History: 25-Mar-95 AlexMit Created +// +//-------------------------------------------------------------------- +INTERNAL MarshalObjRef(OBJREF &objref, REFIID riid, void *pv, DWORD mshlflags) +{ + TRACECALL(TRACE_MARSHAL, "MarshalObjRef"); + ComDebOut((DEB_MARSHAL, "MarshalObjRef: riid:%I pv:%x flags:%x\n", + &riid, pv, mshlflags)); + ASSERT_LOCK_RELEASED + + HRESULT hr = InitChannelIfNecessary(); + if (SUCCEEDED(hr)) + { + // Find or create the StdId for this object. We need to get a strong + // reference to guard against an incoming last release on another + // thread which would cause us to Disconnect this StdId. + + DWORD dwFlags = IDLF_CREATE | IDLF_STRONG; + dwFlags |= (mshlflags & MSHLFLAGS_NOPING) ? IDLF_NOPING : 0; + + CStdIdentity *pStdID; + hr = LookupIDFromUnk((IUnknown *)pv, dwFlags, &pStdID); + + if (hr == NOERROR) + { + hr = pStdID->MarshalObjRef(objref, riid, pv, mshlflags); + pStdID->DecStrongCnt(TRUE); // fKeepAlive + } + } + + ASSERT_LOCK_RELEASED + ComDebOut((DEB_MARSHAL, "MarshalObjRef: hr:%x\n", hr)); + return hr; +} + +//+------------------------------------------------------------------- +// +// Function: MarshalInternalObjRef, private +// +// Synopsis: Marshals an internal interface into the objref. +// +// Arguements: [objref] - object reference +// [riid] - interface id to marshal +// [pv] - interface to marshal +// [mshlflags] - marshal flags +// [ppStdId] - StdId to return (may be NULL) +// +// Algorithm: Create a StdIdentity and ask it to do the work. +// +// Notes: This differs from the normal MarshalObjRef in that it does +// not look in the OID table for an already marshaled interface, +// nor does it register the marshaled interface in the OID table. +// This is used for internal interfaces such as the IObjServer +// and IRemUnknown. +// +// History: 25-Oct-95 Rickhi Created +// +//-------------------------------------------------------------------- +INTERNAL MarshalInternalObjRef(OBJREF &objref, REFIID riid, void *pv, + DWORD mshlflags, void **ppStdId) +{ + TRACECALL(TRACE_MARSHAL, "MarshalInternalObjRef"); + ComDebOut((DEB_MARSHAL, "MarshalInternalObjRef: riid:%I pv:%x flags:%x\n", + &riid, pv, mshlflags)); + ASSERT_LOCK_RELEASED + + HRESULT hr = InitChannelIfNecessary(); + if (SUCCEEDED(hr)) + { + if (!IsEqualGUID(riid, IID_IRundown)) + { + // NOTE: make sure the local OXID is registered with the resolver. + // See the discussion on the Chicken and Egg problem in ipidtbl.cxx + // COXIDTable::GetLocalEntry for why this is necessary. + + LOCK + MOID moid; + hr = gResolver.ServerGetPreRegMOID(&moid); + UNLOCK + } + + if (SUCCEEDED(hr)) + { + // Find or create the StdId for this object. We need to get a strong + // reference to guard against an incoming last release on another + // thread which would cause us to Disconnect this StdId. + + IUnknown *pUnkId; // ignored + CStdIdentity *pStdId = new CStdIdentity(STDID_SERVER, NULL, + (IUnknown *)pv, &pUnkId); + + if (pStdId != NULL) + { + hr = pStdId->MarshalObjRef(objref, riid, pv, mshlflags); + + if (SUCCEEDED(hr) && ppStdId) + { + *ppStdId = (void *)pStdId; + } + else + { + pStdId->Release(); + } + } + else + { + hr = E_OUTOFMEMORY; + } + } + } + + ASSERT_LOCK_RELEASED + ComDebOut((DEB_MARSHAL, "MarshalInternalObjRef: hr:%x\n", hr)); + return hr; +} + + +//+------------------------------------------------------------------- +// +// Member: CStdMarshal::MarshalInterface, public +// +// Synopsis: marshals the interface into the stream. +// +// History: 20-Feb-95 Rickhi Created. +// +//-------------------------------------------------------------------- +STDMETHODIMP CStdMarshal::MarshalInterface(IStream *pStm, REFIID riid, + LPVOID pv, DWORD dwDestCtx, LPVOID pvDestCtx, DWORD mshlflags) +{ + ComDebOut((DEB_MARSHAL, + "CStdMarshal::MarshalInterface this:%x pStm:%x riid:%I pv:%x dwCtx:%x pvCtx:%x flags:%x\n", + this, pStm, &riid, pv, dwDestCtx, pvDestCtx, mshlflags)); + AssertValid(); + ASSERT_LOCK_RELEASED + + // Marshal the interface into an objref, then write the objref + // into the provided stream. + + OBJREF objref; + HRESULT hr = MarshalObjRef(objref, riid, pv, mshlflags); + + if (SUCCEEDED(hr)) + { + // write the objref into the stream + hr = WriteObjRef(pStm, objref, dwDestCtx); + + if (FAILED(hr)) + { + // undo whatever we just did, ignore error from here since + // the stream write error supercedes any error from here. + ReleaseMarshalObjRef(objref); + } + + // free resources associated with the objref. + FreeObjRef(objref); + } + + ASSERT_LOCK_RELEASED + ComDebOut((DEB_MARSHAL,"CStdMarshal::MarshalInterface this:%x hr:%x\n", + this, hr)); + return hr; +} + +//+------------------------------------------------------------------- +// +// Member: CStdMarshal::MarshalObjRef, public +// +// Synopsis: marshals the interface into the objref. +// +// History: 25-Mar-95 AlexMit Seperated from MarshalInterface +// +//-------------------------------------------------------------------- +HRESULT CStdMarshal::MarshalObjRef(OBJREF &objref, REFIID riid, + LPVOID pv, DWORD mshlflags) +{ + ComDebOut((DEB_MARSHAL, + "CStdMarshal::MarsalObjRef this:%x riid:%I pv:%x flags:%x\n", + this, &riid, pv, mshlflags)); + AssertValid(); + + // validate the parameters. we dont allow TABLE cases if we are + // a client side object. + + if ((mshlflags & MSHLFLAGS_TABLE) && ClientSide()) + return E_INVALIDARG; + + // count of Refs we are handing out. In the table cases we pass out + // zero refs because we dont know how many times it will be unmarshaled + // (and hence how many references to count). Zero refs will cause the + // client to call back and ask for more references if it does not already + // have any (which has the side effect of making sure the object still + // exists, which is required by RunningObjectTable). + + ULONG cRefs = (mshlflags & MSHLFLAGS_TABLE) ? 0 : + (ClientSide()) ? 1 : REM_ADDREF_CNT; + + ASSERT_LOCK_RELEASED + LOCK + + HRESULT hr = PreventDisconnect(); + if (SUCCEEDED(hr)) + { + // The first time through we have some extra work to do so go off + // and do that now. Next time we can just bypass all that work. + + if (_dwFlags & SMFLAGS_FIRSTMARSHAL) + { + hr = FirstMarshal((IUnknown *)pv, mshlflags); + } + + if (SUCCEEDED(hr)) + { + // Create the IPID table entry. On the server side this may + // cause the creation of an interface stub, on the client side + // it may just take away one of our references or it may call + // the server to get more references for the interface being + // marshaled. + + IPIDEntry *pIPIDEntry; + hr = MarshalIPID(riid, cRefs, mshlflags, &pIPIDEntry); + + if (SUCCEEDED(hr)) + { + // fill in the rest of the OBJREF + FillObjRef(objref, cRefs, mshlflags, pIPIDEntry); + } + } + } + + UNLOCK + ASSERT_LOCK_RELEASED + + // it is now OK to allow real disconnects in. + HRESULT hr2 = HandlePendingDisconnect(hr); + if (FAILED(hr2) && SUCCEEDED(hr)) + { + // a disconnect came in while marshaling. The ObjRef has a + // reference to the OXIDEntry so go free that now. + FreeObjRef(objref); + } + + ComDebOut((DEB_MARSHAL, "CStdMarshal::MarshalObjRef this:%x hr:%x\n", + this, hr2)); + return hr2; +} + +//+------------------------------------------------------------------- +// +// Member: CStdMarshal::FillObjRef, private +// +// Synopsis: Fill in the fields of an OBJREF +// +// History: 21-Sep-95 Rickhi Created +// +//+------------------------------------------------------------------- +void CStdMarshal::FillObjRef(OBJREF &objref, ULONG cRefs, DWORD mshlflags, + IPIDEntry *pIPIDEntry) +{ + ComDebOut((DEB_MARSHAL, "FillObjRef pObjRef:%x\n", &objref)); + ASSERT_LOCK_HELD + AssertDisconnectPrevented(); + Win4Assert(pIPIDEntry); + OXIDEntry **ppOXIDEntry; + + // first, fill in the STDOBJREF section + STDOBJREF *pStd = &ORSTD(objref).std; + FillSTD(pStd, cRefs, mshlflags, pIPIDEntry); + + // next fill in the rest of the OBJREF + objref.signature = OBJREF_SIGNATURE; // 'MEOW' + objref.iid = pIPIDEntry->iid; // interface iid + + if (_dwFlags & SMFLAGS_HANDLER) + { + // handler form, copy in the clsid + objref.flags = OBJREF_HANDLER; + ORHDL(objref).clsid = _clsidHandler; + ppOXIDEntry = (OXIDEntry **) &ORHDL(objref).saResAddr; + } + else + { + objref.flags = OBJREF_STANDARD; + ppOXIDEntry = (OXIDEntry **) &ORSTD(objref).saResAddr; + } + + // TRICK: in order to keep the objref a fixed size internally, + // we use the saResAddr.size field as a ptr to the OXIDEntry. We + // pay attention to this in ReadObjRef, WriteObjRef, and FreeObjRef. + + *ppOXIDEntry = pIPIDEntry->pOXIDEntry; + Win4Assert(*ppOXIDEntry != NULL); + IncOXIDRefCnt(*ppOXIDEntry); + ASSERT_LOCK_HELD +} + +//+------------------------------------------------------------------- +// +// Member: CStdMarshal::FillSTD, public +// +// Synopsis: Fill in the STDOBJREF fields of an OBJREF +// +// History: 21-Sep-95 Rickhi Created +// +//+------------------------------------------------------------------- +void CStdMarshal::FillSTD(STDOBJREF *pStd, ULONG cRefs, DWORD mshlflags, + IPIDEntry *pIPIDEntry) +{ + // fill in the STDOBJREF to return to the caller. + pStd->flags = mapMFtoSORF[mshlflags & MSHLFLAGS_TABLE]; + + pStd->flags |= (pIPIDEntry->dwFlags & IPIDF_NOPING) ? SORF_NOPING : 0; + pStd->flags |= (pIPIDEntry->dwFlags & IPIDF_NONNDRSTUB) ? SORF_NONNDR : 0; + + pStd->cPublicRefs = cRefs; + + pStd->ipid = pIPIDEntry->ipid; + + OIDFromMOID(_pStdId->GetOID(), &pStd->oid); + OXIDFromMOXID(pIPIDEntry->pOXIDEntry->moxid, &pStd->oxid); + + ValidateSTD(pStd); + DbgDumpSTD(pStd); +} + +//+------------------------------------------------------------------- +// +// Member: CStdMarshal::FirstMarshal, private +// +// Synopsis: Does some first-time server side marshal stuff +// +// Parameters: [pUnk] - interface being marshalled +// [mshlflags] - flags for marshaling +// +// History: 20-Feb-95 Rickhi Created. +// +//-------------------------------------------------------------------- +HRESULT CStdMarshal::FirstMarshal(IUnknown *pUnk, DWORD mshlflags) +{ + ComDebOut((DEB_MARSHAL, + "CStdMarshal::FirstMarshal this:%x pUnk:%x\n", this, pUnk)); + Win4Assert(ServerSide()); + Win4Assert(_dwFlags & SMFLAGS_FIRSTMARSHAL); + Win4Assert(_pChnl == NULL); + AssertValid(); + AssertDisconnectPrevented(); + ASSERT_LOCK_HELD + + // have now executed this code so dont do it again. + _dwFlags &= ~SMFLAGS_FIRSTMARSHAL; + + if (mshlflags & MSHLFLAGS_NOPING) + { + // if the first interface is marked as NOPING, then all interfaces + // for the object are treated as NOPING, otherwise, all interfaces + // are marked as PING. MakeSrvIPIDEntry will look at _dwFlags to + // determine whether to mark each IPIDEntry as NOPING or not. + + _dwFlags |= SMFLAGS_NOPING; + } + + // get our local OXID. This should have already been created, and + // so wont cause the LOCK to be released. + + OXIDEntry *pOXIDEntry; + HRESULT hr = gOXIDTbl.GetLocalEntry(&pOXIDEntry); + + if (SUCCEEDED(hr)) + { + // create a channel for this object. + CRpcChannelBuffer *pChnl; + hr = CreateChannel(pOXIDEntry, 0, GUID_NULL, GUID_NULL, &pChnl); + } + + ASSERT_LOCK_HELD + AssertDisconnectPrevented(); + ComDebOut((DEB_MARSHAL, + "CStdMarshal::FirstMarshal this:%x hr:%x\n", this, hr)); + return hr; +} + +//+------------------------------------------------------------------- +// +// Member: CStdMarshal::MarshalIPID, private +// +// Synopsis: finds or creates an interface stub and IPID entry +// for the given object interface. +// +// Arguments: [riid] - interface to look for +// [cRefs] - count of references wanted +// [mshlflags] - marshal flags +// [ppEntry] - place to return IPIDEntry ptr +// +// Returns: S_OK if succeeded +// +// History: 20-Feb-95 Rickhi Created +// +//-------------------------------------------------------------------- +HRESULT CStdMarshal::MarshalIPID(REFIID riid, ULONG cRefs, DWORD mshlflags, + IPIDEntry **ppIPIDEntry) +{ + TRACECALL(TRACE_MARSHAL, "CStdMarshal::MarshalIPID"); + ComDebOut((DEB_MARSHAL, + "CStdMarshal::MarshalIPID this:%x riid:%I cRefs:%x mshlflags:%x ppEntry:%x\n", + this, &riid, cRefs, mshlflags, ppIPIDEntry)); + AssertValid(); + AssertDisconnectPrevented(); + ASSERT_LOCK_HELD + + // validate input parms. + Win4Assert(!(IsEqualIID(riid, IID_NULL) || IsEqualIID(riid, IID_IMarshal))); + + // look for an existing IPIDEntry for the requested interface + IPIDEntry *pEntry; + HRESULT hr = FindIPIDEntry(riid, &pEntry); + + if (FAILED(hr)) + { + // no entry currently exists. on the server side we try to create one. + // on the client side we do a remote QI for the requested interface. + + if (ServerSide()) + { + // this call fail if we are disconnected during a yield. + hr = MakeSrvIPIDEntry(riid, &pEntry); + } + else + { + hr = RemQIAndUnmarshal(1, (GUID *)&riid, NULL); + if (SUCCEEDED(hr)) + { + hr = FindIPIDEntry(riid, &pEntry); + } + } + } + + if (SUCCEEDED(hr)) + { + // REFCOUNTING: + if (ServerSide()) + { + // remember the latest marshal time so we can tell if the ping + // server has run us down too early. This can happen when an + // existing client dies and we remarshal the interface just + // moments before the pingserver tells us the first guy is gone + // and before the new client has had time to unmarshal and ping. + + _dwMarshalTime = GetCurrentTime(); + + // inc the refcnt for the IPIDEntry and optionaly the stdid. Note + // that for TABLE marshals cRefs is 0 (that's the number that gets + // placed in the packet) but we do want a reference so we ask for + // 1 here. ReleaseMarshalData will undo the 1. + + ULONG cRefs2 = (mshlflags & MSHLFLAGS_TABLE) ? 1 : cRefs; + IncSrvIPIDCnt(pEntry, cRefs2, 0, NULL, mshlflags); + } + else // client side, + { + // we dont support marshaling weak refs on the client side, though + // we do support marshaling strong from a weak client by going to + // the server and getting a strong reference. + Win4Assert(!(mshlflags & MSHLFLAGS_WEAK)); + + if (cRefs >= pEntry->cStrongRefs) + { + // need more references than we own, go get more from server + // to satisfy the marshal. Get a few extra refs for ourselves + // unless we are a weak client. + + ULONG cExtraRefs = (_dwFlags & SMFLAGS_WEAKCLIENT) + ? 0 : REM_ADDREF_CNT; + + hr = RemoteAddRef(pEntry, pEntry->pOXIDEntry, cRefs + cExtraRefs, 0); + + if (SUCCEEDED(hr)) + { + // add in the extra references we asked for (if any). + pEntry->cStrongRefs += cExtraRefs; + } + } + else + { + // we have enough references to satisfy this request (and still + // keep some for ourselves), just subtract from the IPIDEntry + pEntry->cStrongRefs -= cRefs; + } + + // mark this object as having been client-side marshaled so + // that we can tell the resolver whether or not it needs to + // ping this object if we release it before the OID is registered. + + _dwFlags |= SMFLAGS_CLIENTMARSHALED; + } + + // do some debug stuff + ValidateIPIDEntry(pEntry); + ComDebOut((DEB_MARSHAL, "pEntry:%x cRefs:%x cStdId:%x\n", pEntry, + pEntry->cStrongRefs, _pStdId->GetRC())); + } + + *ppIPIDEntry = pEntry; + + ASSERT_LOCK_HELD + AssertDisconnectPrevented(); + ComDebOut((DEB_MARSHAL, "CStdMarshal::MarshalIPID hr:%x pIPIDEntry\n", hr, *ppIPIDEntry)); + return hr; +} + +//+------------------------------------------------------------------- +// +// Member: CStdMarshal::UnmarshalInterface, public +// +// Synopsis: Unmarshals an Interface from a stream. +// +// History: 20-Feb-95 Rickhi Created. +// +//-------------------------------------------------------------------- +STDMETHODIMP CStdMarshal::UnmarshalInterface(LPSTREAM pStm, + REFIID riid, VOID **ppv) +{ + ComDebOut((DEB_MARSHAL, "CStdMarshal::UnmarsalInterface this:%x pStm:%x riid:%I\n", + this, pStm, &riid)); + AssertValid(); + ASSERT_LOCK_RELEASED + + // read the objref from the stream and find or create an instance + // of CStdMarshal for its OID. Then ask that guy to do the rest of + // the unmarshal (create the interface proxy) + + OBJREF objref; + HRESULT hr = ReadObjRef(pStm, objref); + + if (SUCCEEDED(hr)) + { + // pass objref to subroutine to unmarshal the objref + hr = ::UnmarshalObjRef(objref, ppv); + + // release the objref we read + FreeObjRef(objref); + } + + ASSERT_LOCK_RELEASED + ComDebOut((DEB_MARSHAL, + "UnmarsalInterface this:%x pv:%x hr:\n", this, *ppv, hr)); + return hr; +} + +//+------------------------------------------------------------------- +// +// Function: UnmarshalObjRef, private +// +// Synopsis: UnMarshals interface from objref. +// +// Arguements: [objref] - object reference +// [ppv] - proxy +// +// Algorithm: Get the correct standard identity and ask it to do +// all the work. +// +// History: 25-Mar-95 AlexMit Created +// +//-------------------------------------------------------------------- +INTERNAL UnmarshalObjRef(OBJREF &objref, void **ppv) +{ + ASSERT_LOCK_RELEASED + + CStdMarshal *pStdMshl; + HRESULT hr = FindStdMarshal(objref, &pStdMshl); + + if (SUCCEEDED(hr)) + { + // pass objref to subroutine to unmarshal the objref + hr = pStdMshl->UnmarshalObjRef(objref, ppv); + CALLHOOKOBJECTCREATE(S_OK,ORHDL(objref).clsid,objref.iid,(IUnknown **)ppv); + pStdMshl->Release(); + } + else + { + // we could not create the indentity or handler, release the + // marshaled objref. + ReleaseMarshalObjRef(objref); + } + + ASSERT_LOCK_RELEASED + return hr; +} + +//+------------------------------------------------------------------- +// +// Function: ChkIfLocalOID, private +// +// Synopsis: Helper function for UnmarshalInternalObjRef & FindStdMarshal +// +// Arguements: [objref] - object reference +// [ppStdMshl] - CStdMarshal returned +// +// Algorithm: Read the objref, get the OID. If we already have an identity +// for this OID return it AddRefd. +// +// History: 21-May-95 MurthyS Created. +// +//-------------------------------------------------------------------- +INTERNAL_(BOOL) ChkIfLocalOID(OBJREF &objref, CStdIdentity **ppStdId) +{ + STDOBJREF *pStd = &ORSTD(objref).std; + BOOL flocal = FALSE; + + ComDebOut((DEB_MARSHAL, "ChkIfLocalOID (IN) poid: %x\n", &pStd->oid)); + Win4Assert((*ppStdId == NULL) && "ChkIfLocalOID: pStdId != NULL"); + + ASSERT_LOCK_RELEASED + LOCK + + OXIDEntry *pOXIDEntry = GetOXIDFromObjRef(objref); + + if (pOXIDEntry == GetLocalOXIDEntry()) + { + flocal = TRUE; + // OXID is for this apartment, look IPID up in the IPIDTable + // directly, and extract the CStdMarshal from it. + + IPIDEntry *pEntry = gIPIDTbl.LookupIPID(pStd->ipid); + if (pEntry && pEntry->pChnl) + { + // get the Identity + *ppStdId = pEntry->pChnl->GetStdId(); + (*ppStdId)->AddRef(); + } + } + + UNLOCK + ASSERT_LOCK_RELEASED + + return flocal; +} + +//+------------------------------------------------------------------- +// +// Function: UnmarshalInternalObjRef, private +// +// Synopsis: UnMarshals an internally-used interface from objref. +// +// Arguements: [objref] - object reference +// [ppv] - proxy +// +// Algorithm: Create a StdId and ask it to do the work. +// +// Notes: This differs from UnmarshalObjRef in that it does not lookup +// or register the OID. This saves a fair amount of work and +// avoids initializing the OID table. +// +// History: 25-Oct-95 Rickhi Created +// +//-------------------------------------------------------------------- +INTERNAL UnmarshalInternalObjRef(OBJREF &objref, void **ppv) +{ + ASSERT_LOCK_RELEASED + + HRESULT hr = S_OK; + CStdIdentity *pStdId = NULL; + + if (ChkIfLocalOID(objref, &pStdId)) + { + if (pStdId) + { + // set OID in objref to match that in returned std identity + OIDFromMOID(pStdId->GetOID(), &ORSTD(objref).std.oid); + } + else + { + hr = CO_E_OBJNOTCONNECTED; + } + } + else + { + ASSERT_LOCK_RELEASED + + hr = CreateIdentityHandler(NULL, ORSTD(objref).std.flags, + IID_IStdIdentity, (void **)&pStdId); + } + + if (SUCCEEDED(hr)) + { + // pass objref to subroutine to unmarshal the objref. tell StdId not + // to register the OID in the OID table. + + pStdId->IgnoreOID(); + hr = pStdId->UnmarshalObjRef(objref, ppv); + CALLHOOKOBJECTCREATE(S_OK,ORHDL(objref).clsid,objref.iid,(IUnknown **)ppv); + pStdId->Release(); + } + + ASSERT_LOCK_RELEASED + return hr; +} + + +//+------------------------------------------------------------------- +// +// Member: CStdMarshal::UnmarshalObjRef, private +// +// Synopsis: unmarshals the objref. Called by CoUnmarshalInterface, +// UnmarshalObjRef APIs, and UnmarshalInterface method. +// +// History: 20-Feb-95 Rickhi Created. +// +//-------------------------------------------------------------------- +HRESULT CStdMarshal::UnmarshalObjRef(OBJREF &objref, void **ppv) +{ + ComDebOut((DEB_MARSHAL, "CStdMarshal::UnmarsalObjRef this:%x objref:%x riid:%I\n", + this, &objref, &objref.iid)); + AssertValid(); + + STDOBJREF *pStd = &ORSTD(objref).std; + OXIDEntry *pOXIDEntry = GetOXIDFromObjRef(objref); + DbgDumpSTD(pStd); + + ASSERT_LOCK_RELEASED + LOCK + + // Prevent a disconnect from occuring while unmarshaling the + // interface since we may have to yield the ORPC lock. + + HRESULT hr = PreventPendingDisconnect(); + + if (SUCCEEDED(hr)) + { + if (objref.flags & OBJREF_HANDLER) + { + // handler form, extract the handler clsid and set our flags + _dwFlags |= SMFLAGS_HANDLER; + _clsidHandler = ORHDL(objref).clsid; + } + + // if no OID registered yet, do that now. only possible on client side + // during reconnect. + + MOID moid; + MOIDFromOIDAndMID(pStd->oid, pOXIDEntry->pMIDEntry->mid, &moid); + hr = _pStdId->SetOID(moid); + + if (SUCCEEDED(hr)) + { + // find or create the IPID entry for the interface. On the client + // side this may cause the creation of an interface proxy. It will + // also manipulate the reference counts. + + hr = UnmarshalIPID(objref.iid, pStd, pOXIDEntry, ppv); + } + } + + UNLOCK + ASSERT_LOCK_RELEASED + + if (ClientSide()) + { + if (SUCCEEDED(hr)) + { + if (_pStdId->IsAggregated()) + { + // we are currently holding a proxy pointer. If aggregated, + // the controlling unknown may want to override this pointer + // with his own version, so issue a QI to give it that chance. + IUnknown *pUnk = (IUnknown *)*ppv; + +#ifdef WX86OLE + if (gcwx86.IsN2XProxy(pUnk)) + { + // Tell wx86 thunk layer to thunk as IUnknown + gcwx86.SetStubInvokeFlag((BOOL)1); + } +#endif + + hr = pUnk->QueryInterface(objref.iid, ppv); + pUnk->Release(); + } + } + else + { + // cleanup our state on failure (only meaningful on client side, + // since if the unmarshal failed on the server side, the interface + // is already cleaned up). + ReleaseMarshalObjRef(objref); + } + } + + // now let pending disconnect through. on server-side, ignore any + // error from HPD and pay attention only to the unmarshal result, since + // a successful unmarshal on the server side may result in a disconnect + // if that was the last external reference to the object. + + HRESULT hr2 = HandlePendingDisconnect(hr); + + if (FAILED(hr2) && ClientSide()) + { + if (SUCCEEDED(hr)) + { + // a disconnect came in while unmarshaling. ppv contains an + // AddRef'd interface pointer so go Release that now. + ((IUnknown *)*ppv)->Release(); + } + hr = hr2; + } + + ComDebOut((DEB_MARSHAL, "CStdMarshal::UnmarsalObjRef this:%x hr:%x\n", + this, hr)); + return hr; +} + +//+------------------------------------------------------------------- +// +// Member: CStdMarshal::UnmarshalIPID, private +// +// Synopsis: finds or creates an interface proxy for the given +// interface. may also do a remote query interface. +// +// Arguements: [riid] - the interface to return +// [std] - standard objref to unmarshal from +// [pOXIDEntry] - ptr to OXIDEntry of the server +// [ppv] - interface ptr of type riid returned, AddRef'd +// +// Returns: S_OK if succeeded +// +// History: 20-Feb-95 Rickhi Created +// +//-------------------------------------------------------------------- +HRESULT CStdMarshal::UnmarshalIPID(REFIID riid, STDOBJREF *pStd, + OXIDEntry *pOXIDEntry, void **ppv) +{ + TRACECALL(TRACE_MARSHAL, "CStdMarshal::UnmarshalIPID"); + ComDebOut((DEB_MARSHAL, + "CStdMarshal::UnmarshalIPID this:%x riid:%I pStd:%x pOXIDEntry:%x\n", + this, &riid, pStd, pOXIDEntry)); + DbgDumpSTD(pStd); + AssertValid(); + AssertDisconnectPrevented(); + ASSERT_LOCK_HELD + + // validate input params. + Win4Assert(!(IsEqualIID(riid, IID_NULL) || IsEqualIID(riid, IID_IMarshal))); + Win4Assert(pStd != NULL); + ValidateSTD(pStd); + Win4Assert(pOXIDEntry); + + + // look for an existing IPIDEntry for the requested interface. + IPIDEntry *pEntry; + HRESULT hr = FindIPIDEntry(riid, &pEntry); + +#ifdef WX86OLE + BOOL fSameApt = SUCCEEDED(hr); + PVOID pvPSThunk = NULL; +#endif + + + // REFCOUNTING: + if (ClientSide()) + { + if (FAILED(hr)) + { + // no IPID Entry exists yet for the requested interface. We do + // have a STDOBJREF. Create the interface proxy and IPIDEntry + // now, and connect it up. If successful, the proxy will be + // fully connected upon return, with pEntry->cStrongRefs set + // to pStd->cPublicRefs. + + if (ppv) + *ppv = NULL; + hr = MakeCliIPIDEntry(riid, pStd, pOXIDEntry, &pEntry); + } + else if (pEntry->dwFlags & IPIDF_DISCONNECTED) + { + // reconnect the IPID entry to the server. this will set + // pEntry->cStrongRefs to pStd->cPublicRefs. Even though we could + // yield, the IPIDEntry is guarenteed connected on return + // (cause we are holding the lock on return). + + hr = ConnectIPIDEntry(pStd, pOXIDEntry, pEntry); + } + else if ((pStd->flags & SORF_WEAKREF) && + (pEntry->pOXIDEntry->dwFlags & OXIDF_MACHINE_LOCAL)) + { + // add the refcnt to our weak total for this IPIDEntry + pEntry->cWeakRefs += pStd->cPublicRefs; + } + else + { + // add the refcnt to our strong total for this IPIDEntry + pEntry->cStrongRefs += pStd->cPublicRefs; + } + } + else if (SUCCEEDED(hr)) + { + // unmarshaling in the server apartment. If the cRefs is zero, + // then the interface was TABLE marshalled and we dont do + // anything to the IPID RefCnts since the object must live until + // ReleaseMarshalData is called on it. + +#ifdef WX86OLE + pvPSThunk = gcwx86.UnmarshalledInSameApt(pEntry->pv, riid); +#endif + if (pStd->cPublicRefs > 0) + { + // normal case, dec the ref counts from the IPID entry, + // OLE always passed fLastReleaseCloses = FALSE on + // Unmarshal and RMD so do the same here. + + DWORD mshlflags = (pStd->flags & SORF_WEAKREF) + ? (MSHLFLAGS_WEAK | MSHLFLAGS_KEEPALIVE) + : (MSHLFLAGS_NORMAL | MSHLFLAGS_KEEPALIVE); + + DecSrvIPIDCnt(pEntry, pStd->cPublicRefs, 0, NULL, mshlflags); + } + } + + if (SUCCEEDED(hr) && ppv) + { + ValidateIPIDEntry(pEntry); + + // extract and AddRef the pointer to return to the caller. + // Do this before releasing the lock (which we might do below + // on the server-side in DecSrvIPIDCnt. + + // NOTE: we are calling App code while holding the lock, + // but there is no way to avoid this. + + Win4Assert(IsValidInterface(pEntry->pv)); + *ppv = pEntry->pv; + ((IUnknown *)*ppv)->AddRef(); + AssertOutPtrIface(hr, *ppv); + if (_dwFlags & SMFLAGS_WEAKCLIENT && !(pStd->flags & SORF_WEAKREF)) + { + // make the client interface weak, ignore errors. + UNLOCK + ASSERT_LOCK_RELEASED + RemoteChangeRef(0,0); + ASSERT_LOCK_RELEASED + LOCK + } +#ifdef WX86OLE + // If we unmarshalled in the same apartment as the object and Wx86 + // recognized the interface then change the returned proxy to the + // proxy created for the Wx86 PSThunk. + if (pvPSThunk == (PVOID)-1) + { + // Wx86 recognized the interface, but could not establish a + // PSThunk for it. Force an error return. + *ppv = NULL; + hr = E_NOINTERFACE; + } + else if (pvPSThunk != NULL) + { + // Wx86 recognized the interface and did establish a PSThunk + // for it. Force a successful return with Wx86 proxy interface. + *ppv = pvPSThunk; + } +#endif + } + + ComDebOut((DEB_MARSHAL, "pEntry:%x cRefs:%x cStdId:%x\n", pEntry, + (SUCCEEDED(hr)) ? pEntry->cStrongRefs : 0, _pStdId->GetRC())); + ASSERT_LOCK_HELD + AssertDisconnectPrevented(); + ComDebOut((DEB_MARSHAL, "CStdMarshal::UnmarshalIPID hr:%x\n", hr)); + return hr; +} + +//+------------------------------------------------------------------- +// +// Member: CStdMarshal::PrivateCopyProxy, internal +// +// Synopsis: Creates a copy of a proxy and IPID entry. +// +// Arguements: [pProxy] - Proxy to copy +// [ppProxy] - return copy here. +// +//-------------------------------------------------------------------- +HRESULT CStdMarshal::PrivateCopyProxy( IUnknown *pv, IUnknown **ppv ) +{ + TRACECALL(TRACE_MARSHAL, "CStdMarshal::PrivateCopyProxy"); + ComDebOut((DEB_MARSHAL, "CStdMarshal::PrivateCopyProxy this:%x pv:%x\n", + this, pv)); + + // Don't copy stubs. + if (ServerSide()) + return E_INVALIDARG; + + ASSERT_LOCK_RELEASED + LOCK + + // Prevent a disconnect from occuring while unmarshaling the + // interface since we may have to yield the ORPC lock. + + HRESULT hr = PreventPendingDisconnect(); + + if (SUCCEEDED(hr)) + { + // Find the proxy to copy. + IPIDEntry *pEntry; + hr = FindIPIDEntryByInterface(pv, &pEntry); + if (SUCCEEDED(hr)) + { + // Don't copy disconnected proxies. + if (pEntry->dwFlags & IPIDF_DISCONNECTED) + hr = RPC_E_DISCONNECTED; + + // IUnknown can't be copied. + else if (IsEqualGUID( pEntry->iid, IID_IUnknown )) + hr = E_INVALIDARG; + + else + { + BOOL fNonNDRProxy; + IRpcProxyBuffer *pProxy; + hr = CreateProxy(pEntry->iid, &pProxy, (void **)ppv, + &fNonNDRProxy); + + if (SUCCEEDED(hr)) + { + IPIDEntry *pIpidCopy; + + // add a disconnected IPID entry to the table. + hr = AddIPIDEntry(NULL, &pEntry->ipid, pEntry->iid, NULL, + pProxy, *ppv, &pIpidCopy); + + if (SUCCEEDED(hr)) + { + // mark this IPID as a copy so we dont free it during + // ReleaseIPIDs. + pIpidCopy->dwFlags |= IPIDF_COPY; + + // connect the IPIDEntry before adding it to the table so + // that we dont have to worry about races between Unmarshal, + // Disconnect, and ReconnectProxies. + + // Make up an objref. Mark it as NOPING since we dont + // really have any references and we dont really need + // any because if we ever try to marshal it we will + // find the original IPIDEntry and use that. NOPING + // also lets us skip this IPID in DisconnectCliIPIDs. + + STDOBJREF std; + OXIDFromMOXID(pEntry->pOXIDEntry->moxid, &std.oxid); + std.ipid = pEntry->ipid; + std.cPublicRefs = 1; + std.flags = SORF_NOPING; + + hr = ConnectIPIDEntry(&std, pEntry->pOXIDEntry, pIpidCopy); + + // Add this IPID entry after the original. + pIpidCopy->pNextOID = pEntry->pNextOID; + pEntry->pNextOID = pIpidCopy; + _cIPIDs++; + } + else + { + // could not get an IPIDEntry, release the proxy, need to + // release the lock to do this. + + UNLOCK + ASSERT_LOCK_RELEASED + + pProxy->Release(); + ((IUnknown *)*ppv)->Release(); + + ASSERT_LOCK_RELEASED + LOCK + } + } + } + } + + if (SUCCEEDED(hr)) + { + ValidateIPIDEntry(pEntry); + AssertOutPtrIface(hr, *ppv); + } + AssertDisconnectPrevented(); + } + ASSERT_LOCK_HELD + UNLOCK + ASSERT_LOCK_RELEASED + + // Now let pending disconnect through. + HRESULT hr2 = HandlePendingDisconnect(hr); + if (FAILED(hr2) && SUCCEEDED(hr)) + { + // a disconnect came in while creating the proxy. ppv contains + // an AddRef'd interface pointer so go Release that now. + ((IUnknown *)*ppv)->Release(); + } + + ComDebOut((DEB_MARSHAL, "CStdMarshal::PrivateCopyProxy hr:%x\n", hr2)); + return hr2; +} + +//+------------------------------------------------------------------- +// +// Member: MakeSrvIPIDEntry, private +// +// Synopsis: creates a server side IPID table entry +// +// Arguements: [riid] - the interface to return +// [ppEntry] - IPIDEntry returned +// +// History: 20-Feb-95 Rickhi Created +// +//-------------------------------------------------------------------- +HRESULT CStdMarshal::MakeSrvIPIDEntry(REFIID riid, IPIDEntry **ppEntry) +{ + Win4Assert(ServerSide()); + AssertValid(); + AssertDisconnectPrevented(); + ASSERT_LOCK_HELD + + BOOL fNonNDRStub; + void *pv; + IRpcStubBuffer *pStub; + HRESULT hr = CreateStub(riid, &pStub, &pv, &fNonNDRStub); + + if (SUCCEEDED(hr)) + { + OXIDEntry *pOXIDEntry = _pChnl->GetOXIDEntry(); + + IPID ipidDummy; + hr = AddIPIDEntry(pOXIDEntry, &ipidDummy, riid, _pChnl, pStub, pv, + ppEntry); + + if (SUCCEEDED(hr)) + { + if (_dwFlags & SMFLAGS_NOPING) + { + // object does no need pinging, turn on NOPING + (*ppEntry)->dwFlags |= IPIDF_NOPING; + } + + if (fNonNDRStub) + { + // the stub was a custom 16bit one requested by WOW, mark the + // IPIDEntry as holding a non-NDR stub so we know to set the + // SORF_NONNDR flag in the StdObjRef when marshaling. This + // tells local clients whether to create a MIDL generated + // proxy or custom proxy. Functionality to support OLE + // Automation on DCOM. + + (*ppEntry)->dwFlags |= IPIDF_NONNDRSTUB; + } + + // increment the OXIDEntry ref count so that it stays + // around as long as the IPIDEntry points to it. It gets + // decremented when we disconnect the IPIDEntry. + + IncOXIDRefCnt(pOXIDEntry); + + // chain the IPIDEntries for this OID together + + (*ppEntry)->pNextOID = _pFirstIPID; + _pFirstIPID = *ppEntry; + } + else + { + // release the stub. we need to release the lock to do this. + UNLOCK + ASSERT_LOCK_RELEASED + + pStub->Release(); + + ASSERT_LOCK_RELEASED + LOCK + } + } + + ASSERT_LOCK_HELD + AssertDisconnectPrevented(); + return hr; +} + +//+------------------------------------------------------------------- +// +// Member: MakeCliIPIDEntry, private +// +// Synopsis: creates a client side IPID table entry +// +// Arguements: [riid] - the interface to return +// [pStd] - standard objref +// [pOXIDEntry] - OXIDEntry of the server +// [ppEntry] - IPIDEntry returned +// +// History: 20-Feb-95 Rickhi Created +// +//-------------------------------------------------------------------- +HRESULT CStdMarshal::MakeCliIPIDEntry(REFIID riid, STDOBJREF *pStd, + OXIDEntry *pOXIDEntry, + IPIDEntry **ppEntry) +{ + Win4Assert(ClientSide()); + AssertValid(); + AssertDisconnectPrevented(); + Win4Assert(pOXIDEntry); + ASSERT_LOCK_HELD + + BOOL fNonNDRProxy; + void *pv; + IRpcProxyBuffer *pProxy; + HRESULT hr = CreateProxy(riid, &pProxy, &pv, &fNonNDRProxy); + + if (SUCCEEDED(hr)) + { + // add a disconnected IPID entry to the table. + hr = AddIPIDEntry(NULL, &pStd->ipid, riid, NULL, pProxy, pv, ppEntry); + + if (pv) + { + // throw away our reference here, we will get it back later + // in UnmarshalIPID + ((IUnknown *)pv)->Release(); + } + + if (SUCCEEDED(hr)) + { + if (fNonNDRProxy) + { + // the proxy is a custom 16bit one requested by WOW, mark the + // IPIDEntry as holding a non-NDR proxy so we know to set the + // LOCALF_NOTNDR flag in the local header when we call on it + // (see CRpcChannelBuffer::ClientGetBuffer). Functionality to + // support OLE Automation on DCOM. + + (*ppEntry)->dwFlags |= IPIDF_NONNDRPROXY; + } + + if (pStd->flags & SORF_NONNDR) + { + // need to remember this flag so we can tell other + // unmarshalers if we remarshal it. + + (*ppEntry)->dwFlags |= IPIDF_NONNDRSTUB; + } + + // connect the IPIDEntry before adding it to the table so + // that we dont have to worry about races between Unmarshal, + // Disconnect, and ReconnectProxies. + + hr = ConnectIPIDEntry(pStd, pOXIDEntry, *ppEntry); + + // chain the IPIDEntries for this OID together. On client side + // always add the entry to the list regardless of whether connect + // succeeded. + + (*ppEntry)->pNextOID = _pFirstIPID; + _pFirstIPID = *ppEntry; + + _cIPIDs++; + } + else + { + // could not get an IPIDEntry, release the proxy, need to + // release the lock to do this. + + UNLOCK + ASSERT_LOCK_RELEASED + + pProxy->Release(); + + ASSERT_LOCK_RELEASED + LOCK + } + } + + ASSERT_LOCK_HELD + AssertDisconnectPrevented(); + return hr; +} + +//+------------------------------------------------------------------- +// +// Member: ConnectIPIDEntry, private +// +// Synopsis: connects a client side IPID table entry to the server +// +// Arguments: [pStd] - standard objref +// [pOXIDEntry] - OXIDEntry for the server +// [pEntry] - IPIDEntry to connect, already has a proxy +// and the IID filled in. +// +// Notes: This routine is re-entrant, it may be called multiple +// times for the same IPIDEntry, with part of the work done +// in one call and part in another. Only if the entry is +// fully set up will it return S_OK and mark the entry as +// connected. DisconnectCliIPIDs handles cleanup of partial +// connections. +// +// History: 20-Feb-95 Rickhi Created +// +//-------------------------------------------------------------------- +HRESULT CStdMarshal::ConnectIPIDEntry(STDOBJREF *pStd, + OXIDEntry *pOXIDEntry, + IPIDEntry *pEntry) +{ + ComDebOut((DEB_MARSHAL, + "CStdMarshal::ConnectIPIDEntry this:%x ipid:%I pOXIDEntry:%x pIPIDEntry:%x\n", + this, &pStd->ipid, pOXIDEntry, pEntry)); + Win4Assert(ClientSide()); + AssertDisconnectPrevented(); + AssertValid(); + Win4Assert(pOXIDEntry); + ASSERT_LOCK_HELD + HRESULT hr = S_OK; + + // mark the object as having attempted to connect an IPIDEntry so that + // if we fail somewhere in this routine and dont mark the whole object + // as connected, Disconnect will still try to clean things up. + + _dwFlags |= SMFLAGS_TRIEDTOCONNECT; + + if (!(pStd->flags & SORF_NOPING)) + { + // this interface requires pinging, turn off NOPING for this object + // and this IPIDEntry. + _dwFlags &= ~SMFLAGS_NOPING; + pEntry->dwFlags &= ~IPIDF_NOPING; + } + + if (!(_dwFlags & (SMFLAGS_REGISTEREDOID | SMFLAGS_NOPING))) + { + // register the OID with the ping server so it will get pinged + hr = gResolver.ClientRegisterOIDWithPingServer(pStd->oid, pOXIDEntry); + if (FAILED(hr)) + { + return hr; + } + + _dwFlags |= SMFLAGS_REGISTEREDOID; + } + + // Go get any references we need that are not already included in the + // STDOBJREF. These references will have been added to the counts in + // the IPIDEntry upon return. Any references in the STDOBJREF will be + // added to the IPIDEntry count only if the connect succeeds, otherwise + // ReleaseMarshalObjRef (which will clean up STDOBJREF references) will + // get called by higher level code. + + hr = GetNeededRefs(pStd, pOXIDEntry, pEntry); + if (FAILED(hr)) + { + return hr; + } + + if (pEntry->pChnl == NULL) + { + // create a channel for this oxid/ipid pair. On the client side we + // create one channel per proxy (and hence per IPID). + + hr = CreateChannel(pOXIDEntry, pStd->flags, pStd->ipid, + pEntry->iid, &pEntry->pChnl); + + if (SUCCEEDED(hr)) + { + // update this IPID table entry. must update ipid too since + // on reconnect it differs from the old value. + + IncOXIDRefCnt(pOXIDEntry); + pEntry->pOXIDEntry = pOXIDEntry; + pEntry->ipid = pStd->ipid; + pEntry->pChnl->SetIPIDEntry(pEntry); + } + } + + if (SUCCEEDED(hr)) + { + // Release the lock while we connect the proxy. We have to do + // this because the IDispatch proxy makes an Rpc call during + // Connect (Yuk!), which causes the channel to assert that the + // lock is released. The proxy MUST be able to handle multiple + // simultaneous or nested connects to the same channel ptr, since + // it is possible when we yield the lock for another thread to + // come in here and try a connect. + + void *pv = NULL; + IRpcProxyBuffer * pProxy = (IRpcProxyBuffer *)(pEntry->pStub); + + if (pProxy) + { + // HACKALERT: OleAutomation returns NULL pv in CreateProxy + // in cases where they dont know whether to return an NDR + // proxy or a custom-format proxy. So we have to go connect + // the proxy first then Query for the real interface once that + // is done. + + BOOL fGetpv = (pEntry->pv) ? FALSE : TRUE; + + UNLOCK + ASSERT_LOCK_RELEASED + + hr = pProxy->Connect(pEntry->pChnl); + if (fGetpv && SUCCEEDED(hr)) + { +#ifdef WX86OLE + if (gcwx86.IsN2XProxy(pProxy)) + { + // If we are creating a proxy for an object that is + // living on the x86 side then we need to set the + // StubInvoke flag to allow QI to thunk the + // custom interface QI. + gcwx86.SetStubInvokeFlag((BOOL)2); + } +#endif + hr = pProxy->QueryInterface(pEntry->iid, &pv); + AssertOutPtrIface(hr, pv); + + if(SUCCEEDED(hr)) + { +#ifdef WX86OLE + // Call whole32 thunk layer to play with the ref count + // and aggregate the proxy to the controlling unknown. + gcwx86.AggregateProxy(_pStdId->GetCtrlUnk(), + (IUnknown *)pv); +#endif + // Release our reference here. + // We keep a weak reference to pv. + ((IUnknown *)pv)->Release(); + } + } + + ASSERT_LOCK_RELEASED + LOCK + } + + // Regardless of errors from Connect and QI we wont try to cleanup + // any of the work we have done so far in this routine. The routine + // is reentrant (by the same thread or by different threads) and + // those calls could be using some of resources we have already + // allocated. Instead, we rely on DisconnectCliIPIDs to cleanup + // the partial allocation of resources. + + if (pEntry->dwFlags & IPIDF_DISCONNECTED) + { + // Mark the IPIDEntry as connected so we dont try to connect + // again. Also, as long as there is one IPID connected, the + // whole object is considered connected. This allows disconnect + // to find the newly connected IPID and disconnect it later. + // Infact, DisconnectCliIPIDs relies on there being at least + // one IPID with a non-NULL OXIDEntry. It is safe to set this + // now because Disconnects have been temporarily turned off. + + if (SUCCEEDED(hr)) + { + if (pv) + { + // assign the interface pointer + pEntry->pv = pv; + } + + AssertDisconnectPrevented(); + pEntry->dwFlags &= ~IPIDF_DISCONNECTED; + _dwFlags &= ~SMFLAGS_DISCONNECTED; + } + } + else + { + // while the lock was released, the IPIDEntry got connected + // by another thread (or by a nested call on this thread). + // Ignore any errors from Connect or QI since apparently + // things are connected now. + + hr = S_OK; + } + + if (SUCCEEDED(hr)) + { + // Add in any references we were given. If we were given 0 refs + // and the interface is noping, then pretend like we got 1 ref. + + ULONG cRefs = ((pStd->cPublicRefs == 0) && (pStd->flags & SORF_NOPING)) + ? 1 : pStd->cPublicRefs; + + // figure out if we have weak or strong references. To be weak + // they must be local to this machine and the SORF flag set. + BOOL fWeak = ((pStd->flags & SORF_WEAKREF) && + (pOXIDEntry->dwFlags & OXIDF_MACHINE_LOCAL)); + + if (fWeak) + pEntry->cWeakRefs += cRefs; + else + pEntry->cStrongRefs += cRefs; + } + + // in debug build, ensure that we did not screw up + ValidateIPIDEntry(pEntry); + } + + ASSERT_LOCK_HELD + AssertDisconnectPrevented(); + ComDebOut((DEB_MARSHAL, + "CStdMarshal::ConnectIPIDEntry this:%x pOXIDEntry:%x pChnl:%x hr:%x\n", + this, pEntry->pOXIDEntry, pEntry->pChnl, hr)); + return hr; +} + +//+------------------------------------------------------------------- +// +// Member: GetNeededRefs, private +// +// Synopsis: Figures out if any references are needed and goes and gets +// them from the server. +// +// Arguments: [pStd] - standard objref +// [pOXIDEntry] - OXIDEntry for the server +// [pEntry] - IPIDEntry to connect, already has a proxy +// and the IID filled in. +// +// History: 20-Feb-95 Rickhi Created +// +//-------------------------------------------------------------------- +HRESULT CStdMarshal::GetNeededRefs(STDOBJREF *pStd, OXIDEntry *pOXIDEntry, + IPIDEntry *pEntry) +{ + HRESULT hr = S_OK; + + if ((pStd->flags & (SORF_NOPING | SORF_WEAKREF)) == 0) + { + // if we dont have any and weren't given any strong refs, go get some. + ULONG cNeedStrong = ((pEntry->cStrongRefs + pStd->cPublicRefs) == 0) + ? REM_ADDREF_CNT : 0; + + // if we are using secure refs and we dont have any, go get some. + ULONG cNeedSecure = ((gCapabilities & EOAC_SECURE_REFS) && + (pEntry->cPrivateRefs == 0)) ? 1 : 0; + + if (cNeedStrong || cNeedSecure) + { + // Need to go get some references from the remote server. Note + // that we will yield here but we dont have to worry about it because + // the IPIDEntry is still marked as disconnected. + + hr = RemoteAddRef(pEntry, pOXIDEntry, cNeedStrong, cNeedSecure); + + if (SUCCEEDED(hr)) + { + pEntry->cStrongRefs += cNeedStrong; + pEntry->cPrivateRefs += cNeedSecure; + } + } + } + + return hr; +} + +//+------------------------------------------------------------------- +// +// Member: CStdMarshal::ReconnectProxies +// +// Synopsis: Reconnects the proxies to a new server (functionality +// used by the OLE default handler). +// +// History: 20-Feb-95 Rickhi Created. +// +// CODEWORK: CreateServer should just ask for all these interfaces +// during the create. +// +// BUGBUG: fail this call if freethreaded +// +//-------------------------------------------------------------------- +void CStdMarshal::ReconnectProxies() +{ + ComDebOut((DEB_MARSHAL,"CStdMarshal::ReconnectProxies this:%x pFirst:%x\n", + this, _pFirstIPID)); + AssertValid(); + Win4Assert(ClientSide()); + ASSERT_LOCK_RELEASED + LOCK + + // must be at least 1 proxy already connected in order to be able + // to reconnect the other proxies. We cant just ASSERT that's true + // because we were not holding the lock on entry. + + HRESULT hr = PreventDisconnect(); + + if (SUCCEEDED(hr)) + { + // allocate a stack buffer to hold the IPIDs + IID *pIIDsAlloc = (IID *) _alloca(_cIPIDs * sizeof(IID)); + IID *pIIDs = pIIDsAlloc; + USHORT cIIDs = 0; + + IPIDEntry *pNextIPID = _pFirstIPID; + + while (pNextIPID) + { + // Don't allow reconnection for fancy new servers or with + // secure proxies. + if (pNextIPID->dwFlags & IPIDF_COPY) + { + hr = E_FAIL; + break; + } + if ((pNextIPID->dwFlags & IPIDF_DISCONNECTED)) + { + // not connected, add it to the list to be connected. + *pIIDs = pNextIPID->iid; + pIIDs++; + cIIDs++; + } + + pNextIPID = pNextIPID->pNextOID; + } + + if (cIIDs != 0 && SUCCEEDED(hr)) + { + // we have looped filling in the IID list, and there are + // entries int he list. go call QI on server now and + // unmarshal the results. + + hr = RemQIAndUnmarshal(cIIDs, pIIDsAlloc, NULL); + } + } + + DbgWalkIPIDs(); + UNLOCK + ASSERT_LOCK_RELEASED + + // this will handle any Disconnect that came in while we were busy. + hr = HandlePendingDisconnect(hr); + + ComDebOut((DEB_MARSHAL,"CStdMarshal::ReconnectProxies [OUT] this:%x\n", this)); + return; +} + +//+------------------------------------------------------------------- +// +// Member: CStdMarshal::ReleaseMarshalData, public +// +// Synopsis: Releases the references added by MarshalInterface +// +// History: 20-Feb-95 Rickhi Created. +// +//-------------------------------------------------------------------- +STDMETHODIMP CStdMarshal::ReleaseMarshalData(LPSTREAM pStm) +{ + ComDebOut((DEB_MARSHAL, + "CStdMarshal::ReleaseMarshalData this:%x pStm:%x\n", this, pStm)); + AssertValid(); + ASSERT_LOCK_RELEASED + + OBJREF objref; + HRESULT hr = ReadObjRef(pStm, objref); + + if (SUCCEEDED(hr)) + { + // call worker API to do the rest of the work + hr = ::ReleaseMarshalObjRef(objref); + + // deallocate the objref we read + FreeObjRef(objref); + } + + ASSERT_LOCK_RELEASED + ComDebOut((DEB_MARSHAL, + "CStdMarshal::ReleaseMarshalData this:%x hr:%x\n", this, hr)); + return hr; +} + +//+------------------------------------------------------------------- +// +// Function: ReleaseMarshalObjRef, private +// +// Synopsis: Releases the references added by MarshalObjRef +// +// Arguements: [objref] - object reference +// +// Algorithm: Get the correct standard identity and ask it to do +// a ReleaseMarshalData. +// +// History: 19-Jun-95 Rickhi Created +// +//-------------------------------------------------------------------- +INTERNAL ReleaseMarshalObjRef(OBJREF &objref) +{ + ComDebOut((DEB_MARSHAL, "ReleaseMarshalObjRef objref:%x\n", &objref)); + ASSERT_LOCK_RELEASED + + HRESULT hr = InitChannelIfNecessary(); + if (SUCCEEDED(hr)) + { + CStdMarshal *pStdMshl; + hr = FindStdMarshal(objref, &pStdMshl); + + if (SUCCEEDED(hr)) + { + // only do the RMD if on the server side. + if (pStdMshl->ServerSide()) + { + // pass objref to subroutine to Release the marshaled data + hr = pStdMshl->ReleaseMarshalObjRef(objref); + } + pStdMshl->Release(); + } + else + { + // we could not find or create an identity. If the server is + // outside this apartment, try to issue a remote release on + // the interface. if the OXID is local and we could not find + // the identity, there is nothing left to cleanup. + + LOCK + OXIDEntry *pOXIDEntry = GetOXIDFromObjRef(objref); + if (pOXIDEntry != GetLocalOXIDEntry()) + { + // make a remote release call + RemoteReleaseObjRef(objref); + } + UNLOCK + } + } + + ASSERT_LOCK_RELEASED + ComDebOut((DEB_MARSHAL, "ReleaseMarshalObjRef hr:%x\n", hr)); + return hr; +} + +//+------------------------------------------------------------------- +// +// Member: CStdMarshal::ReleaseMarshalObjRef, public +// +// Synopsis: Releases the references added by MarshalObjRef +// +// History: 20-Feb-95 Rickhi Created. +// +//-------------------------------------------------------------------- +HRESULT CStdMarshal::ReleaseMarshalObjRef(OBJREF &objref) +{ + ComDebOut((DEB_MARSHAL, + "CStdMarshal::ReleaseMarshalObjRef this:%x objref:%x\n", this, &objref)); + AssertValid(); + + HRESULT hr = S_OK; + STDOBJREF *pStd = &ORSTD(objref).std; + ValidateSTD(pStd); + + ASSERT_LOCK_RELEASED + LOCK + + // REFCOUNTING: + if (ServerSide()) + { + // look for an existing IPIDEntry for the given IPID + IPIDEntry *pEntry; + hr = FindIPIDEntryByIPID(pStd->ipid, &pEntry); + + if (SUCCEEDED(hr) && !(pEntry->dwFlags & IPIDF_DISCONNECTED)) + { + // subtract the ref count from the IPIDEntry, may Release the + // StdId if this was the last reference for this IPIDEntry. + + // we need to figure out how it was marshalled, strong/weak etc + // in order to set the flags and cRefs correctly to pass to + // DecSrvIPIDCnt. + + if (pStd->cPublicRefs == 0) + { + // table case + DWORD mshlflags = (pStd->flags & SORF_TBLWEAK) + ? MSHLFLAGS_TABLEWEAK : MSHLFLAGS_TABLESTRONG; + DecSrvIPIDCnt(pEntry, 1, 0, NULL, mshlflags); + } + else + { + // normal or weak case + DWORD mshlflags = (pStd->flags & SORF_WEAKREF) + ? MSHLFLAGS_WEAK : MSHLFLAGS_NORMAL; + DecSrvIPIDCnt(pEntry, pStd->cPublicRefs, 0, NULL, mshlflags); + } + } + } + else // client side + { + if ((pStd->cPublicRefs == 0) || (pStd->flags & SORF_NOPING)) + { + // there are no references, or this interface does not + // need pinging, so there is nothing to do. + ; + } + else + { + // look for an existing IPIDEntry for the given IPID + IPIDEntry *pEntry; + hr = FindIPIDEntryByIPID(pStd->ipid, &pEntry); + + if (SUCCEEDED(hr) && !(pEntry->dwFlags & IPIDF_DISCONNECTED)) + { + // add these to the cRefs of this entry, they will get freed + // when we do the remote release. Saves an Rpc call now. + + if ((pStd->flags & SORF_WEAKREF) && + (pEntry->pOXIDEntry->dwFlags & OXIDF_MACHINE_LOCAL)) + pEntry->cWeakRefs += pStd->cPublicRefs; + else + pEntry->cStrongRefs += pStd->cPublicRefs; + } + else + { + // client side, no matching IPIDEntry so just contact the remote + // server to remove the reference. ignore errors since there is + // nothing we can do about them anyway. + RemoteReleaseObjRef(objref); + hr = S_OK; + } + } + } + + UNLOCK + ASSERT_LOCK_RELEASED + ComDebOut((DEB_MARSHAL, + "CStdMarshal::ReleaseMarshalObjRef this:%x hr:%x\n", this, hr)); + return hr; +} + +//+------------------------------------------------------------------- +// +// Member: CStdMarshal::PreventDisconnect, public +// +// Synopsis: Prevents a Disconnect from occurring until a matching +// HandlePendingDisconnect is called. +// +// History: 21-Sep-95 Rickhi Created +// +// The ORPC LOCK is yielded at many places in order to make calls on +// application interfaces (server-side objects, stubs, proxies, +// handlers, remote objects, resolver, etc). In order to keep the +// code (reasonably?) simple, disconnects are prevented from occuring +// while in the middle of (potentially) complex operations, and while +// there are outstanding calls on interfaces to this object. +// +// To accomplish this, a counter (_cNestedCalls) is atomically incremented. +// When _cNestedCalls != 0 and a Disconnect arrives, the object is flagged +// as PendingDisconnect. When HandlePendingDisconnect is called, it +// decrements the _cNestedCalls. If the _cNestedCalls == 0 and there is +// a pending disconnect, the real Disconnect is done. +// +//-------------------------------------------------------------------- +HRESULT CStdMarshal::PreventDisconnect() +{ + ASSERT_LOCK_HELD + + // treat this as a nested call so that if we yield, a real + // disconnect wont come through, instead it will be treated + // as pending. That allows us to avoid checking our state + // for Disconnected every time we yield the ORPC LOCK. + + InterlockedIncrement(&_cNestedCalls); + + if (_dwFlags & (SMFLAGS_DISCONNECTED | SMFLAGS_PENDINGDISCONNECT)) + return CO_E_OBJNOTCONNECTED; + + return S_OK; +} + +//+------------------------------------------------------------------- +// +// Member: CStdMarshal::PreventPendingDisconnect, public +// +// Synopsis: similar to PreventDisconnect but special case for use +// in UnmarshalObjRef (since the client side starts out +// in the Disconnected state until the first unmarshal is done). +// +// History: 21-Sep-95 Rickhi Created +// +//+------------------------------------------------------------------- +HRESULT CStdMarshal::PreventPendingDisconnect() +{ + ASSERT_LOCK_HELD + InterlockedIncrement(&_cNestedCalls); + + if (_dwFlags & + (ClientSide() ? SMFLAGS_PENDINGDISCONNECT + : SMFLAGS_PENDINGDISCONNECT | SMFLAGS_DISCONNECTED)) + return CO_E_OBJNOTCONNECTED; + + return S_OK; +} + +//+------------------------------------------------------------------- +// +// Member: CStdMarshal::HandlePendingDisconnect, public +// +// Synopsis: Reverses a call to PreventDisconnect and lets a +// pending disconnect through. +// +// History: 21-Sep-95 Rickhi Created +// +//+------------------------------------------------------------------- +HRESULT CStdMarshal::HandlePendingDisconnect(HRESULT hr) +{ + ASSERT_LOCK_RELEASED + + // treat this as a nested call so that if we yield, a real + // disconnect wont come through, instead it will be treated + // as pending. That allows us to avoid checking our state + // for Disconnected every time we yield the ORPC LOCK. + + if (InterlockedDecrement(&_cNestedCalls) == 0 && + (_dwFlags & SMFLAGS_PENDINGDISCONNECT)) + { + Disconnect(); + hr = FAILED(hr) ? hr : CO_E_OBJNOTCONNECTED; + } + + return hr; +} + +//+------------------------------------------------------------------- +// +// Member: CStdMarshal::DisconnectObject, public +// +// Synopsis: part of IMarshal interface, this is legal only on the +// server side. +// +// History: 20-Feb-95 Rickhi Created. +// +//-------------------------------------------------------------------- +STDMETHODIMP CStdMarshal::DisconnectObject(DWORD dwReserved) +{ + ComDebOut((DEB_MARSHAL, + "CStdMarshal::DisconnectObject this:%x dwRes:%x\n", this, dwReserved)); + AssertValid(); + ASSERT_LOCK_RELEASED + + // this operation is not legal from the client side (although + // IProxyManager::Disconnect is), but we still have to return S_OK + // in either case for backward compatibility. + + if (ServerSide()) + { + Disconnect(); + } + + ASSERT_LOCK_RELEASED + return S_OK; +} + +//+------------------------------------------------------------------- +// +// Member: CStdMarshal::Disconnect, public +// +// Synopsis: client side - disconnects proxies from the channel. +// server side - disconnects stubs from the server object. +// +// History: 20-Feb-95 Rickhi Created. +// +//-------------------------------------------------------------------- +void CStdMarshal::Disconnect(void) +{ + ComDebOut((DEB_MARSHAL, "CStdMarshal::Disconnect this:%x\n", this)); + AssertValid(); + + ASSERT_LOCK_RELEASED + LOCK + + if ((_dwFlags & SMFLAGS_DISCONNECTED) && + !(_dwFlags & SMFLAGS_TRIEDTOCONNECT)) + { + // already disconnected, no partial connects, nothing to do + ComDebOut((DEB_MARSHAL,"CStdMarshal::Disconnect [already done]:%x\n",this)); + UNLOCK + ASSERT_LOCK_RELEASED + return; + } + + // Revoke ID from the ID table if registered. This prevents other + // marshals/unmarshals from finding this identity that is about to + // be disconnected. This is the ONLY state that should change, since + // we dont want to screw up any work-in-progress on other threads + // or in calls higher up the stack. + + _pStdId->RevokeOID(); + + if (_cNestedCalls != 0) + { + // we dont allow disconnect to occur inside a nested call since we + // dont want state to vanish in the middle of a call, but we do + // remember that we want to disconnect and will do it when the + // stack unwinds (or other threads complete). + + _dwFlags |= SMFLAGS_PENDINGDISCONNECT; + + ComDebOut((DEB_MARSHAL,"CStdMarshal::Disconnect [pending]:%x\n",this)); + UNLOCK; + ASSERT_LOCK_RELEASED + return; + } + + + // No calls in progress and not already disconnected, OK to really + // disconnect now. First mark ourself as disconnected incase we + // get reentered while releasing a stub pointer. + + _dwFlags |= SMFLAGS_DISCONNECTED; // turn on disconnected + _dwFlags &= ~(SMFLAGS_PENDINGDISCONNECT | // turn off pending disconnect + SMFLAGS_TRIEDTOCONNECT); // turn off tried to connect + + // disconnect all our IPIDs + if (ServerSide()) + DisconnectSrvIPIDs(); + else + DisconnectCliIPIDs(); + + UNLOCK + ASSERT_LOCK_RELEASED + + if (ServerSide()) + { + // HACK - 16 and 32 bit Word 6.0 crash if you release all the objects + // it left lying around at CoUninitialize. Leak them. + COleTls tls; + // If we are not uninitializing, then call the release. + if ((tls->dwFlags & OLETLS_THREADUNINITIALIZING) == 0 || + + // If we are in WOW and the app is not word, then call the release. + (IsWOWThread() && + (g_pOleThunkWOW->GetAppCompatibilityFlags() & OACF_NO_UNINIT_CLEANUP) == 0) || + + // If the app is not 32 bit word, then call the release. + !IsTaskName( L"winword.exe" )) + { + // on the server side, we have to tell the stdid to release his + // controlling unknown of the real object. + _pStdId->ReleaseCtrlUnk(); + } + } + + ASSERT_LOCK_RELEASED + ComDebOut((DEB_MARSHAL,"CStdMarshal::Disconnect [complete]:%x\n",this)); +} + +//+------------------------------------------------------------------- +// +// Member: CStdMarshal::DisconnectCliIPIDs +// +// Synopsis: disconnects client side IPIDs for this object. +// +// History: 20-Feb-95 Rickhi Created. +// +//-------------------------------------------------------------------- +void CStdMarshal::DisconnectCliIPIDs() +{ + ComDebOut((DEB_MARSHAL,"CStdMarshal::DisconnectCliIPIDs this:%x pFirst:%x\n", + this, _pFirstIPID)); + Win4Assert(ClientSide()); + Win4Assert(_dwFlags & SMFLAGS_DISCONNECTED); + + // YIELD WARNING: Do not yield between here and the matching comment + // below, since we are mucking with internal state that could get + // messed up if a reconnect (or unmarshal) is done. + + ASSERT_LOCK_HELD + + // on client side, we cant actually release the proxies until the + // object goes away (backward compatibility), so we just release + // our references to the remote guy, disconnect the proxies, and + // delete the channels, but hold on to the IPIDEntries. + + REMINTERFACEREF *pRifRefAlloc = (REMINTERFACEREF *) + _alloca(_cIPIDs * 2 * sizeof(REMINTERFACEREF)); + REMINTERFACEREF *pRifRef = pRifRefAlloc; + + + OXIDEntry *pOXID = NULL; + USHORT cRifRef = 0; + IPIDEntry *pEntry = _pFirstIPID; + + while (pEntry) + { + // we have to handle the case where ConnectIPIDEntry partially (but + // not completely) set up the IPIDEntry, hence we cant just check + // for the IPIDF_DISCONNECTED flag. + + ValidateIPIDEntry(pEntry); + + // NOTE: we are calling Proxy code here while holding the ORPC LOCK. + // There is no way to get around this without introducing race + // conditions. We cant just disconnect the channel and leave the + // proxy connected cause some proxies (like IDispatch) do weird shit, + // like keeping separate pointers to the server. + + if (pEntry->pStub) // NULL for IUnknown IPID + { + ComDebOut((DEB_MARSHAL, "Disconnect pProxy:%x\n", pEntry->pStub)); + ((IRpcProxyBuffer *)pEntry->pStub)->Disconnect(); + pEntry->pv = NULL; + } + + if (!(pEntry->dwFlags & IPIDF_NOPING)) + { + // the object pays attention to pings (and hence refcounts) + + if (pEntry->cStrongRefs > 0 || pEntry->cPrivateRefs > 0) + { + // we own some strong references on this interface, fill + // in an interfaceref so we release them. + + pRifRef->cPublicRefs = pEntry->cStrongRefs; + pRifRef->cPrivateRefs = pEntry->cPrivateRefs; + pRifRef->ipid = pEntry->ipid; + pRifRef++; + cRifRef++; + } + + if (pEntry->cWeakRefs > 0) + { + // we own some weak references on this interface, fill + // in an interfaceref so we release them. + + pRifRef->cPublicRefs = pEntry->cWeakRefs; + pRifRef->cPrivateRefs = 0; + pRifRef->ipid = pEntry->ipid; + + // mark the IPID as weak so that RemRelease on the server + // knows to release weak references instead of strong refs. + + pRifRef->ipid.Data1 |= IPIDFLAG_WEAKREF; + pRifRef++; + cRifRef++; + } + } + + pEntry->cStrongRefs = 0; + pEntry->cWeakRefs = 0; + pEntry->cPrivateRefs = 0; + pEntry->dwFlags |= IPIDF_DISCONNECTED | IPIDF_NOPING; + + if (pEntry->pChnl) + { + // release the channel for this IPID + pEntry->pChnl->Release(); + pEntry->pChnl = NULL; + } + + if (pEntry->pOXIDEntry) + { + // We will be decrementing the OXID refcnt as we release IPIDEntries + // but we dont want the OXIDEntry to go away until after we make the + // RemoteRelease call below, so we hold on to it here. + + if (pOXID == NULL) + { + pOXID = pEntry->pOXIDEntry; + IncOXIDRefCnt(pOXID); + } + + // If we ever go to a model where different IPIDEntries on the + // same object can point to different OXIDEntires, then we need + // to re-write this code to batch the releases by OXID. + Win4Assert(pOXID == pEntry->pOXIDEntry); + + // release the RefCnt on the OXIDEntry + DecOXIDRefCnt(pEntry->pOXIDEntry); + pEntry->pOXIDEntry = NULL; + } + + // get next IPID in chain for this object + pEntry = pEntry->pNextOID; + } + + if (_pChnl) + { + // release the last client side channel + _pChnl->Release(); + _pChnl = NULL; + } + + if (_dwFlags & SMFLAGS_REGISTEREDOID) + { + // Tell the resolver to stop pinging the OID. The OID is only + // registered on the client side. + + Win4Assert(ClientSide()); + gResolver.ClientDeRegisterOIDFromPingServer(_pStdId->GetOID(), + _dwFlags & SMFLAGS_CLIENTMARSHALED); + + } + + // turn these flags off so re-connect (with new OID) will behave properly. + _dwFlags &= ~(SMFLAGS_CLIENTMARSHALED | SMFLAGS_REGISTEREDOID | + SMFLAGS_NOPING); + + + // YIELD WARNING: Up this this point we have been mucking with our + // internal state. We cant yield before this point or a reconnect + // proxies could get all screwed up. It is OK to yield after this point + // because all internal state changes are now complete. The function + // to release the remote references yield. + + if (cRifRef != 0) + { + // we have looped filling in the RifRef and entries exist in the + // array. go call the server now to release the IPIDs. + + Win4Assert(pOXID); // must have been at least one + RemoteReleaseRifRef(pOXID, cRifRef, pRifRefAlloc); + } + + if (pOXID) + { + // Now release the refcnt (if any) we put on the OXIDEntry above + // to hold it + DecOXIDRefCnt(pOXID); + } + + ASSERT_LOCK_HELD + DbgWalkIPIDs(); + ComDebOut((DEB_MARSHAL, "CStdMarshal::DisconnectCliIPIDs this:%x\n",this)); +} + +//+------------------------------------------------------------------- +// +// Member: CStdMarshal::DisconnectSrvIPIDs +// +// Synopsis: disconnects the server side IPIDs for this object. +// +// History: 20-Feb-95 Rickhi Created. +// +//-------------------------------------------------------------------- +void CStdMarshal::DisconnectSrvIPIDs() +{ + ComDebOut((DEB_MARSHAL, + "CStdMarshal::DisconnectSrvIPIDs this:%x pFirst:%x\n",this, _pFirstIPID)); + Win4Assert(ServerSide()); + + // there should be no other threads looking at these IPIDs at this time, + // since Marshal, Unmarshal, and Dispatch all call PreventDisconnect, + // Disconnect checks the disconnected flag directly, RMD holds the + // lock over it's whole execution, RemAddRef and RemRelease hold the + // lock and check the disconnected flag of the IPIDEntry, and + // RemQueryInterface calls PreventDisconnect. + + Win4Assert(_dwFlags & SMFLAGS_DISCONNECTED); + Win4Assert(_cNestedCalls == 0); + ASSERT_LOCK_HELD + + + // while holding the lock, flag each IPID as disconnected so that no + // more incoming calls are dispatched to this object. We also unchain + // the IPIDs to ensure that no other threads are pointing at them. + + IPIDEntry *pFirstIPID = _pFirstIPID; + _pFirstIPID = NULL; + + IPIDEntry *pEntry = pFirstIPID; + while (pEntry) + { + pEntry->dwFlags |= IPIDF_VACANT | IPIDF_DISCONNECTED; + + // release the refcnt on the OXIDEntry and NULL it + DecOXIDRefCnt(pEntry->pOXIDEntry); + pEntry->pOXIDEntry = NULL; + + pEntry = pEntry->pNextOID; + } + + + // now release the LOCK since we will be calling into app code to + // disconnect the stubs, and to release the external connection counts. + // There should be no other pointers to these IPIDEntries now, so it + // is safe to muck with their fields (except the dwFlags which is looked + // at by Dispatch and was already set above). + + UNLOCK + ASSERT_LOCK_RELEASED + + IPIDEntry *pLastIPID; + pEntry = pFirstIPID; + + while (pEntry) + { + if (pEntry->dwFlags & IPIDF_NOTIFYACT) + { + // the activation code asked to be notified when the refcnt + // on this interface reaches zero. Turn the flag off so we + // don't call twice. + pEntry->dwFlags &= ~IPIDF_NOTIFYACT; + NotifyActivation(FALSE, (IUnknown *)(pEntry->pv)); + } + + if (pEntry->pStub) // pStub is NULL for IUnknown IPID + { + ComDebOut((DEB_MARSHAL, "Disconnect pStub:%x\n", pEntry->pStub)); + ((IUnknown *)pEntry->pv)->Release(); + ((IRpcStubBuffer *)pEntry->pStub)->Disconnect(); + pEntry->pStub->Release(); + pEntry->pStub = NULL; + pEntry->pv = NULL; + } + + if (pEntry->cWeakRefs > 0) + { + // Release weak references on the StdId. + pEntry->cWeakRefs = 0; + _pStdId->Release(); + } + + if (pEntry->cStrongRefs > 0) + { + // Release strong references on the StdId. Note that 16bit + // 16bit OLE always passed fLastReleaseCloses = FALSE in + // DisconnectObject so we do the same here. + + pEntry->cStrongRefs = 0; + _pStdId->DecStrongCnt(TRUE); // fKeepAlive + } + + if (pEntry->cPrivateRefs > 0) + { + // Release private references on the StdId. Note that 16bit + // 16bit OLE always passed fLastReleaseCloses = FALSE in + // DisconnectObject so we do the same here. + + pEntry->cPrivateRefs = 0; + _pStdId->DecStrongCnt(TRUE); // fKeepAlive + } + pLastIPID = pEntry; + pEntry = pEntry->pNextOID; + } + + ASSERT_LOCK_RELEASED + LOCK + + if (pFirstIPID) + { + // now we release all entries. + gIPIDTbl.ReleaseEntryList(pFirstIPID, pLastIPID); + } + + ASSERT_LOCK_HELD + DbgWalkIPIDs(); + ComDebOut((DEB_MARSHAL, + "CStdMarshal::DisconnectSrvIPIDs [OUT] this:%x\n",this)); +} + +//+------------------------------------------------------------------- +// +// Member: CStdMarshal::InstantiatedProxy, public +// +// Synopsis: return requested interfaces to the caller if instantiated +// +// History: 20-Feb-95 Rickhi Created. +// +//-------------------------------------------------------------------- +BOOL CStdMarshal::InstantiatedProxy(REFIID riid, void **ppv, HRESULT *phr) +{ + ComDebOut((DEB_MARSHAL, + "CStdMarshal::InstantiatedProxy this:%x riid:%I ppv:%x\n", + this, &riid, ppv)); + AssertValid(); + Win4Assert(ClientSide()); + Win4Assert(*ppv == NULL); + Win4Assert(*phr == S_OK); + + BOOL fRet = FALSE; + + ASSERT_LOCK_RELEASED + LOCK + + // look for an existing IPIDEntry for the requested interface + IPIDEntry *pEntry; + HRESULT hr = FindIPIDEntry(riid, &pEntry); + + if (SUCCEEDED(hr) && pEntry->pv) + { + // found the ipid entry, now extract the interface + // pointer to return to the caller. + + Win4Assert(IsValidInterface(pEntry->pv)); + *ppv = pEntry->pv; + fRet = TRUE; + } + else if (_cIPIDs == 0) + { + // no IPIDEntry for the requested interface, and we have never + // been connected to the server. Return E_NOINTERFACE in this + // case. This is different from having been connected then + // disconnected, where we return CO_E_OBJNOTCONNECTED. + + *phr = E_NOINTERFACE; + Win4Assert(fRet == FALSE); + } + else if (_dwFlags & SMFLAGS_PENDINGDISCONNECT) + { + // no IPIDEntry for the requested interface and disconnect is + // pending, so return an error. + + *phr = CO_E_OBJNOTCONNECTED; + Win4Assert(fRet == FALSE); + } + else + { + // no IPIDEntry, we are not disconnected, and we do have other + // instantiated proxies. QueryMultipleInterfaces expects + // *phr == S_OK and FALSE returned. + + Win4Assert(*phr == S_OK); + Win4Assert(fRet == FALSE); + } + + UNLOCK + ASSERT_LOCK_RELEASED + ComDebOut((DEB_MARSHAL, + "CStdMarshal::InstantiatedProxy hr:%x pv:%x fRet:%x\n", *phr, *ppv, fRet)); + return fRet; +} + +//+------------------------------------------------------------------- +// +// Member: CStdMarshal::QueryRemoteInterfaces, public +// +// Synopsis: return requested interfaces to the caller if supported +// +// History: 20-Feb-95 Rickhi Created. +// +//-------------------------------------------------------------------- +HRESULT CStdMarshal::QueryRemoteInterfaces(USHORT cIIDs, IID *pIIDs, SQIResult *pQIRes) +{ + ComDebOut((DEB_MARSHAL, + "CStdMarshal::QueryRemoteInterfaces this:%x pIIDs:%x pQIRes:%x\n", + this, pIIDs, pQIRes)); + AssertValid(); + Win4Assert(ClientSide()); + Win4Assert(cIIDs > 0); + + ASSERT_LOCK_RELEASED + LOCK + + HRESULT hr = PreventDisconnect(); + + if (SUCCEEDED(hr)) + { + // call QI on the remote guy and unmarshal the results + hr = RemQIAndUnmarshal(cIIDs, pIIDs, pQIRes); + } + else + { + // cant call out because we're disconnected so return error for + // each requested interface. + for (USHORT i=0; i<cIIDs; i++, pQIRes++) + { + pQIRes->hr = hr; + } + } + + UNLOCK + ASSERT_LOCK_RELEASED + + // if the object was disconnected while in the middle of the call, + // then we still return SUCCESS for any interfaces we acquired. The + // reason is that we do have the proxies, and this matches the + // behaviour of a QI for an instantiated proxy on a disconnected + // object. + + hr = HandlePendingDisconnect(hr); + + ComDebOut((DEB_MARSHAL, + "CStdMarshal::QueryRemoteInterfaces this:%x hr:%x\n", this, hr)); + return hr; +} + +//+------------------------------------------------------------------- +// +// Member: CStdMarshal::RemQIAndUnmarshal, private +// +// Synopsis: call QI on remote guy, then unmarshal the STDOBJREF +// to create the IPID, and return the interface ptr. +// +// History: 20-Feb-95 Rickhi Created. +// +// Notes: Caller must guarantee at least one IPIDEntry is connected. +// This function does a sparse fill of the result array. +// +//-------------------------------------------------------------------- +HRESULT CStdMarshal::RemQIAndUnmarshal(USHORT cIIDs, IID *pIIDs, + SQIResult *pQIRes) +{ + ComDebOut((DEB_MARSHAL, + "CStdMarshal::RemQIAndUnmarshal this:%x cIIDs:%x pIIDs:%x pQIRes:%x\n", + this, cIIDs, pIIDs, pQIRes)); + AssertValid(); + AssertDisconnectPrevented(); + Win4Assert(_pFirstIPID); // must be at least 1 IPIDEntry + ASSERT_LOCK_HELD + + // we need an IPID to call RemoteQueryInterface with, any one will + // do so long as it is connected (in the reconnect case there may be + // only one connected IPID) so we pick the first one in the chain that + // is connected. + + IPIDEntry *pIPIDEntry = GetConnectedIPID(); + + // remember what type of reference to get since we yield the lock + // and cant rely on _dwFlags later. + BOOL fWeakClient = (_dwFlags & SMFLAGS_WEAKCLIENT); + + // call the remote guy + REMQIRESULT *pRemQiRes = NULL; + IRemUnknown *pRemUnk; + HRESULT hr = GetSecureRemUnk( &pRemUnk, pIPIDEntry->pOXIDEntry ); + if (SUCCEEDED(hr)) + { + hr = RemoteQueryInterface(pRemUnk, pIPIDEntry, cIIDs, pIIDs, &pRemQiRes, + fWeakClient); + } + + // need to remember the result ptr so we can free it. + REMQIRESULT *pRemQiResNext = pRemQiRes; + + // unmarshal each STDOBJREF returned. Note that while we did the + // RemoteQI we could have yielded (or nested) and did another + // RemoteQI for the same interfaces, so we have to call UnmarshalIPID + // which will find any existing IPIDEntry and bump its refcnt. + + HRESULT hr2; + HRESULT *phr = &hr2; + void *pv; + void **ppv = &pv; + + for (USHORT i=0; i<cIIDs; i++) + { + if (pQIRes) + { + // caller wants the pointers returned, set ppv and phr. + ppv = &pQIRes->pv; + phr = &pQIRes->hr; + pQIRes++; + } + + if (SUCCEEDED(hr)) + { + if (SUCCEEDED(pRemQiResNext->hResult)) + { + if (fWeakClient) + { + // mark the std objref with the weak reference flag so + // that UnmarshalIPID adds the references to the correct + // count. + pRemQiResNext->std.flags |= SORF_WEAKREF; + } + + *phr = UnmarshalIPID(*pIIDs, &pRemQiResNext->std, + pIPIDEntry->pOXIDEntry, + (pQIRes) ? ppv : NULL); + + if (FAILED(*phr)) + { + // could not unmarshal, release the resources with the + // server. + RemoteReleaseStdObjRef(&pRemQiResNext->std, + pIPIDEntry->pOXIDEntry); + } + } + else if (pQIRes) + { + // the requested interface was not returned so set the + // return code and interface ptr. + *phr = pRemQiResNext->hResult; + *ppv = NULL; + } + + pIIDs++; + pRemQiResNext++; + } + else + { + // the whole call failed so return the error for each + // requested interface. + *phr = hr; + *ppv = NULL; + } + + // make sure the ptr value is NULL on failure. It may be NULL or + // non-NULL on success. (ReconnectProxies wants NULL). + Win4Assert(SUCCEEDED(*phr) || *ppv == NULL); + } + + // free the result buffer + CoTaskMemFree(pRemQiRes); + + ASSERT_LOCK_HELD + AssertDisconnectPrevented(); + ComDebOut((DEB_MARSHAL, + "CStdMarshal::RemQIAndUnmarshal this:%x hr:%x\n", this, hr)); + return hr; +} + +//+------------------------------------------------------------------- +// +// Member: CStdMarshal::RemIsConnected, private +// +// Synopsis: Returns TRUE if most likely connected, FALSE if definitely +// not connected or pending disconnect. +// +// History: 20-Feb-95 Rickhi Created. +// +//-------------------------------------------------------------------- +BOOL CStdMarshal::RemIsConnected(void) +{ + AssertValid(); + Assert(ClientSide()); + + // the default link depends on us returning FALSE if we are either + // disconnected or just pending disconnect, in order that they avoid + // running their cleanup code twice. + + BOOL fRes = (_dwFlags & (SMFLAGS_DISCONNECTED | SMFLAGS_PENDINGDISCONNECT)) + ? FALSE : TRUE; + + ComDebOut((DEB_MARSHAL, + "CStdMarshal::RemIsConnected this:%x fResult:%x\n", this, fRes)); + return fRes; +} + +//+------------------------------------------------------------------- +// +// Member: CreateChannel, private +// +// Synopsis: Creates an instance of the Rpc Channel. +// +// History: 20-Feb-95 Rickhi Created +// +//-------------------------------------------------------------------- +HRESULT CStdMarshal::CreateChannel(OXIDEntry *pOXIDEntry, DWORD dwFlags, + REFIPID ripid, REFIID riid, CRpcChannelBuffer **ppChnl) +{ + ASSERT_LOCK_HELD + HRESULT hr = S_OK; + + if (_pChnl == NULL) + { + DWORD cState = ServerSide() ? server_cs : client_cs; + cState |= (dwFlags & SORF_FREETHREADED) ? freethreaded_cs : 0; + + // make a channel. We dont need the call control stuff so just + // create the base class. + + _pChnl = new CRpcChannelBuffer(_pStdId, pOXIDEntry, cState); + + if (_pChnl == NULL) + { + hr = E_OUTOFMEMORY; + } + } + + if (SUCCEEDED(hr) && ClientSide()) + { + *ppChnl = _pChnl->Copy(pOXIDEntry, ripid, riid); + if (*ppChnl == NULL) + { + hr = E_OUTOFMEMORY; + } + } + else + { + *ppChnl = _pChnl; + } + + ASSERT_LOCK_HELD + return hr; +} + +//+------------------------------------------------------------------- +// +// Member: GetPSFactory, private +// +// Synopsis: loads the proxy/stub factory for given IID +// +// History: 20-Feb-95 Rickhi Created +// +//-------------------------------------------------------------------- +HRESULT CStdMarshal::GetPSFactory(REFIID riid, IUnknown *pUnkWow, BOOL fServer, + IPSFactoryBuffer **ppIPSF, BOOL *pfNonNDR) +{ + ComDebOut((DEB_MARSHAL, + "CStdMarshal::GetPSFactory this:%x riid:%I pUnkWow:%x\n", + this, &riid, pUnkWow)); + ASSERT_LOCK_RELEASED + + // map iid to classid + CLSID clsid; + HRESULT hr = gRIFTbl.RegisterInterface(riid, fServer, &clsid); +#ifdef WX86OLE + BOOL fWx86 = FALSE; +#endif + + if (SUCCEEDED(hr)) + { + BOOL fWow = FALSE; + + if (IsWOWThread()) + { + // figure out if this is a custom interface from a 16bit + // app, since we have to load the 16bit proxy code if so. + + IThunkManager *pThkMgr; + g_pOleThunkWOW->GetThunkManager(&pThkMgr); + Win4Assert(pThkMgr && "pUnk in WOW does not support IThunkManager."); + + if (pUnkWow) + fWow = pThkMgr->IsCustom3216Proxy(pUnkWow, riid); + else + fWow = pThkMgr->IsIIDRequested(riid); + + pThkMgr->Release(); + } + +#ifdef WX86OLE + // If we are in a Wx86 process then we need to determine if the + // PSFactory needs to be an x86 or native one. + else if (gcwx86.IsWx86Enabled()) + { + // Callout to wx86 to ask it to determine if an x86 PS factory + // is required. Whole32 can tell if the stub needs to be x86 + // by determining if pUnkWow is a custom interface proxy or not. + // Whole32 can determine if a x86 proxy is required by checking + // if the riid is one for a custom interface that is expected + // to be returned. + fWx86 = gcwx86.NeedX86PSFactory(pUnkWow, riid); + } +#endif + + // if we are loading a 16bit custom proxy then mark it as non NDR + *pfNonNDR = (fWow) ? TRUE : FALSE; + + if (IsEqualGUID(clsid, CLSID_PSOlePrx32)) + { + // its our internal CLSID so go straight to our class factory. + hr = PrxDllGetClassObject(clsid, IID_IPSFactoryBuffer, + (void **)ppIPSF); + } + else + { +#ifdef WX86OLE + DWORD dwContext = fWow ? CLSCTX_INPROC_SERVER16 + : (fWx86 ? CLSCTX_INPROC_SERVERX86 : + CLSCTX_INPROC_SERVER) + | CLSCTX_PS_DLL; +#else + DWORD dwContext = fWow ? CLSCTX_INPROC_SERVER16 + : CLSCTX_INPROC_SERVER | CLSCTX_PS_DLL; +#endif + + // load the dll and get the PS class object + hr = ICoGetClassObject(clsid, dwContext, NULL, IID_IPSFactoryBuffer, + (void **)ppIPSF); +#ifdef WX86OLE + if (fWx86 && FAILED(hr)) + { + // if we are looking for an x86 PSFactory and we didn't find + // one on InprocServerX86 key then we need to check + // InprocServer32 key as well. + hr = ICoGetClassObject(clsid, + CLSCTX_INPROC_SERVER | CLSCTX_PS_DLL, + NULL, IID_IPSFactoryBuffer, + (void **)ppIPSF); + + if (SUCCEEDED(hr) && + (! gcwx86.IsN2XProxy((IUnknown *)*ppIPSF))) + { + ((IUnknown *)*ppIPSF)->Release(); + hr = REGDB_E_CLASSNOTREG; + } + } +#endif + AssertOutPtrIface(hr, *ppIPSF); + } + } + +#if DBG==1 + // if the fake NonNDR flag is set and its the test interface, then + // trick the code into thinking this is a nonNDR proxy. This is to + // enable simpler testing of an esoteric feature. + + if (gfFakeNonNDR && IsEqualIID(riid, IID_ICube)) + { + *pfNonNDR = TRUE; + } +#endif + + ComDebOut((DEB_MARSHAL, + "CStdMarshal::GetPSFactory this:%x pIPSF:%x fNonNDR:%x hr:%x\n", + this, *ppIPSF, *pfNonNDR, hr)); + + ASSERT_LOCK_RELEASED + return hr; +} + +//+------------------------------------------------------------------- +// +// Member: CreateProxy, private +// +// Synopsis: creates an interface proxy for the given interface +// +// Returns: [ppv] - interface of type riid, AddRef'd +// +// History: 20-Feb-95 Rickhi Created +// +//-------------------------------------------------------------------- +HRESULT CStdMarshal::CreateProxy(REFIID riid, IRpcProxyBuffer **ppProxy, + void **ppv, BOOL *pfNonNDR) +{ + TRACECALL(TRACE_MARSHAL, "CreateProxy"); + ComDebOut((DEB_MARSHAL, + "CStdMarshal::CreateProxy this:%x riid:%I\n", this, &riid)); + AssertValid(); + Win4Assert(ClientSide()); + Win4Assert(ppProxy != NULL); + ASSERT_LOCK_HELD + + // get the controlling IUnknown of this object + IUnknown *punkCtrl = _pStdId->GetCtrlUnk(); + Win4Assert(punkCtrl != NULL); + + if (InlineIsEqualGUID(riid, IID_IUnknown)) + { + // there is no proxy for IUnknown so we handle that case here + punkCtrl->AddRef(); + *ppv = (void **)punkCtrl; + *ppProxy = NULL; + *pfNonNDR = FALSE; + return S_OK; + } + + UNLOCK + ASSERT_LOCK_RELEASED + + // now construct the proxy for the interface + IPSFactoryBuffer *pIPSF = NULL; + HRESULT hr = GetPSFactory(riid, NULL, FALSE, &pIPSF, pfNonNDR); + + if (SUCCEEDED(hr)) + { + // got the class factory, now create an instance + hr = pIPSF->CreateProxy(punkCtrl, riid, ppProxy, ppv); + AssertOutPtrIface(hr, *ppProxy); + pIPSF->Release(); + } + + ASSERT_LOCK_RELEASED + LOCK + + ComDebOut((DEB_MARSHAL, + "CStdMarshal::CreateProxy this:%x pProxy:%x pv:%x fNonNDR:%x hr:%x\n", + this, *ppProxy, *ppv, *pfNonNDR, hr)); + return hr; +} + +//+------------------------------------------------------------------- +// +// Member: CreateStub, private +// +// Synopsis: creates an interface stub and adds it to the IPID table +// +// History: 20-Feb-95 Rickhi Created +// +//-------------------------------------------------------------------- +HRESULT CStdMarshal::CreateStub(REFIID riid, IRpcStubBuffer **ppStub, + void **ppv, BOOL *pfNonNDR) +{ + TRACECALL(TRACE_MARSHAL, "CreateStub"); + ComDebOut((DEB_MARSHAL, + "CStdMarshal::CreateStub this:%x riid:%I\n", this, &riid)); + AssertValid(); + Win4Assert(ServerSide()); + Win4Assert(ppStub != NULL); + ASSERT_LOCK_HELD + + // get the IUnknown of the object + IUnknown *punkObj = _pStdId->GetServer(); + Win4Assert(punkObj != NULL); + + if (InlineIsEqualGUID(riid, IID_IUnknown)) + { + // there is no stub for IUnknown so we handle that here + *ppv = (void *)punkObj; + *ppStub = NULL; + *pfNonNDR = FALSE; + return S_OK; + } + + UNLOCK + ASSERT_LOCK_RELEASED + + // make sure the object supports the given interface, so we dont + // waste a bunch of effort creating a stub if the interface is + // not supported. + + IUnknown *pUnkIf = NULL; + HRESULT hr; +#ifdef WX86OLE + if (gcwx86.IsN2XProxy(punkObj)) + { + // If we are creating a stub for an object that is living on the + // x86 side then we need to set the StubInvoke flag to allow QI + // to thunk the custom interface QI. + gcwx86.SetStubInvokeFlag((BOOL)1); + } +#endif + hr = punkObj->QueryInterface(riid, (void **)&pUnkIf); + AssertOutPtrIface(hr, pUnkIf); + + if (SUCCEEDED(hr)) + { + // now construct the stub for the interface + IPSFactoryBuffer *pIPSF = NULL; + hr = GetPSFactory(riid, pUnkIf, TRUE, &pIPSF, pfNonNDR); + + if (SUCCEEDED(hr)) + { + // got the class factory, now create an instance + hr = pIPSF->CreateStub(riid, punkObj, ppStub); + AssertOutPtrIface(hr, *ppStub); + pIPSF->Release(); + } + + if (SUCCEEDED(hr)) + { + // remember the interface pointer + *ppv = (void *)pUnkIf; + } + else + { + // error, release the interface and return NULL + pUnkIf->Release(); + *ppv = NULL; + } + } + + ASSERT_LOCK_RELEASED + LOCK + + ComDebOut((DEB_MARSHAL, + "CStdMarshal::CreateStub this:%x pStub:%x pv:%x fNonNDR:%x hr:%x\n", + this, *ppStub, *ppv, *pfNonNDR, hr)); + return hr; +} + +//+------------------------------------------------------------------- +// +// Member: FindIPIDEntry, private +// +// Synopsis: Finds an IPIDEntry, chained off this object, with the +// given riid. +// +// History: 20-Feb-95 Rickhi Created +// +//-------------------------------------------------------------------- +HRESULT CStdMarshal::FindIPIDEntry(REFIID riid, IPIDEntry **ppEntry) +{ + ComDebOut((DEB_OXID,"CStdMarshal::FindIPIDEntry ppEntry:%x riid:%I\n", + ppEntry, &riid)); + ASSERT_LOCK_HELD + + IPIDEntry *pEntry = _pFirstIPID; + while (pEntry) + { + if (InlineIsEqualGUID(riid, pEntry->iid)) + { + *ppEntry = pEntry; + return S_OK; + } + + pEntry = pEntry->pNextOID; // get next entry in object chain + } + + ASSERT_LOCK_HELD + return E_NOINTERFACE; +} + +//+------------------------------------------------------------------- +// +// Member: FindIPIDEntryByIPID, private +// +// Synopsis: returns the IPIDEntry ptr for the given IPID +// +// History: 20-Feb-95 Rickhi Created +// +//-------------------------------------------------------------------- +HRESULT CStdMarshal::FindIPIDEntryByIPID(REFIPID ripid, IPIDEntry **ppEntry) +{ + ASSERT_LOCK_HELD + + IPIDEntry *pEntry = _pFirstIPID; + while (pEntry) + { + if (InlineIsEqualGUID(pEntry->ipid, ripid)) + { + *ppEntry = pEntry; + return S_OK; + } + + pEntry = pEntry->pNextOID; // get next entry in object chain + } + + ASSERT_LOCK_HELD + return E_NOINTERFACE; +} + +//+------------------------------------------------------------------- +// +// Member: FindIPIDEntryByInterface, internal +// +// Synopsis: returns the IPIDEntry ptr for the given proxy +// +//-------------------------------------------------------------------- +HRESULT CStdMarshal::FindIPIDEntryByInterface(void *pProxy, IPIDEntry **ppEntry) +{ + ASSERT_LOCK_HELD + + IPIDEntry *pEntry = _pFirstIPID; + *ppEntry = NULL; + while (pEntry) + { + if (pEntry->pv == pProxy) + { + *ppEntry = pEntry; + break; + } + + pEntry = pEntry->pNextOID; + } + + if (*ppEntry != NULL) + return S_OK; + else + return E_NOINTERFACE; +} + +//+------------------------------------------------------------------- +// +// Member: IncSrvIPIDCnt, protected +// +// Synopsis: increments the refcnt on the IPID entry, and optionally +// AddRefs the StdId. +// +// History: 20-Feb-95 Rickhi Created +// +//-------------------------------------------------------------------- +HRESULT CStdMarshal::IncSrvIPIDCnt(IPIDEntry *pEntry, ULONG cRefs, + ULONG cPrivateRefs, SECURITYBINDING *pName, + DWORD mshlflags) +{ + ComDebOut((DEB_MARSHAL, "IncSrvIPIDCnt this:%x pIPID:%x cRefs:%x cPrivateRefs:%x\n", + this, pEntry, cRefs, cPrivateRefs)); + Win4Assert(ServerSide()); + Win4Assert(pEntry); + Win4Assert(cRefs > 0 || cPrivateRefs > 0); + ASSERT_LOCK_HELD + + HRESULT hr = S_OK; + + if (cPrivateRefs != 0) + { + // Add a reference. + hr = gSRFTbl.IncRef( cPrivateRefs, pEntry->ipid, pName ); + + if (SUCCEEDED(hr)) + { + BOOL fNotify = (pEntry->cPrivateRefs == 0) ? TRUE : FALSE; + pEntry->cPrivateRefs += cPrivateRefs; + if (fNotify) + { + // this inc causes the count to go from zero to non-zero, so we + // inc the strong count on the stdid to hold it alive until this + // IPID is released. + IncStrongAndNotifyAct(pEntry, mshlflags); + } + } + } + + if (SUCCEEDED(hr)) + { + if (mshlflags & (MSHLFLAGS_TABLESTRONG | MSHLFLAGS_TABLEWEAK)) + { + // Table Marshal Case: inc the number of table marshals. + IncTableCnt(); + } + + if (mshlflags & (MSHLFLAGS_WEAK | MSHLFLAGS_TABLEWEAK)) + { + if (pEntry->cWeakRefs == 0) + { + // this inc causes the count to go from zero to non-zero, so we + // AddRef the stdid to hold it alive until this IPID is released. + + _pStdId->AddRef(); + } + pEntry->cWeakRefs += cRefs; + } + else + { + BOOL fNotify = (pEntry->cStrongRefs == 0) ? TRUE : FALSE; + pEntry->cStrongRefs += cRefs; + if (fNotify) + { + // this inc causes the count to go from zero to non-zero, so we + // inc the strong count on the stdid to hold it alive until this + // IPID is released. + IncStrongAndNotifyAct(pEntry, mshlflags); + } + } + } + + ASSERT_LOCK_HELD + return hr; +} + +//+------------------------------------------------------------------- +// +// Member: IncTableCnt, public +// +// Synopsis: increments the count of table marshals +// +// History: 9-Oct-96 Rickhi Created +// +//-------------------------------------------------------------------- +void CStdMarshal::IncTableCnt(void) +{ + ASSERT_LOCK_HELD + + // If something was marshaled for a table, we have to ignore + // rundowns until a subsequent RMD is called for it, at which + // time we start paying attention to rundowns again. Since there + // can be any number of table marshals, we have to refcnt them. + + _cTableRefs++; + _dwFlags |= SMFLAGS_IGNORERUNDOWN; +} + +//+------------------------------------------------------------------- +// +// Member: IncStrongAndNotifyAct, private +// +// Synopsis: notifies the activation code when this interface refcnt +// goes from 0 to non-zero and the activation code asked to be +// notified, and also increments the strong refcnt. +// +// History: 21-Apr-96 Rickhi Created +// +//-------------------------------------------------------------------- +void CStdMarshal::IncStrongAndNotifyAct(IPIDEntry *pEntry, DWORD mshlflags) +{ + ASSERT_LOCK_HELD + + // inc the strong count on the stdid to hold it alive until this + // IPIDEntry is released. + + _pStdId->IncStrongCnt(); + if (mshlflags & MSHLFLAGS_NOTIFYACTIVATION && + !(pEntry->dwFlags & IPIDF_NOTIFYACT)) + { + // the activation code asked to be notified when the refcnt + // on this interface goes positive, and when it reaches + // zero again. Set a flag so we remember to notify + // activation when the strong reference reference count + // goes back down to zero. + pEntry->dwFlags |= IPIDF_NOTIFYACT; + + UNLOCK + ASSERT_LOCK_RELEASED + BOOL fOK = NotifyActivation(TRUE, (IUnknown *)(pEntry->pv)); + ASSERT_LOCK_RELEASED + LOCK + + if (!fOK) + { + // call failed, so dont bother notifying + pEntry->dwFlags &= ~IPIDF_NOTIFYACT; + } + } +} + +//+------------------------------------------------------------------- +// +// Member: DecSrvIPIDCnt, protected +// +// Synopsis: decrements the refcnt on the IPID entry, and optionally +// Releases the StdId. +// +// History: 20-Feb-95 Rickhi Created +// +//-------------------------------------------------------------------- +void CStdMarshal::DecSrvIPIDCnt(IPIDEntry *pEntry, ULONG cRefs, + ULONG cPrivateRefs, SECURITYBINDING *pName, + DWORD mshlflags) +{ + ComDebOut((DEB_MARSHAL, "DecSrvIPIDCnt this:%x pIPID:%x cRefs:%x cPrivateRefs:%x\n", + this, pEntry, cRefs, cPrivateRefs)); + Win4Assert(ServerSide()); + Win4Assert(pEntry); + Win4Assert(cRefs > 0 || cPrivateRefs > 0); + ASSERT_LOCK_HELD + + // Note: we dont care about holding the LOCK over the Release call since + // the guy who called us is holding a ref to the StdId, so this Release + // wont cause us to go away. + + if (mshlflags & (MSHLFLAGS_TABLESTRONG | MSHLFLAGS_TABLEWEAK)) + { + // Table Marshal Case: dec the number of table marshals. + DecTableCnt(); + } + + if (mshlflags & (MSHLFLAGS_WEAK | MSHLFLAGS_TABLEWEAK)) + { + Win4Assert(pEntry->cWeakRefs >= cRefs); + pEntry->cWeakRefs -= cRefs; + + if (pEntry->cWeakRefs == 0) + { + // this dec caused the count to go from non-zero to zero, so we + // Release the stdid since this IPID is no longer holding it alive. + _pStdId->Release(); + } + } + else + { + // Adjust the strong reference count. Don't let the caller release + // too many times. + + if (pEntry->cStrongRefs < cRefs) + { + ComDebOut((DEB_WARN,"DecSrvIPIDCnt too many releases. IPID entry: 0x%x Extra releases: 0x%x", + pEntry, cRefs-pEntry->cStrongRefs)); + cRefs = pEntry->cStrongRefs; + } + pEntry->cStrongRefs -= cRefs; + + if (pEntry->cStrongRefs == 0 && cRefs != 0) + { + // this dec caused the count to go from non-zero to zero, so we + // dec the strong count on the stdid since the public references + // on this IPID is no longer hold it alive. + + DecStrongAndNotifyAct(pEntry, mshlflags); + } + + // Adjust the secure reference count. Don't let the caller release + // too many times. + + if (pName != NULL) + { + cPrivateRefs = gSRFTbl.DecRef(cPrivateRefs, pEntry->ipid, pName); + } + else + { + cPrivateRefs = 0; + } + + Win4Assert( pEntry->cPrivateRefs >= cPrivateRefs ); + pEntry->cPrivateRefs -= cPrivateRefs; + + if (pEntry->cPrivateRefs == 0 && cPrivateRefs != 0) + { + // this dec caused the count to go from non-zero to zero, so we + // dec the strong count on the stdid since the private references + // on this IPID is no longer hold it alive. + + DecStrongAndNotifyAct(pEntry, mshlflags); + } + } + + ASSERT_LOCK_HELD +} + +//+------------------------------------------------------------------- +// +// Member: DecTableCnt, public +// +// Synopsis: decrements the count of table marshals +// +// History: 9-Oct-96 Rickhi Created +// +//-------------------------------------------------------------------- +void CStdMarshal::DecTableCnt(void) +{ + ASSERT_LOCK_HELD + + // If something was marshaled for a table, we have to ignore + // rundowns until a subsequent RMD is called for it, at which + // time we start paying attention to rundowns again. Since there + // can be any number of table marshals, we have to refcnt them. + // This is also used by CoLockObjectExternal. + + if (--_cTableRefs == 0) + { + // this was the last table marshal, so now we have to pay + // attention to rundown from normal clients, so that if all + // clients go away we cleanup. + + _dwFlags &= ~SMFLAGS_IGNORERUNDOWN; + } +} + +//+------------------------------------------------------------------- +// +// Member: DecStrongAndNotifyAct, private +// +// Synopsis: notifies the activation code if this interface has +// been released and the activation code asked to be +// notified, and also decrements the strong refcnt. +// +// History: 21-Apr-96 Rickhi Created +// +//-------------------------------------------------------------------- +void CStdMarshal::DecStrongAndNotifyAct(IPIDEntry *pEntry, DWORD mshlflags) +{ + ASSERT_LOCK_HELD + BOOL fNotifyAct = FALSE; + + if ((pEntry->dwFlags & IPIDF_NOTIFYACT) && + pEntry->cStrongRefs == 0 && + pEntry->cPrivateRefs == 0) + { + // the activation code asked to be notified when the refcnt + // on this interface reaches zero. Turn the flag off so we + // don't call twice. + pEntry->dwFlags &= ~IPIDF_NOTIFYACT; + fNotifyAct = TRUE; + } + + UNLOCK + ASSERT_LOCK_RELEASED + + if (fNotifyAct) + { + NotifyActivation(FALSE, (IUnknown *)(pEntry->pv)); + } + + _pStdId->DecStrongCnt(mshlflags & MSHLFLAGS_KEEPALIVE); + + ASSERT_LOCK_RELEASED + LOCK +} + +//+------------------------------------------------------------------- +// +// Member: AddIPIDEntry, private +// +// Synopsis: Allocates and fills in an entry in the IPID table. +// The returned entry is not yet in the IPID chain. +// +// History: 20-Feb-95 Rickhi Created +// +//-------------------------------------------------------------------- +HRESULT CStdMarshal::AddIPIDEntry(OXIDEntry *pOXIDEntry, IPID *pipid, + REFIID riid, CRpcChannelBuffer *pChnl, IUnknown *pUnkStub, + void *pv, IPIDEntry **ppEntry) +{ + ComDebOut((DEB_MARSHAL,"AddIPIDEntry this:%x pOXID:%x iid:%I pStub:%x pv:%x\n", + this, pOXIDEntry, &riid, pUnkStub, pv)); + ASSERT_LOCK_HELD + + // CODEWORK: while we released the lock to create the proxy or stub, + // the same interface could have been marshaled/unmarshaled. We should + // go check for duplicates now. This is just an optimization, not a + // requirement. + + // get a new entry in the IPID table. + IPIDEntry *pEntryNew = gIPIDTbl.FirstFree(); + + if (pEntryNew == NULL) + { + // no free slots and could not allocate more memory to grow + return E_OUTOFMEMORY; + } + + if (ServerSide()) + { + // create an IPID for this entry + DWORD *pdw = &pipid->Data1; + *pdw = gIPIDTbl.GetEntryIndex(pEntryNew); // IPID table index + *(pdw+1) = GetCurrentProcessId(); // current PID + *(pdw+2) = GetCurrentThreadId(); // current TID + *(pdw+3) = gIPIDSeqNum++; // process sequence # + } + + *ppEntry = pEntryNew; + + pEntryNew->ipid = *pipid; + pEntryNew->iid = riid; + pEntryNew->pChnl = pChnl; + pEntryNew->pStub = pUnkStub; + pEntryNew->pv = pv; + pEntryNew->dwFlags = ServerSide() ? IPIDF_SERVERENTRY : + IPIDF_DISCONNECTED | IPIDF_NOPING; + pEntryNew->cStrongRefs = 0; + pEntryNew->cWeakRefs = 0; + pEntryNew->cPrivateRefs = 0; + pEntryNew->pOXIDEntry = pOXIDEntry; + + ASSERT_LOCK_HELD + ComDebOut((DEB_MARSHAL,"AddIPIDEntry this:%x pIPIDEntry:%x ipid:%I\n", + this, pEntryNew, &pEntryNew->ipid)); + return S_OK; +} + +//+------------------------------------------------------------------- +// +// Member: ReleaseCliIPIDs, private +// +// Synopsis: walks the IPID table releasing the proxy/stub entries +// on the IPIDEntries associated with this Object. +// +// History: 20-Feb-95 Rickhi Created +// +//-------------------------------------------------------------------- +void CStdMarshal::ReleaseCliIPIDs(void) +{ + Win4Assert(ClientSide()); + + ASSERT_LOCK_RELEASED + LOCK + + // first thing we do is detach the chain of IPIDs from the CStdMarshal + // while holding the LOCK. Then release the lock and walk the chain + // releasing the proxy/stub pointers. Note there should not be any other + // pointers to any of these IPIDs, so it is OK to muck with their state. + + IPIDEntry *pFirstIPID = _pFirstIPID; + _pFirstIPID = NULL; + + UNLOCK + ASSERT_LOCK_RELEASED; + + IPIDEntry *pLastIPID; + IPIDEntry *pEntry = pFirstIPID; + + while (pEntry) + { + // mark the entry as vacant and disconnected. Note we dont put + // it back in the FreeList yet. We leave it chained to the other + // IPIDs in the list, and add the whole chain to the FreeList at + // the end. + + pEntry->dwFlags |= IPIDF_VACANT | IPIDF_DISCONNECTED; + + if (pEntry->pStub) + { + ComDebOut((DEB_MARSHAL,"ReleaseProxy pProxy:%x\n", pEntry->pStub)); + pEntry->pStub->Release(); + pEntry->pStub = NULL; + } + + pLastIPID = pEntry; + pEntry = pEntry->pNextOID; + } + + + if (pFirstIPID != NULL) + { + // now take the LOCK again and release all the IPIDEntries back into + // the IPIDTable in one fell swoop. + + ASSERT_LOCK_RELEASED + LOCK + + gIPIDTbl.ReleaseEntryList(pFirstIPID, pLastIPID); + + UNLOCK + ASSERT_LOCK_RELEASED + } + + ASSERT_LOCK_RELEASED +} + +//+------------------------------------------------------------------------ +// +// Member: CStdMarshal::LockClient/UnLockClient +// +// Synopsis: Locks the client side object during outgoing calls in order +// to prevent the object going away in a nested disconnect. +// +// Notes: UnLockClient is not safe in the freethreaded model. +// Fortunately pending disconnect can only be set in the +// apartment model on the client side. +// +// History: 12-Jun-95 Rickhi Created +// +//------------------------------------------------------------------------- +ULONG CStdMarshal::LockClient(void) +{ + Win4Assert(ClientSide()); + InterlockedIncrement(&_cNestedCalls); + return (_pStdId->GetCtrlUnk())->AddRef(); +} + +ULONG CStdMarshal::UnLockClient(void) +{ + Win4Assert(ClientSide()); + if ((InterlockedDecrement(&_cNestedCalls) == 0) && + (_dwFlags & SMFLAGS_PENDINGDISCONNECT)) + { + Disconnect(); + } + return (_pStdId->GetCtrlUnk())->Release(); +} + +//+------------------------------------------------------------------- +// +// Member: CStdMarshal::GetSecureRemUnk, public +// +// Synopsis: If the marshaller has its own remote unknown, use it. +// Otherwise use the OXID's remote unknown. +// +// History: 2-Apr-96 AlexMit Created +// +//-------------------------------------------------------------------- +HRESULT CStdMarshal::GetSecureRemUnk( IRemUnknown **ppSecureRemUnk, + OXIDEntry *pOXIDEntry ) +{ + ComDebOut((DEB_OXID, "CStdMarshal::GetSecureRemUnk ppRemUnk:%x\n", + ppSecureRemUnk)); + + ASSERT_LOCK_DONTCARE + + if (_pSecureRemUnk != NULL) + { + *ppSecureRemUnk = _pSecureRemUnk; + return S_OK; + } + else + { + return gOXIDTbl.GetRemUnk( pOXIDEntry, ppSecureRemUnk ); + } +} + + + +//+------------------------------------------------------------------- +// +// Member: CStdMarshal::LookupStub, private +// +// Synopsis: used by the channel to acquire the stub ptr for debugging +// +// History: 12-Jun-95 Rickhi Created +// +//-------------------------------------------------------------------- +HRESULT CStdMarshal::LookupStub(REFIID riid, IRpcStubBuffer **ppStub) +{ + AssertValid(); + Win4Assert(ServerSide()); + + ASSERT_LOCK_RELEASED + LOCK + + IPIDEntry *pEntry; + HRESULT hr = FindIPIDEntry(riid, &pEntry); + + if (SUCCEEDED(hr)) + { + *ppStub = (IRpcStubBuffer *)pEntry->pStub; + } + + UNLOCK + ASSERT_LOCK_RELEASED + return hr; +} + + +#if DBG==1 +//+------------------------------------------------------------------- +// +// Member: CStdMarshal::GetOXID, private, debug +// +// Synopsis: returns the OXID for this object +// +// History: 20-Feb-95 Rickhi Created +// +//-------------------------------------------------------------------- +REFMOXID CStdMarshal::GetMOXID(void) +{ + ASSERT_LOCK_HELD + + if (ServerSide()) + { + // local to this apartment, use the local OXID + return GetLocalOXIDEntry()->moxid; + } + else + { + Win4Assert(_pChnl); + return _pChnl->GetMOXID(); + } +} + +//+------------------------------------------------------------------- +// +// Member: CStdMarshal::DbgWalkIPIDs +// +// Synopsis: Validates that the state of all the IPIDs is consistent. +// +// History: 20-Feb-95 Rickhi Created. +// +//-------------------------------------------------------------------- +void CStdMarshal::DbgWalkIPIDs(void) +{ + IPIDEntry *pEntry = _pFirstIPID; + while (pEntry) + { + ValidateIPIDEntry(pEntry); + pEntry = pEntry->pNextOID; + } +} + +//+------------------------------------------------------------------- +// +// Member: CStdMarshal::AssertValid +// +// Synopsis: Validates that the state of the object is consistent. +// +// History: 20-Feb-95 Rickhi Created. +// +//-------------------------------------------------------------------- +void CStdMarshal::AssertValid() +{ + LOCK + Win4Assert((_dwFlags & ~(SMFLAGS_CLIENT_SIDE | SMFLAGS_REGISTEREDOID | + SMFLAGS_PENDINGDISCONNECT | SMFLAGS_DISCONNECTED | + SMFLAGS_FIRSTMARSHAL | SMFLAGS_HANDLER | SMFLAGS_WEAKCLIENT | + SMFLAGS_IGNORERUNDOWN | SMFLAGS_CLIENTMARSHALED | + SMFLAGS_NOPING | SMFLAGS_TRIEDTOCONNECT)) == 0); + + Win4Assert(_pStdId != NULL); + Win4Assert(IsValidInterface(_pStdId)); + + if (_pChnl != NULL) + { + Win4Assert(IsValidInterface(_pChnl)); + _pChnl->AssertValid(FALSE, FALSE); + } + + DbgWalkIPIDs(); + UNLOCK +} + +//+------------------------------------------------------------------- +// +// Member: CStdMarshal::AssertDisconnectPrevented, private +// +// Synopsis: Just ensures that no disconnects can/have arrived. +// +// History: 21-Sep-95 Rickhi Created +// +//+------------------------------------------------------------------- +void CStdMarshal::AssertDisconnectPrevented() +{ + ASSERT_LOCK_HELD + if (ServerSide()) + Win4Assert(!(_dwFlags & SMFLAGS_DISCONNECTED)); + Win4Assert(_cNestedCalls > 0); +} + +//+------------------------------------------------------------------- +// +// Member: CStdMarshal::ValidateSTD +// +// Synopsis: Ensures that the STDOBJREF is valid +// +// History: 20-Feb-95 Rickhi Created. +// +//-------------------------------------------------------------------- +void CStdMarshal::ValidateSTD(STDOBJREF *pStd) +{ + LOCK + + // validate the flags field + Win4Assert((pStd->flags & SORF_RSRVD_MBZ) == 0); + + // validate the OID + OID oid; + OIDFromMOID(_pStdId->GetOID(), &oid); + Win4Assert(pStd->oid == oid); + + if (ServerSide() || _pChnl != NULL) + { + // validate the OXID + OXID oxid; + OXIDFromMOXID(GetMOXID(), &oxid); + Win4Assert(pStd->oxid == oxid ); + } + + UNLOCK +} + +//+------------------------------------------------------------------- +// +// Function: DbgDumpSTD +// +// Synopsis: dumps a formated STDOBJREF to the debugger +// +// History: 20-Feb-95 Rickhi Created. +// +//-------------------------------------------------------------------- +void DbgDumpSTD(STDOBJREF *pStd) +{ + ULARGE_INTEGER *puintOxid = (ULARGE_INTEGER *)&pStd->oxid; + ULARGE_INTEGER *puintOid = (ULARGE_INTEGER *)&pStd->oid; + + ComDebOut((DEB_MARSHAL, + "\n\tpStd:%x flags:%08x cPublicRefs:%08x\n\toxid: %08x %08x\n\t oid: %08x %08x\n\tipid:%I\n", + pStd, pStd->flags, pStd->cPublicRefs, puintOxid->HighPart, puintOxid->LowPart, + puintOid->HighPart, puintOid->LowPart, &pStd->ipid)); +} + +//+------------------------------------------------------------------- +// +// Member: CStdMarshal::ValidateIPIDEntry +// +// Synopsis: Ensures that the IPIDEntry is valid +// +// History: 20-Feb-95 Rickhi Created. +// +//-------------------------------------------------------------------- +void CStdMarshal::ValidateIPIDEntry(IPIDEntry *pEntry) +{ + // ask the table to validate the IPID entry + gIPIDTbl.ValidateIPIDEntry(pEntry, ServerSide(), _pChnl); +} + +//+------------------------------------------------------------------- +// +// Member: CStdMarshal::DbgDumpInterfaceList +// +// Synopsis: Prints the list of Interfaces on the object. +// +// History: 20-Feb-95 Rickhi Created. +// +//-------------------------------------------------------------------- +void CStdMarshal::DbgDumpInterfaceList(void) +{ + ComDebOut((DEB_ERROR, "\tInterfaces left on object are:\n")); + LOCK + + // walk the IPID list printing the friendly name of each interface + IPIDEntry *pEntry = _pFirstIPID; + while (pEntry) + { + WCHAR wszName[MAX_PATH]; + GetInterfaceName(pEntry->iid, wszName); + ComDebOut((DEB_ERROR,"\t\t %ws\t cRefs:%x\n",wszName,pEntry->cStrongRefs)); + pEntry = pEntry->pNextOID; + } + + UNLOCK +} +#endif // DBG == 1 + +//+------------------------------------------------------------------- +// +// Function: RemoteQueryInterface, private +// +// Synopsis: call RemoteQueryInterface on remote server. +// +// History: 20-Feb-95 Rickhi Created. +// +//-------------------------------------------------------------------- +INTERNAL RemoteQueryInterface(IRemUnknown *pRemUnk, IPIDEntry *pIPIDEntry, + USHORT cIIDs, IID *pIIDs, + REMQIRESULT **ppQiRes, BOOL fWeakClient) +{ + ComDebOut((DEB_MARSHAL, + "RemoteQueryInterface pIPIDEntry:%x cIIDs:%x, pIIDs:%x riid:%I\n", + pIPIDEntry, cIIDs, pIIDs, pIIDs)); + Win4Assert(pIPIDEntry->pOXIDEntry); // must have a resolved oxid + ASSERT_LOCK_HELD + + // set the IPID according to whether we want strong or weak + // references. It will only be weak if we are an OLE container + // and are talking to an embedding running on the same machine. + + IPID ipid = pIPIDEntry->ipid; + if (fWeakClient) + { + ipid.Data1 |= IPIDFLAG_WEAKREF; + } + + UNLOCK + ASSERT_LOCK_RELEASED + + HRESULT hr = pRemUnk->RemQueryInterface(ipid, REM_ADDREF_CNT, + cIIDs, pIIDs, ppQiRes); + ASSERT_LOCK_RELEASED + LOCK + + ASSERT_LOCK_HELD + ComDebOut((DEB_MARSHAL, "RemoteQueryInterface hr:%x pQIRes:%x\n", + hr, *ppQiRes)); + return hr; +} + +//+------------------------------------------------------------------- +// +// Function: RemoteAddRef, private +// +// Synopsis: calls the remote server to AddRef one of its interfaces +// +// History: 20-Feb-95 Rickhi Created. +// +//-------------------------------------------------------------------- +INTERNAL RemoteAddRef(IPIDEntry *pIPIDEntry, OXIDEntry *pOXIDEntry, + ULONG cStrongRefs, ULONG cSecureRefs) +{ + ComDebOut((DEB_MARSHAL, + "RemoteAddRef cRefs:%x cSecure:%x ipid:%I\n", + cStrongRefs, cSecureRefs, &pIPIDEntry->ipid)); + ASSERT_LOCK_HELD + + // if the object does not require pinging, it is also ignoring + // reference counts, so there is no need to go get more, just + // pretend like we did. + + if (pIPIDEntry->dwFlags & IPIDF_NOPING) + { + return S_OK; + } + + // get the IRemUnknown for the remote server + IRemUnknown *pRemUnk; + HRESULT hr = gOXIDTbl.GetRemUnk(pOXIDEntry, &pRemUnk); + + if (SUCCEEDED(hr)) + { + // call RemAddRef on the interface + REMINTERFACEREF rifRef; + rifRef.ipid = pIPIDEntry->ipid; + rifRef.cPublicRefs = cStrongRefs; + rifRef.cPrivateRefs = cSecureRefs; + + UNLOCK + ASSERT_LOCK_RELEASED + + HRESULT ignore; + hr = pRemUnk->RemAddRef(1, &rifRef, &ignore); + + ASSERT_LOCK_RELEASED + LOCK + } + + ASSERT_LOCK_HELD + ComDebOut((DEB_MARSHAL, "RemoteAddRef hr:%x\n", hr)); + return hr; +} + +//+------------------------------------------------------------------- +// +// Function: RemoteReleaseRifRef +// +// Synopsis: calls the remote server to release some IPIDs +// +// History: 20-Feb-95 Rickhi Created. +// +//-------------------------------------------------------------------- +INTERNAL RemoteReleaseRifRef(OXIDEntry *pOXIDEntry, + USHORT cRifRef, REMINTERFACEREF *pRifRef) +{ + Win4Assert(pRifRef); + ComDebOut((DEB_MARSHAL, + "RemoteRelease pOXID:%x cRifRef:%x pRifRef:%x cRefs:%x ipid:%I\n", + pOXIDEntry, cRifRef, pRifRef, pRifRef->cPublicRefs, &pRifRef->ipid)); + Win4Assert(pOXIDEntry); + ASSERT_LOCK_HELD + + HRESULT hr; + + if (IsSTAThread() && + FAILED(CanMakeOutCall(CALLCAT_SYNCHRONOUS, IID_IRundown))) + { + // the call control will not let this apartment model thread make + // the outgoing release call (cause we're inside an InputSync call) + // so we post ourselves a message to do it later. + + hr = PostReleaseRifRef(pOXIDEntry, cRifRef, pRifRef); + } + else + { + // get the IRemUnknown for the remote server + IRemUnknown *pRemUnk; + hr = gOXIDTbl.GetRemUnk(pOXIDEntry, &pRemUnk); + + if (SUCCEEDED(hr)) + { + UNLOCK + ASSERT_LOCK_RELEASED + hr = pRemUnk->RemRelease(cRifRef, pRifRef); + ASSERT_LOCK_RELEASED + LOCK + } + } + + ComDebOut((DEB_MARSHAL, "RemoteRelease hr:%x\n", hr)); + ASSERT_LOCK_HELD + return hr; +} + +//+------------------------------------------------------------------- +// +// Function: PostReleaseRifRef +// +// Synopsis: Post a message to ourself to call RemoteReleaseRifRef later. +// This is used to make a synchronous remote Release call when +// a Release is done inside of an InputSync call. The call is +// delayed until we are out of the InputSync call, since the +// call control wont allow a synch call inside an inputsync call. +// +// History: 05-Apr-96 Rickhi Created. +// +//-------------------------------------------------------------------- +INTERNAL PostReleaseRifRef(OXIDEntry *pOXIDEntry, + USHORT cRifRef, REMINTERFACEREF *pRifRef) +{ + Win4Assert(pRifRef); + ComDebOut((DEB_MARSHAL, + "PostRelease pOXID:%x cRifRef:%x pRifRef:%x cRefs:%x ipid:%I\n", + pOXIDEntry, cRifRef, pRifRef, pRifRef->cPublicRefs, &pRifRef->ipid)); + Win4Assert(pOXIDEntry); + ASSERT_LOCK_HELD + + OXIDEntry *pLocalOXIDEntry = NULL; + HRESULT hr = gOXIDTbl.GetLocalEntry(&pLocalOXIDEntry); + + if (SUCCEEDED(hr)) + { + // allocate a structure to hold the data and copy in the RifRef + // list, OXIDEntry, and count of entries. Inc the OXID RefCnt to + // ensure it stays alive until the posted message is processed. + + hr = E_OUTOFMEMORY; + ULONG cbRifRef = cRifRef * sizeof(REMINTERFACEREF); + ULONG cbAlloc = sizeof(POSTRELRIFREF) + (cbRifRef-1); + POSTRELRIFREF *pRelRifRef = (POSTRELRIFREF *) PrivMemAlloc(cbAlloc); + + if (pRelRifRef) + { + IncOXIDRefCnt(pOXIDEntry); // keep alive + pRelRifRef->pOXIDEntry = pOXIDEntry; + pRelRifRef->cRifRef = cRifRef; + memcpy(&pRelRifRef->arRifRef, pRifRef, cbRifRef); + + if (!PostMessage((HWND)pLocalOXIDEntry->hServerSTA, + WM_OLE_ORPC_RELRIFREF, + GetCurrentThreadId(), + (LPARAM)pRelRifRef)) + { + // Post failed, free the structure and report an error. + DecOXIDRefCnt(pOXIDEntry); + PrivMemFree(pRelRifRef); + hr = RPC_E_SYS_CALL_FAILED; + } + } + } + + ComDebOut((DEB_MARSHAL, "PostRelease hr:%x\n", hr)); + ASSERT_LOCK_HELD + return hr; +} + +//+------------------------------------------------------------------- +// +// Function: HandlePostReleaseRifRef +// +// Synopsis: Handles the ReleaseRifRef message that was posted to the +// current thread (by the current thread) in order to do a +// delayed remote release call. See PostReleaseRifRef above. +// +// History: 05-Apr-96 Rickhi Created. +// +//-------------------------------------------------------------------- +INTERNAL HandlePostReleaseRifRef(LPARAM param) +{ + Win4Assert(param); + ComDebOut((DEB_MARSHAL, "HandlePostRelease pRifRef:%x\n", param)); + POSTRELRIFREF *pRelRifRef = (POSTRELRIFREF *)param; + + ASSERT_LOCK_RELEASED + LOCK + + // simply make the real remote release call now, then release the + // reference we have on the OXIDEntry, and free the message buffer. + // If this call fails, dont try again, otherwise we could spin busy + // waiting. Instead, just let Rundown clean up the server. + + RemoteReleaseRifRef(pRelRifRef->pOXIDEntry, + pRelRifRef->cRifRef, + &pRelRifRef->arRifRef); + + DecOXIDRefCnt(pRelRifRef->pOXIDEntry); + + UNLOCK + ASSERT_LOCK_RELEASED + + PrivMemFree(pRelRifRef); + ComDebOut((DEB_MARSHAL, "HandlePostRelease hr:%x\n", S_OK)); + return S_OK; +} + +//+------------------------------------------------------------------- +// +// Member: RemoteChangeRef +// +// Synopsis: calls the remote server to convert interface refereces +// from strong to weak or vise versa. This behaviour is +// required to support silent updates in the OLE container / +// link / embedding scenarios. +// +// Notes: This functionality is not exposed in FreeThreaded apps +// or in remote apps. The implication being that the container +// must be on the same machine as the embedding. +// +// History: 20-Nov-95 Rickhi Created. +// +//-------------------------------------------------------------------- +HRESULT CStdMarshal::RemoteChangeRef(BOOL fLock, BOOL fLastUnlockReleases) +{ + ComDebOut((DEB_MARSHAL, "RemoteChangeRef \n")); + Win4Assert(ClientSide()); + Win4Assert(IsSTAThread()); // not allowed in MTA Apartment + ASSERT_LOCK_RELEASED + + // must be at least 1 proxy already connected in order to be able + // to do this. We cant just ASSERT that's true because we were not + // holding the lock on entry. + + LOCK + HRESULT hr = PreventDisconnect(); + + // A previous version of OLE set the object to weak even it it was + // currently disconnected, and it remembered that it was weak and set + // any new interfaces that it later accquired to weak. I emulate that + // behaviour here. + + if (fLock) + _dwFlags &= ~SMFLAGS_WEAKCLIENT; + else + _dwFlags |= SMFLAGS_WEAKCLIENT; + + + if (SUCCEEDED(hr)) + { + REMINTERFACEREF *pRifRefAlloc = (REMINTERFACEREF *) + _alloca(_cIPIDs * sizeof(REMINTERFACEREF)); + REMINTERFACEREF *pRifRef = pRifRefAlloc; + + DWORD cSecure = gCapabilities & EOAC_SECURE_REFS ? 1 : 0; + USHORT cIIDs = 0; + OXIDEntry *pOXIDEntry = NULL; + IPIDEntry *pNextIPID = _pFirstIPID; + + while (pNextIPID) + { + if (!(pNextIPID->dwFlags & IPIDF_DISCONNECTED)) + { + if (pOXIDEntry == NULL) + { + // This is the first connected IPID we encountered. + // Get its OXID entry and make sure it is for a server + // process on the current machine. + + if (!(pNextIPID->pOXIDEntry->dwFlags & + OXIDF_MACHINE_LOCAL)) + { + // OXID is for a remote process. Abandon this call. + Win4Assert(cIIDs == 0); // skip call below + Win4Assert(pOXIDEntry == NULL); // dont dec below + Win4Assert(hr == S_OK); // report success + break; // exit while loop + } + + // Remember the OXID and AddRef it to keep it alive + // over the duration of the call. + + pOXIDEntry = pNextIPID->pOXIDEntry; + IncOXIDRefCnt(pOXIDEntry); + } + + pRifRef->ipid = pNextIPID->ipid; + + if (!fLock && pNextIPID->cStrongRefs > 0) + { + pRifRef->cPublicRefs = pNextIPID->cStrongRefs; + pRifRef->cPrivateRefs = pNextIPID->cPrivateRefs; + pNextIPID->cWeakRefs += pNextIPID->cStrongRefs; + pNextIPID->cStrongRefs = 0; + pNextIPID->cPrivateRefs = 0; + + pRifRef++; + cIIDs++; + } + else if (fLock && pNextIPID->cStrongRefs == 0) + { + pRifRef->cPublicRefs = pNextIPID->cWeakRefs; + pRifRef->cPrivateRefs = cSecure; + pNextIPID->cStrongRefs += pNextIPID->cWeakRefs; + pNextIPID->cWeakRefs = 0; + pNextIPID->cPrivateRefs = cSecure; + + pRifRef++; + cIIDs++; + } + } + + // get next IPIDentry for this object + pNextIPID = pNextIPID->pNextOID; + } + + if (cIIDs != 0) + { + // we have looped filling in the IPID list, and there are + // entries in the list. go call the server now. First, set up + // the flags, then reset the RifRef pointer since we trashed + // it while walking the list above. + + DWORD dwFlags = (fLock) ? IRUF_CONVERTTOSTRONG : IRUF_CONVERTTOWEAK; + if (fLastUnlockReleases) + dwFlags |= IRUF_DISCONNECTIFLASTSTRONG; + + hr = RemoteChangeRifRef(pOXIDEntry, dwFlags, cIIDs, pRifRefAlloc); + } + + if (pOXIDEntry) + { + // release the OXIDEntry + DecOXIDRefCnt(pOXIDEntry); + } + } + else + { + // A previous implementation of OLE returned S_OK if the object was + // disconnected. I emulate that behaviour here. + + hr = S_OK; + } + + DbgWalkIPIDs(); + UNLOCK + ASSERT_LOCK_RELEASED + + // this will handle any Disconnect that came in while we were busy. + hr = HandlePendingDisconnect(hr); + + ComDebOut((DEB_MARSHAL, "RemoteChangeRef hr:%x\n", hr)); + ASSERT_LOCK_RELEASED + return hr; +} + +//+------------------------------------------------------------------- +// +// Function: RemoteChangeRifRef +// +// Synopsis: calls the remote server to release some IPIDs +// +// History: 20-Feb-95 Rickhi Created. +// +//-------------------------------------------------------------------- +INTERNAL RemoteChangeRifRef(OXIDEntry *pOXIDEntry, DWORD dwFlags, + USHORT cRifRef, REMINTERFACEREF *pRifRef) +{ + Win4Assert(pRifRef); + ComDebOut((DEB_MARSHAL, + "RemoteChangeRifRef pOXID:%x cRifRef:%x pRifRef:%x cRefs:%x ipid:%I\n", + pOXIDEntry, cRifRef, pRifRef, pRifRef->cPublicRefs, &(pRifRef->ipid))); + Win4Assert(pOXIDEntry); + ASSERT_LOCK_HELD + + // get the IRemUnknown for the remote server + IRemUnknown *pRemUnk; + HRESULT hr = gOXIDTbl.GetRemUnk(pOXIDEntry, &pRemUnk); + + if (SUCCEEDED(hr)) + { + UNLOCK + ASSERT_LOCK_RELEASED + hr = ((IRemUnknown2 *)pRemUnk)->RemChangeRef(dwFlags, cRifRef, pRifRef); + ASSERT_LOCK_RELEASED + LOCK + } + + ComDebOut((DEB_MARSHAL, "RemoteChangeRifRef hr:%x\n", hr)); + ASSERT_LOCK_HELD + return hr; +} + +//+------------------------------------------------------------------- +// +// Function: RemoteReleaseStdObjRef +// +// Synopsis: calls the remote server to release an ObjRef +// +// History: 20-Feb-95 Rickhi Created. +// +//-------------------------------------------------------------------- +INTERNAL RemoteReleaseStdObjRef(STDOBJREF *pStd, OXIDEntry *pOXIDEntry) +{ + ComDebOut((DEB_MARSHAL, "RemoteReleaseStdObjRef pStd:%x\n pOXIDEntry:%x", + pStd, pOXIDEntry)); + ASSERT_LOCK_HELD + + REMINTERFACEREF rifRef; + rifRef.ipid = pStd->ipid; + rifRef.cPublicRefs = pStd->cPublicRefs; + rifRef.cPrivateRefs = 0; + + // incase we get disconnected while in the RemRelease call + // we need to extract the OXIDEntry and AddRef it. + + IncOXIDRefCnt(pOXIDEntry); + RemoteReleaseRifRef(pOXIDEntry, 1, &rifRef); + DecOXIDRefCnt(pOXIDEntry); + + ComDebOut((DEB_MARSHAL, "RemoteReleaseStdObjRef hr:%x\n", S_OK)); + ASSERT_LOCK_HELD + return S_OK; +} + +//+------------------------------------------------------------------- +// +// Function: RemoteReleaseObjRef +// +// Synopsis: calls the remote server to release an ObjRef +// +// History: 20-Feb-95 Rickhi Created. +// +//-------------------------------------------------------------------- +INTERNAL RemoteReleaseObjRef(OBJREF &objref) +{ + return RemoteReleaseStdObjRef(&ORSTD(objref).std, GetOXIDFromObjRef(objref)); +} + +//+------------------------------------------------------------------- +// +// Function: GetOXIDFromObjRef, private +// +// Synopsis: extracts the OXID from the OBJREF. +// +// History: 09-Jan-96 Rickhi Created. +// +//-------------------------------------------------------------------- +OXIDEntry *GetOXIDFromObjRef(OBJREF &objref) +{ + // TRICK: Internally we use the saResAddr.size field as the ptr + // to the OXIDEntry. See ReadObjRef and FillObjRef. + + OXIDEntry *pOXIDEntry = (objref.flags & OBJREF_STANDARD) + ? *(OXIDEntry **)&ORSTD(objref).saResAddr + : *(OXIDEntry **)&ORHDL(objref).saResAddr; + + Win4Assert(pOXIDEntry); + return pOXIDEntry; +} + +//+------------------------------------------------------------------- +// +// Function: WriteObjRef, private +// +// Synopsis: Writes the objref into the stream +// +// History: 20-Feb-95 Rickhi Created. +// +//-------------------------------------------------------------------- +INTERNAL WriteObjRef(IStream *pStm, OBJREF &objref, DWORD dwDestCtx) +{ + ASSERT_LOCK_RELEASED + + ULONG cbToWrite = (objref.flags & OBJREF_STANDARD) + ? (2*sizeof(ULONG)) + sizeof(IID) + sizeof(STDOBJREF) + : (2*sizeof(ULONG)) + sizeof(IID) + sizeof(STDOBJREF) + sizeof(CLSID); + + // write the fixed-sized part of the OBJREF into the stream + HRESULT hr = pStm->Write(&objref, cbToWrite, NULL); + + if (SUCCEEDED(hr)) + { + // write the resolver address into the stream. + // TRICK: Internally we use the saResAddr.size field as the ptr + // to the OXIDEntry. See ReadObjRef and FillObjRef. + + DUALSTRINGARRAY *psa; + OXIDEntry *pOXIDEntry = GetOXIDFromObjRef(objref); + + if (pOXIDEntry->pMIDEntry != gpLocalMIDEntry || + dwDestCtx == MSHCTX_DIFFERENTMACHINE) + { + // the interface is for a remote server, or it is going to a + // remote client, therefore, marshal the resolver strings + psa = pOXIDEntry->pMIDEntry->Node.psaKey; + Win4Assert(psa->wNumEntries != 0); + } + else + { + // the interface is for an OXID local to this machine and + // the interface is not going to a remote client, marshal an + // empty string (we pay attention to this in ReadObjRef) + psa = &saNULL; + } + + // These string bindings always come from the object exporter + // who has already padded the size to 8 bytes. + hr = pStm->Write(psa, SASIZE(psa->wNumEntries), NULL); + + ComDebOut((DEB_MARSHAL,"WriteObjRef psa:%x\n", psa)); + } + + ASSERT_LOCK_RELEASED + return hr; +} + +//+------------------------------------------------------------------- +// +// Function: ReadObjRef, private +// +// Synopsis: Reads the objref from the stream +// +// History: 20-Feb-95 Rickhi Created. +// +//-------------------------------------------------------------------- +INTERNAL ReadObjRef(IStream *pStm, OBJREF &objref) +{ + ASSERT_LOCK_RELEASED + + // read the signature, flags, and iid fields of the objref so we know + // what kind of objref we are dealing with and how big it is. + + HRESULT hr = StRead(pStm, &objref, 2*sizeof(ULONG) + sizeof(IID)); + + if (SUCCEEDED(hr)) + { + if ((objref.signature != OBJREF_SIGNATURE) || + (objref.flags & OBJREF_RSRVD_MBZ) || + (objref.flags == 0)) + { + // the objref signature is bad, or one of the reserved + // bits in the flags is set, or none of the required bits + // in the flags is set. the objref cant be interpreted so + // fail the call. + + Win4Assert(!"Invalid Objref Flags"); + return RPC_E_INVALID_OBJREF; + } + + // compute the size of the remainder of the objref and + // include the size fields for the resolver string array + + STDOBJREF *pStd = &ORSTD(objref).std; + DUALSTRINGARRAY *psa; + ULONG cbToRead; + + if (objref.flags & OBJREF_STANDARD) + { + cbToRead = sizeof(STDOBJREF) + sizeof(ULONG); + psa = &ORSTD(objref).saResAddr; + } + else if (objref.flags & OBJREF_HANDLER) + { + cbToRead = sizeof(STDOBJREF) + sizeof(CLSID) + sizeof(ULONG); + psa = &ORHDL(objref).saResAddr; + } + else if (objref.flags & OBJREF_CUSTOM) + { + cbToRead = sizeof(CLSID) + 2*sizeof(DWORD); // clsid + cbExtension + size + psa = NULL; + } + + // read the rest of the (fixed sized) objref from the stream + hr = StRead(pStm, pStd, cbToRead); + + if (SUCCEEDED(hr)) + { + if (psa != NULL) + { + // Non custom interface. Make sure the resolver string array + // has some sensible values. + if (psa->wNumEntries != 0 && + psa->wSecurityOffset >= psa->wNumEntries) + { + hr = RPC_E_INVALID_OBJREF; + } + } + else + { + // custom marshaled interface + if (ORCST(objref).cbExtension != 0) + { + // skip past the extensions since we currently dont + // know about any extension types. + LARGE_INTEGER dlibMove; + dlibMove.LowPart = ORCST(objref).cbExtension; + dlibMove.HighPart = 0; + hr = pStm->Seek(dlibMove, STREAM_SEEK_CUR, NULL); + } + } + } + + if (SUCCEEDED(hr) && psa) + { + // Non custom interface. The data that follows is a variable + // sized string array. Allocate memory for it and then read it. + + DbgDumpSTD(pStd); + DUALSTRINGARRAY *psaNew; + + cbToRead = psa->wNumEntries * sizeof(WCHAR); + if (cbToRead == 0) + { + // server must be local to this machine, just get the local + // resolver strings and use them to resolve the OXID + psaNew = gpsaLocalResolver; + } + else + { + // allocate space to read the strings + psaNew = (DUALSTRINGARRAY *) _alloca(cbToRead + sizeof(ULONG)); + if (psaNew != NULL) + { + // update the size fields and read in the rest of the data + psaNew->wSecurityOffset = psa->wSecurityOffset; + psaNew->wNumEntries = psa->wNumEntries; + + hr = StRead(pStm, psaNew->aStringArray, cbToRead); + } + else + { + psa->wNumEntries = 0; + psa->wSecurityOffset = 0; + hr = E_OUTOFMEMORY; + + // seek the stream past what we should have read, ignore + // seek errors, since the OOM takes precedence. + + LARGE_INTEGER libMove; + libMove.LowPart = cbToRead; + libMove.HighPart = 0; + pStm->Seek(libMove, STREAM_SEEK_CUR, 0); + } + } + + // TRICK: internally we want to keep the ObjRef a fixed size + // structure, even though we have variable sized data. To do + // this i use the saResAddr.size field of the ObjRef as a ptr + // to the OXIDEntry. We pay attention to this in FillObjRef, + // WriteObjRef and FreeObjRef. + + if (SUCCEEDED(hr)) + { + // resolve the OXID. + ASSERT_LOCK_RELEASED + LOCK + OXIDEntry *pOXIDEntry = NULL; + hr = gResolver.ClientResolveOXID(pStd->oxid, + psaNew, &pOXIDEntry); + UNLOCK + ASSERT_LOCK_RELEASED + *((void **) psa) = pOXIDEntry; + } + else + { + *((void **) psa) = NULL; + } + } + } + + ComDebOut((DEB_MARSHAL,"ReadObjRef hr:%x objref:%x\n", hr, &objref)); + ASSERT_LOCK_RELEASED + return hr; +} + +//+------------------------------------------------------------------- +// +// Function: FreeObjRef, private +// +// Synopsis: Releases an objref that was read in from a stream via +// ReadObjRef. +// +// History: 20-Feb-95 Rickhi Created. +// +// Notes: Anybody who calls ReadObjRef should call this guy to +// free the objref. This decrements the refcnt on the +// embedded pointer to the OXIDEntry. +// +//-------------------------------------------------------------------- +INTERNAL_(void) FreeObjRef(OBJREF &objref) +{ + if (objref.flags & (OBJREF_STANDARD | OBJREF_HANDLER)) + { + // TRICK: Internally we use the saResAddr.size field as the ptr to + // the OXIDEntry. See ReadObjRef, WriteObjRef and FillObjRef. + + OXIDEntry *pOXIDEntry = GetOXIDFromObjRef(objref); + + LOCK + Win4Assert(pOXIDEntry); + DecOXIDRefCnt(pOXIDEntry); + UNLOCK + } +} + +//+------------------------------------------------------------------- +// +// Function: MakeFakeObjRef, private +// +// Synopsis: Invents an OBJREF that can be unmarshaled in this process. +// The objref is partially fact (the OXIDEntry) and partially +// fiction (the OID). +// +// History: 16-Jan-96 Rickhi Created. +// +// Notes: This is used by MakeSCMProxy and GetRemUnk. Note that +// the pOXIDEntry is not AddRef'd here because the OBJREF +// created is only short-lived the callers guarantee it's +// lifetime, so FreeObjRef need not be called. +// +//-------------------------------------------------------------------- +INTERNAL MakeFakeObjRef(OBJREF &objref, OXIDEntry *pOXIDEntry, + REFIPID ripid, REFIID riid) +{ + // first, invent an OID since this could fail. + + STDOBJREF *pStd = &ORSTD(objref).std; + HRESULT hr = gResolver.ServerGetReservedID(&pStd->oid); + + if (SUCCEEDED(hr)) + { + pStd->flags = SORF_NOPING | SORF_FREETHREADED; + pStd->cPublicRefs = 1; + pStd->ipid = ripid; + OXIDFromMOXID(pOXIDEntry->moxid, &pStd->oxid); + + // TRICK: Internally we use the saResAddr.size field as the ptr to + // the OXIDEntry. See ReadObjRef, WriteObjRef and FillObjRef. + + OXIDEntry **ppOXIDEntry = (OXIDEntry **) &ORSTD(objref).saResAddr; + *ppOXIDEntry = pOXIDEntry; + + objref.signature = OBJREF_SIGNATURE; + objref.flags = OBJREF_STANDARD; + objref.iid = riid; + } + + return hr; +} + +//+------------------------------------------------------------------- +// +// Function: MakeCallableFromAnyApt, private +// +// Synopsis: set SORF_FREETHREADED in OBJREF so unmarshaled proxy +// can be called from any apartment. +// +// History: 16-Jan-96 Rickhi Created. +// +//-------------------------------------------------------------------- +void MakeCallableFromAnyApt(OBJREF &objref) +{ + STDOBJREF *pStd = &ORSTD(objref).std; + pStd->flags |= SORF_FREETHREADED; +} + +//+------------------------------------------------------------------- +// +// Function: FindStdMarshal, private +// +// Synopsis: Finds the CStdMarshal for the OID read from the stream +// +// Arguements: [objref] - object reference +// [ppStdMshl] - CStdMarshal returned, AddRef'd +// +// Algorithm: Read the objref, get the OID. If we already have an identity +// for this OID, use that, otherwise either create an identity +// object, or create a handler (which in turn will create the +// identity). The identity inherits CStdMarshal. +// +// History: 20-Feb-95 Rickhi Created. +// +//-------------------------------------------------------------------- +INTERNAL FindStdMarshal(OBJREF &objref, CStdMarshal **ppStdMshl) +{ + ComDebOut((DEB_MARSHAL, + "FindStdMarshal objref:%x ppStdMshl:%x\n", &objref, ppStdMshl)); + + HRESULT hr = CO_E_OBJNOTCONNECTED; + CStdIdentity *pStdId = NULL; + + if (ChkIfLocalOID(objref, &pStdId)) + { + if (pStdId) + { + hr = S_OK; + } + else + { + hr = CO_E_OBJNOTCONNECTED; + } + } + else + { + STDOBJREF *pStd = &ORSTD(objref).std; + ComDebOut((DEB_MARSHAL, "poid: %x\n", &pStd->oid)); + + ASSERT_LOCK_RELEASED + LOCK + + OXIDEntry *pOXIDEntry = GetOXIDFromObjRef(objref); + + // OXID is for different apartment, check the identity table for + // an existing OID. + + MOID moid; + MOIDFromOIDAndMID(pStd->oid, pOXIDEntry->pMIDEntry->mid, &moid); + + hr = LookupIDFromID(moid, TRUE, &pStdId); + + if (FAILED(hr)) + { + CStdIdentity *pStdIdPrev = NULL; + BOOL fDuplicate = FALSE; + + if (objref.flags & OBJREF_STANDARD) + { + // create an instance of the identity for this OID. We want + // to be holding the lock while we do this since it wont + // exercise any app code. + + hr = CreateIdentityHandler(NULL, pStd->flags, + IID_IStdIdentity, (void **)&pStdId); + AssertOutPtrIface(hr, pStdId); + + if (SUCCEEDED(hr)) + { + // set the identity while holding the lock. The result is + // checked below and we release if this fails. + + hr = pStdId->SetOID(moid); + Win4Assert(pStdIdPrev == NULL); + } + } + else + { + // create an instance of the handler. the handler will + // aggregate in the identity, but will pass GUID_NULL for + // the OID so that the identity is not set in the table yet. + + Win4Assert(!(ORHDL(objref).std.flags & SORF_FREETHREADED)); + + // dont want to hold the lock while creating the handler + // since this involves running app code and calling the + // SCM etc. + + UNLOCK + ASSERT_LOCK_RELEASED + + hr = CoCreateInstance(ORHDL(objref).clsid, NULL, + CLSCTX_INPROC_HANDLER, + IID_IStdIdentity, (void **)&pStdId); + + AssertOutPtrIface(hr, pStdId); + + ASSERT_LOCK_RELEASED + LOCK + + // look for the OID in the table again, since it may have + // been added while we released the lock to create the + // handler. + + if (SUCCEEDED(LookupIDFromID(moid, TRUE, &pStdIdPrev))) + { + // object was unmarshaled while we released the lock + // to create the handler, so we will use the existing one. + // since we are releasing app code, we need to release the + // lock. + + fDuplicate = TRUE; + } + else if (SUCCEEDED(hr)) + { + // set the OID now while we are holding the lock. + hr = pStdId->SetOID(moid); + Win4Assert(pStdIdPrev == NULL); + } + } + + if (pStdId && (FAILED(hr) || fDuplicate)) + { + Win4Assert( (FAILED(hr) && (pStdIdPrev == NULL)) || + (fDuplicate && (pStdIdPrev != NULL)) ); + UNLOCK + ASSERT_LOCK_RELEASED + + pStdId->Release(); + pStdId = pStdIdPrev; + + ASSERT_LOCK_RELEASED + LOCK + } + } + + UNLOCK + ASSERT_LOCK_RELEASED + } + + *ppStdMshl = (CStdMarshal *)pStdId; + AssertOutPtrIface(hr, *ppStdMshl); + + ComDebOut((DEB_MARSHAL, + "FindStdMarshal pStdMshl:%x hr:%x\n", *ppStdMshl, hr)); + return hr; +} + +//+------------------------------------------------------------------------ +// +// Function: CompleteObjRef, public +// +// Synopsis: Fills in the missing fields of an OBJREF from a STDOBJREF +// and resolves the OXID. Also sets fLocal to TRUE if the +// object was marshaled in this apartment. +// +// History: 22-Jan-96 Rickhi Created +// +//------------------------------------------------------------------------- +HRESULT CompleteObjRef(OBJREF &objref, OXID_INFO &oxidInfo, REFIID riid, BOOL *pfLocal) +{ + // tweak the objref so we can call ReleaseMarshalObjRef or UnmarshalObjRef + objref.signature = OBJREF_SIGNATURE; + objref.flags = OBJREF_STANDARD; + objref.iid = riid; + + HRESULT hr = InitChannelIfNecessary(); + if (FAILED(hr)) + return hr; + + ASSERT_LOCK_RELEASED + LOCK + + OXIDEntry *pOXIDEntry = NULL; + MIDEntry *pMIDEntry; + hr = GetLocalMIDEntry(&pMIDEntry); + + if (SUCCEEDED(hr)) + { + hr = FindOrCreateOXIDEntry(ORSTD(objref).std.oxid, + oxidInfo, + FOCOXID_NOREF, + gpsaLocalResolver, + gLocalMid, + pMIDEntry, + &pOXIDEntry); + } + + if (SUCCEEDED(hr)) + { + OXIDEntry **ppOXIDEntry = (OXIDEntry **) &ORSTD(objref).saResAddr; + *ppOXIDEntry = pOXIDEntry; + + *pfLocal = (pOXIDEntry == GetLocalOXIDEntry()); + } + + UNLOCK + ASSERT_LOCK_RELEASED + return hr; +} + diff --git a/private/ole32/com/dcomrem/marshal.hxx b/private/ole32/com/dcomrem/marshal.hxx new file mode 100644 index 000000000..54d57d202 --- /dev/null +++ b/private/ole32/com/dcomrem/marshal.hxx @@ -0,0 +1,388 @@ +//+------------------------------------------------------------------- +// +// File: marshal.hxx +// +// Contents: class for standard interface marshaling +// +// Classes: CStdMarshal +// +// History: 20-Feb-95 Rickhi Created +// +//-------------------------------------------------------------------- +#ifndef _MARSHAL_HXX_ +#define _MARSHAL_HXX_ + +#include <ipidtbl.hxx> // CIPIDTable +#include <remunk.h> // IRemUnknown, REMINTERFACEREF +#include <locks.hxx> + +// convenient mappings +#define ORCST(objref) objref.u_objref.u_custom +#define ORSTD(objref) objref.u_objref.u_standard +#define ORHDL(objref) objref.u_objref.u_handler + + +// bits that must be zero in the flags fields +#define OBJREF_RSRVD_MBZ ~(OBJREF_STANDARD | OBJREF_HANDLER | OBJREF_CUSTOM) + +#define SORF_RSRVD_MBZ ~(SORF_NOPING | SORF_OXRES1 | SORF_OXRES2 | \ + SORF_OXRES3 | SORF_OXRES4 | SORF_OXRES5 | \ + SORF_OXRES6 | SORF_OXRES7 | SORF_OXRES8) + + +// Internal Uses of the reserved SORF_OXRES flags. + +// SORF_TBLWEAK is needed so that RMD works correctly on TABLEWEAK +// marshaling, so it is ignored by unmarshalers. Therefore, we use one of +// the bits reserved for the object exporter that must be ignored by +// unmarshalers. +// +// SORF_WEAKREF is needed for container weak references, when handling +// an IRemUnknown::RemQueryInterface on a weak interface. This is a strictly +// local (windows) machine protocol, so we use a reserved bit. +// +// SORF_NONNDR is needed for interop of 16bit custom (non-NDR) marshalers +// with 32bit, since the 32bit guys want to use MIDL (NDR) to talk to other +// 32bit processes and remote processes, but the custom (non-NDR) format to +// talk to local 16bit guys. In particular, this is to support OLE Automation. +// +// SORF_FREETHREADED is needed when we create a proxy to the SCM interface +// in the apartment model. All apartments can use the same proxy so we avoid +// the test for calling on the correct thread. + +#define SORF_TBLWEAK SORF_OXRES1 // (table) weak reference +#define SORF_WEAKREF SORF_OXRES2 // (normal) weak reference +#define SORF_NONNDR SORF_OXRES3 // stub does not use NDR marshaling +#define SORF_FREETHREADED SORF_OXRES4 // proxy may be used on any thread + + +// new MARSHAL FLAG constants. +const DWORD MSHLFLAGS_WEAK = 8; +const DWORD MSHLFLAGS_KEEPALIVE = 32; + +// definitions to simplify coding +const DWORD MSHLFLAGS_TABLE = MSHLFLAGS_TABLESTRONG | MSHLFLAGS_TABLEWEAK; + +const DWORD MSHLFLAGS_USER_MASK = MSHLFLAGS_NORMAL | + MSHLFLAGS_TABLEWEAK | + MSHLFLAGS_TABLESTRONG | + MSHLFLAGS_NOPING; + +const DWORD MSHLFLAGS_ALL = MSHLFLAGS_NORMAL | // 0x00 + MSHLFLAGS_TABLEWEAK | // 0x01 + MSHLFLAGS_TABLESTRONG | // 0x02 + MSHLFLAGS_NOPING | // 0x04 + MSHLFLAGS_WEAK | // 0x08 + MSHLFLAGS_KEEPALIVE | // 0x20 + MSHLFLAGS_NOTIFYACTIVATION; // 0x8000000 + +// forward class declarations +class CStdIdentity; +class CStdMarshal; +class CRpcChannelBuffer; + +extern IMarshal *gpStdMarshal; + + +// internal subroutines used by CStdMarshal and CoUnmarshalInterface +INTERNAL ReadObjRef (IStream *pStm, OBJREF &objref); +INTERNAL WriteObjRef(IStream *pStm, OBJREF &objref, DWORD dwDestCtx); +INTERNAL MakeFakeObjRef(OBJREF &objref, OXIDEntry *pOXIDEntry, REFIPID ipid, REFIID riid); +INTERNAL_(void) FreeObjRef(OBJREF &objref); +INTERNAL_(OXIDEntry *)GetOXIDFromObjRef(OBJREF &objref); + +INTERNAL RemoteQueryInterface(IRemUnknown *pRemUnk, IPIDEntry *pIPIDEntry, + USHORT cIIDs, IID *pIIDs, + REMQIRESULT **ppQiRes, BOOL fWeakClient); +INTERNAL RemoteAddRef(IPIDEntry *pIPIDEntry, OXIDEntry *pOXIDEntry, ULONG cStrongRefs, ULONG cSecureRefs); +INTERNAL RemoteReleaseObjRef(OBJREF &objref); +INTERNAL RemoteReleaseStdObjRef(STDOBJREF *pStd, OXIDEntry *pOXIDEntry); +INTERNAL RemoteReleaseRifRef(OXIDEntry *pOXIDEntry, USHORT cRifRef, + REMINTERFACEREF *pRifRef); +INTERNAL PostReleaseRifRef(OXIDEntry *pOXIDEntry, USHORT cRifRef, + REMINTERFACEREF *pRifRef); +INTERNAL HandlePostReleaseRifRef(LPARAM param); +INTERNAL RemoteChangeRifRef(OXIDEntry *pOXIDEntry, DWORD dwFlags, + USHORT cRifRef, REMINTERFACEREF *pRifRef); +INTERNAL FindStdMarshal(OBJREF &objref, CStdMarshal **ppStdMshl); + +#if DBG==1 +void DbgDumpSTD(STDOBJREF *pStd); +#else +inline void DbgDumpSTD(STDOBJREF *pStd) {}; +#endif + + +// Definition of values for dwFlags field of CStdMarshal +typedef enum tagSMFLAGS +{ + SMFLAGS_CLIENT_SIDE = 0x01, // object is local to this process + SMFLAGS_PENDINGDISCONNECT = 0x02, // disconnect is pending + SMFLAGS_REGISTEREDOID = 0x04, // OID is registered with resolver + SMFLAGS_DISCONNECTED = 0x08, // really disconnected + SMFLAGS_FIRSTMARSHAL = 0x10, // first time marshalled + SMFLAGS_HANDLER = 0x20, // object has a handler + SMFLAGS_WEAKCLIENT = 0x40, // client has weak ref to server + SMFLAGS_IGNORERUNDOWN = 0x80, // dont rundown this object + SMFLAGS_CLIENTMARSHALED = 0x100,// client-side has re-marshaled object + SMFLAGS_NOPING = 0x200,// this object is not pinged + SMFLAGS_TRIEDTOCONNECT = 0x400 // attempted ConnectIPIDEntry +} SMFLAGS; + + +//+---------------------------------------------------------------- +// +// structure: SQIResult +// +// synopsis: structure used for QueryRemoteInterfaces +// +//+---------------------------------------------------------------- +typedef struct tagSQIResult +{ + void *pv; // interface pointer + HRESULT hr; // result of the QI call +} SQIResult; + + +//+---------------------------------------------------------------- +// +// Class: CStdMarshal, private +// +// Purpose: Provides standard marshaling of interface pointers. +// +// History: 20-Feb-95 Rickhi Created +// +//----------------------------------------------------------------- +class CStdMarshal : public IMarshal +{ +public: + CStdMarshal(); + ~CStdMarshal(); + void Init(IUnknown *pUnk, CStdIdentity *pstdID, + REFCLSID rclsidHandler, DWORD dwFlags); + + + // IMarshal - IUnknown taken from derived classes + STDMETHOD(GetUnmarshalClass)(REFIID riid, LPVOID pv, DWORD dwDestCtx, + LPVOID pvDestCtx, DWORD mshlflags, LPCLSID pClsid); + STDMETHOD(GetMarshalSizeMax)(REFIID riid, LPVOID pv, DWORD dwDestCtx, + LPVOID pvDestCtx, DWORD mshlflags, LPDWORD pSize); + STDMETHOD(MarshalInterface)(LPSTREAM pStm, REFIID riid, LPVOID pv, + DWORD dwDestCtx, LPVOID pvDestCtx, DWORD mshlflags); + STDMETHOD(UnmarshalInterface)(LPSTREAM pStm, REFIID riid, LPVOID *ppv); + STDMETHOD(ReleaseMarshalData)(LPSTREAM pStm); + STDMETHOD(DisconnectObject)(DWORD dwReserved); + + + // used by coapi's for unmarshaling/releasing + HRESULT MarshalObjRef(OBJREF &objref, REFIID riid, LPVOID pv, DWORD mshlflags); + HRESULT MarshalIPID(REFIID riid, ULONG cRefs, DWORD mshlflags, IPIDEntry **ppEntry); + HRESULT UnmarshalObjRef(OBJREF &objref, void **ppv); + HRESULT ReleaseMarshalObjRef(OBJREF &objref); + + + // used by client side StdIdentity to make calls to the remote server + HRESULT QueryRemoteInterfaces(USHORT cIIDs, IID *pIIDs, SQIResult *pQIRes); + BOOL InstantiatedProxy(REFIID riid, void **ppv, HRESULT *phr); + BOOL RemIsConnected(void); + void Disconnect(void); + void ReconnectProxies(void); + HRESULT FindIPIDEntry(REFIID riid, IPIDEntry **ppEntry); + void SetMarshalTime(void) { _dwMarshalTime = GetCurrentTime() ;} + void SetNoPing(void) { _dwFlags |= SMFLAGS_NOPING; } + HRESULT RemoteChangeRef(BOOL fLock, BOOL fLastUnlockReleases); + + // used by CRpcChannelBuffer + HRESULT LookupStub(REFIID riid, IRpcStubBuffer **ppStub); + ULONG LockClient(void); + ULONG UnLockClient(void); + void LockServer(void); + void UnLockServer(void); + + // used by CRemoteUnknown + HRESULT PreventDisconnect(); + HRESULT PreventPendingDisconnect(); + HRESULT HandlePendingDisconnect(HRESULT hr); + HRESULT IncSrvIPIDCnt(IPIDEntry *pEntry, ULONG cRefs, ULONG cPrivateRefs, + SECURITYBINDING *pName, DWORD mshlflags); + void DecSrvIPIDCnt(IPIDEntry *pEntry, ULONG cRefs, ULONG cPrivateRefs, + SECURITYBINDING *pName, DWORD mshlflags); + BOOL CanRunDown(DWORD iNow); + void FillSTD(STDOBJREF *pStd, ULONG cRefs, DWORD mshlflags, IPIDEntry *pEntry); + IPIDEntry *GetConnectedIPID(); + HRESULT GetSecureRemUnk( IRemUnknown **, OXIDEntry * ); + void SetSecureRemUnk( IRemUnknown *pSecure ) { _pSecureRemUnk = pSecure; } + BOOL CheckSecureRemUnk() { return _pSecureRemUnk != NULL; } + + // used by CoLockObjectExternal + void IncTableCnt(void); + void DecTableCnt(void); + + // used by CClientSecurity + HRESULT FindIPIDEntryByInterface( void * pProxy, IPIDEntry ** ppEntry ); + HRESULT PrivateCopyProxy( IUnknown *pProxy, IUnknown **ppProxy ); + +#if DBG==1 + void DbgDumpInterfaceList(void); +#else + void DbgDumpInterfaceList(void) {} +#endif + + friend INTERNAL ReleaseMarshalObjRef(OBJREF &objref); + +private: + + HRESULT FirstMarshal(IUnknown *pUnk, DWORD mshlflags); + HRESULT CreateChannel(OXIDEntry *pOXIDEntry, DWORD dwFlags, REFIPID ripid, + REFIID riid, CRpcChannelBuffer **ppChnl); + + + // Internal methods to find or create interface proxies or stubs + HRESULT CreateProxy(REFIID riid, IRpcProxyBuffer **ppProxy, void **ppv, BOOL *pfNonNDR); + HRESULT CreateStub(REFIID riid, IRpcStubBuffer **ppStub, void **ppv, BOOL *pfNonNDR); + HRESULT GetPSFactory(REFIID riid, IUnknown *pUnkWow, BOOL fServer, IPSFactoryBuffer **ppIPSF, BOOL *pfNonNDR); + + + // IPID Table Manipulation subroutines + HRESULT UnmarshalIPID(REFIID riid, STDOBJREF *pStd, OXIDEntry *pOXIDEntry, void **ppv); + HRESULT FindIPIDEntryByIPID(REFIPID ripid, IPIDEntry **ppEntry); + HRESULT MakeSrvIPIDEntry(REFIID riid, IPIDEntry **ppEntry); + HRESULT MakeCliIPIDEntry(REFIID riid, STDOBJREF *pStd, OXIDEntry *pOXIDEntry, IPIDEntry **ppEntry); + HRESULT ConnectIPIDEntry(STDOBJREF *pStd, OXIDEntry *pOXIDEntry, IPIDEntry *pEntry); + HRESULT AddIPIDEntry(OXIDEntry *pOXIDEntry, IPID *pipid, REFIID riid, + CRpcChannelBuffer *pChnl, IUnknown *pUnkStub, + void *pv, IPIDEntry **ppEntry); + void DisconnectCliIPIDs(void); + void DisconnectSrvIPIDs(void); + void ReleaseCliIPIDs(void); + void IncStrongAndNotifyAct(IPIDEntry *pEntry, DWORD mshlflags); + void DecStrongAndNotifyAct(IPIDEntry *pEntry, DWORD mshlflags); + + + + // reference counting routines + HRESULT GetNeededRefs(STDOBJREF *pStd, OXIDEntry *pOXIDEntry, IPIDEntry *pEntry); + HRESULT RemQIAndUnmarshal(USHORT cIIDs, IID* pIIDs, SQIResult *pQIRes); + void FillObjRef(OBJREF &objref, ULONG cRefs, DWORD mshlflags, IPIDEntry *pEntry); + + BOOL ClientSide() { return (_dwFlags & SMFLAGS_CLIENT_SIDE); } + BOOL ServerSide() { return !(_dwFlags & SMFLAGS_CLIENT_SIDE); } + +#if DBG==1 + void AssertValid(); + void AssertDisconnectPrevented(); + void ValidateSTD(STDOBJREF *pStd); + void ValidateIPIDEntry(IPIDEntry *pEntry); + void DbgWalkIPIDs(); + REFMOXID GetMOXID(void); +#else + void AssertValid() {} + void AssertDisconnectPrevented() {} + void ValidateSTD(STDOBJREF *pStd) {} + void ValidateIPIDEntry(IPIDEntry *pEntry) {} + void DbgWalkIPIDs() {} +#endif + + + DWORD _dwFlags; // flags info (see SMFLAGS) + LONG _cIPIDs; // count of IPIDs in this object + IPIDEntry *_pFirstIPID; // first IPID of this object + CStdIdentity *_pStdId; // standard identity + CRpcChannelBuffer *_pChnl; // channel ptr + CLSID _clsidHandler; // clsid of handler (if needed) + LONG _cNestedCalls; // count of nested calls + LONG _cTableRefs; // count of table marshals + DWORD _dwMarshalTime; // tick count when last marshalled + IRemUnknown *_pSecureRemUnk; // remunk with app specified security +}; + + +//+------------------------------------------------------------------------ +// +// Member: CStdMarshal::CanRunDown +// +// Synopsis: determines if it is OK to rundown this object, based on +// the current time and the marshaled state of the object. +// +// History: 24-Aug-95 Rickhi Created +// +//------------------------------------------------------------------------- + +// time period of one ping, used to determine if OK to rundown OID +extern DWORD giPingPeriod; + +inline BOOL CStdMarshal::CanRunDown(DWORD iNow) +{ + ASSERT_LOCK_HELD + + // Make sure the interface hasn't been marshalled since it + // was last pinged. This calculation handles the wrap case. + + if (!(_dwFlags & (SMFLAGS_IGNORERUNDOWN | SMFLAGS_NOPING)) && + (iNow - _dwMarshalTime >= giPingPeriod)) + { + Win4Assert(_cTableRefs == 0); + ComDebOut((DEB_MARSHAL, "Running Down Object this:%x\n", this)); + return TRUE; + } + + return FALSE; +} + +//+------------------------------------------------------------------------ +// +// Member: CStdMarshal::LockServer/UnLockServer +// +// Synopsis: Locks the server side object during incoming calls in order +// to prevent the object going away in a nested disconnect. +// +// History: 12-Jun-95 Rickhi Created +// +//------------------------------------------------------------------------- +inline void CStdMarshal::LockServer(void) +{ + Win4Assert(ServerSide()); + ASSERT_LOCK_HELD + + AddRef(); + InterlockedIncrement(&_cNestedCalls); +} + +inline void CStdMarshal::UnLockServer(void) +{ + Win4Assert(ServerSide()); + ASSERT_LOCK_RELEASED + + if ((InterlockedDecrement(&_cNestedCalls) == 0) && + (_dwFlags & SMFLAGS_PENDINGDISCONNECT)) + { + // a disconnect was pending, do that now. + Disconnect(); + } + + Release(); +} + +//+------------------------------------------------------------------------ +// +// Member: CStdMarshal::GetConnectedIPID +// +// Synopsis: Finds the first connected IPID entry. +// +// History: 10-Apr-96 AlexMit Plagerized +// +//------------------------------------------------------------------------- +inline IPIDEntry *CStdMarshal::GetConnectedIPID() +{ + Win4Assert( _pFirstIPID != NULL ); + IPIDEntry *pIPIDEntry = _pFirstIPID; + + // Find an IPID entry that has an OXID pointer. + while (pIPIDEntry->dwFlags & IPIDF_DISCONNECTED) + { + pIPIDEntry = pIPIDEntry->pNextOID; + } + Win4Assert( pIPIDEntry != NULL ); + return pIPIDEntry; +} +#endif // _MARSHAL_HXX_ diff --git a/private/ole32/com/dcomrem/orpc_dbg.c b/private/ole32/com/dcomrem/orpc_dbg.c new file mode 100644 index 000000000..ef410ce1c --- /dev/null +++ b/private/ole32/com/dcomrem/orpc_dbg.c @@ -0,0 +1,643 @@ +//-------------------------------------------------------------------------- +// ORPC_DBG.C (tabs 4) +// +// !!!!!!!!! !!!!!!!!! NOTE NOTE NOTE NOTE !!!!!!!!! !!!!!!!!!! +// +// SEND MAIL TO SANJAYS IF YOU MODIFY THIS FILE! +// WE MUST KEEP OLE AND LANGUAGES IN SYNC! +// +// !!!!!!!!! !!!!!!!!! NOTE NOTE NOTE NOTE !!!!!!!!! !!!!!!!!!! +// +// Created 08-Oct-1993 by Mike Morearty. The master copy of this file +// is in the LANGAPI project owned by the Languages group. +// +// Helper functions for OLE RPC debugging. +//-------------------------------------------------------------------------- + +#include <windows.h> +#ifndef _CHICAGO_ +#include <tchar.h> +#endif + +#include "orpc_dbg.h" + +static TCHAR tszAeDebugName[] = TEXT("AeDebug"); +static TCHAR tszAutoName[] = TEXT("Auto"); +static TCHAR tszOldAutoName[] = TEXT("OldAuto"); +static TCHAR tszDebugObjectRpcEnabledName[] = +#ifdef _CHICAGO_ + "SOFTWARE\\Microsoft\\Windows\\CurrentVersion\\DebugObjectRPCEnabled"; +#else + TEXT("SOFTWARE\\Microsoft\\Windows NT\\CurrentVersion\\DebugObjectRPCEnabled"); +#endif + +// Emit the ORPC signature into the bytestream of the function +#define ORPC_EMIT_SIGNATURE() 'M', 'A', 'R', 'B', + +// Emit a LONG into the bytestream +#define ORPC_EMIT_LONG(l) \ + ((l >> 0) & 0xFF), \ + ((l >> 8) & 0xFF), \ + ((l >> 16) & 0xFF), \ + ((l >> 24) & 0xFF), + +// Emit a WORD into the bytestream +#define ORPC_EMIT_WORD(w) \ + ((w >> 0) & 0xFF), \ + ((w >> 8) & 0xFF), + +// Emit a BYTE into the bytestream +#define ORPC_EMIT_BYTE(b) \ + b, + +// Emit a GUID into the bytestream +#define ORPC_EMIT_GUID(l, w1, w2, b1, b2, b3, b4, b5, b6, b7, b8) \ + ORPC_EMIT_LONG(l) \ + ORPC_EMIT_WORD(w1) ORPC_EMIT_WORD(w2) \ + ORPC_EMIT_BYTE(b1) ORPC_EMIT_BYTE(b2) \ + ORPC_EMIT_BYTE(b3) ORPC_EMIT_BYTE(b4) \ + ORPC_EMIT_BYTE(b5) ORPC_EMIT_BYTE(b6) \ + ORPC_EMIT_BYTE(b7) ORPC_EMIT_BYTE(b8) + +BYTE rgbClientGetBufferSizeSignature[] = +{ + ORPC_EMIT_SIGNATURE() + ORPC_EMIT_GUID(0x9ED14F80, 0x9673, 0x101A, 0xB0, 0x7B, + 0x00, 0xDD, 0x01, 0x11, 0x3F, 0x11) + ORPC_EMIT_LONG(0) +}; + +BYTE rgbClientFillBufferSignature[] = +{ + ORPC_EMIT_SIGNATURE() + ORPC_EMIT_GUID(0xDA45F3E0, 0x9673, 0x101A, 0xB0, 0x7B, + 0x00, 0xDD, 0x01, 0x11, 0x3F, 0x11) + ORPC_EMIT_LONG(0) +}; + +BYTE rgbClientNotifySignature[] = +{ + ORPC_EMIT_SIGNATURE() + ORPC_EMIT_GUID(0x4F60E540, 0x9674, 0x101A, 0xB0, 0x7B, + 0x00, 0xDD, 0x01, 0x11, 0x3F, 0x11) + ORPC_EMIT_LONG(0) +}; + +BYTE rgbServerNotifySignature[] = +{ + ORPC_EMIT_SIGNATURE() + ORPC_EMIT_GUID(0x1084FA00, 0x9674, 0x101A, 0xB0, 0x7B, + 0x00, 0xDD, 0x01, 0x11, 0x3F, 0x11) + ORPC_EMIT_LONG(0) +}; + +BYTE rgbServerGetBufferSizeSignature[] = +{ + ORPC_EMIT_SIGNATURE() + ORPC_EMIT_GUID(0x22080240, 0x9674, 0x101A, 0xB0, 0x7B, + 0x00, 0xDD, 0x01, 0x11, 0x3F, 0x11) + ORPC_EMIT_LONG(0) +}; + +BYTE rgbServerFillBufferSignature[] = +{ + ORPC_EMIT_SIGNATURE() + ORPC_EMIT_GUID(0x2FC09500, 0x9674, 0x101A, 0xB0, 0x7B, + 0x00, 0xDD, 0x01, 0x11, 0x3F, 0x11) + ORPC_EMIT_LONG(0) +}; + +// Macro to deal with assigning refiid for both C and C++. +#if defined(__cplusplus) +#define ASSIGN_REFIID(orpc_all, iid) ((orpc_all).refiid = &iid) +#else +#define ASSIGN_REFIID(orpc_all, iid) ((orpc_all).refiid = iid) +#endif + +#pragma code_seg(".orpc") + +//-------------------------------------------------------------------------- +// SzSubStr() +// +// Find str2 in str2 +//-------------------------------------------------------------------------- + +static LPTSTR SzSubStr(LPTSTR str1, LPTSTR str2) +{ + CharLower(str1); + +#ifdef _CHICAGO_ + return strstr(str1, str2); +#else + return _tcsstr(str1, str2); +#endif +} + +//-------------------------------------------------------------------------- +// DebugORPCSetAuto() +// +// Sets the "Auto" value in the "AeDebug" key to "1", and saves info +// necessary to restore the previous value later. +//-------------------------------------------------------------------------- + +BOOL WINAPI DebugORPCSetAuto(VOID) +{ + HKEY hkey; + TCHAR rgtchDebugger[256]; // 256 is the length NT itself uses for this + TCHAR rgtchAuto[256]; + TCHAR rgtchOldAuto[2]; // don't need to get the whole thing + + // If the "DebugObjectRPCEnabled" key does not exist, then do not + // cause any notifications + if (RegOpenKey(HKEY_LOCAL_MACHINE, tszDebugObjectRpcEnabledName, &hkey)) + return FALSE; + RegCloseKey(hkey); + + // If the AeDebug debugger string does not exist, or if it contains + // "drwtsn32" anywhere in it, then don't cause any notifications, + // because Dr. Watson is not capable of fielding OLE notifications. + if (!GetProfileString(tszAeDebugName, TEXT("Debugger"), TEXT(""), + rgtchDebugger, sizeof(rgtchDebugger)) || + SzSubStr(rgtchDebugger, TEXT("drwtsn32")) != NULL) + { + return FALSE; + } + + // Must ensure that the "Auto" value in the AeDebug registry key + // is set to "1", so that the embedded INT 3 below will cause the + // debugger to be automatically spawned if it doesn't already + // exist. + + // Get old "Auto" value + GetProfileString(tszAeDebugName, tszAutoName, TEXT(""), + rgtchAuto, sizeof(rgtchAuto)); + + // If "OldAuto" already existed, then it's probably left over from + // a previous invocation of the debugger, so don't overwrite it. + // Otherwise, copy "Auto" value to "OldAuto" + if (!GetProfileString(tszAeDebugName, tszOldAutoName, TEXT(""), + rgtchOldAuto, sizeof(rgtchOldAuto))) + { + if (!WriteProfileString(tszAeDebugName, tszOldAutoName, rgtchAuto)) + return FALSE; + } + + // Change "Auto" value to "1" + if (!WriteProfileString(tszAeDebugName, tszAutoName, TEXT("1"))) + return FALSE; + + return TRUE; +} + +//-------------------------------------------------------------------------- +// DebugORPCRestoreAuto() +// +// Restores the previous value of the "Auto" value in the AeDebug key. +//-------------------------------------------------------------------------- + +VOID WINAPI DebugORPCRestoreAuto(VOID) +{ + TCHAR rgtchAuto[256]; + + // Restore old Auto value (or delete it if it didn't exist before). + // Very minor bug here: if "Auto" was previously "", then we will + // now delete it. That's not a big deal though, as an empty "Auto" + // and a nonexistent one have the same effect. + GetProfileString(tszAeDebugName, tszOldAutoName, TEXT(""), + rgtchAuto, sizeof(rgtchAuto)); + + WriteProfileString(tszAeDebugName, tszAutoName, + rgtchAuto[0] ? rgtchAuto : NULL); + + // Delete OldAuto value + WriteProfileString(tszAeDebugName, tszOldAutoName, NULL); +} + + // This pragma is necessary in case the compiler chooses not to inline these +// functions (e.g. in a debug build, when optimizations are off). + +#pragma code_seg(".orpc") + +__inline DWORD WINAPI OrpcBreakpointFilter( + LPEXCEPTION_POINTERS lpExcptPtr, + BOOL *lpAeDebugAttached ) \ +{ + BOOL fAeDebugAttached = FALSE; + DWORD dwRet; + + if ( lpExcptPtr->ExceptionRecord->ExceptionCode == EXCEPTION_ORPC_DEBUG ) + { + if ( UnhandledExceptionFilter(lpExcptPtr) == EXCEPTION_CONTINUE_SEARCH ) + { + // It is important that we don't return EXCEPTION_CONTINUE_SEARCH. + // This is because there might an handler up the stack which could + // handle this exception. Just set the flag indicating that a + // debugger is now attached. + + fAeDebugAttached = TRUE; + } + dwRet = EXCEPTION_EXECUTE_HANDLER; + } + else + { + // Not one of our exceptions. + dwRet = EXCEPTION_CONTINUE_SEARCH; + } + + if ( lpAeDebugAttached != NULL ) + (*lpAeDebugAttached) = fAeDebugAttached; + + return dwRet; +} + +ULONG WINAPI DebugORPCClientGetBufferSize( + RPCOLEMESSAGE * pMessage, + REFIID iid, + void * reserved, + IUnknown * pUnkProxyMgr, + LPORPC_INIT_ARGS lpInitArgs, + BOOL fHookEnabled) +{ + ULONG cbBuffer = 0; + ORPC_DBG_ALL orpc_all = {0}; + ORPC_DBG_ALL * lpOrpcAll = &orpc_all; + + if (!fHookEnabled) + return 0; // We should be able to assert that this never happens. + + orpc_all.pSignature = rgbClientGetBufferSizeSignature; + orpc_all.pMessage = pMessage; + orpc_all.reserved = reserved; + orpc_all.pUnkProxyMgr = pUnkProxyMgr; + orpc_all.lpcbBuffer = &cbBuffer; + ASSIGN_REFIID(orpc_all, iid); + + if ( lpInitArgs == NULL || lpInitArgs->lpIntfOrpcDebug == NULL ) + { + // Do Orpc debug notification using an exception. + __try + { + RaiseException(EXCEPTION_ORPC_DEBUG, 0, 1, (LPDWORD)&lpOrpcAll); + } + __except(OrpcBreakpointFilter(GetExceptionInformation(), NULL)) + { + // this just goes down to the to the return. + } + } + else + { + IOrpcDebugNotify __RPC_FAR *lpIntf = lpInitArgs->lpIntfOrpcDebug; + + // call the appropriate method in the registered interface + // ( this is typically used by in-proc debuggers) +#if defined(__cplusplus) && !defined(CINTERFACE) + lpIntf->ClientGetBufferSize(lpOrpcAll); +#else + lpIntf->lpVtbl->ClientGetBufferSize(lpIntf, lpOrpcAll); +#endif + + } + + return cbBuffer; +} + +//-------------------------------------------------------------------------- + +void WINAPI DebugORPCClientFillBuffer( + RPCOLEMESSAGE * pMessage, + REFIID iid, + void * reserved, + IUnknown * pUnkProxyMgr, + void * pvBuffer, + ULONG cbBuffer, + LPORPC_INIT_ARGS lpInitArgs, + BOOL fHookEnabled) +{ + ORPC_DBG_ALL orpc_all = {0}; + ORPC_DBG_ALL * lpOrpcAll = &orpc_all; + + if (!fHookEnabled) + return; // We should be able to assert that this never happens + + orpc_all.pSignature = rgbClientFillBufferSignature; + + orpc_all.pMessage = pMessage; + orpc_all.reserved = reserved; + orpc_all.pUnkProxyMgr = pUnkProxyMgr; + ASSIGN_REFIID(orpc_all, iid); + + orpc_all.pvBuffer = pvBuffer; + orpc_all.cbBuffer = cbBuffer; + + + if ( lpInitArgs == NULL || lpInitArgs->lpIntfOrpcDebug == NULL ) + { + // Do Orpc debug notification using an exception. + __try + { + RaiseException(EXCEPTION_ORPC_DEBUG, 0, 1, (LPDWORD)&lpOrpcAll); + } + __except(OrpcBreakpointFilter(GetExceptionInformation(), NULL)) + { + // this just returns. + } + } + else + { + IOrpcDebugNotify __RPC_FAR *lpIntf = lpInitArgs->lpIntfOrpcDebug; + + // call the appropriate method in the registered interface + // ( this is typically used by in-proc debuggers) +#if defined(__cplusplus) && !defined(CINTERFACE) + lpIntf->ClientFillBuffer(lpOrpcAll); +#else + lpIntf->lpVtbl->ClientFillBuffer(lpIntf, lpOrpcAll); +#endif + } +} + +//-------------------------------------------------------------------------- + +// This special value is to ensure backward compatibility with VC 2.0. +// It is not exposed in the header files. The behavior if this is the value +// in the first four bytes of the debug packet, should be identical to +// ORPC_DEBUG_ALWAYS. + +#define ORPC_COMPATIBILITY_CODE (0x4252414DL) + +void WINAPI DebugORPCClientNotify( + RPCOLEMESSAGE * pMessage, + REFIID iid, + void * reserved, + IUnknown * pUnkProxyMgr, + HRESULT hresult, + void * pvBuffer, + ULONG cbBuffer, + LPORPC_INIT_ARGS lpInitArgs, + BOOL fHookEnabled) +{ + ORPC_DBG_ALL orpc_all = {0}; + ORPC_DBG_ALL * lpOrpcAll = &orpc_all; + BOOL fRethrow = FALSE; + + // First check to see if the debugger on the other side + // wants us to notify this side if the hook is not enabled. + if (!fHookEnabled) + { + if (cbBuffer >= 4) + { + LONG orpcCode = *(LONG *)pvBuffer; + if ( orpcCode == ORPC_DEBUG_IF_HOOK_ENABLED) + return; // No notification in this case. + } + } + + orpc_all.pSignature = rgbClientNotifySignature; + + orpc_all.pMessage = pMessage; + orpc_all.reserved = reserved; + orpc_all.pUnkProxyMgr = pUnkProxyMgr; + orpc_all.hresult = hresult; + ASSIGN_REFIID(orpc_all, iid); + + orpc_all.pvBuffer = pvBuffer; + orpc_all.cbBuffer = cbBuffer; + + if ( lpInitArgs == NULL || lpInitArgs->lpIntfOrpcDebug == NULL ) + { + if (DebugORPCSetAuto()) + { + // Do Orpc debug notification using an exception. + __try + { + RaiseException(EXCEPTION_ORPC_DEBUG, 0, 1, (LPDWORD)&lpOrpcAll); + } + __except(OrpcBreakpointFilter(GetExceptionInformation(), &fRethrow)) + { + // Fall through. + } + + if (fRethrow) + { + // At this point we are sure that a debugger is attached + // so we raise this exception outside of a __try block. + RaiseException(EXCEPTION_ORPC_DEBUG, 0, 1, (LPDWORD)&lpOrpcAll); + } + + DebugORPCRestoreAuto(); + } + + } + else + { + IOrpcDebugNotify __RPC_FAR *lpIntf = lpInitArgs->lpIntfOrpcDebug; + + // call the appropriate method in the registered interface + // ( this is typically used by in-proc debuggers) +#if defined(__cplusplus) && !defined(CINTERFACE) + lpIntf->ClientNotify(lpOrpcAll); +#else + lpIntf->lpVtbl->ClientNotify(lpIntf, lpOrpcAll); +#endif + } + +} + +//-------------------------------------------------------------------------- + +void WINAPI DebugORPCServerNotify( + RPCOLEMESSAGE * pMessage, + REFIID iid, + IRpcChannelBuffer * pChannel, + void * pInterface, + IUnknown * pUnkObject, + void * pvBuffer, + ULONG cbBuffer, + LPORPC_INIT_ARGS lpInitArgs, + BOOL fHookEnabled) + +{ + ORPC_DBG_ALL orpc_all = {0}; + ORPC_DBG_ALL * lpOrpcAll = &orpc_all; + BOOL fRethrow = FALSE; + + // First check to see if the debugger on the other side + // wants us to notify this side if the hook is not enabled. + if (!fHookEnabled) + { + if (cbBuffer >= 4) + { + LONG orpcCode = *(LONG *)pvBuffer; + if ( orpcCode == ORPC_DEBUG_IF_HOOK_ENABLED) + return; // No notification in this case. + } + } + + orpc_all.pSignature = rgbServerNotifySignature; + + orpc_all.pMessage = pMessage; + orpc_all.pChannel = pChannel; + orpc_all.pInterface = pInterface; + orpc_all.pUnkObject = pUnkObject; + ASSIGN_REFIID(orpc_all, iid); + + orpc_all.pvBuffer = pvBuffer; + orpc_all.cbBuffer = cbBuffer; + + if ( lpInitArgs == NULL || lpInitArgs->lpIntfOrpcDebug == NULL ) + { + if (DebugORPCSetAuto()) + { + // Do Orpc debug notification using an exception. + __try + { + RaiseException(EXCEPTION_ORPC_DEBUG, 0, 1, (LPDWORD)&lpOrpcAll); + } + __except(OrpcBreakpointFilter(GetExceptionInformation(), &fRethrow)) + { + // Fall through + } + + if (fRethrow) + { + // At this point we are sure that a debugger is attached + // so we raise this exception outside of a __try block. + RaiseException(EXCEPTION_ORPC_DEBUG, 0, 1, (LPDWORD)&lpOrpcAll); + } + + DebugORPCRestoreAuto(); + } + + } + else + { + IOrpcDebugNotify __RPC_FAR *lpIntf = lpInitArgs->lpIntfOrpcDebug; + + // call the appropriate method in the registered interface + // ( this is typically used by in-proc debuggers) +#if defined(__cplusplus) && !defined(CINTERFACE) + lpIntf->ServerNotify(lpOrpcAll); +#else + lpIntf->lpVtbl->ServerNotify(lpIntf, lpOrpcAll); +#endif + } + +} + +//-------------------------------------------------------------------------- + +ULONG WINAPI DebugORPCServerGetBufferSize( + RPCOLEMESSAGE * pMessage, + REFIID iid, + IRpcChannelBuffer * pChannel, + void * pInterface, + IUnknown * pUnkObject, + LPORPC_INIT_ARGS lpInitArgs, + BOOL fHookEnabled) + +{ + ULONG cbBuffer = 0; + ORPC_DBG_ALL orpc_all = {0}; + ORPC_DBG_ALL * lpOrpcAll = &orpc_all; + + if (!fHookEnabled) + return 0; // We should be able to assert that this never happens. + + orpc_all.pSignature = rgbServerGetBufferSizeSignature; + + orpc_all.pMessage = pMessage; + orpc_all.pChannel = pChannel; + orpc_all.pInterface = pInterface; + orpc_all.pUnkObject = pUnkObject; + orpc_all.lpcbBuffer = &cbBuffer; + ASSIGN_REFIID(orpc_all, iid); + + if ( lpInitArgs == NULL || lpInitArgs->lpIntfOrpcDebug == NULL ) + { + // Do Orpc debug notification using an exception. + __try + { + RaiseException(EXCEPTION_ORPC_DEBUG, 0, 1, (LPDWORD)&lpOrpcAll); + } + __except(OrpcBreakpointFilter(GetExceptionInformation(), NULL)) + { + // this just goes down to the return. + } + } + else + { + IOrpcDebugNotify __RPC_FAR *lpIntf = lpInitArgs->lpIntfOrpcDebug; + + // call the appropriate method in the registered interface + // ( this is typically used by in-proc debuggers) +#if defined(__cplusplus) && !defined(CINTERFACE) + lpIntf->ServerGetBufferSize(lpOrpcAll); +#else + lpIntf->lpVtbl->ServerGetBufferSize(lpIntf, lpOrpcAll); +#endif + } + + return cbBuffer; +} + +//-------------------------------------------------------------------------- + +void WINAPI DebugORPCServerFillBuffer( + RPCOLEMESSAGE * pMessage, + REFIID iid, + IRpcChannelBuffer * pChannel, + void * pInterface, + IUnknown * pUnkObject, + void * pvBuffer, + ULONG cbBuffer, + LPORPC_INIT_ARGS lpInitArgs, + BOOL fHookEnabled) +{ + ORPC_DBG_ALL orpc_all = {0}; + ORPC_DBG_ALL * lpOrpcAll = &orpc_all; + + if (!fHookEnabled) + return; // We should be able to assert that this never happens. + + orpc_all.pSignature = rgbServerFillBufferSignature; + + orpc_all.pMessage = pMessage; + orpc_all.pChannel = pChannel; + orpc_all.pInterface = pInterface; + orpc_all.pUnkObject = pUnkObject; + ASSIGN_REFIID(orpc_all, iid); + + orpc_all.pvBuffer = pvBuffer; + orpc_all.cbBuffer = cbBuffer; + + if ( lpInitArgs == NULL || lpInitArgs->lpIntfOrpcDebug == NULL ) + { + // Do Orpc debug notification using an exception. + __try + { + RaiseException(EXCEPTION_ORPC_DEBUG, 0, 1, (LPDWORD)&lpOrpcAll); + } + __except(OrpcBreakpointFilter(GetExceptionInformation(), NULL)) + { + // this just returns. + } + } + else + { + IOrpcDebugNotify __RPC_FAR *lpIntf = lpInitArgs->lpIntfOrpcDebug; + + // call the appropriate method in the registered interface + // ( this is typically used by in-proc debuggers) +#if defined(__cplusplus) && !defined(CINTERFACE) + lpIntf->ServerFillBuffer(lpOrpcAll); +#else + lpIntf->lpVtbl->ServerFillBuffer(lpIntf, lpOrpcAll); +#endif + } +} + +// WARNING: there is no way to "pop" to the previously active code_seg: +// this will revert to what the code seg was when compilation began. +#pragma code_seg() + + diff --git a/private/ole32/com/dcomrem/orpc_dbg.h b/private/ole32/com/dcomrem/orpc_dbg.h new file mode 100644 index 000000000..62b512df7 --- /dev/null +++ b/private/ole32/com/dcomrem/orpc_dbg.h @@ -0,0 +1,219 @@ +//-------------------------------------------------------------------------- +// ORPC_DBG.H (tabs 4) +// +// !!!!!!!!! !!!!!!!!! NOTE NOTE NOTE NOTE !!!!!!!!! !!!!!!!!!! +// +// SEND MAIL TO SANJAYS IF YOU MODIFY THIS FILE! +// WE MUST KEEP OLE AND LANGUAGES IN SYNC! +// +// !!!!!!!!! !!!!!!!!! NOTE NOTE NOTE NOTE !!!!!!!!! !!!!!!!!!! +// +// Created 07-Oct-1993 by Mike Morearty. The master copy of this file +// is in the LANGAPI project owned by the Languages group. +// +// Macros and functions for OLE RPC debugging. For a detailed explanation, +// see OLE2DBG.DOC. +// +//-------------------------------------------------------------------------- + + +#ifndef __ORPC_DBG__ +#define __ORPC_DBG__ + +//-------------------------------------------------------------------------- +// Public: +//-------------------------------------------------------------------------- + +// This structure is the information packet which OLE sends the debugger +// when it is notifying it about an OLE debug event. The first field in this +// structure points to the signature which identifies the type of the debug +// notification. The consumer of the notification can then get the relevant +// information from the struct members. Note that for each OLE debug notification +// only a subset of the struct members are meaningful. + + +typedef struct ORPC_DBG_ALL +{ + BYTE * pSignature; + RPCOLEMESSAGE * pMessage; + const IID * refiid; + IRpcChannelBuffer * pChannel; + IUnknown * pUnkProxyMgr; + void * pInterface; + IUnknown * pUnkObject; + HRESULT hresult; + void * pvBuffer; + ULONG cbBuffer; + ULONG * lpcbBuffer; + void * reserved; +} ORPC_DBG_ALL; + +typedef ORPC_DBG_ALL __RPC_FAR *LPORPC_DBG_ALL; + +// Interface definition for IOrpcDebugNotify + +typedef interface IOrpcDebugNotify IOrpcDebugNotify; + +typedef IOrpcDebugNotify __RPC_FAR * LPORPCDEBUGNOTIFY; + +#if defined(__cplusplus) && !defined(CINTERFACE) + + interface IOrpcDebugNotify : public IUnknown + { + public: + virtual VOID __stdcall ClientGetBufferSize (LPORPC_DBG_ALL) = 0; + virtual VOID __stdcall ClientFillBuffer (LPORPC_DBG_ALL) = 0; + virtual VOID __stdcall ClientNotify (LPORPC_DBG_ALL) = 0; + virtual VOID __stdcall ServerNotify (LPORPC_DBG_ALL) = 0; + virtual VOID __stdcall ServerGetBufferSize (LPORPC_DBG_ALL) = 0; + virtual VOID __stdcall ServerFillBuffer (LPORPC_DBG_ALL) = 0; + }; + +#else /* C style interface */ + + typedef struct IOrpcDebugNotifyVtbl + { + HRESULT ( __stdcall __RPC_FAR *QueryInterface )( + IOrpcDebugNotify __RPC_FAR * This, + /* [in] */ REFIID riid, + /* [out] */ void __RPC_FAR *__RPC_FAR *ppvObject); + + ULONG ( __stdcall __RPC_FAR *AddRef )( + IOrpcDebugNotify __RPC_FAR * This); + + ULONG ( __stdcall __RPC_FAR *Release )( + IOrpcDebugNotify __RPC_FAR * This); + + VOID ( __stdcall __RPC_FAR *ClientGetBufferSize)( + IOrpcDebugNotify __RPC_FAR * This, + LPORPC_DBG_ALL lpOrpcDebugAll); + + VOID ( __stdcall __RPC_FAR *ClientFillBuffer)( + IOrpcDebugNotify __RPC_FAR * This, + LPORPC_DBG_ALL lpOrpcDebugAll); + + VOID ( __stdcall __RPC_FAR *ClientNotify)( + IOrpcDebugNotify __RPC_FAR * This, + LPORPC_DBG_ALL lpOrpcDebugAll); + + VOID ( __stdcall __RPC_FAR *ServerNotify)( + IOrpcDebugNotify __RPC_FAR * This, + LPORPC_DBG_ALL lpOrpcDebugAll); + + VOID ( __stdcall __RPC_FAR *ServerGetBufferSize)( + IOrpcDebugNotify __RPC_FAR * This, + LPORPC_DBG_ALL lpOrpcDebugAll); + + VOID ( __stdcall __RPC_FAR *ServerFillBuffer)( + IOrpcDebugNotify __RPC_FAR * This, + LPORPC_DBG_ALL lpOrpcDebugAll); + + } IOrpcDebugNotifyVtbl; + + interface IOrpcDebugNotify + { + CONST_VTBL struct IOrpcDebugNotifyVtbl __RPC_FAR *lpVtbl; + }; + +#endif + +// This is the structure that is passed by the debugger to OLE when it enables ORPC +// debugging. +typedef struct ORPC_INIT_ARGS +{ + IOrpcDebugNotify __RPC_FAR * lpIntfOrpcDebug; + void * pvPSN; // contains ptr to Process Serial No. for Mac ORPC debugging. + DWORD dwReserved1; // For future use, must be 0. + DWORD dwReserved2; +} ORPC_INIT_ARGS; + +typedef ORPC_INIT_ARGS __RPC_FAR * LPORPC_INIT_ARGS; + +// Function pointer prototype for the "DllDebugObjectRPCHook" function. +typedef BOOL (WINAPI* ORPCHOOKPROC)(BOOL, LPORPC_INIT_ARGS); + +// The first four bytes in the debug specific packet are interpreted by the +// ORPC debug layer. The valid values are the ones defined below. + +#define ORPC_DEBUG_ALWAYS (0x00000000L) // Notify always. +#define ORPC_DEBUG_IF_HOOK_ENABLED (0x00000001L) // Notify only if hook enabled. + + +// This exception code indicates that the exception is really an +// ORPC debug notification. + +#define EXCEPTION_ORPC_DEBUG (0x804f4c45) + + +//-------------------------------------------------------------------------------------- +// Private: Declarations below this point are related to the implementation and should +// be removed from the distributable version of the header file. +//-------------------------------------------------------------------------------------- + + +// Helper routines to set & restore the "Auto" value in the registry + +BOOL WINAPI DebugORPCSetAuto(VOID); +VOID WINAPI DebugORPCRestoreAuto(VOID); + + ULONG WINAPI DebugORPCClientGetBufferSize( + RPCOLEMESSAGE * pMessage, + REFIID iid, + void * reserved, + IUnknown * pUnkProxyMgr, + LPORPC_INIT_ARGS lpInitArgs, + BOOL fHookEnabled); + +void WINAPI DebugORPCClientFillBuffer( + RPCOLEMESSAGE * pMessage, + REFIID iid, + void * reserved, + IUnknown * pUnkProxyMgr, + void * pvBuffer, + ULONG cbBuffer, + LPORPC_INIT_ARGS lpInitArgs, + BOOL fHookEnabled); + +void WINAPI DebugORPCClientNotify( + RPCOLEMESSAGE * pMessage, + REFIID iid, + void * reserved, + IUnknown * pUnkProxyMgr, + HRESULT hresult, + void * pvBuffer, + ULONG cbBuffer, + LPORPC_INIT_ARGS lpInitArgs, + BOOL fHookEnabled); + +void WINAPI DebugORPCServerNotify( + RPCOLEMESSAGE * pMessage, + REFIID iid, + IRpcChannelBuffer * pChannel, + void * pInterface, + IUnknown * pUnkObject, + void * pvBuffer, + ULONG cbBuffer, + LPORPC_INIT_ARGS lpInitArgs, + BOOL fHookEnabled); + +ULONG WINAPI DebugORPCServerGetBufferSize( + RPCOLEMESSAGE * pMessage, + REFIID iid, + IRpcChannelBuffer * pChannel, + void * pInterface, + IUnknown * pUnkObject, + LPORPC_INIT_ARGS lpInitArgs, + BOOL fHookEnabled); + +void WINAPI DebugORPCServerFillBuffer( + RPCOLEMESSAGE * pMessage, + REFIID iid, + IRpcChannelBuffer * pChannel, + void * pInterface, + IUnknown * pUnkObject, + void * pvBuffer, + ULONG cbBuffer, + LPORPC_INIT_ARGS lpInitArgs, + BOOL fHookEnabled); + +#endif // __ORPC_DBG__ diff --git a/private/ole32/com/dcomrem/pgalloc.cxx b/private/ole32/com/dcomrem/pgalloc.cxx new file mode 100644 index 000000000..55976b35f --- /dev/null +++ b/private/ole32/com/dcomrem/pgalloc.cxx @@ -0,0 +1,317 @@ +//+----------------------------------------------------------------------- +// +// File: pagealloc.cxx +// +// Contents: Special fast allocator to allocate fixed-sized entities. +// +// Classes: CPageAllocator +// +// History: 02-Feb-96 Rickhi Created +// +// Notes: All synchronization is the responsibility of the caller. +// +// CODEWORK: faster list managment +// free empty pages +// +//------------------------------------------------------------------------- +#include <ole2int.h> +#include <pgalloc.hxx> // class def'n +#include <locks.hxx> // LOCK/UNLOCK + + +//+------------------------------------------------------------------------ +// +// Member: CPageAllocator::Initialize, public +// +// Synopsis: Initializes the page allocator. +// +// Notes: Instances of this class must be static since this +// function does not init all members to 0. +// +// History: 02-Feb-95 Rickhi Created +// +//------------------------------------------------------------------------- +void CPageAllocator::Initialize(LONG cbPerEntry, LONG cEntriesPerPage) +{ + ASSERT_LOCK_HELD + ComDebOut((DEB_PAGE, + "CPageAllocator::Initialize cbPerEntry:%x cEntriesPerPage:%x\n", + cbPerEntry, cEntriesPerPage)); + + Win4Assert(cbPerEntry >= sizeof(PageEntry)); + Win4Assert(cEntriesPerPage > 0); + + _cbPerEntry = cbPerEntry; + _cEntriesPerPage = cEntriesPerPage; +} + +//+------------------------------------------------------------------------ +// +// Member: CPageAllocator::Cleanup, public +// +// Synopsis: Cleanup the page allocator. +// +// History: 02-Feb-95 Rickhi Created +// +//------------------------------------------------------------------------- +void CPageAllocator::Cleanup() +{ + ComDebOut((DEB_PAGE, "CPageAllocator::Cleanup\n")); + ASSERT_LOCK_HELD + + if (_pPageListStart) + { + PageEntry **pPagePtr = _pPageListStart; + while (pPagePtr < _pPageListEnd) + { + // release each page of the table + PrivMemFree(*pPagePtr); + pPagePtr++; + } + + // release the page list + PrivMemFree(_pPageListStart); + + // reset the pointers so re-initialization is not needed + _cPages = 0; + _pPageListStart = NULL; + _pPageListEnd = NULL; + _pFirstFreeEntry = NULL; + } + + ASSERT_LOCK_HELD +} + +//+------------------------------------------------------------------------ +// +// Member: CPageAllocator::AllocEntry, public +// +// Synopsis: Finds the first available entry in the table and returns +// a ptr to it. Returns NULL if no space is available and it +// cant grow the list. +// +// History: 02-Feb-95 Rickhi Created +// +//------------------------------------------------------------------------- +PageEntry *CPageAllocator::AllocEntry() +{ + ComDebOut((DEB_PAGE, "CPageAllocator::AllocEntry\n")); + ASSERT_LOCK_HELD + + if (_pFirstFreeEntry == NULL) + { + // no free entries, grow the list + Grow(); + + if (_pFirstFreeEntry == NULL) + { + // unable to allocate more + return NULL; + } + } + + // get the ptr to return and update the _pFirstFree to the next + // available entry + + PageEntry *pEntry = _pFirstFreeEntry; + _pFirstFreeEntry = pEntry->pNext; + + ASSERT_LOCK_HELD + ComDebOut((DEB_PAGE, "CPageAllocator::AllocEntry pEntry:%x\n", pEntry)); + return pEntry; +} + +//+------------------------------------------------------------------------ +// +// Member: CPageAllocator::ReleaseEntry, private +// +// Synopsis: returns an entry on the free list. +// +// History: 02-Feb-95 Rickhi Created +// +//------------------------------------------------------------------------- +void CPageAllocator::ReleaseEntry(PageEntry *pEntry) +{ + ComDebOut((DEB_PAGE, "CPageAllocator::ReleaseEntry pEntry:%x\n", pEntry)); + Win4Assert(pEntry); + ASSERT_LOCK_HELD + + // chain it on the free list + pEntry->pNext = _pFirstFreeEntry; + _pFirstFreeEntry = pEntry; +} + +//+------------------------------------------------------------------------ +// +// Member: CPageAllocator::ReleaseEntryList, private +// +// Synopsis: returns a list of entries to the free list. +// +// History: 02-Feb-95 Rickhi Created +// +//------------------------------------------------------------------------- +void CPageAllocator::ReleaseEntryList(PageEntry *pFirst, PageEntry *pLast) +{ + ComDebOut((DEB_PAGE, + "CPageAllocator::ReleaseEntryList pFirst:%x pLast:%x\n", + pFirst, pLast)); + Win4Assert(pFirst); + Win4Assert(pLast); + ASSERT_LOCK_HELD + + // update the free list + pLast->pNext = _pFirstFreeEntry; + _pFirstFreeEntry = pFirst; +} + +//+------------------------------------------------------------------------ +// +// Member: CPageAllocator::Grow, private +// +// Synopsis: Grows the table to allow for more Entries. +// +// History: 02-Feb-95 Rickhi Created +// +//------------------------------------------------------------------------- +void CPageAllocator::Grow() +{ + Win4Assert(_pFirstFreeEntry == NULL); + ASSERT_LOCK_HELD + + // allocate a new page + LONG cbPerPage = _cbPerEntry * _cEntriesPerPage; + PageEntry *pNewPage = (PageEntry *) PrivMemAlloc(cbPerPage); + + if (pNewPage == NULL) + { + return; + } + +#if DBG==1 + // clear the page (only needed in debug) + memset(pNewPage, 0, cbPerPage); +#endif + + // compute size of current page list + LONG cbCurListSize = _cPages * sizeof(PageEntry *); + + // allocate a new page list to hold the new page ptr. + PageEntry **pNewList = (PageEntry **) PrivMemAlloc(cbCurListSize + + sizeof(PageEntry *)); + if (pNewList) + { + // copy old page list into the new page list + memcpy(pNewList, _pPageListStart, cbCurListSize); + + // set the new page ptr entry + *(pNewList + _cPages) = pNewPage; + _cPages ++; + + // replace old page list with the new page list + PrivMemFree(_pPageListStart); + _pPageListStart = pNewList; + _pPageListEnd = pNewList + _cPages; + + + // update the first free entry ptr and link all the new entries + // together in a linked list. + + _pFirstFreeEntry = pNewPage; + + PageEntry *pNextFreeEntry = pNewPage; + PageEntry *pLastFreeEntry = (PageEntry *)(((BYTE *)pNewPage) + cbPerPage - _cbPerEntry); + + while (pNextFreeEntry < pLastFreeEntry) + { + pNextFreeEntry->pNext = (PageEntry *)((BYTE *)pNextFreeEntry + _cbPerEntry); + pNextFreeEntry = pNextFreeEntry->pNext; + } + + // last entry has an pNextFree of NULL (end of list) + pLastFreeEntry->pNext = NULL; + } + else + { + // release the allocated page. + PrivMemFree(pNewPage); + } + + ComDebOut((DEB_PAGE, "CPageAllocator::Grow _pPageListStart:%x _pPageListEnd:%x _pFirstFreeEntry:%x\n", + _pPageListStart, _pPageListEnd, _pFirstFreeEntry)); +} + +//+------------------------------------------------------------------------ +// +// Member: CPageAllocator::GetEntryIndex, public +// +// Synopsis: Converts a PageEntry ptr into an index. +// +// History: 02-Feb-95 Rickhi Created +// +//------------------------------------------------------------------------- +LONG CPageAllocator::GetEntryIndex(PageEntry *pEntry) +{ + for (LONG index=0; index<_cPages; index++) + { + PageEntry *pPage = *(_pPageListStart + index); // get page ptr + if (pEntry >= pPage) + { + if (pEntry < (PageEntry *) ((BYTE *)pPage + (_cEntriesPerPage * _cbPerEntry))) + { + // found the page that the entry lives on, compute the index of + // the page and the index of the entry within the page. + return (index << PAGETBL_PAGESHIFT) + + ((BYTE *)pEntry - (BYTE *)pPage) / _cbPerEntry; + } + } + } + + // not found + return -1; +} + +//+------------------------------------------------------------------------ +// +// Member: CPageAllocator::IsValidIndex, private +// +// Synopsis: determines if the given DWORD provides a legal index +// into the PageTable. +// +// History: 02-Feb-95 Rickhi Created +// +//------------------------------------------------------------------------- +BOOL CPageAllocator::IsValidIndex(LONG index) +{ + // make sure the index is not negative, otherwise the shift will do + // sign extension. check for valid page and valid offset within page + if ( (index >= 0) && + ((index >> PAGETBL_PAGESHIFT) < _cPages) && + ((index & PAGETBL_PAGEMASK) < _cEntriesPerPage) ) + return TRUE; + + // Don't print errors during shutdown. + if (_cPages != 0) + ComDebOut((DEB_ERROR, "IsValidIndex: Invalid PageTable Index:%x\n", index)); + return FALSE; +} + +//+------------------------------------------------------------------------ +// +// Member: CPageAllocator::GetEntryPtr, public +// +// Synopsis: Converts an entry index into an entry pointer +// +// History: 02-Feb-95 Rickhi Created +// +//------------------------------------------------------------------------- +PageEntry *CPageAllocator::GetEntryPtr(LONG index) +{ + Win4Assert(index >= 0); + Win4Assert(_cPages != 0); + Win4Assert(IsValidIndex(index)); + + PageEntry *pEntry = _pPageListStart[index >> PAGETBL_PAGESHIFT]; + pEntry = (PageEntry *) ((BYTE *)pEntry + + ((index & PAGETBL_PAGEMASK) * _cbPerEntry)); + return pEntry; +} diff --git a/private/ole32/com/dcomrem/pgalloc.hxx b/private/ole32/com/dcomrem/pgalloc.hxx new file mode 100644 index 000000000..1c1c709d4 --- /dev/null +++ b/private/ole32/com/dcomrem/pgalloc.hxx @@ -0,0 +1,96 @@ +//+----------------------------------------------------------------------- +// +// File: pagealloc.hxx +// +// Contents: Special fast allocator to allocate fixed-sized entities. +// +// Classes: CPageAllocator +// +// History: 02-Feb-96 Rickhi Created +// +//------------------------------------------------------------------------- +#ifndef _PAGEALLOC_HXX_ +#define _PAGEALLOC_HXX_ + + +//+------------------------------------------------------------------------ +// +// struct: PageEntry. This is one entry in the page alloctor. +// +//+------------------------------------------------------------------------ +typedef struct tagPageEntry +{ + struct tagPageEntry *pNext; // next page in list + struct tagPageEntry *pPrev; // prev page in list +} PageEntry; + + +// Page Table constants for Index manipulation. +// The high 16bits of the PageEntry index provides the index to the page +// where the PageEntry is located. The lower 16bits provides the index +// within the page where the PageEntry is located. + +#define PAGETBL_PAGESHIFT 16 +#define PAGETBL_PAGEMASK 0x0000ffff + + +//+------------------------------------------------------------------------ +// +// class: CPageAllocator +// +// Synopsis: special fast allocator for fixed-sized entities. +// +// Notes: The table has two-levels. The top level is an array of ptrs +// to "pages" of entries. Each "page" is an array of entries +// of a given size (specified at init time). This allows us to +// grow the table by adding a new "page" and extending the top +// level by one more pointer, while allowing the existing entries +// to remain at the same address throughout their life times. +// +// A 32bit entry index can be computed for any entry. It consists +// if two 16bit indices, one for the page pointer index, and +// and one for the entry index on the page. There is also a +// function to compute the entry address from its index. +// +// This allocator is used for various internal DCOM tables. +// The main points are to keep related data close together +// to reduce working set, minimize allocation time, allow +// verifiable handles (indexs) that can be passed outside, and +// to make debugging easier (since all data is kept in tables +// its easier to find in the debugger). +// +// Tables using instances of this allocator are: +// CMIDTable COXIDTable CIPIDTable CRIFTable +// +// History: 02-Feb-96 Rickhi Created +// +//------------------------------------------------------------------------- +class CPageAllocator +{ +public: + PageEntry *AllocEntry(); // return ptr to first free entry + void ReleaseEntry(PageEntry *); // return an entry to the free list + void ReleaseEntryList(PageEntry *pFirst, PageEntry *pLast); + + LONG GetEntryIndex(PageEntry *pEntry); + BOOL IsValidIndex(LONG iEntry); // TRUE if index is valid + PageEntry *GetEntryPtr(LONG iEntry); // return ptr based on index + + // initialize the table + void Initialize(LONG cbPerEntry, LONG cEntryPerPage); + void Cleanup(); // cleanup the table + +private: + + void Grow(); // grows the table + + LONG _cPages; // count of pages in the page list + PageEntry **_pPageListStart; // ptr to start of page list + PageEntry **_pPageListEnd; // ptr to end of page list + PageEntry *_pFirstFreeEntry; // ptr to first free page entry + + LONG _cbPerEntry; // count of bytes in a single page entry + LONG _cEntriesPerPage; // # of page entries in a page +}; + +#endif // _PAGEALLOC_HXX_ diff --git a/private/ole32/com/dcomrem/remoteu.cxx b/private/ole32/com/dcomrem/remoteu.cxx new file mode 100644 index 000000000..3c58b8599 --- /dev/null +++ b/private/ole32/com/dcomrem/remoteu.cxx @@ -0,0 +1,710 @@ +//+------------------------------------------------------------------- +// +// File: remoteu.cxx +// +// Copyright (c) 1996-1996, Microsoft Corp. All rights reserved. +// +// Contents: Remote Unknown object implementation +// +// Classes: CRemoteUnknown +// +// History: 23-Feb-95 AlexMit Created +// +//-------------------------------------------------------------------- +#include <ole2int.h> +#include <remoteu.hxx> // CRemoteUnknown +#include <ipidtbl.hxx> // COXIDTable, CIPIDTable +#include <stdid.hxx> // CStdIdentity +#include <channelb.hxx> // CRpcChannelBuffer +#include <resolver.hxx> // giPingPeriod +#include <security.hxx> // FromLocalSystem + +CRemoteUnknown *gpMTARemoteUnknown = NULL; + +const WCHAR *gLocalName = L"\\\\\\Thread to thread"; + + +//+------------------------------------------------------------------- +// +// Member: CRemoteUnknown::CRemoteUnknown, public +// +// Synopsis: ctor for the CRemoteUnknown +// +// History: 22-Feb-95 Rickhi Created +// +//-------------------------------------------------------------------- +CRemoteUnknown::CRemoteUnknown(HRESULT &hr, IPID *pipid) : + _pStdId(NULL) +{ + ASSERT_LOCK_HELD + + // Marshal the remote unknown and rundown, no pinging needed. Note + // that we just marshal the IRundown interfaces since it inherits + // from IRemUnknown. This lets us use the same IPID for both + // interfaces. Also, we use the Internal version of MarshalObjRef in + // order to prevent registering the OID in the OIDTable. This allows + // us to receive Release calls during IDTableThreadUninitialize since + // we wont get cleaned up in the middle of that function. It also allows + // us to lazily create the OIDTable. + + UNLOCK // release the LOCK because MarshalObjRef expects it unlocked. + + OBJREF objref; + hr = MarshalInternalObjRef(objref, IID_IRundown, this, MSHLFLAGS_NOPING, + (void **)&_pStdId); + + LOCK + + // regardless of errors, put this object in TLS or the global. If we + // got an error marshaling, COIXIDTable::ReleaseLocalEntry still will be + // able to find us to cleanup properly. + + COleTls tls; + if (tls->dwFlags & OLETLS_APARTMENTTHREADED) + { + // Store the pRemUnk in TLS so we can clean it up on CoUninitialize. + tls->pRemoteUnk = this; + } + else + { + // store the pRemUnk in the global for the MTA apartment + gpMTARemoteUnknown = this; + } + + + if (SUCCEEDED(hr)) + { + // return the IPID to the caller, and release any allocated resources + // since all we wanted was the infrastructure, not the objref itself. + + *pipid = ORSTD(objref).std.ipid; + FreeObjRef(objref); + } + + ComDebOut((DEB_MARSHAL, + "CRemoteUnk::CRemoteUnk this:%x pStdId:%x hr:%x\n", this, _pStdId, hr)); +} + +//+------------------------------------------------------------------- +// +// Member: CRemoteUnknown::~CRemoteUnknown, public +// +// Synopsis: dtor for the CRemoteUnknown +// +// History: 22-Feb-95 Rickhi Created +// +//-------------------------------------------------------------------- +CRemoteUnknown::~CRemoteUnknown() +{ + ASSERT_LOCK_HELD + + if (_pStdId) + { + UNLOCK // DisconnectObject expects lock to be released + + // disconnect the standard identity and release it + _pStdId->DisconnectObject(0); + _pStdId->Release(); + + LOCK + } + + ComDebOut((DEB_MARSHAL, "CRemoteUnk::~CRemoteUnk this:%x\n", this)); +} + +//+------------------------------------------------------------------- +// +// Member: CRemoteUnknown::QueryInterface, public +// +// Synopsis: returns supported interfaces +// +// History: 22-Feb-95 AlexMit Created +// +//-------------------------------------------------------------------- +STDMETHODIMP CRemoteUnknown::QueryInterface(REFIID riid, void **ppv) +{ + if (IsEqualIID(riid, IID_IRundown) || // more common than IUnknown + IsEqualIID(riid, IID_IRemUnknown) || + IsEqualIID(riid, IID_IUnknown)) + { + *ppv = (IRundown *) this; + // no need to AddRef since we dont refcount this object + return S_OK; + } + + *ppv = NULL; + return E_NOINTERFACE; +} + +//+------------------------------------------------------------------- +// +// Member: CRemoteUnknown::AddRef, public +// +// Synopsis: increment reference count +// +// History: 23-Feb-95 AlexMit Created +// +//-------------------------------------------------------------------- +STDMETHODIMP_(ULONG) CRemoteUnknown::AddRef(void) +{ + return 1; +} + +//+------------------------------------------------------------------- +// +// Member: CRemoteUnknown::Release, public +// +// Synopsis: decrement reference count +// +// History: 23-Feb-95 AlexMit Created +// +//-------------------------------------------------------------------- +STDMETHODIMP_(ULONG) CRemoteUnknown::Release(void) +{ + return 1; +} + +//+------------------------------------------------------------------- +// +// Function: GetIPIDEntry, private +// +// Synopsis: find the IPIDEntry given an IPID +// +// History: 23-Feb-95 Rickhi Created +// +//-------------------------------------------------------------------- +IPIDEntry *GetIPIDEntry(REFIPID ripid) +{ + IPIDEntry *pEntry= gIPIDTbl.LookupIPID(ripid); + + if (pEntry && !(pEntry->dwFlags & IPIDF_DISCONNECTED)) + { + return pEntry; + } + + return NULL; +} + +//+------------------------------------------------------------------- +// +// Function: GetStdIdFromIPID, private +// +// Synopsis: find the stdid from the ipid +// +// History: 23-Feb-95 Rickhi Created +// +//-------------------------------------------------------------------- +CStdIdentity *GetStdIdFromIPID(REFIPID ripid) +{ + IPIDEntry *pEntry = GetIPIDEntry(ripid); + + if (pEntry) + { + return pEntry->pChnl->GetStdId(); + } + + return NULL; +} + +//+------------------------------------------------------------------- +// +// Member: CRemoteUnknown::RemQueryInterface, public +// +// Synopsis: returns supported interfaces +// +// History: 22-Feb-95 AlexMit Created +// +// Notes: Remote calls to QueryInterface for this OXID arrive here. +// This routine looks up the object and calls MarshalIPID on +// it for each interface requested. +// +//-------------------------------------------------------------------- +STDMETHODIMP CRemoteUnknown::RemQueryInterface(REFIPID ripid, ULONG cRefs, + USHORT cIids, IID *iids, REMQIRESULT **ppQIResults) +{ + ComDebOut((DEB_MARSHAL, + "CRemUnknown::RemQueryInterface this:%x ipid:%I cRefs:%x cIids:%x iids:%x ppQIResults:%x\n", + this, &ripid, cRefs, cIids, iids, ppQIResults)); + + // init the out parameters + *ppQIResults = NULL; + + // validate the input parameters + if (cIids == 0) + { + return E_INVALIDARG; + } + + // allocate space for the return parameters + REMQIRESULT *pQIRes = (REMQIRESULT *)CoTaskMemAlloc(cIids * + sizeof(REMQIRESULT)); + + if (pQIRes == NULL) + { + return E_OUTOFMEMORY; + } + + // Remember whether the IPID is for a strong or a weak reference, + // then clear the strong/weak bit so that GetIPIDEntry will find + // the IPID. It is safe to mask off this bit because we are the + // server for this IPID and we know it's format. + + DWORD mshlflags = MSHLFLAGS_NORMAL; + DWORD sorfflags = SORF_NULL; + + if (ripid.Data1 & IPIDFLAG_WEAKREF) + { + mshlflags = MSHLFLAGS_WEAK; + sorfflags = SORF_WEAKREF; + ((IPID &)(ripid)).Data1 &= ~IPIDFLAG_WEAKREF; // overcome the const + } + + + ASSERT_LOCK_RELEASED + LOCK + + CStdIdentity *pStdId = GetStdIdFromIPID(ripid); + if (pStdId == NULL) + { + UNLOCK + ASSERT_LOCK_RELEASED + + CoTaskMemFree(pQIRes); + return RPC_E_INVALID_OBJECT; + } + + USHORT cFails = 0; + HRESULT hr = pStdId->PreventDisconnect(); + + if (SUCCEEDED(hr)) + { + *ppQIResults = pQIRes; + + for (USHORT i=0; i < cIids; i++, pQIRes++) + { + // marshal each interface that was requested + + IPIDEntry *pIPIDEntry; + pQIRes->hResult = pStdId->MarshalIPID(iids[i], cRefs, mshlflags, + &pIPIDEntry); + if (SUCCEEDED(pQIRes->hResult)) + { + pStdId->FillSTD(&pQIRes->std, cRefs, mshlflags, pIPIDEntry); + pQIRes->std.flags |= sorfflags; + } + else + { + // on failure, the STDOBJREF must be NULL + memset(&pQIRes->std, 0, sizeof(pQIRes->std)); + cFails++; + } + } + } + else + { + CoTaskMemFree(pQIRes); + } + + UNLOCK + ASSERT_LOCK_RELEASED + + if (cFails > 0) + { + hr = (cFails == cIids) ? E_NOINTERFACE : S_FALSE; + } + + // handle any disconnects that came in while we were marshaling + // the requested interfaces. + hr = pStdId->HandlePendingDisconnect(hr); + + + ComDebOut((DEB_MARSHAL, + "CRemUnknown::RemQueryInterface this:%x pQIRes:%x hr:%x\n", + this, *ppQIResults, hr)); + return hr; +} + +//+------------------------------------------------------------------- +// +// Member: CRemoteUnknown::GetSecBinding +// +// Synopsis: Get the security binding of the caller +// +// History: 21-Feb-96 AlexMit Created +// +//-------------------------------------------------------------------- +HRESULT CRemoteUnknown::GetSecBinding( SECURITYBINDING **pSecBind ) +{ + HRESULT hr; + DWORD lAuthnSvc; + DWORD lAuthzSvc; + DWORD lAuthnLevel; + const WCHAR *pPrivs; + DWORD lLen; + + hr = CoQueryClientBlanket( &lAuthnSvc, &lAuthzSvc, NULL, + &lAuthnLevel, NULL, (void **) &pPrivs, NULL ); + if (FAILED(hr)) + return hr; + + // For thread to thread calls, make up a privilege name. + if (pPrivs == NULL && LocalCall()) + pPrivs = gLocalName; + else if (lAuthnLevel == RPC_C_AUTHN_LEVEL_NONE || + lAuthnLevel < gAuthnLevel || + pPrivs == NULL || + pPrivs[0] == 0) + return E_INVALIDARG; + + lLen = lstrlenW( pPrivs ) * sizeof(WCHAR); + *pSecBind = (SECURITYBINDING *) PrivMemAlloc( + sizeof(SECURITYBINDING) + lLen ); + if (*pSecBind != NULL) + { + // BUGBUG - Sometimes mswmsg returns authn svc 0. + if (lAuthnSvc == RPC_C_AUTHN_NONE) + lAuthnSvc = RPC_C_AUTHN_WINNT; + + (*pSecBind)->wAuthnSvc = (USHORT) lAuthnSvc; + if (lAuthzSvc == RPC_C_AUTHZ_NONE) + (*pSecBind)->wAuthzSvc = COM_C_AUTHZ_NONE; + else + (*pSecBind)->wAuthzSvc = (USHORT) lAuthzSvc; + memcpy( &(*pSecBind)->aPrincName, pPrivs, lLen+2 ); + return S_OK; + } + else + return E_OUTOFMEMORY; +} + +//+------------------------------------------------------------------- +// +// Member: CRemoteUnknown::RemAddRef, public +// +// Synopsis: increment reference count +// +// History: 22-Feb-95 AlexMit Created +// +// Description: Remote calls to AddRef for this OXID arrive +// here. This routine just looks up the correct remote +// remote handler and asks it to do the work. +// +//-------------------------------------------------------------------- +STDMETHODIMP CRemoteUnknown::RemAddRef(unsigned short cInterfaceRefs, + REMINTERFACEREF InterfaceRefs[], + HRESULT *pResults) +{ + // Adjust the reference count for each entry. + + ASSERT_LOCK_RELEASED + LOCK + + HRESULT hr = S_OK; + HRESULT hr2; + SECURITYBINDING *pSecBind = NULL; + REMINTERFACEREF *pNext = InterfaceRefs; + + for (USHORT i=0; i < cInterfaceRefs; i++, pNext++) + { + // Get the IPIDEntry for the specified IPID. + IPIDEntry *pEntry = GetIPIDEntry(pNext->ipid); + if (!pEntry) + { + // Don't assert on failure. The server can disconnect and go away + // while clients exist. + pResults[i] = hr = CO_E_OBJNOTREG; + continue; + } + + // get the stdmarshal identity + CStdIdentity *pStdId = pEntry->pChnl->GetStdId(); + + if (pStdId) + { + ComDebOut((DEB_MARSHAL, + "CRemUnknown::RemAddRef pEntry:%x cCur:%x cAdd:%x cStdId:%x ipid:%I\n", pEntry, + pEntry->cStrongRefs, pNext->cPublicRefs, pStdId->GetRC(), &pNext->ipid)); + + Win4Assert(pNext->cPublicRefs > 0 || + pNext->cPrivateRefs > 0); + + // Lookup security info the first time an entry asks for + // secure references. + if (pNext->cPrivateRefs != 0 && pSecBind == NULL) + { + hr2 = GetSecBinding( &pSecBind ); + if (FAILED(hr2)) + { + hr = pResults[i] = hr2; + continue; + } + } + + hr2 = pStdId->IncSrvIPIDCnt(pEntry, pNext->cPublicRefs, + pNext->cPrivateRefs, pSecBind, + MSHLFLAGS_NORMAL); + if (FAILED(hr2)) + hr = pResults[i] = hr2; + else + pResults[i] = S_OK; + } + else + hr = pResults[i] = CO_E_OBJNOTREG; + } + + UNLOCK + ASSERT_LOCK_RELEASED + + PrivMemFree( pSecBind ); + + return hr; +} + +//+------------------------------------------------------------------- +// +// Member: CRemoteUnknown::RemRelease, public +// +// Synopsis: decrement reference count +// +// History: 22-Feb-95 AlexMit Created +// +// Description: Remote calls to Release for this OXID arrive +// here. This routine just looks up the correct remote +// remote handler and asks it to do the work. +// +//-------------------------------------------------------------------- +STDMETHODIMP CRemoteUnknown::RemRelease(unsigned short cInterfaceRefs, + REMINTERFACEREF InterfaceRefs[]) +{ + REMINTERFACEREF *pNext = InterfaceRefs; + SECURITYBINDING *pSecBind = NULL; + + ASSERT_LOCK_RELEASED + LOCK + + // Adjust the reference count for each entry. + for (USHORT i=0; i < cInterfaceRefs; i++, pNext++) + { + // Get the entry for the requested IPID. Remember whether this + // is an IPID for a strong or a weak reference, then clear the + // strong/weak bit so that GetIPIDEntry will find the IPID. + + DWORD mshlflags = (InterfaceRefs[i].ipid.Data1 & IPIDFLAG_WEAKREF) + ? MSHLFLAGS_WEAK : MSHLFLAGS_NORMAL; + + InterfaceRefs[i].ipid.Data1 &= ~IPIDFLAG_WEAKREF; + IPIDEntry *pEntry = GetIPIDEntry(InterfaceRefs[i].ipid); + + if (pEntry) + { + // Get the entry for the requested IPID. + CStdIdentity *pStdId = pEntry->pChnl->GetStdId(); + + if (pStdId) + { + + // Get the client's security binding on the first entry + // that releases secure references. + if (pNext->cPrivateRefs > 0 && pSecBind == NULL) + { + GetSecBinding( &pSecBind ); + if (pSecBind == NULL) + continue; + } + pStdId->AddRef(); + + ComDebOut((DEB_MARSHAL, + "CRemUnknown::RemRelease pEntry:%x cCur:%x cStdId:%x cRel:%x mshlflags:%x ipid:%I\n", pEntry, + (mshlflags == MSHLFLAGS_WEAK) ? pEntry->cWeakRefs : pEntry->cStrongRefs, + pStdId->GetRC(), pNext->cPublicRefs, mshlflags, &pNext->ipid)); + + Win4Assert(pNext->cPublicRefs > 0 || pNext->cPrivateRefs > 0); + + // Prevent a disconnect from occuring while releasing the + // interface since we have to yield the ORPC lock. + HRESULT hr = pStdId->PreventDisconnect(); + + if (SUCCEEDED(hr)) + { + pStdId->DecSrvIPIDCnt(pEntry, pNext->cPublicRefs, + pNext->cPrivateRefs, pSecBind, + mshlflags); + } + + // do the final release of the object while not holding + // the lock, since it may call into the server. + + UNLOCK + ASSERT_LOCK_RELEASED + + // This will handle any Disconnect that came in while we were + // busy. Ignore error codes since we are releasing. + pStdId->HandlePendingDisconnect(hr); + + pStdId->Release(); + + ASSERT_LOCK_RELEASED + LOCK + } + } + } + + UNLOCK + ASSERT_LOCK_RELEASED + + PrivMemFree( pSecBind ); + return S_OK; +} + +//+------------------------------------------------------------------------- +// +// Member: CRemoteUnknown::RemChangeRefs, public +// +// Synopsis: Change an interface reference from strong/weak or vice versa. +// +// History: 08-Nov-95 Rickhi Created +// +// Note: It is safe for this routine to ignore private refcounts +// becuase it is only called locally hence we own the client +// implementation and can guarantee they are zero. +// +//-------------------------------------------------------------------------- +STDMETHODIMP CRemoteUnknown::RemChangeRef(ULONG flags, USHORT cInterfaceRefs, + REMINTERFACEREF InterfaceRefs[]) +{ + ASSERT_LOCK_RELEASED + LOCK + + // figure out the flags to pass to the Inc/DecSrvIPIDCnt + BOOL fMakeStrong = flags & IRUF_CONVERTTOSTRONG; + DWORD IncFlags = fMakeStrong ? MSHLFLAGS_NORMAL : MSHLFLAGS_WEAK; + DWORD DecFlags = fMakeStrong ? MSHLFLAGS_WEAK : MSHLFLAGS_NORMAL; + DecFlags |= (flags & IRUF_DISCONNECTIFLASTSTRONG) ? 0 : MSHLFLAGS_KEEPALIVE; + + CStdIdentity *pStdId = NULL; + + for (USHORT i=0; i < cInterfaceRefs; i++) + { + // Get the entry for the specified IPID. + IPIDEntry *pEntry = GetIPIDEntry(InterfaceRefs[i].ipid); + + if (pEntry) + { + // find the StdId for this IPID. We assume that the client + // only gives us IPIDs for the same object, so first time + // we find a StdId we remember it and AddRef it. This is a safe + // assumption cause the client is local to this machine (ie + // we wrote the client). + + CStdIdentity *pStdIdTmp = pEntry->pChnl->GetStdId(); + + if (pStdIdTmp != NULL) + { + if (pStdId == NULL) + { + pStdId = pStdIdTmp; + pStdId->AddRef(); + } + + // We assume that all IPIDs are for the same object. We + // just verify that here. + + if (pStdId == pStdIdTmp) + { + // tweak the reference counts + pStdId->IncSrvIPIDCnt( + pEntry, InterfaceRefs[i].cPublicRefs, + fMakeStrong ? InterfaceRefs[i].cPrivateRefs : 0, + NULL, IncFlags); + pStdId->DecSrvIPIDCnt( + pEntry, InterfaceRefs[i].cPublicRefs, + fMakeStrong ? 0 : InterfaceRefs[i].cPrivateRefs, + NULL, DecFlags); + } + } + } + } + + UNLOCK + ASSERT_LOCK_RELEASED + + if (pStdId) + { + // release the AddRef (if any) we did above + pStdId->Release(); + } + + return S_OK; +} + +//+------------------------------------------------------------------- +// +// Member: CRemoteUnknown::RundownOid, public +// +// Synopsis: Tell the server that no clients are using an object +// +// History: 25 May 95 AlexMit Created +// +// Description: Lookup each OID in the IDTable. If found and not +// recently marshaled, call DisconnectObject on it. +// +//-------------------------------------------------------------------- +STDMETHODIMP CRemoteUnknown::RundownOid(ULONG cOid, OID aOid[], + unsigned char afOkToRundown[]) +{ + DWORD iNow = GetCurrentTime(); + + ASSERT_LOCK_RELEASED + + if (IsCallerLocalSystem()) + { + LOCK + for (ULONG i = 0; i < cOid; i++) + { + afOkToRundown[i] = TRUE; + + MOID moid; + MOIDFromOIDAndMID(aOid[i], gLocalMid, &moid); + + CStdIdentity *pStdId; + HRESULT hr = LookupIDFromID(moid, TRUE, &pStdId); + + if (SUCCEEDED(hr)) + { + afOkToRundown[i] = pStdId->CanRunDown(iNow); + + UNLOCK + ASSERT_LOCK_RELEASED + + if (afOkToRundown[i] == TRUE) + { + pStdId->DisconnectObject( 0 ); + } + pStdId->Release(); + + ASSERT_LOCK_RELEASED + LOCK + } + else + { + // need to look at the set of pre-registered OIDs to ensure + // we dont run these down before we use them. + + afOkToRundown[i] = gResolver.ServerCanRundownOID(aOid[i]); + } + } + UNLOCK + } + + // Rather then being rude and returning access denied, tell the caller + // that all the objects have been released. + else + { + ComDebOut((DEB_ERROR, "Invalid user called CRemoteUnknown::RundownOid" )); + for (ULONG i = 0; i < cOid; i++) + afOkToRundown[i] = TRUE; + } + + ASSERT_LOCK_RELEASED + return S_OK; +} diff --git a/private/ole32/com/dcomrem/remoteu.hxx b/private/ole32/com/dcomrem/remoteu.hxx new file mode 100644 index 000000000..d74984fa5 --- /dev/null +++ b/private/ole32/com/dcomrem/remoteu.hxx @@ -0,0 +1,93 @@ +//+------------------------------------------------------------------- +// +// File: remoteu.hxx +// +// Contents: Remote Unknown class definition +// +// Classes: CRemoteUnknown +// +// Functions: +// +// History: 23-Feb-95 AlexMit Created +// +// Notes: Each server has one remote unknown object per OXID. +// Each client OXID has a table of proxies to OXIDs referenced +// by the client OXID. The table includes a pointer +// to the remote unknown for the client (if it has one). +// Entries in the table are reference counted. +// An OXID references a thread in the apartment model and +// a process in the free threaded model. +// +//-------------------------------------------------------------------- +#ifndef __REMOTEU__ +#define __REMOTEU__ + +#include <obase.h> +#include <remunk.h> +#include <odeth.h> + +// forward declaration +class CStdIdentity; + +// we set the top bit in the first dword of an IPID to flag the IPID as +// holding weak references, so that RemRelease and RemQueryInterface between +// an OLE container and an embedded object works as desired. Note that this +// this is strictly a same-machine protocol, it is not part of the published +// DCOM protocol spec. + +#define IPIDFLAG_WEAKREF 0x80000000 + + +//+------------------------------------------------------------------------- +// +// Class: CRemoteUnknown +// +// Purpose: Pass remote IUnknown calls and rundowns to the correct +// local standard identity. +// +// History: 23-Feb-95 AlexMit Created +// +//-------------------------------------------------------------------------- +class CRemoteUnknown : public IRundown, public CPrivAlloc +{ +public: + CRemoteUnknown(HRESULT &hr, IPID *pipid); + ~CRemoteUnknown(); + + // IUnknown + STDMETHOD (QueryInterface) ( REFIID riid, LPVOID FAR* ppvObj); + STDMETHOD_(ULONG,AddRef) ( void ); + STDMETHOD_(ULONG,Release) ( void ); + + // IRemUnknown + STDMETHOD(RemQueryInterface) ( REFIPID ripid, + ULONG cRefs, + unsigned short cIids, + IID *iids, + REMQIRESULT **ppQIResults); + + STDMETHOD(RemAddRef) ( unsigned short cInterfaceRefs, + REMINTERFACEREF InterfaceRefs[], + HRESULT *pResults ); + STDMETHOD(RemRelease) ( unsigned short cInterfaceRefs, + REMINTERFACEREF InterfaceRefs[] ); + + // IRemUnknown2 + STDMETHOD(RemChangeRef) ( unsigned long flags, + unsigned short cInterfaceRefs, + REMINTERFACEREF InterfaceRefs[]); + + // IRundown + STDMETHOD(RundownOid) ( ULONG cOid, + OID aOid[], + unsigned char afOkToRundown[] ); +private: + HRESULT GetSecBinding( SECURITYBINDING **pSecBind ); + + CStdIdentity *_pStdId; // stdid for this object +}; + +// remote unknown pointer for MTA Apartment. +extern CRemoteUnknown *gpMTARemoteUnknown; + +#endif // __REMOTEU__ diff --git a/private/ole32/com/dcomrem/resolver.cxx b/private/ole32/com/dcomrem/resolver.cxx new file mode 100644 index 000000000..94959df81 --- /dev/null +++ b/private/ole32/com/dcomrem/resolver.cxx @@ -0,0 +1,2959 @@ +//+------------------------------------------------------------------- +// +// File: resolver.cxx +// +// Contents: class implementing interface to RPC OXID/PingServer +// resolver process. Only one instance per process. +// +// Classes: CRpcResolver +// +// History: 20-Feb-95 Rickhi Created +// +//-------------------------------------------------------------------- +#include <ole2int.h> +#include <resolver.hxx> // CRpcResolver +#include <service.hxx> // GetStringBindings +#include <locks.hxx> // LOCK/UNLOCK etc +#include <security.hxx> // GetCallAuthnLevel +#include <marshal.hxx> // GetOXIDFromObjRef +#include <sobjact.hxx> // CObjServer + + +// global instance of OXID resolver +CRpcResolver gResolver; + +// static members of CRpcResolver + +handle_t CRpcResolver::_hRpc = NULL; // binding handle to resolver +PHPROCESS CRpcResolver::_ph = NULL; // context handle to resolver +HANDLE CRpcResolver::_hThrd = NULL; // worker thread handle +HANDLE CRpcResolver::_hEventOXID = NULL; // event for registering threads +DWORD CRpcResolver::_dwFlags = 0; // flags +DWORD CRpcResolver::_dwSleepPeriod = 0; // worker thread sleep period +ULONG CRpcResolver::_cReservedOidsAvail = 0; +ULONGLONG CRpcResolver::_OidNextReserved = 0; +ULONG CRpcResolver::_cOidsToAdd = 0; // # OIDs to add next call +ULONG CRpcResolver::_cOidsToRemove = 0; // # OIDs to remove next call +ULONG CRpcResolver::_cPreRegOidsAvail = 0; // # Pre-Regist'd OIDs available +OID CRpcResolver::_arPreRegOids[MAX_PREREGISTERED_OIDS]; + +IDSCM * CRpcResolver::_pSCMSTA = NULL; // single-threaded scm proxy +IDSCM * CRpcResolver::_pSCMMTA = NULL; // multi-threaded scm proxy +LPWSTR CRpcResolver::_pwszWinstaDesktop = NULL; + +DWORD CRpcResolver::_dwProcessSignature = 0; +BOOL CRpcResolver::_bDynamicSecurity = FALSE; + +// List of OIDs to register/ping/revoke with the resolver used +// for lazy/batch client-side OID processing. + +SOIDRegistration CRpcResolver::_ClientOIDRegList = {{{NULL, NULL},}, + 0, 0, NULL, + &_ClientOIDRegList, + &_ClientOIDRegList}; +// MID (machine ID) of local machine +MID gLocalMid; + +// Ping period in milliseconds. +DWORD giPingPeriod; + +// string binding to the resolver +const WCHAR *pwszResolverBindString = L"ncalrpc:[epmapper,Security=Impersonation Dynamic False]"; + +// String arrays for the SCM process. These are used to tell the interface +// marshaling code the protocol and endpoint of the SCM process. + +#ifdef _CHICAGO_ +typedef struct tagSCMSA +{ + unsigned short wNumEntries; // Number of entries in array. + unsigned short wSecurityOffset; // Offset of security info. + WCHAR awszStringArray[26]; +} SCMSA; + +SCMSA saSCM = {26, 25, L"mswmsg:[endpoint mapper]\0" }; + +#else + +typedef struct tagSCMSA +{ + unsigned short wNumEntries; // Number of entries in array. + unsigned short wSecurityOffset; // Offset of security info. + WCHAR awszStringArray[60]; +} SCMSA; + +// The last 4 characters in the string define the security bindings. +// \0xA is RPC_C_AUTHN_WINNT +// \0xFFFF is COM_C_AUTHZ_NONE +// \0 is an empty principle name +SCMSA saSCM = {57, 56, L"ncalrpc:[epmapper,Security=Impersonation Dynamic False]\0\xA\xFFFF\0"}; +#endif + +DWORD GetThreadWinstaDesktop( WCHAR ** ppwszWinstaDesktop ); + +//+------------------------------------------------------------------- +// +// Member: CRpcResolver::Cleanup, public +// +// Synopsis: cleanup the resolver state. Called by ProcessUninitialze. +// +// History: 20-Feb-95 Rickhi Created. +// +//-------------------------------------------------------------------- +void CRpcResolver::Cleanup() +{ + ASSERT_LOCK_HELD + + // release our context handle + if (_ph != NULL) + { + RpcSmDestroyClientContext(&_ph); + _ph = NULL; + } + + // release regular handle + if (_hRpc) + { + RpcBindingFree(&_hRpc); + _hRpc = NULL; + } + + // Release the string bindings for the local object exporter. + if (gpsaLocalResolver) + { + MIDL_user_free(gpsaLocalResolver); + gpsaLocalResolver = NULL; + } + + // empty the OIDRegList. Any SOIDRegistration records have already + // been deleted by the gClientRegisteredOIDs list cleanup code. + + _ClientOIDRegList.pPrevList = &_ClientOIDRegList; + _ClientOIDRegList.pNextList = &_ClientOIDRegList; + _cOidsToAdd = 0; + _cOidsToRemove = 0; + + // zero the count of pre-registered oids since all pre-registered + // Oids are for our old OXID value. + + _cPreRegOidsAvail = 0; + + // close the event handle (if any) + if (_hEventOXID) + { + CloseHandle(_hEventOXID); + _hEventOXID = NULL; + } + + if (_pwszWinstaDesktop != NULL) + { + PrivMemFree(_pwszWinstaDesktop); + _pwszWinstaDesktop = NULL; + } + + _bDynamicSecurity = FALSE; +} + +//+------------------------------------------------------------------- +// +// Member: CRpcResolver::ReleaseSCMProxy, public +// +// Synopsis: cleanup the resolver state. Called by ProcessUninitialze. +// +// History: 20-Feb-95 Rickhi Created. +// +//-------------------------------------------------------------------- +void CRpcResolver::ReleaseSCMProxy() +{ + if (_pSCMSTA != NULL) + { + _pSCMSTA->Release(); + _pSCMSTA = NULL; + } + + if (_pSCMMTA != NULL) + { + _pSCMMTA->Release(); + _pSCMMTA = NULL; + } + + if (gpMTAObjServer != NULL) + { + delete gpMTAObjServer; + gpMTAObjServer = NULL; + } +} + +//+------------------------------------------------------------------- +// +// Member: CRpcResolver::RetryRPC, private +// +// Synopsis: determine if we need to retry the RPC call due to +// the resolver being too busy. +// +// History: 20-Feb-95 Rickhi Created. +// +//-------------------------------------------------------------------- +BOOL CRpcResolver::RetryRPC(RPC_STATUS sc) +{ + if (sc != RPC_S_SERVER_TOO_BUSY) + return FALSE; + + // give the resolver time to run, then try again. + Sleep(100); + + // CODEWORK: this is currently an infinite loop. Should we limit it? + return TRUE; +} + +//+------------------------------------------------------------------- +// +// Member: CRpcResolver::CheckStatus, private +// +// Synopsis: Checks the status code of an Rpc call, prints a debug +// ERROR message if failed, and maps the failed status code +// into an HRESULT. +// +// History: 20-Feb-95 Rickhi Created. +// +//-------------------------------------------------------------------- +HRESULT CRpcResolver::CheckStatus(RPC_STATUS sc) +{ + if (sc != RPC_S_OK) + { + ComDebOut((DEB_ERROR, "OXID Resolver Failure sc:%x\n", sc)); + sc = HRESULT_FROM_WIN32(sc); + } + + return sc; +} + +//+------------------------------------------------------------------- +// +// Member: CRpcResolver::GetConnection, public +// +// Synopsis: connects to the resolver process +// +// History: 20-Feb-95 Rickhi Created. +// +//-------------------------------------------------------------------- +HRESULT CRpcResolver::GetConnection() +{ + ComDebOut((DEB_OXID,"CRpcResolver::GetConnection\n")); + + HRESULT hr; + COleTls tls(hr); + if (FAILED(hr)) + { + return hr; + } + + RPC_STATUS sc = RPC_S_OK; + + LOCK + + if (_ph == NULL) + { + sc = RpcBindingFromStringBinding((LPWSTR)pwszResolverBindString, &_hRpc); + ComDebErr(sc != RPC_S_OK, "Resolver Binding Failed.\n"); + + if (sc == RPC_S_OK) + { + OID oidBase; + DWORD fConnectFlags; + + do + { + // call the resolver to get a context handle + sc = Connect(_hRpc, + &_ph, + &giPingPeriod, + &gpsaLocalResolver, + &gLocalMid, + MAX_RESERVED_OIDS, + &oidBase, + &fConnectFlags, + (WCHAR **) &gLegacySecurity, + &gAuthnLevel, + &gImpLevel, + &gServerSvcListLen, + &gServerSvcList, + &gClientSvcListLen, + &gClientSvcList, + &(tls->dwApartmentID), + &gdwScmProcessID, + &_dwProcessSignature); + } while (RetryRPC(sc)); + + if (sc == RPC_S_OK) + { + gDisableDCOM = fConnectFlags & CONNECT_DISABLEDCOM; + if (fConnectFlags & CONNECT_MUTUALAUTH) + gCapabilities = EOAC_MUTUAL_AUTH; + else + gCapabilities = EOAC_NONE; + if (fConnectFlags & CONNECT_SECUREREF) + gCapabilities |= EOAC_SECURE_REFS; + + // remember the reserved OID base. + _OidNextReserved = oidBase; + _cReservedOidsAvail = MAX_RESERVED_OIDS; + + // Mark the security data as initialized. + gGotSecurityData = TRUE; + if (IsWOWProcess()) + { + gDisableDCOM = TRUE; + } + + // Convert the ping period from seconds to milliseconds. + giPingPeriod *= 1000; + Win4Assert(gpsaLocalResolver->wNumEntries != 0); + + // compute the sleep period for the registration worker thread + // (which is 1/6th the ping period). The ping period may differ + // on debug and retail builds. +#if DBG==1 + // shorter time period to enable testing + _dwSleepPeriod = 5000; +#else + _dwSleepPeriod = giPingPeriod / 6; +#endif + } + else + { + ComDebOut((DEB_OXID, "Resolver Connect Failed sc:%x\n", sc)); + RpcBindingFree(&_hRpc); + _hRpc = NULL; + Win4Assert(gpsaLocalResolver == NULL); + Win4Assert(_ph == NULL); + } + } + } + + if ( (sc == RPC_S_OK) && (_pwszWinstaDesktop == NULL)) + sc = SetWinstaDesktop(); + + UNLOCK + + hr = CheckStatus(sc); + ComDebErr(hr != S_OK, "GetConnection Failed.\n"); + ComDebOut((DEB_OXID,"CRpcResolver::GetConnection hr:%x\n", hr)); + return hr; +} + +//+-------------------------------------------------------------------------- +// +// Member: CRpcResolver::ServerGetReservedMOID, public +// +// Synopsis: Get an OID that does not need to be pinged. +// +// History: 06-Nov-95 Rickhi Created. +// +//---------------------------------------------------------------------------- +HRESULT CRpcResolver::ServerGetReservedMOID(MOID *pmoid) +{ + ComDebOut((DEB_OXID,"ServerGetReservedMOID\n")); + + OID oid; + HRESULT hr = ServerGetReservedID(&oid); + + MOIDFromOIDAndMID(oid, gLocalMid, pmoid); + + ASSERT_LOCK_HELD + ComDebOut((DEB_OXID,"ServerGetReservedMOID hr:%x moid:%I\n", pmoid)); + return hr; +} + +//+-------------------------------------------------------------------------- +// +// Member: CRpcResolver::ServerGetReservedID, public +// +// Synopsis: Get an ID that does not need to be pinged. +// +// History: 06-Nov-95 Rickhi Created. +// +//---------------------------------------------------------------------------- +HRESULT CRpcResolver::ServerGetReservedID(OID *pid) +{ + ComDebOut((DEB_OXID,"ServerGetReservedID\n")); + ASSERT_LOCK_HELD + + HRESULT hr = S_OK; + + if (_cReservedOidsAvail == 0) + { + // go get more reserved OIDs from the ping server + UNLOCK + ASSERT_LOCK_RELEASED + + OID OidBase; + + do + { + hr = ::AllocateReservedIds( + _hRpc, // Rpc binding handle + MAX_RESERVED_OIDS, // count of OIDs requested + &OidBase); // place to hold base id + + } while (RetryRPC(hr)); + + // map Rpc status if necessary + hr = CheckStatus(hr); + + ASSERT_LOCK_RELEASED + LOCK + + if (SUCCEEDED(hr)) + { + // copy into global state. Dont have to worry about two threads + // getting more simultaneously, since these OIDs are expendable. + + _cReservedOidsAvail = MAX_RESERVED_OIDS; + _OidNextReserved = OidBase; + } + } + + if (SUCCEEDED(hr)) + { + *pid = _OidNextReserved; + _OidNextReserved++; + _cReservedOidsAvail--; + } + + ASSERT_LOCK_HELD + ComDebOut((DEB_OXID,"ServerGetReservedID hr:%x id:%08x %08x\n", *pid)); + return hr; +} + +//+-------------------------------------------------------------------------- +// +// Member: CRpcResolver::ServerGetPreRegMOID, public +// +// Synopsis: Get an OID that has been pre-registered with the Ping +// Server. +// +// History: 06-Nov-95 Rickhi Created. +// +// Notes: careful. The oids are dispensed in reverse order [n]-->[0], so the +// unused ones are from [0]-->[cPreRegOidsAvail-1]. ServerCanRundownOID +// depends on this behavior. +// +//---------------------------------------------------------------------------- +HRESULT CRpcResolver::ServerGetPreRegMOID(MOID *pmoid) +{ + ComDebOut((DEB_OXID,"ServerGetPreRegMOID\n")); + ASSERT_LOCK_HELD + + // Get the local OXID. This cant fail because the local + // entry was pre-created in ChannelThreadInitialize. + + OXIDEntry *pOXIDEntry; + HRESULT hr = gOXIDTbl.GetLocalEntry(&pOXIDEntry); + Win4Assert(SUCCEEDED(hr)); + + COleTls tls; + if (!(tls->dwFlags & OLETLS_APARTMENTTHREADED)) + { + // in MTA Apartment, use the global list and global count. + + if (_cPreRegOidsAvail == 0) + { + hr = ServerAllocMoreOIDs(&_cPreRegOidsAvail, _arPreRegOids, + pOXIDEntry); + } + + if (SUCCEEDED(hr)) + { + _cPreRegOidsAvail--; + MOIDFromOIDAndMID(_arPreRegOids[_cPreRegOidsAvail], + gLocalMid, pmoid); + } + } + else + { + // In STA Apartment, the pre-registered OIDs are kept per apartment + // in a list off of tls. + + if (tls->cPreRegOidsAvail == 0) + { + if (tls->pPreRegOids == NULL) + { + // first time for this thread. Allocate a list to hold + // the pre-registered oids. + + tls->pPreRegOids = (OID *)PrivMemAlloc(MAX_PREREGISTERED_OIDS * + sizeof(OID)); + if (tls->pPreRegOids == NULL) + { + hr = E_OUTOFMEMORY; + } + } + + if (SUCCEEDED(hr)) + { + hr = ServerAllocMoreOIDs(&tls->cPreRegOidsAvail, + tls->pPreRegOids, pOXIDEntry); + } + } + + if (SUCCEEDED(hr)) + { + tls->cPreRegOidsAvail--; + MOIDFromOIDAndMID(tls->pPreRegOids[tls->cPreRegOidsAvail], + gLocalMid, pmoid); + } + } + + ASSERT_LOCK_HELD + ComDebOut((DEB_OXID,"ServerGetPreRegMOID hr:%x moid:%I\n", hr, pmoid)); + return hr; +} + +//+------------------------------------------------------------------- +// +// Member: CRpcResolver::ServerCanRundownOID, public +// +// Synopsis: Determine if OK to rundown the specified OID. +// +// History: 06-Nov-95 Rickhi Created. +// +//-------------------------------------------------------------------- +BOOL CRpcResolver::ServerCanRundownOID(REFOID roid) +{ + ComDebOut((DEB_OXID,"ServerCanRundownOID poid:%x\n", &roid)); + ASSERT_LOCK_HELD + + // look in the list of unused pre-registered OIDs to see if the + // OID is in there. If so, we dont want to run it down yet so + // return FALSE, otherwise return TRUE + + BOOL fRundown = TRUE; // assume not found + + ULONG cPreRegOidsAvail = _cPreRegOidsAvail; + OID *pPreRegOids = &_arPreRegOids[0]; + + COleTls tls; + + if (tls->dwFlags & OLETLS_APARTMENTTHREADED) + { + cPreRegOidsAvail = tls->cPreRegOidsAvail; + pPreRegOids = tls->pPreRegOids; + } + + // carefull. The oids are dispensed in reverse order (ie [n]-->[0]) + // so when checking for unused ones check in forward order + // [0]-->[cPreRegOidsAvail-1] + + for (ULONG i=0; i<cPreRegOidsAvail; i++, pPreRegOids++) + { + if (roid == *pPreRegOids) + { + // found the oid in the list of unused ones. Dont run it down. + fRundown = FALSE; + break; + } + } + + ASSERT_LOCK_HELD + ComDebOut((DEB_OXID,"ServerCanRundownOID fRundown:%x\n", fRundown)); + return fRundown; +} + +//+------------------------------------------------------------------- +// +// Member: CRpcResolver::WaitForOXIDEntry, private +// +// Synopsis: waits until an OXIDEntry is not busy +// +// History: 06-Nov-95 Rickhi Created. +// +//-------------------------------------------------------------------- +HRESULT CRpcResolver::WaitForOXIDEntry(OXIDEntry *pOXIDEntry) +{ + ASSERT_LOCK_HELD + + if (pOXIDEntry->dwFlags & OXIDF_REGISTERINGOIDS) + { + // some other thread is busy registering OIDs for this OXID + // so lets wait for it to finish. This should only happen in + // the MTA apartment. + Win4Assert(IsMTAThread()); + + if (_hEventOXID == NULL) + { + _hEventOXID = CreateEvent(NULL, FALSE, FALSE, NULL); + if (_hEventOXID == NULL) + { + return HRESULT_FROM_WIN32(GetLastError()); + } + } + + // count one more waiter + pOXIDEntry->cWaiters++; + + do + { + // release the lock before we block so the other thread can wake + // us up when it returns. + UNLOCK + ASSERT_LOCK_RELEASED + + ComDebOut((DEB_WARN,"WaitForOXIDEntry wait on hEvent:%x\n", _hEventOXID)); + DWORD rc = WaitForSingleObject(_hEventOXID, INFINITE); + Win4Assert(rc == WAIT_OBJECT_0); + + ASSERT_LOCK_RELEASED + LOCK + + } while (pOXIDEntry->dwFlags & OXIDF_REGISTERINGOIDS); + + // one less waiter + pOXIDEntry->cWaiters--; + } + + // mark the entry as busy by us + pOXIDEntry->dwFlags |= OXIDF_REGISTERINGOIDS; + + ASSERT_LOCK_HELD + return S_OK; +} + +//+------------------------------------------------------------------- +// +// Member: CRpcResolver::CheckForWaiters, private +// +// Synopsis: wakes up any threads waiting for this OXIDEntry +// +// History: 06-Nov-95 Rickhi Created. +// +//-------------------------------------------------------------------- +void CRpcResolver::CheckForWaiters(OXIDEntry *pOXIDEntry) +{ + ASSERT_LOCK_HELD + + if (pOXIDEntry->cWaiters > 0) + { + // some other thread is busy waiting for the current thread to + // finish registering so signal him that we are done. + + Win4Assert(_hEventOXID != NULL); + ComDebOut((DEB_TRACE,"CheckForWaiters signalling hEvent:%x\n", _hEventOXID)); + SetEvent(_hEventOXID); + } + + // mark the entry as no longer busy by us + pOXIDEntry->dwFlags &= ~OXIDF_REGISTERINGOIDS; + + ASSERT_LOCK_HELD +} + +//+------------------------------------------------------------------- +// +// Member: CRpcResolver::ServerAllocMoreOIDs, private +// +// Synopsis: register Object ID with the local ping server +// +// History: 20-Feb-95 Rickhi Created. +// +//-------------------------------------------------------------------- +HRESULT CRpcResolver::ServerAllocMoreOIDs(ULONG *pcPreRegOidsAvail, + OID *parPreRegOids, + OXIDEntry *pOXIDEntry) +{ + ComDebOut((DEB_OXID,"ServerAllocMoreOIDs\n")); + ASSERT_LOCK_HELD + Win4Assert(_ph != NULL); + + // wait until no other threads are calling ServerAllocOIDs + HRESULT hr = WaitForOXIDEntry(pOXIDEntry); + + if (SUCCEEDED(hr)) + { + if (*pcPreRegOidsAvail == 0) + { + // need to really go get more + hr = ServerAllocOIDs(pOXIDEntry, + pcPreRegOidsAvail, + parPreRegOids); + } + + // wakeup any waiters + CheckForWaiters(pOXIDEntry); + } + + ComDebOut((DEB_OXID, "ServerAllocMoreOIDs hr:%x\n", hr)); + ComDebErr(hr != S_OK, "ServerAllocMoreOIDs Failed.\n"); + return hr; +} + +//+------------------------------------------------------------------- +// +// Member: CRpcResolver::ServerAllocOIDs, private +// +// Synopsis: allocate Object IDs from the local ping server +// +// History: 20-Feb-95 Rickhi Created. +// +//-------------------------------------------------------------------- +HRESULT CRpcResolver::ServerAllocOIDs(OXIDEntry *pOXIDEntry, + ULONG *pcPreRegOidsAvail, + OID *parPreRegOids) +{ + HRESULT hr; + + // make up a list of pre-registered OIDs on our stack so multiple + // threads executing here simultaneously are not a problem. + + ULONG cOidsToAllocate = MAX_PREREGISTERED_OIDS; + OID arNewOidList[MAX_PREREGISTERED_OIDS]; + + if (!(pOXIDEntry->dwFlags & OXIDF_REGISTERED)) + { + // have not yet registered the OXID, so go do that at the same time + // we allocate OIDs. + + hr = ServerRegisterOXID(pOXIDEntry, &cOidsToAllocate, arNewOidList); + } + else + { + // just need to allocate more OIDs. + + OXID oxid; + OXIDFromMOXID(pOXIDEntry->moxid, &oxid); + + UNLOCK + ASSERT_LOCK_RELEASED + + do + { + hr = ::ServerAllocateOIDs( + _hRpc, // Rpc binding handle + _ph, // context handle + &oxid, // OXID of server + cOidsToAllocate, // count of OIDs requested + arNewOidList, // array of reserved oids + &cOidsToAllocate);// count actually allocated + + } while (RetryRPC(hr)); + + // map Rpc status if necessary + hr = CheckStatus(hr); + + ASSERT_LOCK_RELEASED + LOCK + } + + if (SUCCEEDED(hr)) + { + // copy the newly created OIDs into the list in whatever space + // is still available, since some other thread could have come + // along and pre-registered OIDs simultaneously (in MTA apartment + // only). The OIDs that are not copied will be lost and + // eventually the resolver will run them down. This should be + // relatively rare. + + LONG cToCopy = min(cOidsToAllocate, + MAX_PREREGISTERED_OIDS - *pcPreRegOidsAvail); + + memcpy(parPreRegOids + *pcPreRegOidsAvail, + arNewOidList, + sizeof(OID) * cToCopy); + + *pcPreRegOidsAvail += cToCopy; + } + + return hr; +} + +//+------------------------------------------------------------------- +// +// Member: CRpcResolver::ServerRegisterOXID, public +// +// Synopsis: allocate an OXID and Object IDs with the local ping server +// +// History: 20-Feb-95 Rickhi Created. +// +//-------------------------------------------------------------------- +HRESULT CRpcResolver::ServerRegisterOXID(OXIDEntry *pOXIDEntry, + ULONG *pcOidsToAllocate, + OID arNewOidList[]) +{ + ComDebOut((DEB_OXID, "ServerRegisterOXID TID:%x\n", GetCurrentThreadId())); + ASSERT_LOCK_HELD + + // OXID has not yet been registered with the resolver, do that + // now along with pre-registering a bunch of OIDs. + + // make sure we have the local binding and security strings + HRESULT hr = StartListen(); + ComDebErr(hr != S_OK, "StartListen Failed.\n"); + + if (hr == S_OK) + { + OXID_INFO oxidInfo; + oxidInfo.dwTid = pOXIDEntry->dwTid; + oxidInfo.dwPid = pOXIDEntry->dwPid; + oxidInfo.ipidRemUnknown = pOXIDEntry->ipidRundown; + oxidInfo.dwAuthnHint = gAuthnLevel; + oxidInfo.psa = NULL; + + + DUALSTRINGARRAY *psaSB = gpsaCurrentProcess; // string bindings + DUALSTRINGARRAY *psaSC = gpsaSecurity; // security bindings + + if (_dwFlags & ORF_STRINGSREGISTERED) + { + // already registered these once, dont need to do it again. + psaSB = NULL; + psaSC = NULL; + } + + OXID oxid; + + ComDebOut((DEB_OXID,"ServerRegisterOXID oxidInfo:%x psaSB:%x psaSC:%x\n", + &oxidInfo, psaSB, psaSC)); + + UNLOCK + ASSERT_LOCK_RELEASED + + do + { + hr = ::ServerAllocateOXIDAndOIDs( + _hRpc, // Rpc binding handle + _ph, // context handle + &oxid, // OXID of server + IsSTAThread(), // fApartment Threaded + *pcOidsToAllocate, // count of OIDs requested + arNewOidList, // array of reserved oids + pcOidsToAllocate, // count actually allocated + &oxidInfo, // OXID_INFO to register + psaSB, // string bindings for process + psaSC); // security bindings for process + + } while (RetryRPC(hr)); + + // map Rpc status if necessary + hr = CheckStatus(hr); + + ASSERT_LOCK_RELEASED + LOCK + + if (hr == S_OK) + { + // mark the OXID as registered with the resolver, and replace + // the (temporarily zero) oxid with the real one the resolver + // returned to us. + + pOXIDEntry->dwFlags |= OXIDF_REGISTERED; + MOXIDFromOXIDAndMID(oxid, gLocalMid, &pOXIDEntry->moxid); + } + } + + ComDebOut((DEB_OXID, "ServerRegisterOXID hr:%x\n", hr)); + ASSERT_LOCK_HELD + return hr; +} + +//+------------------------------------------------------------------- +// +// Member: CRpcResolver::ServerFreeOXID, public +// +// Synopsis: frees an OXID and associated OIDs that were pre-registered +// with the local ping server +// +// History: 20-Jan-96 Rickhi Created. +// +//-------------------------------------------------------------------- +HRESULT CRpcResolver::ServerFreeOXID(OXIDEntry *pOXIDEntry) +{ + ComDebOut((DEB_OXID, "ServerFreeOXID TID:%x\n", GetCurrentThreadId())); + ASSERT_LOCK_HELD + + if (!(pOXIDEntry->dwFlags & OXIDF_REGISTERED)) + { + // OXID was never registered, just return + return S_OK; + } + + // Free any pre-registered OIDs since these are registered for the + // current OXID. We get a new OXID if the thread is re-initialized. + // Set the ptr and count of Oids to de-register. + + ULONG cOids; + OID *pOids; + + COleTls tls; + if (!(tls->dwFlags & OLETLS_APARTMENTTHREADED)) + { + pOids = _arPreRegOids; + cOids = _cPreRegOidsAvail; + _cPreRegOidsAvail = 0; + } + else + { + cOids = tls->cPreRegOidsAvail; + tls->cPreRegOidsAvail = 0; + + pOids = tls->pPreRegOids; + tls->pPreRegOids = NULL; + } + + // extract the OXID and mark the OXIDEntry as no longer registered + OXID oxid; + OXIDFromMOXID(pOXIDEntry->moxid, &oxid); + pOXIDEntry->dwFlags &= ~OXIDF_REGISTERED; + + + UNLOCK + ASSERT_LOCK_RELEASED + + // call the resolver. + HRESULT hr; + + do + { + Win4Assert(_ph != NULL); + + hr = ::ServerFreeOXIDAndOIDs( + _hRpc, // Rpc binding handle + _ph, // context handle + oxid, // OXID of server + cOids, // count of OIDs to de-register + pOids); // ptr to OIDs to de-register + + } while (RetryRPC(hr)); + + ASSERT_LOCK_RELEASED + LOCK + + // map Rpc status if necessary + hr = CheckStatus(hr); + + if (tls->dwFlags & OLETLS_APARTMENTTHREADED) + { + // delete the space allocated for the pre-registered OIDs + PrivMemFree(pOids); + } + + ComDebOut((DEB_OXID, "ServerFreeOXID hr:%x\n", hr)); + ASSERT_LOCK_HELD + return hr; +} + +//+------------------------------------------------------------------- +// +// Member: CRpcResolver::ClientResolveOXID, public +// +// Synopsis: Resolve client-side OXID and returns the OXIDEntry, AddRef'd. +// +// History: 20-Feb-95 Rickhi Created. +// +//-------------------------------------------------------------------- +HRESULT CRpcResolver::ClientResolveOXID(REFOXID roxid, + DUALSTRINGARRAY *psaResolver, + OXIDEntry **ppOXIDEntry) +{ + ComDebOut((DEB_OXID,"ClientResolveOXID oxid:%08x %08x psa:%x\n", + roxid, psaResolver)); + ASSERT_LOCK_HELD + RPC_STATUS sc = RPC_S_OK; + + *ppOXIDEntry = NULL; + + // Look for a MID entry for the resolver. if we cant find it + // then we know we dont have an OXIDEntry for the oxid. + + DWORD dwHash; + MIDEntry *pMIDEntry = gMIDTbl.LookupMID(psaResolver, &dwHash); + if (pMIDEntry) + { + // found the MID, now look for the OXID + *ppOXIDEntry = gOXIDTbl.LookupOXID(roxid, pMIDEntry->mid); + } + + if (*ppOXIDEntry == NULL) + { + // didn't find the OXIDEntry in the table so we need to resolve it. + + UNLOCK + ASSERT_LOCK_RELEASED + + MID mid; + OXID_INFO oxidInfo; + oxidInfo.psa = NULL; + + do + { + Win4Assert(_ph != NULL); + + sc = ::ClientResolveOXID( + _hRpc, // Rpc binding handle + _ph, // context handle + (OXID *)&roxid, // OXID of server + psaResolver, // resolver binging strings + IsSTAThread(), // fApartment threaded + // GetCallAuthnLevel(), CODEWORK: someday + &oxidInfo, // resolver info returned + &mid); // mid for the machine + + } while (RetryRPC(sc)); + + ASSERT_LOCK_RELEASED + LOCK + + // map Rpc status if necessary + sc = CheckStatus(sc); + + if (SUCCEEDED(sc)) + { + // create an OXIDEntry. + sc = FindOrCreateOXIDEntry(roxid, oxidInfo, FOCOXID_REF, + psaResolver, + mid, pMIDEntry, ppOXIDEntry); + + // free the returned string bindings + MIDL_user_free(oxidInfo.psa); + } + } + + if (pMIDEntry) + { + DecMIDRefCnt(pMIDEntry); + } + + ASSERT_LOCK_HELD + ComDebOut((DEB_OXID,"ClientResolveOXID hr:%x pOXIDEntry:%x\n", + sc, *ppOXIDEntry)); + return sc; +} + +//+------------------------------------------------------------------- +// +// Function: FillLocalOXIDInfo +// +// Synopsis: Fills in a OXID_INFO structure for the current apartment. +// Used by the Drag & Drop code to register with the resolver. +// +// History: 20-Feb-95 Rickhi Created. +// +//-------------------------------------------------------------------- +HRESULT FillLocalOXIDInfo(OBJREF &objref, OXID_INFO &oxidInfo) +{ + // extract the OXIDEntry from the objref + OXIDEntry *pOXIDEntry = GetOXIDFromObjRef(objref); + Win4Assert(pOXIDEntry); + + // fill in the fields of the OXID_INFO structure. + oxidInfo.dwTid = pOXIDEntry->dwTid; + oxidInfo.dwPid = pOXIDEntry->dwPid; + oxidInfo.ipidRemUnknown = pOXIDEntry->ipidRundown; + oxidInfo.dwAuthnHint = RPC_C_AUTHN_LEVEL_NONE; + + HRESULT hr = GetStringBindings(&oxidInfo.psa); + ComDebErr(hr != S_OK, "GetStringBindings Failed.\n"); + return (hr); +} + +//+------------------------------------------------------------------- +// +// Function: AddToList / RemoveFromList +// +// Synopsis: adds or removes an SOIDRegistration entry to/from +// a doubly linked list. +// +// History: 30-Oct-95 Rickhi Created. +// +//-------------------------------------------------------------------- +void AddToList(SOIDRegistration *pOIDReg, SOIDRegistration* pOIDListHead) +{ + pOIDReg->pPrevList = pOIDListHead; + pOIDListHead->pNextList->pPrevList = pOIDReg; + pOIDReg->pNextList = pOIDListHead->pNextList; + pOIDListHead->pNextList = pOIDReg; +} + +void RemoveFromList(SOIDRegistration *pOIDReg) +{ + pOIDReg->pPrevList->pNextList = pOIDReg->pNextList; + pOIDReg->pNextList->pPrevList = pOIDReg->pPrevList; + pOIDReg->pPrevList = pOIDReg; + pOIDReg->pNextList = pOIDReg; +} + +//+------------------------------------------------------------------- +// +// Member: CRpcResolver::ClientRegisterOIDWithPingServer +// +// Synopsis: registers an OID with the Ping Server if it has +// not already been registered. +// +// History: 30-Oct-95 Rickhi Created. +// +//-------------------------------------------------------------------- +HRESULT CRpcResolver::ClientRegisterOIDWithPingServer(REFOID roid, + OXIDEntry *pOXIDEntry) +{ + ComDebOut((DEB_OXID, "ClientRegisterOIDWithPingServer poid:%x\n", &roid)); + ASSERT_LOCK_HELD + AssertValid(); + HRESULT hr = S_OK; + + // make a MOID from the OID + MOID moid; + MOIDFromOIDAndMID(roid, pOXIDEntry->pMIDEntry->mid, &moid); + + + // see if this OID already has a client-side registration + // record created by another apartment in this process. + + DWORD iHash = gClientRegisteredOIDs.Hash(moid); + SOIDRegistration *pOIDReg = (SOIDRegistration *) + gClientRegisteredOIDs.Lookup(iHash, moid); + + if (pOIDReg == NULL) + { + // not yet registered with resolver, create a new entry and + // add it to the hash table and to the List of items to register + // with the Resolver. + + // make sure we have a worker thread ready to do the register + // at some point in the future. + hr = EnsureWorkerThread(); + + if (SUCCEEDED(hr)) + { + hr = E_OUTOFMEMORY; + pOIDReg = new SOIDRegistration; + + if (pOIDReg) + { + pOIDReg->cRefs = 1; + pOIDReg->pPrevList = pOIDReg; + pOIDReg->pNextList = pOIDReg; + pOIDReg->pOXIDEntry = pOXIDEntry; + + gClientRegisteredOIDs.Add(iHash, moid, (SUUIDHashNode *)pOIDReg); + + pOIDReg->flags = ROIDF_REGISTER; + AddToList(pOIDReg, &_ClientOIDRegList); + _cOidsToAdd++; + + hr = S_OK; + } + } + } + else + { + // already have a record for this OID, inc the refcnt + pOIDReg->cRefs++; + + if (pOIDReg->cRefs == 1) + { + // re-using an entry that had a count of zero, so it must have + // been going to be deregistered or pinged. + Win4Assert((pOIDReg->flags == ROIDF_PING) || + (pOIDReg->flags == ROIDF_DEREGISTER)); + + _cOidsToRemove--; + + if (pOIDReg->flags & ROIDF_PING) + { + // was only going to be pinged, now must be added. + pOIDReg->flags |= ROIDF_REGISTER; + } + else + { + // was going to be unregistered, already registered so does + // not need to be on the registration list anymmore + + Win4Assert(pOIDReg->flags & ROIDF_DEREGISTER); + pOIDReg->flags = 0; + RemoveFromList(pOIDReg); + } + } + } + + AssertValid(); + ASSERT_LOCK_HELD + ComDebOut((DEB_OXID,"ClientRegisterOIDWithPingServer pOIDReg:%x hr:%x\n", + pOIDReg, hr)); + return hr; +} + +//+------------------------------------------------------------------- +// +// Member: CRpcResolver::ClientDeRegisterOIDWithPingServer +// +// Synopsis: de-registers an OID that has previously been registered +// with the Ping Server +// +// History: 30-Oct-95 Rickhi Created. +// +//-------------------------------------------------------------------- +HRESULT CRpcResolver::ClientDeRegisterOIDFromPingServer(REFMOID rmoid, + BOOL fMarshaled) +{ + ComDebOut((DEB_OXID,"ClientDeRegisterOIDWithPingServer rmoid:%I\n", &rmoid)); + ASSERT_LOCK_HELD + AssertValid(); + + // find the OID in the hash table. it better still be there! + + DWORD iHash = gClientRegisteredOIDs.Hash(rmoid); + SOIDRegistration *pOIDReg = (SOIDRegistration *) + gClientRegisteredOIDs.Lookup(iHash, rmoid); + Win4Assert(pOIDReg != NULL); + Win4Assert((pOIDReg->flags == ROIDF_REGISTER) || + (pOIDReg->flags == (ROIDF_REGISTER | ROIDF_PING)) || + (pOIDReg->flags == 0)); + + if (-- pOIDReg->cRefs == 0) + { + // this was the last registration of the OID in this process. + + if (pOIDReg->flags & ROIDF_REGISTER) + { + // still on the Register list, have not yet told the Ping Server + // about this OID so dont have to do anything unless it was + // client-side marshaled. + + if (fMarshaled || pOIDReg->flags & ROIDF_PING) + { + // object was marshaled by the client. Still need to tell + // the Ping Server to ping the OID then forget about it. + + pOIDReg->flags = ROIDF_PING; + _cOidsToRemove++; + + // make sure we have a worker thread ready to do the deregister + // at some point in the future. Not much we can do about an + // error here. If transient, then a thread will most likely + // be created later. + EnsureWorkerThread(); + } + else + { + // dont need this record any longer. remove from chain + // and delete the record. + + RemoveFromList(pOIDReg); + _cOidsToAdd--; + gClientRegisteredOIDs.Remove((SHashChain *)pOIDReg); + delete pOIDReg; + } + } + else + { + // must already be registered with the resolver. now need to + // deregister it so put it on the Registration list for delete. + + pOIDReg->flags = ROIDF_DEREGISTER; + AddToList(pOIDReg, &_ClientOIDRegList); + _cOidsToRemove++; + + // make sure we have a worker thread ready to do the deregister + // at some point in the future. Not much we can do about an + // error here. If transient, then a thread will most likely + // be created later. + EnsureWorkerThread(); + } + } + + AssertValid(); + ASSERT_LOCK_HELD + ComDebOut((DEB_OXID,"ClientDeRegisterOIDWithPingServer pOIDReg:%x hr:%x\n", + pOIDReg, S_OK)); + return S_OK; +} + +//+------------------------------------------------------------------- +// +// Member: CRpcResolver::ClientBulkUpdateOIDWithPingServer +// +// Synopsis: registers/deregisters/pings any OIDs waiting to be +// sent to the ping server. +// +// History: 30-Oct-95 Rickhi Created. +// +//-------------------------------------------------------------------- +HRESULT CRpcResolver::ClientBulkUpdateOIDWithPingServer(void) +{ + ComDebOut((DEB_OXID, "ClientBulkUpdateOIDWithPingServer\n")); + ASSERT_LOCK_HELD + AssertValid(); + Win4Assert(_cOidsToAdd + _cOidsToRemove != 0); + + // Copy the counters so we can reset them before we make the call. + // Allocate space for the Add, Status, and Remove lists to send to the + // ping server, and remember the start address so we can free the + // memory later. Compute the address of the other lists within the + // one allocated memory block. + + ULONG cOidsToAdd = _cOidsToAdd; + ULONG cOidsToRemove = _cOidsToRemove; + ULONG cOxidsToRemove = gOXIDTbl.NumOxidsToRemove(); + + ULONG cBytesToAlloc = (cOidsToAdd * (sizeof(OXID_OID_PAIR)+sizeof(ULONG))) + + (cOidsToRemove * sizeof(OID_MID_PAIR)) + + (cOxidsToRemove * sizeof(OXID_REF)); + + OXID_OID_PAIR *pOidsToAdd = (OXID_OID_PAIR *)PrivMemAlloc(cBytesToAlloc); + if (pOidsToAdd == NULL) + { + // cant allocate memory. Leave the registration lists alone for + // now, this may be a transient problem and we can handle the + // registration later (unless of course the problem persists and + // our object is run down!). + + UNLOCK + ASSERT_LOCK_RELEASED + ComDebOut((DEB_ERROR, "ClientBulkUpdate OOM\n")); + return E_OUTOFMEMORY; + } + + OXID_OID_PAIR *pOidsToAddStart = pOidsToAdd; + LONG *pStatusOfAdds = (LONG *) (&pOidsToAdd[cOidsToAdd]); + OID_MID_PAIR *pOidsToRemove = (OID_MID_PAIR *)(&pStatusOfAdds[cOidsToAdd]); + OXID_REF *pOxidsToRemove = (OXID_REF *) (&pOidsToRemove[cOidsToRemove]); + + + // loop through each OID registration records in the list filling in + // the Add and Remove lists. Pinged OIDs are placed in both lists. + + while (_ClientOIDRegList.pNextList != &_ClientOIDRegList) + { + // get the entry and remove it from the registration list + SOIDRegistration *pOIDReg = _ClientOIDRegList.pNextList; + RemoveFromList(pOIDReg); + + // reset the state flags before we begin + DWORD dwFlags = pOIDReg->flags; + pOIDReg->flags = 0; + + if (dwFlags & (ROIDF_REGISTER | ROIDF_PING)) + { + // register the OID with the ping server + MIDFromMOXID (pOIDReg->pOXIDEntry->moxid, &pOidsToAdd->mid); + OXIDFromMOXID(pOIDReg->pOXIDEntry->moxid, &pOidsToAdd->oxid); + OIDFromMOID (pOIDReg->Node.key, &pOidsToAdd->oid); + + pOidsToAdd++; + _cOidsToAdd--; + } + + if (dwFlags == ROIDF_DEREGISTER || dwFlags == ROIDF_PING) + { + // deregister the OID with the ping server + // Node.key is the OID+MID so extract each part + MIDFromMOID(pOIDReg->Node.key, &pOidsToRemove->mid); + OIDFromMOID(pOIDReg->Node.key, &pOidsToRemove->oid); + + pOidsToRemove++; + _cOidsToRemove--; + + // dont need the entry any more since there are no more + // users of it. remove from hash table and delete it. + gClientRegisteredOIDs.Remove((SHashChain *)pOIDReg); + delete pOIDReg; + } + } + + // Ask the OXID table to fill in the list of OXIDs to remove. + gOXIDTbl.GetOxidsToRemove( pOxidsToRemove, &cOxidsToRemove ); + + // make sure we got all the entries and that our counters work correctly. + Win4Assert(_cOidsToAdd == 0); + Win4Assert(_cOidsToRemove == 0); + AssertValid(); + + UNLOCK + ASSERT_LOCK_RELEASED + + // reset the OidsToRemove list pointer since we mucked with it above. + pOidsToRemove = (OID_MID_PAIR *) (&pStatusOfAdds[cOidsToAdd]); + + RPC_STATUS sc; + + do + { + // call the Resolver. + sc = BulkUpdateOIDs(_hRpc, // Rpc binding handle + _ph, // context handle + cOidsToAdd, // #oids to add + pOidsToAddStart, // ptr to oids to add + pStatusOfAdds, // status of adds + cOidsToRemove, // #oids to remove + pOidsToRemove, // ptr to oids to remove + 0, 0, // ptr to oids to free + cOxidsToRemove, // #oxids to remove + pOxidsToRemove); // ptr to oxids to remove + + } while (RetryRPC(sc)); + + // map status if necessary + sc = CheckStatus(sc); + + // CODEWORK: reset the status flags for any OIDs not successfully added + // to the resolver. + + // release the memory allocated above + PrivMemFree(pOidsToAddStart); + +#if DBG==1 + LOCK + AssertValid(); + UNLOCK +#endif + ASSERT_LOCK_RELEASED + ComDebOut((DEB_OXID, "ClientBulkUpdateOIDWithPingServer hr:%x\n", S_OK)); + return S_OK; +} + +//+------------------------------------------------------------------- +// +// Member: CRpcResolver::EnsureWorkerThread +// +// Synopsis: Make sure there is a worker thread. Create one if +// necessary. +// +// History: 06-Nov-95 Rickhi Created. +// +//-------------------------------------------------------------------- +HRESULT CRpcResolver::EnsureWorkerThread(void) +{ + ASSERT_LOCK_HELD + HRESULT hr = S_OK; + + if (_hThrd == NULL) + { + // no worker thread currently exists, try to create one. First, make + // sure that we have a connection to the resolver. + + hr = GetConnection(); + + if (SUCCEEDED(hr)) + { + DWORD dwThrdId; + _hThrd = CreateThread(NULL, 0, + WorkerThreadLoop, + 0, 0, &dwThrdId); + if (_hThrd) + { + // although the handle is closed, it is NOT nulled until + // the worker thread exits. That is the signal that there + // is no more worker thread and we may need to allocate + // another one. + + CloseHandle(_hThrd); + } + else + { + // unable to create worker thread + hr = HRESULT_FROM_WIN32(GetLastError()); + ComDebOut((DEB_ERROR,"Create Resolver worker thread hr:%x\n",hr)); + } + } + } + + ASSERT_LOCK_HELD + return hr; +} + +//+------------------------------------------------------------------- +// +// Member: CRpcResolver::WorkerThreadLoop +// +// Synopsis: Worker thread for doing lazy/bulk OID registration +// with the ping server. +// +// History: 06-Nov-95 Rickhi Created. +// +//-------------------------------------------------------------------- +DWORD _stdcall CRpcResolver::WorkerThreadLoop(void *param) +{ + // First thing we need to do is LoadLibrary ourselves in order to + // prevent our code from going away while this worker thread exists. + // The library will be freed when this thread exits. + + HINSTANCE hInst = LoadLibrary(L"OLE32.DLL"); + + while (TRUE) + { + // sleep for a while to let the OIDs batch up in the registration list + Sleep(_dwSleepPeriod); + + ASSERT_LOCK_RELEASED + LOCK + + if (_cOidsToAdd == 0 && _cOidsToRemove == 0) + { + // There is no work to do. Exit this thread. If we need to + // register more oids later we will spin up another thread. + + _hThrd = NULL; + UNLOCK + break; + } + + ASSERT_LOCK_HELD + ClientBulkUpdateOIDWithPingServer(); + ASSERT_LOCK_RELEASED + } + + // Simultaneously free our Dll and exit our thread. This allows us to + // keep our Dll around incase a remote call was is progress and the + // worker thread is still blocked on the call, and allows us to cleanup + // properly when all threads are done with the code. + + ASSERT_LOCK_RELEASED + FreeLibraryAndExitThread(hInst, 0); + + // compiler wants a return value + return 0; +} + +#if DBG==1 +//+------------------------------------------------------------------- +// +// Member: CRpcResolver::AssertValid +// +// Synopsis: validates the state of this object +// +// History: 30-Oct-95 Rickhi Created. +// +//-------------------------------------------------------------------- +void CRpcResolver::AssertValid(void) +{ + ASSERT_LOCK_HELD + + Win4Assert((_cOidsToAdd & 0xf0000000) == 0x00000000); + Win4Assert((_cOidsToRemove & 0xf0000000) == 0x00000000); + + if (_cOidsToAdd == 0 && _cOidsToRemove == 0) + { + // make sure the Reg list is empty. + Win4Assert(_ClientOIDRegList.pPrevList == &_ClientOIDRegList); + Win4Assert(_ClientOIDRegList.pNextList == &_ClientOIDRegList); + } + else + { + // make sure we have a worker thread. we cant assert because + // we could be OOM trying to create the thread. + if (_hThrd == NULL) + { + ComDebOut((DEB_WARN, "No Resolver Worked Thread\n")); + } + + // make sure the Reg list is consistent with the counters + ULONG cAdd = 0; + ULONG cRemove = 0; + + SOIDRegistration *pOIDReg = _ClientOIDRegList.pNextList; + while (pOIDReg != &_ClientOIDRegList) + { + // make sure the flags are valid + Win4Assert(pOIDReg->flags == ROIDF_REGISTER || + pOIDReg->flags == ROIDF_DEREGISTER || + pOIDReg->flags == ROIDF_PING || + pOIDReg->flags == (ROIDF_PING | ROIDF_REGISTER)); + + if (pOIDReg->flags & (ROIDF_REGISTER | ROIDF_PING)) + { + // OID is to be registered + cAdd++; + } + + if (pOIDReg->flags == ROIDF_DEREGISTER || + pOIDReg->flags == ROIDF_PING) + { + // OID is to be deregistered + cRemove++; + } + + pOIDReg = pOIDReg->pNextList; + } + + Win4Assert(cAdd == _cOidsToAdd); + Win4Assert(cRemove == _cOidsToRemove); + } + + ASSERT_LOCK_HELD +} +#endif + +//+------------------------------------------------------------------------ +// +// Function: MakeSCMProxy, public +// +// Synopsis: Creates an OXIDEntry and a proxy for the SCM Activation +// Interface. +// +// History: 14 Apr 95 AlexMit Created +// +//------------------------------------------------------------------------- +INTERNAL MakeSCMProxy(DUALSTRINGARRAY *psaSCM, REFIID riid, void **ppSCM) +{ + ComDebOut((DEB_OXID, "MakeSCMProxy psaSCM:%x ppSCM:%x\n", psaSCM, ppSCM)); + Win4Assert(gdwScmProcessID != 0); + + // Init out parameter + *ppSCM = NULL; + + // Make a fake OXIDEntry for the SCM. + OXID_INFO oxidInfo; + oxidInfo.dwTid = 0; + oxidInfo.dwPid = gdwScmProcessID; + oxidInfo.ipidRemUnknown = GUID_NULL; + oxidInfo.psa = psaSCM; + oxidInfo.dwAuthnHint = RPC_C_AUTHN_LEVEL_NONE; + + LOCK + + OXIDEntry *pOXIDEntry; + MIDEntry *pMIDEntry; + + HRESULT hr = GetLocalMIDEntry(&pMIDEntry); // not AddRef'd + + if (SUCCEEDED(hr)) + { + // Make a fake OXID for the SCM. We can use any ID that the resolver + // hands out as the OXID for the SCM. + + OXID oxid; + hr = gResolver.ServerGetReservedID(&oxid); + + if (SUCCEEDED(hr)) + { + hr = gOXIDTbl.AddEntry(oxid, &oxidInfo, pMIDEntry, &pOXIDEntry); + } + + if (SUCCEEDED(hr)) + { + // Make an object reference for the SCM. The oid and ipid dont + // matter, except the OID must be machine-unique. + + IPID ipidTmp; + UuidCreate(&ipidTmp); // fake the IPID + + OBJREF objref; + hr = MakeFakeObjRef(objref, pOXIDEntry, ipidTmp, riid); + + if (SUCCEEDED(hr)) + { + // now unmarshal the objref to create a proxy to the SCM. + // use the internal form to reduce initialization time. + UNLOCK + hr = UnmarshalInternalObjRef(objref, ppSCM); + + if (SUCCEEDED(hr) && gImpLevel != RPC_C_IMP_LEVEL_IMPERSONATE) + { + // Make sure SCM can impersonate us. + hr = CoSetProxyBlanket( (IUnknown *) *ppSCM, + RPC_C_AUTHN_WINNT, + RPC_C_AUTHZ_NONE, NULL, + RPC_C_AUTHN_LEVEL_CONNECT, + RPC_C_IMP_LEVEL_IMPERSONATE, + NULL, EOAC_NONE ); + + if (FAILED(hr)) + { + ((IUnknown *) (*ppSCM))->Release(); + *ppSCM = NULL; + } + } + + LOCK + } + + // release the reference to the OXIDEntry from AddEntry, since + // UnmarshalInternalObjRef added another one if it was successful. + DecOXIDRefCnt(pOXIDEntry); + } + } + + UNLOCK + ComDebOut((DEB_OXID, "MakeSCMProxy hr:%x *ppSCM:%x\n", hr, *ppSCM)); + return hr; +} + +//+------------------------------------------------------------------------- +// +// Member: CRpcResolver::BindToSCMProxy +// +// Synopsis: Get a proxy to the SCM Activation interface. +// +// History: 19-May-95 Rickhi Created +// +// Notes: The SCM activation interface is an ORPC interface so that +// apartment model apps can receive callbacks and do cancels +// while activating object servers. +// +//-------------------------------------------------------------------------- +HRESULT CRpcResolver::BindToSCMProxy() +{ + ComDebOut((DEB_ACTIVATE, "CRpcResolver::BindToSCMProxy")); + + // since we are calling out on this thread, we have to ensure that the + // call control is set up for this thread. + + HRESULT hr = InitChannelIfNecessary(); + if (FAILED(hr)) + return hr; + + COleStaticLock lck(gmxsOleMisc); + + if (IsSTAThread()) + { + if (_pSCMSTA == NULL) + { + // Make a proxy to the SCM + hr = MakeSCMProxy((DUALSTRINGARRAY *)&saSCM, IID_IDSCM, (void **) &_pSCMSTA); + } + } + else + { + if (_pSCMMTA == NULL) + { + // Make a proxy to the SCM + hr = MakeSCMProxy((DUALSTRINGARRAY *)&saSCM, IID_IDSCM, (void **) &_pSCMMTA); + } + } + + ComDebOut((SUCCEEDED(hr) ? DEB_SCM : DEB_ERROR, + "CCoScm::BindToSCMProxy for IDSCM returns %x.\n", hr)); + return hr; +} + +//+------------------------------------------------------------------------- +// +// Member: CRpcResolver::NotifyStarted +// +// Synopsis: Notify the SCM that a class has been started +// +// Arguments: [rclsid] - class started +// [dwFlags] - whether class is multiple use or not. +// +// History: 19-May-92 Ricksa Created +// +//-------------------------------------------------------------------------- +HRESULT CRpcResolver::NotifyStarted( + RegInput *pRegIn, + RegOutput **ppRegOut) +{ + ComDebOut((DEB_ACTIVATE, "CRpcResolver::NotifyStarted")); + + // Bind to the SCM if that hasn't already happened + HRESULT hr = GetConnection(); + if (FAILED(hr)) + return hr; + + error_status_t rpcstat; + WCHAR * pwszWinstaDesktop; + + hr = GetWinstaDesktop( &pwszWinstaDesktop ); + + if ( FAILED(hr) ) + return hr; + + do + { + hr = ServerRegisterClsid( + _hRpc, + _ph, + pwszWinstaDesktop, + pRegIn, + ppRegOut, + &rpcstat ); + + } while (RetryRPC(rpcstat)); + + if ( pwszWinstaDesktop != _pwszWinstaDesktop ) + PrivMemFree( pwszWinstaDesktop ); + + ComDebOut(( (hr == S_OK) ? DEB_SCM : DEB_ERROR, + "Class Registration returned %x", hr)); + + if (rpcstat != RPC_S_OK) + { + hr = HRESULT_FROM_WIN32(rpcstat); + } + + return hr; +} + +//+------------------------------------------------------------------------- +// +// Member: CRpcResolver::NotifyStopped +// +// Synopsis: Notify the SCM that the server is stopped. +// +// History: 19-May-92 Ricksa Created +// +//-------------------------------------------------------------------------- +void CRpcResolver::NotifyStopped( + REFCLSID rclsid, + DWORD dwReg) +{ + ComDebOut((DEB_ACTIVATE, "CRpcResolver::NotifyStopped")); + + error_status_t rpcstat; + + RevokeClasses revcls; + revcls.dwSize = 1; + revcls.revent[0].clsid = rclsid; + revcls.revent[0].dwReg = dwReg; + + do + { + ServerRevokeClsid(_hRpc, _ph, &revcls, &rpcstat); + + } while (RetryRPC(rpcstat)); +} + +//+------------------------------------------------------------------------- +// +// Member: CRpcResolver::GetClassObject +// +// Synopsis: Send a get object request to the SCM +// +// Arguments: [rclsid] - class id for class object +// [dwCtrl] - type of server required +// [ppIFDClassObj] - marshaled buffer for class object +// [ppwszDllToLoad] - DLL name to use for server +// +// Returns: S_OK +// +// History: 20-May-93 Ricksa Created +// +//-------------------------------------------------------------------------- +HRESULT CRpcResolver::GetClassObject( + REFCLSID rclsid, + DWORD dwContext, + IID *pIID, + COSERVERINFO *pServerInfo, + MInterfacePointer **ppIFDClassObj, + DWORD *pdwDllServerType, + WCHAR **ppwszDllToLoad) +{ + ComDebOut((DEB_ACTIVATE, "CRpcResolver::GetClassObject")); + + HRESULT hr; + ACTIVATION_INFO ActivationInfo; + OXID OxidServer; + DUALSTRINGARRAY * pssaServerObjectResolverBindings; + OXID_INFO OxidInfo; + MID LocalMidOfRemote; + OXIDEntry * pOxidEntry; + LPWSTR pwszWinstaDesktop; + + hr = BindToSCMProxy(); + if (FAILED(hr)) + return hr; + + hr = GetWinstaDesktop( &pwszWinstaDesktop ); + + if ( FAILED(hr) ) + return hr; + + ActivationInfo.Clsid = &rclsid; + ActivationInfo.pServerInfo = pServerInfo; + ActivationInfo.pwszWinstaDesktop = pwszWinstaDesktop; + ActivationInfo.ClsContext = dwContext; + ActivationInfo.ProcessSignature = _dwProcessSignature; + ActivationInfo.bDynamicSecurity = _bDynamicSecurity; + + pssaServerObjectResolverBindings = 0; + OxidInfo.psa = 0; + pOxidEntry = 0; + + hr = GetSCM()->SCMGetClassObject( + &ActivationInfo, + pIID, + IsSTAThread(), + &OxidServer, + &pssaServerObjectResolverBindings, + &OxidInfo, + &LocalMidOfRemote, + ppIFDClassObj ); + + if ( pwszWinstaDesktop != _pwszWinstaDesktop ) + PrivMemFree( pwszWinstaDesktop ); + + if ( FAILED(hr) || (OxidServer == 0) ) + { + ComDebOut((DEB_ACTIVATE, "CRpcResolver::GetClassObject hr:%x", hr)); + return hr; + } + + ASSERT_LOCK_RELEASED + LOCK + + hr = FindOrCreateOXIDEntry( + OxidServer, + OxidInfo, + FOCOXID_REF, + pssaServerObjectResolverBindings, + LocalMidOfRemote, + NULL, + &pOxidEntry ); + + CoTaskMemFree(OxidInfo.psa); + CoTaskMemFree(pssaServerObjectResolverBindings); + + // + // CODEWORK CODEWORK CODEWORK + // + // These comments also apply to CreateInstance and GetPersistentInstance + // methods. + // + // Releasing the OXID and reacquiring it makes me a little + // nervous. The Expired list is fairly short, so if multiple guys are doing + // this simultaneously, the entries could get lost. I guess this is not + // too bad since it should be rare and the local resolver will have it + // anyway, but I think there is a window where the local resolver could + // lose it too, forcing a complete roundtrip back to the server. + // + // A better mechanism may be to pass the iid and ppunk into this method + // and do the unmarshal inside it. We could improve performance by calling + // UnmarshalObjRef instead of putting a stream wrapper around the + // MInterfacePointer and then calling CoUnmarshalInterface. It would avoid + // looking up the OXIDEntry twice, and would avoid the race where we could + // lose the OXIDEntry off the expired list. It would require a small + // change in UnmarshalObjRef to deal with the custom marshal case. + // + + // + // Decrement our ref. The interface unmarshall will do a LookupOXID + // which will increment the count and move the OXIDEntry back to the + // InUse list. + // + if ( pOxidEntry ) + DecOXIDRefCnt(pOxidEntry); + + UNLOCK + ASSERT_LOCK_RELEASED + + ComDebOut((DEB_ACTIVATE, "CRpcResolver::GetClassObject hr:%x", hr)); + return hr; +} + +//+------------------------------------------------------------------------- +// +// Member: CRpcResolver::CreateInstance +// +// Synopsis: Send a create instance request to the SCM +// +// Arguments: +// +// Returns: S_OK +// +// History: 20-May-93 Ricksa Created +// +//-------------------------------------------------------------------------- +HRESULT CRpcResolver::CreateInstance( + COSERVERINFO *pServerInfo, + CLSID *pClsid, + DWORD dwClsCtx, + DWORD dwCount, + IID *pIIDs, + MInterfacePointer **pRetdItfs, + HRESULT *pRetdHrs, + DWORD *pdwDllServerType, + OLECHAR **ppwszDllToLoad ) + +{ + ComDebOut((DEB_ACTIVATE, "CRpcResolver::CreateInstance")); + + HRESULT hr; + ACTIVATION_INFO ActivationInfo; + OXID OxidServer; + DUALSTRINGARRAY * pssaServerObjectResolverBindings; + OXID_INFO OxidInfo; + MID LocalMidOfRemote; + OXIDEntry * pOxidEntry; + LPWSTR pwszWinstaDesktop; + + hr = BindToSCMProxy(); + if (FAILED(hr)) + return hr; + + hr = GetWinstaDesktop( &pwszWinstaDesktop ); + + if ( FAILED(hr) ) + return hr; + + ActivationInfo.Clsid = pClsid; + ActivationInfo.pServerInfo = pServerInfo; + ActivationInfo.pwszWinstaDesktop = pwszWinstaDesktop; + ActivationInfo.ClsContext = dwClsCtx; + ActivationInfo.ProcessSignature = _dwProcessSignature; + ActivationInfo.bDynamicSecurity = _bDynamicSecurity; + + pssaServerObjectResolverBindings = 0; + OxidInfo.psa = 0; + pOxidEntry = 0; + + hr = GetSCM()->SCMCreateInstance( + &ActivationInfo, + dwCount, + pIIDs, + IsSTAThread(), + &OxidServer, + &pssaServerObjectResolverBindings, + &OxidInfo, + &LocalMidOfRemote, + pRetdItfs, + pRetdHrs ); + + if ( pwszWinstaDesktop != _pwszWinstaDesktop ) + PrivMemFree( pwszWinstaDesktop ); + + if ( FAILED(hr) || (OxidServer == 0) ) + { + ComDebOut((DEB_ACTIVATE, "CRpcResolver::CreateInstance hr:%x", hr)); + return hr; + } + + ASSERT_LOCK_RELEASED + LOCK + + hr = FindOrCreateOXIDEntry( + OxidServer, + OxidInfo, + FOCOXID_REF, + pssaServerObjectResolverBindings, + LocalMidOfRemote, + NULL, + &pOxidEntry ); + + CoTaskMemFree(OxidInfo.psa); + CoTaskMemFree(pssaServerObjectResolverBindings); + + // + // Decrement our ref. The interface unmarshall will do a LookupOXID + // which will increment the count and move the OXIDEntry back to the + // InUse list. + // + if ( pOxidEntry ) + DecOXIDRefCnt(pOxidEntry); + + UNLOCK + ASSERT_LOCK_RELEASED + + ComDebOut((DEB_ACTIVATE, "CRpcResolver::CreateInstance hr:%x", hr)); + return hr; +} + +//+------------------------------------------------------------------------- +// +// Member: CRpcResolver::GetPersistentInstance +// +// Synopsis: Send a get object request to the SCM +// +//GAJGAJ - fix this comment block +// Arguments: [rclsid] - class id for class object +// [dwCtrl] - type of server required +// [ppIFDClassObj] - marshaled buffer for class object +// [ppwszDllToLoad] - DLL name to use for server +// +// Returns: S_OK +// +// History: 20-May-93 Ricksa Created +// +//-------------------------------------------------------------------------- +HRESULT CRpcResolver::GetPersistentInstance( + COSERVERINFO * pServerInfo, + CLSID *pClsid, + DWORD dwClsCtx, + DWORD grfMode, + BOOL bFileWasOpened, + OLECHAR *pwszName, + MInterfacePointer *pstg, + DWORD dwCount, + IID *pIIDs, + BOOL * FoundInROT, + MInterfacePointer **pRetdItfs, + HRESULT *pRetdHrs, + DWORD *pdwDllServerType, + OLECHAR **ppwszDllToLoad ) +{ + ComDebOut((DEB_ACTIVATE, "CRpcResolver::GetPersistentInstance")); + + HRESULT hr; + ACTIVATION_INFO ActivationInfo; + OXID OxidServer; + DUALSTRINGARRAY * pssaServerObjectResolverBindings; + OXID_INFO OxidInfo; + MID LocalMidOfRemote; + OXIDEntry * pOxidEntry; + LPWSTR pwszWinstaDesktop; + + hr = BindToSCMProxy(); + if (FAILED(hr)) + return hr; + + hr = GetWinstaDesktop( &pwszWinstaDesktop ); + + if ( FAILED(hr) ) + return hr; + + ActivationInfo.Clsid = pClsid; + ActivationInfo.pServerInfo = pServerInfo; + ActivationInfo.pwszWinstaDesktop = pwszWinstaDesktop; + ActivationInfo.ClsContext = dwClsCtx; + ActivationInfo.ProcessSignature = _dwProcessSignature; + ActivationInfo.bDynamicSecurity = _bDynamicSecurity; + + pssaServerObjectResolverBindings = 0; + OxidInfo.psa = 0; + pOxidEntry = 0; + + hr = GetSCM()->SCMGetPersistentInstance( + &ActivationInfo, + pwszName, + pstg, + grfMode, + bFileWasOpened, + dwCount, + pIIDs, + IsSTAThread(), + &OxidServer, + &pssaServerObjectResolverBindings, + &OxidInfo, + &LocalMidOfRemote, + FoundInROT, + pRetdItfs, + pRetdHrs ); + + if ( pwszWinstaDesktop != _pwszWinstaDesktop ) + PrivMemFree( pwszWinstaDesktop ); + + if ( FAILED(hr) || (OxidServer == 0) ) + { + ComDebOut((DEB_ACTIVATE, "CRpcResolver::GetPersistentInstance hr:%x",hr)); + return hr; + } + + ASSERT_LOCK_RELEASED + LOCK + + hr = FindOrCreateOXIDEntry( + OxidServer, + OxidInfo, + FOCOXID_REF, + pssaServerObjectResolverBindings, + LocalMidOfRemote, + NULL, + &pOxidEntry ); + + CoTaskMemFree(OxidInfo.psa); + CoTaskMemFree(pssaServerObjectResolverBindings); + + // + // Decrement our ref. The interface unmarshall will do a LookupOXID + // which will increment the count and move the OXIDEntry back to the + // InUse list. + // + if ( pOxidEntry ) + DecOXIDRefCnt(pOxidEntry); + + UNLOCK + ASSERT_LOCK_RELEASED + + ComDebOut((DEB_ACTIVATE, "CRpcResolver::GetPersistentInstance hr:%x", hr)); + return hr; +} + +//+------------------------------------------------------------------------- +// +// Member: CRpcResolver::IrotRegister +// +// Synopsis: Register an object in the ROT +// +// Arguments: [pmkeqbuf] - moniker compare buffer +// [pifdObject] - marshaled interface for object +// [pifdObjectName] - marshaled moniker +// [pfiletime] - file time of last change +// [dwProcessID] - +// [psrkRegister] - output of registration +// +// Returns: S_OK +// +// History: 28-Jan-95 Ricksa Created +// +//-------------------------------------------------------------------------- +HRESULT CRpcResolver::IrotRegister( + MNKEQBUF *pmkeqbuf, + InterfaceData *pifdObject, + InterfaceData *pifdObjectName, + FILETIME *pfiletime, + DWORD dwProcessID, + WCHAR *pwszServerExe, + SCMREGKEY *psrkRegister) +{ + // Bind to the SCM if that hasn't already happened + HRESULT hr = GetConnection(); + if (FAILED(hr)) + return hr; + + error_status_t rpcstat = RPC_S_OK; + WCHAR * pwszWinstaDesktop; + + hr = GetWinstaDesktop( &pwszWinstaDesktop ); + + if ( FAILED(hr) ) + return hr; + + do + { + hr = ::IrotRegister( + _hRpc, + _ph, + pwszWinstaDesktop, + pmkeqbuf, + pifdObject, + pifdObjectName, + pfiletime, + dwProcessID, + pwszServerExe, + psrkRegister, + &rpcstat); + } while (RetryRPC(rpcstat)); + + if ( pwszWinstaDesktop != _pwszWinstaDesktop ) + PrivMemFree( pwszWinstaDesktop ); + + if (rpcstat != RPC_S_OK) + { + hr = CO_E_SCM_RPC_FAILURE; + } + + return hr; +} + +//+------------------------------------------------------------------------- +// +// Member: CRpcResolver::IrotRevoke +// +// Synopsis: Call to SCM to revoke object from the ROT +// +// Arguments: [psrkRegister] - moniker compare buffer +// [fServerRevoke] - whether server for object is revoking +// [pifdObject] - where to put marshaled object +// [pifdName] - where to put marshaled moniker +// +// Returns: S_OK +// +// History: 28-Jan-95 Ricksa Created +// +//-------------------------------------------------------------------------- +HRESULT CRpcResolver::IrotRevoke( + SCMREGKEY *psrkRegister, + BOOL fServerRevoke, + InterfaceData **ppifdObject, + InterfaceData **ppifdName) +{ + // Bind to the SCM if that hasn't already happened + HRESULT hr = GetConnection(); + if (FAILED(hr)) + return hr; + + error_status_t rpcstat = RPC_S_OK; + + do + { + hr = ::IrotRevoke( + _hRpc, + psrkRegister, + fServerRevoke, + ppifdObject, + ppifdName, + &rpcstat); + + } while (RetryRPC(rpcstat)); + + if (rpcstat != RPC_S_OK) + { + hr = CO_E_SCM_RPC_FAILURE; + } + + + return hr; +} + +//+------------------------------------------------------------------------- +// +// Member: CRpcResolver::IrotIsRunning +// +// Synopsis: Call to SCM to determine if object is in the ROT +// +// Arguments: [pmkeqbuf] - moniker compare buffer +// +// Returns: S_OK +// +// History: 28-Jan-95 Ricksa Created +// +//-------------------------------------------------------------------------- +HRESULT CRpcResolver::IrotIsRunning(MNKEQBUF *pmkeqbuf) +{ + // Bind to the SCM if that hasn't already happened + HRESULT hr = GetConnection(); + if (FAILED(hr)) + return hr; + + error_status_t rpcstat = RPC_S_OK; + WCHAR * pwszWinstaDesktop; + + hr = GetWinstaDesktop( &pwszWinstaDesktop ); + + if ( FAILED(hr) ) + return hr; + + do + { + hr = ::IrotIsRunning( + _hRpc, + _ph, + pwszWinstaDesktop, + pmkeqbuf, + &rpcstat); + + } while (RetryRPC(rpcstat)); + + if ( pwszWinstaDesktop != _pwszWinstaDesktop ) + PrivMemFree( pwszWinstaDesktop ); + + if (rpcstat != RPC_S_OK) + { + hr = CO_E_SCM_RPC_FAILURE; + } + + return hr; +} + +//+------------------------------------------------------------------------- +// +// Member: CRpcResolver::IrotGetObject +// +// Synopsis: Call to SCM to determine if object is in the ROT +// +// Arguments: [dwProcessID] - process ID for object we want +// [pmkeqbuf] - moniker compare buffer +// [psrkRegister] - registration ID in SCM +// [pifdObject] - marshaled interface for the object +// +// Returns: S_OK +// +// History: 28-Jan-95 Ricksa Created +// +//-------------------------------------------------------------------------- +HRESULT CRpcResolver::IrotGetObject( + DWORD dwProcessID, + MNKEQBUF *pmkeqbuf, + SCMREGKEY *psrkRegister, + InterfaceData **pifdObject) +{ + // Bind to the SCM if that hasn't already happened + HRESULT hr = GetConnection(); + if (FAILED(hr)) + return hr; + + error_status_t rpcstat = RPC_S_OK; + WCHAR * pwszWinstaDesktop; + + hr = GetWinstaDesktop( &pwszWinstaDesktop ); + + if ( FAILED(hr) ) + return hr; + + do + { + hr = ::IrotGetObject( + _hRpc, + _ph, + pwszWinstaDesktop, + dwProcessID, + pmkeqbuf, + psrkRegister, + pifdObject, + &rpcstat); + + } while (RetryRPC(rpcstat)); + + if ( pwszWinstaDesktop != _pwszWinstaDesktop ) + PrivMemFree( pwszWinstaDesktop ); + + if (rpcstat != RPC_S_OK) + { + hr = CO_E_SCM_RPC_FAILURE; + } + + return hr; +} + +//+------------------------------------------------------------------------- +// +// Member: CRpcResolver::IrotNoteChangeTime +// +// Synopsis: Call to SCM to set time of change for object in the ROT +// +// Arguments: [psrkRegister] - SCM registration ID +// [pfiletime] - time of change +// +// Returns: S_OK +// +// History: 28-Jan-95 Ricksa Created +// +//-------------------------------------------------------------------------- +HRESULT CRpcResolver::IrotNoteChangeTime( + SCMREGKEY *psrkRegister, + FILETIME *pfiletime) +{ + // Bind to the SCM if that hasn't already happened + HRESULT hr = GetConnection(); + if (FAILED(hr)) + return hr; + + error_status_t rpcstat = RPC_S_OK; + + do + { + hr = ::IrotNoteChangeTime( + _hRpc, + psrkRegister, + pfiletime, + &rpcstat); + + } while (RetryRPC(rpcstat)); + + if (rpcstat != RPC_S_OK) + { + hr = CO_E_SCM_RPC_FAILURE; + } + + return hr; +} + +//+------------------------------------------------------------------------- +// +// Member: CRpcResolver::IrotGetTimeOfLastChange +// +// Synopsis: Call to SCM to get time changed of object in the ROT +// +// Arguments: [pmkeqbuf] - moniker compare buffer +// [pfiletime] - where to put time of last change +// +// Returns: S_OK +// +// History: 28-Jan-95 Ricksa Created +// +//-------------------------------------------------------------------------- +HRESULT CRpcResolver::IrotGetTimeOfLastChange( + MNKEQBUF *pmkeqbuf, + FILETIME *pfiletime) +{ + // Bind to the SCM if that hasn't already happened + HRESULT hr = GetConnection(); + if (FAILED(hr)) + return hr; + + error_status_t rpcstat = RPC_S_OK; + WCHAR * pwszWinstaDesktop; + + hr = GetWinstaDesktop( &pwszWinstaDesktop ); + + if ( FAILED(hr) ) + return hr; + + do + { + hr = ::IrotGetTimeOfLastChange( + _hRpc, + _ph, + pwszWinstaDesktop, + pmkeqbuf, + pfiletime, + &rpcstat); + + } while (RetryRPC(rpcstat)); + + if ( pwszWinstaDesktop != _pwszWinstaDesktop ) + PrivMemFree( pwszWinstaDesktop ); + + if (rpcstat != RPC_S_OK) + { + hr = CO_E_SCM_RPC_FAILURE; + } + + return hr; +} + +//+------------------------------------------------------------------------- +// +// Member: CRpcResolver::IrotEnumRunning +// +// Synopsis: Call to SCM to enumerate running objects in the ROT +// +// Arguments: [ppMkIFList] - output pointer to array of marshaled monikers +// +// Returns: S_OK +// +// History: 28-Jan-95 Ricksa Created +// +//-------------------------------------------------------------------------- +HRESULT CRpcResolver::IrotEnumRunning(MkInterfaceList **ppMkIFList) +{ + // Bind to the SCM if that hasn't already happened + HRESULT hr = GetConnection(); + if (FAILED(hr)) + return hr; + + error_status_t rpcstat = RPC_S_OK; + WCHAR * pwszWinstaDesktop; + + hr = GetWinstaDesktop( &pwszWinstaDesktop ); + + if ( FAILED(hr) ) + return hr; + + do + { + hr = ::IrotEnumRunning( + _hRpc, + _ph, + pwszWinstaDesktop, + ppMkIFList, + &rpcstat); + + } while (RetryRPC(rpcstat)); + + if ( pwszWinstaDesktop != _pwszWinstaDesktop ) + PrivMemFree( pwszWinstaDesktop ); + + if (rpcstat != RPC_S_OK) + { + hr = CO_E_SCM_RPC_FAILURE; + } + + return hr; +} + +//+------------------------------------------------------------------------- +// +// Member: CRpcResolver::UpdateShrdTbls +// +// Synopsis: Ask the SCM to update the shared memory tables. +// +// Arguments: none +// +// History: 11-July-94 Rickhi Created +// +//-------------------------------------------------------------------------- +HRESULT CRpcResolver::UpdateShrdTbls(void) +{ + ComDebOut((DEB_ACTIVATE, "CRpcResolver::UpdateShrdTbls")); + + // Bind to the SCM if that hasn't already happened + HRESULT hr = GetConnection(); + if (FAILED(hr)) + return hr; + + error_status_t rpcstat; + + do + { + hr = ::UpdateShrdTbls(_hRpc, &rpcstat); + + } while (RetryRPC(rpcstat)); + + + ComDebOut(( (hr == S_OK) ? DEB_SCM : DEB_ERROR, + "UpdateShrdTbls returned %x\n", hr)); + + if (rpcstat != RPC_S_OK) + { + return HRESULT_FROM_WIN32(rpcstat); + } + + return hr; +} + +//+------------------------------------------------------------------------- +// +// Member: CRpcResolver::GetThreadID +// +// Synopsis: Get unique thread id from SCM. +// +// Arguments: [pThreadID] - Pointer to returned thread ID. +// +// History: 22-Jan-96 Rickhi Created +//-------------------------------------------------------------------------- +void CRpcResolver::GetThreadID( DWORD * pThreadID ) +{ + HRESULT hr; + + *pThreadID = 0; + + hr = GetConnection(); + if ( FAILED(hr) ) + return; + + // + // If GetConnection does the initial connect to the SCM/OR then + // our apartment thread id, which is aliased by pThreadID, will be set. + // + if ( *pThreadID != 0 ) + return; + + error_status_t rpcstat; + + do + { + ::GetThreadID( _hRpc, pThreadID, &rpcstat ); + } while (RetryRPC(rpcstat)); +} + +//+------------------------------------------------------------------------- +// +// Member: CRpcResolver::UpdateActivationSettings +// +// Synopsis: Tells rpcss to re-read default activation keys/values. +// Used by OLE test team. +// +// Arguments: none +// +//-------------------------------------------------------------------------- +void CRpcResolver::UpdateActivationSettings() +{ + HRESULT hr; + + hr = GetConnection(); + if ( FAILED(hr) ) + return; + + error_status_t rpcstat; + + do + { + ::UpdateActivationSettings( _hRpc, &rpcstat ); + } while (RetryRPC(rpcstat)); +} + +//+------------------------------------------------------------------------- +// +// Member: CRpcResolver::RegisterWindowPropInterface +// +// Synopsis: Register window property interface with the SCM +// +// Arguments: +// +// History: 22-Jan-96 Rickhi Created +// +//-------------------------------------------------------------------------- +HRESULT CRpcResolver::RegisterWindowPropInterface(HWND hWnd, STDOBJREF *pStd, + OXID_INFO *pOxidInfo, + DWORD *pdwCookie) +{ + // Bind to the SCM if that hasn't already happened + HRESULT hr = GetConnection(); + if (FAILED(hr)) + return hr; + + error_status_t rpcstat; + + do + { + hr = ::RegisterWindowPropInterface(_hRpc, (DWORD) hWnd, + pStd, pOxidInfo, pdwCookie, &rpcstat); + } while (RetryRPC(rpcstat)); + + ComDebOut(( (hr == S_OK) ? DEB_SCM : DEB_ERROR, + "RegisterWindowPropInterface returned %x\n", hr)); + + if (rpcstat != RPC_S_OK) + { + return HRESULT_FROM_WIN32(rpcstat); + } + + return hr; +} + +//+------------------------------------------------------------------------- +// +// Member: CRpcResolver::RegisterWindowPropInterface +// +// Synopsis: Get (and possibly Revoke) window property interface +// registration with the SCM. +// +// Arguments: +// +// History: 22-Jan-96 Rickhi Created +// +//-------------------------------------------------------------------------- +HRESULT CRpcResolver::GetWindowPropInterface(HWND hWnd, DWORD dwCookie, BOOL fRevoke, + STDOBJREF *pStd, OXID_INFO *pOxidInfo) +{ + // Bind to the SCM if that hasn't already happened + HRESULT hr = GetConnection(); + if (FAILED(hr)) + return hr; + + error_status_t rpcstat; + + do + { + hr = ::GetWindowPropInterface(_hRpc, (DWORD) hWnd, dwCookie, fRevoke, + pStd, pOxidInfo, &rpcstat); + } while (RetryRPC(rpcstat)); + + ComDebOut(( (hr == S_OK) ? DEB_SCM : DEB_ERROR, + "GetWindowPropInterface returned %x\n", hr)); + + if (rpcstat != RPC_S_OK) + { + return HRESULT_FROM_WIN32(rpcstat); + } + + return hr; +} + +//+------------------------------------------------------------------------- +// +// Member: CRpcResolver::SetWinstaDesktop +// +// Purpose: Sets the default winsta\desktop string we'll use for this +// process. +// +// Returns: Success code. +// +// History: Nov 96 DKays Created +// +//-------------------------------------------------------------------------- +DWORD CRpcResolver::SetWinstaDesktop() +{ + return GetThreadWinstaDesktop( &_pwszWinstaDesktop ); +} + +//+------------------------------------------------------------------------- +// +// Member: CRpcResolver::GetWinstaDesktop +// +// Purpose: Gets the winsta\desktop string to use for an activation call. +// +// Returns: Success code. +// +// History: Nov 96 DKays Created +// +//-------------------------------------------------------------------------- +HRESULT CRpcResolver::GetWinstaDesktop( WCHAR ** ppwszWinstaDesktop ) +{ + DWORD Status; + + *ppwszWinstaDesktop = 0; + + if ( ! _bDynamicSecurity ) + { + *ppwszWinstaDesktop = _pwszWinstaDesktop; + return S_OK; + } + + Status = GetThreadWinstaDesktop( ppwszWinstaDesktop ); + return HRESULT_FROM_WIN32( Status ); +} + +//+------------------------------------------------------------------------- +// +// Member: CRpcResolver::GetDynamicSecurity +// +// Purpose: Get the dynamic security setting for the process. +// +// Returns: TRUE or FALSE +// +// History: Nov 96 DKays Created +// +//-------------------------------------------------------------------------- +BOOL CRpcResolver::GetDynamicSecurity() +{ + return _bDynamicSecurity; +} + +//+------------------------------------------------------------------------- +// +// Member: CRpcResolver::SetDynamicSecurity +// +// Purpose: Set the dynamic security setting for the process to TRUE. +// +// Returns: None. +// +// History: Nov 96 DKays Created +// +//-------------------------------------------------------------------------- +void CRpcResolver::SetDynamicSecurity() +{ + _bDynamicSecurity = TRUE; +} + +//+------------------------------------------------------------------------- +// +// Method: GetThreadWinstaDesktop +// +// Purpose: Get the winsta\desktop string for the calling thread. +// +// Returns: Success code. +// +// History: Nov-96 DKays Created +// +//-------------------------------------------------------------------------- +DWORD GetThreadWinstaDesktop( WCHAR ** ppwszWinstaDesktop ) +{ + HWINSTA hWinsta; + HDESK hDesk; + WCHAR wszWinsta[32]; + WCHAR wszDesktop[32]; + LPWSTR pwszWinsta; + LPWSTR pwszDesktop; + DWORD WinstaSize; + DWORD DesktopSize; + DWORD Length; + BOOL Status; + DWORD Result; + + *ppwszWinstaDesktop = 0; + + hWinsta = GetProcessWindowStation(); + + if ( ! hWinsta ) + return GetLastError(); + + hDesk = GetThreadDesktop(GetCurrentThreadId()); + + if ( ! hDesk ) + return GetLastError(); + + pwszWinsta = wszWinsta; + pwszDesktop = wszDesktop; + + Length = sizeof(wszWinsta); + + Status = GetUserObjectInformation( + hWinsta, + UOI_NAME, + pwszWinsta, + Length, + &Length ); + + if ( ! Status ) + { + Result = GetLastError(); + if ( Result != ERROR_INSUFFICIENT_BUFFER ) + goto WinstaDesktopExit; + + pwszWinsta = (LPWSTR)PrivMemAlloc( Length ); + if ( ! pwszWinsta ) + { + Result = ERROR_OUTOFMEMORY; + goto WinstaDesktopExit; + } + + Status = GetUserObjectInformation( + hWinsta, + UOI_NAME, + pwszWinsta, + Length, + &Length ); + + if ( ! Status ) + { + Result = GetLastError(); + goto WinstaDesktopExit; + } + } + + Length = sizeof(wszDesktop); + + Status = GetUserObjectInformation( + hDesk, + UOI_NAME, + pwszDesktop, + Length, + &Length ); + + if ( ! Status ) + { + Result = GetLastError(); + if ( Result != ERROR_INSUFFICIENT_BUFFER ) + goto WinstaDesktopExit; + + pwszDesktop = (LPWSTR)PrivMemAlloc( Length ); + if ( ! pwszDesktop ) + { + Result = ERROR_OUTOFMEMORY; + goto WinstaDesktopExit; + } + + Status = GetUserObjectInformation( + hDesk, + UOI_NAME, + pwszDesktop, + Length, + &Length ); + + if ( ! Status ) + { + Result = GetLastError(); + goto WinstaDesktopExit; + } + } + + *ppwszWinstaDesktop = (WCHAR *) + PrivMemAlloc( (lstrlenW(pwszWinsta) + 1 + lstrlenW(pwszDesktop) + 1) * sizeof(WCHAR) ); + + if ( *ppwszWinstaDesktop ) + { + lstrcpyW( *ppwszWinstaDesktop, pwszWinsta ); + lstrcatW( *ppwszWinstaDesktop, L"\\" ); + lstrcatW( *ppwszWinstaDesktop, pwszDesktop ); + Result = S_OK; + } + else + { + Result = ERROR_OUTOFMEMORY; + } + +WinstaDesktopExit: + + if ( pwszWinsta != wszWinsta ) + PrivMemFree( pwszWinsta ); + + if ( pwszDesktop != wszDesktop ) + PrivMemFree( pwszDesktop ); + + return Result; +} + +//+------------------------------------------------------------------------- +// +// Method: ScmGetThreadId +// +// Purpose: Stupid helper method so gResolver is not used in +// com\class subdir. +// +//-------------------------------------------------------------------------- +void ScmGetThreadId( DWORD * pThreadID ) +{ + gResolver.GetThreadID( pThreadID ); +} + +//+--------------------------------------------------------------------- +// +// Function: UpdateDCOMSettings +// +// Synopsis: Calls rpcss to re-read the default activation keys/values. +// +//---------------------------------------------------------------------- +STDAPI_(void) UpdateDCOMSettings(void) +{ + gResolver.UpdateActivationSettings(); +} + diff --git a/private/ole32/com/dcomrem/resolver.hxx b/private/ole32/com/dcomrem/resolver.hxx new file mode 100644 index 000000000..3232a019e --- /dev/null +++ b/private/ole32/com/dcomrem/resolver.hxx @@ -0,0 +1,269 @@ +//+------------------------------------------------------------------- +// +// File: resolver.hxx +// +// Contents: class implementing interface to RPC OXID/PingServer +// resolver and OLE SCM process. +// +// Classes: CRpcResolver +// +// History: 20-Feb-95 Rickhi Created +// +//-------------------------------------------------------------------- +#ifndef _RESOLVER_HXX_ +#define _RESOLVER_HXX_ + +#include <lclor.h> +#include <ipidtbl.hxx> // gOXIDTbl +#include <hash.hxx> // CHashTable +#include <iface.h> +#include <scm.h> +#include <irot.h> +#include <dscm.h> + +// Client-Side OID registration record. Created for each client-side OID +// that needs to be registered with the Resolver. Exists so we can lazily +// register the OID and because the resolver expects one register/deregister +// per process, not per apartment. + +typedef struct tagSOIDRegistration +{ + SUUIDHashNode Node; // hash node + USHORT cRefs; // # apartments registered this OID + USHORT flags; // state flags + OXIDEntry *pOXIDEntry;// OXID of server for this OID + struct tagSOIDRegistration *pPrevList; // prev ptr for list + struct tagSOIDRegistration *pNextList; // next ptr for list +} SOIDRegistration; + + +// bit values for SOIDRegistration flags field +typedef enum tagROIDFLAG +{ + ROIDF_REGISTER = 0x01, // Register OID with Ping Server + ROIDF_PING = 0x02, // Ping (ie Register & DeRegister) OID + ROIDF_DEREGISTER = 0x04 // DeRegister OID with Ping Server +} ROIDFLAG; + +// number of server-side OIDs to pre-register or reserve with the resolver +#define MAX_PREREGISTERED_OIDS 10 +#define MAX_RESERVED_OIDS 10 + + +// bit values for Resolver _dwFlags field +typedef enum tagORFLAG +{ + ORF_STRINGSREGISTERED = 0x01 // string bindings registerd with resolver +} ORFLAG; + +//+------------------------------------------------------------------- +// +// Class: CRpcResolver +// +// Purpose: Provides an interface to OXID Resolver/PingServer process. +// There is only one instance of this class in the process. +// +// History: 20-Feb-95 Rickhi Created +// +//-------------------------------------------------------------------- +class CRpcResolver : public CPrivAlloc +{ +public: + HRESULT ServerGetPreRegMOID(MOID *pmoid); + HRESULT ServerGetReservedMOID(MOID *pmoid); + HRESULT ServerGetReservedID(OID *pid); + HRESULT ServerFreeOXID(OXIDEntry *pOXIDEntry); + + BOOL ServerCanRundownOID(REFOID roid); + + HRESULT ClientResolveOXID(REFOXID roxid, + DUALSTRINGARRAY *psaResolver, + OXIDEntry **ppOXIDEntry); + + HRESULT ClientRegisterOIDWithPingServer(REFOID roid, + OXIDEntry *pOXIDEntry); + + HRESULT ClientDeRegisterOIDFromPingServer(REFMOID roid, + BOOL fMarshaled); + + HRESULT NotifyStarted( + RegInput *pRegIn, + RegOutput **ppRegOut); + + void NotifyStopped( + REFCLSID rclsid, + DWORD dwReg); + + HRESULT GetClassObject( + REFCLSID rclsid, + DWORD dwCtrl, + IID *pIID, + COSERVERINFO *pServerInfo, + MInterfacePointer **ppIFDClassObj, + DWORD *pdwDllServerType, + WCHAR **ppwszDllToLoad); + + HRESULT CreateInstance( + COSERVERINFO *pServerInfo, + CLSID *pClsid, + DWORD dwClsCtx, + DWORD dwCount, + IID *pIIDs, + MInterfacePointer **pRetdItfs, + HRESULT *pRetdHrs, + DWORD *pdwDllServerType, + OLECHAR **ppwszDllToLoad ); + + HRESULT GetPersistentInstance( + COSERVERINFO * pServerInfo, + CLSID *pClsid, + DWORD dwClsCtx, + DWORD grfMode, + BOOL bFileWasOpened, + OLECHAR *pwszName, + MInterfacePointer *pstg, + DWORD dwCount, + IID *pIIDs, + BOOL * FoundInROT, + MInterfacePointer **pRetdItfs, + HRESULT *pRetdHrs, + DWORD *pdwDllServerType, + OLECHAR **ppwszDllToLoad ); + + HRESULT IrotRegister( + MNKEQBUF *pmkeqbuf, + InterfaceData *pifdObject, + InterfaceData *pifdObjectName, + FILETIME *pfiletime, + DWORD dwProcessID, + WCHAR *pwszServerExe, + SCMREGKEY *pdwRegister); + + HRESULT IrotRevoke( + SCMREGKEY *psrkRegister, + BOOL fServerRevoke, + InterfaceData **pifdObject, + InterfaceData **pifdName); + + HRESULT IrotIsRunning( + MNKEQBUF *pmkeqbuf); + + HRESULT IrotGetObject( + DWORD dwProcessID, + MNKEQBUF *pmkeqbuf, + SCMREGKEY *psrkRegister, + InterfaceData **pifdObject); + + HRESULT IrotNoteChangeTime( + SCMREGKEY *psrkRegister, + FILETIME *pfiletime); + + HRESULT IrotGetTimeOfLastChange( + MNKEQBUF *pmkeqbuf, + FILETIME *pfiletime); + + HRESULT IrotEnumRunning( + MkInterfaceList **ppMkIFList); + + HRESULT UpdateShrdTbls(void); + + void GetThreadID( DWORD * pThreadID ); + + void UpdateActivationSettings(); + + HRESULT RegisterWindowPropInterface( + HWND hWnd, + STDOBJREF *pStd, + OXID_INFO *pOxidInfo, + DWORD *pdwCookie); + + HRESULT GetWindowPropInterface( + HWND hWnd, + DWORD dwCookie, + BOOL fRevoke, + STDOBJREF *pStd, + OXID_INFO *pOxidInfo); + + HRESULT GetConnection(); + HRESULT BindToSCMProxy(); + void ReleaseSCMProxy(); + + DWORD SetWinstaDesktop(); + HRESULT GetWinstaDesktop( WCHAR ** ppwszWinstaDesktop ); + + BOOL GetDynamicSecurity(); + void SetDynamicSecurity(); + + void Cleanup(); + +private: + +#if DBG==1 + void AssertValid(void); +#else + void AssertValid(void) {}; +#endif + + HRESULT EnsureWorkerThread(void); + DWORD _stdcall WorkerThreadLoop(void *param); + HRESULT ClientBulkUpdateOIDWithPingServer(void); + + HRESULT WaitForOXIDEntry(OXIDEntry *pEntry); + void CheckForWaiters(OXIDEntry *pEntry); + HRESULT ServerAllocMoreOIDs(ULONG *pcPreRegOidsAvail, + OID *parPreRegOidsAvail, + OXIDEntry *pEntry); + HRESULT ServerAllocOIDs(OXIDEntry *pEntry, + ULONG *pcPreRegOidsAvail, + OID *parPreRegOidsAvail); + + HRESULT ServerRegisterOXID(OXIDEntry *pOXIDEntry, + ULONG *pcOidsToAllocate, + OID arNewOidList[]); + + HRESULT CheckStatus(RPC_STATUS sc); + BOOL RetryRPC(RPC_STATUS sc); + IDSCM *GetSCM() { return (IsSTAThread()) ? _pSCMSTA : _pSCMMTA; } + + static handle_t _hRpc; // rpc binding handle to resolver + static PHPROCESS _ph; // context handle to resolver + static HANDLE _hThrd; // handle of worker thread (if any) + static HANDLE _hEventOXID; // event for registering threads + static DWORD _dwFlags; // flags + static DWORD _dwSleepPeriod; // worker thread sleep period + + // reserved sequence of OIDs (for no-ping marshals) + static ULONG _cReservedOidsAvail; + static ULONGLONG _OidNextReserved; + + // pre-registered OIDs (for objects that need to be pinged) + static ULONG _cPreRegOidsAvail; + static OID _arPreRegOids[MAX_PREREGISTERED_OIDS]; + + static ULONG _cOidsToAdd; // # of OIDs to register with resolver + static ULONG _cOidsToRemove; // # of OIDs to deregister with resolver + + static SOIDRegistration _ClientOIDRegList; + + static IDSCM * _pSCMSTA; // Single-threaded SCM proxy + static IDSCM * _pSCMMTA; // Multi-threaded SCM proxy + static LPWSTR _pwszWinstaDesktop; + + static DWORD _dwProcessSignature; + + static BOOL _bDynamicSecurity; +}; + +extern MID gLocalMid; // MID for current machine +extern OXID gScmOXID; // OXID for the SCM + +// global ptr to the one instance of this class +extern CRpcResolver gResolver; + +// Ping period in milliseconds. +extern DWORD giPingPeriod; + +// table of OIDs client-registered for pinging +extern CUUIDHashTable gClientRegisteredOIDs; + +#endif // _RESOLVER_HXX_ diff --git a/private/ole32/com/dcomrem/riftbl.cxx b/private/ole32/com/dcomrem/riftbl.cxx new file mode 100644 index 000000000..0f39860ef --- /dev/null +++ b/private/ole32/com/dcomrem/riftbl.cxx @@ -0,0 +1,567 @@ +//+------------------------------------------------------------------------ +// +// File: riftbl.cxx +// +// Contents: RIF (Registered Interfaces) Table. +// +// Classes: CRIFTable +// +// History: 12-Feb-96 Rickhi Created +// +//------------------------------------------------------------------------- +#include <ole2int.h> +#include <riftbl.hxx> // class definition +#include <locks.hxx> // LOCK/UNLOCK +#include <channelb.hxx> // ThreadInvoke + + +// number of Registered Interface Entries per allocator page +#define RIFS_PER_PAGE 32 + +// global RIF table +CRIFTable gRIFTbl; + + +//+------------------------------------------------------------------------ +// +// Vector Table: All calls on registered interfaces are dispatched through +// this table to ThreadInvoke, which subsequently dispatches to the +// appropriate interface stub. All calls on COM interfaces are dispatched +// on method #0 so the table only needs to be 1 entry long. +// +//+------------------------------------------------------------------------ + +const RPC_DISPATCH_FUNCTION vector[] = +{ + (void (_stdcall *) (struct ::_RPC_MESSAGE *)) ThreadInvoke, +}; + +const RPC_DISPATCH_TABLE gDispatchTable = +{ + sizeof(vector)/sizeof(RPC_DISPATCH_FUNCTION), + (RPC_DISPATCH_FUNCTION *)&vector, 0 +}; + + +//+------------------------------------------------------------------------ +// +// Interface Templates. When we register an interface with the RPC runtime, +// we allocate an structure, copy one of these templates in (depending on +// whether we want client side or server side) and then set the interface +// IID to the interface being registered. +// +// We hand-register the RemUnknown interface because we normally marshal its +// derived verion (IRundown), yet expect calls on IRemUnknown. +// +//+------------------------------------------------------------------------ + +const RPC_SERVER_INTERFACE gServerIf = +{ + sizeof(RPC_SERVER_INTERFACE), + {0x69C09EA0, 0x4A09, 0x101B, 0xAE, 0x4B, 0x08, 0x00, 0x2B, 0x34, 0x9A, 0x02, + {0, 0}}, + {0x8A885D04, 0x1CEB, 0x11C9, 0x9F, 0xE8, 0x08, 0x00, 0x2B, 0x10, 0x48, 0x60, + {2, 0}}, + (RPC_DISPATCH_TABLE *)&gDispatchTable, 0, 0, 0 +}; + +const RPC_CLIENT_INTERFACE gClientIf = +{ + sizeof(RPC_CLIENT_INTERFACE), + {0x69C09EA0, 0x4A09, 0x101B, 0xAE, 0x4B, 0x08, 0x00, 0x2B, 0x34, 0x9A, 0x02, + {0, 0}}, + {0x8A885D04, 0x1CEB, 0x11C9, 0x9F, 0xE8, 0x08, 0x00, 0x2B, 0x10, 0x48, 0x60, + {2, 0}}, + 0, 0, 0, 0 +}; + +const RPC_SERVER_INTERFACE gRemUnknownIf = +{ + sizeof(RPC_SERVER_INTERFACE), + {0x00000131, 0x0000, 0x0000, 0xC0, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x46, + {0, 0}}, + {0x8A885D04, 0x1CEB, 0x11C9, 0x9F, 0xE8, 0x08, 0x00, 0x2B, 0x10, 0x48, 0x60, + {2, 0}}, + (RPC_DISPATCH_TABLE *)&gDispatchTable, 0, 0, 0 +}; + + +//+------------------------------------------------------------------------ +// +// Registered Interface hash table buckets. This is defined as a global +// so that we dont have to run any code to initialize the hash table. +// +//+------------------------------------------------------------------------ +SHashChain RIFBuckets[23] = +{ + {&RIFBuckets[0], &RIFBuckets[0]}, + {&RIFBuckets[1], &RIFBuckets[1]}, + {&RIFBuckets[2], &RIFBuckets[2]}, + {&RIFBuckets[3], &RIFBuckets[3]}, + {&RIFBuckets[4], &RIFBuckets[4]}, + {&RIFBuckets[5], &RIFBuckets[5]}, + {&RIFBuckets[6], &RIFBuckets[6]}, + {&RIFBuckets[7], &RIFBuckets[7]}, + {&RIFBuckets[8], &RIFBuckets[8]}, + {&RIFBuckets[9], &RIFBuckets[9]}, + {&RIFBuckets[10], &RIFBuckets[10]}, + {&RIFBuckets[11], &RIFBuckets[11]}, + {&RIFBuckets[12], &RIFBuckets[12]}, + {&RIFBuckets[13], &RIFBuckets[13]}, + {&RIFBuckets[14], &RIFBuckets[14]}, + {&RIFBuckets[15], &RIFBuckets[15]}, + {&RIFBuckets[16], &RIFBuckets[16]}, + {&RIFBuckets[17], &RIFBuckets[17]}, + {&RIFBuckets[18], &RIFBuckets[18]}, + {&RIFBuckets[19], &RIFBuckets[19]}, + {&RIFBuckets[20], &RIFBuckets[20]}, + {&RIFBuckets[21], &RIFBuckets[21]}, + {&RIFBuckets[22], &RIFBuckets[22]} +}; + + +//+------------------------------------------------------------------- +// +// Function: CleanupRIFEntry +// +// Synopsis: Call the RIFTable to cleanup an entry. This is called +// by the hash table cleanup code. +// +// History: 12-Feb-96 Rickhi Created +// +//-------------------------------------------------------------------- +void CleanupRIFEntry(SHashChain *pNode) +{ + gRIFTbl.UnRegisterInterface((RIFEntry *)pNode); +} + +//+------------------------------------------------------------------------ +// +// Member: CRIFTable::Initialize, public +// +// Synopsis: Initialize the Registered Interface Table +// +// History: 12-Feb-96 Rickhi Created +// +//------------------------------------------------------------------------- +void CRIFTable::Initialize() +{ + ComDebOut((DEB_CHANNEL, "CRIFTable::Initialize\n")); + ASSERT_LOCK_HELD + _HashTbl.Initialize(RIFBuckets); + _palloc.Initialize(sizeof(RIFEntry), RIFS_PER_PAGE); +} + +//+------------------------------------------------------------------------ +// +// Member: CRIFTable::Cleanup, public +// +// Synopsis: Cleanup the Registered Interface Table. +// +// History: 12-Feb-96 Rickhi Created +// +//------------------------------------------------------------------------- +void CRIFTable::Cleanup() +{ + ComDebOut((DEB_CHANNEL, "CRIFTable::Cleanup\n")); + ASSERT_LOCK_HELD + _HashTbl.Cleanup(CleanupRIFEntry); + _palloc.Cleanup(); +} + +//+------------------------------------------------------------------- +// +// Member: CRIFTable::GetClientInterfaceInfo, public +// +// Synopsis: returns the interface info for a given interface +// +// History: 12-Feb-96 Rickhi Created +// +//-------------------------------------------------------------------- +RPC_CLIENT_INTERFACE *CRIFTable::GetClientInterfaceInfo(REFIID riid) +{ + DWORD iHash = _HashTbl.Hash(riid); + RIFEntry *pRIFEntry = (RIFEntry *) _HashTbl.Lookup(iHash, riid); + Win4Assert(pRIFEntry); // must already be registered + Win4Assert(pRIFEntry->pCliInterface); + return pRIFEntry->pCliInterface; +} + +//+------------------------------------------------------------------- +// +// Member: CRIFTable::RegisterInterface, public +// +// Synopsis: returns the proxy stub clsid of the specified interface, +// and adds an entry to the registered interface hash table +// if needed. +// +// History: 12-Feb-96 Rickhi Created +// +//-------------------------------------------------------------------- +HRESULT CRIFTable::RegisterInterface(REFIID riid, BOOL fServer, CLSID *pClsid) +{ + ComDebOut((DEB_CHANNEL, "CRIFTable::RegisterInterface riid:%I\n", &riid)); + ASSERT_LOCK_RELEASED + LOCK + + // look for the interface in the table. + RIFEntry *pRIFEntry; + HRESULT hr = GetPSClsid(riid, pClsid, &pRIFEntry); + + if (pRIFEntry) + { + if (fServer) + { + if (pRIFEntry->pSrvInterface == NULL) + { + hr = RegisterServerInterface(pRIFEntry, riid); + } + } + else if (pRIFEntry->pCliInterface == NULL) + { + hr = RegisterClientInterface(pRIFEntry, riid); + } + } + + UNLOCK + ASSERT_LOCK_RELEASED + ComDebOut((DEB_CHANNEL, + "CRIFTable::RegisterInterface hr:%x clsid:%I\n", hr, pClsid)); + return hr; +} + +//+------------------------------------------------------------------- +// +// Member: CRIFTable::RegisterClientInterface, private +// +// Synopsis: Register with the RPC runtime a client RPC interface +// structure for the given IID. The IID must not already +// be registered. +// +// History: 12-Feb-96 Rickhi Created +// +//-------------------------------------------------------------------- +HRESULT CRIFTable::RegisterClientInterface(RIFEntry *pRIFEntry, REFIID riid) +{ + ComDebOut((DEB_CHANNEL, + "CRIFTable::RegisterClientInterface pRIFEntry:%x\n", pRIFEntry)); + Win4Assert(pRIFEntry->pCliInterface == NULL); + ASSERT_LOCK_HELD + + HRESULT hr = E_OUTOFMEMORY; + pRIFEntry->pCliInterface = (RPC_CLIENT_INTERFACE *) + PrivMemAlloc(sizeof(RPC_CLIENT_INTERFACE)); + + if (pRIFEntry->pCliInterface != NULL) + { + memcpy(pRIFEntry->pCliInterface, &gClientIf, sizeof(gClientIf)); + pRIFEntry->pCliInterface->InterfaceId.SyntaxGUID = riid; + hr = S_OK; + } + + return hr; +} + +//+------------------------------------------------------------------- +// +// Member: CRIFTable::RegisterServerInterface, private +// +// Synopsis: Register with the RPC runtime a server RPC interface +// structure for the given IID. The IID must not already +// be registered +// +// History: 12-Feb-96 Rickhi Created +// +//-------------------------------------------------------------------- +HRESULT CRIFTable::RegisterServerInterface(RIFEntry *pRIFEntry, REFIID riid) +{ + ComDebOut((DEB_CHANNEL, + "CRIFTable::RegisterServerInterface pRIFEntry:%x\n", pRIFEntry)); + Win4Assert(pRIFEntry->pSrvInterface == NULL); + ASSERT_LOCK_HELD + + HRESULT hr = E_OUTOFMEMORY; + pRIFEntry->pSrvInterface = (RPC_SERVER_INTERFACE *) + PrivMemAlloc(sizeof(RPC_SERVER_INTERFACE)); + + if (pRIFEntry->pSrvInterface != NULL) + { + hr = S_OK; + memcpy(pRIFEntry->pSrvInterface, &gServerIf, sizeof(gServerIf)); + pRIFEntry->pSrvInterface->InterfaceId.SyntaxGUID = riid; + + RPC_STATUS sc = RpcServerRegisterIfEx(pRIFEntry->pSrvInterface, NULL, + NULL, + RPC_IF_AUTOLISTEN | RPC_IF_OLE, + 0xffff, GetAclFn()); + if (sc != RPC_S_OK) + { + ComDebOut((DEB_ERROR, + "RegisterServerInterface %I failed:0x%x.\n", &riid, sc)); + + PrivMemFree(pRIFEntry->pSrvInterface); + pRIFEntry->pSrvInterface = NULL; + hr = HRESULT_FROM_WIN32(sc); + } + } + + return hr; +} + +//+------------------------------------------------------------------- +// +// Member: CRIFTable::UnRegisterInterface +// +// Synopsis: UnRegister with the RPC runtime a server RPC interface +// structure for the given IID. This is called by +// CUUIDHashTable::Cleanup during CoUninitialize. Also +// delete the interface structures. +// +// History: 12-Feb-96 Rickhi Created +// +//-------------------------------------------------------------------- +void CRIFTable::UnRegisterInterface(RIFEntry *pRIFEntry) +{ + if (pRIFEntry->pSrvInterface) + { + // server side entry exists, unregister the interface with RPC. + // Note that this can result in calls being dispatched so we + // have to release the lock around the call. + + UNLOCK + ASSERT_LOCK_RELEASED + + RpcServerUnregisterIf(pRIFEntry->pSrvInterface, 0, 1); + PrivMemFree(pRIFEntry->pSrvInterface); + + ASSERT_LOCK_RELEASED + LOCK + + pRIFEntry->pSrvInterface = NULL; + } + + PrivMemFree(pRIFEntry->pCliInterface); + + _palloc.ReleaseEntry((PageEntry *)pRIFEntry); +} + +//+------------------------------------------------------------------- +// +// Member: CRIFTable::GetPSClsid, public +// +// Synopsis: Finds the RIFEntry in the table for the given riid, and +// adds an entry if one is not found. Called by CoGetPSClsid +// and by CRIFTable::RegisterInterface. +// +// History: 12-Feb-96 Rickhi Created +// +//-------------------------------------------------------------------- +HRESULT CRIFTable::GetPSClsid(REFIID riid, CLSID *pclsid, RIFEntry **ppEntry) +{ + ComDebOut((DEB_CHANNEL, + "CRIFTable::GetPSClsid riid:%I pclsid:%x\n", &riid, pclsid)); + ASSERT_LOCK_HELD + HRESULT hr = S_OK; + + // look for the interface in the table. + DWORD iHash = _HashTbl.Hash(riid); + RIFEntry *pRIFEntry = (RIFEntry *) _HashTbl.Lookup(iHash, riid); + + if (pRIFEntry == NULL) + { + // no entry exists for this interface, add one. Dont hold + // the lock over a call to the SCM. + + UNLOCK + ASSERT_LOCK_RELEASED + hr = wCoGetPSClsid(riid, pclsid); + ASSERT_LOCK_RELEASED + LOCK + + // now that we are holding the lock again, do another lookup incase + // some other thread came it while the lock was released. + + pRIFEntry = (RIFEntry *) _HashTbl.Lookup(iHash, riid); + + if (pRIFEntry == NULL && SUCCEEDED(hr)) + { + hr = AddEntry(*pclsid, riid, iHash, &pRIFEntry); + } + } + else + { + // found an entry, return the clsid + *pclsid = pRIFEntry->psclsid; + } + + *ppEntry = pRIFEntry; + + ASSERT_LOCK_HELD + ComDebOut((DEB_CHANNEL, "CRIFTable::RegisterPSClsid pRIFEntry:%x\n", pRIFEntry)); + return hr; +} + +//+------------------------------------------------------------------- +// +// Member: CRIFTable::RegisterPSClsid, public +// +// Synopsis: Adds an entry to the table. Used by CoRegisterPSClsid +// so that applications can add a temporary entry that only +// affects the local process without having to muck with +// the system registry. +// +// History: 12-Feb-96 Rickhi Created +// +//-------------------------------------------------------------------- +HRESULT CRIFTable::RegisterPSClsid(REFIID riid, REFCLSID rclsid) +{ + ComDebOut((DEB_CHANNEL, + "CRIFTable::RegisterPSClsid rclsid:%I riid:%I\n", &rclsid, &riid)); + + HRESULT hr = S_OK; + ASSERT_LOCK_RELEASED + LOCK + + // look for the interface in the table. + DWORD iHash = _HashTbl.Hash(riid); + RIFEntry *pRIFEntry = (RIFEntry *) _HashTbl.Lookup(iHash, riid); + + if (pRIFEntry == NULL) + { + // no entry exists for this interface, add one. + hr = AddEntry(rclsid, riid, iHash, &pRIFEntry); + } + else + { + // found an entry, update the clsid + pRIFEntry->psclsid = rclsid; + } + + UNLOCK + ASSERT_LOCK_RELEASED + ComDebOut((DEB_CHANNEL, "CRIFTable::RegisterPSClsid hr:%x\n", hr)); + return hr; +} + +//+------------------------------------------------------------------- +// +// Member: CRIFTable::AddEntry, private +// +// Synopsis: allocates and entry, fills in the values, and adds it +// to the hash table. +// +// History: 12-Feb-96 Rickhi Created +// +//-------------------------------------------------------------------- +HRESULT CRIFTable::AddEntry(REFCLSID rclsid, REFIID riid, + DWORD iHash, RIFEntry **ppRIFEntry) +{ + ASSERT_LOCK_HELD + RIFEntry *pRIFEntry = (RIFEntry *) _palloc.AllocEntry(); + + if (pRIFEntry) + { + pRIFEntry->psclsid = rclsid; + pRIFEntry->pSrvInterface = NULL; + pRIFEntry->pCliInterface = NULL; + *ppRIFEntry = pRIFEntry; + + // add to the hash table + _HashTbl.Add(iHash, riid, &pRIFEntry->HashNode); + + ComDebOut((DEB_CHANNEL, + "Added RIFEntry riid:%I pRIFEntry\n", &riid, pRIFEntry)); + return S_OK; + } + + ASSERT_LOCK_HELD + return E_OUTOFMEMORY; +} + +//+------------------------------------------------------------------- +// +// Function: CoRegisterPSClsid, public +// +// Synopsis: registers a IID->PSCLSID mapping that applies only within +// the current process. Can be used by code downloaded over +// a network to do custom interface marshaling without having +// to muck with the system registry. +// +// Algorithm: validate the parameters then add an entry to the RIFTable. +// +// History: 15-Apr-96 Rickhi Created +// +//-------------------------------------------------------------------- +STDAPI CoRegisterPSClsid(REFIID riid, REFCLSID rclsid) +{ + ComDebOut((DEB_MARSHAL, + "CoRegisterPSClsid riid:%I rclsid:%I\n", &riid, &rclsid)); + + HRESULT hr = InitChannelIfNecessary(); + if (FAILED(hr)) + return hr; + + hr = E_INVALIDARG; + + if ((&riid != NULL) && (&rclsid != NULL) && + IsValidPtrIn(&riid, sizeof(riid)) && + IsValidPtrIn(&rclsid, sizeof(rclsid))) + { + ASSERT_LOCK_RELEASED + + hr = gRIFTbl.RegisterPSClsid(riid, rclsid); + + ASSERT_LOCK_RELEASED + } + + return hr; +} + +//+------------------------------------------------------------------------- +// +// Function: CoGetPSClsid, public +// +// Synopsis: returns the proxystub clsid associated with the specified +// interface IID. +// +// Arguments: [riid] - the interface iid to lookup +// [lpclsid] - where to return the clsid +// +// Returns: S_OK if successfull +// REGDB_E_IIDNOTREG if interface is not registered. +// REGDB_E_READREGDB if any other error +// +// Algorithm: First it looks in the local RIFTable for a matching IID. If +// no entry is found, the RIFTable looks in the shared memory +// table (NT only), and if not found and the table is FULL, it +// will look in the registry itself. +// +// History: 07-Apr-94 Rickhi rewrite +// +//-------------------------------------------------------------------------- +STDAPI CoGetPSClsid(REFIID riid, CLSID *pclsid) +{ + ComDebOut((DEB_MARSHAL, "CoGetPSClsid riid:%I pclsid:%x\n", &riid, pclsid)); + + HRESULT hr = InitChannelIfNecessary(); + if (FAILED(hr)) + return hr; + + hr = E_INVALIDARG; + + if ((&riid != NULL) && + IsValidPtrIn(&riid, sizeof(riid)) && + IsValidPtrOut(pclsid, sizeof(*pclsid))) + { + ASSERT_LOCK_RELEASED + LOCK + + RIFEntry *pRIFEntry; + hr = gRIFTbl.GetPSClsid(riid, pclsid, &pRIFEntry); + + UNLOCK + ASSERT_LOCK_RELEASED + } + + return hr; +} diff --git a/private/ole32/com/dcomrem/riftbl.hxx b/private/ole32/com/dcomrem/riftbl.hxx new file mode 100644 index 000000000..a0a0b7987 --- /dev/null +++ b/private/ole32/com/dcomrem/riftbl.hxx @@ -0,0 +1,79 @@ +//+------------------------------------------------------------------------ +// +// File: riftbl.hxx +// +// Contents: RIF (registered interface) table. +// +// Classes: CRIFTable +// +// History: 12-Feb-96 Rickhi Created +// +//------------------------------------------------------------------------- +#ifndef _RIFTBL_HXX_ +#define _RIFTBL_HXX_ + +#include <pgalloc.hxx> // CPageAllocator +#include <hash.hxx> // CUUIDHashTable + + +//+------------------------------------------------------------------------ +// +// Struct: RIFEntry - Registered Interface Entry +// +// This structure defines an Entry in the RIF table. There is one RIF +// table for the entire process. There is one RIFEntry per interface +// the current process is using (client side or server side). +// +//------------------------------------------------------------------------- +typedef struct tagRIFEntry +{ + SUUIDHashNode HashNode; // hash chain and key (IID) + CLSID psclsid; // proxy stub clsid + RPC_SERVER_INTERFACE *pSrvInterface; // ptr to server interface + RPC_CLIENT_INTERFACE *pCliInterface; // ptr tp client interface +} RIFEntry; + + +//+------------------------------------------------------------------------ +// +// class: CRIFTable +// +// Synopsis: Hash table of registered interfaces. +// +// History: 12-Feb-96 Rickhi Created +// +// Notes: Entries are kept in a hash table keyed by the IID. Entries +// are allocated via the page-based allocator. There is one +// global instance of this table per process (gRIFTbl). +// +//------------------------------------------------------------------------- +class CRIFTable +{ +public: + void Initialize(); + void Cleanup(); + + HRESULT RegisterInterface(REFIID riid, BOOL fServer, CLSID *pClsid); + RPC_CLIENT_INTERFACE *GetClientInterfaceInfo(REFIID riid); + + HRESULT RegisterPSClsid(REFIID riid, REFCLSID rclsid); + HRESULT GetPSClsid(REFIID riid, CLSID *pclsid, RIFEntry **ppEntry); + + void UnRegisterInterface(RIFEntry *pRIFEntry); + +private: + + HRESULT RegisterClientInterface(RIFEntry *pRIFEntry, REFIID riid); + HRESULT RegisterServerInterface(RIFEntry *pRIFEntry, REFIID riid); + HRESULT AddEntry(REFCLSID rclsid, REFIID riid, DWORD iHash, RIFEntry **ppRIFEntry); + + CUUIDHashTable _HashTbl; // interface lookup hash table + CPageAllocator _palloc; // page allocator +}; + + +// global externs +extern CRIFTable gRIFTbl; +extern const RPC_SERVER_INTERFACE gRemUnknownIf; + +#endif // _RIFTBL_HXX_ diff --git a/private/ole32/com/dcomrem/rpcspy.hxx b/private/ole32/com/dcomrem/rpcspy.hxx new file mode 100644 index 000000000..3cf829ec7 --- /dev/null +++ b/private/ole32/com/dcomrem/rpcspy.hxx @@ -0,0 +1,204 @@ +//+--------------------------------------------------------------------------- +// +// Microsoft Windows +// Copyright (C) Microsoft Corporation, 1992 - 1994. +// +// File: rpcspy.hxx +// +// Contents: A primitive rpc spy with output to debug terminal +// +// Classes: +// +// Functions: +// +// History: 3-31-95 JohannP (Johann Posch) Created +// +// Note: Can be turned on via CairOle InfoLelevel mask 0x08000000 +// +//---------------------------------------------------------------------------- + +#ifndef _RPCSPY_HXX_ +#define _RPCSPY_HXX_ + +#if DBG==1 +// +// switch on to trace rpc calls +// by setting CairoleInfoLevel = DEB_USER1; +// +// +#define NESTING_SPACES 32 +#define SPACES_PER_LEVEL 2 +static char achSpaces[NESTING_SPACES+1] = " "; +WORD wlevel = 0; +char tabs[128]; + +//+--------------------------------------------------------------------------- +// +// Method: PushLevel +// +// Synopsis: +// +// Arguments: (none) +// +// History: 3-31-95 JohannP (Johann Posch) Created +// +//---------------------------------------------------------------------------- +void PushLevel() +{ + wlevel++; +} +//+--------------------------------------------------------------------------- +// +// Method: PopLevel +// +// Synopsis: +// +// History: 3-31-95 JohannP (Johann Posch) Created +// +//---------------------------------------------------------------------------- +void PopLevel() +{ + if (wlevel) + wlevel--; +} + +//+--------------------------------------------------------------------------- +// +// Method: NestingSpaces +// +// Synopsis: +// +// Arguments: [psz] -- +// +// Returns: +// +// History: 3-31-95 JohannP (Johann Posch) Created +// +// Notes: +// +//---------------------------------------------------------------------------- +void NestingSpaces(char *psz) +{ + int iSpaces, i; + + iSpaces = wlevel * SPACES_PER_LEVEL; + + while (iSpaces > 0) + { + i = min(iSpaces, NESTING_SPACES); + memcpy(psz, achSpaces, i); + psz += i; + *psz = 0; + iSpaces -= i; + } +} + + +//+--------------------------------------------------------------------------- +// +// Method: GetTabs +// +// Synopsis: +// +// Arguments: (none) +// +// Returns: +// +// History: 3-31-95 JohannP (Johann Posch) Created +// +// Notes: +// +//---------------------------------------------------------------------------- +LPSTR GetTabs() +{ + static char ach[256]; + char *psz; + + sprintf(ach, "%2d:", wlevel); + psz = ach+strlen(ach); + + if (sizeof(ach)/SPACES_PER_LEVEL <= wlevel) + { + strcpy(psz, "..."); + } + else + { + NestingSpaces(psz); + } + return ach; +} + + +typedef enum +{ + CALLIN_BEGIN =1, + CALLIN_TRACE, + CALLIN_ERROR, + CALLIN_END, + CALLOUT_BEGIN, + CALLOUT_TRACE, + CALLOUT_ERROR, + CALLOUT_END +} RPCSPYMODE; + + +//+--------------------------------------------------------------------------- +// +// Method: RpcSpyOutput +// +// Synopsis: +// +// Arguments: [mode] -- in or out call +// [iid] -- interface id +// [dwMethod] -- called method +// [hres] -- hresult of finished call +// +// Returns: +// +// History: 3-31-95 JohannP (Johann Posch) Created +// +// Notes: +// +//---------------------------------------------------------------------------- +void RpcSpyOutput(RPCSPYMODE mode , REFIID iid, DWORD dwMethod, HRESULT hres) +{ + switch (mode) + { + case CALLIN_BEGIN: + CairoleDebugOut((DEB_RPCSPY,"%s <<< %lx, %d \n",GetTabs(), iid.Data1, dwMethod)); + PushLevel(); + break; + case CALLIN_TRACE: + break; + case CALLIN_ERROR: + break; + case CALLIN_END: + PopLevel(); + CairoleDebugOut((DEB_RPCSPY,"%s === %lx, %d (%lx) \n",GetTabs(), iid.Data1, dwMethod, hres)); + break; + case CALLOUT_BEGIN: + CairoleDebugOut((DEB_RPCSPY,"%s >>> %lx, %d \n",GetTabs(), iid.Data1, dwMethod)); + PushLevel(); + break; + case CALLOUT_TRACE: + break; + case CALLOUT_ERROR: + CairoleDebugOut((DEB_RPCSPY,"%s !!! %lx, %d, error:%lx \n",GetTabs(), iid.Data1, dwMethod, hres)); + break; + case CALLOUT_END: + PopLevel(); + CairoleDebugOut((DEB_RPCSPY,"%s +++ %lx, %d (%lx) \n",GetTabs(), iid.Data1, dwMethod, hres)); + break; + } +} + +#define RpcSpy(x) RpcSpyOutput x + +#else + +#define RpcSpy(x) + +#endif // DBG==1 + + +#endif // _RPCSPY_HXX_ diff --git a/private/ole32/com/dcomrem/security.cxx b/private/ole32/com/dcomrem/security.cxx new file mode 100644 index 000000000..de7f7e024 --- /dev/null +++ b/private/ole32/com/dcomrem/security.cxx @@ -0,0 +1,3060 @@ +//+------------------------------------------------------------------- +// +// File: security.cxx +// +// Copyright (c) 1996-1996, Microsoft Corp. All rights reserved. +// +// Contents: Classes for channel security +// +// Classes: CClientSecurity, CServerSecurity +// +// History: 11 Oct 95 AlexMit Created +// +//-------------------------------------------------------------------- + +#include <ole2int.h> +#include <locks.hxx> +#include <security.hxx> +#include <channelb.hxx> +#include <ipidtbl.hxx> +#include <resolver.hxx> +#include <service.hxx> +#include <oleext.h> +#include <stream.hxx> + +#ifdef _CHICAGO_ +#include <apiutil.h> +#include <wksta.h> +#endif + +#ifdef DCOM_SECURITY +/**********************************************************************/ +// Definitions. + +// Versions of the permissions in the registry. +const WORD COM_PERMISSION_SECDESC = 1; +const WORD COM_PERMISSION_ACCCTRL = 2; + +// Guess length of user name. +const DWORD SIZEOF_NAME = 80; + +// This leaves space for 8 sub authorities. Currently NT only uses 6 and +// Cairo uses 7. +const DWORD SIZEOF_SID = 44; + +// This leaves space for 2 access allowed ACEs in the ACL. +const DWORD SIZEOF_ACL = sizeof(ACL) + 2 * sizeof(ACCESS_ALLOWED_ACE) + + 2 * SIZEOF_SID; + +const DWORD SIZEOF_TOKEN_USER = sizeof(TOKEN_USER) + SIZEOF_SID; + +const SID LOCAL_SYSTEM_SID = {SID_REVISION, 1, {0,0,0,0,0,5}, + SECURITY_LOCAL_SYSTEM_RID }; + +const DWORD NUM_SEC_PKG = 8; + +const DWORD ACCESS_CACHE_LEN = 5; + +const DWORD VALID_INIT_FLAGS = EOAC_SECURE_REFS | EOAC_MUTUAL_AUTH | + EOAC_ACCESS_CONTROL | EOAC_APPID | EOAC_DYNAMIC; + +// Remove this for NT 5.0 when we link to oleext.lib +const IID IID_IAccessControl = {0xEEDD23E0,0x8410,0x11CE,{0xA1,0xC3,0x08,0x00,0x2B,0x2B,0x8D,0x8F}}; + +// Stores results of AccessCheck. +typedef struct +{ + BOOL fAccess; + DWORD lHash; + SID sid; +} SAccessCache; + +// Header in access permission key. +typedef struct +{ + WORD wVersion; + WORD wPad; + GUID gClass; +} SPermissionHeader; + +#ifdef _CHICAGO_ +typedef unsigned + (*NetWkstaGetInfoFn) ( const char FAR * pszServer, + short sLevel, + char FAR * pbBuffer, + unsigned short cbBuffer, + unsigned short FAR * pcbTotalAvail ); +#endif + +/**********************************************************************/ +// Externals. + +EXTERN_C const IID IID_IObjServer; + + +/**********************************************************************/ +// Prototypes. +void CacheAccess ( SID *pSid, BOOL fAccess ); +BOOL CacheAccessCheck ( SID *pSid, BOOL *pAccess ); +HRESULT CopySecDesc ( SECURITY_DESCRIPTOR *pOrig, + SECURITY_DESCRIPTOR **pCopy ); +HRESULT FixupAccessControl ( SECURITY_DESCRIPTOR **pSD, DWORD cbSD ); +HRESULT FixupSecurityDescriptor( SECURITY_DESCRIPTOR **pSD, DWORD cbSD ); +HRESULT GetLegacySecDesc ( SECURITY_DESCRIPTOR **, DWORD * ); +HRESULT GetRegistrySecDesc ( HKEY, WCHAR *pValue, + SECURITY_DESCRIPTOR **pSD, DWORD * ); +DWORD HashSid ( SID * ); +BOOL IsLocalAuthnService ( USHORT wAuthnService ); +HRESULT MakeSecDesc ( SECURITY_DESCRIPTOR **, DWORD * ); +HRESULT DefaultAuthnServices ( void ); +HRESULT RegisterAuthnServices ( DWORD cbSvc, SOLE_AUTHENTICATION_SERVICE * ); + +#ifndef _CHICAGO_ +HRESULT LookupPrincName ( WCHAR ** ); +#else +HRESULT LookupPrincName( + USHORT *pwAuthnServices, + ULONG cAuthnServices, + WCHAR **pPrincName + ); +#endif // _CHICAGO_ + +/**********************************************************************/ +// Globals. + +// These variables hold the default authentication information. +DWORD gAuthnLevel = RPC_C_AUTHN_LEVEL_NONE; +DWORD gImpLevel = RPC_C_IMP_LEVEL_IDENTIFY; +DWORD gCapabilities = EOAC_NONE; +SECURITYBINDING *gLegacySecurity = NULL; + +// These variables define a list of security providers OLE clients can +// use and a list OLE servers can use. +USHORT *gClientSvcList = NULL; +DWORD gClientSvcListLen = 0; +USHORT *gServerSvcList = NULL; +DWORD gServerSvcListLen = 0; + +// gDisableDCOM is read from the registry by CRpcResolver::GetConnection. +// If TRUE, all machine remote calls will be failed. It is set TRUE in WOW. +BOOL gDisableDCOM = FALSE; + +// Set TRUE when CRpcResolver::GetConnection initializes the previous globals. +BOOL gGotSecurityData = FALSE; + +// The security descriptor to check when new connections are established. +// gAccessControl and gSecDesc will not both be nonNULL at the same time. +IAccessControl *gAccessControl = NULL; +SECURITY_DESCRIPTOR *gSecDesc = NULL; + +// The security string array. If gDefaultService is TRUE, compute the +// security string array the first time a remote protocol sequence is +// registered. +DUALSTRINGARRAY *gpsaSecurity = NULL; +BOOL gDefaultService = FALSE; + +// The security descriptor to check in RundownOID. +SECURITY_DESCRIPTOR *gRundownSD = NULL; + +// Don't map any of the generic bits to COM_RIGHTS_EXECUTE or any other bit. +GENERIC_MAPPING gMap = { 0, 0, 0, 0 }; +PRIVILEGE_SET gPriv = { 1, 0 }; + +// Cache of results of calls to AccessCheck. +SAccessCache *gAccessCache[ACCESS_CACHE_LEN] = {NULL, NULL, NULL, NULL, NULL}; +DWORD gMostRecentAccess = 0; + + +//+------------------------------------------------------------------- +// +// Function: CacheAccess +// +// Synopsis: Store the results of the access check in the cache. +// +//-------------------------------------------------------------------- +void CacheAccess( SID *pSid, BOOL fAccess ) +{ + SAccessCache *pNew; + DWORD cbSid; + + ASSERT_LOCK_RELEASED + LOCK + + // Allocate a new record. + cbSid = GetLengthSid( pSid ); + pNew = (SAccessCache *) PrivMemAlloc( sizeof(SAccessCache) + cbSid - + sizeof(SID) ); + + // Initialize the record. + if (pNew != NULL) + { + pNew->fAccess = fAccess; + pNew->lHash = HashSid( pSid ); + memcpy( &pNew->sid, pSid, cbSid ); + + // Free the old record and insert the new. + gMostRecentAccess += 1; + if (gMostRecentAccess >= ACCESS_CACHE_LEN) + gMostRecentAccess = 0; + PrivMemFree( gAccessCache[gMostRecentAccess] ); + gAccessCache[gMostRecentAccess] = pNew; + } + + UNLOCK + ASSERT_LOCK_RELEASED +} + +//+------------------------------------------------------------------- +// +// Function: CacheAccessCheck +// +// Synopsis: Look for the specified SID in the cache. If found, +// return the results of the cached access check. +// +//-------------------------------------------------------------------- +BOOL CacheAccessCheck( SID *pSid, BOOL *pAccess ) +{ + DWORD i; + DWORD lHash = HashSid( pSid ); + DWORD j; + BOOL fFound = FALSE; + SAccessCache *pSwap; + + ASSERT_LOCK_RELEASED + LOCK + + // Look for the SID. + j = gMostRecentAccess; + for (i = 0; i < ACCESS_CACHE_LEN; i++) + { + if (gAccessCache[j] != NULL && + gAccessCache[j]->lHash == lHash && + EqualSid( pSid, &gAccessCache[j]->sid )) + { + // Move this entry to the head. + fFound = TRUE; + *pAccess = gAccessCache[j]->fAccess; + pSwap = gAccessCache[gMostRecentAccess]; + gAccessCache[gMostRecentAccess] = gAccessCache[j]; + gAccessCache[j] = pSwap; + break; + } + if (j == 0) + j = ACCESS_CACHE_LEN - 1; + else + j -= 1; + } + + UNLOCK + ASSERT_LOCK_RELEASED + return fFound; +} + +//+------------------------------------------------------------------- +// +// Member: CClientSecurity::CopyProxy, public +// +// Synopsis: Create a new IPID entry for the specified IID. +// +//-------------------------------------------------------------------- +STDMETHODIMP CClientSecurity::CopyProxy( IUnknown *pProxy, IUnknown **ppCopy ) +{ + // Make sure TLS is initialized on this thread. + HRESULT hr; + COleTls tls(hr); + if (FAILED(hr)) + return hr; + + // Ask the marshaller to copy the proxy. + return _pStdId->PrivateCopyProxy( pProxy, ppCopy ); +} + +//+------------------------------------------------------------------- +// +// Member: CClientSecurity::QueryBlanket, public +// +// Synopsis: Get the binding handle for a proxy. Query RPC for the +// authentication information for that handle. +// +//-------------------------------------------------------------------- +STDMETHODIMP CClientSecurity::QueryBlanket( + IUnknown *pProxy, + DWORD *pAuthnSvc, + DWORD *pAuthzSvc, + OLECHAR **pServerPrincName, + DWORD *pAuthnLevel, + DWORD *pImpLevel, + void **pAuthInfo, + DWORD *pCapabilities ) +{ + HRESULT hr; + IPIDEntry *pIpid; + RPC_STATUS sc; + DWORD iLen; + OLECHAR *pCopy; + handle_t hHandle; + IRemUnknown *pRemUnk = NULL; + RPC_SECURITY_QOS sQos; + + ASSERT_LOCK_RELEASED + LOCK + + // Initialize all out parameters to default values. + if (pServerPrincName != NULL) + *pServerPrincName = NULL; + if (pAuthnLevel != NULL) + *pAuthnLevel = RPC_C_AUTHN_LEVEL_PKT_PRIVACY; + if (pImpLevel != NULL) + *pImpLevel = RPC_C_IMP_LEVEL_IMPERSONATE; + if (pAuthnSvc != NULL) + *pAuthnSvc = RPC_C_AUTHN_WINNT; + if (pAuthInfo != NULL) + *pAuthInfo = NULL; + if (pAuthzSvc != NULL) + *pAuthzSvc = RPC_C_AUTHZ_NONE; + if (pCapabilities != NULL) + *pCapabilities = EOAC_NONE; + + // For IUnknown just call QueryBlanket on the IRemUnknown of + // the IPID or the OXID. + if (_pStdId->GetCtrlUnk() == pProxy) + { + pIpid = _pStdId->GetConnectedIPID(); + hr = _pStdId->GetSecureRemUnk( &pRemUnk, pIpid->pOXIDEntry ); + if (pRemUnk != NULL) + { + UNLOCK + hr = CoQueryProxyBlanket( pRemUnk, pAuthnSvc, pAuthzSvc, + pServerPrincName, pAuthnLevel, + pImpLevel, pAuthInfo, pCapabilities ); + LOCK + } + } + + // Find the right IPID entry. + else + { + hr = _pStdId->FindIPIDEntryByInterface( pProxy, &pIpid ); + if (SUCCEEDED(hr)) + { + // Disallow server entries. + if (pIpid->dwFlags & IPIDF_SERVERENTRY) + hr = E_INVALIDARG; + + // No security for disconnected proxies. + else if (pIpid->dwFlags & IPIDF_DISCONNECTED) + hr = RPC_E_DISCONNECTED; + + // If it is local, use the default values for everything but the + // impersonation level. + else if (pIpid->pChnl->ProcessLocal()) + { + if (pImpLevel != NULL) + *pImpLevel = pIpid->pChnl->GetImpLevel(); + } + + // Otherwise ask RPC. + else + { + hr = pIpid->pChnl->GetHandle( &hHandle ); + + if (SUCCEEDED(hr)) + { + sc = RpcBindingInqAuthInfoExW( hHandle, + pServerPrincName, pAuthnLevel, + pAuthnSvc, pAuthInfo, + pAuthzSvc, + RPC_C_SECURITY_QOS_VERSION, + &sQos ); + + // RPC sometimes sets out parameters on error. + if (sc != RPC_S_OK) + { + if (pServerPrincName != NULL) + *pServerPrincName = NULL; + hr = MAKE_SCODE( SEVERITY_ERROR, FACILITY_WIN32, sc ); + } + else + { + // Return the impersonation level and capabilities. + if (pImpLevel != NULL) + *pImpLevel = sQos.ImpersonationType; + if (pCapabilities != NULL) + if (sQos.Capabilities & RPC_C_QOS_CAPABILITIES_MUTUAL_AUTH) + *pCapabilities = EOAC_MUTUAL_AUTH; + else + *pCapabilities = EOAC_NONE; + + // Reallocate the principle name using the OLE memory allocator. + if (pServerPrincName != NULL && *pServerPrincName != NULL) + { + iLen = lstrlenW( *pServerPrincName ) + 1; + pCopy = (OLECHAR *) CoTaskMemAlloc( iLen * sizeof(OLECHAR) ); + if (pCopy != NULL) + memcpy( pCopy, *pServerPrincName, iLen*sizeof(USHORT) ); + else + hr = E_OUTOFMEMORY; + RpcStringFree( pServerPrincName ); + *pServerPrincName = pCopy; + } + } + } + } + } + } + + UNLOCK + ASSERT_LOCK_RELEASED + return hr; +} + +//+------------------------------------------------------------------- +// +// Member: CClientSecurity::SetBlanket, public +// +// Synopsis: Get the binding handle for a proxy. Call RPC to set the +// authentication information for that handle. +// +//-------------------------------------------------------------------- +STDMETHODIMP CClientSecurity::SetBlanket( + IUnknown *pProxy, + DWORD AuthnSvc, + DWORD AuthzSvc, + OLECHAR *pServerPrincName, + DWORD AuthnLevel, + DWORD ImpLevel, + void *pAuthInfo, + DWORD Capabilities ) +{ + HRESULT hr; + IPIDEntry *pIpid; + RPC_STATUS sc; + BOOL fSuccess; + HANDLE hToken = NULL; + HANDLE hProcess; + handle_t hHandle; + IRemUnknown *pRemUnk; + IRemUnknown *pSecureRemUnk = NULL; + RPC_SECURITY_QOS sQos; + SECURITY_IMPERSONATION_LEVEL eDuplicate; + DWORD dwOpen; + + ASSERT_LOCK_RELEASED + + // IUnknown is special. Set the security on IRemUnknown instead. + if (_pStdId->GetCtrlUnk() == pProxy) + { + // Make sure the identity has its own copy of the OXID's + // IRemUnknown. + if (!_pStdId->CheckSecureRemUnk()) + { + // This will get the remote unknown from the OXID. + LOCK + pIpid = _pStdId->GetConnectedIPID(); + hr = _pStdId->GetSecureRemUnk( &pRemUnk, pIpid->pOXIDEntry ); + if (SUCCEEDED(hr)) + { + UNLOCK + hr = CoCopyProxy( pRemUnk, (IUnknown **) &pSecureRemUnk ); + LOCK + if (SUCCEEDED(hr)) + { + // Remote Unknown proxies are not supposed to ref count + // the OXID. + pIpid->pOXIDEntry->cRefs -= 1; + + // Only keep the proxies if no one else made a copy + // while this thread was making a copy. + if (!_pStdId->CheckSecureRemUnk()) + _pStdId->SetSecureRemUnk( pSecureRemUnk ); + else + { + pSecureRemUnk->Release(); + hr = _pStdId->GetSecureRemUnk( &pSecureRemUnk, NULL ); + } + } + } + UNLOCK + } + else + hr = _pStdId->GetSecureRemUnk( &pSecureRemUnk, NULL ); + + // Call SetBlanket on the copy of IRemUnknown. + if (pSecureRemUnk != NULL) + hr = CoSetProxyBlanket( pSecureRemUnk, AuthnSvc, AuthzSvc, + pServerPrincName, AuthnLevel, + ImpLevel, pAuthInfo, Capabilities ); + } + + else + { + // Find the right IPID entry. + LOCK + hr = _pStdId->FindIPIDEntryByInterface( pProxy, &pIpid ); + if (SUCCEEDED(hr)) + { + // Disallow server entries. + if (pIpid->dwFlags & IPIDF_SERVERENTRY) + hr = E_INVALIDARG; + + // No security for disconnected proxies. + else if (pIpid->dwFlags & IPIDF_DISCONNECTED) + hr = RPC_E_DISCONNECTED; + + else if (pIpid->pChnl->ProcessLocal()) + { + // Local calls can use no authn service or winnt. + if (AuthnSvc != RPC_C_AUTHN_NONE && + AuthnSvc != RPC_C_AUTHN_WINNT) + hr = E_INVALIDARG; + + // Make sure the authentication level is not invalid. + else if ((AuthnSvc == RPC_C_AUTHN_NONE && + AuthnLevel != RPC_C_AUTHN_LEVEL_NONE) || + (AuthnSvc == RPC_C_AUTHN_WINNT && + AuthnLevel > RPC_C_AUTHN_LEVEL_PKT_PRIVACY)) + hr = E_INVALIDARG; + + // No authorization services are supported locally. + else if (AuthzSvc != RPC_C_AUTHZ_NONE) + hr = E_INVALIDARG; + + // You cannot supply credentials locally. + else if (pAuthInfo != NULL) + hr = E_INVALIDARG; + + // Impersonation is not legal yet. + else if (ImpLevel != RPC_C_IMP_LEVEL_IMPERSONATE && + ImpLevel != RPC_C_IMP_LEVEL_IDENTIFY) + hr = E_INVALIDARG; + + // No capabilities are supported yet. + else if (Capabilities != EOAC_NONE) + hr = E_INVALIDARG; + + // Don't do delegation for NT 4.0 +#ifndef _SOME_FUTURE_PRODUCT_ + pIpid->pChnl->SetAuthnLevel( AuthnLevel ); + pIpid->pChnl->SetImpLevel( ImpLevel ); +#else + + // Save the user token if the app asked for security. + else if (AuthnLevel != RPC_C_AUTHN_LEVEL_NONE) + { + if (ImpLevel == RPC_C_IMP_LEVEL_IMPERSONATE) + { + eDuplicate = SecurityImpersonation; + dwOpen = TOKEN_IMPERSONATE; + } + else + { + eDuplicate = SecurityIdentification; + dwOpen = TOKEN_QUERY; + } + fSuccess = OpenThreadToken( GetCurrentThread(), dwOpen, + TRUE, &hToken ); + hr = GetLastError(); + + // If the application is not impersonating, no thread token + // will be present. Get the process token instead. + if (!fSuccess && hr == ERROR_NO_TOKEN) + { + fSuccess = OpenProcessToken( GetCurrentProcess(), + TOKEN_DUPLICATE, &hProcess ); + if (fSuccess) + { + fSuccess = DuplicateToken( hProcess, eDuplicate, + &hToken ); + CloseHandle( hProcess ); + } + } + if (fSuccess) + { + hToken = pIpid->pChnl->SwapSecurityToken( hToken ); + pIpid->pChnl->SetAuthnLevel( AuthnLevel ); + pIpid->pChnl->SetImpLevel( ImpLevel ); + hr = S_OK; + } + else + { + hr = MAKE_SCODE( SEVERITY_ERROR, FACILITY_WIN32, GetLastError() ); + hToken = pIpid->pChnl->SwapSecurityToken( NULL ); + } + CloseHandle( hToken ); + } + + // If there was an old token, toss it. + else if (pIpid->pChnl->GetSecurityToken() != NULL) + { + hToken = pIpid->pChnl->SwapSecurityToken( NULL ); + CloseHandle( hToken ); + pIpid->pChnl->SetAuthnLevel( AuthnLevel ); + pIpid->pChnl->SetImpLevel( ImpLevel ); + } +#endif // !_SOME_FUTURE_PRODUCT_ + } + + // If it is remote, tell RPC. + else + { + // Validate the capabilities. + if (Capabilities & ~ EOAC_MUTUAL_AUTH) + hr = E_INVALIDARG; + else + hr = pIpid->pChnl->GetHandle( &hHandle ); + + if (SUCCEEDED(hr)) + { +#ifdef _CHICAGO_ + // If the principal name is not known, the server must be + // NT. Replace the principal name in that case + // because a NULL principal name is a flag for some + // Chicago security hack. + if (pServerPrincName == NULL && + AuthnSvc == RPC_C_AUTHN_WINNT && + (pIpid->pOXIDEntry->dwFlags & OXIDF_MACHINE_LOCAL) == 0) + pServerPrincName = L"Default"; +#endif // _CHICAGO_ + + // Suspend any outstanding impersonation and ignore failures. + COleTls tls(hr); + BOOL resume = FALSE; + if (SUCCEEDED(hr)) + SuspendImpersonate( tls->pCallContext, &resume ); + else + hr = S_OK; + + sQos.Version = RPC_C_SECURITY_QOS_VERSION; + sQos.IdentityTracking = RPC_C_QOS_IDENTITY_STATIC; + sQos.ImpersonationType = ImpLevel; + sQos.Capabilities = (Capabilities & EOAC_MUTUAL_AUTH) ? + RPC_C_QOS_CAPABILITIES_MUTUAL_AUTH : RPC_C_QOS_CAPABILITIES_DEFAULT; + sc = RpcBindingSetAuthInfoExW( hHandle, + pServerPrincName, AuthnLevel, + AuthnSvc, pAuthInfo, AuthzSvc, + &sQos ); + + // Resume any outstanding impersonation. + ResumeImpersonate( tls->pCallContext, resume ); + + if (sc != RPC_S_OK) + hr = MAKE_SCODE( SEVERITY_ERROR, FACILITY_WIN32, sc ); + else + pIpid->pChnl->SetAuthnLevel( AuthnLevel ); + } + } + } + + UNLOCK + } + ASSERT_LOCK_RELEASED + + return hr; +} + +//+------------------------------------------------------------------- +// +// Function: CheckAccessControl +// +// Synopsis: Call the access control and ask it to check access. +// +//-------------------------------------------------------------------- +RPC_STATUS CheckAccessControl( RPC_IF_HANDLE pIid, void *pContext ) +{ + HRESULT hr; + TRUSTEE_W sTrustee; + CServerSecurity sSecurity; + IUnknown *pSave; + BOOL fAccess = FALSE; + COleTls tls(hr); +#if DBG == 1 + char *pFailure = ""; +#endif + + sTrustee.ptstrName = NULL; + if (FAILED(hr)) + { +#if DBG == 1 + pFailure = "Bad TLS: 0x%x\n"; +#endif + } + + else + { +#ifdef _CHICAGO_ + // On Chicago RpcBindingInqAuthClientW doesn't work locally. Since + // IObjServer is the only interface that uses security locally on + // Chicago, allow it if the call is local. + if (pIid == NULL) + return RPC_S_OK; + else if ((*(IID *) pIid) == IID_IObjServer) + { +#if DBG == 1 + pFailure = "IObjServer can't be called remotely: 0x%x\n"; +#endif + hr = E_ACCESSDENIED; + } +#else + // Since IObjServer always uses dynamic impersonation, allow access here. + // It will be checked later in CheckObjactAccess. + if (pIid != NULL && *((IID *) pIid) == IID_IObjServer) + return RPC_S_OK; +#endif + + if (SUCCEEDED(hr)) + { + // Get the trustee name. + hr = RpcBindingInqAuthClientW( NULL, + (void **) &sTrustee.ptstrName, + NULL, NULL, NULL, NULL ); + + if (hr == RPC_S_OK) + { + // Save the security context in TLS. + pSave = tls->pCallContext; + tls->pCallContext = &sSecurity; + + // Check access. + sTrustee.pMultipleTrustee = NULL; + sTrustee.MultipleTrusteeOperation = NO_MULTIPLE_TRUSTEE; + sTrustee.TrusteeForm = TRUSTEE_IS_NAME; + sTrustee.TrusteeType = TRUSTEE_IS_USER; + hr = gAccessControl->IsAccessAllowed( &sTrustee, NULL, + COM_RIGHTS_EXECUTE, &fAccess ); +#if DBG==1 + if (FAILED(hr)) + pFailure = "IsAccessAllowed failed: 0x%x\n"; +#endif + if (SUCCEEDED(hr) && !fAccess) + { + hr = E_ACCESSDENIED; +#if DBG==1 + pFailure = "IAccessControl does not allow user access.\n"; +#endif + } + + // Restore the security context. + tls->pCallContext = pSave; + } +#if DBG == 1 + else + pFailure = "RpcBindingInqAuthClientW failed: 0x%x\n"; +#endif + } + } + +#if DBG==1 + if (FAILED(hr)) + { + ComDebOut(( DEB_WARN, "***** ACCESS DENIED *****\n" )); + ComDebOut(( DEB_WARN, pFailure, hr )); + + // Print the user name. + if (sTrustee.ptstrName != NULL) + ComDebOut(( DEB_WARN, "User: %ws\n", sTrustee.ptstrName )); + } +#endif + return hr; +} + +//+------------------------------------------------------------------- +// +// Function: CheckAcl +// +// Synopsis: Impersonate and do an AccessCheck against the global ACL. +// +//-------------------------------------------------------------------- +RPC_STATUS CheckAcl( RPC_IF_HANDLE pIid, void *pContext ) +{ + RPC_STATUS sc; + BOOL fAccess = FALSE; + BOOL fSuccess; + DWORD lGrantedAccess; + DWORD lSetLen = sizeof(gPriv); + HANDLE hToken; + DWORD i; + DWORD lSize = SIZEOF_TOKEN_USER; + TOKEN_USER *pTokenInfo = (TOKEN_USER *) _alloca( lSize ); + SID *pSid = NULL; +#if DBG==1 + char *pFailure = ""; +#endif + + // Since IObjServer always uses dynamic impersonation, allow access here. + // It will be checked later in CheckObjactAccess. + if (pIid != NULL && *((IID *) pIid) == IID_IObjServer) + return RPC_S_OK; + + // Impersonate. + sc = RpcImpersonateClient( NULL ); + + if (sc == RPC_S_OK) + { + // Open the thread token. + fSuccess = OpenThreadToken( GetCurrentThread(), TOKEN_READ, + TRUE, &hToken ); + + // Revert. + RpcRevertToSelf(); + + if (fSuccess) + { + // Get the SID and see if its cached. + if (GetTokenInformation( hToken, TokenUser, pTokenInfo, + lSize, &lSize )) + { + pSid = (SID *) pTokenInfo->User.Sid; + fSuccess = CacheAccessCheck( pSid, &fAccess ); + if (fSuccess) + { + CloseHandle( hToken ); + if (fAccess) + return RPC_S_OK; + else + return RPC_E_ACCESS_DENIED; + } + } + + // Access check. + fSuccess = AccessCheck( gSecDesc, hToken, COM_RIGHTS_EXECUTE, + &gMap, &gPriv, &lSetLen, &lGrantedAccess, + &fAccess ); + if (fSuccess) + CacheAccess( pSid, fAccess ); + + if (!fAccess) + { + sc = RPC_E_ACCESS_DENIED; +#if DBG==1 + pFailure = "Security descriptor does not allow user access.\n"; +#endif + } +#if DBG==1 + if (!fSuccess) + pFailure = "Bad security descriptor"; +#endif + CloseHandle( hToken ); + } + else + { + sc = GetLastError(); +#if DBG==1 + pFailure = "Could not open thread token: 0x%x\n"; +#endif + } + } +#if DBG==1 + else + pFailure = "Could not impersonate client: 0x%x\n"; +#endif + +#if DBG==1 + if (sc != 0) + { + ComDebOut(( DEB_WARN, "***** ACCESS DENIED *****\n" )); + ComDebOut(( DEB_WARN, pFailure, sc )); + + // Print the user name. + WCHAR *pClient; + if (0 == RpcBindingInqAuthClient( NULL, (void **) &pClient, NULL, + NULL, NULL, NULL ) && + pClient != NULL) + ComDebOut(( DEB_WARN, "User: %ws\n", pClient )); + + // Print the user sid. + ComDebOut(( DEB_WARN, "Security Descriptor 0x%x\n", gSecDesc )); + if (pSid != NULL) + { + ComDebOut(( DEB_WARN, "SID:\n" )); + ComDebOut(( DEB_WARN, " Revision: 0x%02x\n", pSid->Revision )); + ComDebOut(( DEB_WARN, " SubAuthorityCount: 0x%x\n", pSid->SubAuthorityCount )); + ComDebOut(( DEB_WARN, " IdentifierAuthority: 0x%02x%02x%02x%02x%02x%02x\n", + pSid->IdentifierAuthority.Value[0], + pSid->IdentifierAuthority.Value[1], + pSid->IdentifierAuthority.Value[2], + pSid->IdentifierAuthority.Value[3], + pSid->IdentifierAuthority.Value[4], + pSid->IdentifierAuthority.Value[5] )); + for (DWORD i = 0; i < pSid->SubAuthorityCount; i++) + ComDebOut(( DEB_WARN, " SubAuthority[%d]: 0x%08x\n", i, + pSid->SubAuthority[i] )); + } + else + ComDebOut(( DEB_WARN, " Unknown\n" )); + } +#endif + return sc; +} + +//+------------------------------------------------------------------- +// +// Function: CheckObjactAccess, private +// +// Synopsis: Determine whether caller has permission to make call. +// +// Notes: Since IObjServer uses dynamic delegation, we have to allow +// all calls to IObjServer through the normal security (which only +// checks access on connect) and check them manually. +// +//-------------------------------------------------------------------- +BOOL CheckObjactAccess() +{ + RPC_IF_CALLBACK_FN *pAccess; + + // Get the access check function. + pAccess = GetAclFn(); + + // Check access. Lie about the IID since the check access functions + // won't fail calls to IID_IObjServer. + if (pAccess != NULL) + return pAccess( NULL, NULL ) == S_OK; + else + return TRUE; +} + +//+------------------------------------------------------------------- +// +// Function: CoCopyProxy, public +// +// Synopsis: Copy a proxy. +// +//-------------------------------------------------------------------- +WINOLEAPI CoCopyProxy( + IUnknown *pProxy, + IUnknown **ppCopy ) +{ + HRESULT hr; + IClientSecurity *pickle; + // Ask the proxy for IClientSecurity. + hr = ((IUnknown *) pProxy)->QueryInterface( IID_IClientSecurity, + (void **) &pickle ); + if (FAILED(hr)) + return hr; + + // Ask IClientSecurity to do the copy. + hr = pickle->CopyProxy( pProxy, ppCopy ); + pickle->Release(); + return hr; +} + +//+------------------------------------------------------------------- +// +// Function: CoGetCallContext +// +// Synopsis: Get an interface that supplies contextual information +// about the call. Currently only IServerSecurity. +// +//-------------------------------------------------------------------- +WINOLEAPI CoGetCallContext( REFIID riid, void **ppInterface ) +{ + HRESULT hr; + COleTls tls(hr); + + if (SUCCEEDED(hr)) + { + // Fail if there is no call context. + if (tls->pCallContext == NULL) + return RPC_E_CALL_COMPLETE; + + // Look up the requested interface. + return tls->pCallContext->QueryInterface( riid, ppInterface ); + } + return hr; +} + +//+------------------------------------------------------------------- +// +// Function: CoImpersonateClient +// +// Synopsis: Get the server security for the current call and ask it +// to do an impersonation. +// +//-------------------------------------------------------------------- +WINOLEAPI CoImpersonateClient() +{ + HRESULT hr; + IServerSecurity *pSS; + + // Get the IServerSecurity. + hr = CoGetCallContext( IID_IServerSecurity, (void **) &pSS ); + if (FAILED(hr)) + return hr; + + // Ask IServerSecurity to do the impersonate. + hr = pSS->ImpersonateClient(); + pSS->Release(); + return hr; +} + +//+------------------------------------------------------------------- +// +// Function: CoInitializeSecurity, public +// +// Synopsis: Set the values to use for automatic security. This API +// can only be called once so it does not need to be thread +// safe. +// +//-------------------------------------------------------------------- +WINOLEAPI CoInitializeSecurity( + PSECURITY_DESCRIPTOR pVoid, + LONG cAuthSvc, + SOLE_AUTHENTICATION_SERVICE *asAuthSvc, + void *pReserved1, + DWORD dwAuthnLevel, + DWORD dwImpLevel, + void *pReserved2, + DWORD dwCapabilities, + void *pReserved3 ) +{ + HRESULT hr = S_OK; + DWORD i; + SECURITY_DESCRIPTOR *pSecDesc = (SECURITY_DESCRIPTOR *) pVoid; + SECURITY_DESCRIPTOR *pCopySecDesc = NULL; + IAccessControl *pAccessControl = NULL; + BOOL fFreeSecDesc = FALSE; + SOLE_AUTHENTICATION_SERVICE sAuthSvc; + + // Fail if OLE is not initialized or TLS cannot be allocated. + if (!IsApartmentInitialized()) + return CO_E_NOTINITIALIZED; + + // Make sure the security data is available. + if (!gGotSecurityData) + { + hr = gResolver.GetConnection(); + if (FAILED(hr)) + return hr; + Win4Assert(gGotSecurityData); + } + + // Make sure only one of the flags defining the pVoid parameter is set. + if ((dwCapabilities & (EOAC_APPID | EOAC_ACCESS_CONTROL)) == + (EOAC_APPID | EOAC_ACCESS_CONTROL)) + return E_INVALIDARG; + + // If the appid flag is set, read the registry security. + if (dwCapabilities & EOAC_APPID) + { + // Get a security descriptor from the registry. + if (gAuthnLevel != RPC_C_AUTHN_LEVEL_NONE) + { + hr = GetLegacySecDesc( &pSecDesc, &dwCapabilities ); + if (FAILED(hr)) + return hr; + fFreeSecDesc = TRUE; + } + + // Fix up the security binding. + if (gLegacySecurity != NULL) + { + cAuthSvc = 1; + asAuthSvc = &sAuthSvc; + sAuthSvc.dwAuthnSvc = gLegacySecurity->wAuthnSvc; + sAuthSvc.dwAuthzSvc = gLegacySecurity->wAuthzSvc; + sAuthSvc.pPrincipalName = NULL; + if (sAuthSvc.dwAuthzSvc == COM_C_AUTHZ_NONE) + sAuthSvc.dwAuthzSvc = RPC_C_AUTHZ_NONE; + } + else + cAuthSvc = 0xFFFFFFFF; + + // Initialize remaining parameters. + pReserved1 = NULL; + dwAuthnLevel = gAuthnLevel; + dwImpLevel = gImpLevel; + pReserved2 = NULL; + pReserved3 = NULL; + dwCapabilities |= gCapabilities; + } + + // Fail if called too late, recalled, or called with bad parameters. + if (dwImpLevel > RPC_C_IMP_LEVEL_DELEGATE || + dwAuthnLevel > RPC_C_AUTHN_LEVEL_PKT_PRIVACY || + pReserved1 != NULL || + pReserved2 != NULL || + pReserved3 != NULL || + (dwCapabilities & ~VALID_INIT_FLAGS)) + { + hr = E_INVALIDARG; + goto Error; + } + if ((dwCapabilities & EOAC_SECURE_REFS) && + dwAuthnLevel == RPC_C_AUTHN_LEVEL_NONE) + { + hr = E_INVALIDARG; + goto Error; + } + + // Validate the pointers. + if (pSecDesc != NULL) + if (dwCapabilities & EOAC_ACCESS_CONTROL) + { + if (!IsValidPtrIn( pSecDesc, 4 )) + { + hr = E_INVALIDARG; + goto Error; + } + } + else if (!IsValidPtrIn( pSecDesc, sizeof(SECURITY_DESCRIPTOR) )) + { + hr = E_INVALIDARG; + goto Error; + } + if (cAuthSvc != 0 && cAuthSvc != -1 && + !IsValidPtrOut( asAuthSvc, sizeof(SOLE_AUTHENTICATION_SERVICE) * cAuthSvc )) + { + hr = E_INVALIDARG; + goto Error; + } + + LOCK + + if (gpsaSecurity != NULL) + hr = RPC_E_TOO_LATE; + + if (SUCCEEDED(hr)) + { + // If the app doesn't want security, don't set up a security + // descriptor. + if (dwAuthnLevel == RPC_C_AUTHN_LEVEL_NONE) + { + // Check for some more invalid parameters. + if (pSecDesc != NULL) + hr = E_INVALIDARG; + } + + // Check whether security is done with ACLs or IAccessControl. + else if (dwCapabilities & EOAC_ACCESS_CONTROL) + { + if (pSecDesc == NULL) + hr = E_INVALIDARG; + else + hr = ((IUnknown *) pSecDesc)->QueryInterface( + IID_IAccessControl, (void **) &pAccessControl ); + } + + else + { +#ifdef _CHICAGO_ + if (pSecDesc != NULL) + hr = E_INVALIDARG; +#else + // If specified, copy the security descriptor. + if (pSecDesc != NULL) + hr = CopySecDesc( pSecDesc, &pCopySecDesc ); +#endif + } + } + + if (SUCCEEDED(hr)) + { + // Delay the registration of authentication services if the caller + // isn't picky. + if (cAuthSvc == -1) + { + gpsaSecurity = (DUALSTRINGARRAY *) PrivMemAlloc( SASIZE(4) ); + if (gpsaSecurity != NULL) + { + gDefaultService = TRUE; + gpsaSecurity->wNumEntries = 4; + gpsaSecurity->wSecurityOffset = 2; + memset( gpsaSecurity->aStringArray, 0, 4*sizeof(WCHAR) ); + } + else + hr = E_OUTOFMEMORY; + } + + // Otherwise, register the ones the caller specified. + else + hr = RegisterAuthnServices( cAuthSvc, asAuthSvc ); + } + + // If everything succeeded, change the globals. + if (SUCCEEDED(hr)) + { + // Save the defaults. + gAuthnLevel = dwAuthnLevel; + gImpLevel = dwImpLevel; + gCapabilities = dwCapabilities; + gSecDesc = pCopySecDesc; + gAccessControl = pAccessControl; + if ( dwCapabilities & EOAC_DYNAMIC ) + gResolver.SetDynamicSecurity(); + } + + // Otherwise free any memory allocated. + else + { + PrivMemFree( pCopySecDesc ); + } + UNLOCK + + // If anything was allocated for app id security, free it. +Error: + if (fFreeSecDesc && pSecDesc != NULL) + if (dwCapabilities & EOAC_ACCESS_CONTROL) + ((IAccessControl *) pSecDesc)->Release(); + else + PrivMemFree( pSecDesc ); + return hr; +} + +//+------------------------------------------------------------------- +// +// Function: CopySecDesc +// +// Synopsis: Copy a security descriptor. +// +// Notes: The function does not copy the SACL because we do not do +// auditing. +// +//-------------------------------------------------------------------- +HRESULT CopySecDesc( SECURITY_DESCRIPTOR *pOrig, SECURITY_DESCRIPTOR **pCopy ) +{ + SID *pOwner; + SID *pGroup; + ACL *pDacl; + ULONG cSize; + ULONG cOwner; + ULONG cGroup; + ULONG cDacl; + + // Assert if there is a new revision for the security descriptor or + // ACL. +#if DBG== 1 + if (pOrig->Revision != SECURITY_DESCRIPTOR_REVISION) + ComDebOut(( DEB_ERROR, "Someone made a new security descriptor revision without telling me." )); + if (pOrig->Dacl != NULL) + Win4Assert( pOrig->Dacl->AclRevision == ACL_REVISION || + !"Someone made a new acl revision without telling me." ); +#endif + + // Validate the security descriptor and ACL. + if (pOrig->Revision != SECURITY_DESCRIPTOR_REVISION || + (pOrig->Control & SE_SELF_RELATIVE) != 0 || + pOrig->Owner == NULL || + pOrig->Group == NULL || + pOrig->Sacl != NULL || + (pOrig->Dacl != NULL && pOrig->Dacl->AclRevision != ACL_REVISION)) + return E_INVALIDARG; + + // Figure out how much memory to allocate for the copy and allocate it. + cOwner = GetLengthSid( pOrig->Owner ); + cGroup = GetLengthSid( pOrig->Group ); + cDacl = pOrig->Dacl == NULL ? 0 : pOrig->Dacl->AclSize; + cSize = sizeof(SECURITY_DESCRIPTOR) + cOwner + cGroup + cDacl; + *pCopy = (SECURITY_DESCRIPTOR *) PrivMemAlloc( cSize ); + if (*pCopy == NULL) + return E_OUTOFMEMORY; + + // Get pointers to each of the parts of the security descriptor. + pOwner = (SID *) (*pCopy + 1); + pGroup = (SID *) (((char *) pOwner) + cOwner); + if (pOrig->Dacl != NULL) + pDacl = (ACL *) (((char *) pGroup) + cGroup); + else + pDacl = NULL; + + // Copy each piece. + **pCopy = *pOrig; + memcpy( pOwner, pOrig->Owner, cOwner ); + memcpy( pGroup, pOrig->Group, cGroup ); + if (pDacl != NULL) + memcpy( pDacl, pOrig->Dacl, pOrig->Dacl->AclSize ); + (*pCopy)->Owner = pOwner; + (*pCopy)->Group = pGroup; + (*pCopy)->Dacl = pDacl; + (*pCopy)->Sacl = NULL; + + // Check the security descriptor. +#if DBG==1 + if (!IsValidSecurityDescriptor( *pCopy )) + { + Win4Assert( !"COM Created invalid security descriptor." ); + return GetLastError(); + } +#endif + return S_OK; +} + +//+------------------------------------------------------------------- +// +// Function: CoQueryAuthenticationServices, public +// +// Synopsis: Return a list of the registered authentication services. +// +//-------------------------------------------------------------------- +WINOLEAPI CoQueryAuthenticationServices( DWORD *pcAuthSvc, + SOLE_AUTHENTICATION_SERVICE **asAuthSvc ) +{ + DWORD i; + DWORD lNum = 0; + USHORT *pNext; + HRESULT hr = S_OK; + + ASSERT_LOCK_RELEASED + LOCK + + // Count the number of services in the security string array. + if (gpsaSecurity != NULL) + { + pNext = &gpsaSecurity->aStringArray[gpsaSecurity->wSecurityOffset]; + while (*pNext != 0) + { + lNum++; + pNext += lstrlenW(pNext)+1; + } + } + + // Return nothing if there are no authentication services. + *pcAuthSvc = lNum; + if (lNum == 0) + { + *asAuthSvc = NULL; + goto exit; + } + + // Allocate a list of pointers. + *asAuthSvc = (SOLE_AUTHENTICATION_SERVICE *) + CoTaskMemAlloc( lNum * sizeof(void *) ); + if (*asAuthSvc == NULL) + { + hr = E_OUTOFMEMORY; + goto exit; + } + + // Initialize it. + for (i = 0; i < lNum; i++) + (*asAuthSvc)[i].pPrincipalName = NULL; + + // Fill in one SOLE_AUTHENTICATION_SERVICE record per service + pNext = &gpsaSecurity->aStringArray[gpsaSecurity->wSecurityOffset]; + for (i = 0; i < lNum; i++) + { + (*asAuthSvc)[i].dwAuthnSvc = *(pNext++); + (*asAuthSvc)[i].dwAuthzSvc = *(pNext++); + (*asAuthSvc)[i].hr = S_OK; + + // Allocate memory for the principal name string. + (*asAuthSvc)[i].pPrincipalName = (OLECHAR *) + CoTaskMemAlloc( (lstrlenW(pNext)+1)*sizeof(OLECHAR) ); + if ((*asAuthSvc)[i].pPrincipalName == NULL) + { + hr = E_OUTOFMEMORY; + break; + } + + lstrcpyW( (*asAuthSvc)[i].pPrincipalName, pNext ); + pNext += lstrlenW(pNext) + 1; + } + + // Clean up if there wasn't enough memory. + if (FAILED(hr)) + { + for (i = 0; i < lNum; i++) + CoTaskMemFree( (*asAuthSvc)[i].pPrincipalName ); + CoTaskMemFree( *asAuthSvc ); + *asAuthSvc = NULL; + *pcAuthSvc = 0; + } + +exit: + UNLOCK + ASSERT_LOCK_RELEASED + return hr; +} + +//+------------------------------------------------------------------- +// +// Function: CoQueryClientBlanket +// +// Synopsis: Get the authentication settings the client used to call +// the server. +// +//-------------------------------------------------------------------- +WINOLEAPI CoQueryClientBlanket( + DWORD *pAuthnSvc, + DWORD *pAuthzSvc, + OLECHAR **pServerPrincName, + DWORD *pAuthnLevel, + DWORD *pImpLevel, + RPC_AUTHZ_HANDLE *pPrivs, + DWORD *pCapabilities ) +{ + HRESULT hr; + IServerSecurity *pSS; + + // Get the IServerSecurity. + hr = CoGetCallContext( IID_IServerSecurity, (void **) &pSS ); + if (FAILED(hr)) + return hr; + + // Ask IServerSecurity to do the query. + hr = pSS->QueryBlanket( pAuthnSvc, pAuthzSvc, pServerPrincName, + pAuthnLevel, pImpLevel, pPrivs, pCapabilities ); + + pSS->Release(); + return hr; +} + +//+------------------------------------------------------------------- +// +// Function: CoQueryProxyBlanket, public +// +// Synopsis: Get the authentication settings from a proxy. +// +//-------------------------------------------------------------------- +WINOLEAPI CoQueryProxyBlanket( + IUnknown *pProxy, + DWORD *pAuthnSvc, + DWORD *pAuthzSvc, + OLECHAR **pServerPrincName, + DWORD *pAuthnLevel, + DWORD *pImpLevel, + RPC_AUTH_IDENTITY_HANDLE *pAuthInfo, + DWORD *pCapabilities ) +{ + HRESULT hr; + IClientSecurity *pickle; + + // Ask the proxy for IClientSecurity. + hr = ((IUnknown *) pProxy)->QueryInterface( IID_IClientSecurity, + (void **) &pickle ); + if (FAILED(hr)) + return hr; + + // Ask IClientSecurity to do the query. + hr = pickle->QueryBlanket( pProxy, pAuthnSvc, pAuthzSvc, pServerPrincName, + pAuthnLevel, pImpLevel, pAuthInfo, + pCapabilities ); + pickle->Release(); + return hr; +} + +//+------------------------------------------------------------------- +// +// Function: CoRevertToSelf +// +// Synopsis: Get the server security for the current call and ask it +// to revert. +// +//-------------------------------------------------------------------- +WINOLEAPI CoRevertToSelf() +{ + HRESULT hr; + IServerSecurity *pSS; + + // Get the IServerSecurity. + hr = CoGetCallContext( IID_IServerSecurity, (void **) &pSS ); + if (FAILED(hr)) + return hr; + + // Ask IServerSecurity to do the revert. + hr = pSS->RevertToSelf(); + pSS->Release(); + return hr; +} + +//+------------------------------------------------------------------- +// +// Function: CoSetProxyBlanket, public +// +// Synopsis: Set the authentication settings for a proxy. +// +//-------------------------------------------------------------------- +WINOLEAPI CoSetProxyBlanket( + IUnknown *pProxy, + DWORD dwAuthnSvc, + DWORD dwAuthzSvc, + OLECHAR *pServerPrincName, + DWORD dwAuthnLevel, + DWORD dwImpLevel, + RPC_AUTH_IDENTITY_HANDLE pAuthInfo, + DWORD dwCapabilities ) +{ + HRESULT hr; + IClientSecurity *pickle; + + // Ask the proxy for IClientSecurity. + hr = ((IUnknown *) pProxy)->QueryInterface( IID_IClientSecurity, + (void **) &pickle ); + if (FAILED(hr)) + return hr; + + // Ask IClientSecurity to do the set. + hr = pickle->SetBlanket( pProxy, dwAuthnSvc, dwAuthzSvc, pServerPrincName, + dwAuthnLevel, dwImpLevel, pAuthInfo, + dwCapabilities ); + pickle->Release(); + return hr; +} + +//+------------------------------------------------------------------- +// +// Function: CoSwitchCallContext +// +// Synopsis: Replace the call context object in TLS. Return the old +// context object. This API is used by custom marshallers +// to support security. +// +//-------------------------------------------------------------------- +WINOLEAPI CoSwitchCallContext( IUnknown *pNewObject, IUnknown **ppOldObject ) +{ + HRESULT hr; + COleTls tls(hr); + + if (SUCCEEDED(hr)) + { + *ppOldObject = tls->pCallContext; + tls->pCallContext = pNewObject; + } + return hr; +} + +//+------------------------------------------------------------------- +// +// Member: CServerSecurity::AddRef, public +// +// Synopsis: Adds a reference to an interface +// +// Note: This is created in the stack so its reference count is ignored. +// +//-------------------------------------------------------------------- +STDMETHODIMP_(ULONG) CServerSecurity::AddRef() +{ + InterlockedIncrement( (long *) &_iRefCount ); + return _iRefCount; +} + +//+------------------------------------------------------------------- +// +// Member: CServerSecurity::CServerSecurity, public +// +// Synopsis: Construct a server security for a remote call. +// +//-------------------------------------------------------------------- +CServerSecurity::CServerSecurity() +{ + _iRefCount = 1; + _pChannel = NULL; + _pHandle = NULL; + _iFlags = 0; +} + +//+------------------------------------------------------------------- +// +// Member: CServerSecurity::CServerSecurity, public +// +// Synopsis: Construct a server security for a call. +// +//-------------------------------------------------------------------- +CServerSecurity::CServerSecurity( CChannelCallInfo *call ) +{ + _iRefCount = 1; + if (call->iFlags & CF_PROCESS_LOCAL) + { + _pChannel = call->pChannel; + _pHandle = NULL; + _iFlags = SS_PROCESS_LOCAL; + } + else + { + _pChannel = NULL; + _pHandle = (handle_t *) call->pmessage->reserved1; + _iFlags = 0; + } +} + +//+------------------------------------------------------------------- +// +// Member: CServerSecurity::EndCall, public +// +// Synopsis: Clears the stored binding handle because the call +// this object represents is over. +// +//-------------------------------------------------------------------- +void CServerSecurity::EndCall() +{ + // Revert if the app forgot to. + RevertToSelf(); + _iFlags |= SS_CALL_DONE; + _pChannel = NULL; + _pHandle = NULL; +} + +//+------------------------------------------------------------------- +// +// Member: CServerSecurity::ImpersonateClient, public +// +// Synopsis: Calls RPC to impersonate for the stored binding handle. +// +//-------------------------------------------------------------------- +STDMETHODIMP CServerSecurity::ImpersonateClient() +{ +#ifdef _CHICAGO_ + return E_NOTIMPL; +#else + + HRESULT hr = S_OK; + RPC_STATUS sc; + BOOL fSuccess; + HANDLE hProcess; + HANDLE hToken; + SECURITY_IMPERSONATION_LEVEL eDuplicate; + + // If the call is over, fail this request. + if (_iFlags & SS_CALL_DONE) + hr = RPC_E_CALL_COMPLETE; + + // For process local calls, ask the channel to impersonate. + else if (_iFlags & SS_PROCESS_LOCAL) + { + if (_pChannel->GetSecurityToken() == NULL) + { + // Determine what rights to duplicate the token with. + if (_pChannel->GetImpLevel() == RPC_C_IMP_LEVEL_IMPERSONATE) + eDuplicate = SecurityImpersonation; + else + eDuplicate = SecurityIdentification; + + // If the channel doesn't have a token, use the process token. + if (OpenProcessToken( GetCurrentProcess(), + TOKEN_DUPLICATE, + &hProcess )) + { + if (DuplicateToken( hProcess, eDuplicate, &hToken )) + { + if (!SetThreadToken( NULL, hToken )) + hr = MAKE_SCODE( SEVERITY_ERROR, FACILITY_WIN32, GetLastError() ); + + // If the channel still doesn't have a token, save this one. + LOCK + if (_pChannel->GetSecurityToken() == NULL) + _pChannel->SwapSecurityToken( hToken ); + else + CloseHandle( hToken ); + UNLOCK + } + else + hr = MAKE_SCODE( SEVERITY_ERROR, FACILITY_WIN32, GetLastError() ); + CloseHandle( hProcess ); + } + else + hr = MAKE_SCODE( SEVERITY_ERROR, FACILITY_WIN32, GetLastError() ); + } + else + { + fSuccess = SetThreadToken( NULL, _pChannel->GetSecurityToken() ); + if (!fSuccess) + hr = MAKE_SCODE( SEVERITY_ERROR, FACILITY_WIN32, GetLastError() ); + } + } + + // For process remote calls, ask RPC to impersonate. + else + { + sc = RpcImpersonateClient( _pHandle ); + if (sc != RPC_S_OK) + hr = MAKE_SCODE( SEVERITY_ERROR, FACILITY_WIN32, sc ); + } + + if (SUCCEEDED(hr)) + _iFlags |= SS_IMPERSONATING; + return hr; +#endif +} + +//+------------------------------------------------------------------- +// +// Member: CServerSecurity::IsImpersonating, public +// +// Synopsis: Return TRUE if ImpersonateClient has been called. +// +//-------------------------------------------------------------------- +STDMETHODIMP_(BOOL) CServerSecurity::IsImpersonating() +{ +#ifdef _CHICAGO_ + return FALSE; +#else + return _iFlags & SS_IMPERSONATING; +#endif +} + +//+------------------------------------------------------------------- +// +// Member: CServerSecurity::QueryBlanket, public +// +// Synopsis: Calls RPC to return the authentication information +// for the stored binding handle. +// +//-------------------------------------------------------------------- +STDMETHODIMP CServerSecurity::QueryBlanket( + DWORD *pAuthnSvc, + DWORD *pAuthzSvc, + OLECHAR **pServerPrincName, + DWORD *pAuthnLevel, + DWORD *pImpLevel, + void **pPrivs, + DWORD *pCapabilities ) +{ + HRESULT hr = S_OK; + RPC_STATUS sc; + DWORD iLen; + OLECHAR *pCopy; + + // Initialize the out parameters. Currently the impersonation level + // and capabilities can not be determined. + if (pPrivs != NULL) + *((void **) pPrivs) = NULL; + if (pServerPrincName != NULL) + *pServerPrincName = NULL; + if (pAuthnSvc != NULL) + *pAuthnSvc = RPC_C_AUTHN_WINNT; + if (pAuthnLevel != NULL) + *pAuthnLevel = RPC_C_AUTHN_LEVEL_PKT_PRIVACY; + if (pImpLevel != NULL) + *pImpLevel = RPC_C_IMP_LEVEL_ANONYMOUS; + if (pAuthzSvc != NULL) + *pAuthzSvc = RPC_C_AUTHZ_NONE; + if (pCapabilities != NULL) + *pCapabilities = EOAC_NONE; + + // If the call is over, fail this request. + if (_iFlags & SS_CALL_DONE) + hr = RPC_E_CALL_COMPLETE; + + // For process local calls, use the defaults. Otherwise ask RPC. + else if ((_iFlags & SS_PROCESS_LOCAL) == 0) + { + sc = RpcBindingInqAuthClientW( _pHandle, pPrivs, pServerPrincName, + pAuthnLevel, pAuthnSvc, pAuthzSvc ); + + // Sometimes RPC sets out parameters in error cases. + if (sc != RPC_S_OK) + { + if (pServerPrincName != NULL) + *pServerPrincName = NULL; + hr = MAKE_SCODE( SEVERITY_ERROR, FACILITY_WIN32, sc ); + } + else if (pServerPrincName != NULL && *pServerPrincName != NULL) + { + // Reallocate the principle name using the OLE memory allocator. + iLen = lstrlenW( *pServerPrincName ); + pCopy = (OLECHAR *) CoTaskMemAlloc( (iLen+1) * sizeof(OLECHAR) ); + if (pCopy != NULL) + lstrcpyW( pCopy, *pServerPrincName ); + else + hr = E_OUTOFMEMORY; + RpcStringFree( pServerPrincName ); + *pServerPrincName = pCopy; + } + } + return hr; +} + +//+------------------------------------------------------------------- +// +// Member: CServerSecurity::QueryInterface, public +// +// Synopsis: Returns a pointer to the requested interface. +// +//-------------------------------------------------------------------- +STDMETHODIMP CServerSecurity::QueryInterface( REFIID riid, LPVOID FAR* ppvObj) +{ + if (IsEqualIID(riid, IID_IUnknown) || + IsEqualIID(riid, IID_IServerSecurity)) + { + *ppvObj = (IServerSecurity *) this; + } + else + { + *ppvObj = NULL; + return E_NOINTERFACE; + } + + AddRef(); + return S_OK; +} + +//+------------------------------------------------------------------- +// +// Member: CServerSecurity::Release, public +// +// Synopsis: Releases an interface +// +// Note: This is created in the stack so its reference count is ignored. +// +//-------------------------------------------------------------------- +STDMETHODIMP_(ULONG) CServerSecurity::Release() +{ + ULONG lRef = _iRefCount - 1; + + if (InterlockedDecrement( (long*) &_iRefCount ) == 0) + { + Win4Assert( !"Illegal release of IServerSecurity." ); + delete this; + return 0; + } + else + { + return lRef; + } +} + +//+------------------------------------------------------------------- +// +// Member: CServerSecurity::RevertToSelf, public +// +// Synopsis: If ImpersonateClient was called, then either ask RPC to +// revert or NULL the thread token ourself. +// +//-------------------------------------------------------------------- +HRESULT CServerSecurity::RevertToSelf() +{ +#ifdef _CHICAGO_ + return S_OK; +#else + HRESULT hr = RPC_S_OK; + RPC_STATUS sc; + BOOL fSuccess; + + if (_iFlags & SS_IMPERSONATING) + { + // Ask win32 to revert for process local calls. + _iFlags &= ~SS_IMPERSONATING; + if (_iFlags & SS_PROCESS_LOCAL) + { + fSuccess = SetThreadToken( NULL, NULL ); + if (!fSuccess) + hr = MAKE_SCODE( SEVERITY_ERROR, FACILITY_WIN32, GetLastError() ); + } + + // Ask RPC to revert for process remote calls. + else + { + sc = RpcRevertToSelfEx( _pHandle ); + if (sc != RPC_S_OK) + hr = MAKE_SCODE( SEVERITY_ERROR, FACILITY_WIN32, sc ); + } + } + return hr; +#endif +} + +//+------------------------------------------------------------------- +// +// Function: DefaultAuthnServices, private +// +// Synopsis: Register authentication services with RPC and build +// a string array of authentication services and principal +// names. +// +//-------------------------------------------------------------------- +HRESULT DefaultAuthnServices() +{ + HRESULT hr = S_OK; + DWORD i; + WCHAR *pPrincName = NULL; + DWORD lNameLen; + USHORT *pNextString; + DUALSTRINGARRAY *pOld; + DWORD cBinding = gServerSvcListLen ? gServerSvcListLen : 1; + + ASSERT_LOCK_HELD + + // Return if the security bindings are already computed. + if (!gDefaultService) + return S_OK; + + // Only look up the current user name if the only security provider + // is not NTLMSSP since NTLMSSP doesn't do mutual auth. + if (gServerSvcListLen != 0 && + (gServerSvcListLen != 1 || gServerSvcList[0] != RPC_C_AUTHN_WINNT)) + { +#ifndef _CHICAGO_ + hr = LookupPrincName( &pPrincName ); + + if (SUCCEEDED(hr)) + lNameLen = lstrlenW( pPrincName ) + 1; +#else + hr = LookupPrincName( gServerSvcList, gServerSvcListLen, &pPrincName ); + if (SUCCEEDED(hr)) + lNameLen = lstrlenW( pPrincName ) + 1; + else + { + // BUGBUG: the whole PrincName mess still needs clean up + // especially given the state of msnsspc.dll + pPrincName = NULL; + hr = S_OK; + lNameLen = 1; + } +#endif // _CHICAGO_ + } + else + lNameLen = 1; + + if (SUCCEEDED(hr)) + { + // Allocate memory for the string array. + Win4Assert( gGotSecurityData ); + pOld = gpsaSecurity; + gpsaSecurity = (DUALSTRINGARRAY *) + PrivMemAlloc( sizeof(DUALSTRINGARRAY) + 2 * sizeof(WCHAR) + + cBinding * (sizeof(SECURITYBINDING) + + lNameLen*sizeof(WCHAR)) ); + if (gpsaSecurity != NULL) + { + // Fill in the array of security information. First two characters + // are NULLs to signal empty binding strings. + PrivMemFree( pOld ); + gDefaultService = FALSE; + gpsaSecurity->wSecurityOffset = 2; + gpsaSecurity->aStringArray[0] = 0; + gpsaSecurity->aStringArray[1] = 0; + pNextString = &gpsaSecurity->aStringArray[2]; + + for (i = 0; i < gServerSvcListLen; i++) + { + // Ignore errors since applications using automatic security + // may not care if they can't receive secure calls. + hr = RpcServerRegisterAuthInfoW( pPrincName, gServerSvcList[i], + NULL, NULL ); + if (hr == RPC_S_OK) + { + // Fill in authentication service, authorization service, + // and principal name. + *(pNextString++) = gServerSvcList[i]; + *(pNextString++) = COM_C_AUTHZ_NONE; + if (pPrincName == NULL) + *pNextString = 0; + else + memcpy( pNextString, pPrincName, lNameLen*sizeof(USHORT) ); + pNextString += lNameLen; + } + } + + // Add a final NULL. Special case an empty list which requires + // two NULLs. + *(pNextString++) = 0; + if (gServerSvcListLen == 0) + *(pNextString++) = 0; + gpsaSecurity->wNumEntries = (USHORT) + (pNextString-gpsaSecurity->aStringArray); + hr = S_OK; + } + else + { + hr = E_OUTOFMEMORY; + gpsaSecurity = pOld; + } + } + + PrivMemFree( pPrincName ); + return hr; +} + +//+------------------------------------------------------------------- +// +// Function: FixupAccessControl, internal +// +// Synopsis: Get the access control class id. Instantiate the access +// control class and load the data. +// +// Notes: The caller has already insured that the structure is +// at least as big as a SPermissionHeader structure. +// +//-------------------------------------------------------------------- +HRESULT FixupAccessControl( SECURITY_DESCRIPTOR **pSD, DWORD cbSD ) +{ + SPermissionHeader *pHeader; + IAccessControl *pControl = NULL; + IPersistStream *pPersist = NULL; + CNdrStream cStream( ((unsigned char *) *pSD) + sizeof(SPermissionHeader), + cbSD - sizeof(SPermissionHeader) ); + HRESULT hr; + + // Get the class id. + pHeader = (SPermissionHeader *) *pSD; + + // Instantiate the class. + hr = CoCreateInstance( pHeader->gClass, NULL, CLSCTX_INPROC_SERVER, + IID_IAccessControl, (void **) &pControl ); + + // Get IPeristStream + if (SUCCEEDED(hr)) + { + hr = pControl->QueryInterface( IID_IPersistStream, (void **) &pPersist ); + + // Load the stream. + if (SUCCEEDED(hr)) + hr = pPersist->Load( &cStream ); + } + + // Release resources. + if (pPersist != NULL) + pPersist->Release(); + if (SUCCEEDED(hr)) + { + PrivMemFree( *pSD ); + *pSD = (SECURITY_DESCRIPTOR *) pControl; + } + else if (pControl != NULL) + pControl->Release(); + return hr; +} + +//+------------------------------------------------------------------- +// +// Function: FixupSecurityDescriptor, internal +// +// Synopsis: Convert the security descriptor from self relative to +// absolute form and check for errors. +// +//-------------------------------------------------------------------- +HRESULT FixupSecurityDescriptor( SECURITY_DESCRIPTOR **pSD, DWORD cbSD ) +{ + // Fix up the security descriptor. + (*pSD)->Control &= ~SE_SELF_RELATIVE; + (*pSD)->Sacl = NULL; + if ((*pSD)->Dacl != NULL) + { + if (cbSD < sizeof(ACL) + sizeof(SECURITY_DESCRIPTOR) || + (ULONG) (*pSD)->Dacl > cbSD - sizeof(ACL)) + return REGDB_E_INVALIDVALUE; + (*pSD)->Dacl = (ACL *) (((char *) *pSD) + ((ULONG) (*pSD)->Dacl)); + if ((*pSD)->Dacl->AclSize + sizeof(SECURITY_DESCRIPTOR) > cbSD) + return REGDB_E_INVALIDVALUE; + } + + // Set up the owner and group SIDs. + if ((*pSD)->Group == 0 || ((ULONG) (*pSD)->Group) + sizeof(SID) > cbSD || + (*pSD)->Owner == 0 || ((ULONG) (*pSD)->Owner) + sizeof(SID) > cbSD) + return REGDB_E_INVALIDVALUE; + (*pSD)->Group = (SID *) (((BYTE *) *pSD) + (ULONG) (*pSD)->Group); + (*pSD)->Owner = (SID *) (((BYTE *) *pSD) + (ULONG) (*pSD)->Owner); + + // Check the security descriptor. +#if DBG==1 + if (!IsValidSecurityDescriptor( *pSD )) + return REGDB_E_INVALIDVALUE; +#endif + return S_OK; +} + +//+------------------------------------------------------------------- +// +// Function: GetLegacySecDesc, internal +// +// Synopsis: Get a security descriptor for the current app. First, +// look under the app id for the current exe name. If that +// fails look up the default descriptor. If that fails, +// create one. +// +// Note: It is possible that the security descriptor size could change +// during the size computation. Add code to retry. +// +//-------------------------------------------------------------------- +HRESULT GetLegacySecDesc( SECURITY_DESCRIPTOR **pSD, DWORD *pCapabilities ) +{ + // Holds either Appid\{guid} or Appid\module_name. + WCHAR aKeyName[MAX_PATH+7]; + HRESULT hr; + HKEY hKey = NULL; + DWORD lSize; + WCHAR aModule[MAX_PATH]; + DWORD cModule; + DWORD i; + WCHAR aAppid[40]; // Hold a registry GUID. + DWORD lType; + + // If the flag EOAC_APPID is set, the security descriptor contains the + // app id. + if ((*pCapabilities & EOAC_APPID) && *pSD != NULL) + { + if (StringFromIID2( *((GUID *) *pSD), aAppid, sizeof(aAppid) ) == 0) + return RPC_E_UNEXPECTED; + *pSD = NULL; + + // Open the application id key. A GUID in the registry is stored. + // as a 38 character string. + lstrcpyW( aKeyName, L"AppID\\" ); + memcpy( &aKeyName[6], aAppid, 39*sizeof(WCHAR) ); + hr = RegOpenKeyEx( HKEY_CLASSES_ROOT, aKeyName, + NULL, KEY_READ, &hKey ); + + // Get the security descriptor from the registry. + if (hr == ERROR_SUCCESS) + hr = GetRegistrySecDesc( hKey, L"AccessPermission", pSD, + pCapabilities ); + else + hr = MAKE_SCODE( SEVERITY_ERROR, FACILITY_WIN32, hr ); + + } + + // Look up the app id from the exe name. + else + { + // Get the executable's name. Find the start of the file name. + cModule = GetModuleFileName( NULL, aModule, MAX_PATH ); + if (cModule >= MAX_PATH) + { + Win4Assert( !"Module name too long." ); + return RPC_E_UNEXPECTED; + } + for (i = cModule-1; i > 0; i--) + if (aModule[i] == '/' || + aModule[i] == '\\' || + aModule[i] == ':') + break; + if (i != 0) + i += 1; + + // Open the key for the EXE's module name. + lstrcpyW( aKeyName, L"AppID\\" ); + memcpy( &aKeyName[6], &aModule[i], (cModule - i + 1) * sizeof(WCHAR) ); + hr = RegOpenKeyEx( HKEY_CLASSES_ROOT, aKeyName, + NULL, KEY_READ, &hKey ); + + // Look for an application id. + if (hr == ERROR_SUCCESS) + { + lSize = sizeof(aAppid); + hr = RegQueryValueEx( hKey, L"AppID", NULL, &lType, + (unsigned char *) &aAppid, &lSize ); + RegCloseKey( hKey ); + hKey = NULL; + + // Open the application id key. A GUID in the registry is stored. + // as a 38 character string. + if (hr == ERROR_SUCCESS && lType == REG_SZ && + lSize == 39*sizeof(WCHAR)) + { + memcpy( &aKeyName[6], aAppid, 39*sizeof(WCHAR) ); + hr = RegOpenKeyEx( HKEY_CLASSES_ROOT, aKeyName, + NULL, KEY_READ, &hKey ); + + // Get the security descriptor from the registry. + if (hr == ERROR_SUCCESS) + { + hr = GetRegistrySecDesc( hKey, L"AccessPermission", pSD, + pCapabilities ); + if (SUCCEEDED(hr) || hr == REGDB_E_INVALIDVALUE) + goto cleanup; + RegCloseKey( hKey ); + hKey = NULL; + } + } + } + + // Open the default key. + hr = RegOpenKeyEx( HKEY_LOCAL_MACHINE, L"SOFTWARE\\Microsoft\\OLE", + NULL, KEY_READ, &hKey ); + + // Get the security descriptor from the registry. + if (hr == ERROR_SUCCESS) + { + hr = GetRegistrySecDesc( hKey, L"DefaultAccessPermission", pSD, + pCapabilities ); + + // If that failed, make one. + if (FAILED(hr) && hr != REGDB_E_INVALIDVALUE) + hr = MakeSecDesc( pSD, pCapabilities ); + } + else + hr = MAKE_SCODE( SEVERITY_ERROR, FACILITY_WIN32, hr ); + } + +cleanup: + + // Free the security descriptor memory if anything failed. + if (FAILED(hr)) + { + PrivMemFree( *pSD ); + *pSD = NULL; + } + + // Close the registry key. + if (hKey != NULL) + RegCloseKey( hKey ); + return hr; +} + +//+------------------------------------------------------------------- +// +// Function: GetRegistrySecDesc, internal +// +// Synopsis: Convert a security descriptor from self relative to +// absolute form. Stuff in an owner and a group. +// +// Notes: +// REGDB_E_INVALIDVALUE is returned when there is something +// at the specified value, but it is not a security descriptor. +// +// The caller must free the security descriptor in both the +// success and failure cases. +// +// Codework: It would be nice to use the unicode APIs on NT. +// +//-------------------------------------------------------------------- +HRESULT GetRegistrySecDesc( HKEY hKey, WCHAR *pValue, + SECURITY_DESCRIPTOR **pSD, DWORD *pCapabilities ) + +{ + SID *pGroup; + SID *pOwner; + DWORD cbSD = 256; + DWORD lType; + HRESULT hr; + WORD wVersion; + + // Guess how much memory to allocate for the security descriptor. + *pSD = (SECURITY_DESCRIPTOR *) PrivMemAlloc( cbSD ); + if (*pSD == NULL) + { + hr = E_OUTOFMEMORY; + goto cleanup; + } + + // Find put how much memory to allocate for the security + // descriptor. + hr = RegQueryValueEx( hKey, pValue, NULL, &lType, + (unsigned char *) *pSD, &cbSD ); + if (hr != ERROR_SUCCESS && hr != ERROR_MORE_DATA) + { + hr = MAKE_SCODE( SEVERITY_ERROR, FACILITY_WIN32, hr ); + goto cleanup; + } + if (lType != REG_BINARY || cbSD < sizeof(SECURITY_DESCRIPTOR)) + { + hr = REGDB_E_INVALIDVALUE; + goto cleanup; + } + + // If the first guess wasn't large enough, reallocate the memory. + if (hr == ERROR_MORE_DATA) + { + PrivMemFree( *pSD ); + *pSD = (SECURITY_DESCRIPTOR *) PrivMemAlloc( cbSD ); + if (*pSD == NULL) + { + hr = E_OUTOFMEMORY; + goto cleanup; + } + + // Read the security descriptor. + hr = RegQueryValueEx( hKey, pValue, NULL, &lType, + (unsigned char *) *pSD, &cbSD ); + if (hr != ERROR_SUCCESS) + { + hr = MAKE_SCODE( SEVERITY_ERROR, FACILITY_WIN32, hr ); + goto cleanup; + } + if (lType != REG_BINARY || cbSD < sizeof(SECURITY_DESCRIPTOR)) + { + hr = REGDB_E_INVALIDVALUE; + goto cleanup; + } + } + + // Check the first DWORD to determine what type of data is in the + // registry value. + wVersion = *((WORD *) *pSD); +#ifndef _CHICAGO_ + if (wVersion == COM_PERMISSION_SECDESC) + hr = FixupSecurityDescriptor( pSD, cbSD ); + else +#endif + if (wVersion == COM_PERMISSION_ACCCTRL) + { + hr = FixupAccessControl( pSD, cbSD ); + if (SUCCEEDED(hr)) + *pCapabilities |= EOAC_ACCESS_CONTROL; + } + else + hr = REGDB_E_INVALIDVALUE; + +cleanup: + if (FAILED(hr)) + { + PrivMemFree( *pSD ); + *pSD = NULL; + } + return hr; +} + +//+------------------------------------------------------------------- +// +// Function: HashSid +// +// Synopsis: Create a 32 bit hash of a SID. +// +//-------------------------------------------------------------------- +DWORD HashSid( SID *pSid ) +{ + DWORD lHash = 0; + DWORD cbSid = GetLengthSid( pSid ); + DWORD i; + unsigned char *pData = (unsigned char *) pSid; + + for (i = 0; i < cbSid; i++) + lHash = (lHash << 1) + *pData++; + return lHash; +} + +//+------------------------------------------------------------------- +// +// Function: InitializeSecurity, internal +// +// Synopsis: Called the first time the channel is used. If the app +// has not initialized security yet, this function sets +// up legacy security. +// +//-------------------------------------------------------------------- +HRESULT InitializeSecurity() +{ + HRESULT hr; + ASSERT_LOCK_HELD + + // Return if already initialized. + if (gpsaSecurity != NULL) + return S_OK; + + // Initialize. All parameters are ignored except the security descriptor + // since the capability is set to app id. + hr = CoInitializeSecurity( NULL, -1, NULL, NULL, RPC_C_AUTHN_LEVEL_NONE, + RPC_C_IMP_LEVEL_IDENTIFY, NULL, EOAC_APPID, + NULL ); + + // Convert confusing error codes. + if (hr == E_INVALIDARG) + hr = REGDB_E_INVALIDVALUE; + return hr; +} + +//+------------------------------------------------------------------- +// +// Function: IsCallerLocalSystem +// +// Synopsis: Impersonate the caller and do an ACL check. The first +// time this function is called, create the ACL +// +//-------------------------------------------------------------------- +BOOL IsCallerLocalSystem() +{ + HRESULT hr = S_OK; + DWORD granted_access; + BOOL access; + HANDLE token; + DWORD privilege_size = sizeof(gPriv); + BOOL success; + SECURITY_DESCRIPTOR *pSecDesc = NULL; + DWORD lIgnore; + + ASSERT_LOCK_RELEASED + + // If the security descriptor does not exist, create it. + if (gRundownSD == NULL) + { + // Make the security descriptor. + hr = MakeSecDesc( &pSecDesc, &lIgnore ); + + // Save the security descriptor. + LOCK + if (gRundownSD == NULL) + gRundownSD = pSecDesc; + else + PrivMemFree( pSecDesc ); + UNLOCK + } + + // Impersonate. + if (SUCCEEDED(hr)) + hr = CoImpersonateClient(); + + // Get the thread token. + if (SUCCEEDED(hr)) + { + success = OpenThreadToken( GetCurrentThread(), TOKEN_READ, + TRUE, &token ); + if (!success) + hr = E_FAIL; + } + + // Check access. + if (SUCCEEDED(hr)) + { + success = AccessCheck( gRundownSD, token, COM_RIGHTS_EXECUTE, + &gMap, &gPriv, &privilege_size, + &granted_access, &access ); + if (!success || !access) + hr = E_FAIL; + CloseHandle( token ); + } + + // Just call revert since it detects whether or not the impersonate + // succeeded. + CoRevertToSelf(); + + ASSERT_LOCK_RELEASED + + return SUCCEEDED(hr); +} + +//+------------------------------------------------------------------- +// +// Function: IsLocalAuthnService +// +// Synopsis: Return TRUE is the specified authentication service is +// on the list of services this machine supports. +// +// NOTE: If we ever expect the list to be more then three items +// long, we can add code to sort it. +// +//-------------------------------------------------------------------- +BOOL IsLocalAuthnService( USHORT wAuthnService ) +{ + DWORD l; + + for (l = 0; l < gClientSvcListLen; l++) + if (gClientSvcList[l] == wAuthnService) + return TRUE; + return FALSE; +} + +#ifndef _CHICAGO_ +//+------------------------------------------------------------------- +// +// Function: LookupPrincName, private +// +// Synopsis: Open the process token and find the user's name. +// +//-------------------------------------------------------------------- +HRESULT LookupPrincName( WCHAR **pPrincName ) +{ + HRESULT hr = S_OK; + BYTE aMemory[SIZEOF_TOKEN_USER]; + TOKEN_USER *pTokenUser = (TOKEN_USER *) &aMemory; + HANDLE hToken = NULL; + DWORD lIgnore; + DWORD lNameLen = 80; + DWORD lDomainLen = 80; + WCHAR *pDomainName = NULL; + SID_NAME_USE sIgnore; + BOOL fSuccess; + + // Open the process's token. + *pPrincName = NULL; + if (OpenProcessToken( GetCurrentProcess(), TOKEN_QUERY, &hToken )) + { + + // Lookup SID of process token. + if (GetTokenInformation( hToken, TokenUser, pTokenUser, sizeof(aMemory), + &lIgnore )) + { + // Preallocate some memory. + *pPrincName = (WCHAR *) PrivMemAlloc( lNameLen*sizeof(WCHAR) ); + pDomainName = (WCHAR *) _alloca( lDomainLen*sizeof(WCHAR) ); + if (*pPrincName != NULL && pDomainName != NULL) + { + + // Find the user's name. + fSuccess = LookupAccountSidW( NULL, pTokenUser->User.Sid, + *pPrincName, &lNameLen, + pDomainName, &lDomainLen, + &sIgnore ); + + // If the call failed, try allocating more memory. + if (!fSuccess) + { + + // Allocate memory for the user's name. + PrivMemFree( *pPrincName ); + *pPrincName = (WCHAR *) PrivMemAlloc( lNameLen*sizeof(WCHAR) ); + pDomainName = (WCHAR *) _alloca( lDomainLen*sizeof(WCHAR) ); + if (*pPrincName != NULL && pDomainName != NULL) + { + + // Find the user's name. + if (!LookupAccountSidW( NULL, pTokenUser->User.Sid, + *pPrincName, &lNameLen, pDomainName, + &lDomainLen, &sIgnore )) + hr = MAKE_SCODE( SEVERITY_ERROR, FACILITY_WIN32, GetLastError() ); + } + else + hr = E_OUTOFMEMORY; + } + } + else + hr = E_OUTOFMEMORY; + } + else + hr = MAKE_SCODE( SEVERITY_ERROR, FACILITY_WIN32, GetLastError() ); + CloseHandle( hToken ); + } + else + { + hr = MAKE_SCODE( SEVERITY_ERROR, FACILITY_WIN32, GetLastError() ); +#if DBG==1 + Win4Assert( !"Why did OpenProcessToken fail?" ); + OpenProcessToken( GetCurrentProcess(), TOKEN_QUERY, &hToken ); +#endif + } + + if (hr != S_OK) + { + PrivMemFree( *pPrincName ); + *pPrincName = NULL; + } + return hr; +} +#else // _CHICAGO_ +//+------------------------------------------------------------------- +// +// Function: LookupPrincName, private +// +// Synopsis: We have a service other than NTLMSSP. +// Find the first (!) such and find the user's name. +// +// BUGBUG: This is a hack until the principal name issue is properly +// sorted out. +// +//-------------------------------------------------------------------- +HRESULT LookupPrincName( + USHORT *pwAuthnServices, + ULONG cAuthnServices, + WCHAR **pPrincName + ) +{ + // assume failure lest thou be disappointed + RPC_STATUS status = RPC_S_INVALID_AUTH_IDENTITY; + + *pPrincName = NULL; + + for (ULONG i = 0; i < cAuthnServices; i++) + { + if (pwAuthnServices[i] != RPC_C_AUTHN_WINNT) + { + status = RpcServerInqDefaultPrincNameW( + pwAuthnServices[i], + pPrincName); + if (status == RPC_S_OK) + { + break; + } + } + } + + return HRESULT_FROM_WIN32(status); +} + +#endif // _CHICAGO_ + +#ifdef _CHICAGO_ +//+------------------------------------------------------------------- +// +// Function: MakeSecDesc, private +// +// Synopsis: Make an access control that allows the current user +// access. +// +// NOTE: NetWkstaGetInfo does not return the size needed unless the size +// in is zero. +// +//-------------------------------------------------------------------- +HRESULT MakeSecDesc( SECURITY_DESCRIPTOR **pSD, DWORD *pCapabilities ) +{ + HRESULT hr = S_OK; + IAccessControl *pAccess = NULL; + DWORD cTrustee; + WCHAR *pTrusteeW; + char *pTrusteeA; + DWORD cDomain; + DWORD cUser; + char *pBuffer; + struct wksta_info_10 *wi10; + USHORT cbBuffer; + HINSTANCE hMsnet; + NetWkstaGetInfoFn fnNetWkstaGetInfo; + ACTRL_ACCESSW sAccessList; + ACTRL_PROPERTY_ENTRYW sProperty; + ACTRL_ACCESS_ENTRY_LISTW sEntryList; + ACTRL_ACCESS_ENTRYW sEntry; + + // Load msnet32.dll + hMsnet = LoadLibraryA( "msnet32.dll" ); + if (hMsnet == NULL) + return MAKE_SCODE( SEVERITY_ERROR, FACILITY_WIN32, GetLastError() ); + + // Get the function NetWkstaGetInfo. + fnNetWkstaGetInfo = (NetWkstaGetInfoFn) GetProcAddress( hMsnet, + (char *) 57 ); + if (fnNetWkstaGetInfo == NULL) + { + hr = MAKE_SCODE( SEVERITY_ERROR, FACILITY_WIN32, GetLastError() ); + goto cleanup; + } + + // Find out how much space to allocate for the domain and user names. + cbBuffer = 0; + fnNetWkstaGetInfo( NULL, 10, NULL, 0, &cbBuffer ); + pBuffer = (char *) _alloca( cbBuffer ); + + // Get the domain and user names. + hr = fnNetWkstaGetInfo( NULL, 10, pBuffer, cbBuffer, &cbBuffer ); + if (hr != 0) + { + hr = MAKE_SCODE( SEVERITY_ERROR, FACILITY_WIN32, hr ); + goto cleanup; + } + + // Stick the user name and domain name in the same string. + wi10 = (struct wksta_info_10 *) pBuffer; + Win4Assert( wi10->wki10_logon_domain != NULL ); + Win4Assert( wi10->wki10_username != NULL ); + cDomain = lstrlenA( wi10->wki10_logon_domain ); + cUser = lstrlenA( wi10->wki10_username ); + pTrusteeA = (char *) _alloca( cDomain+cUser+2 ); + lstrcpyA( pTrusteeA, wi10->wki10_logon_domain ); + lstrcpyA( &pTrusteeA[cDomain+1], wi10->wki10_username ); + pTrusteeA[cDomain] = '\\'; + + // Find out how long the name is in Unicode. + cTrustee = MultiByteToWideChar( GetConsoleCP(), 0, pTrusteeA, + cDomain+cUser+2, NULL, 0 ); + + // Convert the name to Unicode. + pTrusteeW = (WCHAR *) _alloca( cTrustee * sizeof(WCHAR) ); + if (!MultiByteToWideChar( GetConsoleCP(), 0, pTrusteeA, + cDomain+cUser+2, pTrusteeW, cTrustee )) + { + hr = MAKE_SCODE( SEVERITY_ERROR, FACILITY_WIN32, GetLastError() ); + goto cleanup; + } + + // Create an AccessControl. + *pSD = NULL; + hr = CoCreateInstance( CLSID_DCOMAccessControl, NULL, + CLSCTX_INPROC_SERVER, + IID_IAccessControl, (void **) &pAccess ); + if (FAILED(hr)) + goto cleanup; + + // Give the current user access. + sAccessList.cEntries = 1; + sAccessList.pPropertyAccessList = &sProperty; + sProperty.lpProperty = NULL; + sProperty.pAccessEntryList = &sEntryList; + sProperty.fListFlags = 0; + sEntryList.cEntries = 1; + sEntryList.pAccessList = &sEntry; + sEntry.fAccessFlags = ACTRL_ACCESS_ALLOWED; + sEntry.Access = COM_RIGHTS_EXECUTE; + sEntry.ProvSpecificAccess = 0; + sEntry.Inheritance = NO_INHERITANCE; + sEntry.lpInheritProperty = NULL; + sEntry.Trustee.pMultipleTrustee = NULL; + sEntry.Trustee.MultipleTrusteeOperation = NO_MULTIPLE_TRUSTEE; + sEntry.Trustee.TrusteeForm = TRUSTEE_IS_NAME; + sEntry.Trustee.TrusteeType = TRUSTEE_IS_USER; + sEntry.Trustee.ptstrName = pTrusteeW; + hr = pAccess->GrantAccessRights( &sAccessList ); + +cleanup: + FreeLibrary( hMsnet ); + if (SUCCEEDED(hr)) + { + *pSD = (SECURITY_DESCRIPTOR *) pAccess; + *pCapabilities |= EOAC_ACCESS_CONTROL; + } + else if (pAccess != NULL) + pAccess->Release(); + return hr; +} + +#else +//+------------------------------------------------------------------- +// +// Function: MakeSecDesc, private +// +// Synopsis: Make a security descriptor that allows the current user +// and local system access. +// +// NOTE: Compute the length of the sids used rather then using constants. +// +//-------------------------------------------------------------------- +HRESULT MakeSecDesc( SECURITY_DESCRIPTOR **pSD, DWORD *pCapabilities ) +{ + HRESULT hr = S_OK; + ACL *pAcl; + DWORD lSidLen; + SID *pGroup; + SID *pOwner; + BYTE aMemory[SIZEOF_TOKEN_USER]; + TOKEN_USER *pTokenUser = (TOKEN_USER *) &aMemory; + HANDLE hToken = NULL; + DWORD lIgnore; + HANDLE hThread; + + Win4Assert( *pSD == NULL ); + + // Open the process's token. + if (!OpenProcessToken( GetCurrentProcess(), TOKEN_QUERY, &hToken )) + { + // If the thread has a token, remove it and try again. + if (!OpenThreadToken( GetCurrentThread(), TOKEN_IMPERSONATE, TRUE, + &hThread )) + { + Win4Assert( !"How can both OpenThreadToken and OpenProcessToken fail?" ); + return MAKE_SCODE( SEVERITY_ERROR, FACILITY_WIN32, GetLastError() ); + } + if (!SetThreadToken( NULL, NULL )) + { + hr = MAKE_SCODE( SEVERITY_ERROR, FACILITY_WIN32, GetLastError() ); + CloseHandle( hThread ); + return hr; + } + if (!OpenProcessToken( GetCurrentProcess(), TOKEN_QUERY, &hToken )) + hr = MAKE_SCODE( SEVERITY_ERROR, FACILITY_WIN32, GetLastError() ); + SetThreadToken( NULL, hThread ); + CloseHandle( hThread ); + if (FAILED(hr)) + return hr; + } + + // Lookup SID of process token. + if (!GetTokenInformation( hToken, TokenUser, pTokenUser, sizeof(aMemory), + &lIgnore )) + goto last_error; + + // Compute the length of the SID. + lSidLen = GetLengthSid( pTokenUser->User.Sid ); + Win4Assert( lSidLen <= SIZEOF_SID ); + + // Allocate the security descriptor. + *pSD = (SECURITY_DESCRIPTOR *) PrivMemAlloc( + sizeof(SECURITY_DESCRIPTOR) + 2*lSidLen + SIZEOF_ACL ); + if (*pSD == NULL) + { + hr = E_OUTOFMEMORY; + goto cleanup; + } + pGroup = (SID *) (*pSD + 1); + pOwner = (SID *) (((BYTE *) pGroup) + lSidLen); + pAcl = (ACL *) (((BYTE *) pOwner) + lSidLen); + + // Initialize a new security descriptor. + if (!InitializeSecurityDescriptor(*pSD, SECURITY_DESCRIPTOR_REVISION)) + goto last_error; + + // Initialize a new ACL. + if (!InitializeAcl(pAcl, SIZEOF_ACL, ACL_REVISION2)) + goto last_error; + + // Allow the current user access. + if (!AddAccessAllowedAce( pAcl, ACL_REVISION2, COM_RIGHTS_EXECUTE, + pTokenUser->User.Sid)) + goto last_error; + + // Allow local system access. + if (!AddAccessAllowedAce( pAcl, ACL_REVISION2, COM_RIGHTS_EXECUTE, + (void *) &LOCAL_SYSTEM_SID )) + goto last_error; + + // Add a new ACL to the security descriptor. + if (!SetSecurityDescriptorDacl( *pSD, TRUE, pAcl, FALSE )) + goto last_error; + + // Set the group. + memcpy( pGroup, pTokenUser->User.Sid, lSidLen ); + if (!SetSecurityDescriptorGroup( *pSD, pGroup, FALSE )) + goto last_error; + + // Set the owner. + memcpy( pOwner, pTokenUser->User.Sid, lSidLen ); + if (!SetSecurityDescriptorOwner( *pSD, pOwner, FALSE )) + goto last_error; + + // Check the security descriptor. +#if DBG==1 + if (!IsValidSecurityDescriptor( *pSD )) + { + Win4Assert( !"COM Created invalid security descriptor." ); + goto last_error; + } +#endif + + goto cleanup; +last_error: + hr = MAKE_SCODE( SEVERITY_ERROR, FACILITY_WIN32, GetLastError() ); + +cleanup: + if (hToken != NULL) + CloseHandle( hToken ); + if (FAILED(hr)) + { + PrivMemFree( *pSD ); + *pSD = NULL; + } + return hr; +} +#endif + +//+------------------------------------------------------------------- +// +// Function: RegisterAuthnServices, public +// +// Synopsis: Register the specified services. Build a security +// binding. +// +//-------------------------------------------------------------------- +HRESULT RegisterAuthnServices( DWORD cAuthSvc, + SOLE_AUTHENTICATION_SERVICE *asAuthSvc ) +{ + DWORD i; + RPC_STATUS sc; + USHORT wNumEntries = 0; + USHORT *pNext; + HRESULT hr; + DWORD lNameLen; + + ASSERT_LOCK_HELD + + // Register all the authentication services specified. + for (i = 0; i < cAuthSvc; i++) + { + sc = RpcServerRegisterAuthInfoW( asAuthSvc[i].pPrincipalName, + asAuthSvc[i].dwAuthnSvc, + NULL, NULL ); + + // If the registration failed, store the failure code. + if (sc != RPC_S_OK) + asAuthSvc[i].hr = MAKE_SCODE( SEVERITY_ERROR, FACILITY_WIN32, sc ); + + // Otherwise determine how much space to reserve for it in the string + // array. + else + { + asAuthSvc[i].hr = S_OK; + if (asAuthSvc[i].pPrincipalName != NULL) + wNumEntries += lstrlenW( asAuthSvc[i].pPrincipalName ) + 3; + else + wNumEntries += 3; + } + } + if (wNumEntries == 0) + hr = RPC_E_NO_GOOD_SECURITY_PACKAGES; + else + // make room for the two NULLs that placehold for the empty + // string binding and the trailing NULL. + wNumEntries += 3; + + // If some services were registered, build a string array. + if (wNumEntries != 0) + { + gpsaSecurity = (DUALSTRINGARRAY *) PrivMemAlloc( + wNumEntries*sizeof(USHORT) + sizeof(DUALSTRINGARRAY) ); + if (gpsaSecurity == NULL) + hr = E_OUTOFMEMORY; + else + { + gpsaSecurity->wNumEntries = wNumEntries; + gpsaSecurity->wSecurityOffset = 2; + gpsaSecurity->aStringArray[0] = 0; + gpsaSecurity->aStringArray[1] = 0; + pNext = &gpsaSecurity->aStringArray[2]; + + for (i = 0; i < cAuthSvc; i++) + { + if (asAuthSvc[i].hr == S_OK) + { + // Fill in authentication service, authorization service, + // and principal name. + *(pNext++) = (USHORT) asAuthSvc[i].dwAuthnSvc; + *(pNext++) = (USHORT) (asAuthSvc[i].dwAuthzSvc == 0 ? + COM_C_AUTHZ_NONE : + asAuthSvc[i].dwAuthzSvc); + if (asAuthSvc[i].pPrincipalName != NULL) + { + lNameLen = lstrlenW( asAuthSvc[i].pPrincipalName ) + 1; + memcpy( pNext, asAuthSvc[i].pPrincipalName, + lNameLen*sizeof(USHORT) ); + pNext += lNameLen; + } + else + *(pNext++) = 0; + } + } + *pNext = 0; + + hr = S_OK; + } + } + + ASSERT_LOCK_HELD + return hr; +} + +//+------------------------------------------------------------------- +// +// Function: SetAuthnService, internal +// +// Synopsis: Determine the authentication information to set on a +// binding handle for a newly unmarshalled interface. +// The authentication level is the higher of the process +// default and the level in the interface. The +// impersonation level is the process default. If the +// authentication level is not zero, the function +// scans the list of authentication services in the +// interface looking for one this machine supports. +// +//-------------------------------------------------------------------- +HRESULT SetAuthnService( handle_t hHandle, OXID_INFO *pOxidInfo, + OXIDEntry *pOxid ) +{ + DWORD lAuthnLevel; + DWORD lAuthnSvc; + USHORT wNext; + USHORT wAuthzSvc; + RPC_STATUS sc; + WCHAR *pPrincipal; + RPC_SECURITY_QOS sQos; + + // Pick the highest authentication level between the process default + // and the interface hint. The constant RPC_C_AUTHN_LEVEL_DEFAULT + // has value zero and maps to connect. + if (gAuthnLevel == RPC_C_AUTHN_LEVEL_DEFAULT) + lAuthnLevel = RPC_C_AUTHN_LEVEL_CONNECT; + else + lAuthnLevel = gAuthnLevel; + if (pOxidInfo->dwAuthnHint == RPC_C_AUTHN_LEVEL_DEFAULT) + pOxidInfo->dwAuthnHint = RPC_C_AUTHN_LEVEL_CONNECT; + if (lAuthnLevel > pOxidInfo->dwAuthnHint) + lAuthnLevel = gAuthnLevel; + else + lAuthnLevel = pOxidInfo->dwAuthnHint; + + // For machine local servers, only set the authentication information if + // the impersonation level is not the default. + sQos.Version = RPC_C_SECURITY_QOS_VERSION; + sQos.IdentityTracking = RPC_C_QOS_IDENTITY_STATIC; + sQos.ImpersonationType = gImpLevel; + sQos.Capabilities = (gCapabilities & EOAC_MUTUAL_AUTH) ? + RPC_C_QOS_CAPABILITIES_MUTUAL_AUTH : RPC_C_QOS_CAPABILITIES_DEFAULT; + if (pOxid->dwFlags & OXIDF_MACHINE_LOCAL) + { + if (gImpLevel != RPC_C_IMP_LEVEL_IMPERSONATE) + { + sc = RpcBindingSetAuthInfoExW( hHandle, NULL, lAuthnLevel, + RPC_C_AUTHN_WINNT, NULL, + RPC_C_AUTHZ_NONE, &sQos ); + if (sc != RPC_S_OK) + return MAKE_SCODE( SEVERITY_ERROR, FACILITY_WIN32, sc ); + else + return S_OK; + } + } + + // For machine remote servers, set the authentication information if any + // parameter differs from RPC's default. + else if (lAuthnLevel != RPC_C_AUTHN_LEVEL_NONE) + { + // Look through all the authentication services in the interface + // till we find one that works on this machine. + wNext = pOxidInfo->psa->wSecurityOffset; + while (wNext < pOxidInfo->psa->wNumEntries && + pOxidInfo->psa->aStringArray[wNext] != 0) + { + if (IsLocalAuthnService( pOxidInfo->psa->aStringArray[wNext] )) + { + // Set the authentication info on the binding handle. + pPrincipal = &pOxidInfo->psa->aStringArray[wNext+2]; + if (pPrincipal[0] == 0) + pPrincipal = NULL; + +#ifdef _CHICAGO_ + // If the principal name is not known, the server must be + // NT. Replace the principal name in that case + // because a NULL principal name is a flag for some + // Chicago security hack. + if (pPrincipal == NULL && + pOxidInfo->psa->aStringArray[wNext] == RPC_C_AUTHN_WINNT) + pPrincipal = L"Default"; +#endif // _CHICAGO_ + + wAuthzSvc = pOxidInfo->psa->aStringArray[wNext+1]; + if (wAuthzSvc == COM_C_AUTHZ_NONE) + wAuthzSvc = RPC_C_AUTHZ_NONE; + sc = RpcBindingSetAuthInfoExW( + hHandle, + pPrincipal, + lAuthnLevel, + pOxidInfo->psa->aStringArray[wNext], + NULL, + wAuthzSvc, + &sQos ); + if (sc != RPC_S_OK) + return MAKE_SCODE( SEVERITY_ERROR, FACILITY_WIN32, sc ); + else + return S_OK; + } + + // Skip to the next authentication service. + wNext += lstrlenW( &pOxidInfo->psa->aStringArray[wNext] ) + 1; + } + + // No valid authentication service was found. This is an error. + return RPC_E_NO_GOOD_SECURITY_PACKAGES; + } + return S_OK; +} + +//+------------------------------------------------------------------- +// +// Function: UninitializeSecurity, internal +// +// Synopsis: Free resources allocated while initializing security. +// +//-------------------------------------------------------------------- +void UninitializeSecurity() +{ + DWORD i; + + ASSERT_LOCK_HELD + + PrivMemFree(gSecDesc); + PrivMemFree(gpsaSecurity); + PrivMemFree( gRundownSD ); +#ifndef SHRMEM_OBJEX + MIDL_user_free( gClientSvcList ); + MIDL_user_free( gServerSvcList ); + MIDL_user_free( gLegacySecurity ); +#else // SHRMEM_OBJEX + delete [] gClientSvcList; + delete [] gServerSvcList; + delete [] gLegacySecurity; +#endif // SHRMEM_OBJEX + for (i = 0; i < ACCESS_CACHE_LEN; i++) + { + PrivMemFree( gAccessCache[i] ); + gAccessCache[i] = NULL; + } + + if (gAccessControl != NULL) + gAccessControl->Release(); + + gAccessControl = NULL; + gSecDesc = NULL; + gAuthnLevel = RPC_C_AUTHN_LEVEL_NONE; + gImpLevel = RPC_C_IMP_LEVEL_IDENTIFY; + gCapabilities = EOAC_NONE; + gLegacySecurity = NULL; + gpsaSecurity = NULL; + gClientSvcList = NULL; + gServerSvcList = NULL; + gGotSecurityData = FALSE; + gRundownSD = NULL; + gDefaultService = FALSE; + gMostRecentAccess = 0; +} + +#endif diff --git a/private/ole32/com/dcomrem/security.hxx b/private/ole32/com/dcomrem/security.hxx new file mode 100644 index 000000000..96b48f760 --- /dev/null +++ b/private/ole32/com/dcomrem/security.hxx @@ -0,0 +1,239 @@ +//+------------------------------------------------------------------- +// +// File: security.hxx +// +// Contents: Classes for channel security +// +// Classes: CClientSecurity, CServerSecurity +// +// History: 11 Oct 95 AlexMit Created +// +//-------------------------------------------------------------------- +#ifndef _SECURITY_HXX_ +#define _SECURITY_HXX_ + +#include <chancont.hxx> + +//+---------------------------------------------------------------- +// Typedefs. +typedef enum +{ + SS_PROCESS_LOCAL = 0x1, // Client and server are in same process + SS_CALL_DONE = 0x2, // Call is complete, fail new calls to impersonate + SS_IMPERSONATING = 0x4 // Server has called impersonate +} EServerSecurity; + +//+---------------------------------------------------------------- +// +// Class: CClientSecurity, public +// +// Purpose: Provides security for proxies +// +//----------------------------------------------------------------- + +class CStdIdentity; + +class CClientSecurity : public IClientSecurity +{ + public: + CClientSecurity( CStdIdentity *pId ) { _pStdId = pId; } + ~CClientSecurity() {} + + STDMETHOD (QueryBlanket) + ( + IUnknown *pProxy, + DWORD *pAuthnSvc, + DWORD *pAuthzSvc, + OLECHAR **pServerPrincName, + DWORD *pAuthnLevel, + DWORD *pImpLevel, + void **pAuthInfo, + DWORD *pCapabilities + ); + + STDMETHOD (SetBlanket) + ( + IUnknown *pProxy, + DWORD AuthnSvc, + DWORD AuthzSvc, + OLECHAR *ServerPrincName, + DWORD AuthnLevel, + DWORD ImpLevel, + void *pAuthInfo, + DWORD Capabilities + ); + + STDMETHOD (CopyProxy) + ( + IUnknown *pProxy, + IUnknown **ppCopy + ); + + private: + CStdIdentity *_pStdId; +}; + +//+---------------------------------------------------------------- +// +// Class: CServerSecurity, public +// +// Purpose: Provides security for stubs +// +//----------------------------------------------------------------- + +class CRpcChannelBuffer; + +class CServerSecurity : public IServerSecurity +{ + public: + CServerSecurity( CChannelCallInfo * ); + CServerSecurity(); + ~CServerSecurity() {} + + STDMETHOD (QueryInterface) ( REFIID riid, LPVOID FAR* ppvObj); + STDMETHOD_(ULONG,AddRef) ( void ); + STDMETHOD_(ULONG,Release) ( void ); + STDMETHOD (QueryBlanket) + ( + DWORD *pAuthnSvc, + DWORD *pAuthzSvc, + OLECHAR **pServerPrincName, + DWORD *pAuthnLevel, + DWORD *pImpLevel, + void **pPrivs, + DWORD *pCapabilities + ); + STDMETHOD (ImpersonateClient)( void ); + STDMETHOD (RevertToSelf) ( void ); + STDMETHOD_(BOOL,IsImpersonating) (void); + + void EndCall(); + + private: + DWORD _iRefCount; + DWORD _iFlags; // See EServerSecurity + handle_t *_pHandle; // RPC server handle of call + CRpcChannelBuffer *_pChannel; // Channel of call +}; + +//+---------------------------------------------------------------- +// Prototypes. +RPC_STATUS CheckAccessControl ( RPC_IF_HANDLE pIid, void *pContext ); +RPC_STATUS CheckAcl ( RPC_IF_HANDLE pIid, void *pContext ); +HRESULT DefaultAuthnServices(); +HRESULT InitializeSecurity (); +BOOL IsCallerLocalSystem (); +HRESULT SetAuthnService ( handle_t, OXID_INFO *, OXIDEntry * ); +void UninitializeSecurity(); + +struct IAccessControl; + +extern IAccessControl *gAccessControl; +extern DWORD gAuthnLevel; +extern DWORD gCapabilities; +extern USHORT *gClientSvcList; +extern DWORD gClientSvcListLen; +extern BOOL gDisableDCOM; +extern BOOL gGotSecurityData; +extern DWORD gImpLevel; +extern SECURITYBINDING *gLegacySecurity; +extern DUALSTRINGARRAY *gpsaSecurity; +extern SECURITY_DESCRIPTOR *gSecDesc; +extern USHORT *gServerSvcList; +extern DWORD gServerSvcListLen; +extern BOOL gSetAuth; + + +//+------------------------------------------------------------------- +// +// Function: GetCallAuthnLevel, public +// +// Synopsis: Get the authentication level of the current call from TLS. +// If no calls are in progress and the level has not been +// set on this thread, use the process default instead. +// +//-------------------------------------------------------------------- +inline DWORD GetCallAuthnLevel() +{ + COleTls tls; + DWORD lAuthnLevel = tls->dwAuthnLevel; + if (lAuthnLevel == RPC_C_AUTHN_LEVEL_DEFAULT) + { + lAuthnLevel = tls->dwAuthnLevel = gAuthnLevel; + } + return lAuthnLevel; +} + +//+------------------------------------------------------------------- +// +// Function: ResumeImpersonate +// +// Synopsis: Query the context object for IServerSecurity. If the +// resume flag is set, call ImpersonateClient. +// +//-------------------------------------------------------------------- +inline void ResumeImpersonate( IUnknown *pContext, BOOL fResume ) +{ + IServerSecurity *pServer; + HRESULT result; + + if (pContext != NULL && fResume) + { + result = pContext->QueryInterface( IID_IServerSecurity, + (void **) &pServer ); + if (SUCCEEDED(result)) + { + pServer->ImpersonateClient(); + pServer->Release(); + } + } +} + +//+------------------------------------------------------------------- +// +// Function: SuspendImpersonate +// +// Synopsis: Query the context for IServerSecurity. If found, +// check to see if the call is impersonated. If it is, +// set pResume TRUE and call RevertToSelf. +// +//-------------------------------------------------------------------- +inline void SuspendImpersonate( IUnknown *pContext, BOOL *pResume ) +{ + IServerSecurity *pServer; + HRESULT result; + + *pResume = FALSE; + if (pContext != NULL) + { + result = pContext->QueryInterface( IID_IServerSecurity, + (void **) &pServer ); + if (SUCCEEDED(result)) + { + *pResume = pServer->IsImpersonating(); + if (*pResume) + pServer->RevertToSelf(); + pServer->Release(); + } + } +} + +//+------------------------------------------------------------------- +// +// Function: GetAclFn() +// +// Synopsis: If automatic security is turned on and the level is +// not none, return the function to do ACL checking. +// Otherwise return NULL. +// +//-------------------------------------------------------------------- +inline RPC_IF_CALLBACK_FN *GetAclFn() +{ + if (gSecDesc != NULL) + return CheckAcl; + else if (gAccessControl != NULL) + return CheckAccessControl; + else + return NULL; +} +#endif diff --git a/private/ole32/com/dcomrem/service.cxx b/private/ole32/com/dcomrem/service.cxx new file mode 100644 index 000000000..6b3f2f580 --- /dev/null +++ b/private/ole32/com/dcomrem/service.cxx @@ -0,0 +1,818 @@ +//+------------------------------------------------------------------- +// +// File: service.cxx +// +// Contents: APIs to simplify RPC setup +// +// Functions: +// +// History: 23-Nov-92 Rickhi +// 20-Feb-95 Rickhi Major Simplification for DCOM +// +//-------------------------------------------------------------------- +#include <ole2int.h> +#include <service.hxx> // CRpcService +#include <orcb.h> // IOrCallback +#include <malloc.hxx> // MIDL_user_allocate +#include <locks.hxx> // LOCK/UNLOCK etc +#include <ipidtbl.hxx> // GetLocalEntry +#include <security.hxx> // gpsaSecurity +#include <channelb.hxx> // gRemUnknownIf + + +BOOL gSpeedOverMem = FALSE; // Trade memory for speed. +BOOL gfListening = FALSE; // Server is/isn't listening +BOOL gfDefaultStrings = FALSE; // Using precomputed string bindings +BOOL gfLrpc = FALSE; // Registered for ncalrpc +#ifdef _CHICAGO_ +BOOL gfMswmsg = FALSE; // Registered for mswmsg +#endif + +DWORD gdwEndPoint = 0; +DWORD gdwPsaMaxSize = 0; +DUALSTRINGARRAY *gpsaCurrentProcess = NULL; +const DWORD MAX_LOCAL_SB = 23; + +#ifndef _CHICAGO_ +SECURITY_DESCRIPTOR LrpcSecurityDescriptor; +BOOL fLrpcSDInitialized = FALSE; +#endif + +// interface structure for IRemUnknown +extern const RPC_SERVER_INTERFACE gRemUnknownIf; + + +#if DBG==1 +//+------------------------------------------------------------------- +// +// Function: DisplayAllStringBindings, private +// +// Synopsis: prints the stringbindings to the debugger +// +// Notes: This function requires the caller to hold gComLock. +// +// History: 23-Nov-93 Rickhi Created +// +//-------------------------------------------------------------------- +void DisplayAllStringBindings(void) +{ + ASSERT_LOCK_HELD + + if (gpsaCurrentProcess) + { + LPWSTR pwszNext = gpsaCurrentProcess->aStringArray; + LPWSTR pwszEnd = pwszNext + gpsaCurrentProcess->wSecurityOffset; + + while (pwszNext < pwszEnd) + { + ComDebOut((DEB_CHANNEL, "pSEp=%x %ws\n", pwszNext, pwszNext)); + pwszNext += lstrlenW(pwszNext) + 1; + } + } +} +#endif // DBG == 1 + + +//+------------------------------------------------------------------- +// +// Function: InitializeLrpcSecurity, private +// +// Synopsis: Create a DACL allowing all access to NCALRPC and MSWMSG +// endpoints. +// +//-------------------------------------------------------------------- +void InitializeLrpcSecurity() +{ +#ifndef _CHICAGO_ + if (!fLrpcSDInitialized) + { + // + // Since this is static storage, and we always initialize it + // to the same values, it does not need to be MT safe. + // + InitializeSecurityDescriptor(&LrpcSecurityDescriptor, + SECURITY_DESCRIPTOR_REVISION); + SetSecurityDescriptorDacl(&LrpcSecurityDescriptor, + TRUE, NULL, FALSE); + fLrpcSDInitialized = TRUE; + } +#endif +} + +//+------------------------------------------------------------------- +// +// Function: RegisterLrpc, private +// +// Synopsis: Register the ncalrpc transport. +// +//-------------------------------------------------------------------- +RPC_STATUS RegisterLrpc() +{ + RPC_STATUS sc; + WCHAR pwszEndPoint[12]; + + InitializeLrpcSecurity(); + + lstrcpyW( pwszEndPoint, L"OLE" ); + _ultow( gdwEndPoint, &pwszEndPoint[3], 16 ); + + // The second parameter is a hint that tells lrpc whether or not it + // can preallocate additional resources (threads). + sc = RpcServerUseProtseqEp(L"ncalrpc", + RPC_C_PROTSEQ_MAX_REQS_DEFAULT + 1, + pwszEndPoint, +#ifndef _CHICAGO_ + &LrpcSecurityDescriptor); +#else + NULL); +#endif + + // Assume that duplicate endpoint means we registered the endpoint and + // got unload and reloaded instead of it meaning someone else registered + // the endpoint. + if (sc == RPC_S_DUPLICATE_ENDPOINT) + { + gfLrpc = TRUE; + return RPC_S_OK; + } + else if (sc == RPC_S_OK) + { +#ifndef _CHICAGO_ + // Tell RPC to use this endpoint for mswmsg replies. + sc = I_RpcSetWMsgEndpoint( pwszEndPoint ); + if (sc == RPC_S_OK) +#endif + gfLrpc = TRUE; + } + return sc; +} + +#ifdef _CHICAGO_ +//+------------------------------------------------------------------- +// +// Function: RegisterMswmsg, private +// +// Synopsis: Register the mswmsg transport. +// +// Notes: The caller must hold gComLock. +// +//-------------------------------------------------------------------- +RPC_STATUS RegisterMswmsg() +{ + + RPC_STATUS sc; + WCHAR pwszEndPoint[12]; + + ASSERT_LOCK_HELD + + InitializeLrpcSecurity(); + + lstrcpyW( pwszEndPoint, L"MSG" ); + _ultow( gdwEndPoint, &pwszEndPoint[3], 16 ); + sc = RpcServerUseProtseqEp(L"mswmsg", + RPC_C_PROTSEQ_MAX_REQS_DEFAULT, + pwszEndPoint, + &LrpcSecurityDescriptor); + + // Assume that duplicate endpoint means we registered the endpoint and + // got unload and reloaded instead of it meaning someone else registered + // the endpoint. + if (sc == RPC_S_OK || sc == RPC_S_DUPLICATE_ENDPOINT) + { + gfMswmsg = TRUE; + return RPC_S_OK; + } + else + return sc; +} +#endif + +//+------------------------------------------------------------------- +// +// Function: CheckClientMswmsg, public +// +// Synopsis: For the MSWMSG transport, we must call RpcServerUseProtseqEp +// on the client side. +// +// Notes: The caller must hold gComLock. +// +// History: 27 Sept 95 AlexMit Created +// +//-------------------------------------------------------------------- +RPC_STATUS CheckClientMswmsg( WCHAR *pProtseq, DWORD *pFlags ) +{ + RPC_STATUS sc = RPC_S_OK; + + ASSERT_LOCK_HELD + + // Set the MSWMSG flag correctly. +#ifdef _CHICAGO_ + if (lstrlenW (pProtseq) >= 6 && + memcmp ( L"mswmsg", pProtseq, 6 * sizeof (WCHAR)) == 0) +#else + if (IsSTAThread() && (*pFlags & OXIDF_MACHINE_LOCAL)) +#endif + *pFlags |= OXIDF_MSWMSG; + + // Find out if the transport is MSWMSG. + if ((*pFlags & OXIDF_MSWMSG) +#ifdef _CHICAGO_ + && IsSTAThread() && !gfMswmsg +#endif + ) + { + // Get a unique number and convert it to a string endpoint. + if (gdwEndPoint == 0) + gdwEndPoint = CoGetCurrentProcess(); + if (gdwEndPoint == 0) + return E_FAIL; + + // Register mswmsg. +#ifdef _CHICAGO_ + sc = RegisterMswmsg(); +#else + sc = RegisterLrpc(); +#endif + } + + return sc; +} + +//+------------------------------------------------------------------- +// +// Function: GetLocalEndpoint, public +// +// Synopsis: Get the endpoint for the local protocol sequence +// for the local OXID. +// +// Notes: This function takes gComLock. +// +// History: 6 May 95 AlexMit Created +// +//-------------------------------------------------------------------- +LPWSTR GetLocalEndpoint() +{ + ComDebOut((DEB_MARSHAL,"Entering GetLocalEndpoint.\n")); + LPWSTR pwszLocalEndpoint = NULL; + LOCK + + StartListen(); + if (gfListening) + { + // OLEFFFFFFFF + // maximum 12 character including the null, 24 bytes. + pwszLocalEndpoint = (LPWSTR) PrivMemAlloc( 24 ); + + if (pwszLocalEndpoint != NULL) + { + Win4Assert( gdwEndPoint != 0 ); + lstrcpyW( pwszLocalEndpoint, L"OLE" ); + _ultow( gdwEndPoint, &pwszLocalEndpoint[3], 16 ); + } + } + + UNLOCK + ComDebOut((DEB_MARSHAL,"Leaving GetLocalEndpoint Endpoint: 0x%x\n", + pwszLocalEndpoint)); + return pwszLocalEndpoint; +} + +//+------------------------------------------------------------------- +// +// Function: DefaultStringBindings, private +// +// Synopsis: Create a string binding with entries for just ncalrpc +// and mswmsg +// +// Notes: This function requires the caller to hold gComLock. +// +//-------------------------------------------------------------------- +RPC_STATUS DefaultStringBindings() +{ + ULONG cChar; + + ASSERT_LOCK_HELD + + // If mswmsg has been used, reserve space for the string + // mswmsg:[MSGnnnnnnnn] +#ifdef _CHICAGO_ + if (gfMswmsg) + cChar = 22; + else +#endif + cChar = 0; + + // If ncalrpc has been used, reserve space for the string + // ncalrpc:[OLEnnnnnnnn] + if (gfLrpc) + cChar += 24; + + // Allocate memory. Reserve space for an empty security binding. + cChar += 3; + gpsaCurrentProcess = (DUALSTRINGARRAY *) PrivMemAlloc( SASIZE(cChar) ); + + // Give up if the allocation failed. + if (gpsaCurrentProcess == NULL) + return RPC_S_OUT_OF_RESOURCES; + + // If mswmsg has been used, make up a string for it. +#ifdef _CHICAGO_ + if (gfMswmsg) + { + lstrcpyW( gpsaCurrentProcess->aStringArray, L"mswmsg:[MSG" ); + _ultow( gdwEndPoint, &gpsaCurrentProcess->aStringArray[11], 16 ); + cChar = lstrlenW( gpsaCurrentProcess->aStringArray ); + gpsaCurrentProcess->aStringArray[cChar++] = L']'; + gpsaCurrentProcess->aStringArray[cChar++] = 0; + } + else +#endif + cChar = 0; + + // If ncalrpc has been used, make up a string for it. + if (gfLrpc) + { + lstrcpyW( &gpsaCurrentProcess->aStringArray[cChar], L"ncalrpc:[OLE" ); + _ultow( gdwEndPoint, &gpsaCurrentProcess->aStringArray[cChar+12], 16 ); + cChar += lstrlenW( &gpsaCurrentProcess->aStringArray[cChar] ); + gpsaCurrentProcess->aStringArray[cChar++] = L']'; + gpsaCurrentProcess->aStringArray[cChar++] = 0; + } + + // Stick on an empty security binding. + gpsaCurrentProcess->aStringArray[cChar++] = 0; + gpsaCurrentProcess->wSecurityOffset = (USHORT) cChar; + gpsaCurrentProcess->aStringArray[cChar++] = 0; + gpsaCurrentProcess->aStringArray[cChar++] = 0; + gpsaCurrentProcess->wNumEntries = (USHORT) cChar; + gfDefaultStrings = TRUE; + return RPC_S_OK; +} + +//+------------------------------------------------------------------- +// +// Function: InquireStringBindings, private +// +// Synopsis: Get and server binding handles from RPC and convert them +// into a string array. +// +// Notes: This function requires the caller to hold gComLock. +// +// History: 23 May 95 AlexMit Created +// +//-------------------------------------------------------------------- +BOOL InquireStringBindings( WCHAR *pProtseq ) +{ + ASSERT_LOCK_HELD + + BOOL fFound = FALSE; + DWORD cbProtseq; + RPC_BINDING_VECTOR *pBindVect = NULL; + RPC_STATUS sc = RpcServerInqBindings(&pBindVect); + + if (sc == S_OK) + { + LPWSTR *apwszFullStringBinding; + ULONG *aulStrLen; + ULONG ulTotalStrLen = MAX_LOCAL_SB; // Total string lengths + ULONG j = 0; // BindString we're using + + if (pProtseq != NULL) + cbProtseq = lstrlenW( pProtseq ) * sizeof(WCHAR); + else + cbProtseq = 0; + apwszFullStringBinding = (LPWSTR *) PrivMemAlloc( pBindVect->Count * + sizeof(LPWSTR) ); + aulStrLen = (ULONG *) PrivMemAlloc( pBindVect->Count * + sizeof(ULONG) ); + if (apwszFullStringBinding != NULL && + aulStrLen != NULL) + { + + // iterate over the handles to get the string bindings + // and dynamic endpoints for all available protocols. + + for (ULONG i=0; i<pBindVect->Count; i++) + { + LPWSTR pwszStringBinding = NULL; + apwszFullStringBinding[j] = NULL; + aulStrLen[j] = 0; + + sc = RpcBindingToStringBinding(pBindVect->BindingH[i], + &pwszStringBinding); + Win4Assert(sc == S_OK && "RpcBindingToStringBinding"); + + + if (sc == S_OK) + { + // Determine is this is the protseq we are looking for. + if (memcmp( pProtseq, pwszStringBinding, cbProtseq ) == 0) + fFound = TRUE; + + // Skip ncalrpc because rot needs to know the + // format of the ncalrpc endpoint. + if (lstrlenW (pwszStringBinding) >= 7 && + memcmp ( L"ncalrpc", pwszStringBinding, 7*sizeof(WCHAR)) != 0) + { + // record the string lengths for later. include room + // for the NULL terminator. + apwszFullStringBinding[j] = pwszStringBinding; + aulStrLen[j] = lstrlenW(apwszFullStringBinding[j])+1; + ulTotalStrLen += aulStrLen[j]; + j++; + } + else + { + RpcStringFree( &pwszStringBinding ); + } + } + } // for + + + // now that all the string bindings and endpoints have been + // accquired, allocate a DUALSTRINGARRAY large enough to hold them + // all and copy them into the structure. + + if (ulTotalStrLen > 0) + { + void *pNew = PrivMemAlloc( sizeof(DUALSTRINGARRAY) + + (ulTotalStrLen+1)*sizeof(WCHAR) ); + if (pNew) + { + PrivMemFree( gpsaCurrentProcess ); + gpsaCurrentProcess = (DUALSTRINGARRAY *) pNew; + LPWSTR pwszNext = gpsaCurrentProcess->aStringArray; + + // Copy in ncalrpc:[OLEnnnnnnnn] + if (gfLrpc) + { + lstrcpyW( pwszNext, L"ncalrpc:[OLE" ); + _ultow( gdwEndPoint, &pwszNext[12], 16 ); + lstrcatW( pwszNext, L"]" ); + pwszNext += lstrlenW(pwszNext) + 1; + } + + // copy in the strings + for (i=0; i<j; i++) + { + lstrcpyW(pwszNext, apwszFullStringBinding[i]); + pwszNext += aulStrLen[i]; + } + + // Add a second null to terminate the string binding + // set. Add a third and fourth null to create an empty + // security binding set. + + pwszNext[0] = 0; + pwszNext[1] = 0; + pwszNext[2] = 0; + + // Fill in the size fields. + gpsaCurrentProcess->wSecurityOffset = pwszNext - + gpsaCurrentProcess->aStringArray + 1; + gpsaCurrentProcess->wNumEntries = + gpsaCurrentProcess->wSecurityOffset + 2; + } + else + { + sc = RPC_S_OUT_OF_RESOURCES; + } + } + else + { + // no binding strings. this is an error. + ComDebOut((DEB_ERROR, "No Rpc ProtSeq/EndPoints Generated\n")); + sc = RPC_S_NO_PROTSEQS; + } + + // free the full string bindings we allocated above + for (i=0; i<j; i++) + { + // free the old strings + RpcStringFree(&apwszFullStringBinding[i]); + } + } + else + { + sc = RPC_S_OUT_OF_RESOURCES; + } + + // free the binding vector allocated above + RpcBindingVectorFree(&pBindVect); + PrivMemFree( apwszFullStringBinding ); + PrivMemFree( aulStrLen ); + } + +#if DBG==1 + // display our binding strings on the debugger + DisplayAllStringBindings(); +#endif + + return fFound; +} + +//+------------------------------------------------------------------- +// +// Function: StartListen, public +// +// Synopsis: this starts the Rpc service listening. this is required +// in order to marshal interfaces. it is executed lazily, +// that is, we dont start listening until someone tries to +// marshal a local object interface. this is done so we dont +// spawn a thread unnecessarily. +// +// Notes: This function takes gComLock. +// +// History: 23-Nov-93 Rickhi Created +// +//-------------------------------------------------------------------- +HRESULT StartListen() +{ + ComDebOut((DEB_MARSHAL,"[IN] StartListen.\n")); + ASSERT_LOCK_HELD + + RPC_STATUS sc = S_OK; + OXIDEntry *pOxid; + + if (!gfListening) + { + // Get a unique number and convert it to a string endpoint. + if (gdwEndPoint == 0) + gdwEndPoint = CoGetCurrentProcess(); + if (gdwEndPoint == 0) + return E_FAIL; + + // Register ncalrpc for free threaded and mswmsg for apartment. +#ifdef _CHICAGO_ + if (IsMTAThread()) + { +#endif + sc = RegisterLrpc(); +#ifdef _CHICAGO_ + } + else + { + sc = RegisterMswmsg(); + // BUGBUG - Register ncalrpc until SCM can call us on mswmsg. + if (sc == RPC_S_OK) + sc = RegisterLrpc(); + } +#endif + + if (sc == RPC_S_OK) + { + + // Register the Object Resolver Callback interface. + sc = RpcServerRegisterIfEx(_IOrCallback_ServerIfHandle, NULL, NULL, + RPC_IF_AUTOLISTEN, + 0xffff, GetAclFn()); + + if (sc == RPC_S_OK || sc == RPC_S_TYPE_ALREADY_REGISTERED) + { + // Register the IRemUnknown interface. We need to register this + // manually because CRemoteUnknown marshals IRundown which inherits + // IRemoteUnknown. The resolver calls on IRundown and external clients + // call on IRemoteUnknown. + + sc = RpcServerRegisterIfEx( + (RPC_SERVER_INTERFACE *)&gRemUnknownIf, NULL, NULL, + RPC_IF_AUTOLISTEN | RPC_IF_OLE, + 0xffff, GetAclFn() ); + + if (sc == RPC_S_OK || sc == RPC_S_TYPE_ALREADY_REGISTERED) + { + sc = DefaultStringBindings(); + } + } + } + + if (sc == RPC_S_OK) + { + gfListening = TRUE; + sc = S_OK; + } + else + { + sc = HRESULT_FROM_WIN32(sc); + } + } + + if (sc == RPC_S_OK && IsSTAThread()) + { + // Tell MSWMSG the window for each thread. + sc = gOXIDTbl.GetLocalEntry( &pOxid); + if (SUCCEEDED(sc)) + { + sc = I_RpcServerStartListening( (HWND) pOxid->hServerSTA ); + if (sc != RPC_S_OK) + sc = HRESULT_FROM_WIN32(sc); + } + } + + // If something failed, make sure everything gets cleaned up. + if (FAILED(sc)) + { + UNLOCK + UnregisterDcomInterfaces(); + LOCK + } + + ASSERT_LOCK_HELD + ComDebOut(((sc == S_OK) ? DEB_MARSHAL : DEB_ERROR, + "[OUT] StartListen hr: 0x%x\n", sc)); + return sc; +} + +//+------------------------------------------------------------------- +// +// Function: GetStringBindings, public +// +// Synopsis: Return an array of strings bindings for this process +// +// Notes: This function takes gComLock. +// +// History: 23-Nov-93 Rickhi Created +// +//-------------------------------------------------------------------- +HRESULT GetStringBindings( DUALSTRINGARRAY **psaStrings ) +{ + TRACECALL(TRACE_RPC, "GetStringBindings"); + ComDebOut((DEB_CHANNEL, "[IN] GetStringBindings\n")); + +#ifdef _CHICAGO_ + #error Register MSWMSG per thread. +#endif + + *psaStrings = NULL; + + LOCK + HRESULT hr = StartListen(); + if (SUCCEEDED(hr)) + { + hr = CopyStringArray(gpsaCurrentProcess, gpsaSecurity, psaStrings); + } + UNLOCK + + ComDebOut((DEB_CHANNEL, "[OUT] GetStringBindings hr:%x\n", hr)); + return hr; +} + + +//+------------------------------------------------------------------- +// +// Function: CopyStringArray, public +// +// Synopsis: Combines the string bindings from the first DUALSTRINGARRAY +// with the security bindings from the second DUALSTRINGARRAY +// (if present) into a new DUALSTRINGARRAY. +// +// History: 23-Nov-93 Rickhi Created +// +//-------------------------------------------------------------------- +HRESULT CopyStringArray(DUALSTRINGARRAY *psaStringBinding, + DUALSTRINGARRAY *psaSecurity, + DUALSTRINGARRAY **ppsaNew) +{ + // compute size of string bindings + USHORT lSizeSB = SASIZE(psaStringBinding->wNumEntries); + + // compute size of additional security strings + USHORT lSizeSC = (psaSecurity == NULL) ? 0 : + psaSecurity->wNumEntries - psaSecurity->wSecurityOffset; + + *ppsaNew = (DUALSTRINGARRAY *) PrivMemAlloc( lSizeSB + + lSizeSC * sizeof(USHORT)); + + if (*ppsaNew != NULL) + { + // copy in the string bindings + memcpy(*ppsaNew, psaStringBinding, lSizeSB); + + if (psaSecurity != NULL) + { + // copy in the security strings, and adjust the overall length. + memcpy(&(*ppsaNew)->aStringArray[psaStringBinding->wSecurityOffset], + &psaSecurity->aStringArray[psaSecurity->wSecurityOffset], + lSizeSC*sizeof(USHORT)); + + (*ppsaNew)->wNumEntries = psaStringBinding->wSecurityOffset + + lSizeSC; + } + return S_OK; + } + + return E_OUTOFMEMORY; +} + +//+------------------------------------------------------------------- +// +// Function: UnregisterDcomInterfaces +// +// Synopsis: Unregister the object resolver callback function and mark +// DCOM as no longer accepting remote calls. +// +// Notes: This function requires that the caller guarentee +// serialization without taking gComLock. +// +// History: 23-Nov-93 Rickhi Created +// +//-------------------------------------------------------------------- +SCODE UnregisterDcomInterfaces(void) +{ + ComDebOut((DEB_CHANNEL, "[IN] UnregisterDcomInterfaces\n")); + RPC_STATUS sc = RPC_S_OK; + ASSERT_LOCK_RELEASED + + if (gfListening) + { + // Unregister IOrCallback. This can result in calls being dispatched. + // Do not hold the lock around this call. + sc = RpcServerUnregisterIf(_IOrCallback_ServerIfHandle, 0, 1 ); + + // Unregister IRemUnknown. This can result in calls being dispatched. + // Do not hold the lock around this call. + sc = RpcServerUnregisterIf((RPC_SERVER_INTERFACE *)&gRemUnknownIf, 0, 1); + + gfListening = FALSE; + } + gSpeedOverMem = FALSE; + + if (sc != RPC_S_OK) + sc = HRESULT_FROM_WIN32(sc); + + ComDebOut((DEB_CHANNEL, "[OUT] UnregisterDcomInterfaces hr:%x\n", sc)); + return sc; +} + + +//+------------------------------------------------------------------- +// +// Function: UseProtseq +// +// Synopsis: Use the specified protseq and return a list of all string +// bindings. +// +// History: 25 May 95 AlexMit Created +// +//-------------------------------------------------------------------- +error_status_t _UseProtseq( handle_t hRpc, + wchar_t *pwstrProtseq, + DUALSTRINGARRAY **ppsaNewBindings, + DUALSTRINGARRAY **ppsaSecurity ) +{ + BOOL fInUse = FALSE; + RPC_STATUS sc = RPC_S_OK; + + ASSERT_LOCK_RELEASED + LOCK + + // Make sure security is initialized. + sc = DefaultAuthnServices(); + + // If we have never inquired string bindings, inquire them before doing + // anything else. + if (sc == RPC_S_OK && gfDefaultStrings) + { + fInUse = InquireStringBindings( pwstrProtseq ); + gfDefaultStrings = FALSE; + } + + if (sc == RPC_S_OK && !fInUse) + { + // Special case ncalrpc. + if (lstrcmpW( pwstrProtseq, L"ncalrpc" ) == 0) + sc = RegisterLrpc(); + + #ifdef _CHICAGO_ + // Special case mswmsg. + else if (lstrcmpW( pwstrProtseq, L"mswmsg" ) == 0) + { + if (!gfMswmsg) + sc = RegisterMswmsg(); + } + #endif + + // Register all other protocol sequences. + else + { + sc = RpcServerUseProtseq(pwstrProtseq, + RPC_C_PROTSEQ_MAX_REQS_DEFAULT, + NULL); + } + + if (sc != RPC_S_OK) + ComDebOut((DEB_CHANNEL, "Could not register protseq %ws: 0x%x\n", + pwstrProtseq, sc )); + + // Return the latest string bindings. Ignore failures. + InquireStringBindings( NULL ); + } + + // Generate a copy to return. + CopyStringArray( gpsaCurrentProcess, NULL, ppsaNewBindings ); + CopyStringArray( gpsaSecurity, NULL, ppsaSecurity ); + + UNLOCK + ASSERT_LOCK_RELEASED + return RPC_S_OK; +} diff --git a/private/ole32/com/dcomrem/service.hxx b/private/ole32/com/dcomrem/service.hxx new file mode 100644 index 000000000..8868025bc --- /dev/null +++ b/private/ole32/com/dcomrem/service.hxx @@ -0,0 +1,33 @@ +//+------------------------------------------------------------------- +// +// File: service.hxx +// +// Contents: APIs to simplify RPC setup +// +// History: 23-Nov-92 Rickhi Created +// +//-------------------------------------------------------------------- +#ifndef __SERVICE__ +#define __SERVICE__ + + +#define SASIZE(size) (sizeof(ULONG) + (size) * sizeof(WCHAR)) + +// Function Prototypes. +STDAPI StartListen(void); +SCODE UnregisterDcomInterfaces(void); + +RPC_STATUS CheckClientMswmsg ( WCHAR *pProtseq, DWORD * ); +HRESULT CopyStringArray ( DUALSTRINGARRAY *psaStringBind, + DUALSTRINGARRAY *psaSecurity, + DUALSTRINGARRAY **ppsaNew ); +LPWSTR GetLocalEndpoint(); +HRESULT GetStringBindings ( DUALSTRINGARRAY **psaStrings ); + + +extern DWORD gdwEndPoint; // endpoint for current process +extern DWORD gdwPsaMaxSize; // max size of any known psa +extern DUALSTRINGARRAY *gpsaCurrentProcess; // string bindings for current process + + +#endif // __SERVICE__ diff --git a/private/ole32/com/dcomrem/stdid.cxx b/private/ole32/com/dcomrem/stdid.cxx new file mode 100644 index 000000000..2cd070026 --- /dev/null +++ b/private/ole32/com/dcomrem/stdid.cxx @@ -0,0 +1,1378 @@ +//+------------------------------------------------------------------- +// +// File: stdid.cxx +// +// Contents: identity object and creation function +// +// History: 1-Dec-93 CraigWi Created +// 13-Sep-95 Rickhi Simplified +// +//-------------------------------------------------------------------- +#include <ole2int.h> +#include <stdid.hxx> // CStdIdentity +#include <marshal.hxx> // CStdMarshal +#include <idtable.hxx> // Indentity Table + +#include "..\objact\objact.hxx" // used in IProxyManager::CreateServer + + +#if DBG==1 +// head of linked list of identities for debug tracking purposes +CStdIdentity gDbgIDHead; +#endif // DBG + + +//+---------------------------------------------------------------- +// +// Class: CStdIdentity (stdid) +// +// Purpose: To be the representative of the identity of the object. +// +// History: 11-Dec-93 CraigWi Created. +// 21-Apr-94 CraigWi Stubmgr addref's object; move strong cnt +// 10-May-94 CraigWi IEC called for strong connections +// 17-May-94 CraigWi Container weak connections +// 31-May-94 CraigWi Tell object of weak pointers +// +// Details: +// +// The identity is determined on creation of the identity object. On the +// server side a new OID is created, on the client side, the OID contained +// in the OBJREF is used. +// +// The identity pointer is typically stored in the OIDTable, NOT AddRef'd. +// SetOID adds the identity to the table, and can be called from ctor or +// from Unmarshal. RevokeOID removes the identity from the table, and can +// be called from Disconnect, or final Release. +// +//-------------------------------------------------------------------- + +//+------------------------------------------------------------------- +// +// Member: CStdIdentity::CStdIdentity, private +// +// Synopsis: ctor for identity object +// +// Arguments: for all but the last param, see CreateIdentityHandler. +// [ppUnkInternal] -- +// when aggregated, this the internal unknown; +// when not aggregated, this is the controlling unknown +// +// History: 15-Dec-93 CraigWi Created. +// +//-------------------------------------------------------------------- +CStdIdentity::CStdIdentity(DWORD flags, IUnknown *pUnkOuter, + IUnknown *pUnkControl, IUnknown **ppUnkInternal) : + m_refs(1), + m_cStrongRefs(0), + m_flags(flags), + m_pIEC(NULL), + m_moid(GUID_NULL), + m_pUnkOuter((pUnkOuter) ? pUnkOuter : (IMultiQI *)&m_InternalUnk), + m_pUnkControl((pUnkControl) ? pUnkControl : m_pUnkOuter), + CClientSecurity( this ) +{ + ComDebOut((DEB_MARSHAL, "CStdIdentity %s Created this:%x\n", + IsClient() ? "CLIENT" : "SERVER", this)); + Win4Assert(!!IsClient() == (pUnkControl == NULL)); + +#if DBG==1 + // Chain this identity onto the global list of instantiated identities + // so we can track even the ones that are not placed in the ID table. + LOCK + m_pNext = gDbgIDHead.m_pNext; + m_pPrev = &gDbgIDHead; + gDbgIDHead.m_pNext = this; + m_pNext->m_pPrev = this; + UNLOCK +#endif + + + if (pUnkOuter) + { + m_flags |= STDID_AGGREGATED; + } + + CLSID clsidHandler; + DWORD dwSMFlags = SMFLAGS_CLIENT_SIDE; // assume client side + + if (!IsClient()) + { +#if DBG == 1 + // the caller should have a strong reference and so these tests + // should not disturb the object. These just check the sanity of + // the object we are attempting to marshal. + + // addref/release pUnkControl; shouldn't go away (i.e., + // should be other ref to it). + // Do this only if it is not Excel as it always returns which will + // trigger the assert on debug builds unnecessarily! + if (!IsTaskName(L"EXCEL.EXE")) + { + pUnkControl->AddRef(); + Verify(pUnkControl->Release() != 0); + + // verify that pUnkControl is in fact the controlling unknown + IUnknown *pUnkT; + Verify(pUnkControl->QueryInterface(IID_IUnknown,(void **)&pUnkT)==NOERROR); + Win4Assert(pUnkControl == pUnkT); + Verify(pUnkT->Release() != 0); + } +#endif + + dwSMFlags = 0; // server side + m_pUnkControl->AddRef(); + + // determine if we will write a standard or handler objref. we write + // standard unless the object implements IStdMarshalInfo and overrides + // the standard class. we ignore all errors from this point onward in + // order to maintain backward compatibility. + + ASSERT_LOCK_RELEASED + + IStdMarshalInfo *pSMI; + HRESULT hr = m_pUnkControl->QueryInterface(IID_IStdMarshalInfo, + (void **)&pSMI); + if (SUCCEEDED(hr)) + { + hr = pSMI->GetClassForHandler(NULL, NULL, &clsidHandler); + if (SUCCEEDED(hr) && !IsEqualCLSID(clsidHandler, CLSID_NULL)) + { + dwSMFlags |= SMFLAGS_HANDLER; + } + else + { + clsidHandler = GUID_NULL; + } + pSMI->Release(); + } + + // look for the IExternalConnection interface. The StdId will use + // this for Inc/DecStrongCnt. We do the QI here while we are not + // holding the LOCK. + + hr = m_pUnkControl->QueryInterface(IID_IExternalConnection, + (void **)&m_pIEC); + if (FAILED(hr)) + { + // make sure it is NULL + m_pIEC = NULL; + } + + ASSERT_LOCK_RELEASED + } + else + { + m_cStrongRefs = 1; + } + + // now intialize the standard marshaler + CStdMarshal::Init(m_pUnkControl, this, clsidHandler, dwSMFlags); + + *ppUnkInternal = (IMultiQI *)&m_InternalUnk; // this is what the m_refs=1 is for + + AssertValid(); +} + +#if DBG==1 +//+------------------------------------------------------------------- +// +// Member: CStdIdentity::CStdIdentity, public +// +// Synopsis: Special Identity ctor for the debug list head. +// +//+------------------------------------------------------------------- +CStdIdentity::CStdIdentity() : CClientSecurity(this) +{ + Win4Assert(this == &gDbgIDHead); + m_pNext = this; + m_pPrev = this; +} +#endif // DBG + +//+------------------------------------------------------------------- +// +// Member: CStdIdentity::~CStdIdentity, private +// +// Synopsis: Final destruction of the identity object. ID has been +// revoked by now (in internal ::Release). Here we disconnect +// on server. +// +// History: 15-Dec-93 CraigWi Created. +// Rickhi Simplified +// +//-------------------------------------------------------------------- +CStdIdentity::~CStdIdentity() +{ +#if DBG==1 + if (this != &gDbgIDHead) + { +#endif // DBG + + ComDebOut((DEB_MARSHAL, "CStdIdentity %s Deleted this:%x\n", + IsClient() ? "CLIENT" : "SERVER", this)); + + Win4Assert(m_refs == 0); + m_refs++; // simple guard against reentry of dtor + SetNowInDestructor(); // debug flag which enables asserts to detect + + // make sure we have disconnected + Disconnect(); + +#if DBG==1 + // UnChain this identity from the global list of instantiated identities + // so we can track even the ones that are not placed in the ID table. + LOCK + m_pPrev->m_pNext = m_pNext; + m_pNext->m_pPrev = m_pPrev; + UNLOCK + } +#endif // DBG +} + +//+------------------------------------------------------------------- +// +// Member: CStdIdentity::CInternalUnk::QueryInterface, private +// +// Synopsis: Queries for an interface. Just delegates to the common +// code in QueryMultipleInterfaces. +// +// History: 26-Feb-96 Rickhi Created +// +//-------------------------------------------------------------------- +STDMETHODIMP CStdIdentity::CInternalUnk::QueryInterface(REFIID riid, VOID **ppv) +{ + MULTI_QI mqi; + mqi.pIID = &riid; + mqi.pItf = NULL; + + QueryMultipleInterfaces(1, &mqi); + + *ppv = (void *)mqi.pItf; + return mqi.hr; +} + +//+------------------------------------------------------------------- +// +// Member: CStdIdentity::CInternalUnk::QueryMultipleInterfaces, public +// +// Synopsis: +// +// History: 26-Feb-96 Rickhi Created +// +//-------------------------------------------------------------------- +STDMETHODIMP CStdIdentity::CInternalUnk::QueryMultipleInterfaces(ULONG cMQIs, + MULTI_QI *pMQIs) +{ + // Make sure TLS is initialized. + HRESULT hr; + COleTls tls(hr); + if (FAILED(hr)) + return hr; + + CStdIdentity *pStdID = GETPPARENT(this, CStdIdentity, m_InternalUnk); + pStdID->AssertValid(); + + // allocate some space on the stack for the intermediate results. declare + // working pointers and remember the start address of the allocations. + + MULTI_QI **ppMQIAlloc = (MULTI_QI **)_alloca(sizeof(MULTI_QI *) * cMQIs); + IID *pIIDAlloc = (IID *) _alloca(sizeof(IID) * cMQIs); + SQIResult *pSQIAlloc = (SQIResult *)_alloca(sizeof(SQIResult) * cMQIs); + + MULTI_QI **ppMQIPending = ppMQIAlloc; + IID *pIIDPending = pIIDAlloc; + SQIResult *pSQIPending = pSQIAlloc; + + + // loop over the interfaces looking for locally supported interfaces, + // instantiated proxies, and unsupported interfaces. Gather up all the + // interfaces that dont fall into the above categories, and issue a + // remote query to the server. + + USHORT cPending = 0; + ULONG cAcquired = 0; + MULTI_QI *pMQI = pMQIs; + + for (ULONG i=0; i<cMQIs; i++, pMQI++) + { + if (pMQI->pItf != NULL) + { + // skip any entries that are not set to NULL. This allows + // progressive layers of handlers to optionally fill in the + // interfaces that they know about and pass the whole array + // on to the next level. + continue; + } + + pMQI->hr = S_OK; + + // always allow - IUnknown, IMarshal, IStdIdentity, Instantiated proxies + if (InlineIsEqualGUID(*(pMQI->pIID), IID_IUnknown)) + { + pMQI->pItf = (IMultiQI *)this; + } + else if (InlineIsEqualGUID(*(pMQI->pIID), IID_IMarshal)) + { + pMQI->pItf = (IMarshal *)pStdID; + } + else if (InlineIsEqualGUID(*(pMQI->pIID), IID_IStdIdentity)) + { + pMQI->pItf = (IUnknown *)(void*)pStdID; + } + else if (InlineIsEqualGUID(*(pMQI->pIID), IID_IProxyManager)) + { + // old code exposed this IID and things now depend on it. + pMQI->pItf = (IProxyManager *)pStdID; + } + else if (pStdID->InstantiatedProxy(*(pMQI->pIID),(void **)&pMQI->pItf, + &pMQI->hr)) + { + // a proxy for this interface already exists + // + // NOTE: this call also set pMQI->hr = E_NOINTERFACE if the + // StId has never been connected, and to CO_E_OBJNOTCONNECTED if + // it has been connected but is not currently connected. This is + // required for backwards compatibility, and will cause us to skip + // the QueryRemoteInterface. + ; + } + else if (pStdID->IsAggregated()) + { + // aggregate case + // allow - IInternalUnknown + // dissallow - IMultiQI, IClientSecurity, IServerSecurity + + if (InlineIsEqualGUID(*(pMQI->pIID), IID_IInternalUnknown)) + { + pMQI->pItf = (IInternalUnknown *)this; + pMQI->hr = S_OK; + } + else if (InlineIsEqualGUID(*(pMQI->pIID), IID_IMultiQI) || + InlineIsEqualGUID(*(pMQI->pIID), IID_IClientSecurity) || + InlineIsEqualGUID(*(pMQI->pIID), IID_IServerSecurity)) + { + pMQI->hr = E_NOINTERFACE; + } + else if (pMQI->hr == S_OK) + { + // InstantiatedProxy did not return E_NOINTERFACE or + // CO_E_OBJNOTCONNECTED so add this interface to the + // list to pass to the QueryRemoteInterfaces. + + pMQI->hr = RPC_S_CALLPENDING; + } + } + else + { + // non-aggregate case + // allow - IClientSecurity, IMultiQI + // dissallow - IInternalUnknown, IServerSecurity + + if (InlineIsEqualGUID(*(pMQI->pIID), IID_IClientSecurity)) + { + pMQI->pItf = (IClientSecurity *)pStdID; + pMQI->hr = S_OK; + } + else if (InlineIsEqualGUID(*(pMQI->pIID), IID_IMultiQI)) + { + pMQI->pItf = (IMultiQI *)this; + pMQI->hr = S_OK; + } + else if (InlineIsEqualGUID(*(pMQI->pIID), IID_IInternalUnknown) || + InlineIsEqualGUID(*(pMQI->pIID), IID_IServerSecurity)) + { + pMQI->hr = E_NOINTERFACE; + } + else if (pMQI->hr == S_OK) + { + // InstantiatedProxy did not return E_NOINTERFACE or + // CO_E_OBJNOTCONNECTED so add this interface to the + // list to pass to the QueryRemoteInterfaces. + + pMQI->hr = RPC_S_CALLPENDING; + } + } + + if (pMQI->hr == S_OK) + { + // got an interface to return, AddRef it and count one more + // interface acquired. + + pMQI->pItf->AddRef(); + cAcquired++; + } + else if (pMQI->hr == RPC_S_CALLPENDING) + { + // fill in a remote QI structure and count one more + // pending interface + + pSQIPending->pv = NULL; + pSQIPending->hr = S_OK; + *pIIDPending = *(pMQI->pIID); + *ppMQIPending = pMQI; + + pSQIPending++; + pIIDPending++; + ppMQIPending++; + cPending++; + } + } + + if (cPending > 0) + { + // there are some interfaces which we dont yet know about, so + // go ask the remoting layer to Query the server and build proxies + // where possible. The results are returned in the individual + // SQIResults, so the overall return code is ignored. + + pStdID->QueryRemoteInterfaces(cPending, pIIDAlloc, pSQIAlloc); + + // got some interfaces, loop over the remote QI structure filling + // in the rest of the MULTI_QI structure to return to the caller. + // the proxies are already AddRef'd. + + pSQIPending = pSQIAlloc; + ppMQIPending = ppMQIAlloc; + + for (i=0; i<cPending; i++, pSQIPending++, ppMQIPending++) + { + pMQI = *ppMQIPending; + pMQI->pItf = (IUnknown *)(pSQIPending->pv); + pMQI->hr = pSQIPending->hr; + + if (SUCCEEDED(pMQI->hr)) + { + // count one more acquired interface + cAcquired++; + } + } + } + + // if we got all the interfaces, return S_OK. If we got none of the + // interfaces, return E_NOINTERFACE. If we got some, but not all, of + // the interfaces, return S_FALSE; + + if (cAcquired == cMQIs) + return S_OK; + else if (cAcquired > 0) + return S_FALSE; + else + return E_NOINTERFACE; +} + +//+------------------------------------------------------------------- +// +// Member: CStdIdentity::CInternalUnk::QueryInternalInterface, public +// +// Synopsis: return interfaces that are internal to the aggregated +// proxy manager. +// +// History: 26-Feb-96 Rickhi Created +// +//-------------------------------------------------------------------- +STDMETHODIMP CStdIdentity::CInternalUnk::QueryInternalInterface(REFIID riid, + VOID **ppv) +{ + CStdIdentity *pStdID = GETPPARENT(this, CStdIdentity, m_InternalUnk); + pStdID->AssertValid(); + + if (!pStdID->IsAggregated()) + { + // this method is only valid when we are part of a client-side + // aggregate. + return E_NOTIMPL; + } + + if (InlineIsEqualGUID(riid, IID_IUnknown) || + InlineIsEqualGUID(riid, IID_IInternalUnknown)) + { + *ppv = (IInternalUnknown *)this; + } + else if (InlineIsEqualGUID(riid, IID_IMultiQI)) + { + *ppv = (IMultiQI *)this; + } + else if (InlineIsEqualGUID(riid, IID_IStdIdentity)) + { + *ppv = pStdID; + } + else if (InlineIsEqualGUID(riid, IID_IClientSecurity)) + { + *ppv = (IClientSecurity *)pStdID; + } + else if (InlineIsEqualGUID(riid, IID_IProxyManager)) + { + *ppv = (IProxyManager *)pStdID; + } + else + { + *ppv = NULL; + return E_NOINTERFACE; + } + + ((IUnknown *)*ppv)->AddRef(); + return S_OK; +} + + +//+------------------------------------------------------------------- +// +// Member: CStdIdentity::CInternalUnk::AddRef, public +// +// Synopsis: Nothing special. +// +// History: 15-Dec-93 CraigWi Created. +// +//-------------------------------------------------------------------- +STDMETHODIMP_(ULONG) CStdIdentity::CInternalUnk::AddRef(void) +{ + CStdIdentity *pStdID = GETPPARENT(this, CStdIdentity, m_InternalUnk); + pStdID->AssertValid(); + + AssertSz(!pStdID->IsInDestructor(), "CStdIdentity AddRef'd during destruction"); + + InterlockedIncrement((long *)&pStdID->m_refs); + // ComDebOut((DEB_MARSHAL, "StdId:CtrlUnk::AddRef this:%x m_refs:%x\n", pStdID, pStdID->m_refs)); + return pStdID->m_refs; +} + +//+------------------------------------------------------------------- +// +// Member: CStdIdentity::CInternalUnk::Release, public +// +// Synopsis: Releases the identity object. When the ref count goes +// to zero, revokes the id and destroys the object. +// +// History: 15-Dec-93 CraigWi Created. +// 18-Apr-95 Rickhi Rewrote much faster/simpler +// +//-------------------------------------------------------------------- +STDMETHODIMP_(ULONG) CStdIdentity::CInternalUnk::Release(void) +{ + CStdIdentity *pStdID = GETPPARENT(this, CStdIdentity, m_InternalUnk); + pStdID->AssertValid(); + + DWORD refs = pStdID->m_refs - 1; + // ComDebOut((DEB_MARSHAL, "StdId:CtrlUnk::Release this:%x m_refs:%x\n", pStdID, refs)); + + if (InterlockedDecrement((long *)&pStdID->m_refs) == 0) + { + BOOL fDelete = FALSE; + ASSERT_LOCK_RELEASED + LOCK + + // check if we are already in the dtor and skip a second destruction + // if so. The reason we need this is that some crusty old apps do + // CoMarshalInterface followed by CoLockObjectExternal(FALSE,TRUE), + // expecting this to accomplish a Disconnect. It subtracts from the + // references, but it takes away the ones that the IPIDEntry put on, + // without telling the IPIDEntry, so when we release the IPIDEntry, + // our count goes negative!!! + + // the LockedInMemory flag is for the gpStdMarshal instance that we + // may hand out to clients, but which we never want to go away, + // regardless of how many times they call Release. + + if (pStdID->m_refs == 0) + { + // refcnt is still zero, so the idtable did not just hand + // out a reference behind our back. + + if (!pStdID->IsLockedOrInDestructor()) + { + // remove from the OID table and delete the identity + // We dont delete while holding the table mutex. + + pStdID->RevokeOID(); + fDelete = TRUE; + } + else + { + // this object is locked in memory and we should never + // get here, but some broken test app was doing this in + // stress. + + pStdID->m_refs = 100; + } + } + + UNLOCK + ASSERT_LOCK_RELEASED + + if (fDelete) + { + delete pStdID; + return 0; + } + } + + return refs; +} + +//+------------------------------------------------------------------- +// +// Member: CStdIdentity::IUnknown methods, public +// +// Synopsis: External IUnknown methods; delegates to m_pUnkOuter. +// +// History: 15-Dec-93 CraigWi Created. +// +//-------------------------------------------------------------------- +STDMETHODIMP CStdIdentity::QueryInterface(REFIID riid, VOID **ppvObj) +{ + AssertValid(); + return m_pUnkOuter->QueryInterface(riid, ppvObj); +} + +STDMETHODIMP_(ULONG) CStdIdentity::AddRef(void) +{ + AssertValid(); + return m_pUnkOuter->AddRef(); +} + +STDMETHODIMP_(ULONG) CStdIdentity::Release(void) +{ + AssertValid(); + return m_pUnkOuter->Release(); +} + +//+------------------------------------------------------------------- +// +// Member: CStdIdentity::UnlockAndRelease, public +// +// Synopsis: Version of Release used for gpStdMarshal, that is +// currently locked in memory so nobody but us can +// release it, regardless of refcnt. +// +// History: 19-Apr-96 Rickhi Created +// +//-------------------------------------------------------------------- +ULONG CStdIdentity::UnlockAndRelease(void) +{ + m_flags &= ~STDID_LOCKEDINMEM; + m_refs = 1; + return m_pUnkOuter->Release(); +} + +//+------------------------------------------------------------------- +// +// Member: CStdIdentity::IncStrongCnt, public +// +// Synopsis: Increments the strong reference count on the identity. +// +// History: 15-Dec-93 Rickhi Created. +// +//-------------------------------------------------------------------- +void CStdIdentity::IncStrongCnt() +{ + Win4Assert(!IsClient()); + + // we might be holding the lock here if this is called from + // LookupIDFromUnk, since we have to be holding the lock while + // doing the lookup. We cant release it or we could go away. + + ASSERT_LOCK_DONTCARE + + ComDebOut((DEB_MARSHAL, + "CStdIdentity::IncStrongCnt this:%x cStrong:%x\n", + this, m_cStrongRefs+1)); + + AddRef(); + InterlockedIncrement(&m_cStrongRefs); + + if (m_pIEC) + { + m_pIEC->AddConnection(EXTCONN_STRONG, 0); + } + + ASSERT_LOCK_DONTCARE +} + +//+------------------------------------------------------------------- +// +// Member: CStdIdentity::DecStrongCnt, public +// +// Synopsis: Decrements the strong reference count on the identity, +// and releases the object if that was the last strong +// reference. +// +// History: 15-Dec-93 Rickhi Created. +// +//-------------------------------------------------------------------- +void CStdIdentity::DecStrongCnt(BOOL fKeepAlive) +{ + Win4Assert(!IsClient()); + ASSERT_LOCK_RELEASED + + ComDebOut((DEB_MARSHAL, + "CStdIdentity::DecStrongCnt this:%x cStrong:%x fKeepAlive:%x\n", + this, m_cStrongRefs-1, fKeepAlive)); + + LONG cStrongRefs = InterlockedDecrement(&m_cStrongRefs); + + if (m_pIEC) + { + m_pIEC->ReleaseConnection(EXTCONN_STRONG, 0, !fKeepAlive); + } + + if (cStrongRefs == 0 && !fKeepAlive && (IsWOWThread() || m_pIEC == NULL)) + { + // strong count has gone to zero, disconnect. + DisconnectObject(0); + } + + if (cStrongRefs >= 0) + { + // some apps call CoMarshalInterface + CoLockObjectExternal(F,T) + // and expect the object to go away. Doing that causes Release to + // be called too many times (once for each IPID, once for CLOE, and + // once for the original Lookup). + Release(); + } + + ASSERT_LOCK_RELEASED +} + +//+------------------------------------------------------------------- +// +// Member: CStdIdentity::LockObjectExternal, public +// +// Synopsis: locks (or unlocks) the object so the remoting layer does +// not (or does) go away. +// +// History: 09-Oct-96 Rickhi Moved from CoLockObjectExternal. +// +//-------------------------------------------------------------------- +HRESULT CStdIdentity::LockObjectExternal(BOOL fLock, BOOL fLastUR) +{ + HRESULT hr = S_OK; + + if (GetServer() == NULL) + { + // attempt to lock handler, return error! + hr = E_UNEXPECTED; + } + else if (fLock) + { + // lock (and ignore rundowns) so it does not go away + IncStrongCnt(); + LOCK; + IncTableCnt(); + UNLOCK; + } + else + { + // unlock so that it can go away + LOCK; + DecTableCnt(); + UNLOCK; + DecStrongCnt(!fLastUR); + } + + return hr; +} + +//+------------------------------------------------------------------- +// +// Member: CStdIdentity::GetServer, public +// +// Synopsis: Returns a pUnk for the identified object; NULL on client side +// The pointer is optionally addrefed depending upon fAddRef +// +// Returns: The pUnk on the object. +// +// History: 15-Dec-93 CraigWi Created. +// +//-------------------------------------------------------------------- +IUnknown * CStdIdentity::GetServer() +{ + if (IsClient() || m_pUnkControl == NULL) + return NULL; + + // Verify validity + Win4Assert(IsValidInterface(m_pUnkControl)); + return m_pUnkControl; +} + +//+------------------------------------------------------------------- +// +// Member: CStdIdentity::ReleaseCtrlUnk, public +// +// Synopsis: Releases the server side controlling unknown +// This code is safe for reentrant calls. +// +// History: 11-Jun-95 Rickhi Created +// +//-------------------------------------------------------------------- +void CStdIdentity::ReleaseCtrlUnk(void) +{ + AssertValid(); + Win4Assert(!IsClient()); + + if (m_pUnkControl) + { + // server side: release the real object's m_pUnkControl; + // prevent problem on recursive disconnect + + AssertSz(IsValidInterface(m_pUnkControl), + "Invalid IUnknown during disconnect"); + IUnknown *pUnkControl = m_pUnkControl; + m_pUnkControl = NULL; + + if (m_pIEC) + { + AssertSz(IsValidInterface(m_pIEC), + "Invalid IExternalConnection during disconnect"); + m_pIEC->Release(); + m_pIEC = NULL; + } + + pUnkControl->Release(); + } +} + +//+------------------------------------------------------------------- +// +// Member: CStdIdentity::SetOID, public +// +// Synopsis: Associates the OID and the object (handler or server). +// +// History: 20-Feb-95 Rickhi Simplified +// +//-------------------------------------------------------------------- +HRESULT CStdIdentity::SetOID(REFMOID rmoid) +{ + Win4Assert(rmoid != GUID_NULL); + ASSERT_LOCK_HELD + + HRESULT hr = S_OK; + + if (!(m_flags & STDID_HAVEID)) + { + if (!(m_flags & STDID_IGNOREID)) + { + Win4Assert(!(m_flags & STDID_FREETHREADED)); + hr = SetObjectID(rmoid, m_pUnkControl, this); + } + + if (SUCCEEDED(hr)) + { + m_flags |= STDID_HAVEID; + m_moid = rmoid; + } + } + + ComDebErr(hr != S_OK, "SetOID Failed. Probably OOM.\n"); + return hr; +} + + +//+------------------------------------------------------------------- +// +// Member: CStdIdentity::RevokeOID, public +// +// Synopsis: Disassociates the OID and the object (handler or server). +// Various other methods will fail (e.g., MarshalInterface). +// +// History: 15-Dec-93 CraigWi Created. +// 20-Feb-95 Rickhi Simplified +// +//-------------------------------------------------------------------- +void CStdIdentity::RevokeOID(void) +{ + AssertValid(); + ASSERT_LOCK_HELD + + if (m_flags & STDID_HAVEID) + { + m_flags &= ~STDID_HAVEID; + + if (!(m_flags & STDID_IGNOREID)) + (void)ClearObjectID(m_moid, m_pUnkControl, this); + } +} + +//+------------------------------------------------------------------- +// +// Member: CStdIdentity::IsConnected, public +// +// Synopsis: Indicates if the client is connected to the server. +// Only the negative answer is definitive because we +// might not be able to tell if the server is connected +// and even if we could, the answer might be wrong by +// the time the caller acted on it. +// +// Returns: TRUE if the server might be connected; FALSE if +// definitely not. +// +// History: 16-Dec-93 CraigWi Created. +// +//-------------------------------------------------------------------- +STDMETHODIMP_(BOOL) CStdIdentity::IsConnected(void) +{ + Win4Assert(IsClient()); // must be client side + AssertValid(); + + return RemIsConnected(); +} + +//+------------------------------------------------------------------- +// +// Member: CStdIdentity::Disconnect, public +// +// Synopsis: IProxyManager::Disconnect implementation, just forwards +// to the standard marshaller, which may call us back to +// revoke our OID and release our CtrlUnk. +// +// May also be called by the IDTable cleanup code. +// +// History: 11-Jun-95 Rickhi Created. +// +//-------------------------------------------------------------------- +STDMETHODIMP_(void) CStdIdentity::Disconnect(void) +{ + AssertValid(); + CStdMarshal::Disconnect(); +} + +//+------------------------------------------------------------------- +// +// Member: CStdIdentity::LockConnection, public +// +// Synopsis: IProxyManager::LockConnection implementation. Changes +// all interfaces to weak from strong, or strong from weak. +// +// History: 11-Jun-95 Rickhi Created. +// +//-------------------------------------------------------------------- +STDMETHODIMP CStdIdentity::LockConnection(BOOL fLock, BOOL fLastUnlockReleases) +{ + AssertValid(); + + if (!IsClient()) + { + // this operation does not make sense on the server side. + return E_NOTIMPL; + } + + if (IsMTAThread()) + { + // this call is not allowed if we are FreeThreaded. Report + // success, even though we did not do anything. + return S_OK; + } + + + if (( fLock && (++m_cStrongRefs == 1)) || + (!fLock && (--m_cStrongRefs == 0))) + { + // the strong count transitioned from 0 to 1 or 1 to 0, so + // call the server to change our references. + + return RemoteChangeRef(fLock, fLastUnlockReleases); + } + + return S_OK; +} + +//+------------------------------------------------------------------- +// +// Member: CStdIdentity::CreateServer, public +// +// Synopsis: Creates the server clsid in the given context and +// attaches it to this handler. +// +// History: 16-Dec-93 CraigWi Created. +// +// CODEWORK: this code is not thread safe in the freethreading case. We +// need to decide if the thread safety is the responsibility +// of the caller, or us. In the latter case, we would check +// if we are already connected before doing UnmarshalObjRef, and +// instead do a ::ReleaseMarshalObjRef. +// +//-------------------------------------------------------------------- +STDMETHODIMP CStdIdentity::CreateServer(REFCLSID rclsid, DWORD clsctx, void *pv) +{ + ComDebOut((DEB_ACTIVATE, "ScmCreateObjectInstance this:%x clsctx:%x pv:%x\n", + this, clsctx, pv)); + AssertValid(); + Win4Assert(IsClient()); // must be client side + Win4Assert(IsValidInterface(m_pUnkControl)); // must be valid + //Win4Assert(!IsConnected()); + ASSERT_LOCK_RELEASED + + // Loop trying to get object from the server. Because the server can be + // in the process of shutting down and respond with a marshaled interface, + // we will retry this call if unmarshaling fails assuming that the above + // is true. + + HRESULT hr = InitChannelIfNecessary(); + if (FAILED(hr)) + return hr; + + const int MAX_SERVER_TRIES = 3; + + for (int i = 0; i < MAX_SERVER_TRIES; i++) + { + // create object and get back marshaled interface pointer + InterfaceData *pIFD = NULL; + + // Dll ignored here since we are just doing this to get + // the remote handler. + WCHAR *pwszDllPath = NULL; + DWORD dwDllType = IsSTAThread() ? APT_THREADED : FREE_THREADED; + +#ifdef DCOM + HRESULT hrinterface; + hr = gResolver.CreateInstance( NULL, (CLSID *)&rclsid, clsctx, 1, + (IID *)&IID_IUnknown, (MInterfacePointer **)&pIFD, + &hrinterface,&dwDllType, &pwszDllPath ); +#else + // The first three NULLs (pwszFrom, pstgFrom, pwszNew) trigger a + // simple creation. + hr = gResolver.CreateObject(rclsid, clsctx, 0, + NULL, NULL, NULL, &pIFD, &dwDllType, &pwszDllPath, NULL); +#endif + + if (pwszDllPath != NULL) + { + CoTaskMemFree(pwszDllPath); + } + + if (FAILED(hr)) + { + // If an error occurred, return that otherwise convert a wierd + // success into E_FAIL. The point here is to return an error that + // the caller can figure out what happened. + hr = FAILED(hr) ? hr : E_FAIL; + break; + } + + + // make a stream out of the interface data returned, then read the + // objref from the stream. No need to find another instance of + // CStdMarshal because we already know it is for us! + + CXmitRpcStream Stm(pIFD); + OBJREF objref; + hr = ReadObjRef(&Stm, objref); + + if (SUCCEEDED(hr)) + { + // become this identity by unmarshaling the objref into this + // object. Note the objref must use standard marshaling. + Win4Assert(objref.flags & (OBJREF_HANDLER | OBJREF_STANDARD)); + Win4Assert(IsEqualIID(objref.iid, IID_IUnknown)); + + IUnknown *pUnk = NULL; + hr = UnmarshalObjRef(objref, (void **)&pUnk); + if (SUCCEEDED(hr)) + { + // release the AddRef done by unmarshaling + pUnk->Release(); + + // Reconnect the interface proxies + CStdMarshal::ReconnectProxies(); + } + + // free the objref we read above. + FreeObjRef(objref); + } + + CoTaskMemFree(pIFD); + + + // If either this worked or we got a packet we couldn't unmarshal + // at all we give up. Otherwise, we will hope that recontacting the + // SCM will fix things. + + if (SUCCEEDED(hr) || (hr == E_FAIL)) + { + break; + } + } + + ASSERT_LOCK_RELEASED + ComDebOut((DEB_ACTIVATE, "ScmCreateObjectInstance this:%x hr:%x\n", + this, hr)); + return hr; +} +//+------------------------------------------------------------------- +// +// Member: CStdIdentity::CreateServerWithHandler, public +// +// Synopsis: Creates the server clsid in the given context and +// attaches it to this handler. +// +// History: 10-Oct-95 JohannP Created +// +// CODEWORK: this code is not thread safe in the freethreading case. We +// need to decide if the thread safety is the responsibility +// of the caller, or us. In the latter case, we would check +// if we are already connected before doing UnmarshalObjRef, and +// instead do a ::ReleaseMarshalObjRef. +// +//-------------------------------------------------------------------- +STDMETHODIMP CStdIdentity::CreateServerWithHandler(REFCLSID rclsid, DWORD clsctx, void *pv, + REFCLSID rclsidHandler, IID iidSrv, void **ppv, + IID iidClnt, void *pClientSiteInterface) +{ + ComDebOut((DEB_ACTIVATE, "ScmCreateObjectInstance this:%x clsctx:%x pv:%x\n", + this, clsctx, pv)); + AssertValid(); + Win4Assert(IsClient()); // must be client side + Win4Assert(IsValidInterface(m_pUnkControl)); // must be valid + //Win4Assert(!IsConnected()); + Win4Assert(ppv != NULL); + + ASSERT_LOCK_RELEASED + + // Loop trying to get object from the server. Because the server can be + // in the process of shutting down and respond with a marshaled interface, + // we will retry this call if unmarshaling fails assuming that the above + // is true. + + HRESULT hr = InitChannelIfNecessary(); + if (FAILED(hr)) + return hr; + + IClientSiteHandler *pClientSiteHandler = (IClientSiteHandler *)pClientSiteInterface; + + const int MAX_SERVER_TRIES = 3; + + for (int i = 0; i < MAX_SERVER_TRIES; i++) + { + // create object and get back marshaled interface pointer + InterfaceData *pIFD = NULL; + + // Dll ignored here since we are just doing this to get + // the remote handler. + WCHAR *pwszDllPath = NULL; + DWORD dwDllType = IsSTAThread() ? APT_THREADED : FREE_THREADED; + +#ifdef DCOM + HRESULT hrinterface; + + // marshal ClientSiteHandler + MInterfacePointer * pIFPServerHandler = NULL; + MInterfacePointer * pIFPClientSiteHandler = NULL; + + if (pClientSiteHandler) + { + // addref once here - MarshalHelper calls release on the object + pClientSiteHandler->AddRef(); + hr = MarshalHelper(pClientSiteHandler, IID_IClientSiteHandler, + MSHLFLAGS_NORMAL, + (InterfaceData **) &pIFPClientSiteHandler); + } + + if (SUCCEEDED(hr)) + { + hr = gResolver.CreateInstance( NULL, (CLSID *)&rclsid, clsctx, 1, + (IID *)&IID_IUnknown, (MInterfacePointer **)&pIFD, + &hrinterface, &dwDllType, &pwszDllPath ); + + if (pIFPServerHandler) + { + if (SUCCEEDED(hr)) + { + CXmitRpcStream Stm((InterfaceData *) pIFPServerHandler); + hr = CoUnmarshalInterface(&Stm, IID_IServerHandler, ppv); + } + CoTaskMemFree(pIFPServerHandler); + } + + PrivMemFree(pIFPClientSiteHandler); + } + else + { + hr = gResolver.CreateInstance( NULL, (CLSID *)&rclsid, clsctx, 1, + (IID *)&IID_IUnknown, (MInterfacePointer **)&pIFD, + &hrinterface,&dwDllType, &pwszDllPath ); + } + +#else + // The first three NULLs (pwszFrom, pstgFrom, pwszNew) trigger a + // simple creation. + hr = gResolver.CreateObject(rclsid, clsctx, 0, + NULL, NULL, NULL, &pIFD, &dwDllType, &pwszDllPath, NULL); +#endif + + if (pwszDllPath != NULL) + { + CoTaskMemFree(pwszDllPath); + } + + if (FAILED(hr)) + { + // If an error occurred, return that otherwise convert a wierd + // success into E_FAIL. The point here is to return an error that + // the caller can figure out what happened. + hr = FAILED(hr) ? hr : E_FAIL; + break; + } + + + // make a stream out of the interface data returned, then read the + // objref from the stream. No need to find another instance of + // CStdMarshal because we already know it is for us! + + CXmitRpcStream Stm(pIFD); + OBJREF objref; + hr = ReadObjRef(&Stm, objref); + + if (SUCCEEDED(hr)) + { + // become this identity by unmarshaling the objref into this + // object. Note the objref must use standard marshaling. + Win4Assert(objref.flags & (OBJREF_HANDLER | OBJREF_STANDARD)); + Win4Assert(IsEqualIID(objref.iid, IID_IUnknown)); + + IUnknown *pUnk = NULL; + hr = UnmarshalObjRef(objref, (void **)&pUnk); + if (SUCCEEDED(hr)) + { + // release the AddRef done by unmarshaling + pUnk->Release(); + + // Reconnect the interface proxies + CStdMarshal::ReconnectProxies(); + } + + // free the objref we read above. + FreeObjRef(objref); + } + + CoTaskMemFree(pIFD); + + + // If either this worked or we got a packet we couldn't unmarshal + // at all we give up. Otherwise, we will hope that recontacting the + // SCM will fix things. + + if (SUCCEEDED(hr) || (hr == E_FAIL)) + { + break; + } + } + + ASSERT_LOCK_RELEASED + ComDebOut((DEB_ACTIVATE, "ScmCreateObjectInstance this:%x hr:%x\n", + this, hr)); + return hr; +} + + + +#if DBG == 1 +//+------------------------------------------------------------------- +// +// Member: CStdIdentity::AssertValid +// +// Synopsis: Validates that the state of the object is consistent. +// +// History: 26-Jan-94 CraigWi Created. +// +//-------------------------------------------------------------------- +void CStdIdentity::AssertValid() +{ + LOCK + AssertSz(m_refs < 0x7fff, "Identity ref count unreasonable"); + + // ensure we have the controlling unknown + Win4Assert(IsValidInterface(m_pUnkOuter)); // must be valid + + // NOTE: don't carelessly AddRef/Release because of weak references + + Win4Assert((m_flags & ~(STDID_SERVER | STDID_CLIENT | STDID_HAVEID | + STDID_FREETHREADED | STDID_INDESTRUCTOR | + STDID_IGNOREID | STDID_AGGREGATED | + STDID_LOCKEDINMEM)) == 0); + + if ((m_flags & STDID_HAVEID) && + !(m_flags & (STDID_FREETHREADED | STDID_IGNOREID))) + { + CStdIdentity *pStdID; + Verify(LookupIDFromID(m_moid, FALSE /*fAddRef*/, &pStdID) == NOERROR); + Win4Assert(pStdID == this); + // pStdID not addref'd + } + + if (IsClient()) + Win4Assert(m_pUnkControl == m_pUnkOuter); + + // must have RH tell identity when object goes away so we can NULL this + if (m_pUnkControl != NULL) + Win4Assert(IsValidInterface(m_pUnkControl)); // must be valid + + if (m_pIEC != NULL) + Win4Assert(IsValidInterface(m_pIEC)); // must be valid + + UNLOCK +} +#endif // DBG == 1 + +//+------------------------------------------------------------------- +// +// Function: CreateIdentityHandler, private +// +// Synopsis: Creates a client side identity object (one which is +// initialized by the first unmarshal). +// +// Arguments: [pUnkOuter] - controlling unknown if aggregated +// [flags] - flags (indicates free-threaded or not) +// [riid] - interface requested +// [ppv] - place for pointer to that interface. +// +// History: 16-Dec-93 CraigWi Created. +// 20-Feb-95 Rickhi Simplified +// +//-------------------------------------------------------------------- +INTERNAL CreateIdentityHandler(IUnknown *pUnkOuter, DWORD flags, + REFIID riid, void **ppv) +{ +#if DBG == 1 + Win4Assert(IsApartmentInitialized()); + + // if aggregating, it must ask for IUnknown. + Win4Assert(pUnkOuter == NULL || InlineIsEqualGUID(riid, IID_IUnknown)); + + if (pUnkOuter != NULL) + { + // addref/release pUnkOuter; shouldn't go away (i.e., + // should be other ref to it). + // Except Excel which always returns 0 on Release! + if (!IsTaskName(L"EXCEL.EXE")) + { + pUnkOuter->AddRef(); + Verify(pUnkOuter->Release() != 0); + + // verify that pUnkOuter is in fact the controlling unknown + IUnknown *pUnkT; + Verify(pUnkOuter->QueryInterface(IID_IUnknown,(void**)&pUnkT)==NOERROR); + Win4Assert(pUnkOuter == pUnkT); + Verify(pUnkT->Release() != 0); + } + } +#endif + + *ppv = NULL; + IUnknown *pUnkID; + HRESULT hr = E_OUTOFMEMORY; + + DWORD StdIdFlags = (flags & SORF_FREETHREADED) ? STDID_CLIENT | STDID_FREETHREADED : + STDID_CLIENT; + + CStdIdentity *pStdId = new CStdIdentity(StdIdFlags, pUnkOuter, + NULL, &pUnkID); + if (pStdId) + { + // get the interface the caller asked for. + hr = pUnkID->QueryInterface(riid, ppv); + pUnkID->Release(); + } + + CALLHOOKOBJECTCREATE(hr,CLSID_NULL,riid,(IUnknown **)ppv); + return hr; +} diff --git a/private/ole32/com/dcomrem/stdid.hxx b/private/ole32/com/dcomrem/stdid.hxx new file mode 100644 index 000000000..2e187fc65 --- /dev/null +++ b/private/ole32/com/dcomrem/stdid.hxx @@ -0,0 +1,143 @@ +//+------------------------------------------------------------------- +// +// File: stdid.hxx +// +// Contents: identity object and creation function +// +// History: 1-Dec-93 CraigWi Created +// +//-------------------------------------------------------------------- +#ifndef _STDID_HXX_ +#define _STDID_HXX_ + +#include <marshal.hxx> // CStdMarshal +#include <idtable.hxx> // IDTable APIs +#include <srvhdl.h> +#include <security.hxx> // CClientSecurity + + + +#define DECLARE_INTERNAL_UNK() \ + class CInternalUnk : public IInternalUnknown, public IMultiQI \ + { \ + public: \ + /* IUnknown methods */ \ + STDMETHOD(QueryInterface)(REFIID riid, VOID **ppv); \ + STDMETHOD_(ULONG,AddRef)(void) ; \ + STDMETHOD_(ULONG,Release)(void); \ + \ + /* IInternalUnknown methods */ \ + STDMETHOD(QueryInternalInterface)(REFIID riid, VOID **ppv); \ + \ + /* IMultiQI methods */ \ + STDMETHOD(QueryMultipleInterfaces)(ULONG cMQIs, MULTI_QI *pMQIs); \ + }; \ + friend CInternalUnk; \ + CInternalUnk m_InternalUnk; + + +typedef enum tagSTDID_FLAGS +{ + STDID_SERVER = 0x0, // on server side + STDID_CLIENT = 0x1, // on client side (non-local in RH terms) + STDID_FREETHREADED = 0x2, // this object is callable on any thread + STDID_HAVEID = 0x4, // have an OID in the table + STDID_IGNOREID = 0x8, // dont put OID in the table + STDID_AGGREGATED = 0x10, // dont put OID in the table + STDID_INDESTRUCTOR = 0x100,// dtor entered; assert on AddRef and others + STDID_LOCKEDINMEM = 0x200,// dont delete this object regardless of refcnt +} STDID_FLAGS; + + +class CStdIdentity : public IProxyManager, public CStdMarshal, + public CClientSecurity +{ +public: + CStdIdentity(DWORD flags, IUnknown *pUnkOuter, IUnknown *pUnkControl, + IUnknown **ppUnkInternal); + ~CStdIdentity(); + + // IUnknown + STDMETHOD(QueryInterface) (REFIID riid, LPVOID *ppvObj); + STDMETHOD_(ULONG,AddRef) (void); + STDMETHOD_(ULONG,Release) (void); + + // IProxyManager (only if client side) + STDMETHOD(CreateServer)(REFCLSID rclsid, DWORD clsctx, void *pv); + STDMETHOD_(BOOL, IsConnected)(void); + STDMETHOD(LockConnection)(BOOL fLock, BOOL fLastUnlockReleases); + STDMETHOD_(void, Disconnect)(); + STDMETHOD(CreateServerWithHandler)(REFCLSID rclsid, DWORD clsctx, void *pv, + REFCLSID rclsidHandler, IID iidSrv, void **ppv, + IID iidClnt, void *pClientSiteInterface); + + + IUnknown *GetCtrlUnk(void) { return m_pUnkControl; }; + IUnknown *GetServer(); + void ReleaseCtrlUnk(void); + + REFMOID GetOID (void) { return m_moid; } + HRESULT SetOID (REFMOID rmoid); + void IgnoreOID (void) { m_flags |= STDID_IGNOREID; } + void RevokeOID (void); + ULONG GetRC (void) { return m_refs; } + BOOL IsFreeThreaded(void) { return m_flags & STDID_FREETHREADED; } + BOOL IsAggregated(void) { return m_flags & STDID_AGGREGATED; } + void SetLockedInMemory() { m_flags |= STDID_LOCKEDINMEM; } + ULONG UnlockAndRelease(void); + + // methods to manipulate the strong reference count. + void IncStrongCnt(); + void DecStrongCnt(BOOL fKeepAlive); + + // method used by CoLockObjectExternal + HRESULT LockObjectExternal(BOOL fLock, BOOL fLastUR); + + // internal unknown + DECLARE_INTERNAL_UNK() + + friend INTERNAL CreateIdentityHandler(IUnknown *pUnkOuter, + DWORD flags, REFIID riid, void **ppv); + + friend INTERNAL_(void) IDTableThreadUninitializeHelper(DWORD); + + +#if DBG == 1 + void AssertValid(); + CStdIdentity(); // debug ctor for debug list head +#else + void AssertValid() { } +#endif + +private: + + BOOL IsClient() { return m_flags & STDID_CLIENT; } + void SetNowInDestructor() { m_flags |= STDID_INDESTRUCTOR; } + BOOL IsInDestructor() { return m_flags & STDID_INDESTRUCTOR; } + BOOL IsLockedOrInDestructor(){ return (m_flags & (STDID_INDESTRUCTOR | + STDID_LOCKEDINMEM)); } + + DWORD m_refs; // number of pointer refs + DWORD m_flags; // see STDID_* values above; set once. + + IUnknown *m_pUnkOuter; // controlling unknown; set once. + + IUnknown *m_pUnkControl; // the controlling unk of the object; + // this member has three possible values: + // pUnkOuter - client side; not addref'd + // pUnkControl - server side (which may + // be pUnkOuter if aggregated); addref'd + // NULL - server side, disconnected + + MOID m_moid; // the identity (OID + MID) + IExternalConnection *m_pIEC;// of the server if supported + LONG m_cStrongRefs; // count of strong references + +#if DBG==1 + CStdIdentity *m_pNext; // double chain list of instantiated + CStdIdentity *m_pPrev; // identity objects for debugging +#endif // DBG +}; + +#endif // _STDID_HXX + diff --git a/private/ole32/com/dcomrem/stream.cxx b/private/ole32/com/dcomrem/stream.cxx new file mode 100644 index 000000000..496d9bf92 --- /dev/null +++ b/private/ole32/com/dcomrem/stream.cxx @@ -0,0 +1,605 @@ +/*++ + +Microsoft Windows +Copyright (c) 1994 Microsoft Corporation. All rights reserved. + +Module Name: + stream.cxx + +Abstract: + Implements the IStream interface on a memory buffer. + +Author: + ShannonC 09-Mar-1994 + +Environment: + Windows NT and Windows 95. We do not support DOS and Win16. + +Revision History: + 12-Oct-94 ShannonC Reformat for code review. + +--*/ + +#include <ole2int.h> +#include <stream.hxx> + +CNdrStream::CNdrStream( + IN unsigned char * pData, + IN unsigned long cbMax) + : pBuffer(pData), cbBufferLength(cbMax) +/*++ + +Routine Description: + This function creates a stream on the specified memory buffer. + +Arguments: + pData - Supplies pointer to memory buffer. + cbMax - Supplies size of memory buffer. + +Return Value: + None. + +--*/ +{ + RefCount = 1; + position = 0; +} + + +ULONG STDMETHODCALLTYPE +CNdrStream::AddRef() +/*++ + +Routine Description: + Increment the reference count. + +Arguments: + +Return Value: + Reference count. + +--*/ +{ + InterlockedIncrement(&RefCount); + return (ULONG) RefCount; +} + +HRESULT STDMETHODCALLTYPE +CNdrStream::Clone( + OUT IStream **ppstm) +/*++ + +Routine Description: + Create a new IStream object. The new IStream gets an + independent seek pointer but it shares the underlying + data buffer with the original IStream object. + +Arguments: + ppstm - Pointer to the new stream. + +Return Value: + S_OK - The stream was successfully copied. + E_OUTOFMEMORY - The stream could not be copied due to lack of memory. + +--*/ +{ + HRESULT hr; + CNdrStream *pStream = new CNdrStream(pBuffer, cbBufferLength); + + if(pStream != 0) + { + pStream->position = position; + hr = S_OK; + } + else + { + hr = E_OUTOFMEMORY; + } + + *ppstm = (IStream *) pStream; + + return hr; +} + +HRESULT STDMETHODCALLTYPE +CNdrStream::Commit( + IN DWORD grfCommitFlags) +/*++ + +Routine Description: + This stream does not support transacted mode. This function does nothing. + +Arguments: + grfCommitFlags + +Return Value: + S_OK + +--*/ +{ + return S_OK; +} + +HRESULT STDMETHODCALLTYPE +CNdrStream::CopyTo( + IN IStream * pstm, + IN ULARGE_INTEGER cb, + OUT ULARGE_INTEGER *pcbRead, + OUT ULARGE_INTEGER *pcbWritten) +/*++ + +Routine Description: + Copies data from one stream to another stream. + +Arguments: + pstm - Specifies the destination stream. + cb - Specifies the number of bytes to be copied to the destination stream. + pcbRead - Returns the number of bytes read from the source stream. + pcbWritten - Returns the number of bytes written to the destination stream. + +Return Value: + S_OK - The data was successfully copied. + Other errors from IStream::Write. + +--*/ +{ + HRESULT hr; + unsigned char * pSource; + unsigned long cbRead; + unsigned long cbWritten; + unsigned long cbRemaining; + + //Check if we are going off the end of the buffer. + if(position < cbBufferLength) + cbRemaining = cbBufferLength - position; + else + cbRemaining = 0; + + if((cb.HighPart == 0) && (cb.LowPart <= cbRemaining)) + cbRead = cb.LowPart; + else + cbRead = cbRemaining; + + pSource = pBuffer + position; + + //copy the data + hr = pstm->Write(pSource, cbRead, &cbWritten); + + //advance the current position + position += cbRead; + + if (pcbRead != 0) + { + pcbRead->LowPart = cbRead; + pcbRead->HighPart = 0; + } + if (pcbWritten != 0) + { + pcbWritten->LowPart = cbWritten; + pcbWritten->HighPart = 0; + } + + return hr; +} + +HRESULT STDMETHODCALLTYPE +CNdrStream::LockRegion( + IN ULARGE_INTEGER libOffset, + IN ULARGE_INTEGER cb, + IN DWORD dwLockType) +/*++ + +Routine Description: + Range locking is not supported by this stream. + +Return Value: + STG_E_INVALIDFUNCTION. + +--*/ +{ + return STG_E_INVALIDFUNCTION; +} + +HRESULT STDMETHODCALLTYPE +CNdrStream::QueryInterface( + REFIID riid, + void **ppvObj) +/*++ + +Routine Description: + Query for an interface on the stream. The stream supports + the IUnknown and IStream interfaces. + +Arguments: + riid - Supplies the IID of the interface being requested. + ppvObject - Returns a pointer to the requested interface. + +Return Value: + S_OK + E_NOINTERFACE + +--*/ +{ + HRESULT hr; + + if ((memcmp(&riid, &IID_IUnknown, sizeof(IID)) == 0) || + (memcmp(&riid, &IID_IStream, sizeof(IID)) == 0)) + { + this->AddRef(); + *ppvObj = (IStream *) this; + hr = S_OK; + } + else + { + *ppvObj = 0; + hr = E_NOINTERFACE; + } + + return hr; +} + +HRESULT STDMETHODCALLTYPE +CNdrStream::Read( + OUT void * pv, + IN ULONG cb, + OUT ULONG *pcbRead) +/*++ + +Routine Description: + Reads data from the stream starting at the current seek pointer. + +Arguments: + pv - Returns the data read from the stream. + cb - Supplies the number of bytes to read from the stream. + pcbRead - Returns the number of bytes actually read from the stream. + +Return Value: + S_OK - The data was successfully read from the stream. + S_FALSE - The number of bytes read was smaller than the number requested. + +--*/ +{ + HRESULT hr; + unsigned long cbRead; + unsigned long cbRemaining; + + //Check if we are reading past the end of the buffer. + if(position < cbBufferLength) + cbRemaining = cbBufferLength - position; + else + cbRemaining = 0; + + if(cb <= cbRemaining) + { + cbRead = cb; + hr = S_OK; + } + else + { + cbRead = cbRemaining; + hr = S_FALSE; + } + + //copy the data + memcpy(pv, pBuffer + position, cbRead); + + //advance the current position + position += cbRead; + + if(pcbRead != 0) + *pcbRead = cbRead; + + return hr; +} + +ULONG STDMETHODCALLTYPE +CNdrStream::Release() +/*++ + +Routine Description: + Decrement the reference count. When the reference count + reaches zero, the stream is deleted. + +Arguments: + +Return Value: + Reference count. + +--*/ +{ + unsigned long count; + + count = RefCount - 1; + if(InterlockedDecrement(&RefCount) == 0) + { + count = 0; + delete this; + } + + return count; +} + + +HRESULT STDMETHODCALLTYPE +CNdrStream::Revert() +/*++ + +Routine Description: + This stream does not support transacted mode. This function does nothing. + +Arguments: + None. + +Return Value: + S_OK. + +--*/ +{ + return S_OK; +} + +HRESULT STDMETHODCALLTYPE +CNdrStream::Seek( + IN LARGE_INTEGER dlibMove, + IN DWORD dwOrigin, + OUT ULARGE_INTEGER *plibNewPosition) +/*++ + +Routine Description: + Sets the position of the seek pointer. It is an error to seek + before the beginning of the stream or past the end of the stream. + +Arguments: + dlibMove - Supplies the offset from the position specified in dwOrigin. + dwOrigin - Supplies the seek mode. + plibNewPosition - Returns the new position of the seek pointer. + +Return Value: + S_OK - The seek pointer was successfully adjusted. + STG_E_INVALIDFUNCTION - dwOrigin contains invalid value. + STG_E_SEEKERROR - The seek pointer cannot be positioned before the + beginning of the stream or past the + end of the stream. + +--*/ +{ + HRESULT hr; + long high; + long low; + unsigned long offset; + unsigned long cbRemaining; + + switch (dwOrigin) + { + case STREAM_SEEK_SET: + //Set the seek position relative to the beginning of the stream. + if((dlibMove.HighPart == 0) && (dlibMove.LowPart <= cbBufferLength)) + { + position = dlibMove.LowPart; + hr = S_OK; + } + else + { + //It is an error to seek past the end of the stream. + hr = STG_E_SEEKERROR; + } + break; + + case STREAM_SEEK_CUR: + //Set the seek position relative to the current position of the stream. + high = (long) dlibMove.HighPart; + if(high < 0) + { + //Negative offset + low = (long) dlibMove.LowPart; + offset = -low; + + if((high == -1) && (offset <= position)) + { + position -= offset; + hr = S_OK; + } + else + { + //It is an error to seek before the beginning of the stream. + hr = STG_E_SEEKERROR; + } + } + else + { + //Positive offset + if(position < cbBufferLength) + cbRemaining = cbBufferLength - position; + else + cbRemaining = 0; + + if((dlibMove.HighPart == 0) && (dlibMove.LowPart <= cbRemaining)) + { + position += dlibMove.LowPart; + hr = S_OK; + } + else + { + //It is an error to seek past the end of the stream. + hr = STG_E_SEEKERROR; + } + } + break; + + case STREAM_SEEK_END: + //Set the seek position relative to the end of the stream. + high = (long) dlibMove.HighPart; + if(high < 0) + { + //Negative offset + low = (long) dlibMove.LowPart; + offset = -low; + + if((high == -1) && (offset <= cbBufferLength)) + { + position = cbBufferLength - offset; + hr = S_OK; + } + else + { + //It is an error to seek before the beginning of the stream. + hr = STG_E_SEEKERROR; + } + } + else if(dlibMove.QuadPart == 0) + { + position = cbBufferLength; + hr = S_OK; + } + else + { + //Positive offset + //It is an error to seek past the end of the stream. + hr = STG_E_SEEKERROR; + } + break; + + default: + //dwOrigin contains an invalid value. + hr = STG_E_INVALIDFUNCTION; + } + + if (plibNewPosition != 0) + { + plibNewPosition->LowPart = position; + plibNewPosition->HighPart = 0; + } + + return hr; +} + +HRESULT STDMETHODCALLTYPE +CNdrStream::SetSize( + IN ULARGE_INTEGER libNewSize) +/*++ + +Routine Description: + Changes the size of the stream. + +Arguments: + libNewSize - Supplies the new size of the stream. + +Return Value: + S_OK - The stream size was successfully changed. + STG_E_MEDIUMFULL - The stream size could not be changed. + +--*/ +{ + HRESULT hr; + + if((libNewSize.HighPart == 0) && (libNewSize.LowPart <= cbBufferLength)) + { + cbBufferLength = libNewSize.LowPart; + hr = S_OK; + } + else + { + hr = STG_E_MEDIUMFULL; + } + + return hr; +} + +HRESULT STDMETHODCALLTYPE +CNdrStream::Stat( + OUT STATSTG * pstatstg, + IN DWORD grfStatFlag) +/*++ + +Routine Description: + This function gets information about this stream. + +Arguments: + pstatstg - Returns information about this stream. + grfStatFlg - Specifies the information to be returned in pstatstg. + +Return Value: + S_OK. + +--*/ +{ + memset(pstatstg, 0, sizeof(STATSTG)); + pstatstg->type = STGTY_STREAM; + pstatstg->cbSize.LowPart = cbBufferLength; + pstatstg->cbSize.HighPart = 0; + + return S_OK; +} + +HRESULT STDMETHODCALLTYPE +CNdrStream::UnlockRegion( + IN ULARGE_INTEGER libOffset, + IN ULARGE_INTEGER cb, + IN DWORD dwLockType) +/*++ + +Routine Description: + Range locking is not supported by this stream. + +Return Value: + STG_E_INVALIDFUNCTION. + +--*/ +{ + return STG_E_INVALIDFUNCTION; +} + +HRESULT STDMETHODCALLTYPE +CNdrStream::Write( + IN void const *pv, + IN ULONG cb, + OUT ULONG * pcbWritten) +/*++ + +Routine Description: + Write data to the stream starting at the current seek pointer. + +Arguments: + pv - Supplies the data to be written to the stream. + cb - Specifies the number of bytes to be written to the stream. + pcbWritten - Returns the number of bytes actually written to the stream. + +Return Value: + S_OK - The data was successfully written to the stream. + STG_E_MEDIUMFULL - Data cannot be written past the end of the stream. + +--*/ +{ + HRESULT hr; + unsigned long cbRemaining; + unsigned long cbWritten; + + //Check if we are writing past the end of the buffer. + if(position < cbBufferLength) + cbRemaining = cbBufferLength - position; + else + cbRemaining = 0; + + if(cb <= cbRemaining) + { + cbWritten = cb; + hr = S_OK; + } + else + { + cbWritten = cbRemaining; + hr = STG_E_MEDIUMFULL; + } + + // Write the data. + memcpy(pBuffer + position, pv, cbWritten); + + //Advance the current position + position += cbWritten; + + //update pcbWritten + if (pcbWritten != 0) + *pcbWritten = cbWritten; + + return hr; +} diff --git a/private/ole32/com/dcomrem/stream.hxx b/private/ole32/com/dcomrem/stream.hxx new file mode 100644 index 000000000..af128e10b --- /dev/null +++ b/private/ole32/com/dcomrem/stream.hxx @@ -0,0 +1,94 @@ +//+------------------------------------------------------------------- +// +// File: stream.hxx +// +// Contents: Implements the IStream interface on a memory buffer. +// +//-------------------------------------------------------------------- +#ifndef _STREAM_HXX_ +#define _STREAM_HXX_ + + +class CNdrStream : public IStream +{ +public: + virtual HRESULT STDMETHODCALLTYPE + QueryInterface( + IN REFIID riid, + OUT void **ppvObj); + + virtual ULONG STDMETHODCALLTYPE + AddRef(); + + virtual ULONG STDMETHODCALLTYPE + Release(); + + virtual HRESULT STDMETHODCALLTYPE + Read( + IN void * pv, + IN ULONG cb, + OUT ULONG * pcbRead); + + virtual HRESULT STDMETHODCALLTYPE + Write( + IN void const *pv, + IN ULONG cb, + OUT ULONG * pcbWritten); + + virtual HRESULT STDMETHODCALLTYPE + Seek( + IN LARGE_INTEGER dlibMove, + IN DWORD dwOrigin, + OUT ULARGE_INTEGER *plibNewPosition); + + virtual HRESULT STDMETHODCALLTYPE + SetSize( + IN ULARGE_INTEGER libNewSize); + + virtual HRESULT STDMETHODCALLTYPE + CopyTo( + IN IStream * pstm, + IN ULARGE_INTEGER cb, + OUT ULARGE_INTEGER *pcbRead, + OUT ULARGE_INTEGER *pcbWritten); + + virtual HRESULT STDMETHODCALLTYPE + Commit( + IN DWORD grfCommitFlags); + + virtual HRESULT STDMETHODCALLTYPE + Revert(); + + virtual HRESULT STDMETHODCALLTYPE + LockRegion( + IN ULARGE_INTEGER libOffset, + IN ULARGE_INTEGER cb, + IN DWORD dwLockType); + + virtual HRESULT STDMETHODCALLTYPE + UnlockRegion( + IN ULARGE_INTEGER libOffset, + IN ULARGE_INTEGER cb, + IN DWORD dwLockType); + + virtual HRESULT STDMETHODCALLTYPE + Stat( + OUT STATSTG * pstatstg, + IN DWORD grfStatFlag); + + virtual HRESULT STDMETHODCALLTYPE + Clone( + OUT IStream **ppstm); + + CNdrStream( + IN unsigned char * pData, + IN unsigned long cbMax); + +private: + long RefCount; + unsigned char * pBuffer; + unsigned long cbBufferLength; + unsigned long position; +}; + +#endif // _STREAM_HXX_ diff --git a/private/ole32/com/dcomrem/threads.cxx b/private/ole32/com/dcomrem/threads.cxx new file mode 100644 index 000000000..35ca9c212 --- /dev/null +++ b/private/ole32/com/dcomrem/threads.cxx @@ -0,0 +1,340 @@ +//+------------------------------------------------------------------- +// +// File: threads.cxx +// +// Contents: Rpc thread cache +// +// Classes: CRpcThread - single thread +// CRpcThreadCache - cache of threads +// +// Notes: This code represents the cache of Rpc threads used to +// make outgoing calls in the SINGLETHREADED object Rpc +// model. +// +// History: Rickhi Created +// 07-31-95 Rickhi Fix event handle leak +// +//+------------------------------------------------------------------- +#include <ole2int.h> +#include <olerem.h> +#include <chancont.hxx> // ThreadDispatch +#include <threads.hxx> + + +// static members of ThreadCache class +CRpcThread * CRpcThreadCache::_pFreeList = NULL;// list of free threads +COleStaticMutexSem CRpcThreadCache::_mxs; // for list manipulation + + +//+------------------------------------------------------------------- +// +// Member: CRpcThreadCache::RpcWorkerThreadEntry +// +// Purpose: Entry point for an Rpc worker thread. +// +// Returns: nothing, it never returns. +// +// Callers: Called ONLY by a worker thread. +// +//+------------------------------------------------------------------- +DWORD _stdcall CRpcThreadCache::RpcWorkerThreadEntry(void *param) +{ + // First thing we need to do is LoadLibrary ourselves in order to + // prevent our code from going away while this worker thread exists. + // The library will be freed when this thread exits. + + HINSTANCE hInst = LoadLibrary(L"OLE32.DLL"); + + + // construct a thread object on the stack, and call the main worker + // loop. Do this in nested scope so the dtor is called before ExitThread. + + { + CRpcThread Thrd(param); + Thrd.WorkerLoop(); + } + + + // Simultaneously free our Dll and exit our thread. This allows us to + // keep our Dll around incase a remote call was cancelled and the + // worker thread is still blocked on the call, and allows us to cleanup + // properly when all threads are done with the code. + + FreeLibraryAndExitThread(hInst, 0); + + // compiler wants a return value + return 0; +} + + +//+------------------------------------------------------------------- +// +// Member: CRpcThread::CRpcThread +// +// Purpose: Constructor for a thread object. +// +// Notes: Allocates a wakeup event. +// +// Callers: Called ONLY by a worker thread. +// +//+------------------------------------------------------------------- +CRpcThread::CRpcThread(void *param) : + _param(param), + _pNext(NULL), + _fDone(FALSE) +{ + // create the Wakeup event. Do NOT use the event cache, as there are + // some exit paths that leave this event in the signalled state! + +#ifdef _CHICAGO_ // Chicago ANSI optimization + _hWakeup = CreateEventA(NULL, FALSE, FALSE, NULL); +#else //_CHICAGO_ + _hWakeup = CreateEvent(NULL, FALSE, FALSE, NULL); +#endif //_CHICAGO_ + + ComDebOut((DEB_CHANNEL, + "CRpcThread::CRpcThread pThrd:%x _hWakeup:%x\n", this, _hWakeup)); +} + + +//+------------------------------------------------------------------- +// +// Member: CRpcThread::~CRpcThread +// +// Purpose: Destructor for an Rpc thread object. +// +// Notes: When threads are exiting, they place the CRpcThread +// object on the delete list. The main thread then later +// pulls it from the delete list and calls this destructor. +// +// Callers: Called ONLY by a worker thread. +// +//+------------------------------------------------------------------- +CRpcThread::~CRpcThread() +{ + // close the event handle. Do NOT use the event cache, since not all + // exit paths leave this event in the non-signalled state. Also, do + // not close NULL handle. + + if (_hWakeup) + { + CloseHandle(_hWakeup); + } + + ComDebOut((DEB_CHANNEL, + "CRpcThread::~CRpcThread pThrd:%x _hWakeup:%x\n", this, _hWakeup)); +} + + +//+------------------------------------------------------------------- +// +// Function: CRpcThread::WorkerLoop +// +// Purpose: Entry point for a new Rpc call thread. +// +// Notes: This dispatches a call to the function ThreadDispatch. That +// code signals an event that the COM thread is waiting on, then +// returns to us. We put the thread on the free list, and wait +// for more work to do. +// +// When there is no more work after some timeout period, we +// pull it from the free list and exit. +// +// Callers: Called ONLY by worker thread. +// +//+------------------------------------------------------------------- +void CRpcThread::WorkerLoop() +{ + // Main worker loop where we do some work then wait for more. + // When the thread has been inactive for some period of time + // it will exit the loop. + + while (!_fDone) + { + // Dispatch the call. + ThreadDispatch((CChannelCallInfo **)&_param); + + if (!_hWakeup) + { + // we failed to create an event in the ctor so we cant + // get put on the freelist to be re-awoken later with more + // work. Just exit. + break; + } + + // put the thread object on the free list + gRpcThreadCache.AddToFreeList(this); + + // Wait for more work or for a timeout. + while (WaitForSingleObjectEx(_hWakeup, THREAD_INACTIVE_TIMEOUT, 0) + == WAIT_TIMEOUT) + { + // try to remove ourselves from the queue of free threads. + // if _fDone is still FALSE, it means someone is about to + // give us more work to do (so go wait for that to happen). + + gRpcThreadCache.RemoveFromFreeList(this); + + if (_fDone) + { + // OK to exit and let this thread die. + break; + } + } + } +} + + +//+------------------------------------------------------------------- +// +// Member: CRpcThreadCache::Dispatch +// +// Purpose: Finds the first free thread, and dispatches the request +// to that thread, or creates a new thread if none are +// available. +// +// Returns: S_OK if dispatched OK +// Win32 error if it cant create a thread. +// +// Callers: Called ONLY by the main thread. +// +//+------------------------------------------------------------------- +HRESULT CRpcThreadCache::Dispatch(void *param) +{ + HRESULT hr = S_OK; + + _mxs.Request(); + + // grab the first thread from the list + CRpcThread *pThrd = _pFreeList; + + if (pThrd) + { + // update the free list pointer + _pFreeList = pThrd->GetNext(); + _mxs.Release(); + + // dispatch the call + pThrd->Dispatch(param); + } + else + { + _mxs.Release(); + + // no free threads, spin up a new one and dispatch directly to it. + DWORD dwThrdId; + HANDLE hThrd = CreateThread(NULL, 0, + RpcWorkerThreadEntry, + param, 0, + &dwThrdId); + + if (hThrd) + { + // close the thread handle since we dont need it for anything. + CloseHandle(hThrd); + } + else + { + ComDebOut((DEB_ERROR,"CreatThread failed:%x\n", GetLastError())); + hr = HRESULT_FROM_WIN32(GetLastError()); + } + } + + return hr; +} + + +//+------------------------------------------------------------------- +// +// Member: CRpcThreadCache::RemoveFromFreeList +// +// Purpose: Tries to pull a thread from the free list. +// +// Returns: pThrd->_fDone TRUE if it was successfull and thread can exit. +// pThrd->_fDone FALSE otherwise. +// +// Callers: Called ONLY by a worker thread. +// +//+------------------------------------------------------------------- +void CRpcThreadCache::RemoveFromFreeList(CRpcThread *pThrd) +{ + ComDebOut((DEB_CHANNEL, + "CRpcThreadCache::RemoveFromFreeList pThrd:%x\n", pThrd)); + + COleStaticLock lck(_mxs); + + // pull pThrd from the free list. if it is not on the free list + // then either it has just been dispatched OR ClearFreeList has + // just removed it, set _fDone to TRUE, and kicked the wakeup event. + + CRpcThread *pPrev = NULL; + CRpcThread *pCurr = _pFreeList; + + while (pCurr && pCurr != pThrd) + { + pPrev = pCurr; + pCurr = pCurr->GetNext(); + } + + if (pCurr == pThrd) + { + // remove it from the free list. + if (pPrev) + pPrev->SetNext(pThrd->GetNext()); + else + _pFreeList = pThrd->GetNext(); + + // tell the thread to wakeup and exit + pThrd->WakeAndExit(); + } +} + + +//+------------------------------------------------------------------- +// +// Member: CRpcThreadCache::ClearFreeList +// +// Purpose: Cleans up all threads on the free list. +// +// Notes: For any threads still on the free list, it pulls them +// off the freelist, sets their _fDone flag to TRUE, and +// kicks their event to wake them up. When the threads +// wakeup, they will exit. +// +// We do not free active threads. The only way for a thread +// to still be active at this time is if it was making an Rpc +// call and was cancelled by the message filter and the thread has +// still not returned to us. We cant do much about that until +// Rpc supports cancel for all protocols. If the thread ever +// does return to us, it will eventually idle-out and delete +// itself. This is safe because the threads LoadLibrary OLE32. +// +// Callers: Called ONLY by the last COM thread during +// ProcessUninitialize. +// +//+------------------------------------------------------------------- +void CRpcThreadCache::ClearFreeList(void) +{ + ComDebOut((DEB_CHANNEL, "CRpcThreadCache::ClearFreeList\n")); + + { + COleStaticLock lck(_mxs); + + CRpcThread *pThrd = _pFreeList; + while (pThrd) + { + // use temp variable incase thread exits before we call GetNext + CRpcThread *pThrdNext = pThrd->GetNext(); + pThrd->WakeAndExit(); + pThrd = pThrdNext; + } + + _pFreeList = NULL; + + // the lock goes out of scope at this point. we dont want to hold + // it while we sleep. + } + + // yield to let the other threads run if necessary. + Sleep(0); +} diff --git a/private/ole32/com/dcomrem/threads.hxx b/private/ole32/com/dcomrem/threads.hxx new file mode 100644 index 000000000..5c66b3941 --- /dev/null +++ b/private/ole32/com/dcomrem/threads.hxx @@ -0,0 +1,190 @@ +//+------------------------------------------------------------------- +// +// File: threads.hxx +// +// Contents: Rpc thread cache +// +// Classes: CRpcThread - single thread +// CRpcThreadCache - cache of threads +// +// Notes: This code represents the cache of Rpc threads used to +// make outgoing calls in the APARTMENT object Rpc +// model. +// +// History: Rickhi Created +// 07-31-95 Rickhi Fix event handle leak +// +//+------------------------------------------------------------------- +#ifndef __THREADS_HXX__ +#define __THREADS_HXX__ + +#include <olesem.hxx> + + +// inactive thread timeout. this is how long a thread will sit idle +// in the thread cache before deleting itself. + +#define THREAD_INACTIVE_TIMEOUT 30000 // in milliseconds + + +//+------------------------------------------------------------------- +// +// Class: CRpcThread +// +// Purpose: Represents one thread in the cache of Rpc callout +// threads. +// +// Notes: In order to make Rpc calls in the OLE Single-Threaded +// model, we must leave the main thread and perform the +// blocking Rpc call on a worker thread. This object +// represents such a worker thread. +// +//+------------------------------------------------------------------- +class CRpcThread +{ +public: + CRpcThread(void *param); + ~CRpcThread(); + + // dispatch methods + void Dispatch(void *param); + void WorkerLoop(); + CRpcThread * GetNext(void) { return _pNext; } + void SetNext(CRpcThread *pNext) { _pNext = pNext; } + + // cleanup methods + void WakeAndExit(); + +private: + + HANDLE _hWakeup; // thread wakeup event + BOOL _fDone; // completion flag + + void * _param; // parameter packet + CRpcThread * _pNext; // next thread in free list +}; + + +//+------------------------------------------------------------------- +// +// Class: CRpcThreadCache +// +// Purpose: Holds a cache of Rpc threads. It finds the first +// free CRpcThread or creates a new one and dispatches +// the call to it. +// +// Notes: the free list is kept in a most recently used order +// so that uneeded threads can time out and go away. +// +//+------------------------------------------------------------------- +class CRpcThreadCache +{ +public: + // no ctor, since only work is init'ing a static + // no dtor since nothing to do + + // dispatch methods + HRESULT Dispatch(void *param); + void AddToFreeList(CRpcThread *pThrd); + + // cleanup methods + void RemoveFromFreeList(CRpcThread *pThrd); + void ClearFreeList(void); + +private: + + static DWORD _stdcall RpcWorkerThreadEntry(void *param); + + static CRpcThread * _pFreeList; // list of free threads + static COleStaticMutexSem _mxs; // for list manipulation +}; + + +// Rpc SendReceive thread pool. This must be static to handle Rpc threads +// that block and dont return until after CoUninitialize has been called. + +extern CRpcThreadCache gRpcThreadCache; + + + +//+------------------------------------------------------------------- +// +// Member: CRpcThread::Dispatch +// +// Purpose: wakes up a thread blocked in WorkerLoop. +// +// Notes: folks who want to execute code on another thread +// call this method. It fills in the parameter packet +// and wakes up the sleeping thread. +// +// Callers: Called ONLY by the COM thread. +// +//+------------------------------------------------------------------- +inline void CRpcThread::Dispatch(void *param) +{ + CairoleDebugOut((DEB_CHANNEL, + "Dispatch pThrd:%x param:%x\n", this, param)); + + // set the call info and the completion event + _param = param; + + // signal the Rpc thread to wakeup + SetEvent(_hWakeup); +} + + +//+------------------------------------------------------------------- +// +// Member: CRpcThread::WakeAndExit +// +// Purpose: Tells the thread object to free itself +// +// Note: This is called by CRpcThreadCache::RemoveFromFreeList +// when we want to free this thread, eg at ProcessUninitialize. +// +// Callers: Called by the COM thread OR worker thread. +// +//+------------------------------------------------------------------- +inline void CRpcThread::WakeAndExit() +{ + // _fDone should only be set inside this function and in the + // constructor. _fDone must only ever transition from FALSE + // to TRUE and that must only happen once in the life of this + // object. + + Win4Assert(_fDone == FALSE); + _fDone = TRUE; + + CairoleDebugOut((DEB_CHANNEL, + "CRpcThreadCache:WakeAndExit pThrd:%x _hWakeup:%x\n", this, _hWakeup)); + + SetEvent(_hWakeup); +} + + +//+------------------------------------------------------------------- +// +// Member: CRpcThreadCache::AddToFreeLlist +// +// Purpose: puts a thread back onto the free list after the +// thread has completed its job. +// +// Callers: Called ONLY by worker thread. +// +//+------------------------------------------------------------------- +inline void CRpcThreadCache::AddToFreeList(CRpcThread *pThrd) +{ + COleStaticLock lck(_mxs); + + // place this thread on the front of the free list. it is + // important that we add and remove only from the front of + // the list so that unused threads will eventually time out + // and release themselves...that is, it keeps our thread pool + // as small as possible. + + pThrd->SetNext(_pFreeList); + _pFreeList = pThrd; +} + + +#endif // __THREADS_HXX__ |