From 44633d64c68a278bc0c93e072fb48f406b12ae87 Mon Sep 17 00:00:00 2001 From: Grail Finder Date: Sat, 7 Mar 2026 08:52:10 +0300 Subject: [PATCH] Chore: move to own file --- extra/kokoro_onnx.go | 451 ++++++++++++++++++++++++++++++++++++++++ extra/tts.go | 477 ------------------------------------------- 2 files changed, 451 insertions(+), 477 deletions(-) create mode 100644 extra/kokoro_onnx.go diff --git a/extra/kokoro_onnx.go b/extra/kokoro_onnx.go new file mode 100644 index 0000000..442c511 --- /dev/null +++ b/extra/kokoro_onnx.go @@ -0,0 +1,451 @@ +//go:build extra +// +build extra + +package extra + +import ( + "bytes" + "fmt" + "gf-lt/models" + "gf-lt/onnx" + "log/slog" + "os/exec" + "strings" + "sync" + "time" + + "github.com/gopxl/beep/v2" + "github.com/gopxl/beep/v2/speaker" + "github.com/gopxl/beep/v2/wav" + "github.com/neurosnap/sentences/english" + "github.com/yalue/onnxruntime_go" +) + +// KokoroONNXOrator implements Kokoro TTS using ONNX runtime +type KokoroONNXOrator struct { + logger *slog.Logger + mu sync.Mutex + session *onnxruntime_go.DynamicAdvancedSession + phonemeMap map[string]int + espeakCmd string + voice string + speed float32 + styleVector []float32 + currentStream *beep.Ctrl + currentDone chan bool + textBuffer strings.Builder + interrupt bool + modelLoaded bool + modelPath string + voicesPath string +} + +// Phoneme to token ID mapping from Kokoro tokenizer.json +var kokoroPhonemeMap = map[string]int{ + "$": 0, ";": 1, ":": 2, ",": 3, ".": 4, "!": 5, "?": 6, "—": 9, "…": 10, "\"": 11, "(": 12, ")": 13, "“": 14, "”": 15, " ": 16, "̃": 17, "ˢ": 18, "ˤ": 19, "˦": 20, "˨": 21, "ᾝ": 22, "⭧": 23, + "A": 24, "I": 25, "O": 31, "Q": 33, "S": 35, "T": 36, "W": 39, "Y": 41, "ʲ": 42, + "a": 43, "b": 44, "c": 45, "d": 46, "e": 47, "f": 48, "h": 50, "i": 51, "j": 52, "k": 53, "l": 54, "m": 55, "n": 56, "o": 57, "p": 58, "q": 59, "r": 60, "s": 61, "t": 62, "u": 63, "v": 64, "w": 65, "x": 66, "y": 67, "z": 68, + "ɑ": 69, "ɐ": 70, "ɒ": 71, "æ": 72, "β": 75, "ɔ": 76, "ɕ": 77, "ç": 78, "ɖ": 80, "ð": 81, "˔": 82, "ə": 83, "ɚ": 85, "ɛ": 86, "ɜ": 87, "ɟ": 90, "ɡ": 92, "ɥ": 99, "ɨ": 101, "ɪ": 102, "ɝ": 103, "ɯ": 110, "ɰ": 111, "ŋ": 112, "ɳ": 113, "ɲ": 114, "ɴ": 115, "ø": 116, "ɸ": 118, "θ": 119, "œ": 120, "ɹ": 123, "ɾ": 125, "ɺ": 126, "ʁ": 128, "ɽ": 129, "ʂ": 130, "ʃ": 131, "ʈ": 132, "˧": 133, "ʊ": 135, "ʋ": 136, "ʌ": 138, "ɢ": 139, "ɣ": 140, "χ": 142, "ʎ": 143, "ʒ": 147, "ʔ": 148, + "ˈ": 156, "ˌ": 157, "ː": 158, "̰": 162, "̊": 164, "↕": 169, "→": 171, "↗": 172, "↘": 173, "ᶻ": 177, +} + +func (o *KokoroONNXOrator) ensureInitialized(modelPath string) error { + o.logger.Debug("ensureInitialized called", "modelPath", modelPath) + if o.modelLoaded { + o.logger.Debug("model already loaded") + return nil + } + o.mu.Lock() + defer o.mu.Unlock() + if o.modelLoaded { + return nil + } + if modelPath == "" { + o.logger.Error("modelPath is empty, cannot load ONNX model") + return fmt.Errorf("modelPath is empty, set KokoroModelPath in config") + } + // Initialize ONNX runtime (shared with embedder) + if err := onnx.Init(); err != nil { + o.logger.Error("ONNX init failed", "error", err) + return fmt.Errorf("ONNX init failed: %w", err) + } + if onnx.HasCUDASupport() { + o.logger.Info("ONNX using CUDA") + } else { + o.logger.Info("ONNX using CPU fallback") + } + if o.phonemeMap == nil { + o.phonemeMap = kokoroPhonemeMap + } + if o.espeakCmd == "" { + o.espeakCmd = "espeak-ng" + if _, err := exec.LookPath(o.espeakCmd); err != nil { + o.espeakCmd = "espeak" + if _, err := exec.LookPath(o.espeakCmd); err != nil { + return fmt.Errorf("espeak-ng or espeak not found. Install with: sudo apt-get install espeak-ng") + } + } + } + o.logger.Info("using espeak command", "cmd", o.espeakCmd) + // Load voice embedding if not already loaded + if o.styleVector == nil { + voiceName := o.voice + if voiceName == "" { + voiceName = "af_bella" + } + if o.voicesPath != "" { + styleVec, err := onnx.LoadVoice(o.voicesPath, voiceName) + if err != nil { + o.logger.Warn("failed to load voice, using zeros", "error", err, "voice", voiceName) + o.styleVector = make([]float32, 256) + } else { + // Shape is (510, 1, 256), we want the last 256 values (or first? let's use mean or just pick one) + // Actually, let's average across all 510 to get a single 256-dim vector + o.styleVector = make([]float32, 256) + for i := 0; i < 256; i++ { + var sum float32 + for j := 0; j < 510; j++ { + sum += styleVec[j*256+i] + } + o.styleVector[i] = sum / 510.0 + } + o.logger.Info("loaded voice embedding", "voice", voiceName) + } + } else { + o.logger.Warn("no voices path configured, using zeros for style") + o.styleVector = make([]float32, 256) + } + } + opts, err := onnx.NewSessionOptions() + if err != nil { + return fmt.Errorf("failed to create session options: %w", err) + } + defer func() { _ = opts.Destroy() }() + if onnx.HasCUDASupport() { + o.logger.Info("session options created with CUDA") + } else { + o.logger.Info("session options created with CPU") + } + session, err := onnxruntime_go.NewDynamicAdvancedSession( + modelPath, + []string{"input_ids", "style", "speed"}, + []string{"waveform"}, + opts, + ) + if err != nil { + o.logger.Error("failed to create ONNX session", "error", err) + return fmt.Errorf("failed to create ONNX session: %w", err) + } + o.session = session + o.modelLoaded = true + o.logger.Info("Kokoro ONNX model loaded successfully", "model", modelPath) + return nil +} + +func (o *KokoroONNXOrator) textToPhonemes(text string) (string, error) { + o.logger.Debug("converting text to phonemes", "text", text) + cmd := exec.Command(o.espeakCmd, "-x", "-q", text) + output, err := cmd.Output() + if err != nil { + o.logger.Error("espeak failed", "error", err, "cmd", o.espeakCmd, "text", text) + return "", fmt.Errorf("espeak failed: %w", err) + } + + phonemeStr := strings.TrimSpace(string(output)) + o.logger.Debug("phonemes generated", "phonemes", phonemeStr) + return phonemeStr, nil +} + +func (o *KokoroONNXOrator) phonemesToTokens(phonemeStr string) ([]int, error) { + o.logger.Debug("converting phonemes to tokens", "phonemes", phonemeStr) + + if phonemeStr == "" { + o.logger.Error("empty phoneme string") + return nil, fmt.Errorf("empty phoneme string") + } + + // Iterate over each character in the phoneme string + tokens := make([]int, 0) + for _, ch := range phonemeStr { + chStr := string(ch) + if tokenID, ok := o.phonemeMap[chStr]; ok { + tokens = append(tokens, tokenID) + } + } + + if len(tokens) == 0 { + o.logger.Error("no phonemes mapped to tokens", "phonemeStr", phonemeStr) + return nil, fmt.Errorf("no valid phonemes mapped to tokens") + } + o.logger.Debug("tokens generated", "count", len(tokens), "tokens", tokens) + return tokens, nil +} + +func (o *KokoroONNXOrator) generateAudio(text string) ([]float32, error) { + + o.logger.Debug("generateAudio called", "text", text, "speed", o.speed) + if err := o.ensureInitialized(o.modelPath); err != nil { + o.logger.Error("ensureInitialized failed", "error", err) + return nil, err + } + phonemeStr, err := o.textToPhonemes(text) + if err != nil { + o.logger.Error("phoneme conversion failed", "error", err) + return nil, fmt.Errorf("phoneme conversion failed: %w", err) + } + tokens, err := o.phonemesToTokens(phonemeStr) + if err != nil { + o.logger.Error("token conversion failed", "error", err) + return nil, fmt.Errorf("token conversion failed: %w", err) + } + if len(tokens) > 510 { + return nil, fmt.Errorf("text too long: %d tokens (max 510)", len(tokens)) + } + tokens = append([]int{0}, tokens...) + tokens = append(tokens, 0) + o.logger.Debug("tokens prepared", "count", len(tokens)) + inputIDs := make([]int64, len(tokens)) + for i, t := range tokens { + inputIDs[i] = int64(t) + } + inputTensor, err := onnxruntime_go.NewTensor[int64]( + onnxruntime_go.NewShape(1, int64(len(inputIDs))), + inputIDs, + ) + if err != nil { + o.logger.Error("failed to create input tensor", "error", err) + return nil, fmt.Errorf("failed to create input tensor: %w", err) + } + defer func() { _ = inputTensor.Destroy() }() + o.logger.Debug("input tensor created", "shape", fmt.Sprintf("[1,%d]", len(inputIDs))) + styleTensor, err := onnxruntime_go.NewTensor[float32]( + onnxruntime_go.NewShape(1, 256), + o.styleVector, + ) + if err != nil { + o.logger.Error("failed to create style tensor", "error", err) + return nil, fmt.Errorf("failed to create style tensor: %w", err) + } + defer func() { _ = styleTensor.Destroy() }() + speedTensor, err := onnxruntime_go.NewTensor[float32]( + onnxruntime_go.NewShape(1), + []float32{o.speed}, + ) + if err != nil { + o.logger.Error("failed to create speed tensor", "error", err) + return nil, fmt.Errorf("failed to create speed tensor: %w", err) + } + defer func() { _ = speedTensor.Destroy() }() + o.logger.Debug("speed tensor created", "speed", o.speed) + outputTensor, err := onnxruntime_go.NewEmptyTensor[float32]( + onnxruntime_go.NewShape(1, 512), + ) + if err != nil { + o.logger.Error("failed to create output tensor", "error", err) + return nil, fmt.Errorf("failed to create output tensor: %w", err) + } + defer func() { _ = outputTensor.Destroy() }() + o.logger.Debug("output tensor created", "shape", "[1,512]") + o.logger.Info("running ONNX inference", "input_len", len(inputIDs)) + err = o.session.Run( + []onnxruntime_go.Value{inputTensor, styleTensor, speedTensor}, + []onnxruntime_go.Value{outputTensor}, + ) + if err != nil { + o.logger.Error("ONNX inference failed", "error", err) + return nil, fmt.Errorf("ONNX inference failed: %w", err) + } + o.logger.Debug("ONNX inference completed") + audioData := outputTensor.GetData() + if len(audioData) == 0 { + o.logger.Error("empty audio output from ONNX") + return nil, fmt.Errorf("empty audio output") + } + o.logger.Debug("audio generated", "samples", len(audioData)) + audio := make([]float32, len(audioData)) + copy(audio, audioData) + return audio, nil +} + +func (o *KokoroONNXOrator) Speak(text string) error { + o.logger.Debug("KokoroONNX Speak called", "text_len", len(text)) + audio, err := o.generateAudio(text) + if err != nil { + o.logger.Error("audio generation failed", "error", err) + return fmt.Errorf("audio generation failed: %w", err) + } + o.logger.Debug("audio ready for playback", "samples", len(audio)) + // Create streamer for encoding + encodeStreamer := beep.StreamerFunc(func(samples [][2]float64) (n int, ok bool) { + for i := range samples { + if i >= len(audio) { + return i, false + } + samples[i][0] = float64(audio[i]) + samples[i][1] = float64(audio[i]) + } + return len(audio), true + }) + buf := &seekableBuffer{new(bytes.Buffer)} + err = wav.Encode(buf, encodeStreamer, beep.Format{ + SampleRate: 24000, + NumChannels: 1, + Precision: 2, + }) + if err != nil { + o.logger.Error("wav encoding failed", "error", err) + return fmt.Errorf("wav encoding failed: %w", err) + } + o.logger.Debug("wav encoded", "size", buf.Len()) + decodedStreamer, format, err := wav.Decode(bytes.NewReader(buf.Bytes())) + if err != nil { + o.logger.Error("wav decode failed", "error", err) + return fmt.Errorf("wav decode failed: %w", err) + } + defer decodedStreamer.Close() + o.logger.Debug("wav decoded", "format", format) + if err := speaker.Init(format.SampleRate, format.SampleRate.N(time.Second/10)); err != nil { + o.logger.Error("speaker init failed", "error", err) + return fmt.Errorf("speaker init failed: %w", err) + } + o.logger.Info("playing audio", "sampleRate", format.SampleRate, "channels", format.NumChannels) + done := make(chan bool) + o.mu.Lock() + o.currentDone = done + o.currentStream = &beep.Ctrl{Streamer: beep.Seq(decodedStreamer, beep.Callback(func() { + o.logger.Debug("playback finished") + o.mu.Lock() + close(done) + o.currentStream = nil + o.currentDone = nil + o.mu.Unlock() + })), Paused: false} + o.mu.Unlock() + speaker.Play(o.currentStream) + <-done + o.logger.Debug("Speak completed") + return nil +} + +func (o *KokoroONNXOrator) Stop() { + o.logger.Debug("stopping KokoroONNX orator") + speaker.Lock() + defer speaker.Unlock() + o.mu.Lock() + defer o.mu.Unlock() + if o.currentStream != nil { + o.currentStream.Streamer = nil + } +} + +func (o *KokoroONNXOrator) GetLogger() *slog.Logger { + return o.logger +} + +func (o *KokoroONNXOrator) stoproutine() { + o.logger.Debug("KokoroONNX stoproutine started") + for { + <-TTSDoneChan + o.logger.Debug("KokoroONNX got done signal") + o.Stop() + for len(TTSTextChan) > 0 { + <-TTSTextChan + } + o.mu.Lock() + o.textBuffer.Reset() + if o.currentDone != nil { + select { + case o.currentDone <- true: + default: + } + } + o.interrupt = true + o.mu.Unlock() + o.logger.Debug("KokoroONNX stoproutine finished") + } +} + +func (o *KokoroONNXOrator) readroutine() { + o.logger.Debug("KokoroONNX readroutine started") + tokenizer, _ := english.NewSentenceTokenizer(nil) + for { + select { + case chunk := <-TTSTextChan: + o.logger.Debug("KokoroONNX received chunk", "chunk_len", len(chunk)) + o.mu.Lock() + o.interrupt = false + _, err := o.textBuffer.WriteString(chunk) + if err != nil { + o.logger.Warn("failed to write to buffer", "error", err) + o.mu.Unlock() + continue + } + text := o.textBuffer.String() + sentences := tokenizer.Tokenize(text) + o.logger.Debug("KokoroONNX tokenized", "total_sentences", len(sentences), "buffer", text) + if len(sentences) <= 1 { + o.logger.Debug("KokoroONNX not enough sentences, waiting") + o.mu.Unlock() + continue + } + completeSentences := sentences[:len(sentences)-1] + remaining := sentences[len(sentences)-1].Text + o.textBuffer.Reset() + o.textBuffer.WriteString(remaining) + o.logger.Debug("KokoroONNX processing sentences", "count", len(completeSentences)) + o.mu.Unlock() + for _, sentence := range completeSentences { + o.mu.Lock() + interrupted := o.interrupt + o.mu.Unlock() + if interrupted { + o.logger.Debug("KokoroONNX interrupted, exiting") + return + } + cleanedText := models.CleanText(sentence.Text) + if cleanedText == "" { + continue + } + o.logger.Info("KokoroONNX speak", "text", cleanedText) + if err := o.Speak(cleanedText); err != nil { + o.logger.Error("KokoroONNX tts failed", "text", cleanedText, "error", err) + } + } + case <-TTSFlushChan: + o.logger.Debug("KokoroONNX flush signal") + if len(TTSTextChan) > 0 { + for chunk := range TTSTextChan { + o.mu.Lock() + _, err := o.textBuffer.WriteString(chunk) + o.mu.Unlock() + if err != nil { + continue + } + if len(TTSTextChan) == 0 { + break + } + } + } + o.mu.Lock() + remaining := o.textBuffer.String() + remaining = models.CleanText(remaining) + o.textBuffer.Reset() + o.mu.Unlock() + if remaining == "" { + continue + } + sentencesRem := tokenizer.Tokenize(remaining) + for _, rs := range sentencesRem { + o.mu.Lock() + interrupt := o.interrupt + o.mu.Unlock() + if interrupt { + break + } + if err := o.Speak(rs.Text); err != nil { + o.logger.Error("tts failed", "text", rs.Text, "error", err) + } + } + } + } +} diff --git a/extra/tts.go b/extra/tts.go index f75d23e..c7d6566 100644 --- a/extra/tts.go +++ b/extra/tts.go @@ -9,12 +9,10 @@ import ( "fmt" "gf-lt/config" "gf-lt/models" - "gf-lt/onnx" "io" "log/slog" "net/http" "os" - "os/exec" "strings" "sync" "time" @@ -24,9 +22,7 @@ import ( "github.com/gopxl/beep/v2" "github.com/gopxl/beep/v2/mp3" "github.com/gopxl/beep/v2/speaker" - "github.com/gopxl/beep/v2/wav" "github.com/neurosnap/sentences/english" - "github.com/yalue/onnxruntime_go" ) var ( @@ -495,476 +491,3 @@ func (o *GoogleTranslateOrator) Stop() { _ = o.speech.Stop() } } - -// KokoroONNXOrator implements Kokoro TTS using ONNX runtime -type KokoroONNXOrator struct { - logger *slog.Logger - mu sync.Mutex - session *onnxruntime_go.DynamicAdvancedSession - phonemeMap map[string]int - espeakCmd string - voice string - speed float32 - styleVector []float32 - currentStream *beep.Ctrl - currentDone chan bool - textBuffer strings.Builder - interrupt bool - modelLoaded bool - modelPath string - voicesPath string -} - -// Phoneme to token ID mapping from Kokoro tokenizer.json -var kokoroPhonemeMap = map[string]int{ - "$": 0, ";": 1, ":": 2, ",": 3, ".": 4, "!": 5, "?": 6, "—": 9, "…": 10, "\"": 11, "(": 12, ")": 13, "“": 14, "”": 15, " ": 16, "̃": 17, "ˢ": 18, "ˤ": 19, "˦": 20, "˨": 21, "ᾝ": 22, "⭧": 23, - "A": 24, "I": 25, "O": 31, "Q": 33, "S": 35, "T": 36, "W": 39, "Y": 41, "ʲ": 42, - "a": 43, "b": 44, "c": 45, "d": 46, "e": 47, "f": 48, "h": 50, "i": 51, "j": 52, "k": 53, "l": 54, "m": 55, "n": 56, "o": 57, "p": 58, "q": 59, "r": 60, "s": 61, "t": 62, "u": 63, "v": 64, "w": 65, "x": 66, "y": 67, "z": 68, - "ɑ": 69, "ɐ": 70, "ɒ": 71, "æ": 72, "β": 75, "ɔ": 76, "ɕ": 77, "ç": 78, "ɖ": 80, "ð": 81, "˔": 82, "ə": 83, "ɚ": 85, "ɛ": 86, "ɜ": 87, "ɟ": 90, "ɡ": 92, "ɥ": 99, "ɨ": 101, "ɪ": 102, "ɝ": 103, "ɯ": 110, "ɰ": 111, "ŋ": 112, "ɳ": 113, "ɲ": 114, "ɴ": 115, "ø": 116, "ɸ": 118, "θ": 119, "œ": 120, "ɹ": 123, "ɾ": 125, "ɺ": 126, "ʁ": 128, "ɽ": 129, "ʂ": 130, "ʃ": 131, "ʈ": 132, "˧": 133, "ʊ": 135, "ʋ": 136, "ʌ": 138, "ɢ": 139, "ɣ": 140, "χ": 142, "ʎ": 143, "ʒ": 147, "ʔ": 148, - "ˈ": 156, "ˌ": 157, "ː": 158, "̰": 162, "̊": 164, "↕": 169, "→": 171, "↗": 172, "↘": 173, "ᶻ": 177, -} - -func (o *KokoroONNXOrator) ensureInitialized(modelPath string) error { - o.logger.Debug("ensureInitialized called", "modelPath", modelPath) - if o.modelLoaded { - o.logger.Debug("model already loaded") - return nil - } - o.mu.Lock() - defer o.mu.Unlock() - if o.modelLoaded { - return nil - } - - if modelPath == "" { - o.logger.Error("modelPath is empty, cannot load ONNX model") - return fmt.Errorf("modelPath is empty, set KokoroModelPath in config") - } - - // Initialize ONNX runtime (shared with embedder) - if err := onnx.Init(); err != nil { - o.logger.Error("ONNX init failed", "error", err) - return fmt.Errorf("ONNX init failed: %w", err) - } - if onnx.HasCUDASupport() { - o.logger.Info("ONNX using CUDA") - } else { - o.logger.Info("ONNX using CPU fallback") - } - - if o.phonemeMap == nil { - o.phonemeMap = kokoroPhonemeMap - } - - if o.espeakCmd == "" { - o.espeakCmd = "espeak-ng" - if _, err := exec.LookPath(o.espeakCmd); err != nil { - o.espeakCmd = "espeak" - if _, err := exec.LookPath(o.espeakCmd); err != nil { - return fmt.Errorf("espeak-ng or espeak not found. Install with: sudo apt-get install espeak-ng") - } - } - } - o.logger.Info("using espeak command", "cmd", o.espeakCmd) - - // Load voice embedding if not already loaded - if o.styleVector == nil { - voiceName := o.voice - if voiceName == "" { - voiceName = "af_bella" - } - if o.voicesPath != "" { - styleVec, err := onnx.LoadVoice(o.voicesPath, voiceName) - if err != nil { - o.logger.Warn("failed to load voice, using zeros", "error", err, "voice", voiceName) - o.styleVector = make([]float32, 256) - } else { - // Shape is (510, 1, 256), we want the last 256 values (or first? let's use mean or just pick one) - // Actually, let's average across all 510 to get a single 256-dim vector - o.styleVector = make([]float32, 256) - for i := 0; i < 256; i++ { - var sum float32 - for j := 0; j < 510; j++ { - sum += styleVec[j*256+i] - } - o.styleVector[i] = sum / 510.0 - } - o.logger.Info("loaded voice embedding", "voice", voiceName) - } - } else { - o.logger.Warn("no voices path configured, using zeros for style") - o.styleVector = make([]float32, 256) - } - } - - opts, err := onnx.NewSessionOptions() - if err != nil { - return fmt.Errorf("failed to create session options: %w", err) - } - defer func() { _ = opts.Destroy() }() - - if onnx.HasCUDASupport() { - o.logger.Info("session options created with CUDA") - } else { - o.logger.Info("session options created with CPU") - } - - session, err := onnxruntime_go.NewDynamicAdvancedSession( - modelPath, - []string{"input_ids", "style", "speed"}, - []string{"waveform"}, - opts, - ) - if err != nil { - o.logger.Error("failed to create ONNX session", "error", err) - return fmt.Errorf("failed to create ONNX session: %w", err) - } - o.session = session - o.modelLoaded = true - o.logger.Info("Kokoro ONNX model loaded successfully", "model", modelPath) - return nil -} - -func (o *KokoroONNXOrator) textToPhonemes(text string) (string, error) { - o.logger.Debug("converting text to phonemes", "text", text) - cmd := exec.Command(o.espeakCmd, "-x", "-q", text) - output, err := cmd.Output() - if err != nil { - o.logger.Error("espeak failed", "error", err, "cmd", o.espeakCmd, "text", text) - return "", fmt.Errorf("espeak failed: %w", err) - } - - phonemeStr := strings.TrimSpace(string(output)) - o.logger.Debug("phonemes generated", "phonemes", phonemeStr) - return phonemeStr, nil -} - -func (o *KokoroONNXOrator) phonemesToTokens(phonemeStr string) ([]int, error) { - o.logger.Debug("converting phonemes to tokens", "phonemes", phonemeStr) - - if phonemeStr == "" { - o.logger.Error("empty phoneme string") - return nil, fmt.Errorf("empty phoneme string") - } - - // Iterate over each character in the phoneme string - tokens := make([]int, 0) - for _, ch := range phonemeStr { - chStr := string(ch) - if tokenID, ok := o.phonemeMap[chStr]; ok { - tokens = append(tokens, tokenID) - } - } - - if len(tokens) == 0 { - o.logger.Error("no phonemes mapped to tokens", "phonemeStr", phonemeStr) - return nil, fmt.Errorf("no valid phonemes mapped to tokens") - } - o.logger.Debug("tokens generated", "count", len(tokens), "tokens", tokens) - return tokens, nil -} - -func (o *KokoroONNXOrator) generateAudio(text string) ([]float32, error) { - defer func() { - if r := recover(); r != nil { - fmt.Printf("PANIC RECOVERED in generateAudio: %v\n", r) - } - }() - - o.logger.Debug("generateAudio called", "text", text, "speed", o.speed) - if err := o.ensureInitialized(o.modelPath); err != nil { - o.logger.Error("ensureInitialized failed", "error", err) - return nil, err - } - - phonemeStr, err := o.textToPhonemes(text) - if err != nil { - o.logger.Error("phoneme conversion failed", "error", err) - return nil, fmt.Errorf("phoneme conversion failed: %w", err) - } - - tokens, err := o.phonemesToTokens(phonemeStr) - if err != nil { - o.logger.Error("token conversion failed", "error", err) - return nil, fmt.Errorf("token conversion failed: %w", err) - } - - if len(tokens) > 510 { - return nil, fmt.Errorf("text too long: %d tokens (max 510)", len(tokens)) - } - - tokens = append([]int{0}, tokens...) - tokens = append(tokens, 0) - o.logger.Debug("tokens prepared", "count", len(tokens)) - - inputIDs := make([]int64, len(tokens)) - for i, t := range tokens { - inputIDs[i] = int64(t) - } - - inputTensor, err := onnxruntime_go.NewTensor[int64]( - onnxruntime_go.NewShape(1, int64(len(inputIDs))), - inputIDs, - ) - if err != nil { - o.logger.Error("failed to create input tensor", "error", err) - return nil, fmt.Errorf("failed to create input tensor: %w", err) - } - defer func() { _ = inputTensor.Destroy() }() - o.logger.Debug("input tensor created", "shape", fmt.Sprintf("[1,%d]", len(inputIDs))) - - styleTensor, err := onnxruntime_go.NewTensor[float32]( - onnxruntime_go.NewShape(1, 256), - o.styleVector, - ) - if err != nil { - o.logger.Error("failed to create style tensor", "error", err) - return nil, fmt.Errorf("failed to create style tensor: %w", err) - } - defer func() { _ = styleTensor.Destroy() }() - - speedTensor, err := onnxruntime_go.NewTensor[float32]( - onnxruntime_go.NewShape(1), - []float32{o.speed}, - ) - if err != nil { - o.logger.Error("failed to create speed tensor", "error", err) - return nil, fmt.Errorf("failed to create speed tensor: %w", err) - } - defer func() { _ = speedTensor.Destroy() }() - o.logger.Debug("speed tensor created", "speed", o.speed) - - outputTensor, err := onnxruntime_go.NewEmptyTensor[float32]( - onnxruntime_go.NewShape(1, 512), - ) - if err != nil { - o.logger.Error("failed to create output tensor", "error", err) - return nil, fmt.Errorf("failed to create output tensor: %w", err) - } - defer func() { _ = outputTensor.Destroy() }() - o.logger.Debug("output tensor created", "shape", "[1,512]") - - o.logger.Info("running ONNX inference", "input_len", len(inputIDs)) - err = o.session.Run( - []onnxruntime_go.Value{inputTensor, styleTensor, speedTensor}, - []onnxruntime_go.Value{outputTensor}, - ) - if err != nil { - o.logger.Error("ONNX inference failed", "error", err) - return nil, fmt.Errorf("ONNX inference failed: %w", err) - } - o.logger.Debug("ONNX inference completed") - - audioData := outputTensor.GetData() - if len(audioData) == 0 { - o.logger.Error("empty audio output from ONNX") - return nil, fmt.Errorf("empty audio output") - } - - o.logger.Debug("audio generated", "samples", len(audioData)) - audio := make([]float32, len(audioData)) - copy(audio, audioData) - return audio, nil -} - -func (o *KokoroONNXOrator) Speak(text string) error { - o.logger.Debug("KokoroONNX Speak called", "text_len", len(text)) - - audio, err := o.generateAudio(text) - if err != nil { - o.logger.Error("audio generation failed", "error", err) - return fmt.Errorf("audio generation failed: %w", err) - } - - o.logger.Debug("audio ready for playback", "samples", len(audio)) - - // Create streamer for encoding - encodeStreamer := beep.StreamerFunc(func(samples [][2]float64) (n int, ok bool) { - for i := range samples { - if i >= len(audio) { - return i, false - } - samples[i][0] = float64(audio[i]) - samples[i][1] = float64(audio[i]) - } - return len(audio), true - }) - - buf := &seekableBuffer{new(bytes.Buffer)} - err = wav.Encode(buf, encodeStreamer, beep.Format{ - SampleRate: 24000, - NumChannels: 1, - Precision: 2, - }) - if err != nil { - o.logger.Error("wav encoding failed", "error", err) - return fmt.Errorf("wav encoding failed: %w", err) - } - - o.logger.Debug("wav encoded", "size", buf.Len()) - - decodedStreamer, format, err := wav.Decode(bytes.NewReader(buf.Bytes())) - if err != nil { - o.logger.Error("wav decode failed", "error", err) - return fmt.Errorf("wav decode failed: %w", err) - } - defer decodedStreamer.Close() - - o.logger.Debug("wav decoded", "format", format) - - if err := speaker.Init(format.SampleRate, format.SampleRate.N(time.Second/10)); err != nil { - o.logger.Error("speaker init failed", "error", err) - return fmt.Errorf("speaker init failed: %w", err) - } - - o.logger.Info("playing audio", "sampleRate", format.SampleRate, "channels", format.NumChannels) - done := make(chan bool) - o.mu.Lock() - o.currentDone = done - o.currentStream = &beep.Ctrl{Streamer: beep.Seq(decodedStreamer, beep.Callback(func() { - o.logger.Debug("playback finished") - o.mu.Lock() - close(done) - o.currentStream = nil - o.currentDone = nil - o.mu.Unlock() - })), Paused: false} - o.mu.Unlock() - - speaker.Play(o.currentStream) - <-done - o.logger.Debug("Speak completed") - return nil -} - -func (o *KokoroONNXOrator) Stop() { - o.logger.Debug("stopping KokoroONNX orator") - speaker.Lock() - defer speaker.Unlock() - o.mu.Lock() - defer o.mu.Unlock() - if o.currentStream != nil { - o.currentStream.Streamer = nil - } -} - -func (o *KokoroONNXOrator) GetLogger() *slog.Logger { - return o.logger -} - -func (o *KokoroONNXOrator) stoproutine() { - o.logger.Debug("KokoroONNX stoproutine started") - for { - <-TTSDoneChan - o.logger.Debug("KokoroONNX got done signal") - o.Stop() - for len(TTSTextChan) > 0 { - <-TTSTextChan - } - o.mu.Lock() - o.textBuffer.Reset() - if o.currentDone != nil { - select { - case o.currentDone <- true: - default: - } - } - o.interrupt = true - o.mu.Unlock() - o.logger.Debug("KokoroONNX stoproutine finished") - } -} - -func (o *KokoroONNXOrator) readroutine() { - defer func() { - if r := recover(); r != nil { - fmt.Printf("PANIC RECOVERED in readroutine: %v\n", r) - // Restart the goroutine - go o.readroutine() - } - }() - - o.logger.Debug("KokoroONNX readroutine started") - tokenizer, _ := english.NewSentenceTokenizer(nil) - for { - select { - case chunk := <-TTSTextChan: - o.logger.Debug("KokoroONNX received chunk", "chunk_len", len(chunk)) - o.mu.Lock() - o.interrupt = false - _, err := o.textBuffer.WriteString(chunk) - if err != nil { - o.logger.Warn("failed to write to buffer", "error", err) - o.mu.Unlock() - continue - } - text := o.textBuffer.String() - sentences := tokenizer.Tokenize(text) - o.logger.Debug("KokoroONNX tokenized", "total_sentences", len(sentences), "buffer", text) - if len(sentences) <= 1 { - o.logger.Debug("KokoroONNX not enough sentences, waiting") - o.mu.Unlock() - continue - } - completeSentences := sentences[:len(sentences)-1] - remaining := sentences[len(sentences)-1].Text - o.textBuffer.Reset() - o.textBuffer.WriteString(remaining) - o.logger.Debug("KokoroONNX processing sentences", "count", len(completeSentences)) - o.mu.Unlock() - - for _, sentence := range completeSentences { - o.mu.Lock() - interrupted := o.interrupt - o.mu.Unlock() - if interrupted { - o.logger.Debug("KokoroONNX interrupted, exiting") - return - } - cleanedText := models.CleanText(sentence.Text) - if cleanedText == "" { - continue - } - o.logger.Info("KokoroONNX speak", "text", cleanedText) - if err := o.Speak(cleanedText); err != nil { - o.logger.Error("KokoroONNX tts failed", "text", cleanedText, "error", err) - } - } - case <-TTSFlushChan: - o.logger.Debug("KokoroONNX flush signal") - if len(TTSTextChan) > 0 { - for chunk := range TTSTextChan { - o.mu.Lock() - _, err := o.textBuffer.WriteString(chunk) - o.mu.Unlock() - if err != nil { - continue - } - if len(TTSTextChan) == 0 { - break - } - } - } - o.mu.Lock() - remaining := o.textBuffer.String() - remaining = models.CleanText(remaining) - o.textBuffer.Reset() - o.mu.Unlock() - if remaining == "" { - continue - } - sentencesRem := tokenizer.Tokenize(remaining) - for _, rs := range sentencesRem { - o.mu.Lock() - interrupt := o.interrupt - o.mu.Unlock() - if interrupt { - break - } - if err := o.Speak(rs.Text); err != nil { - o.logger.Error("tts failed", "text", rs.Text, "error", err) - } - } - } - } -}