﻿// Georgy Treshchev 2025

#include "RuntimeAudioImporterSileroVAD.h"
#include "Misc/EngineVersionComparison.h"
#include "HAL/Platform.h"
#include "Interfaces/IPluginManager.h"

#if PLATFORM_WINDOWS && !UE_VERSION_OLDER_THAN(5, 6, 0)
#include "NNEOnnxruntime.h"
#else
#include <onnxruntime_cxx_api.h>
#endif

#define LOCTEXT_NAMESPACE "FRuntimeAudioImporterSileroVADModule"

void FRuntimeAudioImporterSileroVADModule::StartupModule()
{
#if PLATFORM_WINDOWS && !UE_VERSION_OLDER_THAN(5, 6, 0)
	UE_LOG(LogRuntimeAudioImporterSileroVAD, Log, TEXT("Loading ONNX Runtime from NNEOnnxruntime module"));

	const FString PluginDir = IPluginManager::Get().FindPlugin("NNERuntimeORT")->GetBaseDir();
	const FString OrtSharedLibPath = FPaths::Combine(PluginDir, TEXT(PREPROCESSOR_TO_STRING(ONNXRUNTIME_SHAREDLIB_PATH)));

	FPlatformProcess::PushDllDirectory(*FPaths::GetPath(OrtSharedLibPath));
	OrtSharedLibHandle = FPlatformProcess::GetDllHandle(*OrtSharedLibPath);
	FPlatformProcess::PopDllDirectory(*FPaths::GetPath(OrtSharedLibPath));

	TUniquePtr<UE::NNEOnnxruntime::OrtApiFunctions> OrtApiFunctions = UE::NNEOnnxruntime::LoadApiFunctions(OrtSharedLibHandle);
	if (!OrtApiFunctions.IsValid())
	{
		UE_LOG(LogRuntimeAudioImporterSileroVAD, Fatal, TEXT("Failed to load ONNX Runtime shared library functions!"));
		return;
	}

	Ort::InitApi(OrtApiFunctions->OrtGetApiBase()->GetApi(ORT_API_VERSION));
#else
#if PLATFORM_WINDOWS || PLATFORM_MAC || PLATFORM_LINUX
	const FString PluginDir = IPluginManager::Get().FindPlugin("RuntimeAudioImporterSileroVAD")->GetBaseDir();
	const FString OrtSharedLibFilePath = FPaths::Combine(PluginDir, TEXT("Source"), TEXT("RuntimeAudioImporterSileroVAD"), TEXT(PREPROCESSOR_TO_STRING(ONNXRUNTIME_SHAREDLIB_PATH_SILERO_VAD)));
	const FString OrtSharedLibDirPath = FPaths::GetPath(OrtSharedLibFilePath);

	UE_LOG(LogRuntimeAudioImporterSileroVAD, Log, TEXT("Loading ONNX Runtime dynamic/shared library from '%s'"), *OrtSharedLibFilePath);

	FPlatformProcess::PushDllDirectory(*FPaths::GetPath(OrtSharedLibDirPath));
	OrtSharedLibHandle = FPlatformProcess::GetDllHandle(*OrtSharedLibFilePath);
	FPlatformProcess::PopDllDirectory(*FPaths::GetPath(OrtSharedLibDirPath));
#endif
	
	Ort::InitApi();
#endif
}

void FRuntimeAudioImporterSileroVADModule::ShutdownModule()
{
	if (OrtSharedLibHandle != nullptr)
	{
#if !(PLATFORM_WINDOWS && !UE_VERSION_OLDER_THAN(5, 6, 0))
		FPlatformProcess::FreeDllHandle(OrtSharedLibHandle);
#endif
		OrtSharedLibHandle = nullptr;
	}
}

FString FRuntimeAudioImporterSileroVADModule::GetVADModelDataAssetName() const
{
	return TEXT("SileroVADModel");
}

FString FRuntimeAudioImporterSileroVADModule::GetVADModelDataPackagePath() const
{
	return TEXT("/RuntimeAudioImporterSileroVAD/VADModelData");
}

FString FRuntimeAudioImporterSileroVADModule::GetVADModelDataFullPackagePath() const
{
	return FPaths::Combine(GetVADModelDataPackagePath(), GetVADModelDataAssetName());
}

FString FRuntimeAudioImporterSileroVADModule::GetVADModelDataAssetPath() const
{
	const FString AssetName = GetVADModelDataAssetName();
	return FString::Printf(TEXT("%s.%s"), *FPaths::Combine(*GetVADModelDataPackagePath(), *AssetName), *AssetName);
}

#undef LOCTEXT_NAMESPACE
	
IMPLEMENT_MODULE(FRuntimeAudioImporterSileroVADModule, RuntimeAudioImporterSileroVAD)

DEFINE_LOG_CATEGORY(LogRuntimeAudioImporterSileroVAD);