Dep: trying sugarme tokenizer
This commit is contained in:
7
go.mod
7
go.mod
@@ -16,7 +16,7 @@ require (
|
|||||||
github.com/neurosnap/sentences v1.1.2
|
github.com/neurosnap/sentences v1.1.2
|
||||||
github.com/playwright-community/playwright-go v0.5700.1
|
github.com/playwright-community/playwright-go v0.5700.1
|
||||||
github.com/rivo/tview v0.42.0
|
github.com/rivo/tview v0.42.0
|
||||||
github.com/takara-ai/go-tokenizers v1.0.0
|
github.com/sugarme/tokenizer v0.3.0
|
||||||
github.com/yalue/onnxruntime_go v1.27.0
|
github.com/yalue/onnxruntime_go v1.27.0
|
||||||
github.com/yuin/goldmark v1.4.13
|
github.com/yuin/goldmark v1.4.13
|
||||||
)
|
)
|
||||||
@@ -27,6 +27,7 @@ require (
|
|||||||
github.com/dustin/go-humanize v1.0.1 // indirect
|
github.com/dustin/go-humanize v1.0.1 // indirect
|
||||||
github.com/ebitengine/oto/v3 v3.4.0 // indirect
|
github.com/ebitengine/oto/v3 v3.4.0 // indirect
|
||||||
github.com/ebitengine/purego v0.9.1 // indirect
|
github.com/ebitengine/purego v0.9.1 // indirect
|
||||||
|
github.com/emirpasic/gods v1.18.1 // indirect
|
||||||
github.com/gdamore/encoding v1.0.1 // indirect
|
github.com/gdamore/encoding v1.0.1 // indirect
|
||||||
github.com/go-jose/go-jose/v3 v3.0.4 // indirect
|
github.com/go-jose/go-jose/v3 v3.0.4 // indirect
|
||||||
github.com/go-stack/stack v1.8.1 // indirect
|
github.com/go-stack/stack v1.8.1 // indirect
|
||||||
@@ -35,10 +36,14 @@ require (
|
|||||||
github.com/hajimehoshi/oto/v2 v2.3.1 // indirect
|
github.com/hajimehoshi/oto/v2 v2.3.1 // indirect
|
||||||
github.com/lucasb-eyer/go-colorful v1.3.0 // indirect
|
github.com/lucasb-eyer/go-colorful v1.3.0 // indirect
|
||||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||||
|
github.com/mitchellh/colorstring v0.0.0-20190213212951-d06e56a500db // indirect
|
||||||
github.com/ncruces/go-strftime v1.0.0 // indirect
|
github.com/ncruces/go-strftime v1.0.0 // indirect
|
||||||
|
github.com/patrickmn/go-cache v2.1.0+incompatible // indirect
|
||||||
github.com/pkg/errors v0.9.1 // indirect
|
github.com/pkg/errors v0.9.1 // indirect
|
||||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
|
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
|
||||||
github.com/rivo/uniseg v0.4.7 // indirect
|
github.com/rivo/uniseg v0.4.7 // indirect
|
||||||
|
github.com/schollz/progressbar/v2 v2.15.0 // indirect
|
||||||
|
github.com/sugarme/regexpset v0.0.0-20200920021344-4d4ec8eaf93c // indirect
|
||||||
golang.org/x/exp v0.0.0-20251209150349-8475f28825e9 // indirect
|
golang.org/x/exp v0.0.0-20251209150349-8475f28825e9 // indirect
|
||||||
golang.org/x/net v0.48.0 // indirect
|
golang.org/x/net v0.48.0 // indirect
|
||||||
golang.org/x/sys v0.39.0 // indirect
|
golang.org/x/sys v0.39.0 // indirect
|
||||||
|
|||||||
15
go.sum
15
go.sum
@@ -21,6 +21,8 @@ github.com/ebitengine/oto/v3 v3.4.0 h1:br0PgASsEWaoWn38b2Goe7m1GKFYfNgnsjSd5Gg+/
|
|||||||
github.com/ebitengine/oto/v3 v3.4.0/go.mod h1:IOleLVD0m+CMak3mRVwsYY8vTctQgOM0iiL6S7Ar7eI=
|
github.com/ebitengine/oto/v3 v3.4.0/go.mod h1:IOleLVD0m+CMak3mRVwsYY8vTctQgOM0iiL6S7Ar7eI=
|
||||||
github.com/ebitengine/purego v0.9.1 h1:a/k2f2HQU3Pi399RPW1MOaZyhKJL9w/xFpKAg4q1s0A=
|
github.com/ebitengine/purego v0.9.1 h1:a/k2f2HQU3Pi399RPW1MOaZyhKJL9w/xFpKAg4q1s0A=
|
||||||
github.com/ebitengine/purego v0.9.1/go.mod h1:iIjxzd6CiRiOG0UyXP+V1+jWqUXVjPKLAI0mRfJZTmQ=
|
github.com/ebitengine/purego v0.9.1/go.mod h1:iIjxzd6CiRiOG0UyXP+V1+jWqUXVjPKLAI0mRfJZTmQ=
|
||||||
|
github.com/emirpasic/gods v1.18.1 h1:FXtiHYKDGKCW2KzwZKx0iC0PQmdlorYgdFG9jPXJ1Bc=
|
||||||
|
github.com/emirpasic/gods v1.18.1/go.mod h1:8tpGGwCnJ5H4r6BWwaV6OrWmMoPhUl5jm/FMNAnJvWQ=
|
||||||
github.com/gdamore/encoding v1.0.1 h1:YzKZckdBL6jVt2Gc+5p82qhrGiqMdG/eNs6Wy0u3Uhw=
|
github.com/gdamore/encoding v1.0.1 h1:YzKZckdBL6jVt2Gc+5p82qhrGiqMdG/eNs6Wy0u3Uhw=
|
||||||
github.com/gdamore/encoding v1.0.1/go.mod h1:0Z0cMFinngz9kS1QfMjCP8TY7em3bZYeeklsSDPivEo=
|
github.com/gdamore/encoding v1.0.1/go.mod h1:0Z0cMFinngz9kS1QfMjCP8TY7em3bZYeeklsSDPivEo=
|
||||||
github.com/gdamore/tcell/v2 v2.13.2 h1:5j4srfF8ow3HICOv/61/sOhQtA25qxEB2XR3Q/Bhx2g=
|
github.com/gdamore/tcell/v2 v2.13.2 h1:5j4srfF8ow3HICOv/61/sOhQtA25qxEB2XR3Q/Bhx2g=
|
||||||
@@ -61,10 +63,14 @@ github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWE
|
|||||||
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||||
github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU=
|
github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU=
|
||||||
github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
|
github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
|
||||||
|
github.com/mitchellh/colorstring v0.0.0-20190213212951-d06e56a500db h1:62I3jR2EmQ4l5rM/4FEfDWcRD+abF5XlKShorW5LRoQ=
|
||||||
|
github.com/mitchellh/colorstring v0.0.0-20190213212951-d06e56a500db/go.mod h1:l0dey0ia/Uv7NcFFVbCLtqEBQbrT4OCwCSKTEv6enCw=
|
||||||
github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w=
|
github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w=
|
||||||
github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
|
github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
|
||||||
github.com/neurosnap/sentences v1.1.2 h1:iphYOzx/XckXeBiLIUBkPu2EKMJ+6jDbz/sLJZ7ZoUw=
|
github.com/neurosnap/sentences v1.1.2 h1:iphYOzx/XckXeBiLIUBkPu2EKMJ+6jDbz/sLJZ7ZoUw=
|
||||||
github.com/neurosnap/sentences v1.1.2/go.mod h1:/pwU4E9XNL21ygMIkOIllv/SMy2ujHwpf8GQPu1YPbQ=
|
github.com/neurosnap/sentences v1.1.2/go.mod h1:/pwU4E9XNL21ygMIkOIllv/SMy2ujHwpf8GQPu1YPbQ=
|
||||||
|
github.com/patrickmn/go-cache v2.1.0+incompatible h1:HRMgzkcYKYpi3C8ajMPV8OFXaaRUnok+kx1WdO15EQc=
|
||||||
|
github.com/patrickmn/go-cache v2.1.0+incompatible/go.mod h1:3Qf8kWWT7OJRJbdiICTKqZju1ZixQ/KpMGzzAfe6+WQ=
|
||||||
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
||||||
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||||
github.com/playwright-community/playwright-go v0.5700.1 h1:PNFb1byWqrTT720rEO0JL88C6Ju0EmUnR5deFLvtP/U=
|
github.com/playwright-community/playwright-go v0.5700.1 h1:PNFb1byWqrTT720rEO0JL88C6Ju0EmUnR5deFLvtP/U=
|
||||||
@@ -77,12 +83,17 @@ github.com/rivo/tview v0.42.0 h1:b/ftp+RxtDsHSaynXTbJb+/n/BxDEi+W3UfF5jILK6c=
|
|||||||
github.com/rivo/tview v0.42.0/go.mod h1:cSfIYfhpSGCjp3r/ECJb+GKS7cGJnqV8vfjQPwoXyfY=
|
github.com/rivo/tview v0.42.0/go.mod h1:cSfIYfhpSGCjp3r/ECJb+GKS7cGJnqV8vfjQPwoXyfY=
|
||||||
github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ=
|
github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ=
|
||||||
github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88=
|
github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88=
|
||||||
|
github.com/schollz/progressbar/v2 v2.15.0 h1:dVzHQ8fHRmtPjD3K10jT3Qgn/+H+92jhPrhmxIJfDz8=
|
||||||
|
github.com/schollz/progressbar/v2 v2.15.0/go.mod h1:UdPq3prGkfQ7MOzZKlDRpYKcFqEMczbD7YmbPgpzKMI=
|
||||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||||
|
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
||||||
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||||
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
|
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
|
||||||
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||||
github.com/takara-ai/go-tokenizers v1.0.0 h1:C+UQl3fPFw08YQdwthzPZbqykh6yumzjPrSs+3OSe7o=
|
github.com/sugarme/regexpset v0.0.0-20200920021344-4d4ec8eaf93c h1:pwb4kNSHb4K89ymCaN+5lPH/MwnfSVg4rzGDh4d+iy4=
|
||||||
github.com/takara-ai/go-tokenizers v1.0.0/go.mod h1:2A7hN3gMtAARJ2V3sYyIzTDm+GNTudBX+CwUOyIVH2A=
|
github.com/sugarme/regexpset v0.0.0-20200920021344-4d4ec8eaf93c/go.mod h1:2gwkXLWbDGUQWeL3RtpCmcY4mzCtU13kb9UsAg9xMaw=
|
||||||
|
github.com/sugarme/tokenizer v0.3.0 h1:FE8DYbNSz/kSbgEo9l/RjgYHkIJYEdskumitFQBE9FE=
|
||||||
|
github.com/sugarme/tokenizer v0.3.0/go.mod h1:VJ+DLK5ZEZwzvODOWwY0cw+B1dabTd3nCB5HuFCItCc=
|
||||||
github.com/yalue/onnxruntime_go v1.27.0 h1:c1YSgDNtpf0WGtxj3YeRIb8VC5LmM1J+Ve3uHdteC1U=
|
github.com/yalue/onnxruntime_go v1.27.0 h1:c1YSgDNtpf0WGtxj3YeRIb8VC5LmM1J+Ve3uHdteC1U=
|
||||||
github.com/yalue/onnxruntime_go v1.27.0/go.mod h1:b4X26A8pekNb1ACJ58wAXgNKeUCGEAQ9dmACut9Sm/4=
|
github.com/yalue/onnxruntime_go v1.27.0/go.mod h1:b4X26A8pekNb1ACJ58wAXgNKeUCGEAQ9dmACut9Sm/4=
|
||||||
github.com/yuin/goldmark v1.4.13 h1:fVcFKWvrslecOb/tg+Cc05dkeYx540o0FuFt3nUVDoE=
|
github.com/yuin/goldmark v1.4.13 h1:fVcFKWvrslecOb/tg+Cc05dkeYx540o0FuFt3nUVDoE=
|
||||||
|
|||||||
183
rag/embedder.go
183
rag/embedder.go
@@ -10,8 +10,8 @@ import (
|
|||||||
"log/slog"
|
"log/slog"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"github.com/takara-ai/go-tokenizers/tokenizers"
|
"github.com/sugarme/tokenizer"
|
||||||
|
"github.com/sugarme/tokenizer/pretrained"
|
||||||
"github.com/yalue/onnxruntime_go"
|
"github.com/yalue/onnxruntime_go"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -141,59 +141,168 @@ func (a *APIEmbedder) EmbedSlice(lines []string) ([][]float32, error) {
|
|||||||
// 1. Loading ONNX models locally
|
// 1. Loading ONNX models locally
|
||||||
// 2. Using a Go ONNX runtime (like gorgonia/onnx or similar)
|
// 2. Using a Go ONNX runtime (like gorgonia/onnx or similar)
|
||||||
// 3. Converting text to embeddings without external API calls
|
// 3. Converting text to embeddings without external API calls
|
||||||
|
|
||||||
type ONNXEmbedder struct {
|
type ONNXEmbedder struct {
|
||||||
session *onnxruntime_go.DynamicAdvancedSession
|
session *onnxruntime_go.DynamicAdvancedSession
|
||||||
tokenizer *tokenizers.Tokenizer
|
tokenizer *tokenizer.Tokenizer
|
||||||
dims int // 768, 512, 256, or 128 for Matryoshka
|
dims int // embedding dimension (e.g., 768)
|
||||||
|
logger *slog.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *ONNXEmbedder) EmbedSlice(texts []string) ([][]float32, error) {
|
func NewONNXEmbedder(modelPath, tokenizerPath string, dims int, logger *slog.Logger) (*ONNXEmbedder, error) {
|
||||||
// Batch processing
|
// Load tokenizer using sugarme/tokenizer
|
||||||
inputs := e.prepareBatch(texts)
|
tok, err := pretrained.FromFile(tokenizerPath)
|
||||||
outputs := make([][]float32, len(texts))
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to load tokenizer: %w", err)
|
||||||
// Run batch inference (much faster)
|
}
|
||||||
err := e.session.Run(inputs, outputs)
|
// Create ONNX session
|
||||||
return outputs, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewONNXEmbedder(modelPath string) (*ONNXEmbedder, error) {
|
|
||||||
// Load ONNX model
|
|
||||||
session, err := onnxruntime_go.NewDynamicAdvancedSession(
|
session, err := onnxruntime_go.NewDynamicAdvancedSession(
|
||||||
modelPath, // onnx/embedgemma/model_q4.onnx
|
modelPath, // onnx/embedgemma/model_q4.onnx
|
||||||
[]string{"input_ids", "attention_mask"},
|
[]string{"input_ids", "attention_mask"},
|
||||||
[]string{"sentence_embedding"},
|
[]string{"sentence_embedding"},
|
||||||
nil,
|
nil, // optional options
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, fmt.Errorf("failed to create ONNX session: %w", err)
|
||||||
}
|
}
|
||||||
// Load tokenizer (from Hugging Face)
|
|
||||||
tokenizer, err := tokenizers.FromFile("./tokenizer.json")
|
|
||||||
return &ONNXEmbedder{
|
return &ONNXEmbedder{
|
||||||
session: session,
|
session: session,
|
||||||
tokenizer: tokenizer,
|
tokenizer: tok,
|
||||||
|
dims: dims,
|
||||||
|
logger: logger,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *ONNXEmbedder) Embed(text string) ([]float32, error) {
|
func (e *ONNXEmbedder) Embed(text string) ([]float32, error) {
|
||||||
// Tokenize
|
// 1. Tokenize
|
||||||
tokens := e.tokenizer.Encode(text, true)
|
encoding, err := e.tokenizer.Encode(text, true) // true = add special tokens
|
||||||
// Prepare inputs
|
if err != nil {
|
||||||
inputIDs := []int64{tokens.GetIds()}
|
return nil, fmt.Errorf("tokenization failed: %w", err)
|
||||||
attentionMask := []int64{tokens.GetAttentionMask()}
|
}
|
||||||
// Run inference
|
// Convert []int32 to []int64 for ONNX
|
||||||
output := onnxruntime_go.NewEmptyTensor[float32](
|
inputIDs := make([]int64, len(encoding.GetIDs()))
|
||||||
onnxruntime_go.NewShape(1, 768),
|
for i, id := range encoding.GetIDs() {
|
||||||
)
|
inputIDs[i] = int64(id)
|
||||||
err := e.session.Run(
|
}
|
||||||
map[string]any{
|
attentionMask := make([]int64, len(encoding.GetAttentionMask()))
|
||||||
"input_ids": inputIDs,
|
for i, m := range encoding.GetAttentionMask() {
|
||||||
"attention_mask": attentionMask,
|
attentionMask[i] = int64(m)
|
||||||
|
}
|
||||||
|
// 2. Create input tensors (shape: [1, seq_len])
|
||||||
|
seqLen := int64(len(inputIDs))
|
||||||
|
inputIDsTensor, err := onnxruntime_go.NewTensor(onnxruntime_go.NewShape(1, seqLen), inputIDs)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create input_ids tensor: %w", err)
|
||||||
|
}
|
||||||
|
defer inputIDsTensor.Destroy()
|
||||||
|
maskTensor, err := onnxruntime_go.NewTensor(onnxruntime_go.NewShape(1, seqLen), attentionMask)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create attention_mask tensor: %w", err)
|
||||||
|
}
|
||||||
|
defer maskTensor.Destroy()
|
||||||
|
// 3. Create output tensor (shape: [1, dims])
|
||||||
|
outputTensor, err := onnxruntime_go.NewEmptyTensor[float32](onnxruntime_go.NewShape(1, int64(e.dims)))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create output tensor: %w", err)
|
||||||
|
}
|
||||||
|
defer outputTensor.Destroy()
|
||||||
|
// 4. Run inference
|
||||||
|
err = e.session.Run(
|
||||||
|
map[string]*onnxruntime_go.Tensor{
|
||||||
|
"input_ids": inputIDsTensor,
|
||||||
|
"attention_mask": maskTensor,
|
||||||
},
|
},
|
||||||
[]string{"sentence_embedding"},
|
[]string{"sentence_embedding"},
|
||||||
[]any{&output},
|
[]*onnxruntime_go.Tensor{outputTensor},
|
||||||
)
|
)
|
||||||
return output.GetData(), nil
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("inference failed: %w", err)
|
||||||
|
}
|
||||||
|
// 5. Extract data
|
||||||
|
outputData := outputTensor.GetData()
|
||||||
|
// outputTensor is owned by us, but GetData returns a slice that remains valid until Destroy.
|
||||||
|
// We need to copy if we want to keep it after Destroy (we defer Destroy, so copy now).
|
||||||
|
embedding := make([]float32, len(outputData))
|
||||||
|
copy(embedding, outputData)
|
||||||
|
return embedding, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// EmbedSlice (batch) – to be implemented properly
|
||||||
|
func (e *ONNXEmbedder) EmbedSlice(texts []string) ([][]float32, error) {
|
||||||
|
if len(texts) == 0 {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
// 1. Tokenize all texts and find max length for padding
|
||||||
|
encodings := make([]*tokenizer.Encoding, len(texts))
|
||||||
|
maxLen := 0
|
||||||
|
for i, txt := range texts {
|
||||||
|
enc, err := e.tokenizer.Encode(txt, true)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("tokenization failed at index %d: %w", i, err)
|
||||||
|
}
|
||||||
|
encodings[i] = enc
|
||||||
|
if l := len(enc.GetIDs()); l > maxLen {
|
||||||
|
maxLen = l
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// 2. Build padded input_ids and attention_mask (shape: [batch, maxLen])
|
||||||
|
batchSize := len(texts)
|
||||||
|
inputIDs := make([]int64, batchSize*maxLen)
|
||||||
|
attentionMask := make([]int64, batchSize*maxLen)
|
||||||
|
for i, enc := range encodings {
|
||||||
|
ids := enc.GetIDs()
|
||||||
|
mask := enc.GetAttentionMask()
|
||||||
|
offset := i * maxLen
|
||||||
|
// copy actual tokens
|
||||||
|
for j := 0; j < len(ids); j++ {
|
||||||
|
inputIDs[offset+j] = int64(ids[j])
|
||||||
|
attentionMask[offset+j] = int64(mask[j])
|
||||||
|
}
|
||||||
|
// remaining positions (padding) are already zero-initialized
|
||||||
|
}
|
||||||
|
// 3. Create tensors
|
||||||
|
inputIDsTensor, err := onnxruntime_go.NewTensor(
|
||||||
|
onnxruntime_go.NewShape(int64(batchSize), int64(maxLen)),
|
||||||
|
inputIDs,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer inputIDsTensor.Destroy()
|
||||||
|
maskTensor, err := onnxruntime_go.NewTensor(
|
||||||
|
onnxruntime_go.NewShape(int64(batchSize), int64(maxLen)),
|
||||||
|
attentionMask,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer maskTensor.Destroy()
|
||||||
|
outputTensor, err := onnxruntime_go.NewEmptyTensor[float32](
|
||||||
|
onnxruntime_go.NewShape(int64(batchSize), int64(e.dims)),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer outputTensor.Destroy()
|
||||||
|
// 4. Run
|
||||||
|
err = e.session.Run(
|
||||||
|
map[string]*onnxruntime_go.Tensor{
|
||||||
|
"input_ids": inputIDsTensor,
|
||||||
|
"attention_mask": maskTensor,
|
||||||
|
},
|
||||||
|
[]string{"sentence_embedding"},
|
||||||
|
[]*onnxruntime_go.Tensor{outputTensor},
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
// 5. Extract batch results
|
||||||
|
outputData := outputTensor.GetData()
|
||||||
|
embeddings := make([][]float32, batchSize)
|
||||||
|
for i := 0; i < batchSize; i++ {
|
||||||
|
start := i * e.dims
|
||||||
|
emb := make([]float32, e.dims)
|
||||||
|
copy(emb, outputData[start:start+e.dims])
|
||||||
|
embeddings[i] = emb
|
||||||
|
}
|
||||||
|
return embeddings, nil
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user