diff options
Diffstat (limited to 'LibOVR/Src/Displays/OVR_Win32_RenderShim.cpp')
-rw-r--r-- | LibOVR/Src/Displays/OVR_Win32_RenderShim.cpp | 974 |
1 files changed, 974 insertions, 0 deletions
diff --git a/LibOVR/Src/Displays/OVR_Win32_RenderShim.cpp b/LibOVR/Src/Displays/OVR_Win32_RenderShim.cpp new file mode 100644 index 0000000..a74ab75 --- /dev/null +++ b/LibOVR/Src/Displays/OVR_Win32_RenderShim.cpp @@ -0,0 +1,974 @@ +/************************************************************************************
+
+Filename : OVR_Win32_DisplayShim.cpp
+Content : Shared static functions for inclusion that allow for an application
+ to inject the usermode driver into an application
+Created : March 21, 2014
+Authors : Dean Beeler
+
+Copyright : Copyright 2014 Oculus VR, LLC All Rights reserved.
+
+Use of this software is subject to the terms of the Oculus Inc license
+agreement provided at the time of installation or download, or which
+otherwise accompanies this software in either electronic or hard copy form.
+
+************************************************************************************/
+
+#include "../../Include/OVR_Version.h"
+
+#ifndef AVOID_LIB_OVR
+#include "../Kernel/OVR_Types.h" // Without this we can get warnings (due to VC++ bugs) about _malloca being redefined, and miss LibOVR overrides.
+#endif
+
+#include <windows.h>
+#include <DbgHelp.h>
+#include <malloc.h>
+
+#include "OVR_Win32_Dxgi_Display.h"
+#include "OVR_Win32_ShimVersion.h"
+
+#if AVOID_LIB_OVR
+#define IN_COMPATIBILITY_MODE() (0)
+#else
+#include "OVR_Win32_Display.h"
+#define IN_COMPATIBILITY_MODE() OVR::Display::InCompatibilityMode()
+#endif
+
+#pragma comment(lib, "DbgHelp.lib")
+
+#define WIDE_TO_MB(wideString) \
+ int wideString ## _slen = (int)wcslen(wideString); \
+ char* wideString ## _cstr = (char*)alloca(wideString ## _slen * 2); \
+ int count = WideCharToMultiByte(GetACP(), 0, wideString, -1, wideString ## _cstr, wideString ## _slen * 2, NULL, NULL); \
+ wideString ## _cstr[count] = '\0';
+
+// Forward declarations
+// These functions are implemented in OVR_Win32_DisplayDevice.cpp.
+
+BOOL WINAPI OVRIsInitializingDisplay( PVOID context, UINT width, UINT height );
+BOOL WINAPI OVRIsCreatingBackBuffer( PVOID context );
+BOOL WINAPI OVRShouldVSync( );
+ULONG WINAPI OVRRiftForContext( PVOID context, HANDLE driverHandle );
+BOOL WINAPI OVRCloseRiftForContext( PVOID context, HANDLE driverHandle, ULONG rift );
+BOOL WINAPI OVRWindowDisplayResolution( PVOID context, UINT* width, UINT* height,
+ UINT* titleHeight, UINT* borderWidth,
+ BOOL* vsyncEnabled );
+BOOL WINAPI OVRExpectedResolution( PVOID context, UINT* width, UINT* height, UINT* rotationInDegrees );
+BOOL WINAPI OVRShouldEnableDebug();
+BOOL WINAPI OVRMirroringEnabled( PVOID context );
+HWND WINAPI OVRGetWindowForContext(PVOID context);
+BOOL WINAPI OVRShouldPresentOnContext(PVOID context);
+
+static const char* GFX_DRIVER_KEY_FMT = "SYSTEM\\CurrentControlSet\\Control\\Class\\{4d36e968-e325-11ce-bfc1-08002be10318}\\%04d";
+#ifdef _WIN64
+static const char* RTFilter = "OVRDisplayRT64.dll";
+static const char* UMFilter = "OVRDisplay64.dll";
+#else
+static const char* RTFilter = "OVRDisplayRT32.dll";
+static const char* UMFilter = "OVRDisplay32.dll";
+#endif
+static const char* OptimusDrivers = "nvumdshimx.dll nvumdshim.dll";
+
+typedef enum OVRTargetAPI
+{
+ DirectX,
+ OpenGL
+};
+
+static PVOID lastContext = NULL;
+LINK_APPLICATION_DRIVER appDriver = {0};
+static INT apiVersion = 10;
+
+static CHAR* ReadRegStr(HKEY keySub, const char* keyName, const char* valName)
+{
+ CHAR *val = NULL;
+ REGSAM access = KEY_READ;
+ HKEY hKey;
+
+TryAgainWOW64:
+ NTSTATUS res = RegOpenKeyExA( keySub, keyName, 0, access, &hKey );
+ if ( res == ERROR_SUCCESS )
+ {
+ DWORD valLen;
+ res = RegQueryValueExA( hKey, valName, NULL, NULL, NULL, &valLen );
+ if( res == ERROR_SUCCESS ) {
+ val = (CHAR*)calloc( valLen + 1, sizeof(CHAR) );
+ res = RegQueryValueExA( hKey, valName, NULL, NULL, (LPBYTE)val, &valLen );
+
+ if( res == ERROR_SUCCESS )
+ {
+ CHAR* byte = val;
+ for( DWORD j = 0; j < valLen; ++j )
+ {
+ if( byte[j] == 0 )
+ byte[j] = ' ';
+ }
+ }
+ else
+ {
+ free( val );
+ val = NULL;
+ }
+ }
+ RegCloseKey( hKey );
+ }
+
+ if( res == ERROR_FILE_NOT_FOUND && keySub == HKEY_LOCAL_MACHINE && access == KEY_READ ) {
+#ifdef _WIN64
+ access = KEY_READ | KEY_WOW64_32KEY;
+#else
+ access = KEY_READ | KEY_WOW64_64KEY;
+#endif
+ goto TryAgainWOW64;
+ }
+ return val;
+}
+
+#define OLD_DATA_BACKUP_SIZE 16
+
+static WinLoadLibraryA oldProcA = NULL; // Note: This is used to indicate that the shim is in place
+static WinLoadLibraryExA oldProcExA = NULL;
+static WinLoadLibraryW oldProcW = NULL;
+static WinLoadLibraryExW oldProcExW = NULL;
+static WinGetModuleHandleExA oldProcModExA = NULL;
+static WinGetModuleHandleExW oldProcModExW = NULL;
+static WinDirect3DCreate9 oldDirectX9Create = NULL;
+static BYTE oldDirectX9CreateData[OLD_DATA_BACKUP_SIZE];
+static WinDirect3DCreate9Ex oldDirectX9ExCreate = NULL;
+static BYTE oldDirectX9ExCreateData[OLD_DATA_BACKUP_SIZE];
+static WinCreateDXGIFactory oldCreateDXGIFactory = NULL;
+static BYTE oldCreateDXGIFactoryData[OLD_DATA_BACKUP_SIZE];
+static WinCreateDXGIFactory1 oldCreateDXGIFactory1 = NULL;
+static BYTE oldCreateDXGIFactory1Data[OLD_DATA_BACKUP_SIZE];
+static WinCreateDXGIFactory2 oldCreateDXGIFactory2 = NULL;
+static BYTE oldCreateDXGIFactory2Data[OLD_DATA_BACKUP_SIZE];
+
+#define NUM_LOADER_LIBS 4
+
+static const char* loaderLibraryList[NUM_LOADER_LIBS] = {
+ "kernel32.dll",
+ "api-ms-win-core-libraryloader-l1-2-0.dll",
+ "api-ms-win-core-libraryloader-l1-1-0.dll",
+ "api-ms-win-core-libraryloader-l1-1-1.dll"
+};
+
+enum ShimedLibraries
+{
+ ShimLibDXGI = 0,
+ ShimLibD3D9 = 1,
+ ShimLibD3D11 = 2,
+ ShimLibDXGIDebug = 3,
+ ShimLibD3D10Core = 4,
+ ShimLibD3D10 = 5,
+ ShimLibGL = 6,
+ ShimCountMax = 7
+};
+
+static const char* dllList[ShimCountMax] = {
+ "dxgi.dll",
+ "d3d9.dll",
+ "d3d11.dll",
+ "dxgidebug.dll",
+ "d3d10core.dll",
+ "d3d10.dll",
+ "opengl32.dll"
+};
+
+static HINSTANCE oldLoaderInstances[ShimCountMax] = { NULL };
+static PROC oldLoaderProcA[ShimCountMax][NUM_LOADER_LIBS] = { { NULL } };
+static PROC oldLoaderProcW[ShimCountMax][NUM_LOADER_LIBS] = { { NULL } };
+static PROC oldLoaderProcExA[ShimCountMax][NUM_LOADER_LIBS] = { { NULL } };
+static PROC oldLoaderProcExW[ShimCountMax][NUM_LOADER_LIBS] = { { NULL } };
+static PROC oldLoaderProcModExA[ShimCountMax][NUM_LOADER_LIBS] = { { NULL } };
+static PROC oldLoaderProcModExW[ShimCountMax][NUM_LOADER_LIBS] = { { NULL } };
+
+static HMODULE rtFilterModule = NULL;
+
+static bool checkForOverride( LPCSTR libFileName, OVRTargetAPI& targetApi )
+{
+ for (int i=0; ; i++)
+ {
+ CHAR keyString[256] = {0};
+
+ sprintf_s( keyString, 256, GFX_DRIVER_KEY_FMT, i );
+
+ CHAR* infSection = ReadRegStr( HKEY_LOCAL_MACHINE, keyString, "InfSection" );
+
+ // No provider name means we're out of display enumerations
+ if( infSection == NULL )
+ break;
+
+ free(infSection);
+
+ // Check 64-bit driver names followed by 32-bit driver names
+ const char* driverKeys[] = {"UserModeDriverName", "UserModeDriverNameWoW", "OpenGLDriverName", "OpenGLDriverNameWoW", "InstalledDisplayDrivers" };
+ for( int j = 0; j < 6; ++j )
+ {
+ CHAR userModeList[4096] = {0};
+
+ switch(j)
+ {
+ case 5:
+ strcpy_s( userModeList, 4095, OptimusDrivers );
+ break;
+ default:
+ {
+ CHAR* regString = ReadRegStr( HKEY_LOCAL_MACHINE, keyString, driverKeys[j] );
+ if( regString )
+ {
+ strcpy_s( userModeList, 4095, regString );
+ free( regString );
+ }
+
+ }
+ break;
+ }
+
+ char *nextToken = NULL;
+
+ if( userModeList )
+ {
+ char* first = strtok_s( userModeList, " ", &nextToken );
+ while( first )
+ {
+ if( strstr( libFileName, first ) != 0 )
+ {
+ if( j < 2 )
+ targetApi = DirectX;
+ else
+ targetApi = OpenGL;
+
+ return true;
+ }
+ first = strtok_s( NULL, " ", &nextToken );
+ }
+ }
+ }
+ }
+
+ return false;
+}
+
+static HMODULE createShim( LPCSTR lpLibFileName, OVRTargetAPI targetAPI )
+{
+ //Sleep(10000);
+ if( IN_COMPATIBILITY_MODE() )
+ {
+ return (*oldProcA)( lpLibFileName );
+ }
+
+ UNREFERENCED_PARAMETER( targetAPI );
+
+ HMODULE result = NULL;
+
+ result = (*oldProcA)( UMFilter );
+
+ if( result )
+ {
+ PreloadLibraryFn loadFunc = (PreloadLibraryFn)GetProcAddress( result, "PreloadLibrary" );
+ if( loadFunc )
+ {
+ HRESULT localRes = (*loadFunc)( oldProcA, lpLibFileName, &appDriver );
+ if( localRes != S_OK )
+ result = NULL;
+ }
+ }
+
+ if( !result )
+ {
+ MessageBox(nullptr, L"Unable to find the Rift Display Driver.", L"Configuration Error", MB_OK | MB_ICONERROR);
+ result = (*oldProcA)( lpLibFileName );
+ }
+ return result;
+}
+
+static HMODULE
+ WINAPI
+ OVRLoadLibraryA(
+ __in LPCSTR lpLibFileName
+ )
+{
+ OVRTargetAPI targetAPI = DirectX;
+ bool needShim = checkForOverride( lpLibFileName, targetAPI );
+ if( !needShim )
+ return (*oldProcA)( lpLibFileName );
+
+ return createShim( lpLibFileName, targetAPI );
+}
+
+static HMODULE
+ WINAPI
+ OVRLoadLibraryW(
+ __in LPCWSTR lpLibFileName
+ )
+{
+ WIDE_TO_MB(lpLibFileName); // Convert lpLibFileName -> lpLibFileName_cstr
+
+ OVRTargetAPI targetAPI = DirectX;
+
+ bool needShim = checkForOverride( lpLibFileName_cstr, targetAPI );
+ if( !needShim )
+ return (*oldProcW)( lpLibFileName );
+
+ return createShim( lpLibFileName_cstr, targetAPI );
+}
+
+static HMODULE
+ WINAPI
+ OVRLoadLibraryExA(
+ __in LPCSTR lpLibFileName,
+ __reserved HANDLE hFile,
+ __in DWORD dwFlags
+
+ )
+{
+ OVRTargetAPI targetAPI = DirectX;
+
+ bool needShim = checkForOverride( lpLibFileName, targetAPI );
+ if( !needShim )
+ return (*oldProcExA)( lpLibFileName, hFile, dwFlags );
+
+ // FIXME: Don't throw away the flags parameter
+ return createShim( lpLibFileName, targetAPI );
+}
+
+static HMODULE
+ WINAPI
+ OVRLoadLibraryExW(
+ __in LPCWSTR lpLibFileName,
+ __reserved HANDLE hFile,
+ __in DWORD dwFlags
+ )
+{
+ WIDE_TO_MB(lpLibFileName); // Convert lpLibFileName -> lpLibFileName_cstr
+
+ OVRTargetAPI targetAPI = DirectX;
+
+ bool needShim = checkForOverride( lpLibFileName_cstr, targetAPI );
+ if( !needShim )
+ return (*oldProcExW)( lpLibFileName, hFile, dwFlags );
+
+ // FIXME: Don't throw away the flags parameter
+ return createShim( lpLibFileName_cstr, targetAPI );
+}
+
+static BOOL WINAPI OVRGetModuleHandleExA(
+ __in DWORD dwFlags,
+ __in_opt LPCSTR lpModuleName,
+ __out HMODULE *phModule
+ )
+{
+ OVRTargetAPI targetAPI = DirectX;
+
+ bool needShim = checkForOverride( lpModuleName, targetAPI );
+ if( !needShim )
+ {
+ return (*oldProcModExA)( dwFlags, lpModuleName, phModule );
+ }
+
+ *phModule = createShim( lpModuleName, targetAPI );
+
+ return TRUE;
+}
+
+static BOOL WINAPI OVRGetModuleHandleExW(
+ __in DWORD dwFlags,
+ __in_opt LPCWSTR lpModuleName,
+ __out HMODULE *phModule
+ )
+{
+ WIDE_TO_MB(lpModuleName); // Convert lpModuleName -> lpModuleName_cstr
+
+ OVRTargetAPI targetAPI = DirectX;
+
+ bool needShim = checkForOverride( lpModuleName_cstr, targetAPI );
+ if( !needShim )
+ {
+ return (*oldProcModExW)( dwFlags, lpModuleName, phModule );
+ }
+
+ *phModule = createShim( lpModuleName_cstr, targetAPI );
+
+ return TRUE;
+}
+
+#ifdef _AMD64_
+static void restoreFunction( PROC pfnHookAPIAddr, PBYTE oldData )
+{
+ static const LONGLONG addressSize = sizeof(PROC);
+ static const LONGLONG jmpSize = addressSize + 6;
+
+ DWORD oldProtect;
+ VirtualProtect((LPVOID)pfnHookAPIAddr, OLD_DATA_BACKUP_SIZE,
+ PAGE_EXECUTE_READWRITE, &oldProtect);
+
+ memcpy(pfnHookAPIAddr, oldData, OLD_DATA_BACKUP_SIZE);
+
+ VirtualProtect((LPVOID)pfnHookAPIAddr, OLD_DATA_BACKUP_SIZE, oldProtect, NULL);
+}
+
+static void setFunction( PROC pfnHookAPIAddr, PROC replacementFunction, PBYTE oldData )
+{
+ static const LONGLONG addressSize = sizeof(PROC);
+ static const LONGLONG jmpSize = addressSize + 6;
+
+ INT_PTR jumpOffset = (INT_PTR)replacementFunction;
+
+ DWORD oldProtect;
+ VirtualProtect((LPVOID)pfnHookAPIAddr, OLD_DATA_BACKUP_SIZE,
+ PAGE_EXECUTE_READWRITE, &oldProtect);
+
+ memcpy(oldData, pfnHookAPIAddr, OLD_DATA_BACKUP_SIZE);
+
+ PBYTE functionData = (PBYTE)pfnHookAPIAddr;
+ functionData[0] = 0xff; // JMP [RIP+0]
+ functionData[1] = 0x25; //
+ functionData[2] = 0x00; //
+ functionData[3] = 0x00; //
+ functionData[4] = 0x00; //
+ functionData[5] = 0x00; //
+ memcpy( functionData + 6, &jumpOffset, sizeof( INT_PTR ) );
+
+ VirtualProtect((LPVOID)pfnHookAPIAddr, OLD_DATA_BACKUP_SIZE, oldProtect, NULL);
+}
+#else
+static void restoreFunction( PROC pfnHookAPIAddr, PBYTE oldData )
+{
+ static const LONGLONG addressSize = sizeof(PROC);
+ static const LONGLONG jmpSize = addressSize + 1;
+
+ DWORD oldProtect;
+ VirtualProtect((LPVOID)pfnHookAPIAddr, jmpSize,
+ PAGE_EXECUTE_READWRITE, &oldProtect);
+
+ memcpy(pfnHookAPIAddr, oldData, jmpSize);
+
+ VirtualProtect((LPVOID)pfnHookAPIAddr, jmpSize, oldProtect, NULL);
+}
+
+static void setFunction( PROC pfnHookAPIAddr, PROC replacementFunction, PBYTE oldData )
+{
+ static const LONGLONG addressSize = sizeof(PROC);
+ static const LONGLONG jmpSize = addressSize + 1;
+
+ INT_PTR jumpOffset = (INT_PTR)replacementFunction - (INT_PTR)pfnHookAPIAddr - jmpSize;
+
+ DWORD oldProtect;
+ VirtualProtect((LPVOID)pfnHookAPIAddr, jmpSize,
+ PAGE_EXECUTE_READWRITE, &oldProtect);
+
+ memcpy(oldData, pfnHookAPIAddr, jmpSize);
+
+ PBYTE functionData = (PBYTE)pfnHookAPIAddr;
+ memcpy( oldData, functionData, jmpSize );
+ functionData[0] = 0xe9;
+ memcpy( functionData + 1, &jumpOffset, sizeof( INT_PTR ) );
+
+ VirtualProtect((LPVOID)pfnHookAPIAddr, jmpSize, oldProtect, NULL);
+}
+#endif
+
+static BOOL WINAPI OVRLocalIsInitializingDisplay( PVOID context, UINT width, UINT height )
+{
+ UINT expectedWidth, expectedHeight, rotation;
+
+ OVRExpectedResolution( context, &expectedWidth, &expectedHeight, &rotation );
+
+ if( appDriver.pfnActiveAPIVersion )
+ apiVersion = (*appDriver.pfnActiveAPIVersion)( context );
+
+ switch( apiVersion )
+ {
+ case 1: // OpenGL
+ case 10: // DirectX 1X
+ if( width == expectedWidth && height == expectedHeight )
+ return TRUE;
+ break;
+ case 9: // DirectX 9
+ if( rotation == 90 || rotation == 270 )
+ {
+ if( width == expectedHeight && height == expectedWidth )
+ return TRUE;
+ }
+ else
+ {
+ if( width == expectedWidth && height == expectedHeight )
+ return TRUE;
+ }
+ break;
+ default:
+ break;
+ }
+
+ return FALSE;
+}
+
+
+HRESULT APIENTRY OVRDirect3DCreate9Ex(UINT SDKVersion, void** aDevice)
+{
+ apiVersion = 9;
+
+ HRESULT result = S_OK;
+
+ restoreFunction( (PROC)oldDirectX9ExCreate, oldDirectX9ExCreateData );
+
+ if (IN_COMPATIBILITY_MODE())
+ {
+ result = (*oldDirectX9ExCreate)(SDKVersion, aDevice);
+ }
+ else
+ {
+ WinDirect3DCreate9Ex createFunction = (WinDirect3DCreate9Ex)GetProcAddress(rtFilterModule, "Direct3DCreate9Ex");
+ result = (*createFunction)(SDKVersion, aDevice);
+ }
+
+ setFunction( (PROC)oldDirectX9ExCreate, (PROC)OVRDirect3DCreate9Ex, oldDirectX9ExCreateData );
+
+ printf("%s result 0x%x\n", __FUNCTION__, result);
+
+ return result;
+}
+
+void* APIENTRY OVRDirect3DCreate9(UINT SDKVersion)
+{
+ void* result = NULL;
+
+ OVRDirect3DCreate9Ex( SDKVersion, &result );
+
+ return result;
+}
+
+HRESULT APIENTRY OVRCreateDXGIFactory(
+ __in REFIID riid,
+ __out void **ppFactory
+ )
+{
+ HRESULT result = E_FAIL;
+
+ restoreFunction( (PROC)oldCreateDXGIFactory, oldCreateDXGIFactoryData );
+
+ if (IN_COMPATIBILITY_MODE())
+ {
+ result = (*oldCreateDXGIFactory)(riid, ppFactory);
+ }
+ else
+ {
+ WinCreateDXGIFactory createFunction = (WinCreateDXGIFactory)GetProcAddress(rtFilterModule, "CreateDXGIFactory");
+ result = (*createFunction)(riid, ppFactory);
+ }
+
+ setFunction( (PROC)oldCreateDXGIFactory, (PROC)OVRCreateDXGIFactory, oldCreateDXGIFactoryData );
+
+ printf("%s result 0x%x\n", __FUNCTION__, result);
+
+ return result;
+}
+
+HRESULT APIENTRY OVRCreateDXGIFactory1(
+ __in REFIID riid,
+ __out void **ppFactory
+ )
+{
+ HRESULT result = E_FAIL;
+
+ restoreFunction( (PROC)oldCreateDXGIFactory1, oldCreateDXGIFactory1Data );
+
+ if (IN_COMPATIBILITY_MODE())
+ {
+ result = (*oldCreateDXGIFactory1)(riid, ppFactory);
+ }
+ else
+ {
+ WinCreateDXGIFactory1 createFunction = (WinCreateDXGIFactory1)GetProcAddress(rtFilterModule, "CreateDXGIFactory1");
+ result = (*createFunction)(riid, ppFactory);
+ }
+
+ setFunction( (PROC)oldCreateDXGIFactory1, (PROC)OVRCreateDXGIFactory1, oldCreateDXGIFactory1Data );
+
+ printf("%s result 0x%x\n", __FUNCTION__, result);
+
+ return result;
+}
+
+HRESULT APIENTRY OVRCreateDXGIFactory2(
+ __in UINT flags,
+ __in const IID &riid,
+ __out void **ppFactory
+ )
+{
+ HRESULT result = E_FAIL;
+
+ restoreFunction( (PROC)oldCreateDXGIFactory2, oldCreateDXGIFactory2Data );
+
+ if (IN_COMPATIBILITY_MODE())
+ {
+ result = (*oldCreateDXGIFactory2)(flags, riid, ppFactory);
+ }
+ else
+ {
+ WinCreateDXGIFactory2 createFunction = (WinCreateDXGIFactory2)GetProcAddress(rtFilterModule, "CreateDXGIFactory2");
+ result = (*createFunction)(flags, riid, ppFactory);
+ }
+
+ setFunction( (PROC)oldCreateDXGIFactory2, (PROC)OVRCreateDXGIFactory2, oldCreateDXGIFactory2Data );
+
+ printf("%s result 0x%x\n", __FUNCTION__, result);
+
+ return result;
+}
+
+static PROC SetProcAddressDirect(
+ __in HINSTANCE hInstance,
+ __in LPCSTR lpProcName,
+ __in PROC newFunction,
+ __inout BYTE* oldData
+ )
+{
+ static const LONGLONG addressSize = sizeof(PROC);
+ static const LONGLONG jmpSize = addressSize + 1;
+
+ PROC result = NULL;
+
+ PROC pfnHookAPIAddr = GetProcAddress( hInstance, lpProcName );
+
+ if( pfnHookAPIAddr )
+ {
+ result = pfnHookAPIAddr;
+
+ setFunction( pfnHookAPIAddr, newFunction, oldData );
+ }
+
+ return result;
+}
+
+static PROC SetProcAddressA(
+ __in HINSTANCE targetModule,
+ __in LPCSTR lpLibFileName,
+ __in LPCSTR lpProcName,
+ __in PROC newFunction
+ )
+{
+ HMODULE hModule = LoadLibraryA( lpLibFileName );
+ if(hModule == NULL)
+ return NULL;
+
+ // To do: call FreeLibrary(hModule) at the appropriate time.
+ PROC pfnHookAPIAddr = GetProcAddress(hModule, lpProcName );
+
+ HINSTANCE hInstance = targetModule;
+
+ ULONG ulSize;
+ PIMAGE_IMPORT_DESCRIPTOR pImportDesc =
+ (PIMAGE_IMPORT_DESCRIPTOR)ImageDirectoryEntryToData(
+ hInstance,
+ TRUE,
+ IMAGE_DIRECTORY_ENTRY_IMPORT,
+ &ulSize
+ );
+
+ while (pImportDesc->Name)
+ {
+ PSTR pszModName = (PSTR)((PBYTE) hInstance + pImportDesc->Name);
+ if (_stricmp(pszModName, lpLibFileName) == 0)
+ break;
+ pImportDesc++;
+ }
+
+ PIMAGE_THUNK_DATA pThunk =
+ (PIMAGE_THUNK_DATA)((PBYTE) hInstance + pImportDesc->FirstThunk);
+
+ while (pThunk->u1.Function)
+ {
+ PROC* ppfn = (PROC*) &pThunk->u1.Function;
+ BOOL bFound = (*ppfn == pfnHookAPIAddr);
+
+ if (bFound)
+ {
+ MEMORY_BASIC_INFORMATION mbi;
+ VirtualQuery(
+ ppfn,
+ &mbi,
+ sizeof(MEMORY_BASIC_INFORMATION)
+ );
+ VirtualProtect(
+ mbi.BaseAddress,
+ mbi.RegionSize,
+ PAGE_READWRITE,
+ &mbi.Protect);
+
+ *ppfn = *newFunction;
+
+ DWORD dwOldProtect;
+ VirtualProtect(
+ mbi.BaseAddress,
+ mbi.RegionSize,
+ mbi.Protect,
+ &dwOldProtect
+ );
+ break;
+ }
+ pThunk++;
+ }
+
+ return pfnHookAPIAddr;
+}
+
+
+bool checkUMDriverOverrides(void* context)
+{
+ lastContext = context;
+ if (oldProcA != NULL)
+ return true;
+
+ PreloadLibraryRTFn loadFunc = NULL;
+ for( int i = 0; i < ShimCountMax; ++i )
+ {
+ HINSTANCE hInst = NULL;
+ try
+ {
+ hInst = LoadLibraryA(dllList[i]);
+ }
+ catch(...)
+ {
+ }
+
+ oldLoaderInstances[i] = hInst;
+
+ if (hInst == NULL)
+ continue;
+
+ ShimedLibraries libCount = (ShimedLibraries)i;
+ switch( libCount )
+ {
+ case ShimLibDXGI:
+ oldCreateDXGIFactory = (WinCreateDXGIFactory)SetProcAddressDirect( hInst, "CreateDXGIFactory", (PROC)OVRCreateDXGIFactory, oldCreateDXGIFactoryData );
+ oldCreateDXGIFactory1 = (WinCreateDXGIFactory1)SetProcAddressDirect( hInst, "CreateDXGIFactory1", (PROC)OVRCreateDXGIFactory1, oldCreateDXGIFactory1Data );
+ oldCreateDXGIFactory2 = (WinCreateDXGIFactory2)SetProcAddressDirect( hInst, "CreateDXGIFactory2", (PROC)OVRCreateDXGIFactory2, oldCreateDXGIFactory2Data );
+ break;
+ case ShimLibD3D9:
+ oldDirectX9Create = (WinDirect3DCreate9)SetProcAddressDirect( hInst, "Direct3DCreate9", (PROC)OVRDirect3DCreate9, oldDirectX9CreateData );
+ oldDirectX9ExCreate = (WinDirect3DCreate9Ex)SetProcAddressDirect( hInst, "Direct3DCreate9Ex", (PROC)OVRDirect3DCreate9Ex, oldDirectX9ExCreateData );
+ break;
+ default:
+ break;
+ }
+
+ for (int j = 0; j < NUM_LOADER_LIBS; ++j)
+ {
+ const char* loaderLibrary = loaderLibraryList[j];
+
+ PROC temp = NULL;
+ temp = SetProcAddressA(hInst, loaderLibrary, "LoadLibraryA", (PROC)OVRLoadLibraryA);
+ if (!oldProcA)
+ {
+ oldProcA = (WinLoadLibraryA)temp;
+ }
+ oldLoaderProcA[i][j] = temp;
+
+ temp = SetProcAddressA(hInst, loaderLibrary, "LoadLibraryW", (PROC)OVRLoadLibraryW);
+ if (!oldProcW)
+ {
+ oldProcW = (WinLoadLibraryW)temp;
+ }
+ oldLoaderProcW[i][j] = temp;
+
+ temp = SetProcAddressA(hInst, loaderLibrary, "LoadLibraryExA", (PROC)OVRLoadLibraryExA);
+ if (!oldProcExA)
+ {
+ oldProcExA = (WinLoadLibraryExA)temp;
+ }
+ oldLoaderProcExA[i][j] = temp;
+
+ temp = SetProcAddressA(hInst, loaderLibrary, "LoadLibraryExW", (PROC)OVRLoadLibraryExW);
+ if (!oldProcExW)
+ {
+ oldProcExW = (WinLoadLibraryExW)temp;
+ }
+ oldLoaderProcExW[i][j] = temp;
+
+ temp = SetProcAddressA(hInst, loaderLibrary, "GetModuleHandleExA", (PROC)OVRGetModuleHandleExA);
+ if (!oldProcModExA)
+ {
+ oldProcModExA = (WinGetModuleHandleExA)temp;
+ }
+ oldLoaderProcModExA[i][j] = temp;
+
+ temp = SetProcAddressA(hInst, loaderLibrary, "GetModuleHandleExW", (PROC)OVRGetModuleHandleExW);
+ if (!oldProcModExW)
+ {
+ oldProcModExW = (WinGetModuleHandleExW)temp;
+ }
+ oldLoaderProcModExW[i][j] = temp;
+ }
+
+ if (loadFunc == NULL)
+ {
+ loadFunc = (PreloadLibraryRTFn)GetProcAddress(hInst, "PreloadLibraryRT");
+ }
+ }
+
+ rtFilterModule = oldProcA ? (*oldProcA)(RTFilter) : NULL;
+
+ IsCreatingBackBuffer backBufferFunc = NULL;
+ ShouldVSync shouldVSyncFunc = NULL;
+ GetRTFilterVersion getRTFilterVersionFunc = NULL;
+
+ if (rtFilterModule != NULL)
+ {
+ loadFunc = (PreloadLibraryRTFn)GetProcAddress(rtFilterModule, "PreloadLibraryRT");
+ backBufferFunc = (IsCreatingBackBuffer)GetProcAddress(rtFilterModule, "OVRIsCreatingBackBuffer");
+ shouldVSyncFunc = (ShouldVSync)GetProcAddress(rtFilterModule, "OVRShouldVSync");
+ getRTFilterVersionFunc = (GetRTFilterVersion)GetProcAddress(rtFilterModule, "OVRGetRTFilterVersion");
+ }
+
+ if (loadFunc == NULL)
+ {
+ MessageBox(nullptr, L"Unable to load the Oculus Display Driver. Please reinstall the Oculus Runtime.", L"Configuration Error", MB_OK | MB_ICONERROR);
+ return false;
+ }
+
+ if (getRTFilterVersionFunc == NULL)
+ {
+ WCHAR message[1000] = {};
+ swprintf_s(message, L"This app requires the %d.%d.%d version of the Oculus Runtime.", OVR_MAJOR_VERSION, OVR_MINOR_VERSION, OVR_BUILD_VERSION);
+ MessageBox(nullptr, message, L"Configuration Error", MB_OK | MB_ICONERROR);
+ return false;
+ }
+
+ // Verify that we are running with the appropriate display driver
+ {
+ const ULONG rtFilterVersion = (*getRTFilterVersionFunc)();
+ const ULONG rtFilterMajor = OVR_GET_VERSION_MAJOR(rtFilterVersion);
+ const ULONG rtFilterMinor = OVR_GET_VERSION_MINOR(rtFilterVersion);
+
+ if ((rtFilterMajor != OVR_RTFILTER_VERSION_MAJOR) || (rtFilterMinor < OVR_RTFILTER_VERSION_MINOR))
+ {
+ WCHAR message[1000] = {};
+ swprintf_s(message, L"This app requires the %d.%d.%d version of the Oculus Runtime.", OVR_MAJOR_VERSION, OVR_MINOR_VERSION, OVR_BUILD_VERSION);
+ MessageBox(nullptr, message, L"Configuration Error", MB_OK | MB_ICONERROR);
+ return false;
+ }
+ }
+
+ appDriver.version = OVR_RENDER_SHIM_VERSION_MAJOR;
+ appDriver.context = lastContext;
+
+// appDriver.pfnInitializingDisplay = OVRIsInitializingDisplay;
+ appDriver.pfnInitializingDisplay = OVRLocalIsInitializingDisplay;
+ appDriver.pfnRiftForContext = OVRRiftForContext;
+ appDriver.pfnCloseRiftForContext = OVRCloseRiftForContext;
+ appDriver.pfnWindowDisplayResolution = OVRWindowDisplayResolution;
+ appDriver.pfnShouldEnableDebug = OVRShouldEnableDebug;
+ appDriver.pfnIsCreatingBackBuffer = (backBufferFunc == NULL) ? OVRIsCreatingBackBuffer : backBufferFunc;
+ appDriver.pfnShouldVSync = (shouldVSyncFunc == NULL) ? OVRShouldVSync : shouldVSyncFunc;
+ appDriver.pfnExpectedResolution = OVRExpectedResolution;
+ appDriver.pfnMirroringEnabled = OVRMirroringEnabled;
+ appDriver.pfnGetWindowForContext = OVRGetWindowForContext;
+ appDriver.pfnPresentRiftOnContext = OVRShouldPresentOnContext;
+
+ appDriver.pfnDirect3DCreate9 = oldDirectX9Create;
+ appDriver.pfnDirect3DCreate9Ex = oldDirectX9ExCreate;
+ appDriver.pfnCreateDXGIFactory = oldCreateDXGIFactory;
+ appDriver.pfnCreateDXGIFactory1 = oldCreateDXGIFactory1;
+ appDriver.pfnCreateDXGIFactory2 = oldCreateDXGIFactory2;
+
+ (*loadFunc)( &appDriver );
+
+ return true;
+}
+
+void clearUMDriverOverrides()
+{
+ if (oldProcA != NULL)
+ {
+ // Unpatch all the things.
+
+ if (oldCreateDXGIFactory)
+ {
+ restoreFunction((PROC)oldCreateDXGIFactory, oldCreateDXGIFactoryData);
+ }
+ if (oldCreateDXGIFactory1)
+ {
+ restoreFunction((PROC)oldCreateDXGIFactory1, oldCreateDXGIFactory1Data);
+ }
+ if (oldCreateDXGIFactory2)
+ {
+ restoreFunction((PROC)oldCreateDXGIFactory2, oldCreateDXGIFactory2Data);
+ }
+ if (oldDirectX9Create)
+ {
+ restoreFunction((PROC)oldDirectX9Create, oldDirectX9CreateData);
+ }
+ if (oldDirectX9ExCreate)
+ {
+ restoreFunction((PROC)oldDirectX9ExCreate, oldDirectX9ExCreateData);
+ }
+ if (oldCreateDXGIFactory2)
+ {
+ restoreFunction((PROC)oldCreateDXGIFactory2, oldCreateDXGIFactory2Data);
+ }
+
+ for (int i = 0; i < ShimCountMax; ++i)
+ {
+ HINSTANCE hInst = oldLoaderInstances[i];
+
+ if (hInst != NULL)
+ {
+ for (int j = 0; j < NUM_LOADER_LIBS; ++j)
+ {
+ const char* loaderLibrary = loaderLibraryList[j];
+
+ if (oldLoaderProcA[j])
+ {
+ SetProcAddressA(hInst, loaderLibrary, "LoadLibraryA", oldLoaderProcA[i][j]);
+ }
+ if (oldLoaderProcW[j])
+ {
+ SetProcAddressA(hInst, loaderLibrary, "LoadLibraryW", oldLoaderProcW[i][j]);
+ }
+ if (oldLoaderProcExA[j])
+ {
+ SetProcAddressA(hInst, loaderLibrary, "LoadLibraryExA", oldLoaderProcExA[i][j]);
+ }
+ if (oldLoaderProcExW[j])
+ {
+ SetProcAddressA(hInst, loaderLibrary, "LoadLibraryExW", oldLoaderProcExW[i][j]);
+ }
+ if (oldLoaderProcModExA[j])
+ {
+ SetProcAddressA(hInst, loaderLibrary, "GetModuleHandleExA", oldLoaderProcModExA[i][j]);
+ }
+ if (oldLoaderProcModExW[j])
+ {
+ SetProcAddressA(hInst, loaderLibrary, "GetModuleHandleExW", oldLoaderProcModExW[i][j]);
+ }
+ }
+
+ FreeLibrary(hInst);
+ }
+ }
+
+ if (rtFilterModule != NULL)
+ {
+ LPFNCANUNLOADNOW pfnCanUnloadNow = (LPFNCANUNLOADNOW)GetProcAddress(rtFilterModule, "DllCanUnloadNow");
+ if (pfnCanUnloadNow && pfnCanUnloadNow() == S_OK)
+ {
+ FreeLibrary(rtFilterModule);
+ rtFilterModule = NULL;
+ }
+ }
+
+ oldProcA = NULL;
+ oldProcExA = NULL;
+ oldProcW = NULL;
+ oldProcExW = NULL;
+ oldProcModExA = NULL;
+ oldProcModExW = NULL;
+ oldDirectX9Create = NULL;
+ oldDirectX9ExCreate = NULL;
+ oldCreateDXGIFactory = NULL;
+ oldCreateDXGIFactory1 = NULL;
+ oldCreateDXGIFactory2 = NULL;
+ lastContext = NULL;
+ }
+}
|