Enha: onnx config vars
This commit is contained in:
@@ -9,6 +9,7 @@ import (
|
||||
"gf-lt/models"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"sync"
|
||||
|
||||
"github.com/sugarme/tokenizer"
|
||||
"github.com/sugarme/tokenizer/pretrained"
|
||||
@@ -148,7 +149,17 @@ type ONNXEmbedder struct {
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
var onnxInitOnce sync.Once
|
||||
|
||||
func NewONNXEmbedder(modelPath, tokenizerPath string, dims int, logger *slog.Logger) (*ONNXEmbedder, error) {
|
||||
// Initialize ONNX runtime environment once
|
||||
onnxInitOnce.Do(func() {
|
||||
onnxruntime_go.SetSharedLibraryPath("/usr/local/lib/libonnxruntime.so")
|
||||
err := onnxruntime_go.InitializeEnvironment()
|
||||
if err != nil {
|
||||
logger.Error("failed to initialize ONNX runtime", "error", err)
|
||||
}
|
||||
})
|
||||
// Load tokenizer using sugarme/tokenizer
|
||||
tok, err := pretrained.FromFile(tokenizerPath)
|
||||
if err != nil {
|
||||
@@ -195,7 +206,7 @@ func (e *ONNXEmbedder) Embed(text string) ([]float32, error) {
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create input_ids tensor: %w", err)
|
||||
}
|
||||
defer inputIDsTensor.Destroy()
|
||||
defer func() { _ = inputIDsTensor.Destroy() }()
|
||||
maskTensor, err := onnxruntime_go.NewTensor[int64](
|
||||
onnxruntime_go.NewShape(1, seqLen),
|
||||
attentionMask,
|
||||
@@ -203,7 +214,7 @@ func (e *ONNXEmbedder) Embed(text string) ([]float32, error) {
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create attention_mask tensor: %w", err)
|
||||
}
|
||||
defer maskTensor.Destroy()
|
||||
defer func() { _ = maskTensor.Destroy() }()
|
||||
// 4. Create output tensor
|
||||
outputTensor, err := onnxruntime_go.NewEmptyTensor[float32](
|
||||
onnxruntime_go.NewShape(1, int64(e.dims)),
|
||||
@@ -211,7 +222,7 @@ func (e *ONNXEmbedder) Embed(text string) ([]float32, error) {
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create output tensor: %w", err)
|
||||
}
|
||||
defer outputTensor.Destroy()
|
||||
defer func() { _ = outputTensor.Destroy() }()
|
||||
// 5. Run inference
|
||||
err = e.session.Run(
|
||||
[]onnxruntime_go.Value{inputIDsTensor, maskTensor},
|
||||
@@ -257,16 +268,16 @@ func (e *ONNXEmbedder) EmbedSlice(texts []string) ([][]float32, error) {
|
||||
onnxruntime_go.NewShape(int64(batchSize), int64(maxLen)),
|
||||
inputIDs,
|
||||
)
|
||||
defer inputTensor.Destroy()
|
||||
defer func() { _ = inputTensor.Destroy() }()
|
||||
maskTensor, _ := onnxruntime_go.NewTensor[int64](
|
||||
onnxruntime_go.NewShape(int64(batchSize), int64(maxLen)),
|
||||
attentionMask,
|
||||
)
|
||||
defer maskTensor.Destroy()
|
||||
defer func() { _ = maskTensor.Destroy() }()
|
||||
outputTensor, _ := onnxruntime_go.NewEmptyTensor[float32](
|
||||
onnxruntime_go.NewShape(int64(batchSize), int64(e.dims)),
|
||||
)
|
||||
defer outputTensor.Destroy()
|
||||
defer func() { _ = outputTensor.Destroy() }()
|
||||
err := e.session.Run(
|
||||
[]onnxruntime_go.Value{inputTensor, maskTensor},
|
||||
[]onnxruntime_go.Value{outputTensor},
|
||||
|
||||
Reference in New Issue
Block a user