系统相关
首页 > 系统相关> > 用C++ RAII思想写windows驱动

用C++ RAII思想写windows驱动

作者:互联网

随着VS2017的普遍使用,C++驱动的编写已经不用完全使用C语言了,

C语言时代,最难处理的是资源泄露,分配了的还要写代码进行回收

现在,有了C++,完全可以利用C++来提升代码质量

(1)首先我们应注意,C++隐藏了太多细节,这使得写驱动容易蓝屏

所以,用C++写驱动,一定要实现一个简化的C++运行环境

(2)接下来我们实现一个基于CreateProcessCallback的禁止某些进程执行的防护型驱动

首先实现全局环境

struct CxxGlobal
{
public:
	CxxGlobal();
	~CxxGlobal();

public:
	//在这里放置所有的全局对象
	wstring               str_DeviceName;
	wstring               str_SYMBOLIC_LINK_NAME;

	CMyEResourceLock      g_lock_MapCreateProcessCBKRule;//CreateProcessCallback 的规则链表锁


	CLinkListMapByString  g_MapCpCBKRule;//CreateProcessCallback的规则表
};

我们利用new的重载,让这个全局对象创建在给定的全局数据区

CxxGlobal* CreateCxxGlobal_OnDriverEntry(); //请在DriverEntry里调用,确保一定成功后再,否则DriverEntry应该返回失败
void DeleteCxxGlobal_OnDriverUnload();//请在DriverUnload里调用,再此之前请做好Uninit的操作



void * __cdecl operator new(size_t sizex)
{
	return Alloc((DWORD)sizex);
}

void * __cdecl operator new(size_t sizex, void* p)
{
	return p;
}

void __cdecl operator delete(void * ptr, size_t sizex)
{
	Free(ptr);
}

unsigned char g_cxxGlobalObj[sizeof(CxxGlobal)];

CxxGlobal& Global = *((CxxGlobal*)(&(g_cxxGlobalObj[0])));
CxxGlobal* g_pGlobal = 0;

CxxGlobal * CreateCxxGlobal_OnDriverEntry()
{
	if (!g_pGlobal)
	{
		g_pGlobal = new((CxxGlobal*)(&(g_cxxGlobalObj[0]))) CxxGlobal();
	}

	return g_pGlobal;
}

void DeleteCxxGlobal_OnDriverUnload()
{
	if (g_pGlobal)
		g_pGlobal->~CxxGlobal();
	g_pGlobal = NULL;
}

(3)字符串类的实现


template<typename chType>
__inline int Tstrlen(const chType* pstr)
{
	int r = 0;
	if (pstr)
	{
		while (*pstr)
		{
			r++;
			pstr++;
		}
	}
	return r;
}

template<typename chType>
class _chStringT
{
public:
	chType* m_p;
	int length;

public:
	_chStringT() :m_p(nullptr), length(0) {}
	_chStringT(const _chStringT& dest)
	{
		m_p = (chType*)Alloc((dest.length + 1) * sizeof(chType));
		if (m_p)
		{
			length = dest.length;
			for (int i = 0; i < length; i++)
				m_p[i] = dest.m_p[i];
			m_p[length] = 0;
		}
	}

	_chStringT(const chType* pStr)
	{
		int len = Tstrlen(pStr);
		m_p = (chType*)Alloc((len + 1) * sizeof(chType));
		if (m_p)
		{
			length = len;
			for (int i = 0; i < length; i++)
				m_p[i] = pStr[i];
			m_p[length] = 0;
		}
	}

	_chStringT& operator=(const _chStringT& dest)
	{
		__freedata();
		m_p = (chType*)Alloc((dest.length + 1) * sizeof(chType));
		if (m_p)
		{
			length = dest.length;
			for (int i = 0; i < length; i++)
				m_p[i] = dest.m_p[i];
			m_p[length] = 0;
		}
		return (*this);
	}

	_chStringT& operator=(const chType* pStr)
	{
		__freedata();
		int len = Tstrlen(pStr);
		m_p = (chType*)Alloc((len + 1) * sizeof(chType));
		if (m_p)
		{
			length = len;
			for (int i = 0; i < length; i++)
				m_p[i] = pStr[i];
			m_p[length] = 0;
		}
		return (*this);
	}

	~_chStringT()
	{
		__freedata();
	}
	void resize(int len, chType initch)
	{
		chType* pNewBuffer = (chType*)Alloc((len + 1) * sizeof(chType));
		if (pNewBuffer)
		{
			__freedata();
			m_p = pNewBuffer;
			length = len;
			for (int i = 0; i < len; i++)
				pNewBuffer[i] = initch;
			//end 0
			pNewBuffer[len] = 0;
		}
	}
	chType& operator[](int index)
	{
		if (m_p)
		{
			if (index >= 0 && index < length)
				return m_p[index];
		}
		return __dummy_ref_char_for_guard_safe;
	}
	const chType* c_str()
	{
		return m_p;
	}

protected:
	void __freedata()
	{
		if (m_p)
		{
			Free(m_p);
			m_p = nullptr;
			length = 0;
		}
	}

	static chType  __dummy_ref_char_for_guard_safe;
};



template<typename chType>
chType _chStringT<chType>::__dummy_ref_char_for_guard_safe = 0;

typedef _chStringT<char> string;
typedef _chStringT<wchar_t> wstring;

(4)RAII思想设计的资源锁


class CMyEResourceLock
{
public:
	CMyEResourceLock();
	~CMyEResourceLock();
	ERESOURCE* GetResource(){return &m_lock_eresource;}

private:
	ERESOURCE m_lock_eresource;
};

#define MY_ENTER_CRITICAL_SECTION(section) KeEnterCriticalRegion();ExAcquireResourceSharedLite(section, TRUE)
#define MY_LEAVE_CRITICAL_SECTION(section) ExReleaseResourceLite(section);KeLeaveCriticalRegion()


class CAutoLockEResource
{
public:
	CAutoLockEResource(ERESOURCE* pResource)
	{
		m_pRes = pResource;
		MY_ENTER_CRITICAL_SECTION(m_pRes);
	}

	~CAutoLockEResource()
	{
		if (m_pRes)
		{
			MY_LEAVE_CRITICAL_SECTION(m_pRes);
		}
	}

private:
	ERESOURCE* m_pRes;
};

(5)链表类


typedef struct _LinkListMapObj
{
	WCHAR szPath[260];
	BOOL  bLetItOK;

} LinkListMapObj;

typedef BOOL PFN_LinkListMapTravelCallback(void* pContext, LinkListMapObj* pDataObj, _LinkListEntryT<LinkListMapObj>* pCurrent, _LinkListEntryT<LinkListMapObj>* pParentOfCurrent);


class CLinkListMapByString
{
public:
	CLinkListMapByString();
	~CLinkListMapByString();

public:
	BOOL AddString(LPCWSTR lpString, BOOL bLetItOK);
	LinkListMapObj*  FindByName(LPCWSTR lpFind);
	BOOL Erase(LPCWSTR lpString);
	BOOL DeleteEntry(_LinkListEntryT<LinkListMapObj> * pEntry, _LinkListEntryT<LinkListMapObj> * pParentOfEntry);
	void Clear();
	void Travel(PFN_LinkListMapTravelCallback fnCallback, void * pContext);
	int GetCount();

private:
	LinkListT<LinkListMapObj> m_list;
};

(6)实现CreateProcessCallback

VOID ProcessMonitorCallback(IN HANDLE hParentId, IN HANDLE hProcessId, IN BOOLEAN bCreate)
{
	wchar_t szPath[520];
	NTSTATUS status;
	HANDLE procHandle = NULL;
	CLIENT_ID ClientId;
	OBJECT_ATTRIBUTES Obja;
	Obja.Length = sizeof(Obja);
	Obja.RootDirectory = 0;
	Obja.ObjectName = 0;
	Obja.Attributes = 0;
	Obja.SecurityDescriptor = 0;
	Obja.SecurityQualityOfService = 0;
	ClientId.UniqueProcess = (HANDLE)hProcessId;
	ClientId.UniqueThread = 0;

	//不管创建什么程序都关闭程序
	if (bCreate && !g_CpCbkUninit)   //bCreate 为True表示创建程序
	{
		//调用函数ZwOpenProcess函数,通过进程pid号获得进程句柄
		status = ZwOpenProcess(&procHandle, PROCESS_ALL_ACCESS, &Obja, &ClientId);

		if (procHandle != NULL)
		{
			UNICODE_STRING  us_ProcName;
			NTSTATUS status2;
			memset(szPath, 0, sizeof(szPath));
			us_ProcName.MaximumLength = 512;
			us_ProcName.Length = 0;
			us_ProcName.Buffer = szPath;
			status2 = GetProcessImageName(&us_ProcName, procHandle);

			if (STATUS_SUCCESS == status2)
			{
				wchar_t* pToFree = 0;
				UNICODE_STRING usDosPath = {};
				BOOL bInitDosPath = InitUnicode_AllocIoVolPathToDosPath(&us_ProcName, &usDosPath);
				if (usDosPath.Buffer)
					pToFree = usDosPath.Buffer;

				ANSI_STRING ansDosPath = {};
				NTSTATUS ntAllocAnsiPath = STATUS_BUFFER_TOO_SMALL;
				if (IsXp2003() && usDosPath.Buffer)
				{
					ntAllocAnsiPath = RtlUnicodeStringToAnsiString(&ansDosPath, &usDosPath, TRUE);
				}


				//判断是不是在链表的

				if (IsMatchProcessInList(usDosPath.Buffer ? usDosPath.Buffer : us_ProcName.Buffer))
				{
					status = ZwTerminateProcess(procHandle, 0);
				}

				if (pToFree)
					Free(pToFree);
				if (ntAllocAnsiPath == STATUS_SUCCESS)
					RtlFreeAnsiString(&ansDosPath);
			}

			ZwClose(procHandle);
		}
	}
}

标签:__,CxxGlobal,windows,RAII,void,C++,chType,int,length
来源: https://blog.csdn.net/lif12345/article/details/122308967