Feat: add deepseek integration [WIP] (only completion works)
This commit is contained in:
93
bot.go
93
bot.go
@@ -2,6 +2,8 @@ package main
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"elefant/config"
|
||||
"elefant/models"
|
||||
"elefant/rag"
|
||||
@@ -10,6 +12,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"path"
|
||||
@@ -20,7 +23,30 @@ import (
|
||||
"github.com/rivo/tview"
|
||||
)
|
||||
|
||||
var httpClient = http.Client{}
|
||||
var httpClient = &http.Client{}
|
||||
|
||||
func createClient(connectTimeout time.Duration) *http.Client {
|
||||
// Custom transport with connection timeout
|
||||
transport := &http.Transport{
|
||||
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
// Create a dialer with connection timeout
|
||||
dialer := &net.Dialer{
|
||||
Timeout: connectTimeout,
|
||||
KeepAlive: 30 * time.Second, // Optional
|
||||
}
|
||||
return dialer.DialContext(ctx, network, addr)
|
||||
},
|
||||
// Other transport settings (optional)
|
||||
TLSHandshakeTimeout: connectTimeout,
|
||||
ResponseHeaderTimeout: connectTimeout,
|
||||
}
|
||||
|
||||
// Client with no overall timeout (or set to streaming-safe duration)
|
||||
return &http.Client{
|
||||
Transport: transport,
|
||||
Timeout: 0, // No overall timeout (for streaming)
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
cfg *config.Config
|
||||
@@ -36,7 +62,6 @@ var (
|
||||
defaultStarterBytes = []byte{}
|
||||
interruptResp = false
|
||||
ragger *rag.RAG
|
||||
currentModel = "none"
|
||||
chunkParser ChunkParser
|
||||
defaultLCPProps = map[string]float32{
|
||||
"temperature": 0.8,
|
||||
@@ -47,6 +72,7 @@ var (
|
||||
)
|
||||
|
||||
func fetchModelName() *models.LLMModels {
|
||||
// TODO: to config
|
||||
api := "http://localhost:8080/v1/models"
|
||||
//nolint
|
||||
resp, err := httpClient.Get(api)
|
||||
@@ -61,16 +87,44 @@ func fetchModelName() *models.LLMModels {
|
||||
return nil
|
||||
}
|
||||
if resp.StatusCode != 200 {
|
||||
currentModel = "disconnected"
|
||||
chatBody.Model = "disconnected"
|
||||
return nil
|
||||
}
|
||||
currentModel = path.Base(llmModel.Data[0].ID)
|
||||
chatBody.Model = path.Base(llmModel.Data[0].ID)
|
||||
return &llmModel
|
||||
}
|
||||
|
||||
func fetchDSBalance() *models.DSBalance {
|
||||
url := "https://api.deepseek.com/user/balance"
|
||||
method := "GET"
|
||||
req, err := http.NewRequest(method, url, nil)
|
||||
if err != nil {
|
||||
logger.Warn("failed to create request", "error", err)
|
||||
return nil
|
||||
}
|
||||
req.Header.Add("Accept", "application/json")
|
||||
req.Header.Add("Authorization", "Bearer "+cfg.DeepSeekToken)
|
||||
res, err := httpClient.Do(req)
|
||||
if err != nil {
|
||||
logger.Warn("failed to make request", "error", err)
|
||||
return nil
|
||||
}
|
||||
defer res.Body.Close()
|
||||
resp := models.DSBalance{}
|
||||
if err := json.NewDecoder(res.Body).Decode(&resp); err != nil {
|
||||
return nil
|
||||
}
|
||||
return &resp
|
||||
}
|
||||
|
||||
func sendMsgToLLM(body io.Reader) {
|
||||
choseChunkParser()
|
||||
req, err := http.NewRequest("POST", cfg.CurrentAPI, body)
|
||||
req.Header.Add("Accept", "application/json")
|
||||
req.Header.Add("Content-Type", "application/json")
|
||||
req.Header.Add("Authorization", "Bearer "+cfg.DeepSeekToken)
|
||||
// nolint
|
||||
resp, err := httpClient.Post(cfg.CurrentAPI, "application/json", body)
|
||||
// resp, err := httpClient.Post(cfg.CurrentAPI, "application/json", body)
|
||||
if err != nil {
|
||||
logger.Error("llamacpp api", "error", err)
|
||||
if err := notifyUser("error", "apicall failed:"+err.Error()); err != nil {
|
||||
@@ -79,6 +133,16 @@ func sendMsgToLLM(body io.Reader) {
|
||||
streamDone <- true
|
||||
return
|
||||
}
|
||||
resp, err := httpClient.Do(req)
|
||||
if err != nil {
|
||||
bodyBytes, _ := io.ReadAll(body)
|
||||
logger.Error("llamacpp api", "error", err, "body", string(bodyBytes))
|
||||
if err := notifyUser("error", "apicall failed:"+err.Error()); err != nil {
|
||||
logger.Error("failed to notify", "error", err)
|
||||
}
|
||||
streamDone <- true
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
reader := bufio.NewReader(resp.Body)
|
||||
counter := uint32(0)
|
||||
@@ -113,6 +177,10 @@ func sendMsgToLLM(body io.Reader) {
|
||||
// starts with -> data:
|
||||
line = line[6:]
|
||||
logger.Debug("debugging resp", "line", string(line))
|
||||
if bytes.Equal(line, []byte("[DONE]\n")) {
|
||||
streamDone <- true
|
||||
break
|
||||
}
|
||||
content, stop, err = chunkParser.ParseChunk(line)
|
||||
if err != nil {
|
||||
logger.Error("error parsing response body", "error", err, "line", string(line), "url", cfg.CurrentAPI)
|
||||
@@ -185,7 +253,17 @@ func roleToIcon(role string) string {
|
||||
|
||||
func chatRound(userMsg, role string, tv *tview.TextView, regen, resume bool) {
|
||||
botRespMode = true
|
||||
// reader := formMsg(chatBody, userMsg, role)
|
||||
defer func() { botRespMode = false }()
|
||||
// check that there is a model set to use if is not local
|
||||
if cfg.CurrentAPI == cfg.DeepSeekChatAPI || cfg.CurrentAPI == cfg.DeepSeekCompletionAPI {
|
||||
if chatBody.Model != "deepseek-chat" && chatBody.Model != "deepseek-reasoner" {
|
||||
if err := notifyUser("bad request", "wrong deepseek model name"); err != nil {
|
||||
logger.Warn("failed ot notify user", "error", err)
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
reader, err := chunkParser.FormMsg(userMsg, role, resume)
|
||||
if reader == nil || err != nil {
|
||||
logger.Error("empty reader from msgs", "role", role, "error", err)
|
||||
@@ -369,7 +447,8 @@ func init() {
|
||||
Stream: true,
|
||||
Messages: lastChat,
|
||||
}
|
||||
initChunkParser()
|
||||
choseChunkParser()
|
||||
httpClient = createClient(time.Second * 15)
|
||||
// go runModelNameTicker(time.Second * 120)
|
||||
// tempLoad()
|
||||
}
|
||||
|
||||
@@ -10,6 +10,7 @@ type Config struct {
|
||||
ChatAPI string `toml:"ChatAPI"`
|
||||
CompletionAPI string `toml:"CompletionAPI"`
|
||||
CurrentAPI string
|
||||
CurrentProvider string
|
||||
APIMap map[string]string
|
||||
//
|
||||
ShowSys bool `toml:"ShowSys"`
|
||||
@@ -30,6 +31,12 @@ type Config struct {
|
||||
RAGWorkers uint32 `toml:"RAGWorkers"`
|
||||
RAGBatchSize int `toml:"RAGBatchSize"`
|
||||
RAGWordLimit uint32 `toml:"RAGWordLimit"`
|
||||
// deepseek
|
||||
DeepSeekChatAPI string `toml:"DeepSeekChatAPI"`
|
||||
DeepSeekCompletionAPI string `toml:"DeepSeekCompletionAPI"`
|
||||
DeepSeekToken string `toml:"DeepSeekToken"`
|
||||
DeepSeekModel string `toml:"DeepSeekModel"`
|
||||
ApiLinks []string
|
||||
}
|
||||
|
||||
func LoadConfigOrDefault(fn string) *Config {
|
||||
@@ -39,9 +46,11 @@ func LoadConfigOrDefault(fn string) *Config {
|
||||
config := &Config{}
|
||||
_, err := toml.DecodeFile(fn, &config)
|
||||
if err != nil {
|
||||
fmt.Println("failed to read config from file, loading default")
|
||||
fmt.Println("failed to read config from file, loading default", "error", err)
|
||||
config.ChatAPI = "http://localhost:8080/v1/chat/completions"
|
||||
config.CompletionAPI = "http://localhost:8080/completion"
|
||||
config.DeepSeekCompletionAPI = "https://api.deepseek.com/beta/completions"
|
||||
config.DeepSeekChatAPI = "https://api.deepseek.com/chat/completions"
|
||||
config.RAGEnabled = false
|
||||
config.EmbedURL = "http://localhost:8080/v1/embiddings"
|
||||
config.ShowSys = true
|
||||
@@ -59,11 +68,15 @@ func LoadConfigOrDefault(fn string) *Config {
|
||||
config.CurrentAPI = config.ChatAPI
|
||||
config.APIMap = map[string]string{
|
||||
config.ChatAPI: config.CompletionAPI,
|
||||
config.DeepSeekChatAPI: config.DeepSeekCompletionAPI,
|
||||
}
|
||||
if config.CompletionAPI != "" {
|
||||
config.CurrentAPI = config.CompletionAPI
|
||||
config.APIMap = map[string]string{
|
||||
config.CompletionAPI: config.ChatAPI,
|
||||
config.APIMap[config.CompletionAPI] = config.ChatAPI
|
||||
}
|
||||
for _, el := range []string{config.ChatAPI, config.CompletionAPI, config.DeepSeekChatAPI, config.DeepSeekCompletionAPI} {
|
||||
if el != "" {
|
||||
config.ApiLinks = append(config.ApiLinks, el)
|
||||
}
|
||||
}
|
||||
// if any value is empty fill with default
|
||||
|
||||
90
llm.go
90
llm.go
@@ -13,20 +13,32 @@ type ChunkParser interface {
|
||||
FormMsg(msg, role string, cont bool) (io.Reader, error)
|
||||
}
|
||||
|
||||
func initChunkParser() {
|
||||
func choseChunkParser() {
|
||||
chunkParser = LlamaCPPeer{}
|
||||
if strings.Contains(cfg.CurrentAPI, "v1") {
|
||||
logger.Debug("chosen /v1/chat parser")
|
||||
switch cfg.CurrentAPI {
|
||||
case "http://localhost:8080/completion":
|
||||
chunkParser = LlamaCPPeer{}
|
||||
case "http://localhost:8080/v1/chat/completions":
|
||||
chunkParser = OpenAIer{}
|
||||
return
|
||||
case "https://api.deepseek.com/beta/completions":
|
||||
chunkParser = DeepSeeker{}
|
||||
default:
|
||||
chunkParser = LlamaCPPeer{}
|
||||
}
|
||||
logger.Debug("chosen llamacpp /completion parser")
|
||||
// if strings.Contains(cfg.CurrentAPI, "chat") {
|
||||
// logger.Debug("chosen chat parser")
|
||||
// chunkParser = OpenAIer{}
|
||||
// return
|
||||
// }
|
||||
// logger.Debug("chosen llamacpp /completion parser")
|
||||
}
|
||||
|
||||
type LlamaCPPeer struct {
|
||||
}
|
||||
type OpenAIer struct {
|
||||
}
|
||||
type DeepSeeker struct {
|
||||
}
|
||||
|
||||
func (lcp LlamaCPPeer) FormMsg(msg, role string, resume bool) (io.Reader, error) {
|
||||
if msg != "" { // otherwise let the bot to continue
|
||||
@@ -62,7 +74,12 @@ func (lcp LlamaCPPeer) FormMsg(msg, role string, resume bool) (io.Reader, error)
|
||||
}
|
||||
logger.Debug("checking prompt for /completion", "tool_use", cfg.ToolUse,
|
||||
"msg", msg, "resume", resume, "prompt", prompt)
|
||||
payload := models.NewLCPReq(prompt, cfg, defaultLCPProps)
|
||||
var payload any
|
||||
payload = models.NewLCPReq(prompt, cfg, defaultLCPProps)
|
||||
if strings.Contains(chatBody.Model, "deepseek") {
|
||||
payload = models.NewDSCompletionReq(prompt, chatBody.Model,
|
||||
defaultLCPProps["temp"], cfg)
|
||||
}
|
||||
data, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
logger.Error("failed to form a msg", "error", err)
|
||||
@@ -129,3 +146,64 @@ func (op OpenAIer) FormMsg(msg, role string, resume bool) (io.Reader, error) {
|
||||
}
|
||||
return bytes.NewReader(data), nil
|
||||
}
|
||||
|
||||
// deepseek
|
||||
func (ds DeepSeeker) ParseChunk(data []byte) (string, bool, error) {
|
||||
llmchunk := models.DSCompletionResp{}
|
||||
if err := json.Unmarshal(data, &llmchunk); err != nil {
|
||||
logger.Error("failed to decode", "error", err, "line", string(data))
|
||||
return "", false, err
|
||||
}
|
||||
if llmchunk.Choices[0].FinishReason != "" {
|
||||
if llmchunk.Choices[0].Text != "" {
|
||||
logger.Error("text inside of finish llmchunk", "chunk", llmchunk)
|
||||
}
|
||||
return llmchunk.Choices[0].Text, true, nil
|
||||
}
|
||||
return llmchunk.Choices[0].Text, false, nil
|
||||
}
|
||||
|
||||
func (ds DeepSeeker) FormMsg(msg, role string, resume bool) (io.Reader, error) {
|
||||
if msg != "" { // otherwise let the bot to continue
|
||||
newMsg := models.RoleMsg{Role: role, Content: msg}
|
||||
chatBody.Messages = append(chatBody.Messages, newMsg)
|
||||
// if rag
|
||||
if cfg.RAGEnabled {
|
||||
ragResp, err := chatRagUse(newMsg.Content)
|
||||
if err != nil {
|
||||
logger.Error("failed to form a rag msg", "error", err)
|
||||
return nil, err
|
||||
}
|
||||
ragMsg := models.RoleMsg{Role: cfg.ToolRole, Content: ragResp}
|
||||
chatBody.Messages = append(chatBody.Messages, ragMsg)
|
||||
}
|
||||
}
|
||||
if cfg.ToolUse && !resume {
|
||||
// add to chat body
|
||||
chatBody.Messages = append(chatBody.Messages, models.RoleMsg{Role: cfg.ToolRole, Content: toolSysMsg})
|
||||
}
|
||||
messages := make([]string, len(chatBody.Messages))
|
||||
for i, m := range chatBody.Messages {
|
||||
messages[i] = m.ToPrompt()
|
||||
}
|
||||
prompt := strings.Join(messages, "\n")
|
||||
// strings builder?
|
||||
if !resume {
|
||||
botMsgStart := "\n" + cfg.AssistantRole + ":\n"
|
||||
prompt += botMsgStart
|
||||
}
|
||||
if cfg.ThinkUse && !cfg.ToolUse {
|
||||
prompt += "<think>"
|
||||
}
|
||||
logger.Debug("checking prompt for /completion", "tool_use", cfg.ToolUse,
|
||||
"msg", msg, "resume", resume, "prompt", prompt)
|
||||
var payload any
|
||||
payload = models.NewDSCompletionReq(prompt, chatBody.Model,
|
||||
defaultLCPProps["temp"], cfg)
|
||||
data, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
logger.Error("failed to form a msg", "error", err)
|
||||
return nil, err
|
||||
}
|
||||
return bytes.NewReader(data), nil
|
||||
}
|
||||
|
||||
130
models/models.go
130
models/models.go
@@ -103,6 +103,126 @@ type ChatToolsBody struct {
|
||||
ToolChoice string `json:"tool_choice"`
|
||||
}
|
||||
|
||||
type DSChatReq struct {
|
||||
Messages []RoleMsg `json:"messages"`
|
||||
Model string `json:"model"`
|
||||
Stream bool `json:"stream"`
|
||||
FrequencyPenalty int `json:"frequency_penalty"`
|
||||
MaxTokens int `json:"max_tokens"`
|
||||
PresencePenalty int `json:"presence_penalty"`
|
||||
Temperature float32 `json:"temperature"`
|
||||
TopP float32 `json:"top_p"`
|
||||
// ResponseFormat struct {
|
||||
// Type string `json:"type"`
|
||||
// } `json:"response_format"`
|
||||
// Stop any `json:"stop"`
|
||||
// StreamOptions any `json:"stream_options"`
|
||||
// Tools any `json:"tools"`
|
||||
// ToolChoice string `json:"tool_choice"`
|
||||
// Logprobs bool `json:"logprobs"`
|
||||
// TopLogprobs any `json:"top_logprobs"`
|
||||
}
|
||||
|
||||
func NewDSCharReq(cb *ChatBody) DSChatReq {
|
||||
return DSChatReq{
|
||||
Messages: cb.Messages,
|
||||
Model: cb.Model,
|
||||
Stream: cb.Stream,
|
||||
MaxTokens: 2048,
|
||||
PresencePenalty: 0,
|
||||
FrequencyPenalty: 0,
|
||||
Temperature: 1.0,
|
||||
TopP: 1.0,
|
||||
}
|
||||
}
|
||||
|
||||
type DSCompletionReq struct {
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt"`
|
||||
Echo bool `json:"echo"`
|
||||
FrequencyPenalty int `json:"frequency_penalty"`
|
||||
// Logprobs int `json:"logprobs"`
|
||||
MaxTokens int `json:"max_tokens"`
|
||||
PresencePenalty int `json:"presence_penalty"`
|
||||
Stop any `json:"stop"`
|
||||
Stream bool `json:"stream"`
|
||||
StreamOptions any `json:"stream_options"`
|
||||
Suffix any `json:"suffix"`
|
||||
Temperature float32 `json:"temperature"`
|
||||
TopP float32 `json:"top_p"`
|
||||
}
|
||||
|
||||
func NewDSCompletionReq(prompt, model string, temp float32, cfg *config.Config) DSCompletionReq {
|
||||
return DSCompletionReq{
|
||||
Model: model,
|
||||
Prompt: prompt,
|
||||
Temperature: temp,
|
||||
Stream: true,
|
||||
Echo: false,
|
||||
MaxTokens: 2048,
|
||||
PresencePenalty: 0,
|
||||
FrequencyPenalty: 0,
|
||||
TopP: 1.0,
|
||||
Stop: []string{
|
||||
cfg.UserRole + ":\n", "<|im_end|>",
|
||||
cfg.ToolRole + ":\n",
|
||||
cfg.AssistantRole + ":\n",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
type DSCompletionResp struct {
|
||||
ID string `json:"id"`
|
||||
Choices []struct {
|
||||
FinishReason string `json:"finish_reason"`
|
||||
Index int `json:"index"`
|
||||
Logprobs struct {
|
||||
TextOffset []int `json:"text_offset"`
|
||||
TokenLogprobs []int `json:"token_logprobs"`
|
||||
Tokens []string `json:"tokens"`
|
||||
TopLogprobs []struct {
|
||||
} `json:"top_logprobs"`
|
||||
} `json:"logprobs"`
|
||||
Text string `json:"text"`
|
||||
} `json:"choices"`
|
||||
Created int `json:"created"`
|
||||
Model string `json:"model"`
|
||||
SystemFingerprint string `json:"system_fingerprint"`
|
||||
Object string `json:"object"`
|
||||
Usage struct {
|
||||
CompletionTokens int `json:"completion_tokens"`
|
||||
PromptTokens int `json:"prompt_tokens"`
|
||||
PromptCacheHitTokens int `json:"prompt_cache_hit_tokens"`
|
||||
PromptCacheMissTokens int `json:"prompt_cache_miss_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
CompletionTokensDetails struct {
|
||||
ReasoningTokens int `json:"reasoning_tokens"`
|
||||
} `json:"completion_tokens_details"`
|
||||
} `json:"usage"`
|
||||
}
|
||||
|
||||
type DSChatResp struct {
|
||||
Choices []struct {
|
||||
Delta struct {
|
||||
Content string `json:"content"`
|
||||
Role any `json:"role"`
|
||||
} `json:"delta"`
|
||||
FinishReason string `json:"finish_reason"`
|
||||
Index int `json:"index"`
|
||||
Logprobs any `json:"logprobs"`
|
||||
} `json:"choices"`
|
||||
Created int `json:"created"`
|
||||
ID string `json:"id"`
|
||||
Model string `json:"model"`
|
||||
Object string `json:"object"`
|
||||
SystemFingerprint string `json:"system_fingerprint"`
|
||||
Usage struct {
|
||||
CompletionTokens int `json:"completion_tokens"`
|
||||
PromptTokens int `json:"prompt_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
} `json:"usage"`
|
||||
}
|
||||
|
||||
type EmbeddingResp struct {
|
||||
Embedding []float32 `json:"embedding"`
|
||||
Index uint32 `json:"index"`
|
||||
@@ -190,3 +310,13 @@ type LlamaCPPResp struct {
|
||||
Content string `json:"content"`
|
||||
Stop bool `json:"stop"`
|
||||
}
|
||||
|
||||
type DSBalance struct {
|
||||
IsAvailable bool `json:"is_available"`
|
||||
BalanceInfos []struct {
|
||||
Currency string `json:"currency"`
|
||||
TotalBalance string `json:"total_balance"`
|
||||
GrantedBalance string `json:"granted_balance"`
|
||||
ToppedUpBalance string `json:"topped_up_balance"`
|
||||
} `json:"balance_infos"`
|
||||
}
|
||||
|
||||
10
tui.go
10
tui.go
@@ -136,7 +136,7 @@ func colorText() {
|
||||
}
|
||||
|
||||
func updateStatusLine() {
|
||||
position.SetText(fmt.Sprintf(indexLine, botRespMode, cfg.AssistantRole, activeChatName, cfg.RAGEnabled, cfg.ToolUse, currentModel, cfg.CurrentAPI, cfg.ThinkUse, logLevel.Level()))
|
||||
position.SetText(fmt.Sprintf(indexLine, botRespMode, cfg.AssistantRole, activeChatName, cfg.RAGEnabled, cfg.ToolUse, chatBody.Model, cfg.CurrentAPI, cfg.ThinkUse, logLevel.Level()))
|
||||
}
|
||||
|
||||
func initSysCards() ([]string, error) {
|
||||
@@ -202,6 +202,12 @@ func makePropsForm(props map[string]float32) *tview.Form {
|
||||
}).AddDropDown("Set log level (Enter): ", []string{"Debug", "Info", "Warn"}, 1,
|
||||
func(option string, optionIndex int) {
|
||||
setLogLevel(option)
|
||||
}).AddDropDown("Select an api: ", cfg.ApiLinks, 1,
|
||||
func(option string, optionIndex int) {
|
||||
cfg.CurrentAPI = option
|
||||
}).AddDropDown("Select a model: ", []string{chatBody.Model, "deepseek-chat", "deepseek-reasoner"}, 0,
|
||||
func(option string, optionIndex int) {
|
||||
chatBody.Model = option
|
||||
}).
|
||||
AddButton("Quit", func() {
|
||||
pages.RemovePage(propsPage)
|
||||
@@ -600,7 +606,7 @@ func init() {
|
||||
}
|
||||
cfg.APIMap[newAPI] = prevAPI
|
||||
cfg.CurrentAPI = newAPI
|
||||
initChunkParser()
|
||||
choseChunkParser()
|
||||
updateStatusLine()
|
||||
return nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user