Enha: check if model has vision before giving it vision tools
This commit is contained in:
28
bot.go
28
bot.go
@@ -433,6 +433,33 @@ func isModelLoaded(modelID string) (bool, error) {
|
|||||||
return false, nil
|
return false, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func ModelHasVision(api, modelID string) bool {
|
||||||
|
switch {
|
||||||
|
case strings.Contains(api, "deepseek"):
|
||||||
|
return false
|
||||||
|
case strings.Contains(api, "openrouter"):
|
||||||
|
resp, err := http.Get("https://openrouter.ai/api/v1/models")
|
||||||
|
if err != nil {
|
||||||
|
logger.Warn("failed to fetch OR models for vision check", "error", err)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
orm := &models.ORModels{}
|
||||||
|
if err := json.NewDecoder(resp.Body).Decode(orm); err != nil {
|
||||||
|
logger.Warn("failed to decode OR models for vision check", "error", err)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return orm.HasVision(modelID)
|
||||||
|
default:
|
||||||
|
models, err := fetchLCPModelsWithStatus()
|
||||||
|
if err != nil {
|
||||||
|
logger.Warn("failed to fetch LCP models for vision check", "error", err)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return models.HasVision(modelID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// monitorModelLoad starts a goroutine that periodically checks if the specified model is loaded.
|
// monitorModelLoad starts a goroutine that periodically checks if the specified model is loaded.
|
||||||
func monitorModelLoad(modelID string) {
|
func monitorModelLoad(modelID string) {
|
||||||
go func() {
|
go func() {
|
||||||
@@ -1381,6 +1408,7 @@ func updateModelLists() {
|
|||||||
chatBody.Model = m
|
chatBody.Model = m
|
||||||
cachedModelColor = "green"
|
cachedModelColor = "green"
|
||||||
updateStatusLine()
|
updateStatusLine()
|
||||||
|
UpdateToolCapabilities()
|
||||||
app.Draw()
|
app.Draw()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -608,6 +608,20 @@ func (lcp *LCPModels) ListModels() []string {
|
|||||||
return resp
|
return resp
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (lcp *LCPModels) HasVision(modelID string) bool {
|
||||||
|
for _, m := range lcp.Data {
|
||||||
|
if m.ID == modelID {
|
||||||
|
args := m.Status.Args
|
||||||
|
for i := 0; i < len(args)-1; i++ {
|
||||||
|
if args[i] == "--mmproj" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
type ResponseStats struct {
|
type ResponseStats struct {
|
||||||
Tokens int
|
Tokens int
|
||||||
Duration float64
|
Duration float64
|
||||||
|
|||||||
@@ -172,3 +172,16 @@ func (orm *ORModels) ListModels(free bool) []string {
|
|||||||
}
|
}
|
||||||
return resp
|
return resp
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (orm *ORModels) HasVision(modelID string) bool {
|
||||||
|
for i := range orm.Data {
|
||||||
|
if orm.Data[i].ID == modelID {
|
||||||
|
for _, mod := range orm.Data[i].Architecture.InputModalities {
|
||||||
|
if mod == "image" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|||||||
@@ -143,6 +143,7 @@ func showAPILinkSelectionPopup() {
|
|||||||
apiListWidget.SetSelectedFunc(func(index int, mainText string, secondaryText string, shortcut rune) {
|
apiListWidget.SetSelectedFunc(func(index int, mainText string, secondaryText string, shortcut rune) {
|
||||||
// Update the API in config
|
// Update the API in config
|
||||||
cfg.CurrentAPI = mainText
|
cfg.CurrentAPI = mainText
|
||||||
|
UpdateToolCapabilities()
|
||||||
// Update model list based on new API
|
// Update model list based on new API
|
||||||
// Helper function to get model list for a given API (same as in props_table.go)
|
// Helper function to get model list for a given API (same as in props_table.go)
|
||||||
getModelListForAPI := func(api string) []string {
|
getModelListForAPI := func(api string) []string {
|
||||||
@@ -162,6 +163,7 @@ func showAPILinkSelectionPopup() {
|
|||||||
if len(newModelList) > 0 && !slices.Contains(newModelList, chatBody.Model) {
|
if len(newModelList) > 0 && !slices.Contains(newModelList, chatBody.Model) {
|
||||||
chatBody.Model = strings.TrimPrefix(newModelList[0], models.LoadedMark)
|
chatBody.Model = strings.TrimPrefix(newModelList[0], models.LoadedMark)
|
||||||
cfg.CurrentModel = chatBody.Model
|
cfg.CurrentModel = chatBody.Model
|
||||||
|
UpdateToolCapabilities()
|
||||||
}
|
}
|
||||||
pages.RemovePage("apiLinkSelectionPopup")
|
pages.RemovePage("apiLinkSelectionPopup")
|
||||||
app.SetFocus(textArea)
|
app.SetFocus(textArea)
|
||||||
|
|||||||
41
tools.go
41
tools.go
@@ -202,6 +202,7 @@ var (
|
|||||||
windowToolsAvailable bool
|
windowToolsAvailable bool
|
||||||
xdotoolPath string
|
xdotoolPath string
|
||||||
maimPath string
|
maimPath string
|
||||||
|
modelHasVision bool
|
||||||
)
|
)
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
@@ -233,6 +234,29 @@ func checkWindowTools() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func UpdateToolCapabilities() {
|
||||||
|
if !cfg.ToolUse {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
modelHasVision = false
|
||||||
|
if cfg == nil || cfg.CurrentAPI == "" {
|
||||||
|
logger.Warn("cannot determine model capabilities: cfg or CurrentAPI is nil")
|
||||||
|
registerWindowTools()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
prevHasVision := modelHasVision
|
||||||
|
modelHasVision = ModelHasVision(cfg.CurrentAPI, cfg.CurrentModel)
|
||||||
|
if modelHasVision {
|
||||||
|
logger.Info("model has vision support", "model", cfg.CurrentModel, "api", cfg.CurrentAPI)
|
||||||
|
} else {
|
||||||
|
logger.Info("model does not have vision support", "model", cfg.CurrentModel, "api", cfg.CurrentAPI)
|
||||||
|
if windowToolsAvailable && !prevHasVision && modelHasVision == false {
|
||||||
|
notifyUser("window tools", "Window capture-and-view unavailable: model lacks vision support")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
registerWindowTools()
|
||||||
|
}
|
||||||
|
|
||||||
// getWebAgentClient returns a singleton AgentClient for web agents.
|
// getWebAgentClient returns a singleton AgentClient for web agents.
|
||||||
func getWebAgentClient() *agent.AgentClient {
|
func getWebAgentClient() *agent.AgentClient {
|
||||||
webAgentClientOnce.Do(func() {
|
webAgentClientOnce.Do(func() {
|
||||||
@@ -1344,9 +1368,8 @@ func registerWindowTools() {
|
|||||||
if windowToolsAvailable {
|
if windowToolsAvailable {
|
||||||
fnMap["list_windows"] = listWindows
|
fnMap["list_windows"] = listWindows
|
||||||
fnMap["capture_window"] = captureWindow
|
fnMap["capture_window"] = captureWindow
|
||||||
fnMap["capture_window_and_view"] = captureWindowAndView
|
windowTools := []models.Tool{
|
||||||
baseTools = append(baseTools,
|
{
|
||||||
models.Tool{
|
|
||||||
Type: "function",
|
Type: "function",
|
||||||
Function: models.ToolFunc{
|
Function: models.ToolFunc{
|
||||||
Name: "list_windows",
|
Name: "list_windows",
|
||||||
@@ -1358,7 +1381,7 @@ func registerWindowTools() {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
models.Tool{
|
{
|
||||||
Type: "function",
|
Type: "function",
|
||||||
Function: models.ToolFunc{
|
Function: models.ToolFunc{
|
||||||
Name: "capture_window",
|
Name: "capture_window",
|
||||||
@@ -1375,7 +1398,10 @@ func registerWindowTools() {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
models.Tool{
|
}
|
||||||
|
if modelHasVision {
|
||||||
|
fnMap["capture_window_and_view"] = captureWindowAndView
|
||||||
|
windowTools = append(windowTools, models.Tool{
|
||||||
Type: "function",
|
Type: "function",
|
||||||
Function: models.ToolFunc{
|
Function: models.ToolFunc{
|
||||||
Name: "capture_window_and_view",
|
Name: "capture_window_and_view",
|
||||||
@@ -1391,8 +1417,9 @@ func registerWindowTools() {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
})
|
||||||
)
|
}
|
||||||
|
baseTools = append(baseTools, windowTools...)
|
||||||
toolSysMsg += windowToolSysMsg
|
toolSysMsg += windowToolSysMsg
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user