diff --git a/subprojects/gst-plugins-bad/gst-libs/gst/d3d12/gstd3d12converter-builder.cpp b/subprojects/gst-plugins-bad/gst-libs/gst/d3d12/gstd3d12converter-builder.cpp index 2d7b6db8e9..c03527c766 100644 --- a/subprojects/gst-plugins-bad/gst-libs/gst/d3d12/gstd3d12converter-builder.cpp +++ b/subprojects/gst-plugins-bad/gst-libs/gst/d3d12/gstd3d12converter-builder.cpp @@ -121,34 +121,21 @@ gst_d3d12_get_converter_vertex_shader_blob (D3D12_SHADER_BYTECODE * vs, /* root signature * - * +-----+---------+--------------+ - * | RS | size in | | - * | idx | DWORD | | - * +-----+---------+--------------+ - * | 0 | 1 | table (SRV) | - * +-----+---------+--------------+ - * | 1 | 16 | VS matrix | - * +-----+---------+--------------+ - * | 2 | 1 | PS alpha | - * +-----+---------+--------------+ - * | 3 | 2 | PS CBV | - * +-----+---------+--------------+ + * +-----+---------+------------------+ + * | RS | size in | | + * | idx | DWORD | | + * +-----+---------+------------------+ + * | 0 | 1 | table (SRV) | + * +-----+---------+------------------+ + * | 1 | 1 | table (Sampler) | + * +-----+---------+------------------+ + * | 2 | 16 | VS matrix | + * +-----+---------+------------------+ + * | 3 | 1 | PS alpha | + * +-----+---------+------------------+ + * | 4 | 2 | PS CBV | + * +-----+---------+------------------+ */ -static const D3D12_STATIC_SAMPLER_DESC static_sampler_desc_ = { - D3D12_FILTER_MIN_MAG_LINEAR_MIP_POINT, - D3D12_TEXTURE_ADDRESS_MODE_CLAMP, - D3D12_TEXTURE_ADDRESS_MODE_CLAMP, - D3D12_TEXTURE_ADDRESS_MODE_CLAMP, - 0, - 1, - D3D12_COMPARISON_FUNC_ALWAYS, - D3D12_STATIC_BORDER_COLOR_OPAQUE_BLACK, - 0, - D3D12_FLOAT32_MAX, - 0, - 0, - D3D12_SHADER_VISIBILITY_PIXEL -}; static const D3D12_ROOT_SIGNATURE_FLAGS rs_flags_ = D3D12_ROOT_SIGNATURE_FLAG_ALLOW_INPUT_ASSEMBLER_INPUT_LAYOUT | @@ -159,28 +146,15 @@ static const D3D12_ROOT_SIGNATURE_FLAGS rs_flags_ = D3D12_ROOT_SIGNATURE_FLAG_DENY_MESH_SHADER_ROOT_ACCESS; ConverterRootSignature::ConverterRootSignature (D3D_ROOT_SIGNATURE_VERSION - version, UINT num_srv, D3D12_FILTER filter, bool build_lut) + version, UINT num_srv, bool build_lut) { D3D12_VERSIONED_ROOT_SIGNATURE_DESC desc = { }; num_srv_ = num_srv; have_lut_ = build_lut; - std::vector < D3D12_STATIC_SAMPLER_DESC > static_sampler; - D3D12_STATIC_SAMPLER_DESC sampler_desc = static_sampler_desc_; - sampler_desc.Filter = filter; - if (filter == D3D12_FILTER_ANISOTROPIC) - sampler_desc.MaxAnisotropy = 16; - - static_sampler.push_back (sampler_desc); - - if (build_lut) { - sampler_desc = static_sampler_desc_; - sampler_desc.ShaderRegister = 1; - static_sampler.push_back (sampler_desc); - } - std::vector < D3D12_DESCRIPTOR_RANGE1 > range_v1_1; + std::vector < D3D12_DESCRIPTOR_RANGE1 > sampler_range_v1_1; std::vector < D3D12_ROOT_PARAMETER1 > param_list_v1_1; CD3DX12_ROOT_PARAMETER1 param; @@ -207,6 +181,20 @@ ConverterRootSignature::ConverterRootSignature (D3D_ROOT_SIGNATURE_VERSION range_v1_1.data (), D3D12_SHADER_VISIBILITY_PIXEL); param_list_v1_1.push_back (param); + /* sampler state, can be updated */ + ps_sampler_ = (UINT) param_list_v1_1.size (); + sampler_range_v1_1.push_back (CD3DX12_DESCRIPTOR_RANGE1 + (D3D12_DESCRIPTOR_RANGE_TYPE_SAMPLER, 1, 0, 0, + D3D12_DESCRIPTOR_RANGE_FLAG_DESCRIPTORS_VOLATILE)); + if (build_lut) { + sampler_range_v1_1.push_back (CD3DX12_DESCRIPTOR_RANGE1 + (D3D12_DESCRIPTOR_RANGE_TYPE_SAMPLER, 1, 1, 0, + D3D12_DESCRIPTOR_RANGE_FLAG_DESCRIPTORS_VOLATILE)); + } + param.InitAsDescriptorTable (sampler_range_v1_1.size (), + sampler_range_v1_1.data (), D3D12_SHADER_VISIBILITY_PIXEL); + param_list_v1_1.push_back (param); + /* VS root const, maybe updated */ vs_root_const_ = (UINT) param_list_v1_1.size (); param.InitAsConstants (16, 0, 0, D3D12_SHADER_VISIBILITY_VERTEX); @@ -225,8 +213,7 @@ ConverterRootSignature::ConverterRootSignature (D3D_ROOT_SIGNATURE_VERSION param_list_v1_1.push_back (param); CD3DX12_VERSIONED_ROOT_SIGNATURE_DESC::Init_1_1 (desc, - param_list_v1_1.size (), param_list_v1_1.data (), - static_sampler.size (), static_sampler.data (), rs_flags_); + param_list_v1_1.size (), param_list_v1_1.data (), 0, nullptr, rs_flags_); ComPtr < ID3DBlob > error_blob; hr_ = D3DX12SerializeVersionedRootSignature (&desc, @@ -242,7 +229,7 @@ ConverterRootSignature::ConverterRootSignature (D3D_ROOT_SIGNATURE_VERSION ConverterRootSignaturePtr gst_d3d12_get_converter_root_signature (GstD3D12Device * device, - GstVideoFormat in_format, CONVERT_TYPE type, D3D12_FILTER filter) + GstVideoFormat in_format, CONVERT_TYPE type) { auto info = gst_video_format_get_info (in_format); auto num_planes = GST_VIDEO_FORMAT_INFO_N_PLANES (info); @@ -264,7 +251,7 @@ gst_d3d12_get_converter_root_signature (GstD3D12Device * device, build_lut = true; auto rs = std::make_shared < ConverterRootSignature > - (rs_version, num_planes, filter, build_lut); + (rs_version, num_planes, build_lut); if (!rs->IsValid ()) return nullptr; diff --git a/subprojects/gst-plugins-bad/gst-libs/gst/d3d12/gstd3d12converter-builder.h b/subprojects/gst-plugins-bad/gst-libs/gst/d3d12/gstd3d12converter-builder.h index 7cd772f8a4..d8ec798483 100644 --- a/subprojects/gst-plugins-bad/gst-libs/gst/d3d12/gstd3d12converter-builder.h +++ b/subprojects/gst-plugins-bad/gst-libs/gst/d3d12/gstd3d12converter-builder.h @@ -47,13 +47,18 @@ class ConverterRootSignature public: ConverterRootSignature () = delete; ConverterRootSignature (D3D_ROOT_SIGNATURE_VERSION version, UINT num_srv, - D3D12_FILTER filter, bool build_lut); + bool build_lut); UINT GetPsSrvIdx () { return ps_srv_; } + UINT GetPsSamplerIdx () + { + return ps_sampler_; + } + UINT GetNumSrv () { return num_srv_; @@ -98,6 +103,7 @@ private: Microsoft::WRL::ComPtr blob_; UINT ps_srv_ = 0; + UINT ps_sampler_ = 0; UINT ps_cbv_ = 0; UINT vs_root_const_ = 0; UINT num_srv_ = 0; @@ -123,5 +129,4 @@ gst_d3d12_get_converter_vertex_shader_blob (D3D12_SHADER_BYTECODE * vs, ConverterRootSignaturePtr gst_d3d12_get_converter_root_signature (GstD3D12Device * device, GstVideoFormat in_format, - CONVERT_TYPE type, - D3D12_FILTER filter); \ No newline at end of file + CONVERT_TYPE type); \ No newline at end of file diff --git a/subprojects/gst-plugins-bad/gst-libs/gst/d3d12/gstd3d12converter.cpp b/subprojects/gst-plugins-bad/gst-libs/gst/d3d12/gstd3d12converter.cpp index d549dec778..671f1539ba 100644 --- a/subprojects/gst-plugins-bad/gst-libs/gst/d3d12/gstd3d12converter.cpp +++ b/subprojects/gst-plugins-bad/gst-libs/gst/d3d12/gstd3d12converter.cpp @@ -274,6 +274,7 @@ struct _GstD3D12ConverterPrivate ComPtr gamma_enc_lut; D3D12_PLACED_SUBRESOURCE_FOOTPRINT gamma_lut_layout; ComPtr gamma_lut_heap; + ComPtr sampler_heap; std::vector quad_data; @@ -281,6 +282,7 @@ struct _GstD3D12ConverterPrivate guint srv_inc_size; guint rtv_inc_size; + guint sampler_inc_size; guint64 input_texture_width; guint input_texture_height; @@ -673,11 +675,13 @@ gst_d3d12_converter_setup_resource (GstD3D12Converter * self, (D3D12_DESCRIPTOR_HEAP_TYPE_CBV_SRV_UAV); priv->rtv_inc_size = device->GetDescriptorHandleIncrementSize (D3D12_DESCRIPTOR_HEAP_TYPE_RTV); + priv->sampler_inc_size = device->GetDescriptorHandleIncrementSize + (D3D12_DESCRIPTOR_HEAP_TYPE_SAMPLER); ComPtr < ID3DBlob > rs_blob; priv->crs = gst_d3d12_get_converter_root_signature (self->device, - GST_VIDEO_INFO_FORMAT (in_info), priv->convert_type, sampler_filter); + GST_VIDEO_INFO_FORMAT (in_info), priv->convert_type); if (!priv->crs) { GST_ERROR_OBJECT (self, "Couldn't get root signature blob"); return FALSE; @@ -691,6 +695,60 @@ gst_d3d12_converter_setup_resource (GstD3D12Converter * self, return FALSE; } + ComPtr < ID3D12DescriptorHeap > sampler_heap; + hr = gst_d3d12_device_get_sampler_state (self->device, sampler_filter, + &sampler_heap); + if (FAILED (hr) && sampler_filter != D3D12_FILTER_MIN_MAG_LINEAR_MIP_POINT) { + sampler_filter = D3D12_FILTER_MIN_MAG_LINEAR_MIP_POINT; + + GST_WARNING_OBJECT (self, + "Couldn't create requested sampler, trying linear sampler"); + hr = gst_d3d12_device_get_sampler_state (self->device, + sampler_filter, &sampler_heap); + } + + if (!gst_d3d12_result (hr, self->device)) { + GST_ERROR_OBJECT (self, "Couldn't create sampler"); + return FALSE; + } + + if (priv->crs->HaveLut ()) { + D3D12_DESCRIPTOR_HEAP_DESC heap_desc = { }; + heap_desc.NumDescriptors = 1; + heap_desc.Type = D3D12_DESCRIPTOR_HEAP_TYPE_SAMPLER; + heap_desc.Flags = D3D12_DESCRIPTOR_HEAP_FLAG_SHADER_VISIBLE; + hr = device->CreateDescriptorHeap (&heap_desc, + IID_PPV_ARGS (&priv->sampler_heap)); + if (!gst_d3d12_result (hr, self->device)) { + GST_ERROR_OBJECT (self, "Couldn't create sampler heap"); + return FALSE; + } + + auto dst_handle = CD3DX12_CPU_DESCRIPTOR_HANDLE + (GetCPUDescriptorHandleForHeapStart (priv->sampler_heap)); + device->CopyDescriptorsSimple (1, dst_handle, + GetCPUDescriptorHandleForHeapStart (sampler_heap), + D3D12_DESCRIPTOR_HEAP_TYPE_SAMPLER); + + if (sampler_filter != D3D12_FILTER_MIN_MAG_LINEAR_MIP_POINT) { + hr = gst_d3d12_device_get_sampler_state (self->device, + D3D12_FILTER_MIN_MAG_LINEAR_MIP_POINT, + sampler_heap.ReleaseAndGetAddressOf ()); + + if (!gst_d3d12_result (hr, self->device)) { + GST_ERROR_OBJECT (self, "Couldn't create sampler heap"); + return FALSE; + } + } + + dst_handle.Offset (priv->sampler_inc_size); + device->CopyDescriptorsSimple (1, dst_handle, + GetCPUDescriptorHandleForHeapStart (sampler_heap), + D3D12_DESCRIPTOR_HEAP_TYPE_SAMPLER); + } else { + priv->sampler_heap = sampler_heap; + } + auto psblob_list = gst_d3d12_get_converter_pixel_shader_blob (GST_VIDEO_INFO_FORMAT (in_info), GST_VIDEO_INFO_FORMAT (out_info), @@ -2138,10 +2196,12 @@ gst_d3d12_converter_execute (GstD3D12Converter * self, GstD3D12Frame * in_frame, cl->SetGraphicsRootSignature (priv->rs.Get ()); cl->SetPipelineState (pso); - ID3D12DescriptorHeap *heaps[] = { srv_heap }; - cl->SetDescriptorHeaps (1, heaps); + ID3D12DescriptorHeap *heaps[] = { srv_heap, priv->sampler_heap.Get () }; + cl->SetDescriptorHeaps (2, heaps); cl->SetGraphicsRootDescriptorTable (priv->crs->GetPsSrvIdx (), GetGPUDescriptorHandleForHeapStart (srv_heap)); + cl->SetGraphicsRootDescriptorTable (priv->crs->GetPsSamplerIdx (), + GetGPUDescriptorHandleForHeapStart (priv->sampler_heap)); cl->SetGraphicsRoot32BitConstants (priv->crs->GetVsRootConstIdx (), 16, &priv->transform, 0); cl->SetGraphicsRoot32BitConstants (priv->crs->GetPsRootConstIdx (), diff --git a/subprojects/gst-plugins-bad/gst-libs/gst/d3d12/gstd3d12device-private.h b/subprojects/gst-plugins-bad/gst-libs/gst/d3d12/gstd3d12device-private.h index 262398ea15..70661158ef 100644 --- a/subprojects/gst-plugins-bad/gst-libs/gst/d3d12/gstd3d12device-private.h +++ b/subprojects/gst-plugins-bad/gst-libs/gst/d3d12/gstd3d12device-private.h @@ -92,5 +92,10 @@ void gst_d3d12_device_decoder_unlock (GstD3D12Device * device); GST_D3D12_API GstD3D12WAFlags gst_d3d12_device_get_workaround_flags (GstD3D12Device * device); +GST_D3D12_API +HRESULT gst_d3d12_device_get_sampler_state (GstD3D12Device * device, + D3D12_FILTER filter, + ID3D12DescriptorHeap ** heap); + G_END_DECLS diff --git a/subprojects/gst-plugins-bad/gst-libs/gst/d3d12/gstd3d12device.cpp b/subprojects/gst-plugins-bad/gst-libs/gst/d3d12/gstd3d12device.cpp index 52dfee31b6..8422c78e68 100644 --- a/subprojects/gst-plugins-bad/gst-libs/gst/d3d12/gstd3d12device.cpp +++ b/subprojects/gst-plugins-bad/gst-libs/gst/d3d12/gstd3d12device.cpp @@ -143,6 +143,8 @@ struct DeviceInner gst_clear_object (&fence_data_pool); + samplers.clear (); + factory = nullptr; adapter = nullptr; @@ -276,6 +278,8 @@ struct DeviceInner std::atomic removed_reason = { S_OK }; std::vector clients; + + std::unordered_map> samplers; }; typedef std::shared_ptr DeviceInnerPtr; @@ -2158,3 +2162,69 @@ gst_d3d12_device_get_workaround_flags (GstD3D12Device * device) return device->priv->inner->wa_flags; } + +HRESULT +gst_d3d12_device_get_sampler_state (GstD3D12Device * device, + D3D12_FILTER filter, ID3D12DescriptorHeap ** heap) +{ + g_return_val_if_fail (GST_IS_D3D12_DEVICE (device), E_INVALIDARG); + g_return_val_if_fail (heap, E_INVALIDARG); + + UINT max_anisotropy = 1; + switch (filter) { + case D3D12_FILTER_MIN_MAG_MIP_POINT: + case D3D12_FILTER_MIN_LINEAR_MAG_MIP_POINT: + case D3D12_FILTER_MIN_MAG_LINEAR_MIP_POINT: + break; + case D3D12_FILTER_ANISOTROPIC: + max_anisotropy = 16; + break; + default: + GST_WARNING_OBJECT (device, "Not supported sampler filter %d", filter); + return E_INVALIDARG; + } + + auto priv = device->priv->inner; + + std::lock_guard < std::mutex > lk (priv->lock); + auto it = priv->samplers.find (filter); + if (it != priv->samplers.end ()) { + auto sampler = it->second; + *heap = it->second.Get (); + (*heap)->AddRef (); + } else { + ComPtr < ID3D12DescriptorHeap > new_heap; + D3D12_DESCRIPTOR_HEAP_DESC heap_desc = { }; + heap_desc.NumDescriptors = 1; + heap_desc.Type = D3D12_DESCRIPTOR_HEAP_TYPE_SAMPLER; + heap_desc.Flags = D3D12_DESCRIPTOR_HEAP_FLAG_SHADER_VISIBLE; + + auto hr = priv->device->CreateDescriptorHeap (&heap_desc, + IID_PPV_ARGS (&new_heap)); + if (FAILED (hr)) { + GST_ERROR_OBJECT (device, "Couldn't create heap"); + return hr; + } + + D3D12_SAMPLER_DESC sampler_desc = { }; + sampler_desc.Filter = filter; + sampler_desc.AddressU = D3D12_TEXTURE_ADDRESS_MODE_CLAMP; + sampler_desc.AddressV = D3D12_TEXTURE_ADDRESS_MODE_CLAMP; + sampler_desc.AddressW = D3D12_TEXTURE_ADDRESS_MODE_CLAMP; + sampler_desc.MipLODBias = 0.0f; + sampler_desc.MaxAnisotropy = max_anisotropy; + sampler_desc.ComparisonFunc = D3D12_COMPARISON_FUNC_ALWAYS; + sampler_desc.MinLOD = 0; + sampler_desc.MaxLOD = D3D12_FLOAT32_MAX; + + auto cpu_handle = GetCPUDescriptorHandleForHeapStart (new_heap); + priv->device->CreateSampler (&sampler_desc, cpu_handle); + + priv->samplers[filter] = new_heap; + + *heap = new_heap.Get (); + (*heap)->AddRef (); + } + + return S_OK; +}