﻿// Georgy Treshchev 2025

#include "RuntimeSileroVADProvider.h"
#include "RuntimeAudioImporterSileroVAD.h"
#include "Engine/AssetManager.h"
#include "RuntimeSileroVADModel.h"

URuntimeSileroVADProvider::URuntimeSileroVADProvider()
    : SileroThreshold(0.5f)
    , bSileroInitialized(false)
    , bSileroTriggered(false)
    , SileroTempEnd(0)
    , SileroCurrentSample(0)
    , SileroPrevEnd(0)
    , SileroNextStart(0)
    , SileroCurrentSpeech(-1, -1)
    , ModelDataPtr(nullptr)
{
    InitializeSileroParameters();
}

void URuntimeSileroVADProvider::BeginDestroy()
{
    SileroSession.Reset();
    MemoryInfo.Reset();

    if (ModelDataPtr)
    {
        FMemory::Free(ModelDataPtr);
        ModelDataPtr = nullptr;
    }

    Super::BeginDestroy();
}

bool URuntimeSileroVADProvider::Initialize()
{
    if (!InitializeSileroModel())
    {
        UE_LOG(LogRuntimeAudioImporterSileroVAD, Error, TEXT("Failed to initialize Silero VAD model"));
        return false;
    }

    ResetSileroState();
    bSileroInitialized = true;

    UE_LOG(LogRuntimeAudioImporterSileroVAD, Log, TEXT("Successfully initialized Silero VAD provider"));
    return true;
}

bool URuntimeSileroVADProvider::Reset()
{
    if (!bSileroInitialized)
    {
        return false;
    }

    ResetSileroState();
    UE_LOG(LogRuntimeAudioImporterSileroVAD, Log, TEXT("Successfully reset Silero VAD provider"));
    return true;
}

int32 URuntimeSileroVADProvider::ProcessAudio(const TArray<float>& PCMData, int32 SampleRate)
{
    static constexpr int32 SileroTargetSampleRate = 16000;
    static constexpr int32 SileroWindowSize = 512;

    if (!bSileroInitialized && !Initialize())
    {
        UE_LOG(LogRuntimeAudioImporterSileroVAD, Error, TEXT("Unable to initialize!"));
    }

    if (SampleRate != SileroTargetSampleRate)
    {
        UE_LOG(LogRuntimeAudioImporterSileroVAD, Error, TEXT("Silero VAD requires 16kHz sample rate, got %d"), SampleRate);
        return -1;
    }

    // Accumulate audio data
    SileroAccumulatedAudio.Append(PCMData);
    
    // Process in chunks of 512 samples
    while (SileroAccumulatedAudio.Num() >= SileroWindowSize)
    {
        // Extract exactly 512 samples
        TArray<float> ChunkData;
        ChunkData.Append(SileroAccumulatedAudio.GetData(), SileroWindowSize);
        
        // Remove processed samples from accumulated buffer
        SileroAccumulatedAudio.RemoveAt(0, SileroWindowSize);
        
        // Predict using this chunk
        float SpeechProb = PredictSileroChunk(ChunkData);
        
        // Process speech detection logic
        ProcessSileroSpeechDetection(SpeechProb);
    }
    
    // Return the current speech detection state
    return bSileroTriggered ? 1 : 0;
}

void URuntimeSileroVADProvider::OnSpeechStarted()
{
    
}

void URuntimeSileroVADProvider::OnSpeechEnded()
{
    
}

bool URuntimeSileroVADProvider::SetSpeechThreshold(float Threshold)
{
    if (Threshold < 0.0f || Threshold > 1.0f)
    {
        UE_LOG(LogRuntimeAudioImporterSileroVAD, Error, TEXT("Speech threshold must be between 0.0 and 1.0"));
        return false;
    }

    SileroThreshold = Threshold;
    UE_LOG(LogRuntimeAudioImporterSileroVAD, Log, TEXT("Successfully set Silero speech threshold to %f"), Threshold);
    return true;
}

bool URuntimeSileroVADProvider::SetMinimumSpeechDuration(float Duration)
{
    if (Duration <= 0.0f)
    {
        UE_LOG(LogRuntimeAudioImporterSileroVAD, Error, TEXT("Minimum speech duration must be more than 0.0"));
        return false;
    }

    int32 SrPerMs = 16000 / 1000; // 16
    SileroMinSpeechSamples = SrPerMs * Duration;
    UE_LOG(LogRuntimeAudioImporterSileroVAD, Log, TEXT("Successfully set Silero minimum speech duration to %f"), Duration);
    return true;
}

bool URuntimeSileroVADProvider::SetMaximumSpeechDuration(float Duration)
{
    if (Duration <= 0.0f)
    {
        UE_LOG(LogRuntimeAudioImporterSileroVAD, Error, TEXT("Maximum speech duration must be more than 0.0"));
        return false;
    }

    SileroMaxSpeechSamples = 16000 * Duration;
    UE_LOG(LogRuntimeAudioImporterSileroVAD, Log, TEXT("Successfully set Silero maximum speech duration to %f"), Duration);
    return true;
}

bool URuntimeSileroVADProvider::SetMinimumSilenceDuration(float Duration)
{
    if (Duration <= 0.0f)
    {
        UE_LOG(LogRuntimeAudioImporterSileroVAD, Error, TEXT("Minimum silence duration must be more than 0.0"));
        return false;
    }

    int32 SrPerMs = 16000 / 1000; // 16
    SileroMinSilenceSamples = SrPerMs * Duration;
    UE_LOG(LogRuntimeAudioImporterSileroVAD, Log, TEXT("Successfully set Silero minimum silence duration to %f"), Duration);
    return true;
}

void URuntimeSileroVADProvider::InitializeSileroParameters()
{
    // Initialize Silero parameters (matching official implementation)
    static constexpr int32 SileroTargetSampleRate = 16000;
    static constexpr int32 MinSilenceDurationMs = 100;
    static constexpr int32 SpeechPadMs = 30;
    static constexpr int32 MinSpeechDurationMs = 250;
    
    int32 SrPerMs = SileroTargetSampleRate / 1000; // 16
    SileroMinSpeechSamples = SrPerMs * MinSpeechDurationMs; // 250ms * 16 = 4000 samples
    SileroMaxSpeechSamples = SileroTargetSampleRate * 30.0f; // 30 seconds max
    SileroMinSilenceSamples = SrPerMs * MinSilenceDurationMs; // 100ms * 16 = 1600 samples
    SileroMinSilenceSamplesAtMaxSpeech = SrPerMs * 98; // 98ms * 16 = 1568 samples
    SileroSpeechPadSamples = SpeechPadMs; // 30 samples
}

void URuntimeSileroVADProvider::ResetSileroState()
{
    // Initialize state tensor [2, 1, 128]
    SileroState.SetNumZeroed(2 * 1 * 128);
    
    // Initialize context buffer (64 samples)
    SileroContext.SetNumZeroed(64);
    
    // Reset accumulated audio buffer
    SileroAccumulatedAudio.Reset();
    
    // Reset speech detection state
    bSileroTriggered = false;
    SileroTempEnd = 0;
    SileroCurrentSample = 0;
    SileroPrevEnd = 0;
    SileroNextStart = 0;
    SileroCurrentSpeech = FRuntimeSileroVADTimestamp(-1, -1);
    SileroSpeeches.Reset();
}

bool URuntimeSileroVADProvider::InitializeSileroModel()
{
    // If we're already on the game thread, execute directly
    if (IsInGameThread())
    {
        return InitializeSileroModelInternal();
    }
    
    // If we're not on the game thread, we need to dispatch to it
    bool bResult = false;
    FGraphEventRef Task = FFunctionGraphTask::CreateAndDispatchWhenReady([this, &bResult]() {
        bResult = InitializeSileroModelInternal();
    }, TStatId(), nullptr, ENamedThreads::GameThread);
    
    // Wait for the task to complete without blocking the game thread
    FTaskGraphInterface::Get().WaitUntilTaskCompletes(Task);
    
    return bResult;
}

bool URuntimeSileroVADProvider::InitializeSileroModelInternal()
{
    FRuntimeAudioImporterSileroVADModule& Module = FModuleManager::LoadModuleChecked<FRuntimeAudioImporterSileroVADModule>("RuntimeAudioImporterSileroVAD");

    if (!OnnxEnv.IsValid())
    {
        OnnxEnv = MakeUnique<Ort::Env>(ORT_LOGGING_LEVEL_WARNING, "VAD");
        if (!OnnxEnv.IsValid())
        {
            UE_LOG(LogRuntimeAudioImporterSileroVAD, Error, TEXT("ONNX environment is not initialized"));
            return false;
        }
    }

    // Create session options
    Ort::SessionOptions SessionOptions;
    SessionOptions.SetIntraOpNumThreads(1);
    SessionOptions.SetInterOpNumThreads(1);

    // Load model data
    const FRuntimeAudioImporterSileroVADModule* RuntimeAudioImporterSileroVADModulePtr = 
        static_cast<FRuntimeAudioImporterSileroVADModule*>(FModuleManager::Get().GetModule("RuntimeAudioImporterSileroVAD"));
    if (!RuntimeAudioImporterSileroVADModulePtr)
    {
        UE_LOG(LogRuntimeAudioImporterSileroVAD, Error, TEXT("Cannot get the RuntimeAudioImporter module"));
        return false;
    }

    FSoftObjectPath VoiceModelAssetPath = FSoftObjectPath(RuntimeAudioImporterSileroVADModulePtr->GetVADModelDataAssetPath());
    
    // Asynchronously load the asset
    TSharedPtr<FStreamableHandle> StreamableHandle = UAssetManager::GetStreamableManager().RequestSyncLoad(
        VoiceModelAssetPath,
        false
    );
    
    // Check if the asset was loaded successfully
    if (!StreamableHandle || !StreamableHandle->GetLoadedAsset())
    {
        UE_LOG(LogRuntimeAudioImporterSileroVAD, Error, TEXT("Failed to load the VAD model"));
        return false;
    }
    
    URuntimeSileroVADModel* VADModel = Cast<URuntimeSileroVADModel>(StreamableHandle->GetLoadedAsset());
    if (!VADModel)
    {
        UE_LOG(LogRuntimeAudioImporterSileroVAD, Error, TEXT("Failed to cast the loaded asset to URuntimeSileroVADModel"));
        return false;
    }

    if (ModelDataPtr)
    {
        FMemory::Free(ModelDataPtr);
        ModelDataPtr = nullptr;
    }

    VADModel->ModelBulkData.GetCopy(&ModelDataPtr, false);
    const int64 ModelBulkDataSize = VADModel->ModelBulkData.GetBulkDataSize();

    // Create session
    SileroSession = MakeUnique<Ort::Session>(*OnnxEnv, ModelDataPtr, ModelBulkDataSize, SessionOptions);

    // Create memory info
    MemoryInfo = MakeUnique<Ort::MemoryInfo>(Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault));

    UE_LOG(LogRuntimeAudioImporterSileroVAD, Log, TEXT("Successfully initialized Silero VAD model"));
    return true;
}

float URuntimeSileroVADProvider::PredictSileroChunk(const TArray<float>& ChunkData)
{
    static constexpr int32 SileroContextSize = 64;
    static constexpr int32 SileroWindowSize = 512;
    static constexpr int32 SileroTargetSampleRate = 16000;
    
    // Build input: context + current chunk
    TArray<float> InputData;
    InputData.Reserve(SileroContextSize + SileroWindowSize);
    InputData.Append(SileroContext);
    InputData.Append(ChunkData);

    // Prepare ONNX inputs
    std::vector<int64_t> InputShape = {1, static_cast<int64_t>(InputData.Num())};
    std::vector<int64_t> StateShape = {2, 1, 128};
    std::vector<int64_t> SrShape = {1};

    // Create input tensors
    Ort::Value InputTensor = Ort::Value::CreateTensor<float>(*MemoryInfo, InputData.GetData(), InputData.Num(), InputShape.data(), InputShape.size());
    Ort::Value StateTensor = Ort::Value::CreateTensor<float>(*MemoryInfo, SileroState.GetData(), SileroState.Num(), StateShape.data(), StateShape.size());
        
    int64_t SampleRate = SileroTargetSampleRate;
    Ort::Value SrTensor = Ort::Value::CreateTensor<int64_t>(*MemoryInfo, &SampleRate, 1, SrShape.data(), SrShape.size());

    // Input names
    std::vector<const char*> InputNames = {"input", "state", "sr"};
    std::vector<Ort::Value> InputTensors;
    InputTensors.push_back(std::move(InputTensor));
    InputTensors.push_back(std::move(StateTensor));
    InputTensors.push_back(std::move(SrTensor));

    // Output names
    std::vector<const char*> OutputNames = {"output", "stateN"};

    // Run inference
    std::vector<Ort::Value> OutputTensors = SileroSession->Run(Ort::RunOptions{nullptr}, InputNames.data(), InputTensors.data(), InputTensors.size(), OutputNames.data(), OutputNames.size());

    // Get output probability
    float* OutputData = OutputTensors[0].GetTensorMutableData<float>();
    float SpeechProb = OutputData[0];

    // Update state
    float* NewStateData = OutputTensors[1].GetTensorMutableData<float>();
    FMemory::Memcpy(SileroState.GetData(), NewStateData, SileroState.Num() * sizeof(float));

    // Update context (last 64 samples of input)
    FMemory::Memcpy(SileroContext.GetData(), InputData.GetData() + InputData.Num() - SileroContextSize, SileroContextSize * sizeof(float));
        
    // Advance current sample counter
    SileroCurrentSample += SileroWindowSize;
        
    return SpeechProb;
}

void URuntimeSileroVADProvider::ProcessSileroSpeechDetection(float SpeechProb)
{
    static constexpr int32 SileroWindowSize = 512;
    
    // High confidence speech detected
    if (SpeechProb >= SileroThreshold)
    {
        // Clear any pending end detection
        SileroTempEnd = 0;
        
        // Start speech if not already active
        if (!bSileroTriggered)
        {
            bSileroTriggered = true;
            SileroCurrentSpeech.Start = SileroCurrentSample - SileroWindowSize;
            UE_LOG(LogRuntimeAudioImporterSileroVAD, Verbose, TEXT("Speech started at sample %d"), SileroCurrentSpeech.Start);
        }
        return;
    }

    // Handle maximum speech duration limit
    if (bSileroTriggered && ((SileroCurrentSample - SileroCurrentSpeech.Start) > SileroMaxSpeechSamples))
    {
        // Force end the current speech segment
        SileroCurrentSpeech.End = SileroCurrentSample - SileroWindowSize;
        
        if (SileroCurrentSpeech.End - SileroCurrentSpeech.Start > SileroMinSpeechSamples)
        {
            SileroSpeeches.Add(SileroCurrentSpeech);
            UE_LOG(LogRuntimeAudioImporterSileroVAD, Verbose, TEXT("Speech ended (max duration) from %d to %d"), 
                   SileroCurrentSpeech.Start, SileroCurrentSpeech.End);
        }
        
        // Reset state completely
        SileroCurrentSpeech = FRuntimeSileroVADTimestamp();
        SileroTempEnd = 0;
        SileroPrevEnd = 0;
        SileroNextStart = 0;
        bSileroTriggered = false;
        return;
    }

    // Low confidence - potential speech end
    if (SpeechProb < (SileroThreshold - 0.15f))
    {
        if (bSileroTriggered)
        {
            // Mark potential end point if not already marked
            if (SileroTempEnd == 0)
            {
                SileroTempEnd = SileroCurrentSample - SileroWindowSize;
                UE_LOG(LogRuntimeAudioImporterSileroVAD, Verbose, TEXT("Potential speech end marked at sample %d"), SileroTempEnd);
            }
            
            // Check if we've been silent long enough to end speech
            int32 SilenceDuration = SileroCurrentSample - SileroTempEnd;
            int32 RequiredSilence = ((SileroCurrentSample - SileroCurrentSpeech.Start) > SileroMaxSpeechSamples) ? 
                                   SileroMinSilenceSamplesAtMaxSpeech : SileroMinSilenceSamples;
            
            if (SilenceDuration >= RequiredSilence)
            {
                // End the speech segment
                SileroCurrentSpeech.End = SileroTempEnd;
                
                // Only add if it meets minimum duration
                if (SileroCurrentSpeech.End - SileroCurrentSpeech.Start > SileroMinSpeechSamples)
                {
                    SileroSpeeches.Add(SileroCurrentSpeech);
                    UE_LOG(LogRuntimeAudioImporterSileroVAD, Verbose, TEXT("Speech ended from %d to %d (duration: %d samples)"), 
                           SileroCurrentSpeech.Start, SileroCurrentSpeech.End, 
                           SileroCurrentSpeech.End - SileroCurrentSpeech.Start);
                }
                else
                {
                    UE_LOG(LogRuntimeAudioImporterSileroVAD, Verbose, TEXT("Speech segment too short, discarded"));
                }
                
                // Reset state completely
                SileroCurrentSpeech = FRuntimeSileroVADTimestamp();
                SileroTempEnd = 0;
                SileroPrevEnd = 0;
                SileroNextStart = 0;
                bSileroTriggered = false;
            }
        }
        return;
    }
}