Feat: kokoro onnx (WIP)

This commit is contained in:
Grail Finder
2026-03-07 08:35:44 +03:00
parent 014e297ae3
commit 0598e3e86d
5 changed files with 541 additions and 86 deletions

View File

@@ -7,6 +7,7 @@ import (
"fmt"
"gf-lt/config"
"gf-lt/models"
"gf-lt/onnx"
"log/slog"
"net/http"
"os"
@@ -156,43 +157,6 @@ type ONNXEmbedder struct {
modelPath string
}
var onnxInitOnce sync.Once
var onnxReady bool
var onnxLibPath string
var cudaLibPath string
var onnxLibPaths = []string{
"/usr/lib/libonnxruntime.so",
"/usr/lib/libonnxruntime.so.1.24.2",
"/usr/local/lib/libonnxruntime.so",
"/usr/lib/x86_64-linux-gnu/libonnxruntime.so",
"/opt/onnxruntime/lib/libonnxruntime.so",
}
var cudaLibPaths = []string{
"/usr/lib/libonnxruntime_providers_cuda.so",
"/usr/local/lib/libonnxruntime_providers_cuda.so",
"/opt/onnxruntime/lib/libonnxruntime_providers_cuda.so",
}
func findONNXLibrary() string {
for _, path := range onnxLibPaths {
if _, err := os.Stat(path); err == nil {
return path
}
}
return ""
}
func findCUDALibrary() string {
for _, path := range cudaLibPaths {
if _, err := os.Stat(path); err == nil {
return path
}
}
return ""
}
func NewONNXEmbedder(modelPath, tokenizerPath string, dims int, logger *slog.Logger) (*ONNXEmbedder, error) {
// Check if model and tokenizer files exist
if _, err := os.Stat(modelPath); err != nil {
@@ -202,17 +166,16 @@ func NewONNXEmbedder(modelPath, tokenizerPath string, dims int, logger *slog.Log
return nil, fmt.Errorf("tokenizer not found: %w", err)
}
// Find ONNX library
onnxLibPath = findONNXLibrary()
if onnxLibPath == "" {
return nil, errors.New("ONNX runtime library not found in standard locations")
// Initialize ONNX runtime
if err := onnx.Init(); err != nil {
return nil, fmt.Errorf("ONNX init failed: %w", err)
}
if onnx.HasCUDASupport() {
logger.Info("ONNX CUDA support enabled")
} else {
logger.Info("ONNX using CPU fallback")
}
// Find CUDA provider library (optional)
cudaLibPath = findCUDALibrary()
if cudaLibPath == "" {
fmt.Println("WARNING: CUDA provider library not found, will use CPU")
}
emb := &ONNXEmbedder{
tokenizerPath: tokenizerPath,
dims: dims,
@@ -239,26 +202,12 @@ func (e *ONNXEmbedder) ensureInitialized() error {
}
e.tokenizer = tok
}
onnxInitOnce.Do(func() {
onnxruntime_go.SetSharedLibraryPath(onnxLibPath)
if err := onnxruntime_go.InitializeEnvironment(); err != nil {
e.logger.Error("failed to initialize ONNX runtime", "error", err)
onnxReady = false
return
}
// Register CUDA provider if available
if cudaLibPath != "" {
if err := onnxruntime_go.RegisterExecutionProviderLibrary("CUDA", cudaLibPath); err != nil {
e.logger.Warn("failed to register CUDA provider", "error", err)
}
}
onnxReady = true
})
if !onnxReady {
// ONNX runtime already initialized by onnx.Init() in NewONNXEmbedder
if !onnx.IsReady() {
return errors.New("ONNX runtime not ready")
}
// Create session options
opts, err := onnxruntime_go.NewSessionOptions()
opts, err := onnx.NewSessionOptions()
if err != nil {
return fmt.Errorf("failed to create session options: %w", err)
}
@@ -266,27 +215,7 @@ func (e *ONNXEmbedder) ensureInitialized() error {
_ = opts.Destroy()
}()
// Try to add CUDA provider
useCUDA := cudaLibPath != ""
if useCUDA {
cudaOpts, err := onnxruntime_go.NewCUDAProviderOptions()
if err != nil {
e.logger.Warn("failed to create CUDA provider options, falling back to CPU", "error", err)
useCUDA = false
} else {
defer func() {
_ = cudaOpts.Destroy()
}()
if err := cudaOpts.Update(map[string]string{"device_id": "0"}); err != nil {
e.logger.Warn("failed to update CUDA options, falling back to CPU", "error", err)
useCUDA = false
} else if err := opts.AppendExecutionProviderCUDA(cudaOpts); err != nil {
e.logger.Warn("failed to append CUDA provider, falling back to CPU", "error", err)
useCUDA = false
}
}
}
if useCUDA {
if onnx.HasCUDASupport() {
e.logger.Info("Using CUDA for ONNX inference")
} else {
e.logger.Info("Using CPU for ONNX inference")