Files
gf-lt/extra/tts.go
2026-03-07 08:35:44 +03:00

971 lines
27 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
//go:build extra
// +build extra
package extra
import (
"bytes"
"encoding/json"
"fmt"
"gf-lt/config"
"gf-lt/models"
"gf-lt/onnx"
"io"
"log/slog"
"net/http"
"os"
"os/exec"
"strings"
"sync"
"time"
google_translate_tts "github.com/GrailFinder/google-translate-tts"
"github.com/GrailFinder/google-translate-tts/handlers"
"github.com/gopxl/beep/v2"
"github.com/gopxl/beep/v2/mp3"
"github.com/gopxl/beep/v2/speaker"
"github.com/gopxl/beep/v2/wav"
"github.com/neurosnap/sentences/english"
"github.com/yalue/onnxruntime_go"
)
var (
TTSTextChan = make(chan string, 10000)
TTSFlushChan = make(chan bool, 1)
TTSDoneChan = make(chan bool, 1)
// endsWithPunctuation = regexp.MustCompile(`[;.!?]$`)
)
type seekableBuffer struct {
*bytes.Buffer
}
func (s *seekableBuffer) Seek(offset int64, whence int) (int64, error) {
return 0, nil
}
type Orator interface {
Speak(text string) error
Stop()
// pause and resume?
GetLogger() *slog.Logger
}
// impl https://github.com/remsky/Kokoro-FastAPI
type KokoroOrator struct {
logger *slog.Logger
mu sync.Mutex
URL string
Format models.AudioFormat
Stream bool
Speed float32
Language string
Voice string
currentStream *beep.Ctrl // Added for playback control
currentDone chan bool
textBuffer strings.Builder
interrupt bool
// textBuffer bytes.Buffer
}
// Google Translate TTS implementation
type GoogleTranslateOrator struct {
logger *slog.Logger
mu sync.Mutex
speech *google_translate_tts.Speech
currentStream *beep.Ctrl
currentDone chan bool
textBuffer strings.Builder
interrupt bool
}
func (o *KokoroOrator) stoproutine() {
for {
<-TTSDoneChan
o.logger.Debug("orator got done signal")
o.Stop()
// drain the channel
for len(TTSTextChan) > 0 {
<-TTSTextChan
}
o.mu.Lock()
o.textBuffer.Reset()
if o.currentDone != nil {
select {
case o.currentDone <- true:
default:
// Channel might be closed, ignore
}
}
o.interrupt = true
o.mu.Unlock()
}
}
func (o *KokoroOrator) readroutine() {
tokenizer, _ := english.NewSentenceTokenizer(nil)
for {
select {
case chunk := <-TTSTextChan:
o.mu.Lock()
o.interrupt = false
_, err := o.textBuffer.WriteString(chunk)
if err != nil {
o.logger.Warn("failed to write to stringbuilder", "error", err)
o.mu.Unlock()
continue
}
text := o.textBuffer.String()
sentences := tokenizer.Tokenize(text)
o.logger.Debug("adding chunk", "chunk", chunk, "text", text, "sen-len", len(sentences))
if len(sentences) <= 1 {
o.mu.Unlock()
continue
}
completeSentences := sentences[:len(sentences)-1]
remaining := sentences[len(sentences)-1].Text
o.textBuffer.Reset()
o.textBuffer.WriteString(remaining)
o.mu.Unlock()
for _, sentence := range completeSentences {
o.mu.Lock()
interrupted := o.interrupt
o.mu.Unlock()
if interrupted {
return
}
cleanedText := models.CleanText(sentence.Text)
if cleanedText == "" {
continue
}
o.logger.Debug("calling Speak with sentence", "sent", cleanedText)
if err := o.Speak(cleanedText); err != nil {
o.logger.Error("tts failed", "sentence", cleanedText, "error", err)
}
}
case <-TTSFlushChan:
o.logger.Debug("got flushchan signal start")
// lln is done get the whole message out
if len(TTSTextChan) > 0 { // otherwise might get stuck
for chunk := range TTSTextChan {
o.mu.Lock()
_, err := o.textBuffer.WriteString(chunk)
o.mu.Unlock()
if err != nil {
o.logger.Warn("failed to write to stringbuilder", "error", err)
continue
}
if len(TTSTextChan) == 0 {
break
}
}
}
// flush remaining text
o.mu.Lock()
remaining := o.textBuffer.String()
remaining = models.CleanText(remaining)
o.textBuffer.Reset()
o.mu.Unlock()
if remaining == "" {
continue
}
o.logger.Debug("calling Speak with remainder", "rem", remaining)
sentencesRem := tokenizer.Tokenize(remaining)
for _, rs := range sentencesRem { // to avoid dumping large volume of text
o.mu.Lock()
interrupt := o.interrupt
o.mu.Unlock()
if interrupt {
break
}
if err := o.Speak(rs.Text); err != nil {
o.logger.Error("tts failed", "sentence", rs, "error", err)
}
}
}
}
}
func NewOrator(log *slog.Logger, cfg *config.Config) Orator {
provider := cfg.TTS_PROVIDER
if provider == "" {
provider = "google" // does not require local setup
}
switch strings.ToLower(provider) {
case "kokoro": // kokoro
orator := &KokoroOrator{
logger: log,
URL: cfg.TTS_URL,
Format: models.AFMP3,
Stream: false,
Speed: cfg.TTS_SPEED,
Language: "a",
Voice: "af_bella(1)+af_sky(1)",
}
go orator.readroutine()
go orator.stoproutine()
return orator
case "kokoro_onnx":
log.Info("Initializing Kokoro ONNX TTS", "modelPath", cfg.KokoroModelPath, "voicesPath", cfg.KokoroVoicesPath, "voice", cfg.KokoroVoice, "speed", cfg.TTS_SPEED)
orator := &KokoroONNXOrator{
logger: log,
modelPath: cfg.KokoroModelPath,
voicesPath: cfg.KokoroVoicesPath,
speed: cfg.TTS_SPEED,
voice: cfg.KokoroVoice,
}
go orator.readroutine()
go orator.stoproutine()
return orator
default:
language := cfg.TTS_LANGUAGE
if language == "" {
language = "en"
}
speech := &google_translate_tts.Speech{
Folder: os.TempDir() + "/gf-lt-tts", // Temporary directory for caching
Language: language,
Proxy: "", // Proxy not supported
Speed: cfg.TTS_SPEED,
Handler: &handlers.Beep{},
}
orator := &GoogleTranslateOrator{
logger: log,
speech: speech,
}
go orator.readroutine()
go orator.stoproutine()
return orator
}
}
func (o *KokoroOrator) GetLogger() *slog.Logger {
return o.logger
}
func (o *KokoroOrator) requestSound(text string) (io.ReadCloser, error) {
if o.URL == "" {
return nil, fmt.Errorf("TTS URL is empty")
}
payload := map[string]interface{}{
"input": text,
"voice": o.Voice,
"response_format": o.Format,
"download_format": o.Format,
"stream": o.Stream,
"speed": o.Speed,
// "return_download_link": true,
"lang_code": o.Language,
}
payloadBytes, err := json.Marshal(payload)
if err != nil {
return nil, fmt.Errorf("failed to marshal payload: %w", err)
}
req, err := http.NewRequest("POST", o.URL, bytes.NewBuffer(payloadBytes)) //nolint:noctx
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("accept", "application/json")
req.Header.Set("Content-Type", "application/json")
resp, err := http.DefaultClient.Do(req)
if err != nil {
return nil, fmt.Errorf("request failed: %w", err)
}
if resp.StatusCode != http.StatusOK {
defer resp.Body.Close()
return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode)
}
return resp.Body, nil
}
func (o *KokoroOrator) Speak(text string) error {
o.logger.Debug("fn: Speak is called", "text-len", len(text))
body, err := o.requestSound(text)
if err != nil {
o.logger.Error("request failed", "error", err)
return fmt.Errorf("request failed: %w", err)
}
defer body.Close()
// Decode the mp3 audio from response body
streamer, format, err := mp3.Decode(body)
if err != nil {
o.logger.Error("mp3 decode failed", "error", err)
return fmt.Errorf("mp3 decode failed: %w", err)
}
defer streamer.Close()
// here it spams with errors that speaker cannot be initialized more than once, but how would we deal with many audio records then?
if err := speaker.Init(format.SampleRate, format.SampleRate.N(time.Second/10)); err != nil {
o.logger.Debug("failed to init speaker", "error", err)
}
done := make(chan bool)
o.mu.Lock()
o.currentDone = done
o.currentStream = &beep.Ctrl{Streamer: beep.Seq(streamer, beep.Callback(func() {
o.mu.Lock()
close(done)
o.currentStream = nil
o.currentDone = nil
o.mu.Unlock()
})), Paused: false}
o.mu.Unlock()
speaker.Play(o.currentStream)
<-done
return nil
}
func (o *KokoroOrator) Stop() {
// speaker.Clear()
o.logger.Debug("attempted to stop orator", "orator", o)
speaker.Lock()
defer speaker.Unlock()
o.mu.Lock()
defer o.mu.Unlock()
if o.currentStream != nil {
// o.currentStream.Paused = true
o.currentStream.Streamer = nil
}
}
func (o *GoogleTranslateOrator) stoproutine() {
for {
<-TTSDoneChan
o.logger.Debug("orator got done signal")
o.Stop()
// drain the channel
for len(TTSTextChan) > 0 {
<-TTSTextChan
}
o.mu.Lock()
o.textBuffer.Reset()
if o.currentDone != nil {
select {
case o.currentDone <- true:
default:
// Channel might be closed, ignore
}
}
o.interrupt = true
o.mu.Unlock()
}
}
func (o *GoogleTranslateOrator) readroutine() {
tokenizer, _ := english.NewSentenceTokenizer(nil)
for {
select {
case chunk := <-TTSTextChan:
o.mu.Lock()
o.interrupt = false
_, err := o.textBuffer.WriteString(chunk)
if err != nil {
o.logger.Warn("failed to write to stringbuilder", "error", err)
o.mu.Unlock()
continue
}
text := o.textBuffer.String()
sentences := tokenizer.Tokenize(text)
o.logger.Debug("adding chunk", "chunk", chunk, "text", text, "sen-len", len(sentences))
if len(sentences) <= 1 {
o.mu.Unlock()
continue
}
completeSentences := sentences[:len(sentences)-1]
remaining := sentences[len(sentences)-1].Text
o.textBuffer.Reset()
o.textBuffer.WriteString(remaining)
o.mu.Unlock()
for _, sentence := range completeSentences {
o.mu.Lock()
interrupted := o.interrupt
o.mu.Unlock()
if interrupted {
return
}
cleanedText := models.CleanText(sentence.Text)
if cleanedText == "" {
continue
}
o.logger.Debug("calling Speak with sentence", "sent", cleanedText)
if err := o.Speak(cleanedText); err != nil {
o.logger.Error("tts failed", "sentence", cleanedText, "error", err)
}
}
case <-TTSFlushChan:
o.logger.Debug("got flushchan signal start")
// lln is done get the whole message out
if len(TTSTextChan) > 0 { // otherwise might get stuck
for chunk := range TTSTextChan {
o.mu.Lock()
_, err := o.textBuffer.WriteString(chunk)
o.mu.Unlock()
if err != nil {
o.logger.Warn("failed to write to stringbuilder", "error", err)
continue
}
if len(TTSTextChan) == 0 {
break
}
}
}
o.mu.Lock()
remaining := o.textBuffer.String()
remaining = models.CleanText(remaining)
o.textBuffer.Reset()
o.mu.Unlock()
if remaining == "" {
continue
}
o.logger.Debug("calling Speak with remainder", "rem", remaining)
sentencesRem := tokenizer.Tokenize(remaining)
for _, rs := range sentencesRem { // to avoid dumping large volume of text
o.mu.Lock()
interrupt := o.interrupt
o.mu.Unlock()
if interrupt {
break
}
if err := o.Speak(rs.Text); err != nil {
o.logger.Error("tts failed", "sentence", rs.Text, "error", err)
}
}
}
}
}
func (o *GoogleTranslateOrator) GetLogger() *slog.Logger {
return o.logger
}
func (o *GoogleTranslateOrator) Speak(text string) error {
o.logger.Debug("fn: Speak is called", "text-len", len(text))
// Generate MP3 data using google-translate-tts
reader, err := o.speech.GenerateSpeech(text)
if err != nil {
o.logger.Error("generate speech failed", "error", err)
return fmt.Errorf("generate speech failed: %w", err)
}
// Decode the mp3 audio from reader (wrap with NopCloser for io.ReadCloser)
streamer, format, err := mp3.Decode(io.NopCloser(reader))
if err != nil {
o.logger.Error("mp3 decode failed", "error", err)
return fmt.Errorf("mp3 decode failed: %w", err)
}
defer streamer.Close()
playbackStreamer := beep.Streamer(streamer)
speed := o.speech.Speed
if speed <= 0 {
speed = 1.0
}
if speed != 1.0 {
playbackStreamer = beep.ResampleRatio(3, float64(speed), streamer)
}
// Initialize speaker with the format's sample rate
if err := speaker.Init(format.SampleRate, format.SampleRate.N(time.Second/10)); err != nil {
o.logger.Debug("failed to init speaker", "error", err)
}
done := make(chan bool)
o.mu.Lock()
o.currentDone = done
o.currentStream = &beep.Ctrl{Streamer: beep.Seq(playbackStreamer, beep.Callback(func() {
o.mu.Lock()
close(done)
o.currentStream = nil
o.currentDone = nil
o.mu.Unlock()
})), Paused: false}
o.mu.Unlock()
speaker.Play(o.currentStream)
<-done // wait for playback to complete
return nil
}
func (o *GoogleTranslateOrator) Stop() {
o.logger.Debug("attempted to stop google translate orator")
speaker.Lock()
defer speaker.Unlock()
o.mu.Lock()
defer o.mu.Unlock()
if o.currentStream != nil {
o.currentStream.Streamer = nil
}
// Also stop the speech handler if possible
if o.speech != nil {
_ = o.speech.Stop()
}
}
// KokoroONNXOrator implements Kokoro TTS using ONNX runtime
type KokoroONNXOrator struct {
logger *slog.Logger
mu sync.Mutex
session *onnxruntime_go.DynamicAdvancedSession
phonemeMap map[string]int
espeakCmd string
voice string
speed float32
styleVector []float32
currentStream *beep.Ctrl
currentDone chan bool
textBuffer strings.Builder
interrupt bool
modelLoaded bool
modelPath string
voicesPath string
}
// Phoneme to token ID mapping from Kokoro tokenizer.json
var kokoroPhonemeMap = map[string]int{
"$": 0, ";": 1, ":": 2, ",": 3, ".": 4, "!": 5, "?": 6, "—": 9, "…": 10, "\"": 11, "(": 12, ")": 13, "“": 14, "”": 15, " ": 16, "̃": 17, "ˢ": 18, "ˤ": 19, "˦": 20, "˨": 21, "ᾝ": 22, "⭧": 23,
"A": 24, "I": 25, "O": 31, "Q": 33, "S": 35, "T": 36, "W": 39, "Y": 41, "ʲ": 42,
"a": 43, "b": 44, "c": 45, "d": 46, "e": 47, "f": 48, "h": 50, "i": 51, "j": 52, "k": 53, "l": 54, "m": 55, "n": 56, "o": 57, "p": 58, "q": 59, "r": 60, "s": 61, "t": 62, "u": 63, "v": 64, "w": 65, "x": 66, "y": 67, "z": 68,
"ɑ": 69, "ɐ": 70, "ɒ": 71, "æ": 72, "β": 75, "ɔ": 76, "ɕ": 77, "ç": 78, "ɖ": 80, "ð": 81, "˔": 82, "ə": 83, "ɚ": 85, "ɛ": 86, "ɜ": 87, "ɟ": 90, "ɡ": 92, "ɥ": 99, "ɨ": 101, "ɪ": 102, "ɝ": 103, "ɯ": 110, "ɰ": 111, "ŋ": 112, "ɳ": 113, "ɲ": 114, "ɴ": 115, "ø": 116, "ɸ": 118, "θ": 119, "œ": 120, "ɹ": 123, "ɾ": 125, "ɺ": 126, "ʁ": 128, "ɽ": 129, "ʂ": 130, "ʃ": 131, "ʈ": 132, "˧": 133, "ʊ": 135, "ʋ": 136, "ʌ": 138, "ɢ": 139, "ɣ": 140, "χ": 142, "ʎ": 143, "ʒ": 147, "ʔ": 148,
"ˈ": 156, "ˌ": 157, "ː": 158, "̰": 162, "̊": 164, "↕": 169, "→": 171, "↗": 172, "↘": 173, "ᶻ": 177,
}
func (o *KokoroONNXOrator) ensureInitialized(modelPath string) error {
o.logger.Debug("ensureInitialized called", "modelPath", modelPath)
if o.modelLoaded {
o.logger.Debug("model already loaded")
return nil
}
o.mu.Lock()
defer o.mu.Unlock()
if o.modelLoaded {
return nil
}
if modelPath == "" {
o.logger.Error("modelPath is empty, cannot load ONNX model")
return fmt.Errorf("modelPath is empty, set KokoroModelPath in config")
}
// Initialize ONNX runtime (shared with embedder)
if err := onnx.Init(); err != nil {
o.logger.Error("ONNX init failed", "error", err)
return fmt.Errorf("ONNX init failed: %w", err)
}
if onnx.HasCUDASupport() {
o.logger.Info("ONNX using CUDA")
} else {
o.logger.Info("ONNX using CPU fallback")
}
if o.phonemeMap == nil {
o.phonemeMap = kokoroPhonemeMap
}
if o.espeakCmd == "" {
o.espeakCmd = "espeak-ng"
if _, err := exec.LookPath(o.espeakCmd); err != nil {
o.espeakCmd = "espeak"
if _, err := exec.LookPath(o.espeakCmd); err != nil {
return fmt.Errorf("espeak-ng or espeak not found. Install with: sudo apt-get install espeak-ng")
}
}
}
o.logger.Info("using espeak command", "cmd", o.espeakCmd)
// Load voice embedding if not already loaded
if o.styleVector == nil {
voiceName := o.voice
if voiceName == "" {
voiceName = "af_bella"
}
if o.voicesPath != "" {
styleVec, err := onnx.LoadVoice(o.voicesPath, voiceName)
if err != nil {
o.logger.Warn("failed to load voice, using zeros", "error", err, "voice", voiceName)
o.styleVector = make([]float32, 256)
} else {
// Shape is (510, 1, 256), we want the last 256 values (or first? let's use mean or just pick one)
// Actually, let's average across all 510 to get a single 256-dim vector
o.styleVector = make([]float32, 256)
for i := 0; i < 256; i++ {
var sum float32
for j := 0; j < 510; j++ {
sum += styleVec[j*256+i]
}
o.styleVector[i] = sum / 510.0
}
o.logger.Info("loaded voice embedding", "voice", voiceName)
}
} else {
o.logger.Warn("no voices path configured, using zeros for style")
o.styleVector = make([]float32, 256)
}
}
opts, err := onnx.NewSessionOptions()
if err != nil {
return fmt.Errorf("failed to create session options: %w", err)
}
defer func() { _ = opts.Destroy() }()
if onnx.HasCUDASupport() {
o.logger.Info("session options created with CUDA")
} else {
o.logger.Info("session options created with CPU")
}
session, err := onnxruntime_go.NewDynamicAdvancedSession(
modelPath,
[]string{"input_ids", "style", "speed"},
[]string{"waveform"},
opts,
)
if err != nil {
o.logger.Error("failed to create ONNX session", "error", err)
return fmt.Errorf("failed to create ONNX session: %w", err)
}
o.session = session
o.modelLoaded = true
o.logger.Info("Kokoro ONNX model loaded successfully", "model", modelPath)
return nil
}
func (o *KokoroONNXOrator) textToPhonemes(text string) (string, error) {
o.logger.Debug("converting text to phonemes", "text", text)
cmd := exec.Command(o.espeakCmd, "-x", "-q", text)
output, err := cmd.Output()
if err != nil {
o.logger.Error("espeak failed", "error", err, "cmd", o.espeakCmd, "text", text)
return "", fmt.Errorf("espeak failed: %w", err)
}
phonemeStr := strings.TrimSpace(string(output))
o.logger.Debug("phonemes generated", "phonemes", phonemeStr)
return phonemeStr, nil
}
func (o *KokoroONNXOrator) phonemesToTokens(phonemeStr string) ([]int, error) {
o.logger.Debug("converting phonemes to tokens", "phonemes", phonemeStr)
if phonemeStr == "" {
o.logger.Error("empty phoneme string")
return nil, fmt.Errorf("empty phoneme string")
}
// Iterate over each character in the phoneme string
tokens := make([]int, 0)
for _, ch := range phonemeStr {
chStr := string(ch)
if tokenID, ok := o.phonemeMap[chStr]; ok {
tokens = append(tokens, tokenID)
}
}
if len(tokens) == 0 {
o.logger.Error("no phonemes mapped to tokens", "phonemeStr", phonemeStr)
return nil, fmt.Errorf("no valid phonemes mapped to tokens")
}
o.logger.Debug("tokens generated", "count", len(tokens), "tokens", tokens)
return tokens, nil
}
func (o *KokoroONNXOrator) generateAudio(text string) ([]float32, error) {
defer func() {
if r := recover(); r != nil {
fmt.Printf("PANIC RECOVERED in generateAudio: %v\n", r)
}
}()
o.logger.Debug("generateAudio called", "text", text, "speed", o.speed)
if err := o.ensureInitialized(o.modelPath); err != nil {
o.logger.Error("ensureInitialized failed", "error", err)
return nil, err
}
phonemeStr, err := o.textToPhonemes(text)
if err != nil {
o.logger.Error("phoneme conversion failed", "error", err)
return nil, fmt.Errorf("phoneme conversion failed: %w", err)
}
tokens, err := o.phonemesToTokens(phonemeStr)
if err != nil {
o.logger.Error("token conversion failed", "error", err)
return nil, fmt.Errorf("token conversion failed: %w", err)
}
if len(tokens) > 510 {
return nil, fmt.Errorf("text too long: %d tokens (max 510)", len(tokens))
}
tokens = append([]int{0}, tokens...)
tokens = append(tokens, 0)
o.logger.Debug("tokens prepared", "count", len(tokens))
inputIDs := make([]int64, len(tokens))
for i, t := range tokens {
inputIDs[i] = int64(t)
}
inputTensor, err := onnxruntime_go.NewTensor[int64](
onnxruntime_go.NewShape(1, int64(len(inputIDs))),
inputIDs,
)
if err != nil {
o.logger.Error("failed to create input tensor", "error", err)
return nil, fmt.Errorf("failed to create input tensor: %w", err)
}
defer func() { _ = inputTensor.Destroy() }()
o.logger.Debug("input tensor created", "shape", fmt.Sprintf("[1,%d]", len(inputIDs)))
styleTensor, err := onnxruntime_go.NewTensor[float32](
onnxruntime_go.NewShape(1, 256),
o.styleVector,
)
if err != nil {
o.logger.Error("failed to create style tensor", "error", err)
return nil, fmt.Errorf("failed to create style tensor: %w", err)
}
defer func() { _ = styleTensor.Destroy() }()
speedTensor, err := onnxruntime_go.NewTensor[float32](
onnxruntime_go.NewShape(1),
[]float32{o.speed},
)
if err != nil {
o.logger.Error("failed to create speed tensor", "error", err)
return nil, fmt.Errorf("failed to create speed tensor: %w", err)
}
defer func() { _ = speedTensor.Destroy() }()
o.logger.Debug("speed tensor created", "speed", o.speed)
outputTensor, err := onnxruntime_go.NewEmptyTensor[float32](
onnxruntime_go.NewShape(1, 512),
)
if err != nil {
o.logger.Error("failed to create output tensor", "error", err)
return nil, fmt.Errorf("failed to create output tensor: %w", err)
}
defer func() { _ = outputTensor.Destroy() }()
o.logger.Debug("output tensor created", "shape", "[1,512]")
o.logger.Info("running ONNX inference", "input_len", len(inputIDs))
err = o.session.Run(
[]onnxruntime_go.Value{inputTensor, styleTensor, speedTensor},
[]onnxruntime_go.Value{outputTensor},
)
if err != nil {
o.logger.Error("ONNX inference failed", "error", err)
return nil, fmt.Errorf("ONNX inference failed: %w", err)
}
o.logger.Debug("ONNX inference completed")
audioData := outputTensor.GetData()
if len(audioData) == 0 {
o.logger.Error("empty audio output from ONNX")
return nil, fmt.Errorf("empty audio output")
}
o.logger.Debug("audio generated", "samples", len(audioData))
audio := make([]float32, len(audioData))
copy(audio, audioData)
return audio, nil
}
func (o *KokoroONNXOrator) Speak(text string) error {
o.logger.Debug("KokoroONNX Speak called", "text_len", len(text))
audio, err := o.generateAudio(text)
if err != nil {
o.logger.Error("audio generation failed", "error", err)
return fmt.Errorf("audio generation failed: %w", err)
}
o.logger.Debug("audio ready for playback", "samples", len(audio))
// Create streamer for encoding
encodeStreamer := beep.StreamerFunc(func(samples [][2]float64) (n int, ok bool) {
for i := range samples {
if i >= len(audio) {
return i, false
}
samples[i][0] = float64(audio[i])
samples[i][1] = float64(audio[i])
}
return len(audio), true
})
buf := &seekableBuffer{new(bytes.Buffer)}
err = wav.Encode(buf, encodeStreamer, beep.Format{
SampleRate: 24000,
NumChannels: 1,
Precision: 2,
})
if err != nil {
o.logger.Error("wav encoding failed", "error", err)
return fmt.Errorf("wav encoding failed: %w", err)
}
o.logger.Debug("wav encoded", "size", buf.Len())
decodedStreamer, format, err := wav.Decode(bytes.NewReader(buf.Bytes()))
if err != nil {
o.logger.Error("wav decode failed", "error", err)
return fmt.Errorf("wav decode failed: %w", err)
}
defer decodedStreamer.Close()
o.logger.Debug("wav decoded", "format", format)
if err := speaker.Init(format.SampleRate, format.SampleRate.N(time.Second/10)); err != nil {
o.logger.Error("speaker init failed", "error", err)
return fmt.Errorf("speaker init failed: %w", err)
}
o.logger.Info("playing audio", "sampleRate", format.SampleRate, "channels", format.NumChannels)
done := make(chan bool)
o.mu.Lock()
o.currentDone = done
o.currentStream = &beep.Ctrl{Streamer: beep.Seq(decodedStreamer, beep.Callback(func() {
o.logger.Debug("playback finished")
o.mu.Lock()
close(done)
o.currentStream = nil
o.currentDone = nil
o.mu.Unlock()
})), Paused: false}
o.mu.Unlock()
speaker.Play(o.currentStream)
<-done
o.logger.Debug("Speak completed")
return nil
}
func (o *KokoroONNXOrator) Stop() {
o.logger.Debug("stopping KokoroONNX orator")
speaker.Lock()
defer speaker.Unlock()
o.mu.Lock()
defer o.mu.Unlock()
if o.currentStream != nil {
o.currentStream.Streamer = nil
}
}
func (o *KokoroONNXOrator) GetLogger() *slog.Logger {
return o.logger
}
func (o *KokoroONNXOrator) stoproutine() {
o.logger.Debug("KokoroONNX stoproutine started")
for {
<-TTSDoneChan
o.logger.Debug("KokoroONNX got done signal")
o.Stop()
for len(TTSTextChan) > 0 {
<-TTSTextChan
}
o.mu.Lock()
o.textBuffer.Reset()
if o.currentDone != nil {
select {
case o.currentDone <- true:
default:
}
}
o.interrupt = true
o.mu.Unlock()
o.logger.Debug("KokoroONNX stoproutine finished")
}
}
func (o *KokoroONNXOrator) readroutine() {
defer func() {
if r := recover(); r != nil {
fmt.Printf("PANIC RECOVERED in readroutine: %v\n", r)
// Restart the goroutine
go o.readroutine()
}
}()
o.logger.Debug("KokoroONNX readroutine started")
tokenizer, _ := english.NewSentenceTokenizer(nil)
for {
select {
case chunk := <-TTSTextChan:
o.logger.Debug("KokoroONNX received chunk", "chunk_len", len(chunk))
o.mu.Lock()
o.interrupt = false
_, err := o.textBuffer.WriteString(chunk)
if err != nil {
o.logger.Warn("failed to write to buffer", "error", err)
o.mu.Unlock()
continue
}
text := o.textBuffer.String()
sentences := tokenizer.Tokenize(text)
o.logger.Debug("KokoroONNX tokenized", "total_sentences", len(sentences), "buffer", text)
if len(sentences) <= 1 {
o.logger.Debug("KokoroONNX not enough sentences, waiting")
o.mu.Unlock()
continue
}
completeSentences := sentences[:len(sentences)-1]
remaining := sentences[len(sentences)-1].Text
o.textBuffer.Reset()
o.textBuffer.WriteString(remaining)
o.logger.Debug("KokoroONNX processing sentences", "count", len(completeSentences))
o.mu.Unlock()
for _, sentence := range completeSentences {
o.mu.Lock()
interrupted := o.interrupt
o.mu.Unlock()
if interrupted {
o.logger.Debug("KokoroONNX interrupted, exiting")
return
}
cleanedText := models.CleanText(sentence.Text)
if cleanedText == "" {
continue
}
o.logger.Info("KokoroONNX speak", "text", cleanedText)
if err := o.Speak(cleanedText); err != nil {
o.logger.Error("KokoroONNX tts failed", "text", cleanedText, "error", err)
}
}
case <-TTSFlushChan:
o.logger.Debug("KokoroONNX flush signal")
if len(TTSTextChan) > 0 {
for chunk := range TTSTextChan {
o.mu.Lock()
_, err := o.textBuffer.WriteString(chunk)
o.mu.Unlock()
if err != nil {
continue
}
if len(TTSTextChan) == 0 {
break
}
}
}
o.mu.Lock()
remaining := o.textBuffer.String()
remaining = models.CleanText(remaining)
o.textBuffer.Reset()
o.mu.Unlock()
if remaining == "" {
continue
}
sentencesRem := tokenizer.Tokenize(remaining)
for _, rs := range sentencesRem {
o.mu.Lock()
interrupt := o.interrupt
o.mu.Unlock()
if interrupt {
break
}
if err := o.Speak(rs.Text); err != nil {
o.logger.Error("tts failed", "text", rs.Text, "error", err)
}
}
}
}
}