Skip to content

Commit

Permalink
D3D12Hook: Initial fixes for frame generation crashes
Browse files Browse the repository at this point in the history
  • Loading branch information
praydog committed Jun 2, 2024
1 parent 654bac5 commit c07073b
Show file tree
Hide file tree
Showing 2 changed files with 119 additions and 30 deletions.
132 changes: 105 additions & 27 deletions src/D3D12Hook.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@
#include <future>
#include <unordered_set>
#include <stacktrace>
#include <wrl/client.h>

#include <spdlog/spdlog.h>
#include <utility/Thread.hpp>
#include <utility/Module.hpp>
#include <utility/String.hpp>
#include <utility/RTTI.hpp>
#include <utility/Scan.hpp>

#include "REFramework.hpp"

Expand All @@ -21,6 +23,66 @@ D3D12Hook::~D3D12Hook() {
unhook();
}

void* D3D12Hook::Streamline::link_swapchain_to_cmd_queue(void* rcx, void* rdx, void* r8, void* r9) {
std::scoped_lock _{g_framework->get_hook_monitor_mutex()};

spdlog::info("[Streamline] linkSwapchainToCmdQueue: {:x}", (uintptr_t)_ReturnAddress());

g_framework->on_reset(); // Needed to prevent a crash due to resources hanging around
g_d3d12_hook->unhook(); // Removes all vtable hooks

auto& hook = g_d3d12_hook->m_streamline.link_swapchain_to_cmd_queue_hook;
const auto result = hook->get_original<decltype(link_swapchain_to_cmd_queue)>()(rcx, rdx, r8, r9);

return result;
}

void D3D12Hook::hook_streamline() {
if (m_streamline.setup) {
return;
}

spdlog::info("[Streamline] Hooking Streamline");

const auto dlssg_module = GetModuleHandleW(L"sl.dlss_g.dll");

if (dlssg_module == nullptr) {
spdlog::error("[Streamline] Failed to get sl.dlss_g.dll module handle");
return;
}

const auto str = utility::scan_string(dlssg_module, "linkSwapchainToCmdQueue");

if (!str) {
spdlog::error("[Streamline] Failed to find linkSwapchainToCmdQueue");
return;
}

const auto str_ref = utility::scan_displacement_reference(dlssg_module, *str);

if (!str_ref) {
spdlog::error("[Streamline] Failed to find linkSwapchainToCmdQueue reference");
return;
}

const auto fn = utility::find_function_start_with_call(*str_ref);

if (!fn) {
spdlog::error("[Streamline] Failed to find linkSwapchainToCmdQueue function");
return;
}

m_streamline.link_swapchain_to_cmd_queue_hook = std::make_unique<FunctionHook>(*fn, (uintptr_t)&Streamline::link_swapchain_to_cmd_queue);

if (m_streamline.link_swapchain_to_cmd_queue_hook->create()) {
spdlog::info("[Streamline] Hooked linkSwapchainToCmdQueue");
} else {
spdlog::error("[Streamline] Failed to hook linkSwapchainToCmdQueue");
}

m_streamline.setup = true;
}

bool D3D12Hook::hook() {
spdlog::info("Hooking D3D12");

Expand Down Expand Up @@ -330,7 +392,9 @@ bool D3D12Hook::hook() {
return false;
}

utility::ThreadSuspender suspender{};
hook_streamline();

//utility::ThreadSuspender suspender{};

try {
spdlog::info("Initializing hooks");
Expand All @@ -341,20 +405,21 @@ bool D3D12Hook::hook() {
m_is_phase_1 = true;

auto& present_fn = (*(void***)target_swapchain)[8]; // Present
m_present_hook = std::make_unique<PointerHook>(&present_fn, (void*)&D3D12Hook::present);
m_present_hook = std::make_unique<FunctionHook>((uintptr_t)present_fn, (uintptr_t)&D3D12Hook::present);
m_present_hook->create();
m_hooked = true;
} catch (const std::exception& e) {
spdlog::error("Failed to initialize hooks: {}", e.what());
m_hooked = false;
}

suspender.resume();
//suspender.resume();

device->Release();
command_queue->Release();
factory->Release();
swap_chain1->Release();
swap_chain->Release();
device->Release();
factory->Release();

if (hwnd) {
::DestroyWindow(hwnd);
Expand All @@ -368,6 +433,8 @@ bool D3D12Hook::hook() {
}

bool D3D12Hook::unhook() {
std::scoped_lock _{g_framework->get_hook_monitor_mutex()};

if (!m_hooked) {
return true;
}
Expand All @@ -385,57 +452,68 @@ bool D3D12Hook::unhook() {

thread_local int32_t g_present_depth = 0;

HRESULT WINAPI D3D12Hook::present(IDXGISwapChain3* swap_chain, UINT sync_interval, UINT flags) {
HRESULT WINAPI D3D12Hook::present(IDXGISwapChain3* swap_chain, uint64_t sync_interval, uint64_t flags, void* r9) {
std::scoped_lock _{g_framework->get_hook_monitor_mutex()};

auto d3d12 = g_d3d12_hook;

HWND swapchain_wnd{nullptr};
swap_chain->GetHwnd(&swapchain_wnd);

decltype(D3D12Hook::present)* present_fn{nullptr};

//if (d3d12->m_is_phase_1) {
present_fn = d3d12->m_present_hook->get_original<decltype(D3D12Hook::present)*>();
/*} else {
if (d3d12->m_is_phase_1) {
//present_fn = d3d12->m_present_hook->get_original<decltype(D3D12Hook::present)*>();
present_fn = d3d12->m_present_hook->get_original<decltype(D3D12Hook::present)>();
} else {
present_fn = d3d12->m_swapchain_hook->get_method<decltype(D3D12Hook::present)*>(8);
}*/
}

HWND swapchain_wnd{nullptr};
swap_chain->GetHwnd(&swapchain_wnd);

if (d3d12->m_is_phase_1 && WindowFilter::get().is_filtered(swapchain_wnd)) {
//present_fn = d3d12->m_present_hook->get_original<decltype(D3D12Hook::present)*>();
return present_fn(swap_chain, sync_interval, flags);
return present_fn(swap_chain, sync_interval, flags, r9);
}

if (!d3d12->m_is_phase_1 && swap_chain != d3d12->m_swapchain_hook->get_instance()) {
return present_fn(swap_chain, sync_interval, flags);
return present_fn(swap_chain, sync_interval, flags, r9);
}

if (d3d12->m_is_phase_1) {
//d3d12->m_present_hook.reset();
// Remove the present hook, we will just rely on the vtable hook below
// because we don't want to cause any conflicts with other hooks
// vtable hooks are the least intrusive
// And doing a global pointer replacement seems to have
// conflicts with Streamline's hooks, causing unexplainable crashes
d3d12->m_present_hook.reset();

// vtable hook the swapchain instead of global hooking
// this seems safer for whatever reason
// if we globally hook the vtable pointers, it causes all sorts of weird conflicts with other hooks
// dont hook present though via this hook so other hooks dont get confused
d3d12->m_swapchain_hook = std::make_unique<VtableHook>(swap_chain);
//d3d12->m_swapchain_hook->hook_method(8, (uintptr_t)&D3D12Hook::present);
//d3d12->m_swapchain_hook->hook_method(2, (uintptr_t)&D3D12Hook::release);
d3d12->m_swapchain_hook->hook_method(8, (uintptr_t)&D3D12Hook::present);
d3d12->m_swapchain_hook->hook_method(13, (uintptr_t)&D3D12Hook::resize_buffers);
d3d12->m_swapchain_hook->hook_method(14, (uintptr_t)&D3D12Hook::resize_target);
d3d12->m_is_phase_1 = false;

present_fn = d3d12->m_swapchain_hook->get_method<decltype(D3D12Hook::present)*>(8);
}

d3d12->m_inside_present = true;
d3d12->m_swap_chain = swap_chain;

swap_chain->GetDevice(IID_PPV_ARGS(&d3d12->m_device));
{
Microsoft::WRL::ComPtr<ID3D12Device4> temp_device{};
swap_chain->GetDevice(IID_PPV_ARGS(&temp_device));
d3d12->m_device = temp_device.Get();
}

if (d3d12->m_device != nullptr) {
if (d3d12->m_using_proton_swapchain) {
const auto real_swapchain = *(uintptr_t*)((uintptr_t)swap_chain + d3d12->m_proton_swapchain_offset);
d3d12->m_command_queue = *(ID3D12CommandQueue**)(real_swapchain + d3d12->m_command_queue_offset);
} else {
d3d12->m_command_queue = *(ID3D12CommandQueue**)((uintptr_t)swap_chain + d3d12->m_command_queue_offset);
}
if (d3d12->m_using_proton_swapchain) {
const auto real_swapchain = *(uintptr_t*)((uintptr_t)swap_chain + d3d12->m_proton_swapchain_offset);
d3d12->m_command_queue = *(ID3D12CommandQueue**)(real_swapchain + d3d12->m_command_queue_offset);
} else {
d3d12->m_command_queue = *(ID3D12CommandQueue**)((uintptr_t)swap_chain + d3d12->m_command_queue_offset);
}

if (d3d12->m_swapchain_0 == nullptr) {
Expand All @@ -462,7 +540,7 @@ HRESULT WINAPI D3D12Hook::present(IDXGISwapChain3* swap_chain, UINT sync_interva
spdlog::info("Attempting to call real present function");

++g_present_depth;
const auto result = present_fn(swap_chain, sync_interval, flags);
const auto result = present_fn(swap_chain, sync_interval, flags, r9);
--g_present_depth;

if (result != S_OK) {
Expand All @@ -485,7 +563,7 @@ HRESULT WINAPI D3D12Hook::present(IDXGISwapChain3* swap_chain, UINT sync_interva
auto result = S_OK;

if (!d3d12->m_ignore_next_present) {
result = present_fn(swap_chain, sync_interval, flags);
result = present_fn(swap_chain, sync_interval, flags, r9);

if (result != S_OK) {
spdlog::error("Present failed: {:x}", result);
Expand Down
17 changes: 14 additions & 3 deletions src/D3D12Hook.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <dxgi1_4.h>

#include "utility/PointerHook.hpp"
#include "utility/FunctionHook.hpp"
#include "utility/VtableHook.hpp"

class D3D12Hook
Expand Down Expand Up @@ -88,7 +89,7 @@ class D3D12Hook
bool is_proton_swapchain() const {
return m_using_proton_swapchain;
}

bool is_framegen_swapchain() const {
return m_using_frame_generation_swapchain;
}
Expand All @@ -97,6 +98,8 @@ class D3D12Hook
m_ignore_next_present = true;
}

void hook_streamline();

protected:
ID3D12Device4* m_device{ nullptr };
IDXGISwapChain3* m_swap_chain{ nullptr };
Expand All @@ -118,17 +121,25 @@ class D3D12Hook
bool m_inside_present{false};
bool m_ignore_next_present{false};

std::unique_ptr<PointerHook> m_present_hook{};
std::unique_ptr<FunctionHook> m_present_hook{};
//std::unique_ptr<PointerHook> m_release_hook{};
std::unique_ptr<VtableHook> m_swapchain_hook{};
//std::unique_ptr<FunctionHook> m_create_swap_chain_hook{};

struct Streamline {
static void* link_swapchain_to_cmd_queue(void* rcx, void* rdx, void* r8, void* r9);

std::unique_ptr<FunctionHook> link_swapchain_to_cmd_queue_hook{};
bool setup{ false };
} m_streamline{};

OnPresentFn m_on_present{ nullptr };
OnPresentFn m_on_post_present{ nullptr };
OnResizeBuffersFn m_on_resize_buffers{ nullptr };
OnResizeTargetFn m_on_resize_target{ nullptr };
//OnCreateSwapChainFn m_on_create_swap_chain{ nullptr };

static HRESULT WINAPI present(IDXGISwapChain3* swap_chain, UINT sync_interval, UINT flags);
static HRESULT WINAPI present(IDXGISwapChain3* swap_chain, uint64_t sync_interval, uint64_t flags, void* r9);
static HRESULT WINAPI resize_buffers(IDXGISwapChain3* swap_chain, UINT buffer_count, UINT width, UINT height, DXGI_FORMAT new_format, UINT swap_chain_flags);
static HRESULT WINAPI resize_target(IDXGISwapChain3* swap_chain, const DXGI_MODE_DESC* new_target_parameters);
//static HRESULT WINAPI create_swap_chain(IDXGIFactory4* factory, IUnknown* device, HWND hwnd, const DXGI_SWAP_CHAIN_DESC* desc, const DXGI_SWAP_CHAIN_FULLSCREEN_DESC* p_fullscreen_desc, IDXGIOutput* p_restrict_to_output, IDXGISwapChain** swap_chain);
Expand Down

0 comments on commit c07073b

Please sign in to comment.