Feat: stt voice typing with whisper.cpp server [WIP]
This commit is contained in:
235
extra/stt.go
235
extra/stt.go
@@ -2,18 +2,16 @@ package extra
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
"time"
|
||||
|
||||
"github.com/MarkKremer/microphone/v2"
|
||||
"github.com/gopxl/beep/v2"
|
||||
"github.com/gopxl/beep/v2/wav"
|
||||
"github.com/gordonklaus/portaudio"
|
||||
)
|
||||
|
||||
type STT interface {
|
||||
@@ -22,167 +20,140 @@ type STT interface {
|
||||
IsRecording() bool
|
||||
}
|
||||
|
||||
type StreamCloser interface {
|
||||
Close() error
|
||||
}
|
||||
|
||||
type WhisperSTT struct {
|
||||
logger *slog.Logger
|
||||
ServerURL string
|
||||
SampleRate beep.SampleRate
|
||||
Buffer *bytes.Buffer
|
||||
streamer beep.StreamCloser
|
||||
SampleRate int
|
||||
RawBuffer *bytes.Buffer
|
||||
WavBuffer *bytes.Buffer
|
||||
streamer StreamCloser
|
||||
recording bool
|
||||
}
|
||||
|
||||
type writeseeker struct {
|
||||
buf []byte
|
||||
pos int
|
||||
}
|
||||
|
||||
func (m *writeseeker) Write(p []byte) (n int, err error) {
|
||||
minCap := m.pos + len(p)
|
||||
if minCap > cap(m.buf) { // Make sure buf has enough capacity:
|
||||
buf2 := make([]byte, len(m.buf), minCap+len(p)) // add some extra
|
||||
copy(buf2, m.buf)
|
||||
m.buf = buf2
|
||||
}
|
||||
if minCap > len(m.buf) {
|
||||
m.buf = m.buf[:minCap]
|
||||
}
|
||||
copy(m.buf[m.pos:], p)
|
||||
m.pos += len(p)
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
func (m *writeseeker) Seek(offset int64, whence int) (int64, error) {
|
||||
newPos, offs := 0, int(offset)
|
||||
switch whence {
|
||||
case io.SeekStart:
|
||||
newPos = offs
|
||||
case io.SeekCurrent:
|
||||
newPos = m.pos + offs
|
||||
case io.SeekEnd:
|
||||
newPos = len(m.buf) + offs
|
||||
}
|
||||
if newPos < 0 {
|
||||
return 0, errors.New("negative result pos")
|
||||
}
|
||||
m.pos = newPos
|
||||
return int64(newPos), nil
|
||||
}
|
||||
|
||||
// Reader returns an io.Reader. Use it, for example, with io.Copy, to copy the content of the WriterSeeker buffer to an io.Writer
|
||||
func (ws *writeseeker) Reader() io.Reader {
|
||||
return bytes.NewReader(ws.buf)
|
||||
}
|
||||
|
||||
func NewWhisperSTT(logger *slog.Logger, serverURL string, sampleRate beep.SampleRate) *WhisperSTT {
|
||||
func NewWhisperSTT(logger *slog.Logger, serverURL string, sampleRate int) *WhisperSTT {
|
||||
return &WhisperSTT{
|
||||
logger: logger,
|
||||
ServerURL: serverURL,
|
||||
SampleRate: sampleRate,
|
||||
Buffer: new(bytes.Buffer),
|
||||
RawBuffer: new(bytes.Buffer),
|
||||
WavBuffer: new(bytes.Buffer),
|
||||
}
|
||||
}
|
||||
|
||||
func (stt *WhisperSTT) StartRecording() error {
|
||||
stream, err := microphoneStream(stt.SampleRate)
|
||||
if err != nil {
|
||||
if err := stt.microphoneStream(stt.SampleRate); err != nil {
|
||||
return fmt.Errorf("failed to init microphone: %w", err)
|
||||
}
|
||||
|
||||
stt.streamer = stream
|
||||
stt.recording = true
|
||||
|
||||
go stt.capture()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (stt *WhisperSTT) capture() {
|
||||
sink := beep.NewBuffer(beep.Format{
|
||||
SampleRate: stt.SampleRate,
|
||||
NumChannels: 1,
|
||||
Precision: 2,
|
||||
})
|
||||
|
||||
// Append the streamer to the buffer and encode as WAV
|
||||
sink.Append(stt.streamer)
|
||||
|
||||
// Encode the captured audio to WAV format using beep's WAV encoder
|
||||
// var wavBuf bytes.Buffer
|
||||
var wavBuf writeseeker
|
||||
if err := wav.Encode(&wavBuf, sink.Streamer(0, sink.Len()), beep.Format{
|
||||
SampleRate: stt.SampleRate,
|
||||
NumChannels: 1,
|
||||
Precision: 2,
|
||||
}); err != nil {
|
||||
stt.logger.Error("failed to encode WAV", "error", err)
|
||||
}
|
||||
r := wavBuf.Reader()
|
||||
// stt.Buffer = &wavBuf
|
||||
if _, err := io.Copy(stt.Buffer, r); err != nil {
|
||||
stt.logger.Error("failed to encode WAV", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (stt *WhisperSTT) StopRecording() (string, error) {
|
||||
if !stt.recording {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
stt.streamer.Close()
|
||||
stt.recording = false
|
||||
|
||||
// Send to Whisper.cpp server
|
||||
req, err := http.NewRequest("POST", stt.ServerURL, stt.Buffer)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create request: %w", err)
|
||||
time.Sleep(time.Millisecond * 200) // this is not the way
|
||||
// wait loop to finish?
|
||||
if stt.RawBuffer == nil {
|
||||
err := errors.New("unexpected nil RawBuffer")
|
||||
stt.logger.Error(err.Error())
|
||||
return "", err
|
||||
}
|
||||
req.Header.Set("Content-Type", "audio/wav")
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
// Create WAV header first
|
||||
stt.writeWavHeader(stt.WavBuffer, len(stt.RawBuffer.Bytes())) // Write initial header with 0 size
|
||||
stt.WavBuffer.Write(stt.RawBuffer.Bytes())
|
||||
body := &bytes.Buffer{} // third buffer?
|
||||
writer := multipart.NewWriter(body)
|
||||
// Add audio file part
|
||||
part, err := writer.CreateFormFile("file", "recording.wav")
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("transcription request failed: %w", err)
|
||||
stt.logger.Error("fn: StopRecording", "error", err)
|
||||
return "", err
|
||||
}
|
||||
_, err = io.Copy(part, stt.WavBuffer)
|
||||
if err != nil {
|
||||
stt.logger.Error("fn: StopRecording", "error", err)
|
||||
return "", err
|
||||
}
|
||||
// Add response format field
|
||||
err = writer.WriteField("response_format", "text")
|
||||
if err != nil {
|
||||
stt.logger.Error("fn: StopRecording", "error", err)
|
||||
return "", err
|
||||
}
|
||||
if writer.Close() != nil {
|
||||
stt.logger.Error("fn: StopRecording", "error", err)
|
||||
return "", err
|
||||
}
|
||||
// Send request
|
||||
resp, err := http.Post("http://localhost:8081/inference", writer.FormDataContentType(), body)
|
||||
if err != nil {
|
||||
stt.logger.Error("fn: StopRecording", "error", err)
|
||||
return "", err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return "", fmt.Errorf("unexpected status code: %d", resp.StatusCode)
|
||||
// Read and print response
|
||||
responseText, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
stt.logger.Error("fn: StopRecording", "error", err)
|
||||
return "", err
|
||||
}
|
||||
stt.logger.Info("got transcript", "text", string(responseText))
|
||||
return string(responseText), nil
|
||||
}
|
||||
|
||||
var result struct {
|
||||
Text string `json:"text"`
|
||||
}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||
return "", fmt.Errorf("failed to decode response: %w", err)
|
||||
}
|
||||
|
||||
return result.Text, nil
|
||||
func (stt *WhisperSTT) writeWavHeader(w io.Writer, dataSize int) {
|
||||
header := make([]byte, 44)
|
||||
copy(header[0:4], "RIFF")
|
||||
binary.LittleEndian.PutUint32(header[4:8], uint32(36+dataSize))
|
||||
copy(header[8:12], "WAVE")
|
||||
copy(header[12:16], "fmt ")
|
||||
binary.LittleEndian.PutUint32(header[16:20], 16)
|
||||
binary.LittleEndian.PutUint16(header[20:22], 1)
|
||||
binary.LittleEndian.PutUint16(header[22:24], 1)
|
||||
binary.LittleEndian.PutUint32(header[24:28], uint32(stt.SampleRate))
|
||||
binary.LittleEndian.PutUint32(header[28:32], uint32(stt.SampleRate)*1*(16/8))
|
||||
binary.LittleEndian.PutUint16(header[32:34], 1*(16/8))
|
||||
binary.LittleEndian.PutUint16(header[34:36], 16)
|
||||
copy(header[36:40], "data")
|
||||
binary.LittleEndian.PutUint32(header[40:44], uint32(dataSize))
|
||||
w.Write(header)
|
||||
}
|
||||
|
||||
func (stt *WhisperSTT) IsRecording() bool {
|
||||
return stt.recording
|
||||
}
|
||||
|
||||
func microphoneStream(sr beep.SampleRate) (beep.StreamCloser, error) {
|
||||
if err := microphone.Init(); err != nil {
|
||||
return nil, fmt.Errorf("microphone init failed: %w", err)
|
||||
func (stt *WhisperSTT) microphoneStream(sampleRate int) error {
|
||||
if err := portaudio.Initialize(); err != nil {
|
||||
return fmt.Errorf("portaudio init failed: %w", err)
|
||||
}
|
||||
|
||||
stream, _, err := microphone.OpenDefaultStream(sr, 1) // 1 channel mono
|
||||
in := make([]int16, 64)
|
||||
stream, err := portaudio.OpenDefaultStream(1, 0, float64(sampleRate), len(in), in)
|
||||
if err != nil {
|
||||
microphone.Terminate()
|
||||
return nil, fmt.Errorf("failed to open microphone: %w", err)
|
||||
portaudio.Terminate()
|
||||
return fmt.Errorf("failed to open microphone: %w", err)
|
||||
}
|
||||
|
||||
// Handle OS signals to clean up
|
||||
sig := make(chan os.Signal, 1)
|
||||
signal.Notify(sig, os.Interrupt, os.Kill)
|
||||
go func() {
|
||||
<-sig
|
||||
stream.Stop()
|
||||
stream.Close()
|
||||
microphone.Terminate()
|
||||
os.Exit(1)
|
||||
}()
|
||||
|
||||
stream.Start()
|
||||
return stream, nil
|
||||
go func(stream *portaudio.Stream) {
|
||||
if err := stream.Start(); err != nil {
|
||||
stt.logger.Error("microphoneStream", "error", err)
|
||||
return
|
||||
}
|
||||
for {
|
||||
if !stt.IsRecording() {
|
||||
return
|
||||
}
|
||||
if err := stream.Read(); err != nil {
|
||||
stt.logger.Error("reading stream", "error", err)
|
||||
return
|
||||
}
|
||||
if err := binary.Write(stt.RawBuffer, binary.LittleEndian, in); err != nil {
|
||||
stt.logger.Error("writing to buffer", "error", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}(stream)
|
||||
return nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user