Feat: llama.cpp model switch

This commit is contained in:
Grail Finder
2025-12-12 14:07:10 +03:00
parent 2e2e6e9f9c
commit 9edda1fecf
5 changed files with 105 additions and 50 deletions

29
bot.go
View File

@@ -65,6 +65,7 @@ var (
"google/gemma-3-27b-it:free", "google/gemma-3-27b-it:free",
"meta-llama/llama-3.3-70b-instruct:free", "meta-llama/llama-3.3-70b-instruct:free",
} }
LocalModels = []string{}
) )
// cleanNullMessages removes messages with null or empty content to prevent API issues // cleanNullMessages removes messages with null or empty content to prevent API issues
@@ -187,7 +188,7 @@ func createClient(connectTimeout time.Duration) *http.Client {
} }
} }
func fetchLCPModelName() *models.LLMModels { func fetchLCPModelName() *models.LCPModels {
//nolint //nolint
resp, err := httpClient.Get(cfg.FetchModelNameAPI) resp, err := httpClient.Get(cfg.FetchModelNameAPI)
if err != nil { if err != nil {
@@ -199,7 +200,7 @@ func fetchLCPModelName() *models.LLMModels {
return nil return nil
} }
defer resp.Body.Close() defer resp.Body.Close()
llmModel := models.LLMModels{} llmModel := models.LCPModels{}
if err := json.NewDecoder(resp.Body).Decode(&llmModel); err != nil { if err := json.NewDecoder(resp.Body).Decode(&llmModel); err != nil {
logger.Warn("failed to decode resp", "link", cfg.FetchModelNameAPI, "error", err) logger.Warn("failed to decode resp", "link", cfg.FetchModelNameAPI, "error", err)
return nil return nil
@@ -255,6 +256,24 @@ func fetchORModels(free bool) ([]string, error) {
return freeModels, nil return freeModels, nil
} }
func fetchLCPModels() ([]string, error) {
resp, err := http.Get(cfg.FetchModelNameAPI)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != 200 {
err := fmt.Errorf("failed to fetch or models; status: %s", resp.Status)
return nil, err
}
data := &models.LCPModels{}
if err := json.NewDecoder(resp.Body).Decode(data); err != nil {
return nil, err
}
localModels := data.ListModels()
return localModels, nil
}
func sendMsgToLLM(body io.Reader) { func sendMsgToLLM(body io.Reader) {
choseChunkParser() choseChunkParser()
@@ -869,6 +888,12 @@ func init() {
} }
}() }()
} }
go func() {
LocalModels, err = fetchLCPModels()
if err != nil {
logger.Error("failed to fetch llama.cpp models", "error", err)
}
}()
choseChunkParser() choseChunkParser()
httpClient = createClient(time.Second * 15) httpClient = createClient(time.Second * 15)
if cfg.TTS_ENABLED { if cfg.TTS_ENABLED {

2
llm.go
View File

@@ -157,7 +157,7 @@ func (lcp LCPCompletion) FormMsg(msg, role string, resume bool) (io.Reader, erro
logger.Debug("checking prompt for /completion", "tool_use", cfg.ToolUse, logger.Debug("checking prompt for /completion", "tool_use", cfg.ToolUse,
"msg", msg, "resume", resume, "prompt", prompt, "multimodal_data_count", len(multimodalData)) "msg", msg, "resume", resume, "prompt", prompt, "multimodal_data_count", len(multimodalData))
payload := models.NewLCPReq(prompt, multimodalData, defaultLCPProps, chatBody.MakeStopSlice()) payload := models.NewLCPReq(prompt, chatBody.Model, multimodalData, defaultLCPProps, chatBody.MakeStopSlice())
data, err := json.Marshal(payload) data, err := json.Marshal(payload)
if err != nil { if err != nil {
logger.Error("failed to form a msg", "error", err) logger.Error("failed to form a msg", "error", err)

17
main.go
View File

@@ -9,14 +9,15 @@ import (
) )
var ( var (
botRespMode = false botRespMode = false
editMode = false editMode = false
roleEditMode = false roleEditMode = false
injectRole = true injectRole = true
selectedIndex = int(-1) selectedIndex = int(-1)
currentAPIIndex = 0 // Index to track current API in ApiLinks slice currentAPIIndex = 0 // Index to track current API in ApiLinks slice
currentORModelIndex = 0 // Index to track current OpenRouter model in ORFreeModels slice currentORModelIndex = 0 // Index to track current OpenRouter model in ORFreeModels slice
shellMode = false currentLocalModelIndex = 0 // Index to track current llama.cpp model
shellMode = false
// indexLine = "F12 to show keys help | bot resp mode: [orange:-:b]%v[-:-:-] (F6) | card's char: [orange:-:b]%s[-:-:-] (ctrl+s) | chat: [orange:-:b]%s[-:-:-] (F1) | toolUseAdviced: [orange:-:b]%v[-:-:-] (ctrl+k) | model: [orange:-:b]%s[-:-:-] (ctrl+l) | skip LLM resp: [orange:-:b]%v[-:-:-] (F10)\nAPI_URL: [orange:-:b]%s[-:-:-] (ctrl+v) | ThinkUse: [orange:-:b]%v[-:-:-] (ctrl+p) | Log Level: [orange:-:b]%v[-:-:-] (ctrl+p) | Recording: [orange:-:b]%v[-:-:-] (ctrl+r) | Writing as: [orange:-:b]%s[-:-:-] (ctrl+q)" // indexLine = "F12 to show keys help | bot resp mode: [orange:-:b]%v[-:-:-] (F6) | card's char: [orange:-:b]%s[-:-:-] (ctrl+s) | chat: [orange:-:b]%s[-:-:-] (F1) | toolUseAdviced: [orange:-:b]%v[-:-:-] (ctrl+k) | model: [orange:-:b]%s[-:-:-] (ctrl+l) | skip LLM resp: [orange:-:b]%v[-:-:-] (F10)\nAPI_URL: [orange:-:b]%s[-:-:-] (ctrl+v) | ThinkUse: [orange:-:b]%v[-:-:-] (ctrl+p) | Log Level: [orange:-:b]%v[-:-:-] (ctrl+p) | Recording: [orange:-:b]%v[-:-:-] (ctrl+r) | Writing as: [orange:-:b]%s[-:-:-] (ctrl+q)"
indexLineCompletion = "F12 to show keys help | bot resp mode: [orange:-:b]%v[-:-:-] (F6) | card's char: [orange:-:b]%s[-:-:-] (ctrl+s) | chat: [orange:-:b]%s[-:-:-] (F1) | toolUseAdviced: [orange:-:b]%v[-:-:-] (ctrl+k) | model: [orange:-:b]%s[-:-:-] (ctrl+l) | skip LLM resp: [orange:-:b]%v[-:-:-] (F10)\nAPI_URL: [orange:-:b]%s[-:-:-] (ctrl+v) | Insert <think>: [orange:-:b]%v[-:-:-] (ctrl+p) | Log Level: [orange:-:b]%v[-:-:-] (ctrl+p) | Recording: [orange:-:b]%v[-:-:-] (ctrl+r) | Writing as: [orange:-:b]%s[-:-:-] (ctrl+q) | Bot will write as [orange:-:b]%s[-:-:-] (ctrl+x) | role_inject [orange:-:b]%v[-:-:-]" indexLineCompletion = "F12 to show keys help | bot resp mode: [orange:-:b]%v[-:-:-] (F6) | card's char: [orange:-:b]%s[-:-:-] (ctrl+s) | chat: [orange:-:b]%s[-:-:-] (F1) | toolUseAdviced: [orange:-:b]%v[-:-:-] (ctrl+k) | model: [orange:-:b]%s[-:-:-] (ctrl+l) | skip LLM resp: [orange:-:b]%v[-:-:-] (F10)\nAPI_URL: [orange:-:b]%s[-:-:-] (ctrl+v) | Insert <think>: [orange:-:b]%v[-:-:-] (ctrl+p) | Log Level: [orange:-:b]%v[-:-:-] (ctrl+p) | Recording: [orange:-:b]%v[-:-:-] (ctrl+r) | Writing as: [orange:-:b]%s[-:-:-] (ctrl+q) | Bot will write as [orange:-:b]%s[-:-:-] (ctrl+x) | role_inject [orange:-:b]%v[-:-:-]"
focusSwitcher = map[tview.Primitive]tview.Primitive{} focusSwitcher = map[tview.Primitive]tview.Primitive{}

View File

@@ -89,10 +89,10 @@ type ImageContentPart struct {
// RoleMsg represents a message with content that can be either a simple string or structured content parts // RoleMsg represents a message with content that can be either a simple string or structured content parts
type RoleMsg struct { type RoleMsg struct {
Role string `json:"role"` Role string `json:"role"`
Content string `json:"-"` Content string `json:"-"`
ContentParts []interface{} `json:"-"` ContentParts []interface{} `json:"-"`
ToolCallID string `json:"tool_call_id,omitempty"` // For tool response messages ToolCallID string `json:"tool_call_id,omitempty"` // For tool response messages
hasContentParts bool // Flag to indicate which content type to marshal hasContentParts bool // Flag to indicate which content type to marshal
} }
@@ -215,8 +215,8 @@ func (m RoleMsg) ToPrompt() string {
// NewRoleMsg creates a simple RoleMsg with string content // NewRoleMsg creates a simple RoleMsg with string content
func NewRoleMsg(role, content string) RoleMsg { func NewRoleMsg(role, content string) RoleMsg {
return RoleMsg{ return RoleMsg{
Role: role, Role: role,
Content: content, Content: content,
hasContentParts: false, hasContentParts: false,
} }
} }
@@ -420,34 +420,35 @@ type OpenAIReq struct {
// === // ===
type LLMModels struct { // type LLMModels struct {
Object string `json:"object"` // Object string `json:"object"`
Data []struct { // Data []struct {
ID string `json:"id"` // ID string `json:"id"`
Object string `json:"object"` // Object string `json:"object"`
Created int `json:"created"` // Created int `json:"created"`
OwnedBy string `json:"owned_by"` // OwnedBy string `json:"owned_by"`
Meta struct { // Meta struct {
VocabType int `json:"vocab_type"` // VocabType int `json:"vocab_type"`
NVocab int `json:"n_vocab"` // NVocab int `json:"n_vocab"`
NCtxTrain int `json:"n_ctx_train"` // NCtxTrain int `json:"n_ctx_train"`
NEmbd int `json:"n_embd"` // NEmbd int `json:"n_embd"`
NParams int64 `json:"n_params"` // NParams int64 `json:"n_params"`
Size int64 `json:"size"` // Size int64 `json:"size"`
} `json:"meta"` // } `json:"meta"`
} `json:"data"` // } `json:"data"`
} // }
type LlamaCPPReq struct { type LlamaCPPReq struct {
Stream bool `json:"stream"` Model string `json:"model"`
Stream bool `json:"stream"`
// For multimodal requests, prompt should be an object with prompt_string and multimodal_data // For multimodal requests, prompt should be an object with prompt_string and multimodal_data
// For regular requests, prompt is a string // For regular requests, prompt is a string
Prompt interface{} `json:"prompt"` // Can be string or object with prompt_string and multimodal_data Prompt interface{} `json:"prompt"` // Can be string or object with prompt_string and multimodal_data
Temperature float32 `json:"temperature"` Temperature float32 `json:"temperature"`
DryMultiplier float32 `json:"dry_multiplier"` DryMultiplier float32 `json:"dry_multiplier"`
Stop []string `json:"stop"` Stop []string `json:"stop"`
MinP float32 `json:"min_p"` MinP float32 `json:"min_p"`
NPredict int32 `json:"n_predict"` NPredict int32 `json:"n_predict"`
// MaxTokens int `json:"max_tokens"` // MaxTokens int `json:"max_tokens"`
// DryBase float64 `json:"dry_base"` // DryBase float64 `json:"dry_base"`
// DryAllowedLength int `json:"dry_allowed_length"` // DryAllowedLength int `json:"dry_allowed_length"`
@@ -471,12 +472,11 @@ type PromptObject struct {
PromptString string `json:"prompt_string"` PromptString string `json:"prompt_string"`
MultimodalData []string `json:"multimodal_data,omitempty"` MultimodalData []string `json:"multimodal_data,omitempty"`
// Alternative field name used by some llama.cpp implementations // Alternative field name used by some llama.cpp implementations
ImageData []string `json:"image_data,omitempty"` // For compatibility ImageData []string `json:"image_data,omitempty"` // For compatibility
} }
func NewLCPReq(prompt string, multimodalData []string, props map[string]float32, stopStrings []string) LlamaCPPReq { func NewLCPReq(prompt, model string, multimodalData []string, props map[string]float32, stopStrings []string) LlamaCPPReq {
var finalPrompt interface{} var finalPrompt interface{}
if len(multimodalData) > 0 { if len(multimodalData) > 0 {
// When multimodal data is present, use the object format as per Python example: // When multimodal data is present, use the object format as per Python example:
// { "prompt": { "prompt_string": "...", "multimodal_data": [...] } } // { "prompt": { "prompt_string": "...", "multimodal_data": [...] } }
@@ -489,8 +489,8 @@ func NewLCPReq(prompt string, multimodalData []string, props map[string]float32,
// When no multimodal data, use plain string // When no multimodal data, use plain string
finalPrompt = prompt finalPrompt = prompt
} }
return LlamaCPPReq{ return LlamaCPPReq{
Model: model,
Stream: true, Stream: true,
Prompt: finalPrompt, Prompt: finalPrompt,
Temperature: props["temperature"], Temperature: props["temperature"],
@@ -505,3 +505,27 @@ type LlamaCPPResp struct {
Content string `json:"content"` Content string `json:"content"`
Stop bool `json:"stop"` Stop bool `json:"stop"`
} }
type LCPModels struct {
Data []struct {
ID string `json:"id"`
Object string `json:"object"`
OwnedBy string `json:"owned_by"`
Created int `json:"created"`
InCache bool `json:"in_cache"`
Path string `json:"path"`
Status struct {
Value string `json:"value"`
Args []string `json:"args"`
} `json:"status"`
} `json:"data"`
Object string `json:"object"`
}
func (lcp *LCPModels) ListModels() []string {
resp := []string{}
for _, model := range lcp.Data {
resp = append(resp, model.ID)
}
return resp
}

15
tui.go
View File

@@ -961,11 +961,16 @@ func init() {
} }
updateStatusLine() updateStatusLine()
} else { } else {
// For non-OpenRouter APIs, use the old logic if len(LocalModels) > 0 {
go func() { currentLocalModelIndex = (currentLocalModelIndex + 1) % len(LocalModels)
fetchLCPModelName() // blocks chatBody.Model = LocalModels[currentLocalModelIndex]
updateStatusLine() }
}() updateStatusLine()
// // For non-OpenRouter APIs, use the old logic
// go func() {
// fetchLCPModelName() // blocks
// updateStatusLine()
// }()
} }
return nil return nil
} }