用C++ RAII思想写windows驱动
作者:互联网
随着VS2017的普遍使用,C++驱动的编写已经不用完全使用C语言了,
C语言时代,最难处理的是资源泄露,分配了的还要写代码进行回收
现在,有了C++,完全可以利用C++来提升代码质量
(1)首先我们应注意,C++隐藏了太多细节,这使得写驱动容易蓝屏
所以,用C++写驱动,一定要实现一个简化的C++运行环境
(2)接下来我们实现一个基于CreateProcessCallback的禁止某些进程执行的防护型驱动
首先实现全局环境
struct CxxGlobal
{
public:
CxxGlobal();
~CxxGlobal();
public:
//在这里放置所有的全局对象
wstring str_DeviceName;
wstring str_SYMBOLIC_LINK_NAME;
CMyEResourceLock g_lock_MapCreateProcessCBKRule;//CreateProcessCallback 的规则链表锁
CLinkListMapByString g_MapCpCBKRule;//CreateProcessCallback的规则表
};
我们利用new的重载,让这个全局对象创建在给定的全局数据区
CxxGlobal* CreateCxxGlobal_OnDriverEntry(); //请在DriverEntry里调用,确保一定成功后再,否则DriverEntry应该返回失败
void DeleteCxxGlobal_OnDriverUnload();//请在DriverUnload里调用,再此之前请做好Uninit的操作
void * __cdecl operator new(size_t sizex)
{
return Alloc((DWORD)sizex);
}
void * __cdecl operator new(size_t sizex, void* p)
{
return p;
}
void __cdecl operator delete(void * ptr, size_t sizex)
{
Free(ptr);
}
unsigned char g_cxxGlobalObj[sizeof(CxxGlobal)];
CxxGlobal& Global = *((CxxGlobal*)(&(g_cxxGlobalObj[0])));
CxxGlobal* g_pGlobal = 0;
CxxGlobal * CreateCxxGlobal_OnDriverEntry()
{
if (!g_pGlobal)
{
g_pGlobal = new((CxxGlobal*)(&(g_cxxGlobalObj[0]))) CxxGlobal();
}
return g_pGlobal;
}
void DeleteCxxGlobal_OnDriverUnload()
{
if (g_pGlobal)
g_pGlobal->~CxxGlobal();
g_pGlobal = NULL;
}
(3)字符串类的实现
template<typename chType>
__inline int Tstrlen(const chType* pstr)
{
int r = 0;
if (pstr)
{
while (*pstr)
{
r++;
pstr++;
}
}
return r;
}
template<typename chType>
class _chStringT
{
public:
chType* m_p;
int length;
public:
_chStringT() :m_p(nullptr), length(0) {}
_chStringT(const _chStringT& dest)
{
m_p = (chType*)Alloc((dest.length + 1) * sizeof(chType));
if (m_p)
{
length = dest.length;
for (int i = 0; i < length; i++)
m_p[i] = dest.m_p[i];
m_p[length] = 0;
}
}
_chStringT(const chType* pStr)
{
int len = Tstrlen(pStr);
m_p = (chType*)Alloc((len + 1) * sizeof(chType));
if (m_p)
{
length = len;
for (int i = 0; i < length; i++)
m_p[i] = pStr[i];
m_p[length] = 0;
}
}
_chStringT& operator=(const _chStringT& dest)
{
__freedata();
m_p = (chType*)Alloc((dest.length + 1) * sizeof(chType));
if (m_p)
{
length = dest.length;
for (int i = 0; i < length; i++)
m_p[i] = dest.m_p[i];
m_p[length] = 0;
}
return (*this);
}
_chStringT& operator=(const chType* pStr)
{
__freedata();
int len = Tstrlen(pStr);
m_p = (chType*)Alloc((len + 1) * sizeof(chType));
if (m_p)
{
length = len;
for (int i = 0; i < length; i++)
m_p[i] = pStr[i];
m_p[length] = 0;
}
return (*this);
}
~_chStringT()
{
__freedata();
}
void resize(int len, chType initch)
{
chType* pNewBuffer = (chType*)Alloc((len + 1) * sizeof(chType));
if (pNewBuffer)
{
__freedata();
m_p = pNewBuffer;
length = len;
for (int i = 0; i < len; i++)
pNewBuffer[i] = initch;
//end 0
pNewBuffer[len] = 0;
}
}
chType& operator[](int index)
{
if (m_p)
{
if (index >= 0 && index < length)
return m_p[index];
}
return __dummy_ref_char_for_guard_safe;
}
const chType* c_str()
{
return m_p;
}
protected:
void __freedata()
{
if (m_p)
{
Free(m_p);
m_p = nullptr;
length = 0;
}
}
static chType __dummy_ref_char_for_guard_safe;
};
template<typename chType>
chType _chStringT<chType>::__dummy_ref_char_for_guard_safe = 0;
typedef _chStringT<char> string;
typedef _chStringT<wchar_t> wstring;
(4)RAII思想设计的资源锁
class CMyEResourceLock
{
public:
CMyEResourceLock();
~CMyEResourceLock();
ERESOURCE* GetResource(){return &m_lock_eresource;}
private:
ERESOURCE m_lock_eresource;
};
#define MY_ENTER_CRITICAL_SECTION(section) KeEnterCriticalRegion();ExAcquireResourceSharedLite(section, TRUE)
#define MY_LEAVE_CRITICAL_SECTION(section) ExReleaseResourceLite(section);KeLeaveCriticalRegion()
class CAutoLockEResource
{
public:
CAutoLockEResource(ERESOURCE* pResource)
{
m_pRes = pResource;
MY_ENTER_CRITICAL_SECTION(m_pRes);
}
~CAutoLockEResource()
{
if (m_pRes)
{
MY_LEAVE_CRITICAL_SECTION(m_pRes);
}
}
private:
ERESOURCE* m_pRes;
};
(5)链表类
typedef struct _LinkListMapObj
{
WCHAR szPath[260];
BOOL bLetItOK;
} LinkListMapObj;
typedef BOOL PFN_LinkListMapTravelCallback(void* pContext, LinkListMapObj* pDataObj, _LinkListEntryT<LinkListMapObj>* pCurrent, _LinkListEntryT<LinkListMapObj>* pParentOfCurrent);
class CLinkListMapByString
{
public:
CLinkListMapByString();
~CLinkListMapByString();
public:
BOOL AddString(LPCWSTR lpString, BOOL bLetItOK);
LinkListMapObj* FindByName(LPCWSTR lpFind);
BOOL Erase(LPCWSTR lpString);
BOOL DeleteEntry(_LinkListEntryT<LinkListMapObj> * pEntry, _LinkListEntryT<LinkListMapObj> * pParentOfEntry);
void Clear();
void Travel(PFN_LinkListMapTravelCallback fnCallback, void * pContext);
int GetCount();
private:
LinkListT<LinkListMapObj> m_list;
};
(6)实现CreateProcessCallback
VOID ProcessMonitorCallback(IN HANDLE hParentId, IN HANDLE hProcessId, IN BOOLEAN bCreate)
{
wchar_t szPath[520];
NTSTATUS status;
HANDLE procHandle = NULL;
CLIENT_ID ClientId;
OBJECT_ATTRIBUTES Obja;
Obja.Length = sizeof(Obja);
Obja.RootDirectory = 0;
Obja.ObjectName = 0;
Obja.Attributes = 0;
Obja.SecurityDescriptor = 0;
Obja.SecurityQualityOfService = 0;
ClientId.UniqueProcess = (HANDLE)hProcessId;
ClientId.UniqueThread = 0;
//不管创建什么程序都关闭程序
if (bCreate && !g_CpCbkUninit) //bCreate 为True表示创建程序
{
//调用函数ZwOpenProcess函数,通过进程pid号获得进程句柄
status = ZwOpenProcess(&procHandle, PROCESS_ALL_ACCESS, &Obja, &ClientId);
if (procHandle != NULL)
{
UNICODE_STRING us_ProcName;
NTSTATUS status2;
memset(szPath, 0, sizeof(szPath));
us_ProcName.MaximumLength = 512;
us_ProcName.Length = 0;
us_ProcName.Buffer = szPath;
status2 = GetProcessImageName(&us_ProcName, procHandle);
if (STATUS_SUCCESS == status2)
{
wchar_t* pToFree = 0;
UNICODE_STRING usDosPath = {};
BOOL bInitDosPath = InitUnicode_AllocIoVolPathToDosPath(&us_ProcName, &usDosPath);
if (usDosPath.Buffer)
pToFree = usDosPath.Buffer;
ANSI_STRING ansDosPath = {};
NTSTATUS ntAllocAnsiPath = STATUS_BUFFER_TOO_SMALL;
if (IsXp2003() && usDosPath.Buffer)
{
ntAllocAnsiPath = RtlUnicodeStringToAnsiString(&ansDosPath, &usDosPath, TRUE);
}
//判断是不是在链表的
if (IsMatchProcessInList(usDosPath.Buffer ? usDosPath.Buffer : us_ProcName.Buffer))
{
status = ZwTerminateProcess(procHandle, 0);
}
if (pToFree)
Free(pToFree);
if (ntAllocAnsiPath == STATUS_SUCCESS)
RtlFreeAnsiString(&ansDosPath);
}
ZwClose(procHandle);
}
}
}
标签:__,CxxGlobal,windows,RAII,void,C++,chType,int,length 来源: https://blog.csdn.net/lif12345/article/details/122308967