﻿// Georgy Treshchev 2025

#pragma once

#include "CoreMinimal.h"
#include "UObject/Object.h"
#include "Misc/EngineVersionComparison.h"
#include "HAL/Platform.h"
#include "VAD/RuntimeVADProviderBase.h"

#if PLATFORM_WINDOWS && !UE_VERSION_OLDER_THAN(5, 6, 0)
#include "NNEOnnxruntime.h"
#else
#include <onnxruntime_cxx_api.h>
#endif

#include "RuntimeSileroVADProvider.generated.h"

/**
 * Silero VAD provider using ONNX Runtime
 */
UCLASS(BlueprintType, Category = "Voice Activity Detector")
class RUNTIMEAUDIOIMPORTERSILEROVAD_API URuntimeSileroVADProvider : public URuntimeVADProviderBase
{
    GENERATED_BODY()

public:
    URuntimeSileroVADProvider();

    //~ Begin UObject Interface
    virtual void BeginDestroy() override;
    //~ End UObject Interface

    //~ Begin URuntimeVADProviderBase Interface
    virtual bool Reset() override;
    virtual int32 ProcessAudio(const TArray<float>& PCMData, int32 SampleRate) override;
    virtual int32 GetRequiredSampleRate() const override { return 16000; }
    virtual float GetFrameDurationMs() const override { return 32.0f; } // 512 samples at 16kHz
    virtual bool IsSpeechActive() const override { return bSileroTriggered; }
    virtual void OnSpeechStarted() override;
    virtual void OnSpeechEnded() override;
    //~ End URuntimeVADProviderBase Interface

    /**
     * Set the speech detection threshold
     * @param Threshold Threshold value (0.0 to 1.0)
     * @return True if threshold was set successfully
     */
    UFUNCTION(BlueprintCallable, Category = "Voice Activity Detector|Silero")
    bool SetSpeechThreshold(float Threshold);

    UFUNCTION(BlueprintCallable, Category = "Voice Activity Detector|Silero")
    bool SetMinimumSpeechDuration(float Duration);

    UFUNCTION(BlueprintCallable, Category = "Voice Activity Detector|Silero")
    bool SetMaximumSpeechDuration(float Duration);

    UFUNCTION(BlueprintCallable, Category = "Voice Activity Detector|Silero")
    bool SetMinimumSilenceDuration(float Duration);

protected:
    bool Initialize();
    
    /** Speech detection threshold */
    float SileroThreshold;

    /** Whether Silero VAD has been initialized */
    bool bSileroInitialized;

    /** ONNX Runtime session for Silero VAD */
    TUniquePtr<Ort::Session> SileroSession;

    /** ONNX Runtime memory info */
    TUniquePtr<Ort::MemoryInfo> MemoryInfo;

    /** Silero VAD state tensor */
    TArray<float> SileroState;

    /** Silero VAD context buffer */
    TArray<float> SileroContext;

    /** Accumulated audio buffer for Silero VAD */
    TArray<float> SileroAccumulatedAudio;

    /** Speech detection state variables */
    bool bSileroTriggered;
    uint32 SileroTempEnd;
    uint32 SileroCurrentSample;
    int32 SileroPrevEnd;
    int32 SileroNextStart;

    /** Silero VAD configuration parameters */
    int64 SileroMinSilenceSamples;
    int64 SileroMinSilenceSamplesAtMaxSpeech;
    int32 SileroMinSpeechSamples;
    float SileroMaxSpeechSamples;
    int32 SileroSpeechPadSamples;

    /** Speech timestamp structure */
    struct FRuntimeSileroVADTimestamp
    {
        int32 Start;
        int32 End;
        
        FRuntimeSileroVADTimestamp(int32 InStart = -1, int32 InEnd = -1) : Start(InStart), End(InEnd) {}
    };

    /** Current speech segment */
    FRuntimeSileroVADTimestamp SileroCurrentSpeech;

    /** Detected speech segments */
    TArray<FRuntimeSileroVADTimestamp> SileroSpeeches;

    /** Model data pointer */
    void* ModelDataPtr;

    /**
     * Initialize Silero VAD internal state
     */
    void InitializeSileroParameters();

    /**
     * Reset Silero VAD state
     */
    void ResetSileroState();

    /**
     * Initialize the Silero VAD model
     * @return True if initialization was successful
     */
    bool InitializeSileroModel();

    bool InitializeSileroModelInternal();

    /**
     * Predict speech probability for a single chunk using Silero VAD
     * @param ChunkData Audio chunk (512 samples)
     * @return Speech probability (0.0 to 1.0)
     */
    float PredictSileroChunk(const TArray<float>& ChunkData);

    /**
     * Process speech detection logic based on speech probability
     * @param SpeechProb Speech probability from model
     */
    void ProcessSileroSpeechDetection(float SpeechProb);

    TUniquePtr<Ort::Env> OnnxEnv;
};