WIP: adding tool fields into stream resp struct

This commit is contained in:
Grail Finder
2025-08-08 10:22:22 +03:00
parent 3c23ff2403
commit 14558f98cd
2 changed files with 43 additions and 17 deletions

43
llm.go
View File

@@ -9,7 +9,7 @@ import (
) )
type ChunkParser interface { type ChunkParser interface {
ParseChunk([]byte) (string, bool, error) ParseChunk([]byte) (*models.TextChunk, error)
FormMsg(msg, role string, cont bool) (io.Reader, error) FormMsg(msg, role string, cont bool) (io.Reader, error)
GetToken() string GetToken() string
} }
@@ -114,39 +114,47 @@ func (lcp LlamaCPPeer) FormMsg(msg, role string, resume bool) (io.Reader, error)
return bytes.NewReader(data), nil return bytes.NewReader(data), nil
} }
func (lcp LlamaCPPeer) ParseChunk(data []byte) (string, bool, error) { func (lcp LlamaCPPeer) ParseChunk(data []byte) (*models.TextChunk, error) {
llmchunk := models.LlamaCPPResp{} llmchunk := models.LlamaCPPResp{}
resp := &models.TextChunk{}
if err := json.Unmarshal(data, &llmchunk); err != nil { if err := json.Unmarshal(data, &llmchunk); err != nil {
logger.Error("failed to decode", "error", err, "line", string(data)) logger.Error("failed to decode", "error", err, "line", string(data))
return "", false, err return nil, err
} }
resp.Chunk = llmchunk.Content
if llmchunk.Stop { if llmchunk.Stop {
if llmchunk.Content != "" { if llmchunk.Content != "" {
logger.Error("text inside of finish llmchunk", "chunk", llmchunk) logger.Error("text inside of finish llmchunk", "chunk", llmchunk)
} }
return llmchunk.Content, true, nil resp.Finished = true
} }
return llmchunk.Content, false, nil return resp, nil
} }
func (op OpenAIer) GetToken() string { func (op OpenAIer) GetToken() string {
return "" return ""
} }
func (op OpenAIer) ParseChunk(data []byte) (string, bool, error) { func (op OpenAIer) ParseChunk(data []byte) (*models.TextChunk, error) {
llmchunk := models.LLMRespChunk{} llmchunk := models.LLMRespChunk{}
if err := json.Unmarshal(data, &llmchunk); err != nil { if err := json.Unmarshal(data, &llmchunk); err != nil {
logger.Error("failed to decode", "error", err, "line", string(data)) logger.Error("failed to decode", "error", err, "line", string(data))
return "", false, err return nil, err
}
resp := &models.TextChunk{
Chunk: llmchunk.Choices[len(llmchunk.Choices)-1].Delta.Content,
ToolChunk: llmchunk.Choices[len(llmchunk.Choices)-1].Delta.ToolCalls[0].Function.Arguments,
} }
content := llmchunk.Choices[len(llmchunk.Choices)-1].Delta.Content
if llmchunk.Choices[len(llmchunk.Choices)-1].FinishReason == "stop" { if llmchunk.Choices[len(llmchunk.Choices)-1].FinishReason == "stop" {
if content != "" { if resp.Chunk != "" {
logger.Error("text inside of finish llmchunk", "chunk", llmchunk) logger.Error("text inside of finish llmchunk", "chunk", llmchunk)
} }
return content, true, nil resp.Finished = true
} }
return content, false, nil if resp.ToolChunk != "" {
resp.ToolResp = true
}
return resp, nil
} }
func (op OpenAIer) FormMsg(msg, role string, resume bool) (io.Reader, error) { func (op OpenAIer) FormMsg(msg, role string, resume bool) (io.Reader, error) {
@@ -171,19 +179,22 @@ func (op OpenAIer) FormMsg(msg, role string, resume bool) (io.Reader, error) {
} }
// deepseek // deepseek
func (ds DeepSeekerCompletion) ParseChunk(data []byte) (string, bool, error) { func (ds DeepSeekerCompletion) ParseChunk(data []byte) (*models.TextChunk, error) {
llmchunk := models.DSCompletionResp{} llmchunk := models.DSCompletionResp{}
if err := json.Unmarshal(data, &llmchunk); err != nil { if err := json.Unmarshal(data, &llmchunk); err != nil {
logger.Error("failed to decode", "error", err, "line", string(data)) logger.Error("failed to decode", "error", err, "line", string(data))
return "", false, err return nil, err
}
resp := &models.TextChunk{
Chunk: llmchunk.Choices[0].Text,
} }
if llmchunk.Choices[0].FinishReason != "" { if llmchunk.Choices[0].FinishReason != "" {
if llmchunk.Choices[0].Text != "" { if resp.Chunk != "" {
logger.Error("text inside of finish llmchunk", "chunk", llmchunk) logger.Error("text inside of finish llmchunk", "chunk", llmchunk)
} }
return llmchunk.Choices[0].Text, true, nil resp.Finished = true
} }
return llmchunk.Choices[0].Text, false, nil return resp, nil
} }
func (ds DeepSeekerCompletion) GetToken() string { func (ds DeepSeekerCompletion) GetToken() string {

View File

@@ -30,6 +30,13 @@ type LLMResp struct {
ID string `json:"id"` ID string `json:"id"`
} }
type ToolDeltaResp struct {
Index int `json:"index"`
Function struct {
Arguments string `json:"arguments"`
} `json:"function"`
}
// for streaming // for streaming
type LLMRespChunk struct { type LLMRespChunk struct {
Choices []struct { Choices []struct {
@@ -37,6 +44,7 @@ type LLMRespChunk struct {
Index int `json:"index"` Index int `json:"index"`
Delta struct { Delta struct {
Content string `json:"content"` Content string `json:"content"`
ToolCalls []ToolDeltaResp `json:"tool_calls"`
} `json:"delta"` } `json:"delta"`
} `json:"choices"` } `json:"choices"`
Created int `json:"created"` Created int `json:"created"`
@@ -50,6 +58,13 @@ type LLMRespChunk struct {
} `json:"usage"` } `json:"usage"`
} }
type TextChunk struct {
Chunk string
ToolChunk string
Finished bool
ToolResp bool
}
type RoleMsg struct { type RoleMsg struct {
Role string `json:"role"` Role string `json:"role"`
Content string `json:"content"` Content string `json:"content"`