diff options
Diffstat (limited to 'LibOVR/Src/Displays/OVR_Win32_RenderShim.cpp')
-rw-r--r-- | LibOVR/Src/Displays/OVR_Win32_RenderShim.cpp | 1965 |
1 files changed, 991 insertions, 974 deletions
diff --git a/LibOVR/Src/Displays/OVR_Win32_RenderShim.cpp b/LibOVR/Src/Displays/OVR_Win32_RenderShim.cpp index a74ab75..48bcf2b 100644 --- a/LibOVR/Src/Displays/OVR_Win32_RenderShim.cpp +++ b/LibOVR/Src/Displays/OVR_Win32_RenderShim.cpp @@ -1,974 +1,991 @@ -/************************************************************************************
-
-Filename : OVR_Win32_DisplayShim.cpp
-Content : Shared static functions for inclusion that allow for an application
- to inject the usermode driver into an application
-Created : March 21, 2014
-Authors : Dean Beeler
-
-Copyright : Copyright 2014 Oculus VR, LLC All Rights reserved.
-
-Use of this software is subject to the terms of the Oculus Inc license
-agreement provided at the time of installation or download, or which
-otherwise accompanies this software in either electronic or hard copy form.
-
-************************************************************************************/
-
-#include "../../Include/OVR_Version.h"
-
-#ifndef AVOID_LIB_OVR
-#include "../Kernel/OVR_Types.h" // Without this we can get warnings (due to VC++ bugs) about _malloca being redefined, and miss LibOVR overrides.
-#endif
-
-#include <windows.h>
-#include <DbgHelp.h>
-#include <malloc.h>
-
-#include "OVR_Win32_Dxgi_Display.h"
-#include "OVR_Win32_ShimVersion.h"
-
-#if AVOID_LIB_OVR
-#define IN_COMPATIBILITY_MODE() (0)
-#else
-#include "OVR_Win32_Display.h"
-#define IN_COMPATIBILITY_MODE() OVR::Display::InCompatibilityMode()
-#endif
-
-#pragma comment(lib, "DbgHelp.lib")
-
-#define WIDE_TO_MB(wideString) \
- int wideString ## _slen = (int)wcslen(wideString); \
- char* wideString ## _cstr = (char*)alloca(wideString ## _slen * 2); \
- int count = WideCharToMultiByte(GetACP(), 0, wideString, -1, wideString ## _cstr, wideString ## _slen * 2, NULL, NULL); \
- wideString ## _cstr[count] = '\0';
-
-// Forward declarations
-// These functions are implemented in OVR_Win32_DisplayDevice.cpp.
-
-BOOL WINAPI OVRIsInitializingDisplay( PVOID context, UINT width, UINT height );
-BOOL WINAPI OVRIsCreatingBackBuffer( PVOID context );
-BOOL WINAPI OVRShouldVSync( );
-ULONG WINAPI OVRRiftForContext( PVOID context, HANDLE driverHandle );
-BOOL WINAPI OVRCloseRiftForContext( PVOID context, HANDLE driverHandle, ULONG rift );
-BOOL WINAPI OVRWindowDisplayResolution( PVOID context, UINT* width, UINT* height,
- UINT* titleHeight, UINT* borderWidth,
- BOOL* vsyncEnabled );
-BOOL WINAPI OVRExpectedResolution( PVOID context, UINT* width, UINT* height, UINT* rotationInDegrees );
-BOOL WINAPI OVRShouldEnableDebug();
-BOOL WINAPI OVRMirroringEnabled( PVOID context );
-HWND WINAPI OVRGetWindowForContext(PVOID context);
-BOOL WINAPI OVRShouldPresentOnContext(PVOID context);
-
-static const char* GFX_DRIVER_KEY_FMT = "SYSTEM\\CurrentControlSet\\Control\\Class\\{4d36e968-e325-11ce-bfc1-08002be10318}\\%04d";
-#ifdef _WIN64
-static const char* RTFilter = "OVRDisplayRT64.dll";
-static const char* UMFilter = "OVRDisplay64.dll";
-#else
-static const char* RTFilter = "OVRDisplayRT32.dll";
-static const char* UMFilter = "OVRDisplay32.dll";
-#endif
-static const char* OptimusDrivers = "nvumdshimx.dll nvumdshim.dll";
-
-typedef enum OVRTargetAPI
-{
- DirectX,
- OpenGL
-};
-
-static PVOID lastContext = NULL;
-LINK_APPLICATION_DRIVER appDriver = {0};
-static INT apiVersion = 10;
-
-static CHAR* ReadRegStr(HKEY keySub, const char* keyName, const char* valName)
-{
- CHAR *val = NULL;
- REGSAM access = KEY_READ;
- HKEY hKey;
-
-TryAgainWOW64:
- NTSTATUS res = RegOpenKeyExA( keySub, keyName, 0, access, &hKey );
- if ( res == ERROR_SUCCESS )
- {
- DWORD valLen;
- res = RegQueryValueExA( hKey, valName, NULL, NULL, NULL, &valLen );
- if( res == ERROR_SUCCESS ) {
- val = (CHAR*)calloc( valLen + 1, sizeof(CHAR) );
- res = RegQueryValueExA( hKey, valName, NULL, NULL, (LPBYTE)val, &valLen );
-
- if( res == ERROR_SUCCESS )
- {
- CHAR* byte = val;
- for( DWORD j = 0; j < valLen; ++j )
- {
- if( byte[j] == 0 )
- byte[j] = ' ';
- }
- }
- else
- {
- free( val );
- val = NULL;
- }
- }
- RegCloseKey( hKey );
- }
-
- if( res == ERROR_FILE_NOT_FOUND && keySub == HKEY_LOCAL_MACHINE && access == KEY_READ ) {
-#ifdef _WIN64
- access = KEY_READ | KEY_WOW64_32KEY;
-#else
- access = KEY_READ | KEY_WOW64_64KEY;
-#endif
- goto TryAgainWOW64;
- }
- return val;
-}
-
-#define OLD_DATA_BACKUP_SIZE 16
-
-static WinLoadLibraryA oldProcA = NULL; // Note: This is used to indicate that the shim is in place
-static WinLoadLibraryExA oldProcExA = NULL;
-static WinLoadLibraryW oldProcW = NULL;
-static WinLoadLibraryExW oldProcExW = NULL;
-static WinGetModuleHandleExA oldProcModExA = NULL;
-static WinGetModuleHandleExW oldProcModExW = NULL;
-static WinDirect3DCreate9 oldDirectX9Create = NULL;
-static BYTE oldDirectX9CreateData[OLD_DATA_BACKUP_SIZE];
-static WinDirect3DCreate9Ex oldDirectX9ExCreate = NULL;
-static BYTE oldDirectX9ExCreateData[OLD_DATA_BACKUP_SIZE];
-static WinCreateDXGIFactory oldCreateDXGIFactory = NULL;
-static BYTE oldCreateDXGIFactoryData[OLD_DATA_BACKUP_SIZE];
-static WinCreateDXGIFactory1 oldCreateDXGIFactory1 = NULL;
-static BYTE oldCreateDXGIFactory1Data[OLD_DATA_BACKUP_SIZE];
-static WinCreateDXGIFactory2 oldCreateDXGIFactory2 = NULL;
-static BYTE oldCreateDXGIFactory2Data[OLD_DATA_BACKUP_SIZE];
-
-#define NUM_LOADER_LIBS 4
-
-static const char* loaderLibraryList[NUM_LOADER_LIBS] = {
- "kernel32.dll",
- "api-ms-win-core-libraryloader-l1-2-0.dll",
- "api-ms-win-core-libraryloader-l1-1-0.dll",
- "api-ms-win-core-libraryloader-l1-1-1.dll"
-};
-
-enum ShimedLibraries
-{
- ShimLibDXGI = 0,
- ShimLibD3D9 = 1,
- ShimLibD3D11 = 2,
- ShimLibDXGIDebug = 3,
- ShimLibD3D10Core = 4,
- ShimLibD3D10 = 5,
- ShimLibGL = 6,
- ShimCountMax = 7
-};
-
-static const char* dllList[ShimCountMax] = {
- "dxgi.dll",
- "d3d9.dll",
- "d3d11.dll",
- "dxgidebug.dll",
- "d3d10core.dll",
- "d3d10.dll",
- "opengl32.dll"
-};
-
-static HINSTANCE oldLoaderInstances[ShimCountMax] = { NULL };
-static PROC oldLoaderProcA[ShimCountMax][NUM_LOADER_LIBS] = { { NULL } };
-static PROC oldLoaderProcW[ShimCountMax][NUM_LOADER_LIBS] = { { NULL } };
-static PROC oldLoaderProcExA[ShimCountMax][NUM_LOADER_LIBS] = { { NULL } };
-static PROC oldLoaderProcExW[ShimCountMax][NUM_LOADER_LIBS] = { { NULL } };
-static PROC oldLoaderProcModExA[ShimCountMax][NUM_LOADER_LIBS] = { { NULL } };
-static PROC oldLoaderProcModExW[ShimCountMax][NUM_LOADER_LIBS] = { { NULL } };
-
-static HMODULE rtFilterModule = NULL;
-
-static bool checkForOverride( LPCSTR libFileName, OVRTargetAPI& targetApi )
-{
- for (int i=0; ; i++)
- {
- CHAR keyString[256] = {0};
-
- sprintf_s( keyString, 256, GFX_DRIVER_KEY_FMT, i );
-
- CHAR* infSection = ReadRegStr( HKEY_LOCAL_MACHINE, keyString, "InfSection" );
-
- // No provider name means we're out of display enumerations
- if( infSection == NULL )
- break;
-
- free(infSection);
-
- // Check 64-bit driver names followed by 32-bit driver names
- const char* driverKeys[] = {"UserModeDriverName", "UserModeDriverNameWoW", "OpenGLDriverName", "OpenGLDriverNameWoW", "InstalledDisplayDrivers" };
- for( int j = 0; j < 6; ++j )
- {
- CHAR userModeList[4096] = {0};
-
- switch(j)
- {
- case 5:
- strcpy_s( userModeList, 4095, OptimusDrivers );
- break;
- default:
- {
- CHAR* regString = ReadRegStr( HKEY_LOCAL_MACHINE, keyString, driverKeys[j] );
- if( regString )
- {
- strcpy_s( userModeList, 4095, regString );
- free( regString );
- }
-
- }
- break;
- }
-
- char *nextToken = NULL;
-
- if( userModeList )
- {
- char* first = strtok_s( userModeList, " ", &nextToken );
- while( first )
- {
- if( strstr( libFileName, first ) != 0 )
- {
- if( j < 2 )
- targetApi = DirectX;
- else
- targetApi = OpenGL;
-
- return true;
- }
- first = strtok_s( NULL, " ", &nextToken );
- }
- }
- }
- }
-
- return false;
-}
-
-static HMODULE createShim( LPCSTR lpLibFileName, OVRTargetAPI targetAPI )
-{
- //Sleep(10000);
- if( IN_COMPATIBILITY_MODE() )
- {
- return (*oldProcA)( lpLibFileName );
- }
-
- UNREFERENCED_PARAMETER( targetAPI );
-
- HMODULE result = NULL;
-
- result = (*oldProcA)( UMFilter );
-
- if( result )
- {
- PreloadLibraryFn loadFunc = (PreloadLibraryFn)GetProcAddress( result, "PreloadLibrary" );
- if( loadFunc )
- {
- HRESULT localRes = (*loadFunc)( oldProcA, lpLibFileName, &appDriver );
- if( localRes != S_OK )
- result = NULL;
- }
- }
-
- if( !result )
- {
- MessageBox(nullptr, L"Unable to find the Rift Display Driver.", L"Configuration Error", MB_OK | MB_ICONERROR);
- result = (*oldProcA)( lpLibFileName );
- }
- return result;
-}
-
-static HMODULE
- WINAPI
- OVRLoadLibraryA(
- __in LPCSTR lpLibFileName
- )
-{
- OVRTargetAPI targetAPI = DirectX;
- bool needShim = checkForOverride( lpLibFileName, targetAPI );
- if( !needShim )
- return (*oldProcA)( lpLibFileName );
-
- return createShim( lpLibFileName, targetAPI );
-}
-
-static HMODULE
- WINAPI
- OVRLoadLibraryW(
- __in LPCWSTR lpLibFileName
- )
-{
- WIDE_TO_MB(lpLibFileName); // Convert lpLibFileName -> lpLibFileName_cstr
-
- OVRTargetAPI targetAPI = DirectX;
-
- bool needShim = checkForOverride( lpLibFileName_cstr, targetAPI );
- if( !needShim )
- return (*oldProcW)( lpLibFileName );
-
- return createShim( lpLibFileName_cstr, targetAPI );
-}
-
-static HMODULE
- WINAPI
- OVRLoadLibraryExA(
- __in LPCSTR lpLibFileName,
- __reserved HANDLE hFile,
- __in DWORD dwFlags
-
- )
-{
- OVRTargetAPI targetAPI = DirectX;
-
- bool needShim = checkForOverride( lpLibFileName, targetAPI );
- if( !needShim )
- return (*oldProcExA)( lpLibFileName, hFile, dwFlags );
-
- // FIXME: Don't throw away the flags parameter
- return createShim( lpLibFileName, targetAPI );
-}
-
-static HMODULE
- WINAPI
- OVRLoadLibraryExW(
- __in LPCWSTR lpLibFileName,
- __reserved HANDLE hFile,
- __in DWORD dwFlags
- )
-{
- WIDE_TO_MB(lpLibFileName); // Convert lpLibFileName -> lpLibFileName_cstr
-
- OVRTargetAPI targetAPI = DirectX;
-
- bool needShim = checkForOverride( lpLibFileName_cstr, targetAPI );
- if( !needShim )
- return (*oldProcExW)( lpLibFileName, hFile, dwFlags );
-
- // FIXME: Don't throw away the flags parameter
- return createShim( lpLibFileName_cstr, targetAPI );
-}
-
-static BOOL WINAPI OVRGetModuleHandleExA(
- __in DWORD dwFlags,
- __in_opt LPCSTR lpModuleName,
- __out HMODULE *phModule
- )
-{
- OVRTargetAPI targetAPI = DirectX;
-
- bool needShim = checkForOverride( lpModuleName, targetAPI );
- if( !needShim )
- {
- return (*oldProcModExA)( dwFlags, lpModuleName, phModule );
- }
-
- *phModule = createShim( lpModuleName, targetAPI );
-
- return TRUE;
-}
-
-static BOOL WINAPI OVRGetModuleHandleExW(
- __in DWORD dwFlags,
- __in_opt LPCWSTR lpModuleName,
- __out HMODULE *phModule
- )
-{
- WIDE_TO_MB(lpModuleName); // Convert lpModuleName -> lpModuleName_cstr
-
- OVRTargetAPI targetAPI = DirectX;
-
- bool needShim = checkForOverride( lpModuleName_cstr, targetAPI );
- if( !needShim )
- {
- return (*oldProcModExW)( dwFlags, lpModuleName, phModule );
- }
-
- *phModule = createShim( lpModuleName_cstr, targetAPI );
-
- return TRUE;
-}
-
-#ifdef _AMD64_
-static void restoreFunction( PROC pfnHookAPIAddr, PBYTE oldData )
-{
- static const LONGLONG addressSize = sizeof(PROC);
- static const LONGLONG jmpSize = addressSize + 6;
-
- DWORD oldProtect;
- VirtualProtect((LPVOID)pfnHookAPIAddr, OLD_DATA_BACKUP_SIZE,
- PAGE_EXECUTE_READWRITE, &oldProtect);
-
- memcpy(pfnHookAPIAddr, oldData, OLD_DATA_BACKUP_SIZE);
-
- VirtualProtect((LPVOID)pfnHookAPIAddr, OLD_DATA_BACKUP_SIZE, oldProtect, NULL);
-}
-
-static void setFunction( PROC pfnHookAPIAddr, PROC replacementFunction, PBYTE oldData )
-{
- static const LONGLONG addressSize = sizeof(PROC);
- static const LONGLONG jmpSize = addressSize + 6;
-
- INT_PTR jumpOffset = (INT_PTR)replacementFunction;
-
- DWORD oldProtect;
- VirtualProtect((LPVOID)pfnHookAPIAddr, OLD_DATA_BACKUP_SIZE,
- PAGE_EXECUTE_READWRITE, &oldProtect);
-
- memcpy(oldData, pfnHookAPIAddr, OLD_DATA_BACKUP_SIZE);
-
- PBYTE functionData = (PBYTE)pfnHookAPIAddr;
- functionData[0] = 0xff; // JMP [RIP+0]
- functionData[1] = 0x25; //
- functionData[2] = 0x00; //
- functionData[3] = 0x00; //
- functionData[4] = 0x00; //
- functionData[5] = 0x00; //
- memcpy( functionData + 6, &jumpOffset, sizeof( INT_PTR ) );
-
- VirtualProtect((LPVOID)pfnHookAPIAddr, OLD_DATA_BACKUP_SIZE, oldProtect, NULL);
-}
-#else
-static void restoreFunction( PROC pfnHookAPIAddr, PBYTE oldData )
-{
- static const LONGLONG addressSize = sizeof(PROC);
- static const LONGLONG jmpSize = addressSize + 1;
-
- DWORD oldProtect;
- VirtualProtect((LPVOID)pfnHookAPIAddr, jmpSize,
- PAGE_EXECUTE_READWRITE, &oldProtect);
-
- memcpy(pfnHookAPIAddr, oldData, jmpSize);
-
- VirtualProtect((LPVOID)pfnHookAPIAddr, jmpSize, oldProtect, NULL);
-}
-
-static void setFunction( PROC pfnHookAPIAddr, PROC replacementFunction, PBYTE oldData )
-{
- static const LONGLONG addressSize = sizeof(PROC);
- static const LONGLONG jmpSize = addressSize + 1;
-
- INT_PTR jumpOffset = (INT_PTR)replacementFunction - (INT_PTR)pfnHookAPIAddr - jmpSize;
-
- DWORD oldProtect;
- VirtualProtect((LPVOID)pfnHookAPIAddr, jmpSize,
- PAGE_EXECUTE_READWRITE, &oldProtect);
-
- memcpy(oldData, pfnHookAPIAddr, jmpSize);
-
- PBYTE functionData = (PBYTE)pfnHookAPIAddr;
- memcpy( oldData, functionData, jmpSize );
- functionData[0] = 0xe9;
- memcpy( functionData + 1, &jumpOffset, sizeof( INT_PTR ) );
-
- VirtualProtect((LPVOID)pfnHookAPIAddr, jmpSize, oldProtect, NULL);
-}
-#endif
-
-static BOOL WINAPI OVRLocalIsInitializingDisplay( PVOID context, UINT width, UINT height )
-{
- UINT expectedWidth, expectedHeight, rotation;
-
- OVRExpectedResolution( context, &expectedWidth, &expectedHeight, &rotation );
-
- if( appDriver.pfnActiveAPIVersion )
- apiVersion = (*appDriver.pfnActiveAPIVersion)( context );
-
- switch( apiVersion )
- {
- case 1: // OpenGL
- case 10: // DirectX 1X
- if( width == expectedWidth && height == expectedHeight )
- return TRUE;
- break;
- case 9: // DirectX 9
- if( rotation == 90 || rotation == 270 )
- {
- if( width == expectedHeight && height == expectedWidth )
- return TRUE;
- }
- else
- {
- if( width == expectedWidth && height == expectedHeight )
- return TRUE;
- }
- break;
- default:
- break;
- }
-
- return FALSE;
-}
-
-
-HRESULT APIENTRY OVRDirect3DCreate9Ex(UINT SDKVersion, void** aDevice)
-{
- apiVersion = 9;
-
- HRESULT result = S_OK;
-
- restoreFunction( (PROC)oldDirectX9ExCreate, oldDirectX9ExCreateData );
-
- if (IN_COMPATIBILITY_MODE())
- {
- result = (*oldDirectX9ExCreate)(SDKVersion, aDevice);
- }
- else
- {
- WinDirect3DCreate9Ex createFunction = (WinDirect3DCreate9Ex)GetProcAddress(rtFilterModule, "Direct3DCreate9Ex");
- result = (*createFunction)(SDKVersion, aDevice);
- }
-
- setFunction( (PROC)oldDirectX9ExCreate, (PROC)OVRDirect3DCreate9Ex, oldDirectX9ExCreateData );
-
- printf("%s result 0x%x\n", __FUNCTION__, result);
-
- return result;
-}
-
-void* APIENTRY OVRDirect3DCreate9(UINT SDKVersion)
-{
- void* result = NULL;
-
- OVRDirect3DCreate9Ex( SDKVersion, &result );
-
- return result;
-}
-
-HRESULT APIENTRY OVRCreateDXGIFactory(
- __in REFIID riid,
- __out void **ppFactory
- )
-{
- HRESULT result = E_FAIL;
-
- restoreFunction( (PROC)oldCreateDXGIFactory, oldCreateDXGIFactoryData );
-
- if (IN_COMPATIBILITY_MODE())
- {
- result = (*oldCreateDXGIFactory)(riid, ppFactory);
- }
- else
- {
- WinCreateDXGIFactory createFunction = (WinCreateDXGIFactory)GetProcAddress(rtFilterModule, "CreateDXGIFactory");
- result = (*createFunction)(riid, ppFactory);
- }
-
- setFunction( (PROC)oldCreateDXGIFactory, (PROC)OVRCreateDXGIFactory, oldCreateDXGIFactoryData );
-
- printf("%s result 0x%x\n", __FUNCTION__, result);
-
- return result;
-}
-
-HRESULT APIENTRY OVRCreateDXGIFactory1(
- __in REFIID riid,
- __out void **ppFactory
- )
-{
- HRESULT result = E_FAIL;
-
- restoreFunction( (PROC)oldCreateDXGIFactory1, oldCreateDXGIFactory1Data );
-
- if (IN_COMPATIBILITY_MODE())
- {
- result = (*oldCreateDXGIFactory1)(riid, ppFactory);
- }
- else
- {
- WinCreateDXGIFactory1 createFunction = (WinCreateDXGIFactory1)GetProcAddress(rtFilterModule, "CreateDXGIFactory1");
- result = (*createFunction)(riid, ppFactory);
- }
-
- setFunction( (PROC)oldCreateDXGIFactory1, (PROC)OVRCreateDXGIFactory1, oldCreateDXGIFactory1Data );
-
- printf("%s result 0x%x\n", __FUNCTION__, result);
-
- return result;
-}
-
-HRESULT APIENTRY OVRCreateDXGIFactory2(
- __in UINT flags,
- __in const IID &riid,
- __out void **ppFactory
- )
-{
- HRESULT result = E_FAIL;
-
- restoreFunction( (PROC)oldCreateDXGIFactory2, oldCreateDXGIFactory2Data );
-
- if (IN_COMPATIBILITY_MODE())
- {
- result = (*oldCreateDXGIFactory2)(flags, riid, ppFactory);
- }
- else
- {
- WinCreateDXGIFactory2 createFunction = (WinCreateDXGIFactory2)GetProcAddress(rtFilterModule, "CreateDXGIFactory2");
- result = (*createFunction)(flags, riid, ppFactory);
- }
-
- setFunction( (PROC)oldCreateDXGIFactory2, (PROC)OVRCreateDXGIFactory2, oldCreateDXGIFactory2Data );
-
- printf("%s result 0x%x\n", __FUNCTION__, result);
-
- return result;
-}
-
-static PROC SetProcAddressDirect(
- __in HINSTANCE hInstance,
- __in LPCSTR lpProcName,
- __in PROC newFunction,
- __inout BYTE* oldData
- )
-{
- static const LONGLONG addressSize = sizeof(PROC);
- static const LONGLONG jmpSize = addressSize + 1;
-
- PROC result = NULL;
-
- PROC pfnHookAPIAddr = GetProcAddress( hInstance, lpProcName );
-
- if( pfnHookAPIAddr )
- {
- result = pfnHookAPIAddr;
-
- setFunction( pfnHookAPIAddr, newFunction, oldData );
- }
-
- return result;
-}
-
-static PROC SetProcAddressA(
- __in HINSTANCE targetModule,
- __in LPCSTR lpLibFileName,
- __in LPCSTR lpProcName,
- __in PROC newFunction
- )
-{
- HMODULE hModule = LoadLibraryA( lpLibFileName );
- if(hModule == NULL)
- return NULL;
-
- // To do: call FreeLibrary(hModule) at the appropriate time.
- PROC pfnHookAPIAddr = GetProcAddress(hModule, lpProcName );
-
- HINSTANCE hInstance = targetModule;
-
- ULONG ulSize;
- PIMAGE_IMPORT_DESCRIPTOR pImportDesc =
- (PIMAGE_IMPORT_DESCRIPTOR)ImageDirectoryEntryToData(
- hInstance,
- TRUE,
- IMAGE_DIRECTORY_ENTRY_IMPORT,
- &ulSize
- );
-
- while (pImportDesc->Name)
- {
- PSTR pszModName = (PSTR)((PBYTE) hInstance + pImportDesc->Name);
- if (_stricmp(pszModName, lpLibFileName) == 0)
- break;
- pImportDesc++;
- }
-
- PIMAGE_THUNK_DATA pThunk =
- (PIMAGE_THUNK_DATA)((PBYTE) hInstance + pImportDesc->FirstThunk);
-
- while (pThunk->u1.Function)
- {
- PROC* ppfn = (PROC*) &pThunk->u1.Function;
- BOOL bFound = (*ppfn == pfnHookAPIAddr);
-
- if (bFound)
- {
- MEMORY_BASIC_INFORMATION mbi;
- VirtualQuery(
- ppfn,
- &mbi,
- sizeof(MEMORY_BASIC_INFORMATION)
- );
- VirtualProtect(
- mbi.BaseAddress,
- mbi.RegionSize,
- PAGE_READWRITE,
- &mbi.Protect);
-
- *ppfn = *newFunction;
-
- DWORD dwOldProtect;
- VirtualProtect(
- mbi.BaseAddress,
- mbi.RegionSize,
- mbi.Protect,
- &dwOldProtect
- );
- break;
- }
- pThunk++;
- }
-
- return pfnHookAPIAddr;
-}
-
-
-bool checkUMDriverOverrides(void* context)
-{
- lastContext = context;
- if (oldProcA != NULL)
- return true;
-
- PreloadLibraryRTFn loadFunc = NULL;
- for( int i = 0; i < ShimCountMax; ++i )
- {
- HINSTANCE hInst = NULL;
- try
- {
- hInst = LoadLibraryA(dllList[i]);
- }
- catch(...)
- {
- }
-
- oldLoaderInstances[i] = hInst;
-
- if (hInst == NULL)
- continue;
-
- ShimedLibraries libCount = (ShimedLibraries)i;
- switch( libCount )
- {
- case ShimLibDXGI:
- oldCreateDXGIFactory = (WinCreateDXGIFactory)SetProcAddressDirect( hInst, "CreateDXGIFactory", (PROC)OVRCreateDXGIFactory, oldCreateDXGIFactoryData );
- oldCreateDXGIFactory1 = (WinCreateDXGIFactory1)SetProcAddressDirect( hInst, "CreateDXGIFactory1", (PROC)OVRCreateDXGIFactory1, oldCreateDXGIFactory1Data );
- oldCreateDXGIFactory2 = (WinCreateDXGIFactory2)SetProcAddressDirect( hInst, "CreateDXGIFactory2", (PROC)OVRCreateDXGIFactory2, oldCreateDXGIFactory2Data );
- break;
- case ShimLibD3D9:
- oldDirectX9Create = (WinDirect3DCreate9)SetProcAddressDirect( hInst, "Direct3DCreate9", (PROC)OVRDirect3DCreate9, oldDirectX9CreateData );
- oldDirectX9ExCreate = (WinDirect3DCreate9Ex)SetProcAddressDirect( hInst, "Direct3DCreate9Ex", (PROC)OVRDirect3DCreate9Ex, oldDirectX9ExCreateData );
- break;
- default:
- break;
- }
-
- for (int j = 0; j < NUM_LOADER_LIBS; ++j)
- {
- const char* loaderLibrary = loaderLibraryList[j];
-
- PROC temp = NULL;
- temp = SetProcAddressA(hInst, loaderLibrary, "LoadLibraryA", (PROC)OVRLoadLibraryA);
- if (!oldProcA)
- {
- oldProcA = (WinLoadLibraryA)temp;
- }
- oldLoaderProcA[i][j] = temp;
-
- temp = SetProcAddressA(hInst, loaderLibrary, "LoadLibraryW", (PROC)OVRLoadLibraryW);
- if (!oldProcW)
- {
- oldProcW = (WinLoadLibraryW)temp;
- }
- oldLoaderProcW[i][j] = temp;
-
- temp = SetProcAddressA(hInst, loaderLibrary, "LoadLibraryExA", (PROC)OVRLoadLibraryExA);
- if (!oldProcExA)
- {
- oldProcExA = (WinLoadLibraryExA)temp;
- }
- oldLoaderProcExA[i][j] = temp;
-
- temp = SetProcAddressA(hInst, loaderLibrary, "LoadLibraryExW", (PROC)OVRLoadLibraryExW);
- if (!oldProcExW)
- {
- oldProcExW = (WinLoadLibraryExW)temp;
- }
- oldLoaderProcExW[i][j] = temp;
-
- temp = SetProcAddressA(hInst, loaderLibrary, "GetModuleHandleExA", (PROC)OVRGetModuleHandleExA);
- if (!oldProcModExA)
- {
- oldProcModExA = (WinGetModuleHandleExA)temp;
- }
- oldLoaderProcModExA[i][j] = temp;
-
- temp = SetProcAddressA(hInst, loaderLibrary, "GetModuleHandleExW", (PROC)OVRGetModuleHandleExW);
- if (!oldProcModExW)
- {
- oldProcModExW = (WinGetModuleHandleExW)temp;
- }
- oldLoaderProcModExW[i][j] = temp;
- }
-
- if (loadFunc == NULL)
- {
- loadFunc = (PreloadLibraryRTFn)GetProcAddress(hInst, "PreloadLibraryRT");
- }
- }
-
- rtFilterModule = oldProcA ? (*oldProcA)(RTFilter) : NULL;
-
- IsCreatingBackBuffer backBufferFunc = NULL;
- ShouldVSync shouldVSyncFunc = NULL;
- GetRTFilterVersion getRTFilterVersionFunc = NULL;
-
- if (rtFilterModule != NULL)
- {
- loadFunc = (PreloadLibraryRTFn)GetProcAddress(rtFilterModule, "PreloadLibraryRT");
- backBufferFunc = (IsCreatingBackBuffer)GetProcAddress(rtFilterModule, "OVRIsCreatingBackBuffer");
- shouldVSyncFunc = (ShouldVSync)GetProcAddress(rtFilterModule, "OVRShouldVSync");
- getRTFilterVersionFunc = (GetRTFilterVersion)GetProcAddress(rtFilterModule, "OVRGetRTFilterVersion");
- }
-
- if (loadFunc == NULL)
- {
- MessageBox(nullptr, L"Unable to load the Oculus Display Driver. Please reinstall the Oculus Runtime.", L"Configuration Error", MB_OK | MB_ICONERROR);
- return false;
- }
-
- if (getRTFilterVersionFunc == NULL)
- {
- WCHAR message[1000] = {};
- swprintf_s(message, L"This app requires the %d.%d.%d version of the Oculus Runtime.", OVR_MAJOR_VERSION, OVR_MINOR_VERSION, OVR_BUILD_VERSION);
- MessageBox(nullptr, message, L"Configuration Error", MB_OK | MB_ICONERROR);
- return false;
- }
-
- // Verify that we are running with the appropriate display driver
- {
- const ULONG rtFilterVersion = (*getRTFilterVersionFunc)();
- const ULONG rtFilterMajor = OVR_GET_VERSION_MAJOR(rtFilterVersion);
- const ULONG rtFilterMinor = OVR_GET_VERSION_MINOR(rtFilterVersion);
-
- if ((rtFilterMajor != OVR_RTFILTER_VERSION_MAJOR) || (rtFilterMinor < OVR_RTFILTER_VERSION_MINOR))
- {
- WCHAR message[1000] = {};
- swprintf_s(message, L"This app requires the %d.%d.%d version of the Oculus Runtime.", OVR_MAJOR_VERSION, OVR_MINOR_VERSION, OVR_BUILD_VERSION);
- MessageBox(nullptr, message, L"Configuration Error", MB_OK | MB_ICONERROR);
- return false;
- }
- }
-
- appDriver.version = OVR_RENDER_SHIM_VERSION_MAJOR;
- appDriver.context = lastContext;
-
-// appDriver.pfnInitializingDisplay = OVRIsInitializingDisplay;
- appDriver.pfnInitializingDisplay = OVRLocalIsInitializingDisplay;
- appDriver.pfnRiftForContext = OVRRiftForContext;
- appDriver.pfnCloseRiftForContext = OVRCloseRiftForContext;
- appDriver.pfnWindowDisplayResolution = OVRWindowDisplayResolution;
- appDriver.pfnShouldEnableDebug = OVRShouldEnableDebug;
- appDriver.pfnIsCreatingBackBuffer = (backBufferFunc == NULL) ? OVRIsCreatingBackBuffer : backBufferFunc;
- appDriver.pfnShouldVSync = (shouldVSyncFunc == NULL) ? OVRShouldVSync : shouldVSyncFunc;
- appDriver.pfnExpectedResolution = OVRExpectedResolution;
- appDriver.pfnMirroringEnabled = OVRMirroringEnabled;
- appDriver.pfnGetWindowForContext = OVRGetWindowForContext;
- appDriver.pfnPresentRiftOnContext = OVRShouldPresentOnContext;
-
- appDriver.pfnDirect3DCreate9 = oldDirectX9Create;
- appDriver.pfnDirect3DCreate9Ex = oldDirectX9ExCreate;
- appDriver.pfnCreateDXGIFactory = oldCreateDXGIFactory;
- appDriver.pfnCreateDXGIFactory1 = oldCreateDXGIFactory1;
- appDriver.pfnCreateDXGIFactory2 = oldCreateDXGIFactory2;
-
- (*loadFunc)( &appDriver );
-
- return true;
-}
-
-void clearUMDriverOverrides()
-{
- if (oldProcA != NULL)
- {
- // Unpatch all the things.
-
- if (oldCreateDXGIFactory)
- {
- restoreFunction((PROC)oldCreateDXGIFactory, oldCreateDXGIFactoryData);
- }
- if (oldCreateDXGIFactory1)
- {
- restoreFunction((PROC)oldCreateDXGIFactory1, oldCreateDXGIFactory1Data);
- }
- if (oldCreateDXGIFactory2)
- {
- restoreFunction((PROC)oldCreateDXGIFactory2, oldCreateDXGIFactory2Data);
- }
- if (oldDirectX9Create)
- {
- restoreFunction((PROC)oldDirectX9Create, oldDirectX9CreateData);
- }
- if (oldDirectX9ExCreate)
- {
- restoreFunction((PROC)oldDirectX9ExCreate, oldDirectX9ExCreateData);
- }
- if (oldCreateDXGIFactory2)
- {
- restoreFunction((PROC)oldCreateDXGIFactory2, oldCreateDXGIFactory2Data);
- }
-
- for (int i = 0; i < ShimCountMax; ++i)
- {
- HINSTANCE hInst = oldLoaderInstances[i];
-
- if (hInst != NULL)
- {
- for (int j = 0; j < NUM_LOADER_LIBS; ++j)
- {
- const char* loaderLibrary = loaderLibraryList[j];
-
- if (oldLoaderProcA[j])
- {
- SetProcAddressA(hInst, loaderLibrary, "LoadLibraryA", oldLoaderProcA[i][j]);
- }
- if (oldLoaderProcW[j])
- {
- SetProcAddressA(hInst, loaderLibrary, "LoadLibraryW", oldLoaderProcW[i][j]);
- }
- if (oldLoaderProcExA[j])
- {
- SetProcAddressA(hInst, loaderLibrary, "LoadLibraryExA", oldLoaderProcExA[i][j]);
- }
- if (oldLoaderProcExW[j])
- {
- SetProcAddressA(hInst, loaderLibrary, "LoadLibraryExW", oldLoaderProcExW[i][j]);
- }
- if (oldLoaderProcModExA[j])
- {
- SetProcAddressA(hInst, loaderLibrary, "GetModuleHandleExA", oldLoaderProcModExA[i][j]);
- }
- if (oldLoaderProcModExW[j])
- {
- SetProcAddressA(hInst, loaderLibrary, "GetModuleHandleExW", oldLoaderProcModExW[i][j]);
- }
- }
-
- FreeLibrary(hInst);
- }
- }
-
- if (rtFilterModule != NULL)
- {
- LPFNCANUNLOADNOW pfnCanUnloadNow = (LPFNCANUNLOADNOW)GetProcAddress(rtFilterModule, "DllCanUnloadNow");
- if (pfnCanUnloadNow && pfnCanUnloadNow() == S_OK)
- {
- FreeLibrary(rtFilterModule);
- rtFilterModule = NULL;
- }
- }
-
- oldProcA = NULL;
- oldProcExA = NULL;
- oldProcW = NULL;
- oldProcExW = NULL;
- oldProcModExA = NULL;
- oldProcModExW = NULL;
- oldDirectX9Create = NULL;
- oldDirectX9ExCreate = NULL;
- oldCreateDXGIFactory = NULL;
- oldCreateDXGIFactory1 = NULL;
- oldCreateDXGIFactory2 = NULL;
- lastContext = NULL;
- }
-}
+/************************************************************************************ + +Filename : OVR_Win32_DisplayShim.cpp +Content : Shared static functions for inclusion that allow for an application + to inject the usermode driver into an application +Created : March 21, 2014 +Authors : Dean Beeler + +Copyright : Copyright 2014 Oculus VR, LLC All Rights reserved. + +Use of this software is subject to the terms of the Oculus Inc license +agreement provided at the time of installation or download, or which +otherwise accompanies this software in either electronic or hard copy form. + +************************************************************************************/ + +#include "../../Include/OVR_Version.h" + +#ifndef AVOID_LIB_OVR +#include "Kernel/OVR_Types.h" // Without this we can get warnings (due to VC++ bugs) about _malloca being redefined, and miss LibOVR overrides. +#include "Kernel/OVR_Win32_IncludeWindows.h" +#else +#include <windows.h> +#ifndef NTSTATUS +#define NTSTATUS DWORD +#endif +#endif + +#include <ObjBase.h> + +#include <DbgHelp.h> +#include "OVR_Win32_Dxgi_Display.h" +#include "OVR_Win32_ShimVersion.h" + +#if AVOID_LIB_OVR +#define IN_COMPATIBILITY_MODE() (0) +#else +#include "OVR_Win32_Display.h" +#define IN_COMPATIBILITY_MODE() OVR::Display::InCompatibilityMode() +#endif + +#pragma comment(lib, "DbgHelp.lib") + +#ifndef alloca +#include <malloc.h> // alloca +#endif + +#define WIDE_TO_MB(wideString) \ + int wideString ## _slen = (int)wcslen(wideString); \ + char* wideString ## _cstr = (char*)alloca(wideString ## _slen * 2); \ + int count = WideCharToMultiByte(GetACP(), 0, wideString, -1, wideString ## _cstr, wideString ## _slen * 2, NULL, NULL); \ + wideString ## _cstr[count] = '\0'; + +// Forward declarations +// These functions are implemented in OVR_Win32_DisplayDevice.cpp. + +BOOL WINAPI OVRIsInitializingDisplay( PVOID context, UINT width, UINT height ); +BOOL WINAPI OVRIsCreatingBackBuffer( PVOID context ); +BOOL WINAPI OVRShouldVSync( ); +ULONG WINAPI OVRRiftForContext( PVOID context, HANDLE driverHandle ); +BOOL WINAPI OVRCloseRiftForContext( PVOID context, HANDLE driverHandle, ULONG rift ); +BOOL WINAPI OVRWindowDisplayResolution( PVOID context, UINT* width, UINT* height, + UINT* titleHeight, UINT* borderWidth, + BOOL* vsyncEnabled ); +BOOL WINAPI OVRExpectedResolution( PVOID context, UINT* width, UINT* height, UINT* rotationInDegrees ); +BOOL WINAPI OVRShouldEnableDebug(); +BOOL WINAPI OVRMirroringEnabled( PVOID context ); +HWND WINAPI OVRGetWindowForContext(PVOID context); +BOOL WINAPI OVRShouldPresentOnContext(PVOID context); + +static const char* GFX_DRIVER_KEY_FMT = "SYSTEM\\CurrentControlSet\\Control\\Class\\{4d36e968-e325-11ce-bfc1-08002be10318}\\%04d"; +#ifdef _WIN64 +static const char* RTFilter = "OVRDisplayRT64.dll"; +static const char* UMFilter = "OVRDisplay64.dll"; +#else +static const char* RTFilter = "OVRDisplayRT32.dll"; +static const char* UMFilter = "OVRDisplay32.dll"; +#endif +static const char* OptimusDrivers = "nvumdshimx.dll nvumdshim.dll"; + +typedef enum OVRTargetAPI +{ + DirectX, + OpenGL +}; + +static PVOID lastContext = NULL; +LINK_APPLICATION_DRIVER appDriver = {0}; +static INT apiVersion = 10; + +static CHAR* ReadRegStr(HKEY keySub, const char* keyName, const char* valName) +{ + CHAR *val = NULL; + REGSAM access = KEY_READ; + HKEY hKey; + +TryAgainWOW64: + NTSTATUS res = RegOpenKeyExA( keySub, keyName, 0, access, &hKey ); + if ( res == ERROR_SUCCESS ) + { + DWORD valLen; + res = RegQueryValueExA( hKey, valName, NULL, NULL, NULL, &valLen ); + if( res == ERROR_SUCCESS ) { + val = (CHAR*)calloc( valLen + 1, sizeof(CHAR) ); + res = RegQueryValueExA( hKey, valName, NULL, NULL, (LPBYTE)val, &valLen ); + + if( res == ERROR_SUCCESS ) + { + CHAR* byte = val; + for( DWORD j = 0; j < valLen; ++j ) + { + if( byte[j] == 0 ) + byte[j] = ' '; + } + } + else + { + free( val ); + val = NULL; + } + } + RegCloseKey( hKey ); + } + + if( res == ERROR_FILE_NOT_FOUND && keySub == HKEY_LOCAL_MACHINE && access == KEY_READ ) { +#ifdef _WIN64 + access = KEY_READ | KEY_WOW64_32KEY; +#else + access = KEY_READ | KEY_WOW64_64KEY; +#endif + goto TryAgainWOW64; + } + return val; +} + +#define OLD_DATA_BACKUP_SIZE 16 + +static WinLoadLibraryA oldProcA = NULL; // Note: This is used to indicate that the shim is in place +static WinLoadLibraryExA oldProcExA = NULL; +static WinLoadLibraryW oldProcW = NULL; +static WinLoadLibraryExW oldProcExW = NULL; +static WinGetModuleHandleExA oldProcModExA = NULL; +static WinGetModuleHandleExW oldProcModExW = NULL; +static WinDirect3DCreate9 oldDirectX9Create = NULL; +static BYTE oldDirectX9CreateData[OLD_DATA_BACKUP_SIZE]; +static WinDirect3DCreate9Ex oldDirectX9ExCreate = NULL; +static BYTE oldDirectX9ExCreateData[OLD_DATA_BACKUP_SIZE]; +static WinCreateDXGIFactory oldCreateDXGIFactory = NULL; +static BYTE oldCreateDXGIFactoryData[OLD_DATA_BACKUP_SIZE]; +static WinCreateDXGIFactory1 oldCreateDXGIFactory1 = NULL; +static BYTE oldCreateDXGIFactory1Data[OLD_DATA_BACKUP_SIZE]; +static WinCreateDXGIFactory2 oldCreateDXGIFactory2 = NULL; +static BYTE oldCreateDXGIFactory2Data[OLD_DATA_BACKUP_SIZE]; + +#define NUM_LOADER_LIBS 4 + +static const char* loaderLibraryList[NUM_LOADER_LIBS] = { + "kernel32.dll", + "api-ms-win-core-libraryloader-l1-2-0.dll", + "api-ms-win-core-libraryloader-l1-1-0.dll", + "api-ms-win-core-libraryloader-l1-1-1.dll" +}; + +enum ShimedLibraries +{ + ShimLibDXGI = 0, + ShimLibD3D9 = 1, + ShimLibD3D11 = 2, + ShimLibDXGIDebug = 3, + ShimLibD3D10Core = 4, + ShimLibD3D10 = 5, + ShimLibGL = 6, + ShimCountMax = 7 +}; + +static const char* dllList[ShimCountMax] = { + "dxgi.dll", + "d3d9.dll", + "d3d11.dll", + "dxgidebug.dll", + "d3d10core.dll", + "d3d10.dll", + "opengl32.dll" +}; + +static HINSTANCE oldLoaderInstances[ShimCountMax] = { NULL }; +static PROC oldLoaderProcA[ShimCountMax][NUM_LOADER_LIBS] = { { NULL } }; +static PROC oldLoaderProcW[ShimCountMax][NUM_LOADER_LIBS] = { { NULL } }; +static PROC oldLoaderProcExA[ShimCountMax][NUM_LOADER_LIBS] = { { NULL } }; +static PROC oldLoaderProcExW[ShimCountMax][NUM_LOADER_LIBS] = { { NULL } }; +static PROC oldLoaderProcModExA[ShimCountMax][NUM_LOADER_LIBS] = { { NULL } }; +static PROC oldLoaderProcModExW[ShimCountMax][NUM_LOADER_LIBS] = { { NULL } }; + +static HMODULE rtFilterModule = NULL; + +static bool checkForOverride( LPCSTR libFileName, OVRTargetAPI& targetApi ) +{ + for (int i=0; ; i++) + { + CHAR keyString[256] = {0}; + + sprintf_s( keyString, 256, GFX_DRIVER_KEY_FMT, i ); + + CHAR* infSection = ReadRegStr( HKEY_LOCAL_MACHINE, keyString, "InfSection" ); + + // No provider name means we're out of display enumerations + if( infSection == NULL ) + break; + + free(infSection); + + // Check 64-bit driver names followed by 32-bit driver names + const char* driverKeys[] = {"UserModeDriverName", "UserModeDriverNameWoW", "OpenGLDriverName", "OpenGLDriverNameWoW", "InstalledDisplayDrivers" }; + for( int j = 0; j < 6; ++j ) + { + CHAR userModeList[4096] = {0}; + + switch(j) + { + case 5: + strcpy_s( userModeList, 4095, OptimusDrivers ); + break; + default: + { + CHAR* regString = ReadRegStr( HKEY_LOCAL_MACHINE, keyString, driverKeys[j] ); + if( regString ) + { + strcpy_s( userModeList, 4095, regString ); + free( regString ); + } + + } + break; + } + + char *nextToken = NULL; + + if( userModeList ) + { + char* first = strtok_s( userModeList, " ", &nextToken ); + while( first ) + { + if( strstr( libFileName, first ) != 0 ) + { + if( j < 2 ) + targetApi = DirectX; + else + targetApi = OpenGL; + + return true; + } + first = strtok_s( NULL, " ", &nextToken ); + } + } + } + } + + return false; +} + +static HMODULE createShim( LPCSTR lpLibFileName, OVRTargetAPI targetAPI ) +{ + //Sleep(10000); + if( IN_COMPATIBILITY_MODE() ) + { + return (*oldProcA)( lpLibFileName ); + } + + UNREFERENCED_PARAMETER( targetAPI ); + + HMODULE result = NULL; + + result = (*oldProcA)( UMFilter ); + + if( result ) + { + PreloadLibraryFn loadFunc = (PreloadLibraryFn)GetProcAddress( result, "PreloadLibrary" ); + if( loadFunc ) + { + HRESULT localRes = (*loadFunc)( oldProcA, lpLibFileName, &appDriver ); + if( localRes != S_OK ) + result = NULL; + } + } + + if( !result ) + { + MessageBox(nullptr, L"Unable to find the Rift Display Driver.", L"Configuration Error", MB_OK | MB_ICONERROR); + result = (*oldProcA)( lpLibFileName ); + } + return result; +} + +static HMODULE + WINAPI + OVRLoadLibraryA( + __in LPCSTR lpLibFileName + ) +{ + OVRTargetAPI targetAPI = DirectX; + bool needShim = checkForOverride( lpLibFileName, targetAPI ); + if( !needShim ) + return (*oldProcA)( lpLibFileName ); + + return createShim( lpLibFileName, targetAPI ); +} + +static HMODULE + WINAPI + OVRLoadLibraryW( + __in LPCWSTR lpLibFileName + ) +{ + WIDE_TO_MB(lpLibFileName); // Convert lpLibFileName -> lpLibFileName_cstr + + OVRTargetAPI targetAPI = DirectX; + + bool needShim = checkForOverride( lpLibFileName_cstr, targetAPI ); + if( !needShim ) + return (*oldProcW)( lpLibFileName ); + + return createShim( lpLibFileName_cstr, targetAPI ); +} + +static HMODULE + WINAPI + OVRLoadLibraryExA( + __in LPCSTR lpLibFileName, + __reserved HANDLE hFile, + __in DWORD dwFlags + + ) +{ + OVRTargetAPI targetAPI = DirectX; + + bool needShim = checkForOverride( lpLibFileName, targetAPI ); + if( !needShim ) + return (*oldProcExA)( lpLibFileName, hFile, dwFlags ); + + // FIXME: Don't throw away the flags parameter + return createShim( lpLibFileName, targetAPI ); +} + +static HMODULE + WINAPI + OVRLoadLibraryExW( + __in LPCWSTR lpLibFileName, + __reserved HANDLE hFile, + __in DWORD dwFlags + ) +{ + WIDE_TO_MB(lpLibFileName); // Convert lpLibFileName -> lpLibFileName_cstr + + OVRTargetAPI targetAPI = DirectX; + + bool needShim = checkForOverride( lpLibFileName_cstr, targetAPI ); + if( !needShim ) + return (*oldProcExW)( lpLibFileName, hFile, dwFlags ); + + // FIXME: Don't throw away the flags parameter + return createShim( lpLibFileName_cstr, targetAPI ); +} + +static BOOL WINAPI OVRGetModuleHandleExA( + __in DWORD dwFlags, + __in_opt LPCSTR lpModuleName, + __out HMODULE *phModule + ) +{ + OVRTargetAPI targetAPI = DirectX; + + bool needShim = checkForOverride( lpModuleName, targetAPI ); + if( !needShim ) + { + return (*oldProcModExA)( dwFlags, lpModuleName, phModule ); + } + + *phModule = createShim( lpModuleName, targetAPI ); + + return TRUE; +} + +static BOOL WINAPI OVRGetModuleHandleExW( + __in DWORD dwFlags, + __in_opt LPCWSTR lpModuleName, + __out HMODULE *phModule + ) +{ + WIDE_TO_MB(lpModuleName); // Convert lpModuleName -> lpModuleName_cstr + + OVRTargetAPI targetAPI = DirectX; + + bool needShim = checkForOverride( lpModuleName_cstr, targetAPI ); + if( !needShim ) + { + return (*oldProcModExW)( dwFlags, lpModuleName, phModule ); + } + + *phModule = createShim( lpModuleName_cstr, targetAPI ); + + return TRUE; +} + +#ifdef _AMD64_ +static void restoreFunction( PROC pfnHookAPIAddr, PBYTE oldData ) +{ + static const LONGLONG addressSize = sizeof(PROC); + static const LONGLONG jmpSize = addressSize + 6; + + DWORD oldProtect; + VirtualProtect((LPVOID)pfnHookAPIAddr, OLD_DATA_BACKUP_SIZE, + PAGE_EXECUTE_READWRITE, &oldProtect); + + memcpy(pfnHookAPIAddr, oldData, OLD_DATA_BACKUP_SIZE); + + VirtualProtect((LPVOID)pfnHookAPIAddr, OLD_DATA_BACKUP_SIZE, oldProtect, NULL); +} + +static void setFunction( PROC pfnHookAPIAddr, PROC replacementFunction, PBYTE oldData ) +{ + static const LONGLONG addressSize = sizeof(PROC); + static const LONGLONG jmpSize = addressSize + 6; + + INT_PTR jumpOffset = (INT_PTR)replacementFunction; + + DWORD oldProtect; + VirtualProtect((LPVOID)pfnHookAPIAddr, OLD_DATA_BACKUP_SIZE, + PAGE_EXECUTE_READWRITE, &oldProtect); + + memcpy(oldData, pfnHookAPIAddr, OLD_DATA_BACKUP_SIZE); + + PBYTE functionData = (PBYTE)pfnHookAPIAddr; + functionData[0] = 0xff; // JMP [RIP+0] + functionData[1] = 0x25; // + functionData[2] = 0x00; // + functionData[3] = 0x00; // + functionData[4] = 0x00; // + functionData[5] = 0x00; // + memcpy( functionData + 6, &jumpOffset, sizeof( INT_PTR ) ); + + VirtualProtect((LPVOID)pfnHookAPIAddr, OLD_DATA_BACKUP_SIZE, oldProtect, NULL); +} +#else +static void restoreFunction( PROC pfnHookAPIAddr, PBYTE oldData ) +{ + static const LONGLONG addressSize = sizeof(PROC); + static const LONGLONG jmpSize = addressSize + 1; + + DWORD oldProtect; + VirtualProtect((LPVOID)pfnHookAPIAddr, jmpSize, + PAGE_EXECUTE_READWRITE, &oldProtect); + + memcpy(pfnHookAPIAddr, oldData, jmpSize); + + VirtualProtect((LPVOID)pfnHookAPIAddr, jmpSize, oldProtect, NULL); +} + +static void setFunction( PROC pfnHookAPIAddr, PROC replacementFunction, PBYTE oldData ) +{ + static const LONGLONG addressSize = sizeof(PROC); + static const LONGLONG jmpSize = addressSize + 1; + + INT_PTR jumpOffset = (INT_PTR)replacementFunction - (INT_PTR)pfnHookAPIAddr - jmpSize; + + DWORD oldProtect; + VirtualProtect((LPVOID)pfnHookAPIAddr, jmpSize, + PAGE_EXECUTE_READWRITE, &oldProtect); + + memcpy(oldData, pfnHookAPIAddr, jmpSize); + + PBYTE functionData = (PBYTE)pfnHookAPIAddr; + memcpy( oldData, functionData, jmpSize ); + functionData[0] = 0xe9; + memcpy( functionData + 1, &jumpOffset, sizeof( INT_PTR ) ); + + VirtualProtect((LPVOID)pfnHookAPIAddr, jmpSize, oldProtect, NULL); +} +#endif + +static BOOL WINAPI OVRLocalIsInitializingDisplay( PVOID context, UINT width, UINT height ) +{ + UINT expectedWidth, expectedHeight, rotation; + + OVRExpectedResolution( context, &expectedWidth, &expectedHeight, &rotation ); + + if( appDriver.pfnActiveAPIVersion ) + apiVersion = (*appDriver.pfnActiveAPIVersion)( context ); + + switch( apiVersion ) + { + case 1: // OpenGL + case 10: // DirectX 1X + if( width == expectedWidth && height == expectedHeight ) + return TRUE; + break; + case 9: // DirectX 9 + if( rotation == 90 || rotation == 270 ) + { + if( width == expectedHeight && height == expectedWidth ) + return TRUE; + } + else + { + if( width == expectedWidth && height == expectedHeight ) + return TRUE; + } + break; + default: + break; + } + + return FALSE; +} + + +HRESULT APIENTRY OVRDirect3DCreate9Ex(UINT SDKVersion, void** aDevice) +{ + apiVersion = 9; + + HRESULT result = S_OK; + + restoreFunction( (PROC)oldDirectX9ExCreate, oldDirectX9ExCreateData ); + + if (IN_COMPATIBILITY_MODE()) + { + result = (*oldDirectX9ExCreate)(SDKVersion, aDevice); + } + else + { + WinDirect3DCreate9Ex createFunction = (WinDirect3DCreate9Ex)GetProcAddress(rtFilterModule, "Direct3DCreate9Ex"); + result = (*createFunction)(SDKVersion, aDevice); + } + + setFunction( (PROC)oldDirectX9ExCreate, (PROC)OVRDirect3DCreate9Ex, oldDirectX9ExCreateData ); + + printf("%s result 0x%x\n", __FUNCTION__, result); + + return result; +} + +void* APIENTRY OVRDirect3DCreate9(UINT SDKVersion) +{ + void* result = NULL; + + OVRDirect3DCreate9Ex( SDKVersion, &result ); + + return result; +} + +HRESULT APIENTRY OVRCreateDXGIFactory( + __in REFIID riid, + __out void **ppFactory + ) +{ + HRESULT result = E_FAIL; + + restoreFunction( (PROC)oldCreateDXGIFactory, oldCreateDXGIFactoryData ); + + if (IN_COMPATIBILITY_MODE()) + { + result = (*oldCreateDXGIFactory)(riid, ppFactory); + } + else + { + WinCreateDXGIFactory createFunction = (WinCreateDXGIFactory)GetProcAddress(rtFilterModule, "CreateDXGIFactory"); + result = (*createFunction)(riid, ppFactory); + } + + setFunction( (PROC)oldCreateDXGIFactory, (PROC)OVRCreateDXGIFactory, oldCreateDXGIFactoryData ); + + printf("%s result 0x%x\n", __FUNCTION__, result); + + return result; +} + +HRESULT APIENTRY OVRCreateDXGIFactory1( + __in REFIID riid, + __out void **ppFactory + ) +{ + HRESULT result = E_FAIL; + + restoreFunction( (PROC)oldCreateDXGIFactory1, oldCreateDXGIFactory1Data ); + + if (IN_COMPATIBILITY_MODE()) + { + result = (*oldCreateDXGIFactory1)(riid, ppFactory); + } + else + { + WinCreateDXGIFactory1 createFunction = (WinCreateDXGIFactory1)GetProcAddress(rtFilterModule, "CreateDXGIFactory1"); + result = (*createFunction)(riid, ppFactory); + } + + setFunction( (PROC)oldCreateDXGIFactory1, (PROC)OVRCreateDXGIFactory1, oldCreateDXGIFactory1Data ); + + printf("%s result 0x%x\n", __FUNCTION__, result); + + return result; +} + +HRESULT APIENTRY OVRCreateDXGIFactory2( + __in UINT flags, + __in const IID &riid, + __out void **ppFactory + ) +{ + HRESULT result = E_FAIL; + + restoreFunction( (PROC)oldCreateDXGIFactory2, oldCreateDXGIFactory2Data ); + + if (IN_COMPATIBILITY_MODE()) + { + result = (*oldCreateDXGIFactory2)(flags, riid, ppFactory); + } + else + { + WinCreateDXGIFactory2 createFunction = (WinCreateDXGIFactory2)GetProcAddress(rtFilterModule, "CreateDXGIFactory2"); + result = (*createFunction)(flags, riid, ppFactory); + } + + setFunction( (PROC)oldCreateDXGIFactory2, (PROC)OVRCreateDXGIFactory2, oldCreateDXGIFactory2Data ); + + printf("%s result 0x%x\n", __FUNCTION__, result); + + return result; +} + +static PROC SetProcAddressDirect( + __in HINSTANCE hInstance, + __in LPCSTR lpProcName, + __in PROC newFunction, + __inout BYTE* oldData + ) +{ + static const LONGLONG addressSize = sizeof(PROC); + static const LONGLONG jmpSize = addressSize + 1; + + PROC result = NULL; + + PROC pfnHookAPIAddr = GetProcAddress( hInstance, lpProcName ); + + if( pfnHookAPIAddr ) + { + result = pfnHookAPIAddr; + + setFunction( pfnHookAPIAddr, newFunction, oldData ); + } + + return result; +} + +static PROC SetProcAddressA( + __in HINSTANCE targetModule, + __in LPCSTR lpLibFileName, + __in LPCSTR lpProcName, + __in PROC newFunction, + __in PROC origFunction = NULL + ) +{ + HMODULE hModule = LoadLibraryA( lpLibFileName ); + if(hModule == NULL) + return NULL; + + // To do: call FreeLibrary(hModule) at the appropriate time. + PROC pfnHookAPIAddr = (origFunction != NULL) ? origFunction : GetProcAddress(hModule, lpProcName ); + + HINSTANCE hInstance = targetModule; + + ULONG ulSize; + PIMAGE_IMPORT_DESCRIPTOR pImportDesc = + (PIMAGE_IMPORT_DESCRIPTOR)ImageDirectoryEntryToData( + hInstance, + TRUE, + IMAGE_DIRECTORY_ENTRY_IMPORT, + &ulSize + ); + + while (pImportDesc->Name) + { + PSTR pszModName = (PSTR)((PBYTE) hInstance + pImportDesc->Name); + if (_stricmp(pszModName, lpLibFileName) == 0) + break; + pImportDesc++; + } + + PIMAGE_THUNK_DATA pThunk = + (PIMAGE_THUNK_DATA)((PBYTE) hInstance + pImportDesc->FirstThunk); + + while (pThunk->u1.Function) + { + PROC* ppfn = (PROC*) &pThunk->u1.Function; + BOOL bFound = (*ppfn == pfnHookAPIAddr); + + if (bFound) + { + MEMORY_BASIC_INFORMATION mbi; + VirtualQuery( + ppfn, + &mbi, + sizeof(MEMORY_BASIC_INFORMATION) + ); + VirtualProtect( + mbi.BaseAddress, + mbi.RegionSize, + PAGE_READWRITE, + &mbi.Protect); + + *ppfn = *newFunction; + + DWORD dwOldProtect; + VirtualProtect( + mbi.BaseAddress, + mbi.RegionSize, + mbi.Protect, + &dwOldProtect + ); + break; + } + pThunk++; + } + + return pfnHookAPIAddr; +} + +void clearUMDriverOverrides() +{ + if (lastContext != NULL) + { + // Unpatch all the things. + + if (oldCreateDXGIFactory) + { + restoreFunction((PROC)oldCreateDXGIFactory, oldCreateDXGIFactoryData); + } + if (oldCreateDXGIFactory1) + { + restoreFunction((PROC)oldCreateDXGIFactory1, oldCreateDXGIFactory1Data); + } + if (oldCreateDXGIFactory2) + { + restoreFunction((PROC)oldCreateDXGIFactory2, oldCreateDXGIFactory2Data); + } + if (oldDirectX9Create) + { + restoreFunction((PROC)oldDirectX9Create, oldDirectX9CreateData); + } + if (oldDirectX9ExCreate) + { + restoreFunction((PROC)oldDirectX9ExCreate, oldDirectX9ExCreateData); + } + if (oldCreateDXGIFactory2) + { + restoreFunction((PROC)oldCreateDXGIFactory2, oldCreateDXGIFactory2Data); + } + + for (int i = 0; i < ShimCountMax; ++i) + { + HINSTANCE hInst = oldLoaderInstances[i]; + + if (hInst != NULL) + { + for (int j = 0; j < NUM_LOADER_LIBS; ++j) + { + const char* loaderLibrary = loaderLibraryList[j]; + + if (oldLoaderProcA[j]) + { + SetProcAddressA(hInst, loaderLibrary, "LoadLibraryA", oldLoaderProcA[i][j], (PROC)OVRLoadLibraryA); + } + if (oldLoaderProcW[j]) + { + SetProcAddressA(hInst, loaderLibrary, "LoadLibraryW", oldLoaderProcW[i][j], (PROC)OVRLoadLibraryW); + } + if (oldLoaderProcExA[j]) + { + SetProcAddressA(hInst, loaderLibrary, "LoadLibraryExA", oldLoaderProcExA[i][j], (PROC)OVRLoadLibraryExA); + } + if (oldLoaderProcExW[j]) + { + SetProcAddressA(hInst, loaderLibrary, "LoadLibraryExW", oldLoaderProcExW[i][j], (PROC)OVRLoadLibraryExW); + } + if (oldLoaderProcModExA[j]) + { + SetProcAddressA(hInst, loaderLibrary, "GetModuleHandleExA", oldLoaderProcModExA[i][j], (PROC)OVRGetModuleHandleExA); + } + if (oldLoaderProcModExW[j]) + { + SetProcAddressA(hInst, loaderLibrary, "GetModuleHandleExW", oldLoaderProcModExW[i][j], (PROC)OVRGetModuleHandleExW); + } + } + + FreeLibrary(hInst); + } + } + + if (rtFilterModule != NULL) + { + LPFNCANUNLOADNOW pfnCanUnloadNow = (LPFNCANUNLOADNOW)GetProcAddress(rtFilterModule, "DllCanUnloadNow"); + if (pfnCanUnloadNow && pfnCanUnloadNow() == S_OK) + { + FreeLibrary(rtFilterModule); + rtFilterModule = NULL; + } + } + + // Do not set the following to NULL, because it's possible that we may be called after this function is executed. + // This is due to quirks in how third party DLLs implement their linking to these functions. In any case + // this module will be going away within a month or so of this writing. + //oldProcA = NULL; + //oldProcExA = NULL; + //oldProcW = NULL; + //oldProcExW = NULL; + //oldProcModExA = NULL; + //oldProcModExW = NULL; + + oldDirectX9Create = NULL; + oldDirectX9ExCreate = NULL; + oldCreateDXGIFactory = NULL; + oldCreateDXGIFactory1 = NULL; + oldCreateDXGIFactory2 = NULL; + + lastContext = NULL; + } +} + +bool checkUMDriverOverrides(void* context) +{ + lastContext = context; + if (oldCreateDXGIFactory != NULL) + return true; + + PreloadLibraryRTFn loadFunc = NULL; + for( int i = 0; i < ShimCountMax; ++i ) + { + HINSTANCE hInst = NULL; + try + { + hInst = LoadLibraryA(dllList[i]); + } + catch(...) + { + } + + oldLoaderInstances[i] = hInst; + + if (hInst == NULL) + continue; + + ShimedLibraries libCount = (ShimedLibraries)i; + switch( libCount ) + { + case ShimLibDXGI: + oldCreateDXGIFactory = (WinCreateDXGIFactory)SetProcAddressDirect( hInst, "CreateDXGIFactory", (PROC)OVRCreateDXGIFactory, oldCreateDXGIFactoryData ); + oldCreateDXGIFactory1 = (WinCreateDXGIFactory1)SetProcAddressDirect( hInst, "CreateDXGIFactory1", (PROC)OVRCreateDXGIFactory1, oldCreateDXGIFactory1Data ); + oldCreateDXGIFactory2 = (WinCreateDXGIFactory2)SetProcAddressDirect( hInst, "CreateDXGIFactory2", (PROC)OVRCreateDXGIFactory2, oldCreateDXGIFactory2Data ); + break; + case ShimLibD3D9: + oldDirectX9Create = (WinDirect3DCreate9)SetProcAddressDirect( hInst, "Direct3DCreate9", (PROC)OVRDirect3DCreate9, oldDirectX9CreateData ); + oldDirectX9ExCreate = (WinDirect3DCreate9Ex)SetProcAddressDirect( hInst, "Direct3DCreate9Ex", (PROC)OVRDirect3DCreate9Ex, oldDirectX9ExCreateData ); + break; + default: + break; + } + + for (int j = 0; j < NUM_LOADER_LIBS; ++j) + { + const char* loaderLibrary = loaderLibraryList[j]; + + PROC temp = NULL; + temp = SetProcAddressA(hInst, loaderLibrary, "LoadLibraryA", (PROC)OVRLoadLibraryA); + if (!oldProcA) + { + oldProcA = (WinLoadLibraryA)temp; + } + oldLoaderProcA[i][j] = temp; + + temp = SetProcAddressA(hInst, loaderLibrary, "LoadLibraryW", (PROC)OVRLoadLibraryW); + if (!oldProcW) + { + oldProcW = (WinLoadLibraryW)temp; + } + oldLoaderProcW[i][j] = temp; + + temp = SetProcAddressA(hInst, loaderLibrary, "LoadLibraryExA", (PROC)OVRLoadLibraryExA); + if (!oldProcExA) + { + oldProcExA = (WinLoadLibraryExA)temp; + } + oldLoaderProcExA[i][j] = temp; + + temp = SetProcAddressA(hInst, loaderLibrary, "LoadLibraryExW", (PROC)OVRLoadLibraryExW); + if (!oldProcExW) + { + oldProcExW = (WinLoadLibraryExW)temp; + } + oldLoaderProcExW[i][j] = temp; + + temp = SetProcAddressA(hInst, loaderLibrary, "GetModuleHandleExA", (PROC)OVRGetModuleHandleExA); + if (!oldProcModExA) + { + oldProcModExA = (WinGetModuleHandleExA)temp; + } + oldLoaderProcModExA[i][j] = temp; + + temp = SetProcAddressA(hInst, loaderLibrary, "GetModuleHandleExW", (PROC)OVRGetModuleHandleExW); + if (!oldProcModExW) + { + oldProcModExW = (WinGetModuleHandleExW)temp; + } + oldLoaderProcModExW[i][j] = temp; + } + + if (loadFunc == NULL) + { + loadFunc = (PreloadLibraryRTFn)GetProcAddress(hInst, "PreloadLibraryRT"); + } + } + + rtFilterModule = oldProcA ? (*oldProcA)(RTFilter) : NULL; + + IsCreatingBackBuffer backBufferFunc = NULL; + ShouldVSync shouldVSyncFunc = NULL; + GetRTFilterVersion getRTFilterVersionFunc = NULL; + + if (rtFilterModule != NULL) + { + loadFunc = (PreloadLibraryRTFn)GetProcAddress(rtFilterModule, "PreloadLibraryRT"); + backBufferFunc = (IsCreatingBackBuffer)GetProcAddress(rtFilterModule, "OVRIsCreatingBackBuffer"); + shouldVSyncFunc = (ShouldVSync)GetProcAddress(rtFilterModule, "OVRShouldVSync"); + getRTFilterVersionFunc = (GetRTFilterVersion)GetProcAddress(rtFilterModule, "OVRGetRTFilterVersion"); + } + + if (loadFunc == NULL) + { + MessageBox(nullptr, L"Unable to load the Oculus Display Driver. Please reinstall the Oculus Runtime.", L"Configuration Error", MB_OK | MB_ICONERROR); + clearUMDriverOverrides(); + return false; + } + + if (getRTFilterVersionFunc == NULL) + { + WCHAR message[1000] = {}; + swprintf_s(message, L"This app requires the %d.%d.%d version of the Oculus Runtime.", OVR_MAJOR_VERSION, OVR_MINOR_VERSION, OVR_BUILD_NUMBER); + MessageBox(nullptr, message, L"Configuration Error", MB_OK | MB_ICONERROR); + clearUMDriverOverrides(); + return false; + } + + // Verify that we are running with the appropriate display driver + { + const ULONG rtFilterVersion = (*getRTFilterVersionFunc)(); + const ULONG rtFilterMajor = OVR_GET_VERSION_MAJOR(rtFilterVersion); + const ULONG rtFilterMinor = OVR_GET_VERSION_MINOR(rtFilterVersion); + + if ((rtFilterMajor != OVR_RTFILTER_VERSION_MAJOR) || (rtFilterMinor < OVR_RTFILTER_VERSION_MINOR)) + { + WCHAR message[1000] = {}; + swprintf_s(message, L"This app requires the %d.%d.%d version of the Oculus Runtime.", OVR_MAJOR_VERSION, OVR_MINOR_VERSION, OVR_BUILD_NUMBER); + MessageBox(nullptr, message, L"Configuration Error", MB_OK | MB_ICONERROR); + clearUMDriverOverrides(); + return false; + } + } + + appDriver.version = OVR_RENDER_SHIM_VERSION_MAJOR; + appDriver.context = lastContext; + +// appDriver.pfnInitializingDisplay = OVRIsInitializingDisplay; + appDriver.pfnInitializingDisplay = OVRLocalIsInitializingDisplay; + appDriver.pfnRiftForContext = OVRRiftForContext; + appDriver.pfnCloseRiftForContext = OVRCloseRiftForContext; + appDriver.pfnWindowDisplayResolution = OVRWindowDisplayResolution; + appDriver.pfnShouldEnableDebug = OVRShouldEnableDebug; + appDriver.pfnIsCreatingBackBuffer = (backBufferFunc == NULL) ? OVRIsCreatingBackBuffer : backBufferFunc; + appDriver.pfnShouldVSync = (shouldVSyncFunc == NULL) ? OVRShouldVSync : shouldVSyncFunc; + appDriver.pfnExpectedResolution = OVRExpectedResolution; + appDriver.pfnMirroringEnabled = OVRMirroringEnabled; + appDriver.pfnGetWindowForContext = OVRGetWindowForContext; + appDriver.pfnPresentRiftOnContext = OVRShouldPresentOnContext; + + appDriver.pfnDirect3DCreate9 = oldDirectX9Create; + appDriver.pfnDirect3DCreate9Ex = oldDirectX9ExCreate; + appDriver.pfnCreateDXGIFactory = oldCreateDXGIFactory; + appDriver.pfnCreateDXGIFactory1 = oldCreateDXGIFactory1; + appDriver.pfnCreateDXGIFactory2 = oldCreateDXGIFactory2; + + (*loadFunc)( &appDriver ); + + return true; +} |