diff --git a/main.go b/main.go index f6ba56d..884f88e 100644 --- a/main.go +++ b/main.go @@ -218,7 +218,6 @@ func callLLM(prompt string, apiURL string) ([]byte, error) { } } - client := &http.Client{} maxRetries := 6 baseDelay := 2 * time.Second @@ -234,7 +233,7 @@ func callLLM(prompt string, apiURL string) ([]byte, error) { req.Header.Add("Authorization", "Bearer "+cfg.APIToken) } - resp, err := client.Do(req) + resp, err := httpClient.Do(req) if err != nil { if attempt == maxRetries-1 { return nil, fmt.Errorf("LLM call failed after %d retries on client.Do: %w", maxRetries, err) diff --git a/models/models.go b/models/models.go index 80cbe2c..18a2a08 100644 --- a/models/models.go +++ b/models/models.go @@ -85,3 +85,44 @@ type RPMessage struct { Author string `json:"author"` Content string `json:"content"` } + +// === tools models + +type ToolArgProps struct { + Type string `json:"type"` + Description string `json:"description"` +} + +type ToolFuncParams struct { + Type string `json:"type"` + Properties map[string]ToolArgProps `json:"properties"` + Required []string `json:"required"` +} + +type ToolFunc struct { + Name string `json:"name"` + Description string `json:"description"` + Parameters ToolFuncParams `json:"parameters"` +} + +type Tool struct { + Type string `json:"type"` + Function ToolFunc `json:"function"` +} + +type OpenAIReq struct { + *ChatBody + Tools []Tool `json:"tools"` +} + +type ChatBody struct { + Model string `json:"model"` + Messages []RoleMsg `json:"messages"` +} + +type RoleMsg struct { + Role string `json:"role"` + Content string `json:"content"` +} + +// === diff --git a/tools.go b/tools.go new file mode 100644 index 0000000..fa2dbc7 --- /dev/null +++ b/tools.go @@ -0,0 +1,65 @@ +package main + +import ( + "grailbench/models" + "time" +) + +func sendEmail(args map[string]string) []byte { + logger.Info("send-email is used", "args", args) + return nil +} + +func getCurrentTimestamp(args map[string]string) []byte { + ts := time.Now() + return []byte(ts.Format(time.RFC3339)) +} + +type fnSig func(map[string]string) []byte + +var fnMap = map[string]fnSig{ + "get_current_timestamp": getCurrentTimestamp, + "send_email": sendEmail, +} + +// openai style def +var baseTools = []models.Tool{ + // get_current_timestamp + models.Tool{ + Type: "function", + Function: models.ToolFunc{ + Name: "get_current_timestamp", + Description: "Returns current timestamp in RFC3999 format", + Parameters: models.ToolFuncParams{ + Type: "object", + Required: []string{}, + }, + }, + }, + // send_email + models.Tool{ + Type: "function", + Function: models.ToolFunc{ + Name: "send_email", + Description: "Sends email to a provided address with given message.", + Parameters: models.ToolFuncParams{ + Type: "object", + Required: []string{"address", "title", "body"}, + Properties: map[string]models.ToolArgProps{ + "address": models.ToolArgProps{ + Type: "string", + Description: "email address of the recipient", + }, + "title": models.ToolArgProps{ + Type: "string", + Description: "", + }, + "body": models.ToolArgProps{ + Type: "string", + Description: "email body, can be in form of html page, markdown or pure text.", + }, + }, + }, + }, + }, +}