Enha: check if model has vision before giving it vision tools

This commit is contained in:
Grail Finder
2026-03-02 11:25:20 +03:00
parent 9ba46b40cc
commit 4f6000a43a
5 changed files with 91 additions and 7 deletions

28
bot.go
View File

@@ -433,6 +433,33 @@ func isModelLoaded(modelID string) (bool, error) {
return false, nil return false, nil
} }
func ModelHasVision(api, modelID string) bool {
switch {
case strings.Contains(api, "deepseek"):
return false
case strings.Contains(api, "openrouter"):
resp, err := http.Get("https://openrouter.ai/api/v1/models")
if err != nil {
logger.Warn("failed to fetch OR models for vision check", "error", err)
return false
}
defer resp.Body.Close()
orm := &models.ORModels{}
if err := json.NewDecoder(resp.Body).Decode(orm); err != nil {
logger.Warn("failed to decode OR models for vision check", "error", err)
return false
}
return orm.HasVision(modelID)
default:
models, err := fetchLCPModelsWithStatus()
if err != nil {
logger.Warn("failed to fetch LCP models for vision check", "error", err)
return false
}
return models.HasVision(modelID)
}
}
// monitorModelLoad starts a goroutine that periodically checks if the specified model is loaded. // monitorModelLoad starts a goroutine that periodically checks if the specified model is loaded.
func monitorModelLoad(modelID string) { func monitorModelLoad(modelID string) {
go func() { go func() {
@@ -1381,6 +1408,7 @@ func updateModelLists() {
chatBody.Model = m chatBody.Model = m
cachedModelColor = "green" cachedModelColor = "green"
updateStatusLine() updateStatusLine()
UpdateToolCapabilities()
app.Draw() app.Draw()
return return
} }

View File

@@ -608,6 +608,20 @@ func (lcp *LCPModels) ListModels() []string {
return resp return resp
} }
func (lcp *LCPModels) HasVision(modelID string) bool {
for _, m := range lcp.Data {
if m.ID == modelID {
args := m.Status.Args
for i := 0; i < len(args)-1; i++ {
if args[i] == "--mmproj" {
return true
}
}
}
}
return false
}
type ResponseStats struct { type ResponseStats struct {
Tokens int Tokens int
Duration float64 Duration float64

View File

@@ -172,3 +172,16 @@ func (orm *ORModels) ListModels(free bool) []string {
} }
return resp return resp
} }
func (orm *ORModels) HasVision(modelID string) bool {
for i := range orm.Data {
if orm.Data[i].ID == modelID {
for _, mod := range orm.Data[i].Architecture.InputModalities {
if mod == "image" {
return true
}
}
}
}
return false
}

View File

@@ -143,6 +143,7 @@ func showAPILinkSelectionPopup() {
apiListWidget.SetSelectedFunc(func(index int, mainText string, secondaryText string, shortcut rune) { apiListWidget.SetSelectedFunc(func(index int, mainText string, secondaryText string, shortcut rune) {
// Update the API in config // Update the API in config
cfg.CurrentAPI = mainText cfg.CurrentAPI = mainText
UpdateToolCapabilities()
// Update model list based on new API // Update model list based on new API
// Helper function to get model list for a given API (same as in props_table.go) // Helper function to get model list for a given API (same as in props_table.go)
getModelListForAPI := func(api string) []string { getModelListForAPI := func(api string) []string {
@@ -162,6 +163,7 @@ func showAPILinkSelectionPopup() {
if len(newModelList) > 0 && !slices.Contains(newModelList, chatBody.Model) { if len(newModelList) > 0 && !slices.Contains(newModelList, chatBody.Model) {
chatBody.Model = strings.TrimPrefix(newModelList[0], models.LoadedMark) chatBody.Model = strings.TrimPrefix(newModelList[0], models.LoadedMark)
cfg.CurrentModel = chatBody.Model cfg.CurrentModel = chatBody.Model
UpdateToolCapabilities()
} }
pages.RemovePage("apiLinkSelectionPopup") pages.RemovePage("apiLinkSelectionPopup")
app.SetFocus(textArea) app.SetFocus(textArea)

View File

@@ -202,6 +202,7 @@ var (
windowToolsAvailable bool windowToolsAvailable bool
xdotoolPath string xdotoolPath string
maimPath string maimPath string
modelHasVision bool
) )
func init() { func init() {
@@ -233,6 +234,29 @@ func checkWindowTools() {
} }
} }
func UpdateToolCapabilities() {
if !cfg.ToolUse {
return
}
modelHasVision = false
if cfg == nil || cfg.CurrentAPI == "" {
logger.Warn("cannot determine model capabilities: cfg or CurrentAPI is nil")
registerWindowTools()
return
}
prevHasVision := modelHasVision
modelHasVision = ModelHasVision(cfg.CurrentAPI, cfg.CurrentModel)
if modelHasVision {
logger.Info("model has vision support", "model", cfg.CurrentModel, "api", cfg.CurrentAPI)
} else {
logger.Info("model does not have vision support", "model", cfg.CurrentModel, "api", cfg.CurrentAPI)
if windowToolsAvailable && !prevHasVision && modelHasVision == false {
notifyUser("window tools", "Window capture-and-view unavailable: model lacks vision support")
}
}
registerWindowTools()
}
// getWebAgentClient returns a singleton AgentClient for web agents. // getWebAgentClient returns a singleton AgentClient for web agents.
func getWebAgentClient() *agent.AgentClient { func getWebAgentClient() *agent.AgentClient {
webAgentClientOnce.Do(func() { webAgentClientOnce.Do(func() {
@@ -1344,9 +1368,8 @@ func registerWindowTools() {
if windowToolsAvailable { if windowToolsAvailable {
fnMap["list_windows"] = listWindows fnMap["list_windows"] = listWindows
fnMap["capture_window"] = captureWindow fnMap["capture_window"] = captureWindow
fnMap["capture_window_and_view"] = captureWindowAndView windowTools := []models.Tool{
baseTools = append(baseTools, {
models.Tool{
Type: "function", Type: "function",
Function: models.ToolFunc{ Function: models.ToolFunc{
Name: "list_windows", Name: "list_windows",
@@ -1358,7 +1381,7 @@ func registerWindowTools() {
}, },
}, },
}, },
models.Tool{ {
Type: "function", Type: "function",
Function: models.ToolFunc{ Function: models.ToolFunc{
Name: "capture_window", Name: "capture_window",
@@ -1375,7 +1398,10 @@ func registerWindowTools() {
}, },
}, },
}, },
models.Tool{ }
if modelHasVision {
fnMap["capture_window_and_view"] = captureWindowAndView
windowTools = append(windowTools, models.Tool{
Type: "function", Type: "function",
Function: models.ToolFunc{ Function: models.ToolFunc{
Name: "capture_window_and_view", Name: "capture_window_and_view",
@@ -1391,8 +1417,9 @@ func registerWindowTools() {
}, },
}, },
}, },
}, })
) }
baseTools = append(baseTools, windowTools...)
toolSysMsg += windowToolSysMsg toolSysMsg += windowToolSysMsg
} }
} }