Refactor: cleanup stt mess, config use

This commit is contained in:
Grail Finder
2025-05-18 14:32:54 +03:00
parent 2b2e45ff00
commit 441225ede8
5 changed files with 43 additions and 34 deletions

16
bot.go
View File

@@ -149,7 +149,7 @@ func sendMsgToLLM(body io.Reader) {
// resp, err := httpClient.Post(cfg.CurrentAPI, "application/json", body)
resp, err := httpClient.Do(req)
if err != nil {
logger.Error("llamacpp api", "error", err, "body", string(bodyBytes))
logger.Error("llamacpp api", "error", err)
if err := notifyUser("error", "apicall failed:"+err.Error()); err != nil {
logger.Error("failed to notify", "error", err)
}
@@ -498,6 +498,7 @@ func init() {
//
logLevel.Set(slog.LevelInfo)
logger = slog.New(slog.NewTextHandler(logfile, &slog.HandlerOptions{Level: logLevel}))
// TODO: rename and/or put in cfg
store = storage.NewProviderSQL("test.db", logger)
if store == nil {
os.Exit(1)
@@ -511,7 +512,7 @@ func init() {
}
lastChat := loadOldChatOrGetNew()
chatBody = &models.ChatBody{
Model: "modl_name",
Model: "modelname",
Stream: true,
Messages: lastChat,
}
@@ -522,9 +523,10 @@ func init() {
}
choseChunkParser()
httpClient = createClient(time.Second * 15)
// TODO: check config for orator
orator = extra.InitOrator(logger, "http://localhost:8880/v1/audio/speech")
asr = extra.NewWhisperSTT(logger, "http://localhost:8081/inference", 44100)
// go runModelNameTicker(time.Second * 120)
// tempLoad()
if cfg.TTS_ENABLED {
orator = extra.InitOrator(logger, cfg.TTS_URL)
}
if cfg.STT_ENABLED {
asr = extra.NewWhisperSTT(logger, cfg.STT_URL, 16000)
}
}

View File

@@ -15,3 +15,6 @@ RAGWorkers = 5
# extra tts
TTS_ENABLED = false
TTS_URL = "http://localhost:8880/v1/audio/speech"
# extra stt
STT_ENABLED = false
STT_URL = "http://localhost:8081/inference"

View File

@@ -42,6 +42,9 @@ type Config struct {
// TTS
TTS_URL string `toml:"TTS_URL"`
TTS_ENABLED bool `toml:"TTS_ENABLED"`
// STT
STT_URL string `toml:"STT_URL"`
STT_ENABLED bool `toml:"STT_ENABLED"`
}
func LoadConfigOrDefault(fn string) *Config {

View File

@@ -9,7 +9,7 @@ import (
"log/slog"
"mime/multipart"
"net/http"
"time"
"strings"
"github.com/gordonklaus/portaudio"
)
@@ -28,8 +28,7 @@ type WhisperSTT struct {
logger *slog.Logger
ServerURL string
SampleRate int
RawBuffer *bytes.Buffer
WavBuffer *bytes.Buffer
AudioBuffer *bytes.Buffer
streamer StreamCloser
recording bool
}
@@ -39,8 +38,7 @@ func NewWhisperSTT(logger *slog.Logger, serverURL string, sampleRate int) *Whisp
logger: logger,
ServerURL: serverURL,
SampleRate: sampleRate,
RawBuffer: new(bytes.Buffer),
WavBuffer: new(bytes.Buffer),
AudioBuffer: new(bytes.Buffer),
}
}
@@ -54,17 +52,14 @@ func (stt *WhisperSTT) StartRecording() error {
func (stt *WhisperSTT) StopRecording() (string, error) {
stt.recording = false
time.Sleep(time.Millisecond * 200) // this is not the way
// wait loop to finish?
if stt.RawBuffer == nil {
err := errors.New("unexpected nil RawBuffer")
if stt.AudioBuffer == nil {
err := errors.New("unexpected nil AudioBuffer")
stt.logger.Error(err.Error())
return "", err
}
// Create WAV header first
stt.writeWavHeader(stt.WavBuffer, len(stt.RawBuffer.Bytes())) // Write initial header with 0 size
stt.WavBuffer.Write(stt.RawBuffer.Bytes())
body := &bytes.Buffer{} // third buffer?
body := &bytes.Buffer{}
writer := multipart.NewWriter(body)
// Add audio file part
part, err := writer.CreateFormFile("file", "recording.wav")
@@ -72,11 +67,15 @@ func (stt *WhisperSTT) StopRecording() (string, error) {
stt.logger.Error("fn: StopRecording", "error", err)
return "", err
}
_, err = io.Copy(part, stt.WavBuffer)
if err != nil {
// Stream directly to multipart writer: header + raw data
dataSize := stt.AudioBuffer.Len()
stt.writeWavHeader(part, dataSize)
if _, err := io.Copy(part, stt.AudioBuffer); err != nil {
stt.logger.Error("fn: StopRecording", "error", err)
return "", err
}
// Reset buffer for next recording
stt.AudioBuffer.Reset()
// Add response format field
err = writer.WriteField("response_format", "text")
if err != nil {
@@ -95,13 +94,12 @@ func (stt *WhisperSTT) StopRecording() (string, error) {
}
defer resp.Body.Close()
// Read and print response
responseText, err := io.ReadAll(resp.Body)
responseTextBytes, err := io.ReadAll(resp.Body)
if err != nil {
stt.logger.Error("fn: StopRecording", "error", err)
return "", err
}
stt.logger.Info("got transcript", "text", string(responseText))
return string(responseText), nil
return strings.TrimRight(string(responseTextBytes), "\n"), nil
}
func (stt *WhisperSTT) writeWavHeader(w io.Writer, dataSize int) {
@@ -149,7 +147,7 @@ func (stt *WhisperSTT) microphoneStream(sampleRate int) error {
stt.logger.Error("reading stream", "error", err)
return
}
if err := binary.Write(stt.RawBuffer, binary.LittleEndian, in); err != nil {
if err := binary.Write(stt.AudioBuffer, binary.LittleEndian, in); err != nil {
stt.logger.Error("writing to buffer", "error", err)
return
}

7
tui.go
View File

@@ -666,6 +666,7 @@ func init() {
pages.AddPage(imgPage, imgView, true, true)
return nil
}
// TODO: move to menu or table
// if event.Key() == tcell.KeyCtrlR && cfg.HFToken != "" {
// // rag load
// // menu of the text files from defined rag directory
@@ -685,7 +686,7 @@ func init() {
// pages.AddPage(RAGPage, chatRAGTable, true, true)
// return nil
// }
if event.Key() == tcell.KeyCtrlR {
if event.Key() == tcell.KeyCtrlR && cfg.STT_ENABLED {
defer updateStatusLine()
if asr.IsRecording() {
userSpeech, err := asr.StopRecording()
@@ -694,7 +695,9 @@ func init() {
return nil
}
if userSpeech != "" {
textArea.SetText(userSpeech, true)
// append indtead of replacing
prevText := textArea.GetText()
textArea.SetText(prevText+userSpeech, true)
} else {
logger.Warn("empty user speech")
}