Files
gf-lt/extra/kokoro_onnx.go
2026-03-07 09:08:01 +03:00

417 lines
12 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
//go:build extra
// +build extra
package extra
import (
"bytes"
"fmt"
"gf-lt/models"
"gf-lt/onnx"
"log/slog"
"os/exec"
"strings"
"sync"
"time"
"github.com/gopxl/beep/v2"
"github.com/gopxl/beep/v2/speaker"
"github.com/gopxl/beep/v2/wav"
"github.com/neurosnap/sentences/english"
"github.com/yalue/onnxruntime_go"
)
// KokoroONNXOrator implements Kokoro TTS using ONNX runtime
type KokoroONNXOrator struct {
logger *slog.Logger
mu sync.Mutex
session *onnxruntime_go.DynamicAdvancedSession
phonemeMap map[string]int
espeakCmd string
voice string
speed float32
styleVector []float32
currentStream *beep.Ctrl
currentDone chan bool
textBuffer strings.Builder
interrupt bool
modelLoaded bool
modelPath string
voicesPath string
}
// Phoneme to token ID mapping from Kokoro tokenizer.json
var kokoroPhonemeMap = map[string]int{
"$": 0, ";": 1, ":": 2, ",": 3, ".": 4, "!": 5, "?": 6, "—": 9, "…": 10, "\"": 11, "(": 12, ")": 13, "“": 14, "”": 15, " ": 16, "̃": 17, "ˢ": 18, "ˤ": 19, "˦": 20, "˨": 21, "ᾝ": 22, "⭧": 23,
"A": 24, "I": 25, "O": 31, "Q": 33, "S": 35, "T": 36, "W": 39, "Y": 41, "ʲ": 42,
"a": 43, "b": 44, "c": 45, "d": 46, "e": 47, "f": 48, "h": 50, "i": 51, "j": 52, "k": 53, "l": 54, "m": 55, "n": 56, "o": 57, "p": 58, "q": 59, "r": 60, "s": 61, "t": 62, "u": 63, "v": 64, "w": 65, "x": 66, "y": 67, "z": 68,
"ɑ": 69, "ɐ": 70, "ɒ": 71, "æ": 72, "β": 75, "ɔ": 76, "ɕ": 77, "ç": 78, "ɖ": 80, "ð": 81, "˔": 82, "ə": 83, "ɚ": 85, "ɛ": 86, "ɜ": 87, "ɟ": 90, "ɡ": 92, "ɥ": 99, "ɨ": 101, "ɪ": 102, "ɝ": 103, "ɯ": 110, "ɰ": 111, "ŋ": 112, "ɳ": 113, "ɲ": 114, "ɴ": 115, "ø": 116, "ɸ": 118, "θ": 119, "œ": 120, "ɹ": 123, "ɾ": 125, "ɺ": 126, "ʁ": 128, "ɽ": 129, "ʂ": 130, "ʃ": 131, "ʈ": 132, "˧": 133, "ʊ": 135, "ʋ": 136, "ʌ": 138, "ɢ": 139, "ɣ": 140, "χ": 142, "ʎ": 143, "ʒ": 147, "ʔ": 148,
"ˈ": 156, "ˌ": 157, "ː": 158, "̰": 162, "̊": 164, "↕": 169, "→": 171, "↗": 172, "↘": 173, "ᶻ": 177,
}
func (o *KokoroONNXOrator) ensureInitialized(modelPath string) error {
if o.modelLoaded {
return nil
}
o.mu.Lock()
defer o.mu.Unlock()
if o.modelLoaded {
return nil
}
if modelPath == "" {
o.logger.Error("modelPath is empty, cannot load ONNX model")
return fmt.Errorf("modelPath is empty, set KokoroModelPath in config")
}
// Initialize ONNX runtime (shared with embedder)
if err := onnx.Init(); err != nil {
o.logger.Error("ONNX init failed", "error", err)
return fmt.Errorf("ONNX init failed: %w", err)
}
if onnx.HasCUDASupport() {
o.logger.Info("ONNX using CUDA")
} else {
o.logger.Info("ONNX using CPU fallback")
}
if o.phonemeMap == nil {
o.phonemeMap = kokoroPhonemeMap
}
if o.espeakCmd == "" {
o.espeakCmd = "espeak-ng"
if _, err := exec.LookPath(o.espeakCmd); err != nil {
o.espeakCmd = "espeak"
if _, err := exec.LookPath(o.espeakCmd); err != nil {
return fmt.Errorf("espeak-ng or espeak not found. Install with: sudo apt-get install espeak-ng")
}
}
}
o.logger.Info("using espeak command", "cmd", o.espeakCmd)
// Load voice embedding if not already loaded
if o.styleVector == nil {
voiceName := o.voice
if voiceName == "" {
voiceName = "af_bella"
}
if o.voicesPath != "" {
styleVec, err := onnx.LoadVoice(o.voicesPath, voiceName)
if err != nil {
o.logger.Warn("failed to load voice, using zeros", "error", err, "voice", voiceName)
o.styleVector = make([]float32, 256)
} else {
// Shape is (510, 1, 256), we want the last 256 values (or first? let's use mean or just pick one)
// Actually, let's average across all 510 to get a single 256-dim vector
o.styleVector = make([]float32, 256)
for i := 0; i < 256; i++ {
var sum float32
for j := 0; j < 510; j++ {
sum += styleVec[j*256+i]
}
o.styleVector[i] = sum / 510.0
}
o.logger.Info("loaded voice embedding", "voice", voiceName)
}
} else {
o.logger.Warn("no voices path configured, using zeros for style")
o.styleVector = make([]float32, 256)
}
}
opts, err := onnx.NewSessionOptions()
if err != nil {
return fmt.Errorf("failed to create session options: %w", err)
}
defer func() { _ = opts.Destroy() }()
if onnx.HasCUDASupport() {
o.logger.Info("session options created with CUDA")
} else {
o.logger.Info("session options created with CPU")
}
session, err := onnxruntime_go.NewDynamicAdvancedSession(
modelPath,
[]string{"input_ids", "style", "speed"},
[]string{"waveform"},
opts,
)
if err != nil {
o.logger.Error("failed to create ONNX session", "error", err)
return fmt.Errorf("failed to create ONNX session: %w", err)
}
o.session = session
o.modelLoaded = true
o.logger.Info("Kokoro ONNX model loaded successfully", "model", modelPath)
return nil
}
func (o *KokoroONNXOrator) textToPhonemes(text string) (string, error) {
cmd := exec.Command(o.espeakCmd, "-x", "-q", text)
output, err := cmd.Output()
if err != nil {
o.logger.Error("espeak failed", "error", err, "cmd", o.espeakCmd, "text", text)
return "", fmt.Errorf("espeak failed: %w", err)
}
phonemeStr := strings.TrimSpace(string(output))
return phonemeStr, nil
}
func (o *KokoroONNXOrator) phonemesToTokens(phonemeStr string) ([]int, error) {
if phonemeStr == "" {
o.logger.Error("empty phoneme string")
return nil, fmt.Errorf("empty phoneme string")
}
// Iterate over each character in the phoneme string
tokens := make([]int, 0)
for _, ch := range phonemeStr {
chStr := string(ch)
if tokenID, ok := o.phonemeMap[chStr]; ok {
tokens = append(tokens, tokenID)
}
}
if len(tokens) == 0 {
o.logger.Error("no phonemes mapped to tokens", "phonemeStr", phonemeStr)
return nil, fmt.Errorf("no valid phonemes mapped to tokens")
}
return tokens, nil
}
func (o *KokoroONNXOrator) generateAudio(text string) ([]float32, error) {
if err := o.ensureInitialized(o.modelPath); err != nil {
o.logger.Error("ensureInitialized failed", "error", err)
return nil, err
}
phonemeStr, err := o.textToPhonemes(text)
if err != nil {
o.logger.Error("phoneme conversion failed", "error", err)
return nil, fmt.Errorf("phoneme conversion failed: %w", err)
}
tokens, err := o.phonemesToTokens(phonemeStr)
if err != nil {
o.logger.Error("token conversion failed", "error", err)
return nil, fmt.Errorf("token conversion failed: %w", err)
}
if len(tokens) > 510 {
return nil, fmt.Errorf("text too long: %d tokens (max 510)", len(tokens))
}
tokens = append([]int{0}, tokens...)
tokens = append(tokens, 0)
inputIDs := make([]int64, len(tokens))
for i, t := range tokens {
inputIDs[i] = int64(t)
}
inputTensor, err := onnxruntime_go.NewTensor[int64](
onnxruntime_go.NewShape(1, int64(len(inputIDs))),
inputIDs,
)
if err != nil {
o.logger.Error("failed to create input tensor", "error", err)
return nil, fmt.Errorf("failed to create input tensor: %w", err)
}
defer func() { _ = inputTensor.Destroy() }()
styleTensor, err := onnxruntime_go.NewTensor[float32](
onnxruntime_go.NewShape(1, 256),
o.styleVector,
)
if err != nil {
o.logger.Error("failed to create style tensor", "error", err)
return nil, fmt.Errorf("failed to create style tensor: %w", err)
}
defer func() { _ = styleTensor.Destroy() }()
speedTensor, err := onnxruntime_go.NewTensor[float32](
onnxruntime_go.NewShape(1),
[]float32{o.speed},
)
if err != nil {
o.logger.Error("failed to create speed tensor", "error", err)
return nil, fmt.Errorf("failed to create speed tensor: %w", err)
}
defer func() { _ = speedTensor.Destroy() }()
outputTensor, err := onnxruntime_go.NewEmptyTensor[float32](
onnxruntime_go.NewShape(1, 512),
)
if err != nil {
o.logger.Error("failed to create output tensor", "error", err)
return nil, fmt.Errorf("failed to create output tensor: %w", err)
}
defer func() { _ = outputTensor.Destroy() }()
err = o.session.Run(
[]onnxruntime_go.Value{inputTensor, styleTensor, speedTensor},
[]onnxruntime_go.Value{outputTensor},
)
if err != nil {
o.logger.Error("ONNX inference failed", "error", err)
return nil, fmt.Errorf("ONNX inference failed: %w", err)
}
audioData := outputTensor.GetData()
if len(audioData) == 0 {
o.logger.Error("empty audio output from ONNX")
return nil, fmt.Errorf("empty audio output")
}
audio := make([]float32, len(audioData))
copy(audio, audioData)
return audio, nil
}
func (o *KokoroONNXOrator) Speak(text string) error {
audio, err := o.generateAudio(text)
if err != nil {
o.logger.Error("audio generation failed", "error", err)
return fmt.Errorf("audio generation failed: %w", err)
}
// Create streamer for encoding
encodeStreamer := beep.StreamerFunc(func(samples [][2]float64) (n int, ok bool) {
for i := range samples {
if i >= len(audio) {
return i, false
}
samples[i][0] = float64(audio[i])
samples[i][1] = float64(audio[i])
}
return len(audio), true
})
buf := &seekableBuffer{new(bytes.Buffer)}
err = wav.Encode(buf, encodeStreamer, beep.Format{
SampleRate: 24000,
NumChannels: 1,
Precision: 2,
})
if err != nil {
o.logger.Error("wav encoding failed", "error", err)
return fmt.Errorf("wav encoding failed: %w", err)
}
decodedStreamer, format, err := wav.Decode(bytes.NewReader(buf.Bytes()))
if err != nil {
o.logger.Error("wav decode failed", "error", err)
return fmt.Errorf("wav decode failed: %w", err)
}
defer decodedStreamer.Close()
if err := speaker.Init(format.SampleRate, format.SampleRate.N(time.Second/10)); err != nil {
o.logger.Error("speaker init failed", "error", err)
return fmt.Errorf("speaker init failed: %w", err)
}
o.logger.Info("playing audio", "sampleRate", format.SampleRate, "channels", format.NumChannels)
done := make(chan bool)
o.mu.Lock()
o.currentDone = done
o.currentStream = &beep.Ctrl{Streamer: beep.Seq(decodedStreamer, beep.Callback(func() {
o.mu.Lock()
close(done)
o.currentStream = nil
o.currentDone = nil
o.mu.Unlock()
})), Paused: false}
o.mu.Unlock()
speaker.Play(o.currentStream)
<-done
return nil
}
func (o *KokoroONNXOrator) Stop() {
speaker.Lock()
defer speaker.Unlock()
o.mu.Lock()
defer o.mu.Unlock()
if o.currentStream != nil {
o.currentStream.Streamer = nil
}
}
func (o *KokoroONNXOrator) GetLogger() *slog.Logger {
return o.logger
}
func (o *KokoroONNXOrator) stoproutine() {
for {
<-TTSDoneChan
o.Stop()
for len(TTSTextChan) > 0 {
<-TTSTextChan
}
o.mu.Lock()
o.textBuffer.Reset()
if o.currentDone != nil {
select {
case o.currentDone <- true:
default:
}
}
o.interrupt = true
o.mu.Unlock()
}
}
func (o *KokoroONNXOrator) readroutine() {
tokenizer, _ := english.NewSentenceTokenizer(nil)
for {
select {
case chunk := <-TTSTextChan:
o.mu.Lock()
o.interrupt = false
_, err := o.textBuffer.WriteString(chunk)
if err != nil {
o.logger.Warn("failed to write to buffer", "error", err)
o.mu.Unlock()
continue
}
text := o.textBuffer.String()
sentences := tokenizer.Tokenize(text)
if len(sentences) <= 1 {
o.mu.Unlock()
continue
}
completeSentences := sentences[:len(sentences)-1]
remaining := sentences[len(sentences)-1].Text
o.textBuffer.Reset()
o.textBuffer.WriteString(remaining)
o.mu.Unlock()
for _, sentence := range completeSentences {
o.mu.Lock()
interrupted := o.interrupt
o.mu.Unlock()
if interrupted {
return
}
cleanedText := models.CleanText(sentence.Text)
if cleanedText == "" {
continue
}
o.logger.Info("KokoroONNX speak", "text", cleanedText)
if err := o.Speak(cleanedText); err != nil {
o.logger.Error("KokoroONNX tts failed", "text", cleanedText, "error", err)
}
}
case <-TTSFlushChan:
if len(TTSTextChan) > 0 {
for chunk := range TTSTextChan {
o.mu.Lock()
_, err := o.textBuffer.WriteString(chunk)
o.mu.Unlock()
if err != nil {
continue
}
if len(TTSTextChan) == 0 {
break
}
}
}
o.mu.Lock()
remaining := o.textBuffer.String()
remaining = models.CleanText(remaining)
o.textBuffer.Reset()
o.mu.Unlock()
if remaining == "" {
continue
}
sentencesRem := tokenizer.Tokenize(remaining)
for _, rs := range sentencesRem {
o.mu.Lock()
interrupt := o.interrupt
o.mu.Unlock()
if interrupt {
break
}
if err := o.Speak(rs.Text); err != nil {
o.logger.Error("tts failed", "text", rs.Text, "error", err)
}
}
}
}
}