Chore: move to own file

This commit is contained in:
Grail Finder
2026-03-07 08:52:10 +03:00
parent 0598e3e86d
commit 44633d64c6
2 changed files with 451 additions and 477 deletions

451
extra/kokoro_onnx.go Normal file
View File

@@ -0,0 +1,451 @@
//go:build extra
// +build extra
package extra
import (
"bytes"
"fmt"
"gf-lt/models"
"gf-lt/onnx"
"log/slog"
"os/exec"
"strings"
"sync"
"time"
"github.com/gopxl/beep/v2"
"github.com/gopxl/beep/v2/speaker"
"github.com/gopxl/beep/v2/wav"
"github.com/neurosnap/sentences/english"
"github.com/yalue/onnxruntime_go"
)
// KokoroONNXOrator implements Kokoro TTS using ONNX runtime
type KokoroONNXOrator struct {
logger *slog.Logger
mu sync.Mutex
session *onnxruntime_go.DynamicAdvancedSession
phonemeMap map[string]int
espeakCmd string
voice string
speed float32
styleVector []float32
currentStream *beep.Ctrl
currentDone chan bool
textBuffer strings.Builder
interrupt bool
modelLoaded bool
modelPath string
voicesPath string
}
// Phoneme to token ID mapping from Kokoro tokenizer.json
var kokoroPhonemeMap = map[string]int{
"$": 0, ";": 1, ":": 2, ",": 3, ".": 4, "!": 5, "?": 6, "—": 9, "…": 10, "\"": 11, "(": 12, ")": 13, "“": 14, "”": 15, " ": 16, "̃": 17, "ˢ": 18, "ˤ": 19, "˦": 20, "˨": 21, "ᾝ": 22, "⭧": 23,
"A": 24, "I": 25, "O": 31, "Q": 33, "S": 35, "T": 36, "W": 39, "Y": 41, "ʲ": 42,
"a": 43, "b": 44, "c": 45, "d": 46, "e": 47, "f": 48, "h": 50, "i": 51, "j": 52, "k": 53, "l": 54, "m": 55, "n": 56, "o": 57, "p": 58, "q": 59, "r": 60, "s": 61, "t": 62, "u": 63, "v": 64, "w": 65, "x": 66, "y": 67, "z": 68,
"ɑ": 69, "ɐ": 70, "ɒ": 71, "æ": 72, "β": 75, "ɔ": 76, "ɕ": 77, "ç": 78, "ɖ": 80, "ð": 81, "˔": 82, "ə": 83, "ɚ": 85, "ɛ": 86, "ɜ": 87, "ɟ": 90, "ɡ": 92, "ɥ": 99, "ɨ": 101, "ɪ": 102, "ɝ": 103, "ɯ": 110, "ɰ": 111, "ŋ": 112, "ɳ": 113, "ɲ": 114, "ɴ": 115, "ø": 116, "ɸ": 118, "θ": 119, "œ": 120, "ɹ": 123, "ɾ": 125, "ɺ": 126, "ʁ": 128, "ɽ": 129, "ʂ": 130, "ʃ": 131, "ʈ": 132, "˧": 133, "ʊ": 135, "ʋ": 136, "ʌ": 138, "ɢ": 139, "ɣ": 140, "χ": 142, "ʎ": 143, "ʒ": 147, "ʔ": 148,
"ˈ": 156, "ˌ": 157, "ː": 158, "̰": 162, "̊": 164, "↕": 169, "→": 171, "↗": 172, "↘": 173, "ᶻ": 177,
}
func (o *KokoroONNXOrator) ensureInitialized(modelPath string) error {
o.logger.Debug("ensureInitialized called", "modelPath", modelPath)
if o.modelLoaded {
o.logger.Debug("model already loaded")
return nil
}
o.mu.Lock()
defer o.mu.Unlock()
if o.modelLoaded {
return nil
}
if modelPath == "" {
o.logger.Error("modelPath is empty, cannot load ONNX model")
return fmt.Errorf("modelPath is empty, set KokoroModelPath in config")
}
// Initialize ONNX runtime (shared with embedder)
if err := onnx.Init(); err != nil {
o.logger.Error("ONNX init failed", "error", err)
return fmt.Errorf("ONNX init failed: %w", err)
}
if onnx.HasCUDASupport() {
o.logger.Info("ONNX using CUDA")
} else {
o.logger.Info("ONNX using CPU fallback")
}
if o.phonemeMap == nil {
o.phonemeMap = kokoroPhonemeMap
}
if o.espeakCmd == "" {
o.espeakCmd = "espeak-ng"
if _, err := exec.LookPath(o.espeakCmd); err != nil {
o.espeakCmd = "espeak"
if _, err := exec.LookPath(o.espeakCmd); err != nil {
return fmt.Errorf("espeak-ng or espeak not found. Install with: sudo apt-get install espeak-ng")
}
}
}
o.logger.Info("using espeak command", "cmd", o.espeakCmd)
// Load voice embedding if not already loaded
if o.styleVector == nil {
voiceName := o.voice
if voiceName == "" {
voiceName = "af_bella"
}
if o.voicesPath != "" {
styleVec, err := onnx.LoadVoice(o.voicesPath, voiceName)
if err != nil {
o.logger.Warn("failed to load voice, using zeros", "error", err, "voice", voiceName)
o.styleVector = make([]float32, 256)
} else {
// Shape is (510, 1, 256), we want the last 256 values (or first? let's use mean or just pick one)
// Actually, let's average across all 510 to get a single 256-dim vector
o.styleVector = make([]float32, 256)
for i := 0; i < 256; i++ {
var sum float32
for j := 0; j < 510; j++ {
sum += styleVec[j*256+i]
}
o.styleVector[i] = sum / 510.0
}
o.logger.Info("loaded voice embedding", "voice", voiceName)
}
} else {
o.logger.Warn("no voices path configured, using zeros for style")
o.styleVector = make([]float32, 256)
}
}
opts, err := onnx.NewSessionOptions()
if err != nil {
return fmt.Errorf("failed to create session options: %w", err)
}
defer func() { _ = opts.Destroy() }()
if onnx.HasCUDASupport() {
o.logger.Info("session options created with CUDA")
} else {
o.logger.Info("session options created with CPU")
}
session, err := onnxruntime_go.NewDynamicAdvancedSession(
modelPath,
[]string{"input_ids", "style", "speed"},
[]string{"waveform"},
opts,
)
if err != nil {
o.logger.Error("failed to create ONNX session", "error", err)
return fmt.Errorf("failed to create ONNX session: %w", err)
}
o.session = session
o.modelLoaded = true
o.logger.Info("Kokoro ONNX model loaded successfully", "model", modelPath)
return nil
}
func (o *KokoroONNXOrator) textToPhonemes(text string) (string, error) {
o.logger.Debug("converting text to phonemes", "text", text)
cmd := exec.Command(o.espeakCmd, "-x", "-q", text)
output, err := cmd.Output()
if err != nil {
o.logger.Error("espeak failed", "error", err, "cmd", o.espeakCmd, "text", text)
return "", fmt.Errorf("espeak failed: %w", err)
}
phonemeStr := strings.TrimSpace(string(output))
o.logger.Debug("phonemes generated", "phonemes", phonemeStr)
return phonemeStr, nil
}
func (o *KokoroONNXOrator) phonemesToTokens(phonemeStr string) ([]int, error) {
o.logger.Debug("converting phonemes to tokens", "phonemes", phonemeStr)
if phonemeStr == "" {
o.logger.Error("empty phoneme string")
return nil, fmt.Errorf("empty phoneme string")
}
// Iterate over each character in the phoneme string
tokens := make([]int, 0)
for _, ch := range phonemeStr {
chStr := string(ch)
if tokenID, ok := o.phonemeMap[chStr]; ok {
tokens = append(tokens, tokenID)
}
}
if len(tokens) == 0 {
o.logger.Error("no phonemes mapped to tokens", "phonemeStr", phonemeStr)
return nil, fmt.Errorf("no valid phonemes mapped to tokens")
}
o.logger.Debug("tokens generated", "count", len(tokens), "tokens", tokens)
return tokens, nil
}
func (o *KokoroONNXOrator) generateAudio(text string) ([]float32, error) {
o.logger.Debug("generateAudio called", "text", text, "speed", o.speed)
if err := o.ensureInitialized(o.modelPath); err != nil {
o.logger.Error("ensureInitialized failed", "error", err)
return nil, err
}
phonemeStr, err := o.textToPhonemes(text)
if err != nil {
o.logger.Error("phoneme conversion failed", "error", err)
return nil, fmt.Errorf("phoneme conversion failed: %w", err)
}
tokens, err := o.phonemesToTokens(phonemeStr)
if err != nil {
o.logger.Error("token conversion failed", "error", err)
return nil, fmt.Errorf("token conversion failed: %w", err)
}
if len(tokens) > 510 {
return nil, fmt.Errorf("text too long: %d tokens (max 510)", len(tokens))
}
tokens = append([]int{0}, tokens...)
tokens = append(tokens, 0)
o.logger.Debug("tokens prepared", "count", len(tokens))
inputIDs := make([]int64, len(tokens))
for i, t := range tokens {
inputIDs[i] = int64(t)
}
inputTensor, err := onnxruntime_go.NewTensor[int64](
onnxruntime_go.NewShape(1, int64(len(inputIDs))),
inputIDs,
)
if err != nil {
o.logger.Error("failed to create input tensor", "error", err)
return nil, fmt.Errorf("failed to create input tensor: %w", err)
}
defer func() { _ = inputTensor.Destroy() }()
o.logger.Debug("input tensor created", "shape", fmt.Sprintf("[1,%d]", len(inputIDs)))
styleTensor, err := onnxruntime_go.NewTensor[float32](
onnxruntime_go.NewShape(1, 256),
o.styleVector,
)
if err != nil {
o.logger.Error("failed to create style tensor", "error", err)
return nil, fmt.Errorf("failed to create style tensor: %w", err)
}
defer func() { _ = styleTensor.Destroy() }()
speedTensor, err := onnxruntime_go.NewTensor[float32](
onnxruntime_go.NewShape(1),
[]float32{o.speed},
)
if err != nil {
o.logger.Error("failed to create speed tensor", "error", err)
return nil, fmt.Errorf("failed to create speed tensor: %w", err)
}
defer func() { _ = speedTensor.Destroy() }()
o.logger.Debug("speed tensor created", "speed", o.speed)
outputTensor, err := onnxruntime_go.NewEmptyTensor[float32](
onnxruntime_go.NewShape(1, 512),
)
if err != nil {
o.logger.Error("failed to create output tensor", "error", err)
return nil, fmt.Errorf("failed to create output tensor: %w", err)
}
defer func() { _ = outputTensor.Destroy() }()
o.logger.Debug("output tensor created", "shape", "[1,512]")
o.logger.Info("running ONNX inference", "input_len", len(inputIDs))
err = o.session.Run(
[]onnxruntime_go.Value{inputTensor, styleTensor, speedTensor},
[]onnxruntime_go.Value{outputTensor},
)
if err != nil {
o.logger.Error("ONNX inference failed", "error", err)
return nil, fmt.Errorf("ONNX inference failed: %w", err)
}
o.logger.Debug("ONNX inference completed")
audioData := outputTensor.GetData()
if len(audioData) == 0 {
o.logger.Error("empty audio output from ONNX")
return nil, fmt.Errorf("empty audio output")
}
o.logger.Debug("audio generated", "samples", len(audioData))
audio := make([]float32, len(audioData))
copy(audio, audioData)
return audio, nil
}
func (o *KokoroONNXOrator) Speak(text string) error {
o.logger.Debug("KokoroONNX Speak called", "text_len", len(text))
audio, err := o.generateAudio(text)
if err != nil {
o.logger.Error("audio generation failed", "error", err)
return fmt.Errorf("audio generation failed: %w", err)
}
o.logger.Debug("audio ready for playback", "samples", len(audio))
// Create streamer for encoding
encodeStreamer := beep.StreamerFunc(func(samples [][2]float64) (n int, ok bool) {
for i := range samples {
if i >= len(audio) {
return i, false
}
samples[i][0] = float64(audio[i])
samples[i][1] = float64(audio[i])
}
return len(audio), true
})
buf := &seekableBuffer{new(bytes.Buffer)}
err = wav.Encode(buf, encodeStreamer, beep.Format{
SampleRate: 24000,
NumChannels: 1,
Precision: 2,
})
if err != nil {
o.logger.Error("wav encoding failed", "error", err)
return fmt.Errorf("wav encoding failed: %w", err)
}
o.logger.Debug("wav encoded", "size", buf.Len())
decodedStreamer, format, err := wav.Decode(bytes.NewReader(buf.Bytes()))
if err != nil {
o.logger.Error("wav decode failed", "error", err)
return fmt.Errorf("wav decode failed: %w", err)
}
defer decodedStreamer.Close()
o.logger.Debug("wav decoded", "format", format)
if err := speaker.Init(format.SampleRate, format.SampleRate.N(time.Second/10)); err != nil {
o.logger.Error("speaker init failed", "error", err)
return fmt.Errorf("speaker init failed: %w", err)
}
o.logger.Info("playing audio", "sampleRate", format.SampleRate, "channels", format.NumChannels)
done := make(chan bool)
o.mu.Lock()
o.currentDone = done
o.currentStream = &beep.Ctrl{Streamer: beep.Seq(decodedStreamer, beep.Callback(func() {
o.logger.Debug("playback finished")
o.mu.Lock()
close(done)
o.currentStream = nil
o.currentDone = nil
o.mu.Unlock()
})), Paused: false}
o.mu.Unlock()
speaker.Play(o.currentStream)
<-done
o.logger.Debug("Speak completed")
return nil
}
func (o *KokoroONNXOrator) Stop() {
o.logger.Debug("stopping KokoroONNX orator")
speaker.Lock()
defer speaker.Unlock()
o.mu.Lock()
defer o.mu.Unlock()
if o.currentStream != nil {
o.currentStream.Streamer = nil
}
}
func (o *KokoroONNXOrator) GetLogger() *slog.Logger {
return o.logger
}
func (o *KokoroONNXOrator) stoproutine() {
o.logger.Debug("KokoroONNX stoproutine started")
for {
<-TTSDoneChan
o.logger.Debug("KokoroONNX got done signal")
o.Stop()
for len(TTSTextChan) > 0 {
<-TTSTextChan
}
o.mu.Lock()
o.textBuffer.Reset()
if o.currentDone != nil {
select {
case o.currentDone <- true:
default:
}
}
o.interrupt = true
o.mu.Unlock()
o.logger.Debug("KokoroONNX stoproutine finished")
}
}
func (o *KokoroONNXOrator) readroutine() {
o.logger.Debug("KokoroONNX readroutine started")
tokenizer, _ := english.NewSentenceTokenizer(nil)
for {
select {
case chunk := <-TTSTextChan:
o.logger.Debug("KokoroONNX received chunk", "chunk_len", len(chunk))
o.mu.Lock()
o.interrupt = false
_, err := o.textBuffer.WriteString(chunk)
if err != nil {
o.logger.Warn("failed to write to buffer", "error", err)
o.mu.Unlock()
continue
}
text := o.textBuffer.String()
sentences := tokenizer.Tokenize(text)
o.logger.Debug("KokoroONNX tokenized", "total_sentences", len(sentences), "buffer", text)
if len(sentences) <= 1 {
o.logger.Debug("KokoroONNX not enough sentences, waiting")
o.mu.Unlock()
continue
}
completeSentences := sentences[:len(sentences)-1]
remaining := sentences[len(sentences)-1].Text
o.textBuffer.Reset()
o.textBuffer.WriteString(remaining)
o.logger.Debug("KokoroONNX processing sentences", "count", len(completeSentences))
o.mu.Unlock()
for _, sentence := range completeSentences {
o.mu.Lock()
interrupted := o.interrupt
o.mu.Unlock()
if interrupted {
o.logger.Debug("KokoroONNX interrupted, exiting")
return
}
cleanedText := models.CleanText(sentence.Text)
if cleanedText == "" {
continue
}
o.logger.Info("KokoroONNX speak", "text", cleanedText)
if err := o.Speak(cleanedText); err != nil {
o.logger.Error("KokoroONNX tts failed", "text", cleanedText, "error", err)
}
}
case <-TTSFlushChan:
o.logger.Debug("KokoroONNX flush signal")
if len(TTSTextChan) > 0 {
for chunk := range TTSTextChan {
o.mu.Lock()
_, err := o.textBuffer.WriteString(chunk)
o.mu.Unlock()
if err != nil {
continue
}
if len(TTSTextChan) == 0 {
break
}
}
}
o.mu.Lock()
remaining := o.textBuffer.String()
remaining = models.CleanText(remaining)
o.textBuffer.Reset()
o.mu.Unlock()
if remaining == "" {
continue
}
sentencesRem := tokenizer.Tokenize(remaining)
for _, rs := range sentencesRem {
o.mu.Lock()
interrupt := o.interrupt
o.mu.Unlock()
if interrupt {
break
}
if err := o.Speak(rs.Text); err != nil {
o.logger.Error("tts failed", "text", rs.Text, "error", err)
}
}
}
}
}

View File

@@ -9,12 +9,10 @@ import (
"fmt"
"gf-lt/config"
"gf-lt/models"
"gf-lt/onnx"
"io"
"log/slog"
"net/http"
"os"
"os/exec"
"strings"
"sync"
"time"
@@ -24,9 +22,7 @@ import (
"github.com/gopxl/beep/v2"
"github.com/gopxl/beep/v2/mp3"
"github.com/gopxl/beep/v2/speaker"
"github.com/gopxl/beep/v2/wav"
"github.com/neurosnap/sentences/english"
"github.com/yalue/onnxruntime_go"
)
var (
@@ -495,476 +491,3 @@ func (o *GoogleTranslateOrator) Stop() {
_ = o.speech.Stop()
}
}
// KokoroONNXOrator implements Kokoro TTS using ONNX runtime
type KokoroONNXOrator struct {
logger *slog.Logger
mu sync.Mutex
session *onnxruntime_go.DynamicAdvancedSession
phonemeMap map[string]int
espeakCmd string
voice string
speed float32
styleVector []float32
currentStream *beep.Ctrl
currentDone chan bool
textBuffer strings.Builder
interrupt bool
modelLoaded bool
modelPath string
voicesPath string
}
// Phoneme to token ID mapping from Kokoro tokenizer.json
var kokoroPhonemeMap = map[string]int{
"$": 0, ";": 1, ":": 2, ",": 3, ".": 4, "!": 5, "?": 6, "—": 9, "…": 10, "\"": 11, "(": 12, ")": 13, "“": 14, "”": 15, " ": 16, "̃": 17, "ˢ": 18, "ˤ": 19, "˦": 20, "˨": 21, "ᾝ": 22, "⭧": 23,
"A": 24, "I": 25, "O": 31, "Q": 33, "S": 35, "T": 36, "W": 39, "Y": 41, "ʲ": 42,
"a": 43, "b": 44, "c": 45, "d": 46, "e": 47, "f": 48, "h": 50, "i": 51, "j": 52, "k": 53, "l": 54, "m": 55, "n": 56, "o": 57, "p": 58, "q": 59, "r": 60, "s": 61, "t": 62, "u": 63, "v": 64, "w": 65, "x": 66, "y": 67, "z": 68,
"ɑ": 69, "ɐ": 70, "ɒ": 71, "æ": 72, "β": 75, "ɔ": 76, "ɕ": 77, "ç": 78, "ɖ": 80, "ð": 81, "˔": 82, "ə": 83, "ɚ": 85, "ɛ": 86, "ɜ": 87, "ɟ": 90, "ɡ": 92, "ɥ": 99, "ɨ": 101, "ɪ": 102, "ɝ": 103, "ɯ": 110, "ɰ": 111, "ŋ": 112, "ɳ": 113, "ɲ": 114, "ɴ": 115, "ø": 116, "ɸ": 118, "θ": 119, "œ": 120, "ɹ": 123, "ɾ": 125, "ɺ": 126, "ʁ": 128, "ɽ": 129, "ʂ": 130, "ʃ": 131, "ʈ": 132, "˧": 133, "ʊ": 135, "ʋ": 136, "ʌ": 138, "ɢ": 139, "ɣ": 140, "χ": 142, "ʎ": 143, "ʒ": 147, "ʔ": 148,
"ˈ": 156, "ˌ": 157, "ː": 158, "̰": 162, "̊": 164, "↕": 169, "→": 171, "↗": 172, "↘": 173, "ᶻ": 177,
}
func (o *KokoroONNXOrator) ensureInitialized(modelPath string) error {
o.logger.Debug("ensureInitialized called", "modelPath", modelPath)
if o.modelLoaded {
o.logger.Debug("model already loaded")
return nil
}
o.mu.Lock()
defer o.mu.Unlock()
if o.modelLoaded {
return nil
}
if modelPath == "" {
o.logger.Error("modelPath is empty, cannot load ONNX model")
return fmt.Errorf("modelPath is empty, set KokoroModelPath in config")
}
// Initialize ONNX runtime (shared with embedder)
if err := onnx.Init(); err != nil {
o.logger.Error("ONNX init failed", "error", err)
return fmt.Errorf("ONNX init failed: %w", err)
}
if onnx.HasCUDASupport() {
o.logger.Info("ONNX using CUDA")
} else {
o.logger.Info("ONNX using CPU fallback")
}
if o.phonemeMap == nil {
o.phonemeMap = kokoroPhonemeMap
}
if o.espeakCmd == "" {
o.espeakCmd = "espeak-ng"
if _, err := exec.LookPath(o.espeakCmd); err != nil {
o.espeakCmd = "espeak"
if _, err := exec.LookPath(o.espeakCmd); err != nil {
return fmt.Errorf("espeak-ng or espeak not found. Install with: sudo apt-get install espeak-ng")
}
}
}
o.logger.Info("using espeak command", "cmd", o.espeakCmd)
// Load voice embedding if not already loaded
if o.styleVector == nil {
voiceName := o.voice
if voiceName == "" {
voiceName = "af_bella"
}
if o.voicesPath != "" {
styleVec, err := onnx.LoadVoice(o.voicesPath, voiceName)
if err != nil {
o.logger.Warn("failed to load voice, using zeros", "error", err, "voice", voiceName)
o.styleVector = make([]float32, 256)
} else {
// Shape is (510, 1, 256), we want the last 256 values (or first? let's use mean or just pick one)
// Actually, let's average across all 510 to get a single 256-dim vector
o.styleVector = make([]float32, 256)
for i := 0; i < 256; i++ {
var sum float32
for j := 0; j < 510; j++ {
sum += styleVec[j*256+i]
}
o.styleVector[i] = sum / 510.0
}
o.logger.Info("loaded voice embedding", "voice", voiceName)
}
} else {
o.logger.Warn("no voices path configured, using zeros for style")
o.styleVector = make([]float32, 256)
}
}
opts, err := onnx.NewSessionOptions()
if err != nil {
return fmt.Errorf("failed to create session options: %w", err)
}
defer func() { _ = opts.Destroy() }()
if onnx.HasCUDASupport() {
o.logger.Info("session options created with CUDA")
} else {
o.logger.Info("session options created with CPU")
}
session, err := onnxruntime_go.NewDynamicAdvancedSession(
modelPath,
[]string{"input_ids", "style", "speed"},
[]string{"waveform"},
opts,
)
if err != nil {
o.logger.Error("failed to create ONNX session", "error", err)
return fmt.Errorf("failed to create ONNX session: %w", err)
}
o.session = session
o.modelLoaded = true
o.logger.Info("Kokoro ONNX model loaded successfully", "model", modelPath)
return nil
}
func (o *KokoroONNXOrator) textToPhonemes(text string) (string, error) {
o.logger.Debug("converting text to phonemes", "text", text)
cmd := exec.Command(o.espeakCmd, "-x", "-q", text)
output, err := cmd.Output()
if err != nil {
o.logger.Error("espeak failed", "error", err, "cmd", o.espeakCmd, "text", text)
return "", fmt.Errorf("espeak failed: %w", err)
}
phonemeStr := strings.TrimSpace(string(output))
o.logger.Debug("phonemes generated", "phonemes", phonemeStr)
return phonemeStr, nil
}
func (o *KokoroONNXOrator) phonemesToTokens(phonemeStr string) ([]int, error) {
o.logger.Debug("converting phonemes to tokens", "phonemes", phonemeStr)
if phonemeStr == "" {
o.logger.Error("empty phoneme string")
return nil, fmt.Errorf("empty phoneme string")
}
// Iterate over each character in the phoneme string
tokens := make([]int, 0)
for _, ch := range phonemeStr {
chStr := string(ch)
if tokenID, ok := o.phonemeMap[chStr]; ok {
tokens = append(tokens, tokenID)
}
}
if len(tokens) == 0 {
o.logger.Error("no phonemes mapped to tokens", "phonemeStr", phonemeStr)
return nil, fmt.Errorf("no valid phonemes mapped to tokens")
}
o.logger.Debug("tokens generated", "count", len(tokens), "tokens", tokens)
return tokens, nil
}
func (o *KokoroONNXOrator) generateAudio(text string) ([]float32, error) {
defer func() {
if r := recover(); r != nil {
fmt.Printf("PANIC RECOVERED in generateAudio: %v\n", r)
}
}()
o.logger.Debug("generateAudio called", "text", text, "speed", o.speed)
if err := o.ensureInitialized(o.modelPath); err != nil {
o.logger.Error("ensureInitialized failed", "error", err)
return nil, err
}
phonemeStr, err := o.textToPhonemes(text)
if err != nil {
o.logger.Error("phoneme conversion failed", "error", err)
return nil, fmt.Errorf("phoneme conversion failed: %w", err)
}
tokens, err := o.phonemesToTokens(phonemeStr)
if err != nil {
o.logger.Error("token conversion failed", "error", err)
return nil, fmt.Errorf("token conversion failed: %w", err)
}
if len(tokens) > 510 {
return nil, fmt.Errorf("text too long: %d tokens (max 510)", len(tokens))
}
tokens = append([]int{0}, tokens...)
tokens = append(tokens, 0)
o.logger.Debug("tokens prepared", "count", len(tokens))
inputIDs := make([]int64, len(tokens))
for i, t := range tokens {
inputIDs[i] = int64(t)
}
inputTensor, err := onnxruntime_go.NewTensor[int64](
onnxruntime_go.NewShape(1, int64(len(inputIDs))),
inputIDs,
)
if err != nil {
o.logger.Error("failed to create input tensor", "error", err)
return nil, fmt.Errorf("failed to create input tensor: %w", err)
}
defer func() { _ = inputTensor.Destroy() }()
o.logger.Debug("input tensor created", "shape", fmt.Sprintf("[1,%d]", len(inputIDs)))
styleTensor, err := onnxruntime_go.NewTensor[float32](
onnxruntime_go.NewShape(1, 256),
o.styleVector,
)
if err != nil {
o.logger.Error("failed to create style tensor", "error", err)
return nil, fmt.Errorf("failed to create style tensor: %w", err)
}
defer func() { _ = styleTensor.Destroy() }()
speedTensor, err := onnxruntime_go.NewTensor[float32](
onnxruntime_go.NewShape(1),
[]float32{o.speed},
)
if err != nil {
o.logger.Error("failed to create speed tensor", "error", err)
return nil, fmt.Errorf("failed to create speed tensor: %w", err)
}
defer func() { _ = speedTensor.Destroy() }()
o.logger.Debug("speed tensor created", "speed", o.speed)
outputTensor, err := onnxruntime_go.NewEmptyTensor[float32](
onnxruntime_go.NewShape(1, 512),
)
if err != nil {
o.logger.Error("failed to create output tensor", "error", err)
return nil, fmt.Errorf("failed to create output tensor: %w", err)
}
defer func() { _ = outputTensor.Destroy() }()
o.logger.Debug("output tensor created", "shape", "[1,512]")
o.logger.Info("running ONNX inference", "input_len", len(inputIDs))
err = o.session.Run(
[]onnxruntime_go.Value{inputTensor, styleTensor, speedTensor},
[]onnxruntime_go.Value{outputTensor},
)
if err != nil {
o.logger.Error("ONNX inference failed", "error", err)
return nil, fmt.Errorf("ONNX inference failed: %w", err)
}
o.logger.Debug("ONNX inference completed")
audioData := outputTensor.GetData()
if len(audioData) == 0 {
o.logger.Error("empty audio output from ONNX")
return nil, fmt.Errorf("empty audio output")
}
o.logger.Debug("audio generated", "samples", len(audioData))
audio := make([]float32, len(audioData))
copy(audio, audioData)
return audio, nil
}
func (o *KokoroONNXOrator) Speak(text string) error {
o.logger.Debug("KokoroONNX Speak called", "text_len", len(text))
audio, err := o.generateAudio(text)
if err != nil {
o.logger.Error("audio generation failed", "error", err)
return fmt.Errorf("audio generation failed: %w", err)
}
o.logger.Debug("audio ready for playback", "samples", len(audio))
// Create streamer for encoding
encodeStreamer := beep.StreamerFunc(func(samples [][2]float64) (n int, ok bool) {
for i := range samples {
if i >= len(audio) {
return i, false
}
samples[i][0] = float64(audio[i])
samples[i][1] = float64(audio[i])
}
return len(audio), true
})
buf := &seekableBuffer{new(bytes.Buffer)}
err = wav.Encode(buf, encodeStreamer, beep.Format{
SampleRate: 24000,
NumChannels: 1,
Precision: 2,
})
if err != nil {
o.logger.Error("wav encoding failed", "error", err)
return fmt.Errorf("wav encoding failed: %w", err)
}
o.logger.Debug("wav encoded", "size", buf.Len())
decodedStreamer, format, err := wav.Decode(bytes.NewReader(buf.Bytes()))
if err != nil {
o.logger.Error("wav decode failed", "error", err)
return fmt.Errorf("wav decode failed: %w", err)
}
defer decodedStreamer.Close()
o.logger.Debug("wav decoded", "format", format)
if err := speaker.Init(format.SampleRate, format.SampleRate.N(time.Second/10)); err != nil {
o.logger.Error("speaker init failed", "error", err)
return fmt.Errorf("speaker init failed: %w", err)
}
o.logger.Info("playing audio", "sampleRate", format.SampleRate, "channels", format.NumChannels)
done := make(chan bool)
o.mu.Lock()
o.currentDone = done
o.currentStream = &beep.Ctrl{Streamer: beep.Seq(decodedStreamer, beep.Callback(func() {
o.logger.Debug("playback finished")
o.mu.Lock()
close(done)
o.currentStream = nil
o.currentDone = nil
o.mu.Unlock()
})), Paused: false}
o.mu.Unlock()
speaker.Play(o.currentStream)
<-done
o.logger.Debug("Speak completed")
return nil
}
func (o *KokoroONNXOrator) Stop() {
o.logger.Debug("stopping KokoroONNX orator")
speaker.Lock()
defer speaker.Unlock()
o.mu.Lock()
defer o.mu.Unlock()
if o.currentStream != nil {
o.currentStream.Streamer = nil
}
}
func (o *KokoroONNXOrator) GetLogger() *slog.Logger {
return o.logger
}
func (o *KokoroONNXOrator) stoproutine() {
o.logger.Debug("KokoroONNX stoproutine started")
for {
<-TTSDoneChan
o.logger.Debug("KokoroONNX got done signal")
o.Stop()
for len(TTSTextChan) > 0 {
<-TTSTextChan
}
o.mu.Lock()
o.textBuffer.Reset()
if o.currentDone != nil {
select {
case o.currentDone <- true:
default:
}
}
o.interrupt = true
o.mu.Unlock()
o.logger.Debug("KokoroONNX stoproutine finished")
}
}
func (o *KokoroONNXOrator) readroutine() {
defer func() {
if r := recover(); r != nil {
fmt.Printf("PANIC RECOVERED in readroutine: %v\n", r)
// Restart the goroutine
go o.readroutine()
}
}()
o.logger.Debug("KokoroONNX readroutine started")
tokenizer, _ := english.NewSentenceTokenizer(nil)
for {
select {
case chunk := <-TTSTextChan:
o.logger.Debug("KokoroONNX received chunk", "chunk_len", len(chunk))
o.mu.Lock()
o.interrupt = false
_, err := o.textBuffer.WriteString(chunk)
if err != nil {
o.logger.Warn("failed to write to buffer", "error", err)
o.mu.Unlock()
continue
}
text := o.textBuffer.String()
sentences := tokenizer.Tokenize(text)
o.logger.Debug("KokoroONNX tokenized", "total_sentences", len(sentences), "buffer", text)
if len(sentences) <= 1 {
o.logger.Debug("KokoroONNX not enough sentences, waiting")
o.mu.Unlock()
continue
}
completeSentences := sentences[:len(sentences)-1]
remaining := sentences[len(sentences)-1].Text
o.textBuffer.Reset()
o.textBuffer.WriteString(remaining)
o.logger.Debug("KokoroONNX processing sentences", "count", len(completeSentences))
o.mu.Unlock()
for _, sentence := range completeSentences {
o.mu.Lock()
interrupted := o.interrupt
o.mu.Unlock()
if interrupted {
o.logger.Debug("KokoroONNX interrupted, exiting")
return
}
cleanedText := models.CleanText(sentence.Text)
if cleanedText == "" {
continue
}
o.logger.Info("KokoroONNX speak", "text", cleanedText)
if err := o.Speak(cleanedText); err != nil {
o.logger.Error("KokoroONNX tts failed", "text", cleanedText, "error", err)
}
}
case <-TTSFlushChan:
o.logger.Debug("KokoroONNX flush signal")
if len(TTSTextChan) > 0 {
for chunk := range TTSTextChan {
o.mu.Lock()
_, err := o.textBuffer.WriteString(chunk)
o.mu.Unlock()
if err != nil {
continue
}
if len(TTSTextChan) == 0 {
break
}
}
}
o.mu.Lock()
remaining := o.textBuffer.String()
remaining = models.CleanText(remaining)
o.textBuffer.Reset()
o.mu.Unlock()
if remaining == "" {
continue
}
sentencesRem := tokenizer.Tokenize(remaining)
for _, rs := range sentencesRem {
o.mu.Lock()
interrupt := o.interrupt
o.mu.Unlock()
if interrupt {
break
}
if err := o.Speak(rs.Text); err != nil {
o.logger.Error("tts failed", "text", rs.Text, "error", err)
}
}
}
}
}