diff --git a/Makefile b/Makefile index 78db940..bd97b38 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,4 @@ -.PHONY: setconfig run lint lintall install-linters setup-whisper build-whisper download-whisper-model docker-up docker-down docker-logs noextra-run installdelve checkdelve fetch-onnx install-onnx-deps +.PHONY: setconfig run lint lintall install-linters setup-whisper build-whisper download-whisper-model docker-up docker-down docker-logs noextra-run installdelve checkdelve fetch-onnx install-onnx-deps fetch-kokoro-voices install-espeak run: setconfig go build -tags extra -o gf-lt && ./gf-lt @@ -33,6 +33,9 @@ lintall: lint fetch-onnx: mkdir -p onnx/embedgemma && curl -o onnx/embedgemma/config.json -L https://huggingface.co/onnx-community/embeddinggemma-300m-ONNX/resolve/main/config.json && curl -o onnx/embedgemma/tokenizer.json -L https://huggingface.co/onnx-community/embeddinggemma-300m-ONNX/resolve/main/tokenizer.json && curl -o onnx/embedgemma/model_q4.onnx -L https://huggingface.co/onnx-community/embeddinggemma-300m-ONNX/resolve/main/onnx/model_q4.onnx && curl -o onnx/embedgemma/model_q4.onnx_data -L https://huggingface.co/onnx-community/embeddinggemma-300m-ONNX/resolve/main/onnx/model_q4.onnx_data?download=true +fetch-kokoro-onnx: + mkdir -p onnx/kokoro && curl -o onnx/kokoro/config.json -L https://huggingface.co/onnx-community/Kokoro-82M-v1.0-ONNX/resolve/main/config.json && curl -o onnx/kokoro/tokenizer.json -L https://huggingface.co/onnx-community/Kokoro-82M-v1.0-ONNX/resolve/main/tokenizer.json && curl -o onnx/kokoro/model_quantized.onnx -L https://huggingface.co/onnx-community/Kokoro-82M-v1.0-ONNX/resolve/main/onnx/model_quantized.onnx && curl -o onnx/kokoro/voices.bin -L https://github.com/thewh1teagle/kokoro-onnx/releases/download/model-files-v1.0/voices-v1.0.bin + install-onnx-deps: ## Install ONNX Runtime with CUDA support (or CPU fallback) @echo "=== ONNX Runtime Installer ===" && \ echo "" && \ @@ -194,3 +197,25 @@ docker-logs-whisper: ## View logs from Whisper STT service only docker-logs-kokoro: ## View logs from Kokoro TTS service only @echo "Displaying logs from Kokoro TTS service..." docker-compose -f batteries/docker-compose.yml logs -f kokoro-tts + +# Kokoro ONNX TTS Setup +install-espeak: ## Install espeak-ng for phoneme tokenization + @echo "=== Installing espeak-ng ===" && \ + if command -v espeak-ng >/dev/null 2>&1; then \ + echo "espeak-ng is already installed:" && \ + espeak-ng --version && \ + exit 0; \ + fi && \ + echo "Installing espeak-ng..." && \ + sudo apt-get update && \ + sudo apt-get install -y espeak-ng espeak && \ + echo "espeak-ng installed successfully!" && \ + espeak-ng --version + +fetch-kokoro-voices: ## Download Kokoro voice files (PyTorch format) + @echo "=== Downloading Kokoro voices ===" && \ + mkdir -p onnx/kokoro/voices && \ + echo "Downloading af_bella voice..." && \ + curl -L -o onnx/kokoro/voices/af_bella.pt https://raw.githubusercontent.com/hexgrad/kokoro/main/kokoro/voices/af_heart.pt && \ + echo "Voice file downloaded to onnx/kokoro/voices/" && \ + ls -lh onnx/kokoro/voices/ diff --git a/bot.go b/bot.go index ad52059..16bfecc 100644 --- a/bot.go +++ b/bot.go @@ -1497,7 +1497,7 @@ func init() { // load cards basicCard.Role = cfg.AssistantRole logLevel.Set(slog.LevelInfo) - logger = slog.New(slog.NewTextHandler(logfile, &slog.HandlerOptions{Level: logLevel})) + logger = slog.New(slog.NewTextHandler(logfile, &slog.HandlerOptions{Level: logLevel, AddSource: true})) store = storage.NewProviderSQL(cfg.DBPATH, logger) if store == nil { cancel() diff --git a/config/config.go b/config/config.go index fab3237..417c2b9 100644 --- a/config/config.go +++ b/config/config.go @@ -61,6 +61,10 @@ type Config struct { TTS_SPEED float32 `toml:"TTS_SPEED"` TTS_PROVIDER string `toml:"TTS_PROVIDER"` TTS_LANGUAGE string `toml:"TTS_LANGUAGE"` + // Kokoro ONNX TTS + KokoroModelPath string `toml:"KokoroModelPath"` + KokoroVoicesPath string `toml:"KokoroVoicesPath"` + KokoroVoice string `toml:"KokoroVoice"` // STT STT_TYPE string `toml:"STT_TYPE"` // WHISPER_SERVER, WHISPER_BINARY STT_URL string `toml:"STT_URL"` diff --git a/extra/tts.go b/extra/tts.go index 1960aa7..f75d23e 100644 --- a/extra/tts.go +++ b/extra/tts.go @@ -9,10 +9,12 @@ import ( "fmt" "gf-lt/config" "gf-lt/models" + "gf-lt/onnx" "io" "log/slog" "net/http" "os" + "os/exec" "strings" "sync" "time" @@ -22,7 +24,9 @@ import ( "github.com/gopxl/beep/v2" "github.com/gopxl/beep/v2/mp3" "github.com/gopxl/beep/v2/speaker" + "github.com/gopxl/beep/v2/wav" "github.com/neurosnap/sentences/english" + "github.com/yalue/onnxruntime_go" ) var ( @@ -32,6 +36,14 @@ var ( // endsWithPunctuation = regexp.MustCompile(`[;.!?]$`) ) +type seekableBuffer struct { + *bytes.Buffer +} + +func (s *seekableBuffer) Seek(offset int64, whence int) (int64, error) { + return 0, nil +} + type Orator interface { Speak(text string) error Stop() @@ -194,6 +206,18 @@ func NewOrator(log *slog.Logger, cfg *config.Config) Orator { go orator.readroutine() go orator.stoproutine() return orator + case "kokoro_onnx": + log.Info("Initializing Kokoro ONNX TTS", "modelPath", cfg.KokoroModelPath, "voicesPath", cfg.KokoroVoicesPath, "voice", cfg.KokoroVoice, "speed", cfg.TTS_SPEED) + orator := &KokoroONNXOrator{ + logger: log, + modelPath: cfg.KokoroModelPath, + voicesPath: cfg.KokoroVoicesPath, + speed: cfg.TTS_SPEED, + voice: cfg.KokoroVoice, + } + go orator.readroutine() + go orator.stoproutine() + return orator default: language := cfg.TTS_LANGUAGE if language == "" { @@ -471,3 +495,476 @@ func (o *GoogleTranslateOrator) Stop() { _ = o.speech.Stop() } } + +// KokoroONNXOrator implements Kokoro TTS using ONNX runtime +type KokoroONNXOrator struct { + logger *slog.Logger + mu sync.Mutex + session *onnxruntime_go.DynamicAdvancedSession + phonemeMap map[string]int + espeakCmd string + voice string + speed float32 + styleVector []float32 + currentStream *beep.Ctrl + currentDone chan bool + textBuffer strings.Builder + interrupt bool + modelLoaded bool + modelPath string + voicesPath string +} + +// Phoneme to token ID mapping from Kokoro tokenizer.json +var kokoroPhonemeMap = map[string]int{ + "$": 0, ";": 1, ":": 2, ",": 3, ".": 4, "!": 5, "?": 6, "—": 9, "…": 10, "\"": 11, "(": 12, ")": 13, "“": 14, "”": 15, " ": 16, "̃": 17, "ˢ": 18, "ˤ": 19, "˦": 20, "˨": 21, "ᾝ": 22, "⭧": 23, + "A": 24, "I": 25, "O": 31, "Q": 33, "S": 35, "T": 36, "W": 39, "Y": 41, "ʲ": 42, + "a": 43, "b": 44, "c": 45, "d": 46, "e": 47, "f": 48, "h": 50, "i": 51, "j": 52, "k": 53, "l": 54, "m": 55, "n": 56, "o": 57, "p": 58, "q": 59, "r": 60, "s": 61, "t": 62, "u": 63, "v": 64, "w": 65, "x": 66, "y": 67, "z": 68, + "ɑ": 69, "ɐ": 70, "ɒ": 71, "æ": 72, "β": 75, "ɔ": 76, "ɕ": 77, "ç": 78, "ɖ": 80, "ð": 81, "˔": 82, "ə": 83, "ɚ": 85, "ɛ": 86, "ɜ": 87, "ɟ": 90, "ɡ": 92, "ɥ": 99, "ɨ": 101, "ɪ": 102, "ɝ": 103, "ɯ": 110, "ɰ": 111, "ŋ": 112, "ɳ": 113, "ɲ": 114, "ɴ": 115, "ø": 116, "ɸ": 118, "θ": 119, "œ": 120, "ɹ": 123, "ɾ": 125, "ɺ": 126, "ʁ": 128, "ɽ": 129, "ʂ": 130, "ʃ": 131, "ʈ": 132, "˧": 133, "ʊ": 135, "ʋ": 136, "ʌ": 138, "ɢ": 139, "ɣ": 140, "χ": 142, "ʎ": 143, "ʒ": 147, "ʔ": 148, + "ˈ": 156, "ˌ": 157, "ː": 158, "̰": 162, "̊": 164, "↕": 169, "→": 171, "↗": 172, "↘": 173, "ᶻ": 177, +} + +func (o *KokoroONNXOrator) ensureInitialized(modelPath string) error { + o.logger.Debug("ensureInitialized called", "modelPath", modelPath) + if o.modelLoaded { + o.logger.Debug("model already loaded") + return nil + } + o.mu.Lock() + defer o.mu.Unlock() + if o.modelLoaded { + return nil + } + + if modelPath == "" { + o.logger.Error("modelPath is empty, cannot load ONNX model") + return fmt.Errorf("modelPath is empty, set KokoroModelPath in config") + } + + // Initialize ONNX runtime (shared with embedder) + if err := onnx.Init(); err != nil { + o.logger.Error("ONNX init failed", "error", err) + return fmt.Errorf("ONNX init failed: %w", err) + } + if onnx.HasCUDASupport() { + o.logger.Info("ONNX using CUDA") + } else { + o.logger.Info("ONNX using CPU fallback") + } + + if o.phonemeMap == nil { + o.phonemeMap = kokoroPhonemeMap + } + + if o.espeakCmd == "" { + o.espeakCmd = "espeak-ng" + if _, err := exec.LookPath(o.espeakCmd); err != nil { + o.espeakCmd = "espeak" + if _, err := exec.LookPath(o.espeakCmd); err != nil { + return fmt.Errorf("espeak-ng or espeak not found. Install with: sudo apt-get install espeak-ng") + } + } + } + o.logger.Info("using espeak command", "cmd", o.espeakCmd) + + // Load voice embedding if not already loaded + if o.styleVector == nil { + voiceName := o.voice + if voiceName == "" { + voiceName = "af_bella" + } + if o.voicesPath != "" { + styleVec, err := onnx.LoadVoice(o.voicesPath, voiceName) + if err != nil { + o.logger.Warn("failed to load voice, using zeros", "error", err, "voice", voiceName) + o.styleVector = make([]float32, 256) + } else { + // Shape is (510, 1, 256), we want the last 256 values (or first? let's use mean or just pick one) + // Actually, let's average across all 510 to get a single 256-dim vector + o.styleVector = make([]float32, 256) + for i := 0; i < 256; i++ { + var sum float32 + for j := 0; j < 510; j++ { + sum += styleVec[j*256+i] + } + o.styleVector[i] = sum / 510.0 + } + o.logger.Info("loaded voice embedding", "voice", voiceName) + } + } else { + o.logger.Warn("no voices path configured, using zeros for style") + o.styleVector = make([]float32, 256) + } + } + + opts, err := onnx.NewSessionOptions() + if err != nil { + return fmt.Errorf("failed to create session options: %w", err) + } + defer func() { _ = opts.Destroy() }() + + if onnx.HasCUDASupport() { + o.logger.Info("session options created with CUDA") + } else { + o.logger.Info("session options created with CPU") + } + + session, err := onnxruntime_go.NewDynamicAdvancedSession( + modelPath, + []string{"input_ids", "style", "speed"}, + []string{"waveform"}, + opts, + ) + if err != nil { + o.logger.Error("failed to create ONNX session", "error", err) + return fmt.Errorf("failed to create ONNX session: %w", err) + } + o.session = session + o.modelLoaded = true + o.logger.Info("Kokoro ONNX model loaded successfully", "model", modelPath) + return nil +} + +func (o *KokoroONNXOrator) textToPhonemes(text string) (string, error) { + o.logger.Debug("converting text to phonemes", "text", text) + cmd := exec.Command(o.espeakCmd, "-x", "-q", text) + output, err := cmd.Output() + if err != nil { + o.logger.Error("espeak failed", "error", err, "cmd", o.espeakCmd, "text", text) + return "", fmt.Errorf("espeak failed: %w", err) + } + + phonemeStr := strings.TrimSpace(string(output)) + o.logger.Debug("phonemes generated", "phonemes", phonemeStr) + return phonemeStr, nil +} + +func (o *KokoroONNXOrator) phonemesToTokens(phonemeStr string) ([]int, error) { + o.logger.Debug("converting phonemes to tokens", "phonemes", phonemeStr) + + if phonemeStr == "" { + o.logger.Error("empty phoneme string") + return nil, fmt.Errorf("empty phoneme string") + } + + // Iterate over each character in the phoneme string + tokens := make([]int, 0) + for _, ch := range phonemeStr { + chStr := string(ch) + if tokenID, ok := o.phonemeMap[chStr]; ok { + tokens = append(tokens, tokenID) + } + } + + if len(tokens) == 0 { + o.logger.Error("no phonemes mapped to tokens", "phonemeStr", phonemeStr) + return nil, fmt.Errorf("no valid phonemes mapped to tokens") + } + o.logger.Debug("tokens generated", "count", len(tokens), "tokens", tokens) + return tokens, nil +} + +func (o *KokoroONNXOrator) generateAudio(text string) ([]float32, error) { + defer func() { + if r := recover(); r != nil { + fmt.Printf("PANIC RECOVERED in generateAudio: %v\n", r) + } + }() + + o.logger.Debug("generateAudio called", "text", text, "speed", o.speed) + if err := o.ensureInitialized(o.modelPath); err != nil { + o.logger.Error("ensureInitialized failed", "error", err) + return nil, err + } + + phonemeStr, err := o.textToPhonemes(text) + if err != nil { + o.logger.Error("phoneme conversion failed", "error", err) + return nil, fmt.Errorf("phoneme conversion failed: %w", err) + } + + tokens, err := o.phonemesToTokens(phonemeStr) + if err != nil { + o.logger.Error("token conversion failed", "error", err) + return nil, fmt.Errorf("token conversion failed: %w", err) + } + + if len(tokens) > 510 { + return nil, fmt.Errorf("text too long: %d tokens (max 510)", len(tokens)) + } + + tokens = append([]int{0}, tokens...) + tokens = append(tokens, 0) + o.logger.Debug("tokens prepared", "count", len(tokens)) + + inputIDs := make([]int64, len(tokens)) + for i, t := range tokens { + inputIDs[i] = int64(t) + } + + inputTensor, err := onnxruntime_go.NewTensor[int64]( + onnxruntime_go.NewShape(1, int64(len(inputIDs))), + inputIDs, + ) + if err != nil { + o.logger.Error("failed to create input tensor", "error", err) + return nil, fmt.Errorf("failed to create input tensor: %w", err) + } + defer func() { _ = inputTensor.Destroy() }() + o.logger.Debug("input tensor created", "shape", fmt.Sprintf("[1,%d]", len(inputIDs))) + + styleTensor, err := onnxruntime_go.NewTensor[float32]( + onnxruntime_go.NewShape(1, 256), + o.styleVector, + ) + if err != nil { + o.logger.Error("failed to create style tensor", "error", err) + return nil, fmt.Errorf("failed to create style tensor: %w", err) + } + defer func() { _ = styleTensor.Destroy() }() + + speedTensor, err := onnxruntime_go.NewTensor[float32]( + onnxruntime_go.NewShape(1), + []float32{o.speed}, + ) + if err != nil { + o.logger.Error("failed to create speed tensor", "error", err) + return nil, fmt.Errorf("failed to create speed tensor: %w", err) + } + defer func() { _ = speedTensor.Destroy() }() + o.logger.Debug("speed tensor created", "speed", o.speed) + + outputTensor, err := onnxruntime_go.NewEmptyTensor[float32]( + onnxruntime_go.NewShape(1, 512), + ) + if err != nil { + o.logger.Error("failed to create output tensor", "error", err) + return nil, fmt.Errorf("failed to create output tensor: %w", err) + } + defer func() { _ = outputTensor.Destroy() }() + o.logger.Debug("output tensor created", "shape", "[1,512]") + + o.logger.Info("running ONNX inference", "input_len", len(inputIDs)) + err = o.session.Run( + []onnxruntime_go.Value{inputTensor, styleTensor, speedTensor}, + []onnxruntime_go.Value{outputTensor}, + ) + if err != nil { + o.logger.Error("ONNX inference failed", "error", err) + return nil, fmt.Errorf("ONNX inference failed: %w", err) + } + o.logger.Debug("ONNX inference completed") + + audioData := outputTensor.GetData() + if len(audioData) == 0 { + o.logger.Error("empty audio output from ONNX") + return nil, fmt.Errorf("empty audio output") + } + + o.logger.Debug("audio generated", "samples", len(audioData)) + audio := make([]float32, len(audioData)) + copy(audio, audioData) + return audio, nil +} + +func (o *KokoroONNXOrator) Speak(text string) error { + o.logger.Debug("KokoroONNX Speak called", "text_len", len(text)) + + audio, err := o.generateAudio(text) + if err != nil { + o.logger.Error("audio generation failed", "error", err) + return fmt.Errorf("audio generation failed: %w", err) + } + + o.logger.Debug("audio ready for playback", "samples", len(audio)) + + // Create streamer for encoding + encodeStreamer := beep.StreamerFunc(func(samples [][2]float64) (n int, ok bool) { + for i := range samples { + if i >= len(audio) { + return i, false + } + samples[i][0] = float64(audio[i]) + samples[i][1] = float64(audio[i]) + } + return len(audio), true + }) + + buf := &seekableBuffer{new(bytes.Buffer)} + err = wav.Encode(buf, encodeStreamer, beep.Format{ + SampleRate: 24000, + NumChannels: 1, + Precision: 2, + }) + if err != nil { + o.logger.Error("wav encoding failed", "error", err) + return fmt.Errorf("wav encoding failed: %w", err) + } + + o.logger.Debug("wav encoded", "size", buf.Len()) + + decodedStreamer, format, err := wav.Decode(bytes.NewReader(buf.Bytes())) + if err != nil { + o.logger.Error("wav decode failed", "error", err) + return fmt.Errorf("wav decode failed: %w", err) + } + defer decodedStreamer.Close() + + o.logger.Debug("wav decoded", "format", format) + + if err := speaker.Init(format.SampleRate, format.SampleRate.N(time.Second/10)); err != nil { + o.logger.Error("speaker init failed", "error", err) + return fmt.Errorf("speaker init failed: %w", err) + } + + o.logger.Info("playing audio", "sampleRate", format.SampleRate, "channels", format.NumChannels) + done := make(chan bool) + o.mu.Lock() + o.currentDone = done + o.currentStream = &beep.Ctrl{Streamer: beep.Seq(decodedStreamer, beep.Callback(func() { + o.logger.Debug("playback finished") + o.mu.Lock() + close(done) + o.currentStream = nil + o.currentDone = nil + o.mu.Unlock() + })), Paused: false} + o.mu.Unlock() + + speaker.Play(o.currentStream) + <-done + o.logger.Debug("Speak completed") + return nil +} + +func (o *KokoroONNXOrator) Stop() { + o.logger.Debug("stopping KokoroONNX orator") + speaker.Lock() + defer speaker.Unlock() + o.mu.Lock() + defer o.mu.Unlock() + if o.currentStream != nil { + o.currentStream.Streamer = nil + } +} + +func (o *KokoroONNXOrator) GetLogger() *slog.Logger { + return o.logger +} + +func (o *KokoroONNXOrator) stoproutine() { + o.logger.Debug("KokoroONNX stoproutine started") + for { + <-TTSDoneChan + o.logger.Debug("KokoroONNX got done signal") + o.Stop() + for len(TTSTextChan) > 0 { + <-TTSTextChan + } + o.mu.Lock() + o.textBuffer.Reset() + if o.currentDone != nil { + select { + case o.currentDone <- true: + default: + } + } + o.interrupt = true + o.mu.Unlock() + o.logger.Debug("KokoroONNX stoproutine finished") + } +} + +func (o *KokoroONNXOrator) readroutine() { + defer func() { + if r := recover(); r != nil { + fmt.Printf("PANIC RECOVERED in readroutine: %v\n", r) + // Restart the goroutine + go o.readroutine() + } + }() + + o.logger.Debug("KokoroONNX readroutine started") + tokenizer, _ := english.NewSentenceTokenizer(nil) + for { + select { + case chunk := <-TTSTextChan: + o.logger.Debug("KokoroONNX received chunk", "chunk_len", len(chunk)) + o.mu.Lock() + o.interrupt = false + _, err := o.textBuffer.WriteString(chunk) + if err != nil { + o.logger.Warn("failed to write to buffer", "error", err) + o.mu.Unlock() + continue + } + text := o.textBuffer.String() + sentences := tokenizer.Tokenize(text) + o.logger.Debug("KokoroONNX tokenized", "total_sentences", len(sentences), "buffer", text) + if len(sentences) <= 1 { + o.logger.Debug("KokoroONNX not enough sentences, waiting") + o.mu.Unlock() + continue + } + completeSentences := sentences[:len(sentences)-1] + remaining := sentences[len(sentences)-1].Text + o.textBuffer.Reset() + o.textBuffer.WriteString(remaining) + o.logger.Debug("KokoroONNX processing sentences", "count", len(completeSentences)) + o.mu.Unlock() + + for _, sentence := range completeSentences { + o.mu.Lock() + interrupted := o.interrupt + o.mu.Unlock() + if interrupted { + o.logger.Debug("KokoroONNX interrupted, exiting") + return + } + cleanedText := models.CleanText(sentence.Text) + if cleanedText == "" { + continue + } + o.logger.Info("KokoroONNX speak", "text", cleanedText) + if err := o.Speak(cleanedText); err != nil { + o.logger.Error("KokoroONNX tts failed", "text", cleanedText, "error", err) + } + } + case <-TTSFlushChan: + o.logger.Debug("KokoroONNX flush signal") + if len(TTSTextChan) > 0 { + for chunk := range TTSTextChan { + o.mu.Lock() + _, err := o.textBuffer.WriteString(chunk) + o.mu.Unlock() + if err != nil { + continue + } + if len(TTSTextChan) == 0 { + break + } + } + } + o.mu.Lock() + remaining := o.textBuffer.String() + remaining = models.CleanText(remaining) + o.textBuffer.Reset() + o.mu.Unlock() + if remaining == "" { + continue + } + sentencesRem := tokenizer.Tokenize(remaining) + for _, rs := range sentencesRem { + o.mu.Lock() + interrupt := o.interrupt + o.mu.Unlock() + if interrupt { + break + } + if err := o.Speak(rs.Text); err != nil { + o.logger.Error("tts failed", "text", rs.Text, "error", err) + } + } + } + } +} diff --git a/rag/embedder.go b/rag/embedder.go index 5a4aae0..41d49cd 100644 --- a/rag/embedder.go +++ b/rag/embedder.go @@ -7,6 +7,7 @@ import ( "fmt" "gf-lt/config" "gf-lt/models" + "gf-lt/onnx" "log/slog" "net/http" "os" @@ -156,43 +157,6 @@ type ONNXEmbedder struct { modelPath string } -var onnxInitOnce sync.Once -var onnxReady bool -var onnxLibPath string -var cudaLibPath string - -var onnxLibPaths = []string{ - "/usr/lib/libonnxruntime.so", - "/usr/lib/libonnxruntime.so.1.24.2", - "/usr/local/lib/libonnxruntime.so", - "/usr/lib/x86_64-linux-gnu/libonnxruntime.so", - "/opt/onnxruntime/lib/libonnxruntime.so", -} - -var cudaLibPaths = []string{ - "/usr/lib/libonnxruntime_providers_cuda.so", - "/usr/local/lib/libonnxruntime_providers_cuda.so", - "/opt/onnxruntime/lib/libonnxruntime_providers_cuda.so", -} - -func findONNXLibrary() string { - for _, path := range onnxLibPaths { - if _, err := os.Stat(path); err == nil { - return path - } - } - return "" -} - -func findCUDALibrary() string { - for _, path := range cudaLibPaths { - if _, err := os.Stat(path); err == nil { - return path - } - } - return "" -} - func NewONNXEmbedder(modelPath, tokenizerPath string, dims int, logger *slog.Logger) (*ONNXEmbedder, error) { // Check if model and tokenizer files exist if _, err := os.Stat(modelPath); err != nil { @@ -202,17 +166,16 @@ func NewONNXEmbedder(modelPath, tokenizerPath string, dims int, logger *slog.Log return nil, fmt.Errorf("tokenizer not found: %w", err) } - // Find ONNX library - onnxLibPath = findONNXLibrary() - if onnxLibPath == "" { - return nil, errors.New("ONNX runtime library not found in standard locations") + // Initialize ONNX runtime + if err := onnx.Init(); err != nil { + return nil, fmt.Errorf("ONNX init failed: %w", err) + } + if onnx.HasCUDASupport() { + logger.Info("ONNX CUDA support enabled") + } else { + logger.Info("ONNX using CPU fallback") } - // Find CUDA provider library (optional) - cudaLibPath = findCUDALibrary() - if cudaLibPath == "" { - fmt.Println("WARNING: CUDA provider library not found, will use CPU") - } emb := &ONNXEmbedder{ tokenizerPath: tokenizerPath, dims: dims, @@ -239,26 +202,12 @@ func (e *ONNXEmbedder) ensureInitialized() error { } e.tokenizer = tok } - onnxInitOnce.Do(func() { - onnxruntime_go.SetSharedLibraryPath(onnxLibPath) - if err := onnxruntime_go.InitializeEnvironment(); err != nil { - e.logger.Error("failed to initialize ONNX runtime", "error", err) - onnxReady = false - return - } - // Register CUDA provider if available - if cudaLibPath != "" { - if err := onnxruntime_go.RegisterExecutionProviderLibrary("CUDA", cudaLibPath); err != nil { - e.logger.Warn("failed to register CUDA provider", "error", err) - } - } - onnxReady = true - }) - if !onnxReady { + // ONNX runtime already initialized by onnx.Init() in NewONNXEmbedder + if !onnx.IsReady() { return errors.New("ONNX runtime not ready") } // Create session options - opts, err := onnxruntime_go.NewSessionOptions() + opts, err := onnx.NewSessionOptions() if err != nil { return fmt.Errorf("failed to create session options: %w", err) } @@ -266,27 +215,7 @@ func (e *ONNXEmbedder) ensureInitialized() error { _ = opts.Destroy() }() - // Try to add CUDA provider - useCUDA := cudaLibPath != "" - if useCUDA { - cudaOpts, err := onnxruntime_go.NewCUDAProviderOptions() - if err != nil { - e.logger.Warn("failed to create CUDA provider options, falling back to CPU", "error", err) - useCUDA = false - } else { - defer func() { - _ = cudaOpts.Destroy() - }() - if err := cudaOpts.Update(map[string]string{"device_id": "0"}); err != nil { - e.logger.Warn("failed to update CUDA options, falling back to CPU", "error", err) - useCUDA = false - } else if err := opts.AppendExecutionProviderCUDA(cudaOpts); err != nil { - e.logger.Warn("failed to append CUDA provider, falling back to CPU", "error", err) - useCUDA = false - } - } - } - if useCUDA { + if onnx.HasCUDASupport() { e.logger.Info("Using CUDA for ONNX inference") } else { e.logger.Info("Using CPU for ONNX inference")