You've already forked UnrealEngineUWP
mirror of
https://github.com/izzy2lost/UnrealEngineUWP.git
synced 2026-03-26 18:15:20 -07:00
541248193b
This CL adds neural post processing to NNE. A view extension registers a callback for enqueueing post procssing calls. The added code will downsample the input scene color to the neural network input shape, rune the neural network (currently disabled due to missing weights) and upsample it to overwrite the scene color. There is a blueprint wrapper class to make this functionality accessible through blueprints. Also a simple sobel filter has been added once with fixed input, once with dynamic shape for testing purposes. #rb florent.guinier #preflight 6396e0439549ddaa2847f4bd [CL 23476263 by nico ranieri in ue5-main branch]
204 lines
7.4 KiB
C++
204 lines
7.4 KiB
C++
// Copyright Epic Games, Inc. All Rights Reserved.
|
|
|
|
#include "NNIRuntimeRDGConv.h"
|
|
#include "NNEHlslShadersConvCS.h"
|
|
#include "NNECoreAttributeMap.h"
|
|
#include "NNXRuntimeHLSLHelper.h"
|
|
|
|
namespace UE::NNIRuntimeRDG::Private::Hlsl
|
|
{
|
|
DECLARE_GPU_STAT_NAMED(FNNIOperatorConv, TEXT("NNI.Operator.Hlsl.Conv"));
|
|
|
|
using EConvAutoPad = UE::NNEHlslShaders::Internal::EConvAutoPad;
|
|
|
|
/**
|
|
* Convolution operator implementation
|
|
*/
|
|
class FConv : public NNX::FMLOperatorHlsl
|
|
{
|
|
public:
|
|
|
|
static NNX::FMLOperatorHlsl* Create()
|
|
{
|
|
return new FConv();
|
|
}
|
|
|
|
virtual ~FConv() = default;
|
|
|
|
private:
|
|
|
|
FConv() {}
|
|
|
|
int32 NumDimensions = 0;
|
|
|
|
EConvAutoPad AutoPad = EConvAutoPad::NOTSET;
|
|
TArray<int32> Dilations;
|
|
int32 Group = 1;
|
|
TArray<int32> Pads;
|
|
TArray<int32> Strides;
|
|
|
|
public:
|
|
|
|
virtual int PrepareOutputs(TConstArrayView<NNX::FTensorRef> InputTensors, TArrayView<NNX::FTensorRef> OutputTensors) const override
|
|
{
|
|
check(InputTensors.Num() >= 2 && InputTensors.Num() <= 3);
|
|
check(OutputTensors.Num() == 1);
|
|
|
|
const NNX::FTensorShape& Input = InputTensors[0]->GetShape();
|
|
const NNX::FTensorShape& Weights = InputTensors[1]->GetShape();
|
|
NNX::FSymbolicTensorShape OutputShape;
|
|
|
|
OutputShape.Data = UE::NNEHlslShaders::Internal::FConvCS::GetOutputShape(Input.Data, Weights.Data, AutoPad, Dilations, Strides, Pads);
|
|
if (!OutputShape.IsConcrete())
|
|
{
|
|
return -1;
|
|
}
|
|
OutputTensors[0]->SetShape(NNX::FTensorShape::MakeFromSymbolic(OutputShape));
|
|
|
|
return 0;
|
|
};
|
|
|
|
virtual bool Initialize(TConstArrayView<NNX::FTensorDesc> InputTensorDescs, TConstArrayView<NNX::FTensorDesc> OutputTensorDescs, const UE::NNECore::FAttributeMap& Attributes) override
|
|
{
|
|
using namespace UE::NNEHlslShaders::Internal;
|
|
|
|
check(InputTensorDescs.Num() >= 2 && InputTensorDescs.Num() <= 3);
|
|
check(OutputTensorDescs.Num() == 1);
|
|
|
|
const NNX::FTensorDesc& Input = InputTensorDescs[0];
|
|
const NNX::FTensorDesc& Weights = InputTensorDescs[1];
|
|
const NNX::FTensorDesc& Output = OutputTensorDescs[0];
|
|
|
|
if (Input.GetShape().Rank() < 2)
|
|
{
|
|
UE_LOG(LogNNX, Warning, TEXT("Conv first input should be at least of rank 2"));
|
|
return false;
|
|
}
|
|
if (Weights.GetShape().Rank() != Input.GetShape().Rank())
|
|
{
|
|
UE_LOG(LogNNX, Warning, TEXT("Conv first and second inputs should be of same ranks"));
|
|
return false;
|
|
}
|
|
if (Output.GetShape().Rank() != Input.GetShape().Rank())
|
|
{
|
|
UE_LOG(LogNNX, Warning, TEXT("Conv first and output should be of same ranks"));
|
|
return false;
|
|
}
|
|
|
|
NumDimensions = Input.GetShape().Rank() - 2;
|
|
|
|
TArray<int32> DilationsOrStridesDefault;
|
|
DilationsOrStridesDefault.Init(1, NumDimensions);
|
|
|
|
FConvCS::LexFromString(AutoPad, *Attributes.GetValue<FString>(TEXT("auto_pad")));
|
|
Dilations = Attributes.GetValueOrDefault<TArray<int32>>(TEXT("dilations"), DilationsOrStridesDefault);
|
|
Group = Attributes.GetValueOrDefault<int>(TEXT("group"), 1);
|
|
if (AutoPad == EConvAutoPad::NOTSET)
|
|
{
|
|
TArray<int32> PadsDefault;
|
|
PadsDefault.Init(1, 2 * NumDimensions);
|
|
|
|
Pads = Attributes.GetValueOrDefault<TArray<int32>>(TEXT("pads"), PadsDefault);
|
|
}
|
|
Strides = Attributes.GetValueOrDefault<TArray<int32>>(TEXT("strides"), DilationsOrStridesDefault);
|
|
|
|
return true;
|
|
}
|
|
|
|
virtual void Dispatch(FRDGBuilder& GraphBuilder, TConstArrayView<NNX::FTensorRDGRef> InputTensors, TConstArrayView<NNX::FTensorRDGRef> OutputTensors) override
|
|
{
|
|
using namespace UE::NNEHlslShaders::Internal;
|
|
|
|
constexpr EConvAlgorithm Algorithm = EConvAlgorithm::SharedMemory;
|
|
constexpr EConvGroupSize GroupSize = EConvGroupSize::Size256;
|
|
|
|
check(InputTensors.Num() >= 2 && InputTensors.Num() <= 3);
|
|
check(OutputTensors.Num() == 1);
|
|
check(InputTensors[0] != nullptr);
|
|
check(InputTensors[1] != nullptr);
|
|
check(OutputTensors[0] != nullptr);
|
|
|
|
const NNX::FTensorRDG& Input = *InputTensors[0];
|
|
const NNX::FTensorRDG& Weights = *InputTensors[1];
|
|
const NNX::FTensorRDG& Output = *OutputTensors[0];
|
|
const NNX::FTensorRDG* Bias = nullptr;
|
|
bool HasBias = false;
|
|
|
|
if (InputTensors.Num() == 3) {
|
|
HasBias = true;
|
|
check(InputTensors[2] != nullptr);
|
|
Bias = InputTensors[2];
|
|
}
|
|
|
|
check(Input.GetShape().Rank() > 2);
|
|
check(Weights.GetShape().Rank() == Input.GetShape().Rank());
|
|
check(Output.GetShape().Rank() == Input.GetShape().Rank());
|
|
check(NumDimensions == (Input.GetShape().Rank() - 2));
|
|
|
|
TArray<int32> OutputShape = FConvCS::GetOutputShape(Input.GetShape().Data, Weights.GetShape().Data, AutoPad, Dilations, Strides, Pads);
|
|
|
|
// Set parameters
|
|
FConvCS::FParameters* Params = GraphBuilder.AllocParameters<FConvCS::FParameters>();
|
|
FConvCS::FillInParameters(GroupSize, Input.GetShape().Data, Weights.GetShape().Data, HasBias, AutoPad, Group, Dilations,Strides, Pads, *Params);
|
|
Params->X = GraphBuilder.CreateSRV(FRDGBufferSRVDesc(Input.GetBuffer(), PF_R32_FLOAT));
|
|
Params->W = GraphBuilder.CreateSRV(FRDGBufferSRVDesc(Weights.GetBuffer(), PF_R32_FLOAT));
|
|
if (HasBias) {
|
|
Params->B = GraphBuilder.CreateSRV(FRDGBufferSRVDesc(Bias->GetBuffer(), PF_R32_FLOAT));
|
|
}
|
|
Params->Y = GraphBuilder.CreateUAV(FRDGBufferUAVDesc(Output.GetBuffer(), PF_R32_FLOAT));
|
|
|
|
FConvCS::FPermutationDomain PermutationVector;
|
|
|
|
PermutationVector.Set<FConvCS::FConvAlgorithm>(Algorithm);
|
|
PermutationVector.Set<FConvCS::FConvGroupSize>(GroupSize);
|
|
PermutationVector.Set<FConvCS::FConvNumDimensions>(NumDimensions);
|
|
PermutationVector.Set<FConvCS::FConvNumReadsPerThread>(FConvCS::GetNumReadsPerThread(GroupSize, Weights.GetShape().Data, Dilations, Strides));
|
|
PermutationVector.Set<FConvCS::FConvHasB>(HasBias);
|
|
TShaderMapRef<FConvCS> ComputeShader(GetGlobalShaderMap(GMaxRHIFeatureLevel), PermutationVector);
|
|
|
|
RDG_EVENT_SCOPE(GraphBuilder, "NNI.Operator.Hlsl.Conv");
|
|
RDG_GPU_STAT_SCOPE(GraphBuilder, FNNIOperatorConv);
|
|
|
|
FComputeShaderUtils::AddPass(
|
|
GraphBuilder,
|
|
RDG_EVENT_NAME("NNI.Operator.Hlsl.Conv.Dispatch"),
|
|
ERDGPassFlags::Compute | ERDGPassFlags::NeverCull,
|
|
ComputeShader,
|
|
Params,
|
|
FConvCS::GetGroupCount(OutputShape, FConvCS::GetGroupShape(GroupSize, NumDimensions)));
|
|
}
|
|
};
|
|
|
|
bool ValidateConvOperator(const UE::NNECore::FAttributeMap& AttributeMap, TConstArrayView<EMLTensorDataType> InputTypes, TConstArrayView<NNX::FSymbolicTensorShape> InputShapes)
|
|
{
|
|
bool bIsValid = true;
|
|
|
|
NNX::FAttributeValidator AttributeValidator;
|
|
AttributeValidator.AddOptional(TEXT("auto_pad"), ENNEAttributeDataType::String);
|
|
AttributeValidator.AddOptional(TEXT("dilations"), ENNEAttributeDataType::Int32Array);
|
|
AttributeValidator.AddOptional(TEXT("group"), ENNEAttributeDataType::Int32);
|
|
AttributeValidator.AddOptional(TEXT("kernel_shape"), ENNEAttributeDataType::Int32Array); // idea: cross check input weight shape with this attribute if present
|
|
AttributeValidator.AddOptional(TEXT("pads"), ENNEAttributeDataType::Int32Array);
|
|
AttributeValidator.AddOptional(TEXT("strides"), ENNEAttributeDataType::Int32Array);
|
|
|
|
bIsValid &= AttributeValidator.Validate(AttributeMap);
|
|
|
|
NNX::FInputValidator InputValidator;
|
|
InputValidator.AddSupportedType(EMLTensorDataType::Float);
|
|
InputValidator.AddRequired();
|
|
InputValidator.AddRequired();
|
|
InputValidator.AddOptional();
|
|
bIsValid &= InputValidator.Validate(InputTypes);
|
|
|
|
return bIsValid;
|
|
}
|
|
|
|
bool RegisterConvOperator(NNX::FMLOperatorRegistryHlsl& Registry)
|
|
{
|
|
Registry.OpAdd(TEXT("Conv"), FConv::Create, ValidateConvOperator);
|
|
|
|
return true;
|
|
}
|
|
|
|
} // UE::NNIRuntimeRDG::Private::Hlsl
|