簡易 x86で関数のサイズを調べる

#include "stdafx.h"

#include <map>
#include <vector>

#include <detours.h>
#pragma comment(lib, "detours.lib")

BOOL is_jcc_instruction(PBYTE pb)
{
	if (!pb)
		return FALSE;

	if (0x70 <= pb[0] && pb[0] <= 0x7F)
	{
		// jo, jno, jb, jnb, jz, jnz, jbe, ja, js, jns, jp, jnp, jl, jnl, jle, jnle
		return TRUE;
	}
	else if (pb[0] == 0xE3)
	{
		// jcxz
		return TRUE;
	}
	else if (pb[0] == 0xE9 || pb[0] == 0xEA || pb[0] == 0xEB)
	{
		// jmp long, jmp special, jmp short
		return TRUE;
	}
	else if (pb[0] == 0x0F && (0x80 <= pb[1] && pb[1] <= 0x8F))
	{
		// two bytes jump
		return TRUE;
	}

	return FALSE;
}

BOOL is_jmp_instruction(PBYTE pb)
{
	if (!pb)
		return FALSE;

	if (pb[0] == 0xE9 || pb[0] == 0xEA || pb[0] == 0xEB)
	{
		// jmp long, jmp special, jmp short
		return TRUE;
	}
	else if (pb[0] == 0xFF && pb[1] == 0x25)
	{
		// jmp dword ptr []
		return TRUE;
	}

	return FALSE;
}

int main()
{
	HMODULE hModule;
	LPVOID lpvFunction;

	//hModule = GetModuleHandle(_T("KERNEL32"));
	hModule = GetModuleHandle(_T("KERNELBASE"));
	//hModule = GetModuleHandle(_T("ntdll"));
	if (hModule == NULL)
	{
		printf("baka\n");
		return 1;
	}

	lpvFunction = GetProcAddress(hModule, "CreateFileA");
	if (!lpvFunction)
	{
		printf("baka\n");
		return 1;
	}

	//
	BYTE bDst[256];
	PVOID pDstPool;
	PVOID pSrc, pSrcNext;
	PVOID pTarget;
	LONG lExtra;
	UINT uSize;
	std::vector<PVOID> vpTarget;
	std::map<PVOID, BOOL> mpTarget;

	pSrc = lpvFunction;
	pDstPool = &bDst[256];

	do
	{
		lExtra = 0;
		pTarget = NULL;

		// Returns : the address of the next instruction
		pSrcNext = DetourCopyInstruction(bDst, &pDstPool, pSrc, &pTarget, &lExtra);

		// Get instruction size
		uSize = PBYTE(pSrcNext) - PBYTE(pSrc);

		// Move to next instruction
		pSrc = pSrcNext;

		if (pTarget)
		{
			if (is_jcc_instruction(bDst))
			{
				// jump if condition is met
				printf("jump if condition is met - target = %p\n", pTarget);

				vpTarget.push_back(pTarget);
				mpTarget[pTarget] = TRUE;
			}
			else if (bDst[0] == 0xFF)
			{
				if (bDst[1] == 0x15)
				{
					// call dword ptr []
					printf("call dword ptr [] - target = %p\n", pTarget);
				}
				else if (bDst[1] == 0x25)
				{
					// jmp dword ptr []
					printf("jmp dword ptr [] - target = %p\n", pTarget);
				}
			}
		}

		// printf("size = %d\n", uSize);
		if (bDst[0] == 0xCC)
		{
			// int 3
			break;
		}
		else if (bDst[0] == 0xC2 || bDst[0] == 0xC3)
		{
			// ret X
			auto it = mpTarget.find(pSrcNext);
			if (it == mpTarget.end())
			{
				// not found jump toward next instruction
				break;
			}

			// found jump toward next instruction
		}
		else if (is_jmp_instruction(bDst))
		{
			// jmp or jmp dword ptr []
			auto it = mpTarget.find(pSrcNext);
			if (it == mpTarget.end())
			{
				// not found jump toward next instruction
				break;
			}

			// found jump toward next instruction
		}
	} while (TRUE);

	printf("start = %p, end = %p, size = %d\n", lpvFunction, pSrcNext, PBYTE(pSrcNext) - PBYTE(lpvFunction));

	return 0;
}